2020.10.03: 因torch版本更新, 对gather描述进行了修正.
2021.03.11: 更新了对gather的描述.
Pytorch学习: 张量进阶操作
整理内容顺序来自龙龙老师的<深度学习与PyTorch入门实战教程>
, 根据个人所需情况进行删减或扩充. 如果想要自己创建新的模块, 这些操作都是基本功, 需要掌握扎实.
拼接与拆分
cat
torch.cat(*tensor, dim)
能在指定的维度上将tensor拼接:
a = torch.rand(2, 3)
b = torch.rand(5, 3)
print('a.shape:', a.shape)
print('b.shape:', b.shape)
print('a concat b:', torch.cat([a, b], dim=0).shape)
"""
a.shape: torch.Size([2, 3])
b.shape: torch.Size([5, 3])
a concat b: torch.Size([7, 3])
"""
当然, concat之前必须保证除了cat的维度的shape不同外, 其他维度的shape均相同.
stack
与torch.cat()
不同, torch.stack()
是将tensor堆叠在一个新的维度上, 即创建一个新的维度:
a = torch.rand(5, 3)
b = torch.rand(5, 3)
print('a.shape:', a.shape)
print('b.shape:', b.shape)
print('a concat b:', torch.cat([a, b], dim=0).shape)
print('a stack b:', torch.stack([a, b], dim=0).shape)
"""
a.shape: torch.Size([5, 3])
b.shape: torch.Size([5, 3])
a concat b: torch.Size([10, 3])
a stack b: torch.Size([2, 5, 3])
"""
两个stack的tensor必须保持完全一致的维度.
split
torch.split()
既可以按照指定dim的长度来拆分, 也可以按照类似步长的拆法来拆分.
当传入的参数是list
时, 按照列表中指定的长度拆分:
a = torch.rand(32, 8)
b = torch.rand(32, 8)
c = torch.stack([a, b, b], dim=0)
print('c.shape:', c.shape)
s = c.split([1, 1, 1], dim=0)
for i, n in zip(['aa', 'bb', 'cc'], s):
print('{}.shape:{}'.format(i, n.shape))
"""
c.shape: torch.Size([3, 32, 8])
aa.shape:torch.Size([1, 32, 8])
bb.shape:torch.Size([1, 32, 8])
cc.shape:torch.Size([1, 32, 8])
"""
这样能将c分别拆分为长度为1, 1, 1的tensor.
当传入的参数是int
时, 意在使指定dim上每个tensor应该具有多长, 也可以理解为步长:
a = torch.rand(32, 8)
b = torch.rand(32, 8)
c = torch.stack([a, b, b], dim=0)
print('c.shape:', c.shape)
# s = c.split([1, 1, 1], dim=0)
s = c.split(2, dim=0)
for i, n in zip(['aa', 'bb', 'cc'], s):
print('{}.shape:{}'.format(i, n.shape))
"""
c.shape: torch.Size([3, 32, 8])
aa.shape:torch.Size([2, 32, 8])
bb.shape:torch.Size([1, 32, 8])
"""
虽然在dim=0上tensor长度为3, 但是还是能够拆分成2个tensor. 如果传入的整数是3, 则只会得到aa.shape:torch.Size([3, 32, 8])
, 即每个tensor应该在dim=0上有3个元素.
chunk
我认为torch.chunk()
是为了和torch.split()
同参数而防止歧义的区分函数.
torch.chunk()
中需要传入的就是需要分多少个tensor了, 还是上面那个例子, 体会区别:
a = torch.rand(32, 8)
b = torch.rand(32, 8)
c = torch.stack([a, b, b], dim=0)
print('c.shape:', c.shape)
# s = c.split([1, 1, 1], dim=0)
s = c.chunk(3, dim=0)
for i, n in zip(['aa', 'bb', 'cc'], s):
print('{}.shape:{}'.format(i, n.shape))
"""
c.shape: torch.Size([3, 32, 8])
aa.shape:torch.Size([1, 32, 8])
bb.shape:torch.Size([1, 32, 8])
cc.shape:torch.Size([1, 32, 8])
"""
数学运算
大多数操作与python中的基本类型操作区别不大, pytorch使用了运算符重载使得它们在tensor上保持相同的含义. pytorch也提供了函数为tensor做运算. 大部分函数仍然和numpy中的格式一致.
加减乘除
a = torch.rand(3, 4)
b = torch.rand(4)
print('a + b和add:', torch.all(torch.eq(a + b, torch.add(a, b))))
print('a - b和sub:', torch.all(torch.eq(a - b, torch.sub(a, b))))
print('a * b和mul:', torch.all(torch.eq(a * b, torch.mul(a, b))))
print('a + b和div:', torch.all(torch.eq(a / b, torch.div(a, b))))
"""
a + b和add: tensor(1, dtype=torch.uint8)
a - b和sub: tensor(1, dtype=torch.uint8)
a * b和mul: tensor(1, dtype=torch.uint8)
a + b和div: tensor(1, dtype=torch.uint8)
"""
当然上面的例子中说的乘法不是矩阵乘法, 而是点乘(element - wise).
矩阵乘法
如果要实现矩阵乘有三种方式:
torch.mm()
仅限于2d - tensor.torch.matmul()
推荐使用.@
是重载的运算符.
a = torch.rand(2, 3)
b = torch.ones(3, 2)
print('mm in 2d:\n', torch.mm(a, b))
print('matmul:\n', torch.matmul(a, b))
print('@:\n', a @ b)
"""
mm in 2d:
tensor([[0.8500, 0.8500],
[1.0964, 1.0964]])
matmul:
tensor([[0.8500, 0.8500],
[1.0964, 1.0964]])
@:
tensor([[0.8500, 0.8500],
[1.0964, 1.0964]])
"""
记住, 如果tensor的维度比2大, 则默认在最后两个维度上进行运算, 也可以理解为对多个矩阵并行做矩阵乘法.
a = torch.rand(4, 3, 28, 64)
b = torch.rand(4, 3, 64, 28)
print('matmul(a, b).shape:', torch.matmul(a, b).shape)
# matmul(a, b).shape: torch.Size([4, 3, 28, 28])
指数和对数
既可以沿用python中的**
做幂计算, 也可以使用Tensor.pow()
.
a = torch.full([2, 2], 3)
print('a.pow(2):\n', a.pow(2))
print('a ** 2:\n', a ** 2)
print('a ** 2.sqrt():\n', (a ** 2).sqrt())
print('a ** 2.rsqrt():\n', (a ** 2).rsqrt())
print('a ** 2 ** 0.5:\n', a ** 2 ** 0.5)
"""
a.pow(2):
tensor([[9., 9.],
[9., 9.]])
a ** 2:
tensor([[9., 9.],
[9., 9.]])
a ** 2.sqrt():
tensor([[3., 3.],
[3., 3.]])
a ** 2.rsqrt():
tensor([[0.3333, 0.3333],
[0.3333, 0.3333]])
a ** 2 ** 0.5:
tensor([[4.7288, 4.7288],
[4.7288, 4.7288]])
"""
Tensor.rsqrt()
代表的是求平方根后的倒数.
对数指数也是一样的用法:
a = torch.ones(2, 3)
b = torch.exp(a)
print('exp:', b)
print('log:', torch.log(a))
"""
exp: tensor([[2.7183, 2.7183, 2.7183],
[2.7183, 2.7183, 2.7183]])
log: tensor([[0., 0., 0.],
[0., 0., 0.]])
"""
近似值
近似值会在有小数点进行取舍时用到.
a = torch.tensor(3.14)
for i in ['a.floor()', 'a.ceil()', 'a.trunc()', 'a.frac()', 'a.round()']:
print(i + ':', eval(i))
"""
a.floor(): tensor(3.)
a.ceil(): tensor(4.)
a.trunc(): tensor(3.)
a.frac(): tensor(0.1400)
a.round(): tensor(3.)
"""
floor
: 下取整.ceil
: 上取整.trunc
: 只要整数部分.frac
: 只要小数部分.round
: 四舍五入.
裁剪
这个操作在梯度裁剪非常常用. 当发生梯度爆炸或消失时, 使用梯度裁剪能将梯度控制在可控范围内.
在pytorch中,Tensor.clamp(min, max)
函数作用等价于numpy的np.clip()
.
grad = torch.rand(2, 3) * 15
print('max:', grad.max())
print('min:', grad.min())
print('grad:\n', grad)
print('clamp(10):\n', grad.clamp(10))
print('clamp(0, 10):\n', grad.clamp(0, 10))
"""
max: tensor(13.1130)
min: tensor(1.6760)
grad:
tensor([[10.3501, 13.1130, 1.6760],
[ 3.1330, 7.5342, 4.7226]])
clamp(10):
tensor([[10.3501, 13.1130, 10.0000],
[10.0000, 10.0000, 10.0000]])
clamp(0, 10):
tensor([[10.0000, 10.0000, 1.6760],
[ 3.1330, 7.5342, 4.7226]])
"""
属性统计
norm
这里的norm指的不是标准化的那个normalization, 而是指的范数. 具体对范数的定义, 在这里就不再给出了, 请自己查询.
a = torch.full([8], 1)
b = a.view(2, 4)
c = a.view(2, 2, 2)
print('a.norm(1):', a.norm(1))
print('b.norm(1):', b.norm(1))
print('c.norm(1):', c.norm(1))
print('a.norm(2):', a.norm(2))
print('b.norm(2):', b.norm(2))
print('c.norm(2):', c.norm(2))
"""
a.norm(1): tensor(8.)
b.norm(1): tensor(8.)
c.norm(1): tensor(8.)
a.norm(2): tensor(2.8284)
b.norm(2): tensor(2.8284)
c.norm(2): tensor(2.8284)
"""
这样得到的结果都是标量, 也可以按照dim来求范数:
a = torch.full([8], 1)
b = a.view(2, 4)
c = a.view(2, 2, 2)
print('dim 1:')
print('b.norm(1):', b.norm(1, dim=1))
print('c.norm(1):', c.norm(1, dim=1))
print('b.norm(2):', b.norm(2, dim=1))
print('c.norm(2):', c.norm(2, dim=1))
"""
dim 1:
b.norm(1): tensor([4., 4.])
c.norm(1): tensor([[2., 2.],
[2., 2.]])
b.norm(2): tensor([2., 2.])
c.norm(2): tensor([[1.4142, 1.4142],
[1.4142, 1.4142]])
"""
mean / sum / min / max / prod
这一系列都是统计操作, 不是很难理解.
a = torch.arange(8).view(2, 4).float()
for i in ['a.min()', 'a.max()', 'a.mean()', 'a.prod()', 'a.sum()']:
print(i + ':', eval(i))
"""
a.min(): tensor(0.)
a.max(): tensor(7.)
a.mean(): tensor(3.5000)
a.prod(): tensor(0.)
a.sum(): tensor(28.)
"""
argmin / argmax
在求min和max时经常有一种操作, 找到一个tensor中最大或最小的元素并返回其索引, Tensor.argmin()
和Tensor.argmax()
就能实现这个功能. 不加参数默认为返回整个tensor中最大或最小的元素索引, 加dim后为沿着该维度切分tensor, 找到每个tensor最大或最小的元素索引.
a = torch.randn(2, 3)
print('a.argmin():', a.argmin())
print('a.argmin(dim=1):', a.argmin(dim=1))
print('a.argmax():', a.argmax())
print('a.argmax(dim=1):', a.argmax(dim=1))
"""
a.argmin(): tensor(2)
a.argmin(dim=1): tensor([2, 2])
a.argmax(): tensor(4)
a.argmax(dim=1): tensor([0, 1])
"""
dim / keepdim
dim
和keepdim
是作为参数放在前面所说的函数中的, 对于dim
我们已经接触很多次了, 理解为沿着该维度进行某种操作. keepdim
指的是在函数做完操作后还要不要维持原来的维度, 如:
a = torch.randn(4, 10)
print('a.argmin(dim=1):', a.argmax(dim=1))
print('a.argmin(dim=1, keepdim=True):\n', a.argmax(dim=1, keepdim=True))
print('a.argmin(dim=1).shape:', a.argmax(dim=1).shape)
print('a.argmin(dim=1, keepdim=True).shape:\n', a.argmax(dim=1, keepdim=True).shape)
"""
a.argmin(dim=1): tensor([7, 2, 8, 7])
a.argmin(dim=1, keepdim=True):
tensor([[7],
[2],
[8],
[7]])
a.argmin(dim=1).shape: torch.Size([4])
a.argmin(dim=1, keepdim=True).shape:
torch.Size([4, 1])
"""
在保存维度后, 结果的shape是torch.Size([4, 1])
, 否则会被自动消去, 即torch.Size([4])
.
topk / kthvalue
Top - k也是很常用的操作, 函数能返回最大的前k个值的相关信息. 在pytorch中, Tensor.topk(k)
能返回最值和它们的索引.
a = torch.rand(4, 10)
print('top3 in dim 1:\n', a.topk(3, dim=1))
"""
top3 in dim 1:
torch.return_types.topk(
values=tensor([[0.9354, 0.8272, 0.8214],
[0.9918, 0.8757, 0.8410],
[0.9744, 0.8817, 0.8365],
[0.9985, 0.8475, 0.8181]]),
indices=tensor([[3, 6, 0],
[6, 0, 8],
[7, 0, 9],
[5, 4, 6]]))
"""
因为输入的tensor大小为(4, 10), 并绑定dim=1, 所以返回了4条top3的value和index.
通过largest
参数来控制选择最大值还是最小值, 当其为True
时选择最大的k个值, False
时选择最小的k个值.
a = torch.rand(4, 10)
print('top-3 in dim 1:\n', a.topk(3, dim=1, largest=False))
"""
top-3 in dim 1:
torch.return_types.topk(
values=tensor([[0.0362, 0.0593, 0.1323],
[0.1024, 0.1237, 0.1397],
[0.0160, 0.1430, 0.2003],
[0.2674, 0.2686, 0.4065]]),
indices=tensor([[5, 6, 4],
[5, 1, 9],
[9, 8, 0],
[9, 7, 4]]))
如果你只是想要从小到大排列第k个值, 那么Tensor.kthvalue
能满足你的需求.
a = torch.rand(4, 5)
print('a:', a)
print('a.kthvalue(2):', a.kthvalue(2, dim=1))
"""
a: tensor([[0.2830, 0.2628, 0.2188, 0.9593, 0.6418],
[0.8727, 0.2504, 0.8656, 0.3067, 0.6215],
[0.0977, 0.7201, 0.1081, 0.2605, 0.7691],
[0.4776, 0.9503, 0.9577, 0.4100, 0.6476]])
a.kthvalue(2): torch.return_types.kthvalue(
values=tensor([0.2628, 0.3067, 0.1081, 0.4776]),
indices=tensor([1, 3, 2, 0]))
"""
compare
和python中的比较大小运算符一样, 有>
, >=
, <
, <=
, !=
, ==
. 它们都能比较tensor之间的大小关系, 它们也都有相应的缩写函数. 这里不详细说了, 非要用函数再从网上查就可以了.
a = torch.randn(2, 3)
print(a)
print(a > 0.5)
"""
tensor([[ 0.1539, -0.9071, -1.6426],
[ 0.7671, -1.7312, -0.8053]])
tensor([[0, 0, 0],
[1, 0, 0]], dtype=torch.uint8)
"""
有一特殊函数torch.eq(a, b)
和torch.equal(a, b)
的返回值是不一样的, 前者返回Booltensor, 后者返回一个逻辑值, 只有全部相同时才为True, 否则为False.
a = torch.rand(2, 3)
print(torch.eq(a, a))
print(torch.equal(a, a))
"""
tensor([[1, 1, 1],
[1, 1, 1]], dtype=torch.uint8)
True
"""
高阶操作
where
where的用法是torch.where(condition, x, y)
, 返回对应位置上符合condition的tensor. 如果符合condition, 则对应元素为x, 否则为y.
condition = torch.rand(2, 2)
print(condition)
a = torch.zeros(2, 2)
b = torch.ones(2, 2)
print(torch.where(condition > 0.5, a, b))
"""
tensor([[0.0559, 0.1984],
[0.9894, 0.2738]])
tensor([[1., 1.],
[0., 1.]])
"""
gather
torch.gather(input, dim, index)
的作用是将输入tensor按照指定的dim和index重新组合, 类似于查表, 返回一个新的tensor. 它能收集特定维度的指定位置的数值. 这种操作常用于将概率转化为具体的类. 理解起来比较抽象.
prob = torch.randn(4, 10)
index = prob.topk(3, dim=1)[1]
print('index:', index)
label = torch.arange(10) + 100
print('取到的classes:', torch.gather(label.expand(4, 10), dim=1, index=index.long()))
"""
index: tensor([[5, 3, 2],
[0, 6, 7],
[1, 4, 5],
[2, 9, 7]])
取到的classes: tensor([[105, 103, 102],
[100, 106, 107],
[101, 104, 105],
[102, 109, 107]])
"""
index
传入的数据类型必须是torch.LongTensor
.
再说的通俗一点, index
的作用就像一个Mask一样, 存储的是指定dim上的位置索引, 再看一个更简单的例子:
a = torch.Tensor([[1,2,3],[4,5,6]])
index_1 = torch.LongTensor([[0,1,1],[0,1,0]])
index_2 = torch.LongTensor([[0,1],[2,0]])
print(a)
print(torch.gather(a, dim=0, index=index_1))
print(torch.gather(a, dim=1, index=index_2))
"""
tensor([[1., 2., 3.],
[4., 5., 6.]])
tensor([[1., 5., 6.],
[1., 5., 3.]])
tensor([[1., 2.],
[6., 4.]])
"""
如果还没明白, 参照官网给的3d - tensor解释结合起来:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
以官网的角度说一下上面那个例子. 假设gather后的向量名为new_tensor
, 我们可以手动模拟这个过程和构造出gather相同的结果.
dim=0时, 对于a来说, 第0维位置index是通过查表获得的.
"""
out[i][j] = input[index[i][j]][j] if dim == 0
a = torch.Tensor([[1,2,3],[4,5,6]])
index_1 = torch.LongTensor([[0,1,1],[0,1,0]])
"""
new_tensor = torch.tensor([
[a[0][0], a[1][1], a[1][2]],
[a[0][0], a[1][1], a[0][2]]
])
print(new_tensor)
"""
tensor([[1., 5., 6.],
[1., 5., 3.]])
"""
dim=1时, 对于a来说, 第1维位置index是通过查表获得的.
"""
out[i][j] = input[i][index[i][j]] if dim == 1
a = torch.Tensor([[1,2,3],[4,5,6]])
index_1 = torch.LongTensor([[0,1,1],[0,1,0]])
"""
new_tensor = torch.tensor([
[a[0][0], a[0][1]],
[a[1][2], a[1][0]]
])
print(new_tensor)
"""
tensor([[1., 2.],
[6., 4.]])
"""
index能遍历到沿dim上所有的元素, 最起码应该保证index和输入在除去指定的dim外其他dim上shape相同.