PyTorch中的padding(边缘填充)操作方式

1. Padding是什么?

在深度学习中,我们经常需要对图像或序列等数据进行处理。但是不同的数据长度不同,这就给模型训练带来了困难。Padding就是在边缘填充一定的值,使得所有的输入数据维度一致,从而方便网络的处理。

2. PyTorch中的Padding操作

在PyTorch中,Padding的操作可以通过使用torch.nn模块中的函数来实现。下面是一个例子,对一个输入的大小为(1,3,32,32)的张量进行Padding操作:

import torch

import torch.nn as nn

inputs = torch.randn(1, 3, 32, 32)

padding = nn.ZeroPad2d((2, 2, 2, 2))

outputs = padding(inputs)

print(outputs.size()) # (1, 3, 36, 36)

2.1 nn.ZeroPad2d函数

nn.ZeroPad2d函数是一个常用的Padding函数,它可以在输入张量的四周填充值0。它的参数是一个四元组(pad_left, pad_right, pad_top, pad_bottom),分别表示在输入张量的左边,右边,上面和下面填充的值。需要注意的是,当我们需要在某一维度填充相同的值时,我们可以使用一个单独的整数值代替其左右两个元素(如上例中的(2, 2, 2, 2))。

作为另一个例子,我们可以使用nn.ZeroPad2d函数为MNIST中的手写数字数据集的图像进行Padding操作:

import torch

from torchvision import datasets

from torchvision.transforms import ToTensor, Normalize, Compose

import matplotlib.pyplot as plt

mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]))

# 取第一张训练图像作为例子

image, _ = mnist_train[0]

fig, axs = plt.subplots(1, 2, figsize=(10,5))

axs[0].set_title('Original image')

axs[0].imshow(image.squeeze(), cmap='gray')

# 在两侧各填充2个0

padding = nn.ZeroPad2d((2, 2, 0, 0))

padded = padding(image.unsqueeze(0)).squeeze()

axs[1].set_title('Padded image')

axs[1].imshow(padded, cmap='gray')

plt.show()

该例子演示了如何对MNIST数据集中的手写数字图像进行Padding操作,并使用可视化的方式展示出Padding前后的效果。如图所示,我们在输入的左右两侧分别填充了2个像素,使输入张量的宽度增加了4个像素。

2.2 nn.ConstantPad2d函数

除了nn.ZeroPad2d函数之外,PyTorch中还提供了nn.ConstantPad2d函数来进行Padding操作。与nn.ZeroPad2d不同之处在于,nn.ConstantPad2d可以指定任何常量值进行填充。

举个例子,我们可以使用nn.ConstantPad2d来对一个输入进行填充,并使用不同的填充值来可视化边界:

import torch

import torch.nn as nn

import numpy as np

import matplotlib.pyplot as plt

np.random.seed(123)

inputs = torch.from_numpy(np.random.rand(1, 3, 32, 32))

padding = nn.ConstantPad2d((2, 1, 2, 1), 0.3)

outputs = padding(inputs)

fig, axs = plt.subplots(1, 2, figsize=(10,5))

axs[0].set_title('Original image')

axs[0].imshow(inputs.squeeze().permute(1, 2, 0))

axs[1].set_title('Padded image')

axs[1].imshow(outputs.squeeze().permute(1, 2, 0))

plt.show()

该例子中,我们使用torch.from_numpy来随机生成一个大小为(1,3,32,32)的张量。接着,我们使用nn.ConstantPad2d函数将其四周各填充一个宽度为1高度为2的边界,并将填充值设置为0.3。这里需要注意的是我们可以使用不同的填充值进行不同的填充。如图所示,左边为原始的输入图像,右边为进行Padding后的图像,可以看到Padding被成功地应用于输入的四周。

后端开发标签