Pytorch学习: 张量进阶操作


2020.10.03: 因torch版本更新, 对gather描述进行了修正.

2021.03.11: 更新了对gather的描述.

Pytorch学习: 张量进阶操作

整理内容顺序来自龙龙老师的<深度学习与PyTorch入门实战教程>, 根据个人所需情况进行删减或扩充. 如果想要自己创建新的模块, 这些操作都是基本功, 需要掌握扎实.

拼接与拆分

cat

torch.cat(*tensor, dim)能在指定的维度上将tensor拼接:

a = torch.rand(2, 3)
b = torch.rand(5, 3)
print('a.shape:', a.shape)
print('b.shape:', b.shape)
print('a concat b:', torch.cat([a, b], dim=0).shape)
"""
a.shape: torch.Size([2, 3])
b.shape: torch.Size([5, 3])
a concat b: torch.Size([7, 3])
"""

当然, concat之前必须保证除了cat的维度的shape不同外, 其他维度的shape均相同.

stack

torch.cat()不同, torch.stack()是将tensor堆叠在一个新的维度上, 即创建一个新的维度:

a = torch.rand(5, 3)
b = torch.rand(5, 3)
print('a.shape:', a.shape)
print('b.shape:', b.shape)
print('a concat b:', torch.cat([a, b], dim=0).shape)
print('a stack b:', torch.stack([a, b], dim=0).shape)
"""
a.shape: torch.Size([5, 3])
b.shape: torch.Size([5, 3])
a concat b: torch.Size([10, 3])
a stack b: torch.Size([2, 5, 3])
"""

两个stack的tensor必须保持完全一致的维度.

split

torch.split()既可以按照指定dim的长度来拆分, 也可以按照类似步长的拆法来拆分.

当传入的参数是list时, 按照列表中指定的长度拆分:

a = torch.rand(32, 8)
b = torch.rand(32, 8)
c = torch.stack([a, b, b], dim=0)
print('c.shape:', c.shape)
s = c.split([1, 1, 1], dim=0)
for i, n in zip(['aa', 'bb', 'cc'], s):
    print('{}.shape:{}'.format(i, n.shape))
"""
c.shape: torch.Size([3, 32, 8])
aa.shape:torch.Size([1, 32, 8])
bb.shape:torch.Size([1, 32, 8])
cc.shape:torch.Size([1, 32, 8])
"""

这样能将c分别拆分为长度为1, 1, 1的tensor.

当传入的参数是int时, 意在使指定dim上每个tensor应该具有多长, 也可以理解为步长:

a = torch.rand(32, 8)
b = torch.rand(32, 8)
c = torch.stack([a, b, b], dim=0)
print('c.shape:', c.shape)
# s = c.split([1, 1, 1], dim=0)
s = c.split(2, dim=0)
for i, n in zip(['aa', 'bb', 'cc'], s):
    print('{}.shape:{}'.format(i, n.shape))
"""
c.shape: torch.Size([3, 32, 8])
aa.shape:torch.Size([2, 32, 8])
bb.shape:torch.Size([1, 32, 8])
"""

虽然在dim=0上tensor长度为3, 但是还是能够拆分成2个tensor. 如果传入的整数是3, 则只会得到aa.shape:torch.Size([3, 32, 8]), 即每个tensor应该在dim=0上有3个元素.

chunk

我认为torch.chunk()是为了和torch.split()同参数而防止歧义的区分函数.

torch.chunk()中需要传入的就是需要分多少个tensor了, 还是上面那个例子, 体会区别:

a = torch.rand(32, 8)
b = torch.rand(32, 8)
c = torch.stack([a, b, b], dim=0)
print('c.shape:', c.shape)
# s = c.split([1, 1, 1], dim=0)
s = c.chunk(3, dim=0)
for i, n in zip(['aa', 'bb', 'cc'], s):
    print('{}.shape:{}'.format(i, n.shape))
"""
c.shape: torch.Size([3, 32, 8])
aa.shape:torch.Size([1, 32, 8])
bb.shape:torch.Size([1, 32, 8])
cc.shape:torch.Size([1, 32, 8])
"""

