1. torch.cat介绍
torch.cat是PyTorch的一个重要函数,它能够将多个张量(tensor)拼接在一起,形成一个更大的张量。这个函数非常常用,尤其是在深度学习中。
2. torch.cat用法详解
2.1. torch.cat语法
语法为:
torch.cat(tensors, dim=0, out=None) -> Tensor
其中,tensors是需要拼接的张量序列,dim是指拼接的维度,out是指输出张量的名称(可选)。
2.2. torch.cat参数详解
2.2.1. tensors
tensors是需要拼接的张量序列。它可以是一个张量数组,也可以是一个张量的元组或者列表。
2.2.2. dim
dim是指把张量按照哪个维度进行拼接。
例如:张量A的shape为(3, 5),张量B的shape为(4, 5)。如果dim=0,则拼接后的张量shape为(7, 5);如果dim=1,则拼接后的张量shape为(3, 10)。
2.2.3. out
out是指输出张量的名称。如果用户不指定,则函数会自动创建一个输出张量,并返回此张量。如果指定了,则必须保证此张量的shape、dtype和device与拼接后的张量相同。
2.3. torch.cat使用示例
下面我们来看一个简单的示例。我们先创建两个张量:
import torch
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[7, 8, 9], [10, 11, 12]])
然后我们使用torch.cat将它们拼接在一起:
c = torch.cat([a, b])
print(c)
# tensor([[ 1, 2, 3],
# [ 4, 5, 6],
# [ 7, 8, 9],
# [10, 11, 12]])
在上述例子中,对于两个张量a和b,我们没有指定dim。它们被默认拼接在了dim=0的维度上。
2.4. torch.cat常见错误分析
2.4.1. 张量形状不匹配
如果要拼接的张量形状不匹配,则会抛出如下错误:
RuntimeError: Sizes of tensors must match except in dimension 0. Got 3 and 4 in dimension 1 at /tmp/pip-req-build-jle7vp7w/aten/src/TH/generic/THTensor.cpp:612
上述错误提示表明,在拼接的过程中,张量的形状不匹配。具体地,在第2个维度上,张量a的形状为3,而张量b的形状为4。
2.4.2. out指定的形状与拼接后的形状不匹配
如果用户指定了out张量,并且其形状与拼接后的形状不匹配,则会抛出如下错误:
RuntimeError: The expanded size of the tensor (7) must match the existing size (6) at non-singleton dimension 1. ...
上述错误提示表明,用户指定的输出张量具有6个元素,但拼接后的张量具有7个元素。
3. 总结
通过以上的介绍,我们了解了torch.cat的用法和特点。在实际开发中,根据需要调整拼接的维度设置能够让我们更好地进行张量的拼接。同时,也要注意形状和dtype的匹配问题。
下面是上述示例的完整代码:
import torch
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([[7, 8, 9], [10, 11, 12]])
c = torch.cat([a, b])
print(c)
# tensor([[ 1, 2, 3],
# [ 4, 5, 6],
# [ 7, 8, 9],
# [10, 11, 12]])