PyTorch 中的傅里叶卷积实现示例

1. PyTorch 中的傅里叶卷积简介

傅里叶变换在信号处理和图像处理中被广泛使用,也是卷积神经网络(Convolutional Neural Network,CNN)中卷积操作的基础。傅里叶卷积是将两个函数的乘积经过傅里叶变换得到的变换成果的逆变换,这个过程可以利用快速傅里叶变换(Fast Fourier Transform,FFT)实现。

PyTorch 中通过 torch.fft 模块来实现快速傅里叶变换,进而实现傅里叶卷积。

2. 傅里叶卷积的实现流程

傅里叶卷积的实现流程可以概括为以下几个步骤:

2.1. 输入数据变换

将输入数据经过傅里叶变换得到频率域的数据,可以利用 torch.fft.fftn 实现。

import torch

input = torch.randn(1, 3, 32, 32)

kernel = torch.randn(3, 3, 3)

stride = 1

padding = 1

# 将输入数据从空间域变换到频率域

input_f = torch.fft.fftn(input, dim=[-2, -1])

2.2. 卷积核变换

将卷积核经过傅里叶变换得到频率域的卷积核,可以利用 torch.fft.fftn 实现。

kernel = torch.flip(kernel, dims=[-2, -1]) # 若卷积核不是空间对称的,需要先进行翻转

kernel = torch.nn.functional.pad(kernel, [0, input.shape[-1]-kernel.shape[-1], 0, input.shape[-2]-kernel.shape[-2]], mode='constant', value=0) # 若卷积核尺寸小于输入数据尺寸,需要进行 zero-padding

kernel_f = torch.fft.fftn(kernel, dim=[-2, -1])

2.3. 卷积操作

在频率域上进行卷积操作,可以利用 torch.fft.ifftn 实现。

output_f = input_f * kernel_f # 在频率域上进行卷积操作

output = torch.fft.ifftn(output_f, dim=[-2, -1]).real # 将卷积结果从频率域变换到空间域,取实部得到最终输出

2.4. 卷积输出

根据步长(stride)和填充(padding)参数,调整卷积输出的尺寸和形状。

def calc_output_size(input_size, kernel_size, stride, padding):

output_size = (input_size + padding * 2 - kernel_size) // stride + 1

return output_size

output_size = (calc_output_size(input.shape[-2], kernel.shape[-2], stride, padding),

calc_output_size(input.shape[-1], kernel.shape[-1], stride, padding))

output = output[:, :, :output_size[0], :output_size[1]]

3. 傅里叶卷积实现示例

下面是一个在 PyTorch 中实现傅里叶卷积的示例,其中采用的是 2D 傅里叶卷积。

class FourierConv2d(torch.nn.Module):

def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):

super(FourierConv2d, self).__init__()

self.in_channels = in_channels

self.out_channels = out_channels

self.kernel_size = kernel_size

self.stride = stride

self.padding = padding

self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size))

def forward(self, x):

# 定义卷积核

weight = self.weight

weight = torch.flip(weight, dims=[-2, -1]) # 若卷积核不是空间对称的,需要先进行翻转

weight = torch.nn.functional.pad(weight, [0, x.shape[-1]-weight.shape[-1], 0, x.shape[-2]-weight.shape[-2]], mode='constant', value=0) # 若卷积核尺寸小于输入数据尺寸,需要进行 zero-padding

weight_f = torch.fft.fftn(weight, dim=[-2, -1])

# 定义输入数据

x = torch.nn.functional.pad(x, [0, 0, self.padding, self.padding, self.padding, self.padding], mode='constant', value=0) # 进行 zero-padding

x_f = torch.fft.fftn(x, dim=[-2, -1])

# 在频率域上进行卷积操作

output_f = x_f * weight_f

# 将卷积结果从频率域变换到空间域,取实部得到最终输出

output = torch.fft.ifftn(output_f, dim=[-2, -1]).real

# 根据步长和填充参数,调整卷积输出的尺寸和形状

output_size = (calc_output_size(x.shape[-2], weight.shape[-2], self.stride, self.padding),

calc_output_size(x.shape[-1], weight.shape[-1], self.stride, self.padding))

output = output[:, :, :output_size[0], :output_size[1]]

return output

# 测试示例代码

input = torch.randn(1, 3, 32, 32)

conv = FourierConv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=1)

output = conv(input)

assert output.shape == (1, 16, 32, 32)

通过上述代码,我们可以在 PyTorch 中实现傅里叶卷积,从而加深我们对卷积操作的理解。

后端开发标签