换脸火了,我用 python 快速入门生成模型

发布于 2021-04-29 07:41

作者 | 李秋键

头图 | 下载于ICphoto

出品 | AI科技大本营(ID:rgznai100)


引言:

近几年来,GAN生成对抗式应用十分火热,不论是抖音上大火的“蚂蚁牙黑”还是B站上的“复原老旧照片”以及换脸等功能,都是基于GAN生成对抗式的模型。但是GAN算法对于大多数而言上手较难,故今天我们将使用最少的代码,简单入门“生成对抗式网络”,实现用GAN生成数字。

其中生成的图片效果如下可见:

模型建立


1.1 环境要求

本次环境使用的是python3.6.5+windows平台

主要用的库有:

  • OS模块用来对本地文件读写删除、查找到等文件操作
  • numpy模块用来矩阵和数据的运算处理,其中也包括和深度学习框架之间的交互等
  • Keras模块是一个由Python编写的开源人工神经网络库,可以作为Tensorflow、Microsoft-CNTK和Theano的高阶应用程序接口,进行深度学习模型的设计、调试、评估、应用和可视化 。在这里我们用来搭建网络层和直接读取数据集操作,简单方便
  • Matplotlib模块用来可视化训练效果等数据图的制作

1.2 GAN简单介绍
GAN 由生成器 (Generator)和判别器 (Discriminator) 两个网络模型组成,这两个模型作用并不相同,而是相互对抗。我们可以很简单的理解成,Generator是造假的的人,Discriminator是负责鉴宝的人。正是因为生成模型和对抗模型的相互对抗关系才称之为生成对抗式。
那我们为什么不适用VAE去生成模型呢,又怎么知道GAN生成的图片会比VAE生成的更优呢?问题就在于VAE模型作用是使得生成效果越相似越好,但事实上仅仅是相似却只是依葫芦画瓢。而 GAN 是通过 discriminator 来生成目標,而不是像 VAE线性般的学习。
这个项目里我们目标是训练神经网络生成新的图像,这些图像与数据集中包含的图像尽可能相近,而不是简单的复制粘贴。神经网络学习什么是图像的“本质”,然后能够从一个随机的数字数组开始创建它。其主要思想是让两个独立的神经网络,一个产生器和一个鉴别器,相互竞争。生成器会创建与数据集中的图片尽可能相似的新图像。判别器试图了解它们是原始图片还是合成图片。
1.3 模型初始化
在这里我们初始化需要使用到的变量,以及优化器、对抗式模型等。
def __init__(self, width=28, height=28, channels=1):
    self.width = width
    self.height = height
    self.channels = channels
    self.shape = (self.width, self.height, self.channels)
    self.optimizer = Adam(lr=0.0002, beta_1=0.5, decay=8e-8)
    self.G = self.__generator()
    self.G.compile(loss='binary_crossentropy', optimizer=self.optimizer)
    self.D = self.__discriminator()
    self.D.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy'])
    self.stacked_generator_discriminator = self.__stacked_generator_discriminator()
    self.stacked_generator_discriminator.compile(loss='binary_crossentropy', optimizer=self.optimizer)

