Pytorch使用MNIST数据集实现基础GAN和DCGAN详解

1. 简介

生成对抗网络(GAN, Generative Adversarial Network)是一种当前非常热门的深度学习技术,用于从数据中学习生成具有类似于训练数据的新数据。本文将一个基本的GAN模型和一个Deep Convolutional GAN(DCGAN)应用于手写数字MNIST数据集的生成。

2. MNIST数据集介绍

MNIST(Modified National Institute of Standards and Technology)是一个手写数字数据集,它由0-9共10个数字的手写数字图像组成。这些图像是28x28像素大小。MNIST数据集包含60000个训练样本和10000个测试样本。闵可夫斯基空间隔的Mnist数据集已成为深度学习领域中最知名的基准测试数据之一。

import torch 

from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),])

trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

3. GAN模型

GAN模型由一个生成器和一个分类器组成。生成器输入一个随机的噪声向量,输出一张能够欺骗分类器的图像。分类器输入一张图像,输出一个0到1之间的实数,表示这张图像是真实的概率。首先定义一个生成器网络和一个鉴别器网络:

import torch.nn as nn 

# 定义生成器

class Generator(nn.Module):

def __init__(self):

super(Generator, self).__init__()

self.fc1 = nn.Linear(100, 256)

self.fc2 = nn.Linear(256, 512)

self.fc3 = nn.Linear(512, 28*28)

def forward(self, x):

x = nn.functional.relu(self.fc1(x))

x = nn.functional.relu(self.fc2(x))

x = nn.functional.tanh(self.fc3(x))

return x.view(-1, 1, 28, 28)

# 定义鉴别器

class Discriminator(nn.Module):

def __init__(self):

super(Discriminator, self).__init__()

self.fc1 = nn.Linear(28*28, 512)

self.fc2 = nn.Linear(512, 256)

self.fc3 = nn.Linear(256, 1)

def forward(self, x):

x = x.view(-1, 28*28)

x = nn.functional.relu(self.fc1(x))

x = nn.functional.relu(self.fc2(x))

x = nn.functional.sigmoid(self.fc3(x))

return x

其中生成器网络接受一个随机噪声向量,它通过一个全连接层,两个线性转换和一个tanh激活函数输出28×28图像。鉴别器网络是一个前馈神经网络,将28×28图像映射到一个实数概率。

4. 训练GAN模型

下面的训练过程包括两个基本步骤。

4.1 生成器训练

生成器的目标是生成看起来真实的图像,使鉴别器无法决定输入数据是否是真实数据。为了实现这个目标,生成器网络需要最小化其输出的图像与真实图像之间的距离。当鉴别器网络把生成的图像判断为真实图像时,生成器网络的损失函数将最小化,其损失函数定义为:

criterion = nn.BCELoss()

d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

def train_generator(noise, labels):

g_optimizer.zero_grad()

fake_data = generator(noise).detach()

prediction_fake = discriminator(fake_data)

loss = criterion(prediction_fake, labels)

loss.backward()

g_optimizer.step()

return loss

4.2 鉴别器训练

训练鉴别器网络时,它需要将生成器产生的假图像和训练集中的真实图像区分开来。因此,它需要最小化它对真实图像和生成图像之间差异的损失函数。当生成器网络生成的图像被识别为假图像时,鉴别器网络的损失函数将最小化,其损失函数定义为:

def train_discriminator(real_data, fake_data, labels):

d_optimizer.zero_grad()

predictions_real = discriminator(real_data)

loss_real = criterion(predictions_real, labels)

predictions_fake = discriminator(fake_data)

loss_fake = criterion(predictions_fake, labels)

loss = (loss_real + loss_fake) / 2

loss.backward()

d_optimizer.step()

return loss

4.3 训练GAN模型

虽然GAN的训练过程可以是很复杂的,但是训练过程仍然可以被简化为生成器和鉴别器网络之间的交替训练,如下代码所示:

real_labels = torch.ones(batch_size, 1)

fake_labels = torch.zeros(batch_size, 1)

for epoch in range(10):

for i, (real_images, _) in enumerate(trainloader):

# 训练鉴别器

d_loss_real = criterion(discriminator(real_images), real_labels)

noise = torch.randn(batch_size, 100)

fake_images = generator(noise)

d_loss_fake = criterion(discriminator(fake_images.detach()), fake_labels)

d_loss = d_loss_real + d_loss_fake

train_discriminator(real_images, fake_images.detach(), real_labels)

# 训练生成器

noise = torch.randn(batch_size, 100)

fake_images = generator(noise)

g_loss = criterion(discriminator(fake_images), real_labels)

train_generator(noise, real_labels)

if i % 10 == 0:

print(f'Epoch {epoch}/{10}, Discriminator Loss: {d_loss:.4f}, Generator Loss: {g_loss:.4f}')

5. DCGAN模型

DCGAN是基于GAN提出的一种改进型模型,它能够生成更高质量的图像。下面给出了DCGAN网络的结构和训练过程。

5.1 DCGAN网络

与普通的GAN相比,DCGAN使用了以下技术优化GAN架构:

使用卷积运算代替全连接层。

在生成器网络中使用反卷积来生成图像。