数学运算

大多数操作与python中的基本类型操作区别不大, pytorch使用了运算符重载使得它们在tensor上保持相同的含义. pytorch也提供了函数为tensor做运算. 大部分函数仍然和numpy中的格式一致.

加减乘除

a = torch.rand(3, 4)
b = torch.rand(4)
print('a + b和add:', torch.all(torch.eq(a + b, torch.add(a, b))))
print('a - b和sub:', torch.all(torch.eq(a - b, torch.sub(a, b))))
print('a * b和mul:', torch.all(torch.eq(a * b, torch.mul(a, b))))
print('a + b和div:', torch.all(torch.eq(a / b, torch.div(a, b))))
"""
a + b和add: tensor(1, dtype=torch.uint8)
a - b和sub: tensor(1, dtype=torch.uint8)
a * b和mul: tensor(1, dtype=torch.uint8)
a + b和div: tensor(1, dtype=torch.uint8)
"""

当然上面的例子中说的乘法不是矩阵乘法, 而是点乘(element - wise).

矩阵乘法

如果要实现矩阵乘有三种方式:

  • torch.mm()仅限于2d - tensor.
  • torch.matmul()推荐使用.
  • @是重载的运算符.
a = torch.rand(2, 3)
b = torch.ones(3, 2)
print('mm in 2d:\n', torch.mm(a, b))
print('matmul:\n', torch.matmul(a, b))
print('@:\n', a @ b)
"""
mm in 2d:
 tensor([[0.8500, 0.8500],
        [1.0964, 1.0964]])
matmul:
 tensor([[0.8500, 0.8500],
        [1.0964, 1.0964]])
@:
 tensor([[0.8500, 0.8500],
        [1.0964, 1.0964]])
"""

记住, 如果tensor的维度比2大, 则默认在最后两个维度上进行运算, 也可以理解为对多个矩阵并行做矩阵乘法.

a = torch.rand(4, 3, 28, 64)
b = torch.rand(4, 3, 64, 28)
print('matmul(a, b).shape:', torch.matmul(a, b).shape)
# matmul(a, b).shape: torch.Size([4, 3, 28, 28])

指数和对数

既可以沿用python中的**做幂计算, 也可以使用Tensor.pow().

a = torch.full([2, 2], 3)
print('a.pow(2):\n', a.pow(2))
print('a ** 2:\n', a ** 2)
print('a ** 2.sqrt():\n', (a ** 2).sqrt())
print('a ** 2.rsqrt():\n', (a ** 2).rsqrt())
print('a ** 2 ** 0.5:\n', a ** 2 ** 0.5)
"""
a.pow(2):
 tensor([[9., 9.],
        [9., 9.]])
a ** 2:
 tensor([[9., 9.],
        [9., 9.]])
a ** 2.sqrt():
 tensor([[3., 3.],
        [3., 3.]])
a ** 2.rsqrt():
 tensor([[0.3333, 0.3333],
        [0.3333, 0.3333]])
a ** 2 ** 0.5:
 tensor([[4.7288, 4.7288],
        [4.7288, 4.7288]])
"""

Tensor.rsqrt()代表的是求平方根后的倒数.

对数指数也是一样的用法:

a = torch.ones(2, 3)
b = torch.exp(a)
print('exp:', b)
print('log:', torch.log(a))
"""
exp: tensor([[2.7183, 2.7183, 2.7183],
        [2.7183, 2.7183, 2.7183]])
log: tensor([[0., 0., 0.],
        [0., 0., 0.]])
"""

近似值

近似值会在有小数点进行取舍时用到.

a = torch.tensor(3.14)
for i in ['a.floor()', 'a.ceil()', 'a.trunc()', 'a.frac()', 'a.round()']:
    print(i + ':', eval(i))
"""
a.floor(): tensor(3.)
a.ceil(): tensor(4.)
a.trunc(): tensor(3.)
a.frac(): tensor(0.1400)
a.round(): tensor(3.)
"""
  • floor: 下取整.
  • ceil: 上取整.
  • trunc: 只要整数部分.
  • frac: 只要小数部分.
  • round: 四舍五入.

