1. 数据拼接操作
1.1 torch.cat()
在PyTorch中,我们可以使用torch.cat()函数来对多个Tensor进行拼接,函数用法如下:
torch.cat(tensors, dim=0, *, out=None) -> Tensor
其中:
tensors:需要进行拼接的多个Tensor
dim:指定拼接的维度
out:指定输出的Tensor
需要注意的是,所有需要进行拼接的Tensor在除拼接维度之外的所有维度大小都相同,否则会出现拼接错误的情况。
下面我们通过一个例子来说明torch.cat()的使用方法。假设我们有两个维度为(2, 3)和(2, 2)的Tensor。
import torch
a = torch.randn(2, 3)
b = torch.randn(2, 2)
c = torch.cat((a, b), dim=1)
print("a:", a)
print("b:", b)
print("c:", c)
运行结果为:
a: tensor([[-0.2512, -0.8895, 1.2537],
[-0.8447, 2.3145, -1.3456]])
b: tensor([[ 0.3098, 0.7644],
[ 0.3992, -0.9416]])
c: tensor([[-0.2512, -0.8895, 1.2537, 0.3098, 0.7644],
[-0.8447, 2.3145, -1.3456, 0.3992, -0.9416]])
可以看到,在维度1上对a和b进行拼接,得到了一个维度为(2, 5)的Tensor。
1.2 torch.stack()
除了使用torch.cat()函数进行拼接之外,我们还可以使用torch.stack()函数将多个Tensor在新创建的维度中进行拼接,用法如下:
torch.stack(seq, dim=0, *, out=None) -> Tensor
其中:
seq:需要进行拼接的多个Tensor组成的序列
dim:指定拼接的维度
out:指定输出的Tensor
下面我们通过一个例子来说明torch.stack()的使用方法。假设我们有两个维度为(2, 2)的Tensor。
import torch
a = torch.randn(2, 2)
b = torch.randn(2, 2)
c = torch.stack((a, b), dim=0)
print("a:", a)
print("b:", b)
print("c:", c)
运行结果为:
a: tensor([[ 0.0421, -0.1276],
[-2.0306, 0.1373]])
b: tensor([[-1.2073, -0.3513],
[ 0.5535, -0.8656]])
c: tensor([[[ 0.0421, -0.1276],
[-2.0306, 0.1373]],
[[-1.2073, -0.3513],
[ 0.5535, -0.8656]]])
可以看到,torch.stack()会在新创建的维度0上对a和b进行拼接,得到了一个维度为(2, 2, 2)的Tensor。
2. 数据拆分操作
2.1 torch.chunk()
在PyTorch中,我们可以使用torch.chunk()函数对一个Tensor进行拆分,用法如下:
torch.chunk(input, chunks, dim=0) -> List of Tensors
其中:
input:需要进行拆分的Tensor
chunks:指定拆分后的块数
dim:指定拆分的维度
需要注意的是,input在指定维度上的大小必须是chunks的整数倍,否则会出现拆分错误的情况。
下面我们通过一个例子来说明torch.chunk()的使用方法。假设我们有一个维度为(2, 4)的Tensor。
import torch
a = torch.randn(2, 4)
b, c = torch.chunk(a, 2, dim=1)
print("a:", a)
print("b:", b)
print("c:", c)
运行结果为:
a: tensor([[-1.0190, 1.0677, -0.9337, 0.1867],
[ 0.3485, -0.3037, -0.4665, -1.3789]])
b: tensor([[-1.0190, 1.0677],
[ 0.3485, -0.3037]])
c: tensor([[-0.9337, 0.1867],
[-0.4665, -1.3789]])
可以看到,torch.chunk()在维度1上将a拆分成了大小相同的两个Tensor。
2.2 torch.split()
除了使用torch.chunk()函数进行拆分之外,我们还可以使用torch.split()函数将一个Tensor按照指定长度进行拆分,用法如下:
torch.split(tensor, split_size_or_sections, dim=0) -> List of Tensors
其中:
tensor:需要进行拆分的Tensor
split_size_or_sections:指定拆分后每个部分的大小或者每个部分的长度
dim:指定拆分的维度
下面我们通过一个例子来说明torch.split()的使用方法。假设我们有一个维度为(2, 4)的Tensor。
import torch
a = torch.randn(2, 4)
b, c = torch.split(a, 2, dim=1)
print("a:", a)
print("b:", b)
print("c:", c)
运行结果为:
a: tensor([[-0.2831, 0.4517, -0.3449, -0.2991],
[-0.6605, 0.2996, 0.8940, 0.1864]])
b: tensor([[-0.2831, 0.4517],
[-0.6605, 0.2996]])
c: tensor([[-0.3449, -0.2991],
[ 0.8940, 0.1864]])
可以看到,torch.split()在维度1上将a按照指定长度进行了拆分。