pytorch下的unsqueeze和squeeze的用法说明

1. 简介

在PyTorch中,unsqueeze和squeeze是两个常用的操作函数,用于在张量的维度上进行扩展或压缩。unsqueeze用于在指定维度上插入一个大小为1的维度,而squeeze用于压缩张量中大小为1的维度。

2. unsqueeze的用法

2.1 插入一个维度

unsqueeze函数的主要功能是在张量的指定位置插入一个大小为1的新维度。这个新维度将在原有维度的前面或后面插入,具体取决于指定的位置。

例如,有一个形状为(3, 4)的张量a:

import torch

a = torch.randn(3, 4)

print(a.shape)

# Output: torch.Size([3, 4])

我们可以使用unsqueeze函数在a的第二个维度后面插入一个新的维度:

b = torch.unsqueeze(a, dim=1)

print(b.shape)

# Output: torch.Size([3, 1, 4])

在这个例子中,原来的形状为(3, 4)的张量a变成了形状为(3, 1, 4)的张量b。新插入的维度具有大小为1。

2.2 批量插入维度

除了可以在一个指定位置插入一个维度,unsqueeze函数还支持一次插入多个维度。只需在dim参数中传入一个元组或列表,指定各个维度的插入位置即可。

例如,有一个形状为(3, 4)的张量a:

a = torch.randn(3, 4)

print(a.shape)

# Output: torch.Size([3, 4])

我们可以使用unsqueeze函数在a的第一个维度前面插入一个新的维度,并在第三个维度后面插入一个新的维度:

b = torch.unsqueeze(a, dim=(0, 2))

print(b.shape)

# Output: torch.Size([1, 3, 4, 1])

在这个例子中,原来的形状为(3, 4)的张量a变成了形状为(1, 3, 4, 1)的张量b。新插入的维度具有大小为1。

3. squeeze的用法

3.1 压缩大小为1的维度

squeeze函数的主要功能是压缩张量中大小为1的维度。它会自动检测张量中所有大小为1的维度,并将其压缩。

例如,有一个形状为(1, 3, 1, 4)的张量a:

import torch

a = torch.randn(1, 3, 1, 4)

print(a.shape)

# Output: torch.Size([1, 3, 1, 4])

我们可以使用squeeze函数将a中的大小为1的维度进行压缩:

b = torch.squeeze(a)

print(b.shape)

# Output: torch.Size([3, 4])

在这个例子中,原来的形状为(1, 3, 1, 4)的张量a变成了形状为(3, 4)的张量b。大小为1的维度被压缩。

3.2 指定压缩的维度

除了自动检测大小为1的维度进行压缩,squeeze函数还支持指定需要压缩的维度。只需在dim参数中传入一个元组或列表,指定需要压缩的维度即可。

例如,有一个形状为(1, 3, 1, 4)的张量a:

a = torch.randn(1, 3, 1, 4)

print(a.shape)

# Output: torch.Size([1, 3, 1, 4])

我们可以使用squeeze函数压缩a的第一个维度和第三个维度:

b = torch.squeeze(a, dim=(0, 2))

print(b.shape)

# Output: torch.Size([3, 4])

在这个例子中,原来的形状为(1, 3, 1, 4)的张量a变成了形状为(3, 4)的张量b。指定的维度大小为1的维度被压缩。

4. 总结

unsqueeze和squeeze是PyTorch中非常有用的操作函数,用于在张量的维度上进行扩展或压缩。unsqueeze可以在指定位置插入一个大小为1的新维度,而squeeze可以压缩张量中大小为1的维度。使用这两个函数,可以灵活地操作张量的形状,满足不同的需求。

在本文中,我们介绍了unsqueeze和squeeze的基本用法,并且通过代码示例进行了演示。希望读者能够通过本文对这两个函数有更深入的理解,并能在实际应用中灵活运用。

后端开发标签