裁剪

这个操作在梯度裁剪非常常用. 当发生梯度爆炸或消失时, 使用梯度裁剪能将梯度控制在可控范围内.

在pytorch中,Tensor.clamp(min, max)函数作用等价于numpy的np.clip().

grad = torch.rand(2, 3) * 15
print('max:', grad.max())
print('min:', grad.min())
print('grad:\n', grad)
print('clamp(10):\n', grad.clamp(10))
print('clamp(0, 10):\n', grad.clamp(0, 10))
"""
max: tensor(13.1130)
min: tensor(1.6760)
grad:
 tensor([[10.3501, 13.1130,  1.6760],
        [ 3.1330,  7.5342,  4.7226]])
clamp(10):
 tensor([[10.3501, 13.1130, 10.0000],
        [10.0000, 10.0000, 10.0000]])
clamp(0, 10):
 tensor([[10.0000, 10.0000,  1.6760],
        [ 3.1330,  7.5342,  4.7226]])
"""

属性统计

norm

这里的norm指的不是标准化的那个normalization, 而是指的范数. 具体对范数的定义, 在这里就不再给出了, 请自己查询.

a = torch.full([8], 1)
b = a.view(2, 4)
c = a.view(2, 2, 2)
print('a.norm(1):', a.norm(1))
print('b.norm(1):', b.norm(1))
print('c.norm(1):', c.norm(1))
print('a.norm(2):', a.norm(2))
print('b.norm(2):', b.norm(2))
print('c.norm(2):', c.norm(2))
"""
a.norm(1): tensor(8.)
b.norm(1): tensor(8.)
c.norm(1): tensor(8.)
a.norm(2): tensor(2.8284)
b.norm(2): tensor(2.8284)
c.norm(2): tensor(2.8284)
"""

这样得到的结果都是标量, 也可以按照dim来求范数:

a = torch.full([8], 1)
b = a.view(2, 4)
c = a.view(2, 2, 2)
print('dim 1:')
print('b.norm(1):', b.norm(1, dim=1))
print('c.norm(1):', c.norm(1, dim=1))
print('b.norm(2):', b.norm(2, dim=1))
print('c.norm(2):', c.norm(2, dim=1))
"""
dim 1:
b.norm(1): tensor([4., 4.])
c.norm(1): tensor([[2., 2.],
        [2., 2.]])
b.norm(2): tensor([2., 2.])
c.norm(2): tensor([[1.4142, 1.4142],
        [1.4142, 1.4142]])
"""

mean / sum / min / max / prod

这一系列都是统计操作, 不是很难理解.

a = torch.arange(8).view(2, 4).float()
for i in ['a.min()', 'a.max()', 'a.mean()', 'a.prod()', 'a.sum()']:
    print(i + ':', eval(i))
"""
a.min(): tensor(0.)
a.max(): tensor(7.)
a.mean(): tensor(3.5000)
a.prod(): tensor(0.)
a.sum(): tensor(28.)
"""

argmin / argmax

在求min和max时经常有一种操作, 找到一个tensor中最大或最小的元素并返回其索引, Tensor.argmin()Tensor.argmax()就能实现这个功能. 不加参数默认为返回整个tensor中最大或最小的元素索引, 加dim后为沿着该维度切分tensor, 找到每个tensor最大或最小的元素索引.

a = torch.randn(2, 3)
print('a.argmin():', a.argmin())
print('a.argmin(dim=1):', a.argmin(dim=1))
print('a.argmax():', a.argmax())
print('a.argmax(dim=1):', a.argmax(dim=1))
"""
a.argmin(): tensor(2)
a.argmin(dim=1): tensor([2, 2])
a.argmax(): tensor(4)
a.argmax(dim=1): tensor([0, 1])
"""

dim / keepdim

dimkeepdim是作为参数放在前面所说的函数中的, 对于dim我们已经接触很多次了, 理解为沿着该维度进行某种操作. keepdim指的是在函数做完操作后还要不要维持原来的维度, 如:

