0


torch.where()用法

torch.where用法


前言

本文主要讲述

torch.where()

的两种用法,第一种是最常规的,也是官方文档所注明的;第二种就是配合

bool

型张量的计算


1、torch.where()常规用法

我们先看官方文档的解释:

torch.where(condition, x, y)

根据条件,也就是

condiction

,返回从

x

y

中选择的元素的张量(这里会创建一个新的张量,新张量的元素就是从x或y中选的,形状要符合

x

y

的广播条件)。

Parameters

解释如下:
1、

condition (bool型张量) 

:当

condition

为真,返回

x

的值,否则返回

y

的值
2、

x (张量或标量)

:当

condition=True

时选

x

的值
2、

y (张量或标量)

:当

condition=False

时选

y

的值

我看了好些博文,他们都说

x

y

的形状必须相同,完全胡扯嘛,官方文档写的明明白白的:The tensors condition, x, y must be broadcastable. 也就是说

condition、x、y

能进行广播就行,并不要求形状一样。下面看用法:

1.1 形状相同

先演示形状相同的情况:

import torch

x = torch.tensor([[1,2,3],[3,4,5],[5,6,7]])
y = torch.tensor([[5,6,7],[7,8,9],[9,10,11]])
z = torch.where(x >5, x, y)print(f'x = {x}')print(f'=========================')print(f'y = {y}')print(f'=========================')print(f'x > 5 = {x >5}')print(f'=========================')print(f'z = {z}')>print result:
x = tensor([[1,2,3],[3,4,5],[5,6,7]])=========================
y = tensor([[5,6,7],[7,8,9],[9,10,11]])=========================
x >5= tensor([[False,False,False],[False,False,False],[False,True,True]])=========================
z = tensor([[5,6,7],[7,8,9],[9,6,7]])

上面定义了

x

y

,两者的形状

shape=(3, 3)

相同,然后

condition = x > 5

也是就

x

中的每个元素值都要大于5,这里就能看到

x

中第0行和第1行都是

False

,只有第2行的1、2列是

True

,然后前面说了,为

True

时使用的是

x

中的值,为

False

时使用的是

y

中的值,那么新创建的

z

前两行和第2行0列使用的是

y

中的值,剩下两个使用

x

中的值,

z

shape

也是

(3, 3)

1.2 标量情况

x =3
y = torch.tensor([[1,5,7]])
z = torch.where(y >2, y, x)print(f'y > 2 = {y >2}')print(f'=========================')print(f'z = {z}')print(f'y > 2 = {y >2}')print(f'=========================')print(f'z = {z}')>print result:
y >2= tensor([[False,True,True]])=========================
z = tensor([[3,5,7]])

在这里,

x

是一个标量,

condition = y > 2

,你要是问我为什么不把

condition

设为

condition = x > 2

,很简单,

x > 2

不是

bool Tensor

这里标量和张量是可以进行广播的!!
example:

a = torch.tensor([1,5,7])
b =3
c = a + b
d = torch.tensor([3,3,3])
e = a + d

print(f'c = {c}')print(f'e = {e}')>print result:
c = tensor([4,8,10])
d = tensor([4,8,10])

其实就是把

b = 3

拉成了

[3, 3, 3]

,也是就

d

那样。

1.3 形状不同

其实标量那里也算是形状不同了,这里我再啰嗦一下吧,看例子:

x = torch.tensor([[1,3,5]])
y = torch.tensor([[2],[4],[6]])
z = torch.where(x >2, x, y)print(f'x = {x}')print(f'=========================')print(f'y = {y}')print(f'=========================')print(f'x > 2 = {x >2}')print(f'=========================')print(f'z = {z}')>print result:
x = tensor([[1,3,5]])=========================
y = tensor([[2],[4],[6]])=========================
x >2= tensor([[False,True,True]])=========================
z = tensor([[2,3,5],[4,3,5],[6,3,5]])

上面

x.shape=(1, 3) y.shape=(3, 1)

,然后

condition = x > 2

shape=(1, 3)

,是可广播的,所以运算也能成功,在计算

torch.where(x > 2, x, y)

时,分别对

x、y、condition

进行广播,

x.shape=(3, 3)

y.shape=(3, 3)

condition.shape=(3, 3)

在这里插入图片描述
所以

y

的值替换第0列,第1、2列为

x

的值。
更多的广播形式请读者朋友自行尝试


2、torch.where()特殊用法

torch.where(a & b)
a

b

都是

bool Tensor

,返回的是一个元组,元组第一项是

a、b

中都为

True

index

Tensor

,第二项是

a、b

都为

True

index

Tensor

请看例子:

a = torch.tensor([[0,1,1],[1,0,0],[0,0,1]], dtype=torch.bool)
b = torch.ones((3,3), dtype=torch.bool)
c = torch.where(a & b)print(f'a = {a}')print(f'=========================')print(f'b = {b}')print(f'=========================')print(f'c = {c}')>print result:
a = tensor([[False,True,True],[True,False,False],[False,False,True]])=========================
b = tensor([[True,True,True],[True,True,True],[True,True,True]])=========================
c =(tensor([0,0,1,2]), tensor([1,2,0,2]))
c

就是一个元组,第0项是

a、b

都为

True

行标,第1项是

a、b

都为

True

列标


总结

以上就是torch.where()的两种用法,看起来比较麻烦,多练练也就是那样,特别一点的就是一个广播机制一个特殊用法,欢迎评论指正!
请尊重原创,拒绝转载!!!


参考链接

https://pytorch.org/docs/stable/generated/torch.where.html#torch.where
https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics
https://numpy.org/doc/stable/user/basics.broadcasting.html


本文转载自: https://blog.csdn.net/euqlll/article/details/127791397
版权归原作者 euqlll 所有, 如有侵权,请联系我们删除。

“torch.where()用法”的评论:

还没有评论