Pytorch数据拼接与拆分操作实现图解

1. 数据拼接操作

1.1 torch.cat()

在PyTorch中,我们可以使用torch.cat()函数来对多个Tensor进行拼接,函数用法如下:

torch.cat(tensors, dim=0, *, out=None) -> Tensor

其中:

tensors:需要进行拼接的多个Tensor

dim:指定拼接的维度

out:指定输出的Tensor

需要注意的是,所有需要进行拼接的Tensor在除拼接维度之外的所有维度大小都相同,否则会出现拼接错误的情况。

下面我们通过一个例子来说明torch.cat()的使用方法。假设我们有两个维度为(2, 3)和(2, 2)的Tensor。

import torch

a = torch.randn(2, 3)

b = torch.randn(2, 2)

c = torch.cat((a, b), dim=1)

print("a:", a)

print("b:", b)

print("c:", c)

运行结果为:

a: tensor([[-0.2512, -0.8895, 1.2537],

[-0.8447, 2.3145, -1.3456]])

b: tensor([[ 0.3098, 0.7644],

[ 0.3992, -0.9416]])

c: tensor([[-0.2512, -0.8895, 1.2537, 0.3098, 0.7644],

[-0.8447, 2.3145, -1.3456, 0.3992, -0.9416]])

可以看到,在维度1上对a和b进行拼接,得到了一个维度为(2, 5)的Tensor。

1.2 torch.stack()

除了使用torch.cat()函数进行拼接之外,我们还可以使用torch.stack()函数将多个Tensor在新创建的维度中进行拼接,用法如下:

torch.stack(seq, dim=0, *, out=None) -> Tensor

其中:

seq:需要进行拼接的多个Tensor组成的序列

dim:指定拼接的维度

out:指定输出的Tensor

下面我们通过一个例子来说明torch.stack()的使用方法。假设我们有两个维度为(2, 2)的Tensor。

import torch

a = torch.randn(2, 2)

b = torch.randn(2, 2)

c = torch.stack((a, b), dim=0)

print("a:", a)

print("b:", b)

print("c:", c)

运行结果为:

a: tensor([[ 0.0421, -0.1276],

[-2.0306, 0.1373]])

b: tensor([[-1.2073, -0.3513],

[ 0.5535, -0.8656]])

c: tensor([[[ 0.0421, -0.1276],

[-2.0306, 0.1373]],

[[-1.2073, -0.3513],

[ 0.5535, -0.8656]]])

可以看到,torch.stack()会在新创建的维度0上对a和b进行拼接,得到了一个维度为(2, 2, 2)的Tensor。

2. 数据拆分操作

2.1 torch.chunk()

在PyTorch中,我们可以使用torch.chunk()函数对一个Tensor进行拆分,用法如下:

torch.chunk(input, chunks, dim=0) -> List of Tensors

其中:

input:需要进行拆分的Tensor

chunks:指定拆分后的块数

dim:指定拆分的维度

需要注意的是,input在指定维度上的大小必须是chunks的整数倍,否则会出现拆分错误的情况。

下面我们通过一个例子来说明torch.chunk()的使用方法。假设我们有一个维度为(2, 4)的Tensor。

import torch

a = torch.randn(2, 4)

b, c = torch.chunk(a, 2, dim=1)

print("a:", a)

print("b:", b)

print("c:", c)

运行结果为:

a: tensor([[-1.0190, 1.0677, -0.9337, 0.1867],

[ 0.3485, -0.3037, -0.4665, -1.3789]])

b: tensor([[-1.0190, 1.0677],

[ 0.3485, -0.3037]])

c: tensor([[-0.9337, 0.1867],

[-0.4665, -1.3789]])

可以看到,torch.chunk()在维度1上将a拆分成了大小相同的两个Tensor。

2.2 torch.split()

除了使用torch.chunk()函数进行拆分之外,我们还可以使用torch.split()函数将一个Tensor按照指定长度进行拆分,用法如下:

torch.split(tensor, split_size_or_sections, dim=0) -> List of Tensors

其中:

tensor:需要进行拆分的Tensor

split_size_or_sections:指定拆分后每个部分的大小或者每个部分的长度

dim:指定拆分的维度

下面我们通过一个例子来说明torch.split()的使用方法。假设我们有一个维度为(2, 4)的Tensor。

import torch

a = torch.randn(2, 4)

b, c = torch.split(a, 2, dim=1)

print("a:", a)

print("b:", b)

print("c:", c)

运行结果为:

a: tensor([[-0.2831, 0.4517, -0.3449, -0.2991],

[-0.6605, 0.2996, 0.8940, 0.1864]])

b: tensor([[-0.2831, 0.4517],

[-0.6605, 0.2996]])

c: tensor([[-0.3449, -0.2991],

[ 0.8940, 0.1864]])

可以看到,torch.split()在维度1上将a按照指定长度进行了拆分。

后端开发标签