使用 DCGAN 生成动漫小姐姐头像

生成对抗网络是 2014 年由伊恩·古德费洛等人提出的一种非监督式学习方法 arXiv:1406.2661v1,该方法的特点是通过让两个神经网络相互博弈的方式进行学习。DCGAN 是 GAN 的一种十分实用的延伸网络,它由 Alec Radford 等人于 2015 年提出 arXiv:1511.06434

DCGAN 的全称为:Deep Convolutional Generative Adversarial Networks,翻译成中文也就是:深度卷积生成对抗网络。DCGAN 将卷积网络引入到生成式模型当中来做无监督的训练。这种结构很好地利用了卷积网络强大的特征提取能力,从而有效提高了生成网络的学习效果。

数据集

首先,你需要找到一些动漫人物头像数据。不用特别多,几千张就可以了。这里推荐一个网站 Danbooru,可以直接使用爬虫爬取。

网络搭建

我们可以使用 TensorFlow 或者 PyTorch 来构建网络。这里推荐使用 PyTorch,因为其提供的 Dataloader 数据加载器非常适合用来对图像进行预处理和小批量加载。

下面提供一段示例代码,你需要将头像图片放置在 avatar 目录下方。

import torch
from torchvision import datasets, transforms

# 定义图片处理方法
transforms = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = datasets.ImageFolder('avatar/',
                               transform=transforms)

dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size=16,
                                         shuffle=True,
                                         num_workers=2
                                         )

网络主体需要参考 DCGAN 论文来实现,下面是 DCGAN 网络结构图。

结合网络结构图,就可以实现生成器和判别器网络了。


from torch import nn

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d( 100, 64 * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d( 64 * 4, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d( 64 * 2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d( 64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

定义损失函数和优化器是必不可少的步骤:

netD = Discriminator()
netG = Generator()
criterion = nn.BCELoss()

lr = 0.0002
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))

最后就可以完成网络训练,并使用噪声数据进行测试了。

from torchvision.utils import make_grid
from matplotlib import pyplot as plt
from IPython import display
%matplotlib inline

epochs = 100
for epoch in range(epochs):
    for n, (images, _) in enumerate(dataloader):

        real_labels = torch.ones(images.size(0))
        fake_labels = torch.zeros(images.size(0))

        netD.zero_grad()
        output = netD(images)
        lossD_real = criterion(output.squeeze(), real_labels)

        noise = torch.randn(images.size(0), 100, 1, 1)
        fake_images = netG(noise)
        output2 = netD(fake_images.detach())
        lossD_fake = criterion(output2.squeeze(), fake_labels)
        lossD = lossD_real + lossD_fake
        lossD.backward()
        optimizerD.step()

        netG.zero_grad()
        output3 = netD(fake_images)
        lossG = criterion(output3.squeeze(), real_labels)
        lossG.backward()
        optimizerG.step()

        fixed_noise = torch.randn(64, 100, 1, 1)
        fixed_images = netG(fixed_noise)
        fixed_images = make_grid(fixed_images.data,
                                 nrow=8, normalize=True).cpu()
        plt.figure(figsize=(6, 6))
        plt.title("Epoch[{}/{}], Batch[{}/{}]".format(epoch+1,
                                                      epochs,
                                                      n+1,
                                                      len(dataloader)))
        plt.imshow(fixed_images.permute(1, 2, 0).numpy())
        display.display(plt.gcf())
        display.clear_output(wait=True)

以下,是迭代 100 个 Epoch 之后的测试结果,可以看到效果还是不错的。

GAN 是近些年来快速发展的网络结构,但由于其模型过于自由,很容易出现训练不收敛,模型崩溃等问题。不过在数据增强,图片生成等方面,GAN 还是有很多的发展空间。未来值得期待。