0


torch.where()用法

torch.where用法


前言

本文主要讲述

  1. torch.where()

的两种用法,第一种是最常规的,也是官方文档所注明的;第二种就是配合

  1. bool

型张量的计算


1、torch.where()常规用法

我们先看官方文档的解释:

  1. torch.where(condition, x, y)

根据条件,也就是

  1. condiction

,返回从

  1. x

  1. y

中选择的元素的张量(这里会创建一个新的张量,新张量的元素就是从x或y中选的,形状要符合

  1. x

  1. y

的广播条件)。

  1. Parameters

解释如下:
1、

  1. condition (bool型张量)

:当

  1. condition

为真,返回

  1. x

的值,否则返回

  1. y

的值
2、

  1. x (张量或标量)

:当

  1. condition=True

时选

  1. x

的值
2、

  1. y (张量或标量)

:当

  1. condition=False

时选

  1. y

的值

我看了好些博文,他们都说

  1. x

  1. y

的形状必须相同,完全胡扯嘛,官方文档写的明明白白的:The tensors condition, x, y must be broadcastable. 也就是说

  1. conditionxy

能进行广播就行,并不要求形状一样。下面看用法:

1.1 形状相同

先演示形状相同的情况:

  1. import torch
  2. x = torch.tensor([[1,2,3],[3,4,5],[5,6,7]])
  3. y = torch.tensor([[5,6,7],[7,8,9],[9,10,11]])
  4. z = torch.where(x >5, x, y)print(f'x = {x}')print(f'=========================')print(f'y = {y}')print(f'=========================')print(f'x > 5 = {x >5}')print(f'=========================')print(f'z = {z}')>print result:
  5. x = tensor([[1,2,3],[3,4,5],[5,6,7]])=========================
  6. y = tensor([[5,6,7],[7,8,9],[9,10,11]])=========================
  7. x >5= tensor([[False,False,False],[False,False,False],[False,True,True]])=========================
  8. z = tensor([[5,6,7],[7,8,9],[9,6,7]])

上面定义了

  1. x

  1. y

,两者的形状

  1. shape=(3, 3)

相同,然后

  1. condition = x > 5

也是就

  1. x

中的每个元素值都要大于5,这里就能看到

  1. x

中第0行和第1行都是

  1. False

,只有第2行的1、2列是

  1. True

,然后前面说了,为

  1. True

时使用的是

  1. x

中的值,为

  1. False

时使用的是

  1. y

中的值,那么新创建的

  1. z

前两行和第2行0列使用的是

  1. y

中的值,剩下两个使用

  1. x

中的值,

  1. z

  1. shape

也是

  1. (3, 3)

1.2 标量情况

  1. x =3
  2. y = torch.tensor([[1,5,7]])
  3. z = torch.where(y >2, y, x)print(f'y > 2 = {y >2}')print(f'=========================')print(f'z = {z}')print(f'y > 2 = {y >2}')print(f'=========================')print(f'z = {z}')>print result:
  4. y >2= tensor([[False,True,True]])=========================
  5. z = tensor([[3,5,7]])

在这里,

  1. x

是一个标量,

  1. condition = y > 2

,你要是问我为什么不把

  1. condition

设为

  1. condition = x > 2

,很简单,

  1. x > 2

不是

  1. bool Tensor

这里标量和张量是可以进行广播的!!
example:

  1. a = torch.tensor([1,5,7])
  2. b =3
  3. c = a + b
  4. d = torch.tensor([3,3,3])
  5. e = a + d
  6. print(f'c = {c}')print(f'e = {e}')>print result:
  7. c = tensor([4,8,10])
  8. d = tensor([4,8,10])

其实就是把

  1. b = 3

拉成了

  1. [3, 3, 3]

,也是就

  1. d

那样。

1.3 形状不同

其实标量那里也算是形状不同了,这里我再啰嗦一下吧,看例子:

  1. x = torch.tensor([[1,3,5]])
  2. y = torch.tensor([[2],[4],[6]])
  3. z = torch.where(x >2, x, y)print(f'x = {x}')print(f'=========================')print(f'y = {y}')print(f'=========================')print(f'x > 2 = {x >2}')print(f'=========================')print(f'z = {z}')>print result:
  4. x = tensor([[1,3,5]])=========================
  5. y = tensor([[2],[4],[6]])=========================
  6. x >2= tensor([[False,True,True]])=========================
  7. z = tensor([[2,3,5],[4,3,5],[6,3,5]])

上面

  1. x.shape=(1, 3) y.shape=(3, 1)

,然后

  1. condition = x > 2

  1. shape=(1, 3)

,是可广播的,所以运算也能成功,在计算

  1. torch.where(x > 2, x, y)

时,分别对

  1. xycondition

进行广播,

  1. x.shape=(3, 3)

  1. y.shape=(3, 3)

  1. condition.shape=(3, 3)

在这里插入图片描述
所以

  1. y

的值替换第0列,第1、2列为

  1. x

的值。
更多的广播形式请读者朋友自行尝试


2、torch.where()特殊用法

  1. torch.where(a & b)
  1. a

  1. b

都是

  1. bool Tensor

,返回的是一个元组,元组第一项是

  1. ab

中都为

  1. True

  1. index

  1. Tensor

,第二项是

  1. ab

都为

  1. True

  1. index

  1. Tensor

请看例子:

  1. a = torch.tensor([[0,1,1],[1,0,0],[0,0,1]], dtype=torch.bool)
  2. b = torch.ones((3,3), dtype=torch.bool)
  3. c = torch.where(a & b)print(f'a = {a}')print(f'=========================')print(f'b = {b}')print(f'=========================')print(f'c = {c}')>print result:
  4. a = tensor([[False,True,True],[True,False,False],[False,False,True]])=========================
  5. b = tensor([[True,True,True],[True,True,True],[True,True,True]])=========================
  6. c =(tensor([0,0,1,2]), tensor([1,2,0,2]))
  1. c

就是一个元组,第0项是

  1. ab

都为

  1. True

行标,第1项是

  1. ab

都为

  1. True

列标


总结

以上就是torch.where()的两种用法,看起来比较麻烦,多练练也就是那样,特别一点的就是一个广播机制一个特殊用法,欢迎评论指正!
请尊重原创,拒绝转载!!!


参考链接

https://pytorch.org/docs/stable/generated/torch.where.html#torch.where
https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics
https://numpy.org/doc/stable/user/basics.broadcasting.html


本文转载自: https://blog.csdn.net/euqlll/article/details/127791397
版权归原作者 euqlll 所有, 如有侵权,请联系我们删除。

“torch.where()用法”的评论:

还没有评论