使用Batch Normalization来提高模型鲁棒性。

使用LeakyReLU激活函数来避免梯度消失/爆炸问题。

class DCGANGenerator(nn.Module):

def __init__(self):

super(DCGANGenerator, self).__init__()

self.fc1 = nn.Linear(100, 7 * 7 * 256)

self.deconv1 = nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1)

self.deconv2 = nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1)

self.deconv3 = nn.ConvTranspose2d(64, 1, 4, stride=2, padding=1)

def forward(self, x):

x = nn.functional.relu(self.fc1(x))

x = x.view(-1, 256, 7, 7)

x = nn.functional.relu(self.deconv1(x))

x = nn.functional.relu(self.deconv2(x))

x = nn.functional.tanh(self.deconv3(x))

return x

class DCGANDiscriminator(nn.Module):

def __init__(self):

super(DCGANDiscriminator, self).__init__()

self.conv1 = nn.Conv2d(1, 64, 4, stride=2, padding=1)

self.conv2 = nn.Conv2d(64, 128, 4, stride=2, padding=1)

self.conv3 = nn.Conv2d(128, 256, 4, stride=2, padding=1)

self.fc1 = nn.Linear(7 * 7 * 256, 1)

def forward(self, x):

x = nn.functional.relu(self.conv1(x))

x = nn.functional.relu(self.conv2(x))

x = nn.functional.relu(self.conv3(x))

x = x.view(-1, 7 * 7 * 256)

x = nn.functional.sigmoid(self.fc1(x))

return x

5.2 DCGAN训练

训练过程与普通GAN模型非常相似,相应的修改只是模型的结构,以及训练过程的优化器和训练轮数。

import torch.optim as optim

dataloader = torch.utils.data.DataLoader(

datasets.MNIST(

'./data',

train=True,

download=True,

transform=transforms.Compose(

[

transforms.Resize(64),

transforms.CenterCrop(64),

transforms.ToTensor(),

transforms.Normalize((0.5,), (0.5,)),

]

),

),

batch_size=batch_size,

shuffle=True,

)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = DCGANGenerator().to(device)

discriminator = DCGANDiscriminator().to(device)

criterion = nn.BCELoss()

optimizer_g = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))

optimizer_d = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

for epoch in range(num_epochs):

for i, (images, _) in enumerate(dataloader):

# train discriminator

discriminator.zero_grad()

noise = torch.randn(images.shape[0], 100, 1, 1, device=device)

fake_images = generator(noise)

real_labels = torch.ones(images.shape[0], 1, device=device) * 0.9

fake_labels = torch.zeros(images.shape[0], 1, device=device)

d_loss_real = criterion(discriminator(images), real_labels)

d_loss_fake = criterion(discriminator(fake_images.detach()), fake_labels)

d_loss = d_loss_real + d_loss_fake

d_loss.backward()

optimizer_d.step()

# train generator

generator.zero_grad()

noise = torch.randn(images.shape[0], 100, 1, 1, device=device)

fake_images = generator(noise)

g_loss = criterion(discriminator(fake_images), real_labels)

g_loss.backward()

optimizer_g.step()

if i % 100 == 0:

print(

f"[{epoch}/{num_epochs}][{i}/{len(dataloader)}] Loss_D: {d_loss.item():.4f} Loss_G: {g_loss.item():.4f}"

)

6. 结果展示

下面展示了GAN和DCGAN在MNIST数据集上生成的数字图像。

6.1 基础GAN生成数字图像

在训练的过程中选择了随机的一张图片作为baseline:

import matplotlib.pyplot as plt

import numpy as np

# 选择一张图片作为baseline

baseline = trainset[0][0]

def sample_images(generator):

with torch.no_grad():

noise = torch.randn(64, 100)

fake_images = generator(noise).detach().cpu()

fake_images = fake_images * 0.5 + 0.5

fake_images = np.transpose(fake_images.data.numpy(), (0, 2, 3, 1))

fake_images = np.clip(fake_images, 0, 1)

plt.figure(figsize=(8, 8))

plt.axis("off")

plt.imshow(

np.concatenate(

[

np.concatenate(fake_images[:8], axis=1),

np.concatenate(fake_images[8:16], axis=1),

np.concatenate(fake_images[16:24], axis=1),

np.concatenate(fake_images[24:32], axis=1),

np.concatenate(fake_images[32:40], axis=1),

np.concatenate(fake_images[40:48], axis=1),

np.concatenate(fake_images[48:56], axis=1),

np.concatenate(fake_images[56:], axis=1),

],

axis=0,

)

)

plt.show()

sample_images(generator)

生成的结果如下:

6.2 DCGAN生成数字图像

与基本的GAN模型相比,DCGAN模型能够生成更加细腻、清晰的数字图像。

# 显示DCGAN生成数字图像

sample_images(generator)

生成的结果如下:

总结

本文实现了两个GAN模型(基础GAN和DCGAN)在MNIST数据集上生成手写数字图像,并通过代码详解和样例展示了这两个模型的效果。本文只是简单介绍了GAN模型的基本原理和应用,对GAN模型仍然有非常多的研究价值。

后端开发标签