PyTorch的torch.cat用法

1. torch.cat介绍

torch.cat是PyTorch的一个重要函数,它能够将多个张量(tensor)拼接在一起,形成一个更大的张量。这个函数非常常用,尤其是在深度学习中。

2. torch.cat用法详解

2.1. torch.cat语法

语法为:

torch.cat(tensors, dim=0, out=None) -> Tensor

其中,tensors是需要拼接的张量序列,dim是指拼接的维度,out是指输出张量的名称(可选)。

2.2. torch.cat参数详解

2.2.1. tensors

tensors是需要拼接的张量序列。它可以是一个张量数组,也可以是一个张量的元组或者列表。

2.2.2. dim

dim是指把张量按照哪个维度进行拼接。

例如:张量A的shape为(3, 5),张量B的shape为(4, 5)。如果dim=0,则拼接后的张量shape为(7, 5);如果dim=1,则拼接后的张量shape为(3, 10)。

2.2.3. out

out是指输出张量的名称。如果用户不指定,则函数会自动创建一个输出张量,并返回此张量。如果指定了,则必须保证此张量的shape、dtype和device与拼接后的张量相同。

2.3. torch.cat使用示例

下面我们来看一个简单的示例。我们先创建两个张量:

import torch

a = torch.tensor([[1, 2, 3], [4, 5, 6]])

b = torch.tensor([[7, 8, 9], [10, 11, 12]])

然后我们使用torch.cat将它们拼接在一起:

c = torch.cat([a, b])

print(c)

# tensor([[ 1, 2, 3],

# [ 4, 5, 6],

# [ 7, 8, 9],

# [10, 11, 12]])

在上述例子中,对于两个张量a和b,我们没有指定dim。它们被默认拼接在了dim=0的维度上。

2.4. torch.cat常见错误分析

2.4.1. 张量形状不匹配

如果要拼接的张量形状不匹配,则会抛出如下错误:

RuntimeError: Sizes of tensors must match except in dimension 0. Got 3 and 4 in dimension 1 at /tmp/pip-req-build-jle7vp7w/aten/src/TH/generic/THTensor.cpp:612

上述错误提示表明,在拼接的过程中,张量的形状不匹配。具体地,在第2个维度上,张量a的形状为3,而张量b的形状为4。

2.4.2. out指定的形状与拼接后的形状不匹配

如果用户指定了out张量,并且其形状与拼接后的形状不匹配,则会抛出如下错误:

RuntimeError: The expanded size of the tensor (7) must match the existing size (6) at non-singleton dimension 1. ...

上述错误提示表明,用户指定的输出张量具有6个元素,但拼接后的张量具有7个元素。

3. 总结

通过以上的介绍,我们了解了torch.cat的用法和特点。在实际开发中,根据需要调整拼接的维度设置能够让我们更好地进行张量的拼接。同时,也要注意形状和dtype的匹配问题。

下面是上述示例的完整代码:

import torch

a = torch.tensor([[1, 2, 3], [4, 5, 6]])

b = torch.tensor([[7, 8, 9], [10, 11, 12]])

c = torch.cat([a, b])

print(c)

# tensor([[ 1, 2, 3],

# [ 4, 5, 6],

# [ 7, 8, 9],

# [10, 11, 12]])

后端开发标签