0


Pytorch 常见运算(mul、mm、dot、mv)

Pytorch 常见运算

1.矩阵与标量

矩阵(张量)每一个元素与标量进行操作。

import torch
a = torch.tensor([1,2])print(a+1)>>> tensor([2,3])

2.哈达玛积(mul)

两个相同尺寸的张量相乘,然后对应元素的相乘就是这个哈达玛积。

a = torch.tensor([1,2])
b = torch.tensor([2,3])print(a*b)print(torch.mul(a,b))>>> tensor([2,6])>>> tensor([2,6])

这个torch.mul()和*以及torch.dot()是等价的

当然,除法也是类似的:

a = torch.tensor([1.,2.])
b = torch.tensor([2.,3.])print(a/b)print(torch.div(a/b))>>> tensor([0.5000,0.6667])>>> tensor([0.5000,0.6667])

我们可以发现的torch.div()其实就是/, 类似的:torch.add就是+,torch.sub()就是-,不过符号的运算更简单常用。

3.矩阵乘法

在代码中矩阵相乘有三种写法:

  • torch.mm()
  • torch.matmul()
  • @
a = torch.tensor([1.,2.])
b = torch.tensor([2.,3.]).view(1,2)print(torch.mm(a, b))print(torch.matmul(a, b))print(a @ b)

输出结果:

tensor([[2.,3.],[4.,6.]])
tensor([[2.,3.],[4.,6.]])
tensor([[2.,3.],[4.,6.]])

上面的是对二维矩阵而言的,假如参与运算的是一个多维张量,那么只有torch.matmul()可以使用

torch.mv()等价于torch.mm(),不过不同的是mv适用与矩阵和向量相乘

在多维张量中,参与矩阵运算的其实只有后两个维度,前面的维度其实就像是索引一样,举个例子:

a = torch.rand((1,2,64,32))
b = torch.rand((1,2,32,64))print(torch.matmul(a, b).shape)>>> torch.Size([1,2,64,64])

4.幂与开方

a = torch.tensor([1.,2.])
b = torch.tensor([2.,3.])
c1 = a ** b
c2 = torch.pow(a, b)print(c1,c2)>>> tensor([1.,8.]) tensor([1.,8.])

5.对数运算

pytorch中log是以e自然数为底数的,然后log2和log10才是以2和10为底数的运算。

import numpy as np
print('对数运算')
a = torch.tensor([2,10,np.e])print(torch.log(a))print(torch.log2(a))print(torch.log10(a))>>> tensor([0.6931,2.3026,1.0000])>>> tensor([1.0000,3.3219,1.4427])>>> tensor([0.3010,1.0000,0.4343])

6.近似值运算

  • .ceil() 向上取整
  • .floor()向下取整
  • .trunc()取整数
  • .frac()取小数
  • .round()四舍五入
a = torch.tensor(1.2345)print(a.ceil())>>>tensor(2.)print(a.floor())>>> tensor(1.)print(a.trunc())>>> tensor(1.)print(a.frac())>>> tensor(0.2345)print(a.round())>>> tensor(1.)

7.剪裁运算

这个是让一个数,限制在你自己设置的一个范围内[min,max],小于min的话就被设置为min,大于max的话就被设置为max。这个操作在一些对抗生成网络中,好像是WGAN-GP,通过强行限制模型的参数的值。

a = torch.rand(5)print(a)print(a.clamp(0.3,0.7))

输出为:

tensor([0.5271,0.6924,0.9919,0.0095,0.0340])
tensor([0.5271,0.6924,0.7000,0.3000,0.3000])

在这里插入图片描述


本文转载自: https://blog.csdn.net/weixin_42010722/article/details/122333051
版权归原作者 今晚去我家住吧 所有, 如有侵权,请联系我们删除。

“Pytorch 常见运算(mul、mm、dot、mv)”的评论:

还没有评论