1.4 生成器模型的搭建
这里我们尽可能简单的搭建一个生成器模型,3个完全连接的层,使用sequential标准化。神经元数分别是256,512,1024等:
 def __generator(self):
        """ Declare generator """
        model = Sequential()
        model.add(Dense(256, input_shape=(100,)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(self.width  * self.height * self.channels, activation='tanh'))
        model.add(Reshape((self.width, self.height, self.channels)))
        return model

1.5 判别器模型的搭建
在这里同样简单搭建判别器网络层,和生成器模型类似:
def __discriminator(self):
    """ Declare discriminator """
    model = Sequential()
    model.add(Flatten(input_shape=self.shape))
    model.add(Dense((self.width * self.height * self.channels), input_shape=self.shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(np.int64((self.width * self.height * self.channels)/2)))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(1, activation='sigmoid'))
    model.summary()
    return model

1.6 对抗式模型的搭建
这里是较为难理解的部分。让我们创建一个对抗性模型,简单来说这只是一个后面跟着一个鉴别器的生成器。注意,在这里鉴别器的权重被冻结了,所以当我们训练这个模型时,生成器层将不受影响,只是向上传递梯度。代码很简单如下:
def __stacked_generator_discriminator(self):
    self.D.trainable = False
    model = Sequential()
    model.add(self.G)
    model.add(self.D)
    return model

模型的训练使用

2.1 模型的训练
在这里,我们并没有直接去训练生成器。而是通过对抗性模型间接地训练它。我们将噪声传递给了对抗模型,并将所有从数据库中获取的图像标记为负标签,而它们将由生成器生成。
对真实图像进行预先训练的鉴别器把不能合成的图像标记为真实图像,所犯的错误将导致由损失函数计算出的损失越来越高。这就是反向传播发挥作用的地方。由于鉴别器的参数是冻结的,在这种情况下,反向传播不会影响它们。相反,它会影响生成器的参数。所以优化对抗性模型的损失函数意味着使生成的图像尽可能的相似,鉴别器将识别为真实的。这既是生成对抗式的神奇之处!
故训练阶段结束时,我们的目标是对抗性模型的损失值很小,而鉴别器的误差尽可能高,这意味着它不再能够分辨出差异。
最终在我门的训练结束时,鉴别器损失约为0.73。考虑到我们给它输入了50%的真实图像和50%的合成图像,这意味着它有时无法识别假图像。这是一个很好的结果,考虑到这个例子绝对不是优化的结果。要知道确切的百分比,我可以在编译时添加一个精度指标,这样它可能得到很多更好的结果实现更复杂的结构的生成器和判别器。
代码如下,这里legit_images是指原始训练的图像,而syntetic_images是生成的图像。:
def train(self, X_train, epochs=20000, batch = 32, save_interval = 100):
    for cnt in range(epochs):
        ## train discriminator
        random_index = np.random.randint(0, len(X_train) - np.int64(batch/2))
        legit_images = X_train[random_index : random_index + np.int64(batch/2)].reshape(np.int64(batch/2), self.width, self.height, self.channels)
        gen_noise = np.random.normal(01, (np.int64(batch/2), 100))
        syntetic_images = self.G.predict(gen_noise)
        x_combined_batch = np.concatenate((legit_images, syntetic_images))
        y_combined_batch = np.concatenate((np.ones((np.int64(batch/2), 1)), np.zeros((np.int64(batch/2), 1))))
        d_loss = self.D.train_on_batch(x_combined_batch, y_combined_batch)
        # train generator
        noise = np.random.normal(01, (batch, 100))
        y_mislabled = np.ones((batch, 1))
        g_loss = self.stacked_generator_discriminator.train_on_batch(noise, y_mislabled)
        print ('epoch: %d, [Discriminator :: d_loss: %f], [ Generator :: loss: %f]' % (cnt, d_loss[0], g_loss))
        if cnt % save_interval == 0:
            self.plot_images(save2file=True, step=cnt)

2.2 可视化
使用matplotlib来可视化模型训练效果。
def plot_images(self, save2file=False, samples=16, step=0):
    ''' Plot and generated images '''
    if not os.path.exists("./images"):
        os.makedirs("./images")
    filename = "./images/mnist_%d.png" % step
    noise = np.random.normal(01, (samples, 100))
    images = self.G.predict(noise)
    plt.figure(figsize=(1010))
    for i in range(images.shape[0]):
        plt.subplot(44, i+1)
        image = images[i, :, :, :]
        image = np.reshape(image, [self.height, self.width])
        plt.imshow(image, cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    if save2file:
        plt.savefig(filename)
        plt.close('all')
    else:
        plt.show()
使用方法

考虑到代码较少,下述代码复制粘贴即可运行。
# -*- coding: utf-8 -*-
import os
import numpy as np
from IPython.core.debugger import Tracer
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential
from keras.optimizers import Adam
import matplotlib.pyplot as plt
plt.switch_backend('agg')   # allows code to run without a system DISPLAY
class GAN(object):
    """ Generative Adversarial Network class """
    def __init__(self, width=28, height=28, channels=1):
        self.width = width
        self.height = height
        self.channels = channels
        self.shape = (self.width, self.height, self.channels)
        self.optimizer = Adam(lr=0.0002, beta_1=0.5, decay=8e-8)
        self.G = self.__generator()
        self.G.compile(loss='binary_crossentropy', optimizer=self.optimizer)
        self.D = self.__discriminator()
        self.D.compile(loss='binary_crossentropy', optimizer=self.optimizer, metrics=['accuracy'])
        self.stacked_generator_discriminator = self.__stacked_generator_discriminator()
        self.stacked_generator_discriminator.compile(loss='binary_crossentropy', optimizer=self.optimizer)
    def __generator(self):
        """ Declare generator """
        model = Sequential()
        model.add(Dense(256, input_shape=(100,)))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(self.width  * self.height * self.channels, activation='tanh'))
        model.add(Reshape((self.width, self.height, self.channels)))
        return model
    def __discriminator(self):