a = torch.randn(4, 10)
print('a.argmin(dim=1):', a.argmax(dim=1))
print('a.argmin(dim=1, keepdim=True):\n', a.argmax(dim=1, keepdim=True))
print('a.argmin(dim=1).shape:', a.argmax(dim=1).shape)
print('a.argmin(dim=1, keepdim=True).shape:\n', a.argmax(dim=1, keepdim=True).shape)
"""
a.argmin(dim=1): tensor([7, 2, 8, 7])
a.argmin(dim=1, keepdim=True):
 tensor([[7],
        [2],
        [8],
        [7]])
a.argmin(dim=1).shape: torch.Size([4])
a.argmin(dim=1, keepdim=True).shape:
 torch.Size([4, 1])
"""

在保存维度后, 结果的shape是torch.Size([4, 1]), 否则会被自动消去, 即torch.Size([4]).

topk / kthvalue

Top - k也是很常用的操作, 函数能返回最大的前k个值的相关信息. 在pytorch中, Tensor.topk(k)能返回最值和它们的索引.

a = torch.rand(4, 10)
print('top3 in dim 1:\n', a.topk(3, dim=1))
"""
top3 in dim 1:
torch.return_types.topk(
values=tensor([[0.9354, 0.8272, 0.8214],
        [0.9918, 0.8757, 0.8410],
        [0.9744, 0.8817, 0.8365],
        [0.9985, 0.8475, 0.8181]]),
indices=tensor([[3, 6, 0],
        [6, 0, 8],
        [7, 0, 9],
        [5, 4, 6]]))
"""

因为输入的tensor大小为(4, 10), 并绑定dim=1, 所以返回了4条top3的value和index.

通过largest参数来控制选择最大值还是最小值, 当其为True时选择最大的k个值, False时选择最小的k个值.

a = torch.rand(4, 10)
print('top-3 in dim 1:\n', a.topk(3, dim=1, largest=False))
"""
top-3 in dim 1:
torch.return_types.topk(
values=tensor([[0.0362, 0.0593, 0.1323],
        [0.1024, 0.1237, 0.1397],
        [0.0160, 0.1430, 0.2003],
        [0.2674, 0.2686, 0.4065]]),
indices=tensor([[5, 6, 4],
        [5, 1, 9],
        [9, 8, 0],
        [9, 7, 4]]))

如果你只是想要从小到大排列第k个值, 那么Tensor.kthvalue能满足你的需求.

a = torch.rand(4, 5)
print('a:', a)
print('a.kthvalue(2):', a.kthvalue(2, dim=1))

"""
a: tensor([[0.2830, 0.2628, 0.2188, 0.9593, 0.6418],
        [0.8727, 0.2504, 0.8656, 0.3067, 0.6215],
        [0.0977, 0.7201, 0.1081, 0.2605, 0.7691],
        [0.4776, 0.9503, 0.9577, 0.4100, 0.6476]])
a.kthvalue(2): torch.return_types.kthvalue(
values=tensor([0.2628, 0.3067, 0.1081, 0.4776]),
indices=tensor([1, 3, 2, 0]))
"""

compare

和python中的比较大小运算符一样, 有>, >=, <, <=, !=, ==. 它们都能比较tensor之间的大小关系, 它们也都有相应的缩写函数. 这里不详细说了, 非要用函数再从网上查就可以了.

a = torch.randn(2, 3)
print(a)
print(a > 0.5)
"""
tensor([[ 0.1539, -0.9071, -1.6426],
        [ 0.7671, -1.7312, -0.8053]])
tensor([[0, 0, 0],
        [1, 0, 0]], dtype=torch.uint8)
"""

有一特殊函数torch.eq(a, b)torch.equal(a, b)的返回值是不一样的, 前者返回Booltensor, 后者返回一个逻辑值, 只有全部相同时才为True, 否则为False.

a = torch.rand(2, 3)
print(torch.eq(a, a))
print(torch.equal(a, a))
"""
tensor([[1, 1, 1],
        [1, 1, 1]], dtype=torch.uint8)
