torch.where用法
前言
本文主要讲述
torch.where()
的两种用法,第一种是最常规的,也是官方文档所注明的;第二种就是配合
bool
型张量的计算
1、torch.where()常规用法
我们先看官方文档的解释:
torch.where(condition, x, y)
根据条件,也就是
condiction
,返回从
x
或
y
中选择的元素的张量(这里会创建一个新的张量,新张量的元素就是从x或y中选的,形状要符合
x
和
y
的广播条件)。
Parameters
解释如下:
1、condition (bool型张量)
:当
condition
为真,返回
x
的值,否则返回
y
的值
2、x (张量或标量)
:当
condition=True
时选
x
的值
2、y (张量或标量)
:当
condition=False
时选
y
的值
我看了好些博文,他们都说
x
和
y
的形状必须相同,完全胡扯嘛,官方文档写的明明白白的:The tensors condition, x, y must be broadcastable. 也就是说
condition、x、y
能进行广播就行,并不要求形状一样。下面看用法:
1.1 形状相同
先演示形状相同的情况:
import torch
x = torch.tensor([[1,2,3],[3,4,5],[5,6,7]])
y = torch.tensor([[5,6,7],[7,8,9],[9,10,11]])
z = torch.where(x >5, x, y)print(f'x = {x}')print(f'=========================')print(f'y = {y}')print(f'=========================')print(f'x > 5 = {x >5}')print(f'=========================')print(f'z = {z}')>print result:
x = tensor([[1,2,3],[3,4,5],[5,6,7]])=========================
y = tensor([[5,6,7],[7,8,9],[9,10,11]])=========================
x >5= tensor([[False,False,False],[False,False,False],[False,True,True]])=========================
z = tensor([[5,6,7],[7,8,9],[9,6,7]])
上面定义了
x
和
y
,两者的形状
shape=(3, 3)
相同,然后
condition = x > 5
也是就
x
中的每个元素值都要大于5,这里就能看到
x
中第0行和第1行都是
False
,只有第2行的1、2列是
True
,然后前面说了,为
True
时使用的是
x
中的值,为
False
时使用的是
y
中的值,那么新创建的
z
前两行和第2行0列使用的是
y
中的值,剩下两个使用
x
中的值,
z
的
shape
也是
(3, 3)
。
1.2 标量情况
x =3
y = torch.tensor([[1,5,7]])
z = torch.where(y >2, y, x)print(f'y > 2 = {y >2}')print(f'=========================')print(f'z = {z}')print(f'y > 2 = {y >2}')print(f'=========================')print(f'z = {z}')>print result:
y >2= tensor([[False,True,True]])=========================
z = tensor([[3,5,7]])
在这里,
x
是一个标量,
condition = y > 2
,你要是问我为什么不把
condition
设为
condition = x > 2
,很简单,
x > 2
不是
bool Tensor
。这里标量和张量是可以进行广播的!!
example:
a = torch.tensor([1,5,7])
b =3
c = a + b
d = torch.tensor([3,3,3])
e = a + d
print(f'c = {c}')print(f'e = {e}')>print result:
c = tensor([4,8,10])
d = tensor([4,8,10])
其实就是把
b = 3
拉成了
[3, 3, 3]
,也是就
d
那样。
1.3 形状不同
其实标量那里也算是形状不同了,这里我再啰嗦一下吧,看例子:
x = torch.tensor([[1,3,5]])
y = torch.tensor([[2],[4],[6]])
z = torch.where(x >2, x, y)print(f'x = {x}')print(f'=========================')print(f'y = {y}')print(f'=========================')print(f'x > 2 = {x >2}')print(f'=========================')print(f'z = {z}')>print result:
x = tensor([[1,3,5]])=========================
y = tensor([[2],[4],[6]])=========================
x >2= tensor([[False,True,True]])=========================
z = tensor([[2,3,5],[4,3,5],[6,3,5]])
上面
x.shape=(1, 3) y.shape=(3, 1)
,然后
condition = x > 2
的
shape=(1, 3)
,是可广播的,所以运算也能成功,在计算
torch.where(x > 2, x, y)
时,分别对
x、y、condition
进行广播,
x.shape=(3, 3)
,
y.shape=(3, 3)
,
condition.shape=(3, 3)
所以
y
的值替换第0列,第1、2列为
x
的值。
更多的广播形式请读者朋友自行尝试
2、torch.where()特殊用法
torch.where(a & b)
a
和
b
都是
bool Tensor
,返回的是一个元组,元组第一项是
a、b
中都为
True
的行的
index
的
Tensor
,第二项是
a、b
都为
True
列的
index
的
Tensor
请看例子:
a = torch.tensor([[0,1,1],[1,0,0],[0,0,1]], dtype=torch.bool)
b = torch.ones((3,3), dtype=torch.bool)
c = torch.where(a & b)print(f'a = {a}')print(f'=========================')print(f'b = {b}')print(f'=========================')print(f'c = {c}')>print result:
a = tensor([[False,True,True],[True,False,False],[False,False,True]])=========================
b = tensor([[True,True,True],[True,True,True],[True,True,True]])=========================
c =(tensor([0,0,1,2]), tensor([1,2,0,2]))
c
就是一个元组,第0项是
a、b
都为
True
的行标,第1项是
a、b
都为
True
的列标
总结
以上就是torch.where()的两种用法,看起来比较麻烦,多练练也就是那样,特别一点的就是一个广播机制一个特殊用法,欢迎评论指正!
请尊重原创,拒绝转载!!!
参考链接
https://pytorch.org/docs/stable/generated/torch.where.html#torch.where
https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics
https://numpy.org/doc/stable/user/basics.broadcasting.html
版权归原作者 euqlll 所有, 如有侵权,请联系我们删除。