True
"""

高阶操作

where

where的用法是torch.where(condition, x, y), 返回对应位置上符合condition的tensor. 如果符合condition, 则对应元素为x, 否则为y.

condition = torch.rand(2, 2)
print(condition)
a = torch.zeros(2, 2)
b = torch.ones(2, 2)
print(torch.where(condition > 0.5, a, b))
"""
tensor([[0.0559, 0.1984],
        [0.9894, 0.2738]])
tensor([[1., 1.],
        [0., 1.]])
"""

gather

torch.gather(input, dim, index)的作用是将输入tensor按照指定的dim和index重新组合, 类似于查表, 返回一个新的tensor. 它能收集特定维度的指定位置的数值. 这种操作常用于将概率转化为具体的类. 理解起来比较抽象.

prob = torch.randn(4, 10)
index = prob.topk(3, dim=1)[1]
print('index:', index)
label = torch.arange(10) + 100
print('取到的classes:', torch.gather(label.expand(4, 10), dim=1, index=index.long()))
"""
index: tensor([[5, 3, 2],
        [0, 6, 7],
        [1, 4, 5],
        [2, 9, 7]])
取到的classes: tensor([[105, 103, 102],
        [100, 106, 107],
        [101, 104, 105],
        [102, 109, 107]])
"""

index传入的数据类型必须是torch.LongTensor.

再说的通俗一点, index的作用就像一个Mask一样, 存储的是指定dim上的位置索引, 再看一个更简单的例子:

a = torch.Tensor([[1,2,3],[4,5,6]])
index_1 = torch.LongTensor([[0,1,1],[0,1,0]])
index_2 = torch.LongTensor([[0,1],[2,0]])
print(a)
print(torch.gather(a, dim=0, index=index_1))
print(torch.gather(a, dim=1, index=index_2))
"""
tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([[1., 5., 6.],
        [1., 5., 3.]])
tensor([[1., 2.],
        [6., 4.]])
"""

如果还没明白, 参照官网给的3d - tensor解释结合起来:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

以官网的角度说一下上面那个例子. 假设gather后的向量名为new_tensor, 我们可以手动模拟这个过程和构造出gather相同的结果.

dim=0时, 对于a来说, 第0维位置index是通过查表获得的.

"""
out[i][j] = input[index[i][j]][j] if dim == 0

a = torch.Tensor([[1,2,3],[4,5,6]])
index_1 = torch.LongTensor([[0,1,1],[0,1,0]])
"""
new_tensor = torch.tensor([
    [a[0][0], a[1][1], a[1][2]],
    [a[0][0], a[1][1], a[0][2]]
])
print(new_tensor)
"""
tensor([[1., 5., 6.],
        [1., 5., 3.]])
"""

dim=1时, 对于a来说, 第1维位置index是通过查表获得的.

"""
out[i][j] = input[i][index[i][j]] if dim == 1

a = torch.Tensor([[1,2,3],[4,5,6]])
index_1 = torch.LongTensor([[0,1,1],[0,1,0]])
"""
new_tensor = torch.tensor([
    [a[0][0], a[0][1]],
    [a[1][2], a[1][0]]
])
print(new_tensor)
"""
tensor([[1., 2.],
        [6., 4.]])
"""

index能遍历到沿dim上所有的元素, 最起码应该保证index和输入在除去指定的dim外其他dim上shape相同.


文章作者: DaNing
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 DaNing !
评论
 上一篇
ELMo, GPT, BERT ELMo, GPT, BERT
本文的前置知识: RNN Transformer Language Model ELMo, GPT, BERT本文是对ELMo, GPT, BERT三个模型的结构介绍以及个人理解, 多图预警. Introduction由于NLP领域
2020-10-04
下一篇 
Pytorch学习: 张量基础操作 Pytorch学习: 张量基础操作
Pytorch学习: 张量基础操作整理内容顺序来自龙龙老师的<深度学习与PyTorch入门实战教程>, 根据个人所需情况进行删减或扩充. 如果想要自己创建新的模块, 这些操作都是基本功, 需要掌握扎实. 张量数据类型下表摘自Py
2020-10-02
  目录