Print

发布时间: 2021-03-16
摘要点击次数:
全文下载次数:
DOI: 10.11834/jig.200069
2021 | Volume 26 | Number 3




    图像理解和计算机视觉    




  <<上一篇 




  下一篇>> 





Re-GAN:残差生成式对抗网络算法
expand article info 史彩娟, 涂冬景, 刘靖祎
华北理工大学人工智能学院, 唐山 063210

摘要

目的 生成式对抗网络(generative adversarial network,GAN)是一种无监督生成模型,通过生成模型和判别模型的博弈学习生成图像。GAN的生成模型是逐级直接生成图像,下级网络无法得知上级网络学习的特征,以至于生成的图像多样性不够丰富。另外,随着网络层数的增加,参数变多,反向传播变得困难,出现训练不稳定和梯度消失等问题。针对上述问题,基于残差网络(residual network,ResNet)和组标准化(group normalization,GN),提出了一种残差生成式对抗网络(residual generative adversarial networks,Re-GAN)。方法 Re-GAN在生成模型中构建深度残差网络模块,通过跳连接的方式融合上级网络学习的特征,增强生成图像的多样性和质量,改善反向传播过程,增强生成式对抗网络的训练稳定性,缓解梯度消失。随后采用组标准化(GN)来适应不同批次的学习,使训练过程更加稳定。结果 在Cifar10、CelebA和LSUN数据集上对算法的性能进行测试。Re-GAN的IS(inception score)均值在批次为64时,比DCGAN(deep convolutional GAN)和WGAN(Wasserstein-GAN)分别提高了5%和30%,在批次为4时,比DCGAN和WGAN分别提高了0.2%和13%,表明无论批次大小,Re-GAN生成图像具有很好的多样性。Re-GAN的FID(Fréchet inception distance)在批次为64时比DCGAN和WGAN分别降低了18%和11%,在批次为4时比DCGAN和WGAN分别降低了4%和10%,表明Re-GAN生成图像的质量更好。同时,Re-GAN缓解了训练过程中出现的训练不稳定和梯度消失等问题。结论 实验结果表明,在图像生成方面,Re-GAN的生成图像质量高、多样性强;在网络训练方面,Re-GAN在不同批次下的训练具有更好的兼容性,使训练过程更加稳定,梯度消失得到缓解。

关键词

图像生成; 深度学习; 卷积神经网络; 生成式对抗网络; 残差网络; 组标准化

Re-GAN: residual generative adversarial network algorithm
expand article info Shi Caijuan, Tu Dongjing, Liu Jingyi
College of Artificial Intelligence, North China University of Science and Technology, Tangshan 063210, China
Supported by: National Natural Science Foundation of China (61502143)

Abstract

Objective A generative adversarial network (GAN) is a currently popular unsupervised generation model that generates images via game learning of the generative and discriminative models. The generative model uses Gaussian noise to generate probability distribution, and the discriminative model distinguishes between the generated and real probability distributions. In the ideal state, the discriminative model cannot distinguish between the two data distributions. However, achieving Nash equilibrium between the generative and discriminative models is difficult. Simultaneously, some problems, such as unstable training, gradient disappearance, and poor image quality, occur. Therefore, many studies have been conducted to address these problems, and these studies can be divided into two directions. One direction involves selecting the appropriate loss function, and the other direction involves changing the structure of GAN, e.g., from a fully connected neural network to a convolutional neural network (CNN). A typical work involves deep convolutional GANs (DCGANs), which adopts CNN and batch normalization (BN). Although DCGAN shave achieved good performance, some problems persist in the training process. Increasing the number of network layers leads to more errors, particularly gradient disappearance when the number of neural network layers is extremely high. In addition, BN leads to poor stability in the training process, particularly with small batch samples. In general, as the number of layers increases, the number of parameters increases and backpropagation becomes difficult as the number of layers increases, resulting in some problems, such as unstable training and gradient disappearance. In addition, the generative model directly generates images step by step, and a lower level network cannot determine the features learned by a higher level network, and thus, the diversity of the generated images is not sufficiently rich. To address the a fore mentioned problems, a residual GAN (Re-GAN) is proposed based on a residual network (ResNet) and group normalization (GN). Method ResNet has been recently proposed to solve the problem of network degradation caused by too many layers of a deep neural network and has been applied to image classification due to its good performance. In contrast with BN, GN divides channels into groups and calculates the normalized mean and variance within each group. Calculation is stable and independent of batch size. Therefore, we apply ResNet and GN to GAN to propose Re-GAN. First, a residual module ResNet is introduced into the generative model of GAN by adding the input and the mapping to the output of the layer to prevent gradient disappearance and enhance training stability. Moreover, the residual module ResNet optimizes feature transmission between neural network layers and enhances the diversity and quality of the generated image. Second, Re-GAN adopts the standardized GN to adapt to different batch learning. GN can reduce the difficulty of standardization caused by the lack of training samples and stabilize the training process of the network. Moreover, when the number of samples is sufficient, GN can make the calculated results match well with the sample distribution and exhibit good compatibility. Result To verify the effectiveness of the proposed algorithm Re-GAN, we compare it with DCGAN and Wasserstein-GAN (WGAN) with different batches of samples on three datasets namely, Cifar10, CcelebA, and LSUN bedroom. Two evaluation criteria, i.e., inception score (IS) and Fréchet inception distance (FID), are adopted in our experiments. As a common evaluation criterion for GAN, IS uses the inception network trained on ImageNet to calculate the information of the generated images. IS focuses on the evaluation of the quality but not the diversity of the generated images. When IS is larger, the quality of the generated images is better. FID is more robust to noise and more suitable for describing the diversity of the generated images. It is computed via a set of generated images and a set of ground images. When FID is smaller, the diversity of the generated images is better. We can obtain the following experimental results. 1) When the batch number is 64, the IS of the proposed algorithm Re-GAN is 5% higher than that of DCGAN and 30% higher than that of WGAN. When the batch is 4, the IS of Re-GAN is 0.2% higher than that of DCGAN and 13% higher than that of WGAN. These results show that the images generated by Re-GAN exhibit good diversity regardless of batch size. 2) When the batch number is 64, the FID of Re-GAN is 18% lower than that of DCGAN and 11% lower than that of WGAN. When the batch number is 4, the FID of Re-GAN is 4% lower than that of DCGAN and 10% lower than that of WGAN. These results indicate that the proposed algorithm Re-GAN can generate images with higher quality. 3) Training instability and gradient disappearance are alleviated during the training process. Conclusion The performance of the proposed Re-GAN is tested using two evaluation criteria, i.e., IS and FID, on three datasets. Extensive experiments are conducted, and the experimental results indicate the following findings. In the aspect of image generation, Re-GAN generates high-quality images with rich diversity. In the aspect of network training, Re-GAN guarantees that training exhibits better compatibility regardless of whether the batch is large or small, and then it makes the training process more stable and alleviates gradient disappearance. In addition, compared with DCGAN and WGAN, the proposed Re-GAN exhibits better performance, which can be attributed to the ResNet and GN adopted in Re-GAN.

Key words

image generation; deep learning; convolutional neural network (CNN); generative adversarial network (GAN); residual network (ResNet); group normalization (GN)

0 引言

构建生成式模型需要相应数据的先验知识和大量参数,先验知识的准确程度直接影响着模型的好坏,而大量参数导致计算量庞大。为了解决这些问题,Goodfellow等人(2014)提出了生成式对抗网络(generative adversarial networks, GAN)。作为一种概率生成模型,GAN能够反映数据内在的概率分布规律并生成全新数据,包括但不限于图像、音乐、语音和文本等(曹仰杰等,2018)。随着对GAN的广泛研究,GAN逐步应用到视频预测和生成(Mathieu等,2015)、图像修复(Yeh等,2017)、图像翻译(Isola等,2017)和语义分割(Zhu等,2016)等领域。

GAN主要由生成模型(generative model,G)和判别模型(discriminative model,D)组成。生成模型利用高斯噪声生成概率分布,判别模型区分生成概率分布与真实概率分布之间的差异。因此,GAN的问题变成一个博弈问题,生成模型尽可能生成类似真实分布的数据以迷惑判别模型,判别模型则尽可能分辨两个数据分布的不同,理想状态是判别器无法分辨两个数据分布的差异。然而生成模型和判别模型之间并非很容易就能达到纳什平衡,同时也存在训练不稳定、梯度消失等问题。解决该问题有两种思路,一是选择合适的损失函数;二是改变GAN的架构,如使用全连接神经网络、拉普拉斯金字塔、卷积神经网络、自注意力机制、多层神经网络等提高GAN的生成和特征提取能力。

卷积神经网络具有很好的抽象能力,Radford等人(2015)将其应用到GAN架构中,提出深度卷积生成式对抗网络(deep convolutional generative adversarial networks,DCGAN),对传统GAN进行改进。主要表现为:1)在生成模型和判别模型中使用卷积网络,允许生成器学习自己的空间降采样;2)消除卷积特征顶部的全连接层;3)使用批标准化(batch normalization,BN)将每层的输入都标准化为期望值为0、方差为1的数据。尽管DCGAN具有良好的特征提取能力,但是在训练中存在误差增大和梯度消失等问题,使用的BN模型也存在性能稳定性差或无法使用等问题。

为了解决深度神经网络层数过多产生的网络退化问题,He等人(2016)提出了残差网络(residual network,ResNet),在图像识别中表现出很好的性能,并由此得到广泛研究和应用。Qiu等人(2017)提出了一种深度3维残差神经网络并用于视频理解任务;Lim等人(2017)提出一种用于单一超分辨率图像的增强型深度残差网络;Silver等人(2017)将残差网络用于AlphaGo。Nitanda和Suzuki(2018)以及Huang等人(2018)则对ResNet进行了理论分析,从推进解释角度研究了ResNet的泛化能力。

在深度学习中,数据标准化处理是一个非常重要的过程,可以防止数据过拟合,加快训练速度等。Ioffe和Szegedy(2015)提出的批标准化(BN)处理得到了广泛应用,但BN要求有足够大的批次样本,小批次样本会导致估算不准确、模型误差增加等。为克服批次大小的影响,Ba等人(2016)提出了沿通道维度计算的层标准化(layer normalization,LN),Ulyanov等人(2016)提出了针对每个样本的实例标准化(instance normalization,IN),但是准确性都较差。Wu和He(2020)提出了组标准化(group normalization,GN),将通道分成组,在组内计算标准化均值和方差,该计算独立于批次大小,并且精度稳定。

本文基于残差网络和组标准化提出了一种新的生成式对抗网络——残差生成式对抗网络(residual generative adversarial networks,Re-GAN)。首先,在生成模型的反卷积层中增加残差模块,在层的输出处添加输入和映射,防止梯度消失,增加训练的稳定性。同时,残差模块使浅层反卷积神经网络的特征可以完整地传递到下一层,增强了生成图像的多样性和质量;其次,采用组标准化,不仅可以完成大批次样本训练的标准化处理,而且适合小批次样本训练,使所提网络模型兼容性更好,训练过程更稳定。最后,将本文所提算法在Cifar10、CelebA和LSUN数据集上进行实验,实验结果表明了所提算法的有效性。

1 生成式对抗网络

GAN是一种强大的生成模型,具有两个深度神经网络,即生成模型(generative model, $G$)和判别模型(discriminative model,$D$),GAN的目标函数可以表示为

$ \begin{array}{c} \underbrace{\min }_{G} \underbrace{\max }_{D} V(D, G)=E_{x \sim P_{\text {data }}}[\log D(x)]+ \\ E_{z \sim P_{z}}[\log (1-D(G(z)))] \end{array} $ (1)

式中,$P_{\text {data }}$表示真实数据的概率分布,$P_{z}$表示根据噪声生成数据的概率分布,$E$代表期望值。

生成式对抗网络中的生成模型和判别模型是相互对抗的,在图像生成领域,生成模型的目标是尽量生成逼真图像,判别模型用来判断目标是否为真实图像。GAN的传统架构如图 1所示,假设输入一个根据高斯分布随机产生的噪声$P_{z}$($z$),在GAN学习数据$x$的分布$P_{\text {data }}$时,生成模型$G\left(z; \theta_{g}\right)$ 将噪声变量转化为一个数据$x′$,判别模型$D\left(x; \theta_{d}\right)$ 将任何输入转化为一个(0, 1) 之间的标量,用来表示输入是真实分布的概率。

图 1 GAN的传统架构
Fig. 1 Traditional framework of GAN

GAN基于概率的训练过程如图 2所示,其中,点线、虚线和实线分别表示真实样本分布、生成图像分布以及判别模型,箭头表示映射$x=G(z) $如何将非均匀分布${P_g}$作用在转换后的样本上。具体过程如下:

图 2 GAN基于概率的训练示意图
Fig. 2 Probability-based training diagram of GAN((a)initial distribution; (b)training $D$; (c)training $G$; (d)equilibration)

1) 考虑一个接近收敛的对抗模型对,生成的分布${P_g}$与真实分布$P_{\text {data }}$相似,且$D$是一个部分准确的分类器;

2) 在算法循环中,训练$D$来判别数据中的样本,收敛到$D^{*}(x)$,且

$ D^{*}(x)=\frac{P_{\text {data }}(x)}{P_{\text {data }}(x)+P_{g}(x)} $ (2)

3) 在$G$的一次更新后,$D$的梯度引导$G$($z$)流向更可能分类为真实数据的区域。即$D(x)$ = 1/2;

4) 达到平衡。

为了提高GAN训练的稳定性,对GAN的架构进行多种改进。Radford等人(2015)将卷积神经网络应用到GAN架构中,提出深度卷积生成式对抗网络(deep convolutional generative adversarial network,DCGAN)。Mirza和Osindero(2014)提出了条件生成式对抗网络(conditional generative adversarial network,CGAN),将条件变量$y$作为附加信息约束生成过程。Chen等人(2016)提出InfoGAN(interpretable representation learning by information maximizing generative adversarial network),从噪声矢量中拆分出结构化的隐变量作为条件变量,控制生成图像的结果。Donahue等人(2016)提出的双向生成式对抗网络(bidirectional GAN,BiGAN)和Dumoulin等人(2016)提出的对抗性学习推理(adversarially learned inference,ALI)将单向GAN变为双向GAN,既能进行有效推断又保证了生成图像的质量。基于变分自动编码器(variational autoencoder,VAE)(Rezende等,2014)的生成模型可以用于无监督学习等任务,Larsen等人(2016)将VAE和GAN并入一个无监督生成模型中,将编码器和解码器看做一个生成模型。Che等人(2016)将VAE的重构误差作为遗失模式的正则项,提高了GAN的稳定性和生成图像质量。Wang等人(2016)通过对GAN进行堆叠、平行或相互反馈来调整生成模型和判别模型的组合方式,提出GAN的自组合和级联组合形式。Liu和Tuzel(2016)提出的对生成对抗网络(coupled generative adversarial networks,CoGAN)包含一对GAN。Zhu等人(2017)提出的循环一致性对抗网络(cycle-consistent adversarial networks,CycleGAN)包含两个判别模型${D_x}$${D_y}$Arjovsky等人(2017)提出的Wasserstein生成对抗网络(Wasserstein-GAN,WGAN)引入最优化中的Wasserstein度量距离来度量两个分布的距离,解决模式崩溃问题,确保生成样本的多样性。

以上方法从不同角度对GAN进行了改进,性能得到了一定的提升,但是仍然存在训练不稳定、梯度消失等问题。如DCGAN在训练过程中增加网络层数会导致更大误差,当层数很深时会出现梯度消失的状况;WGAN在寻找判别模型$D$的过程中需要依赖1-Lipschitz约束等。另外,DCGAN和WGAN采用的批标准化BN模型在重构测试样本时性能较差,而在训练过程中稳定性较差,特别是在小批次样本训练中根本无法使用。

2 残差生成式对抗网络

本文提出的残差生成式对抗网络(residual generative adversarial networks,Re-GAN)具有两方面的优势。一方面将残差模块ResNet引入GAN的生成模型,在隐藏层每一层的输出处添加输入和映射,缓解了梯度消失问题,增强了训练稳定性,优化了神经网络层之间的特征传递,改善了反向传播过程,提高了生成图像的质量和多样性;另一方面采用组标准化(GN)将输入图像按通道分组,在组内计算标准化均值和方差,用于解决训练样本不足的问题,减少计算误差,使网络训练过程保持稳定。在样本足够时,GN的计算结果可以很好地符合样本分布,表明Re-GAN具有很好的兼容性。

2.1 生成模型架构

本文所提残差生成式对抗网络Re-GAN是在GAN的生成模型中加入残差模块,其生成模型的结构如图 3所示。

图 3 Re-GAN生成模型架构
Fig. 3 Generative model architecture of Re-GAN

Re-GAN的生成模型包含5个神经网络层,将随机生成的高斯噪声输入生成模型,噪声经过5个神经网络层的反卷积后,输出生成图像。生成模型中的每个残差反卷积层(Res-TransposeConv)都包含反卷积层(transpose convolution,TransposeConv)、组标准化、残差块和激活函数。反卷积层通过设置多个4 × 4的反卷积核来改变输入维度,TransposeConv1将输入的噪声张量维度从100扩大为512,接下来TransposeConv2、TransposeConv3和TransposeConv4每层将维度减少一半,到最后一层TransposeConv5输出的维度为3,对应RGB图像的3个通道。Res-TransposeConv层中的标准化处理为组标准化,每层都有很多卷积核,这些卷积核学习到的特征并不完全是独立的,某些特征具有相同分布。由于使用有界函数更有助于模型迅速在训练分布中覆盖颜色空间,所以除最后一层输出时使用tanh作为激活函数,其余4个Res-TransposeConv层均使用ReLU作为激活函数。为了使输出到下一层的特征可以保留输入时的信息并防止梯度退化,将每个Res-TransposeConv层的输入“跳连接”到激活函数后的输出,通过通道连接构建残差模块。由于此时两个张量的维度数不同,输入的张量需要经过多个4 × 4的反卷积核,使“跳连接”维度与输出张量维度相同后才能将两者连接起来,再经过激活函数后输出到下一层。

使用组标准化是为了兼容不同批次的标准化处理,组标准化在批次较小时具有更好表现。另外,设计了一个生成模型Re-GAN2与Re-GAN进行对比,二者的区别主要是Re-GAN2在跳连接前加入了组标准化来规范跳连接,如图 4所示。Re-GAN和Re-GAN2的判别模型相同。

图 4 Re-GAN和Re-GAN2的生成模型
Fig. 4 Generative models of Re-GAN and Re-GAN2
((a)Re-GAN; (b)Re-GAN2)

2.2 判别模型架构

Re-GAN的判别模型如图 5所示。首先,输入3通道的RGB图像,经过5个卷积层抽取图像特征。每个卷积层使用不同数量的4 × 4卷积核获得不同的维度张量,不添加池化层,使得卷积层能够获得更大的感受野,每层输入张量的维数改变与生成模型相反,所有层均使用LeakyReLU作为激活函数,最后使用sigmod函数将张量转换为大小在(0, 1) 之间的标量输出,该标量表示图像是真实图像的概率。

图 5 Re-GAN判别模型架构
Fig. 5 Discriminative model architecture of Re-GAN

2.3 本文算法

在训练Re-GAN时,设定生成模型和判别模型的学习效率为0.000 2。生成模型除了输出层使用tanh作为激活函数,其他层均使用ReLU作为激活函数。判别模型的所有层均使用LeakyReLU作为激活函数,斜率都设置为0.2。使用Adam优化器调节超参数,训练次数设定为10 000次,每次输出一次生成的图像。算法的伪代码如下:

for number in iters do:

1) 更新网络$D$,最大化log($D$($x$))+ log(1- $D$($G$(noise))

 (1)real data ← $D$(real images)//判别真实图像

 (2)fake data ← $D$($G$(noise))//判别虚假图像

 (3)$\operatorname{loss}_{d(\text { real })}$ ← BCEloss(real data, 1)//计算真实图像结果与标签的损失

 (4)$\operatorname{loss}_{d(\text { fake })}$← BCEloss(fake data, 0)//计算虚假图像结果与标签的损失

 (5)$d_{\theta} \leftarrow \nabla_{d} 1 / m {\mathit{\Sigma}}\left[\log \left(\operatorname{loss}_{d(\text { real })}\right)+\log (1-\right.$ $\left.\left.\operatorname{loss}_{d(\text { fake })}\right)\right]$//判别网络的损失函数

 (6)$w_{d} \leftarrow \operatorname{Adam}\left(d_{\theta, } w_{d}\right)$//更新参数

2) 更新网络$G$,最大化log($D$($G$(noise))

 (1)$\operatorname{loss}_{g(\text { fake })}$ ← BCEloss(fake data, 1)//计算虚假图像结果与真实图像标签的损失

 (2)$g_{\theta} \leftarrow \nabla_{d} 1 / m {\mathit{\Sigma}}\left(\log \left(1-\operatorname{loss}_{g(\text { fake })}\right)\right)$// 生成网络的损失函数

 (3)$w_{g} \leftarrow \operatorname{Adam}\left(d_{\theta, } w_{g}\right)$//更新参数

end

3 实验

3.1 数据库及评价准则

为了验证所提残差生成式对抗网络Re-GAN的有效性,在Cifar10、CelebA和LSUN数据集上进行实验。Cifar10数据集包含10个类别60 000幅彩色图像(每个类别6 000幅),其中50 000幅为训练图像,10 000幅为测试图像,每幅图像为32 × 32像素。CelebA数据集是一个包含超过20万幅名人图像的大规模人脸属性数据集,每幅图像有40个属性标注。LSUN数据集包含10个场景类别和20个对象,共计约100万个标记图像,本文实验采用其中的卧室场景。

采用IS(inception score)(Barratt和Sharma,2018)和FID(Fréchet inception distance)(Heusel等,2017)指标评估所提算法性能。IS评价指标考虑了GAN生成图像的质量,但没有考虑真实数据的影响,因而不能反映生成图像是否逼近真实图像。FID对噪声更加鲁棒,更适合描述GAN网络的多样性。因此综合两个指标对所提算法Re-GAN进行评估。

IS是常见的GAN评估方法,使用在ImageNet上训练过的inception网络计算生成图像的信息,计算为

$ I S(G)=\exp \left(E_{x \sim P_{g}} D_{\mathrm{KL}}(p(y \mid \boldsymbol{x}) \| p(\boldsymbol{y}))\right) $ (3)

式中,$\boldsymbol{x} \sim \boldsymbol{P}_{g}$表示$\mathit{\boldsymbol{x}}$是生成图像的样本。$p(y \mid \boldsymbol{x})$ 表示图像$\mathit{\boldsymbol{x}}$输入到inception V3后得到的属于类别$y$的概率分布。$p$($\mathit{\boldsymbol{x}}$) 表示图像$\mathit{\boldsymbol{x}}$输入到inception V3后得到在所有分类上的边缘分布。${D_{{\rm{KL}}}}$表示对两个概率分布求KL散度。IS值越大表明生成模型越好。

FID是专门评估生成对抗网络性能的指标,通过一组生成数据与目标域的一组统计数据评估生成图像的质量,FID越小表示图像的质量和多样性越好, 计算为

$ \begin{array}{l} F I D(\boldsymbol{r}, \boldsymbol{g})=\left\|\boldsymbol{\mu}_{r}-\boldsymbol{\mu}_{g}\right\|^{2}+ \\ \operatorname{tr}\left(\mathit{\boldsymbol{ \boldsymbol{\varSigma} }}_{r}+\mathit{\boldsymbol{ \boldsymbol{\varSigma} }}_{g}-2\left(\mathit{\boldsymbol{ \boldsymbol{\varSigma} }}_{r} \mathit{\boldsymbol{ \boldsymbol{\varSigma} }}_{g}\right)^{\frac{1}{2}}\right) \end{array} $ (4)

式中,$\boldsymbol{\mu}_{r}$为真实图像特征均值;$\boldsymbol{\mu}_{g}$为生成图像特征均值;$\mathit{\boldsymbol{ \boldsymbol{\varSigma} }}_{r}$为真实图像特征的协方差矩阵;$\mathit{\boldsymbol{ \boldsymbol{\varSigma} }}_{g}$为生成图像特征的协方差矩阵。

3.2 实验结果

3.2.1 性能比较

为了验证所提Re-GAN及Re-GAN2算法的性能,与DCGAN和WGAN算法在上述3个数据集上进行比较。每个算法在每个数据集上训练10 000次,每100次输出1幅样本图像,共输出100幅样本图像,然后计算这100幅样本图像的IS和FID。实验按批次不同分为两种情况,一种采用64幅图像为1个批次,另一种采用4幅图像为1个批次。表 1表 2分别给出了不同批次的IS和FID。

表 1 不同算法在Cifar10、CelebA和LSUN数据集上的IS性能比较
Table 1 Comparison of IS performance among different algorithms on Cifar10, CelebA and LSUN datasets

下载CSV
算法 64幅/批次 4幅/批次
Cifar10 CelebA LSUN Cifar10 CelebA LSUN
DCGAN 1.51 1.51 1.41 1.51 1.66 1.66
WGAN 1.12 1.20 1.23 1.39 1.44 1.45
Re-GAN 1.59 1.40 1.65 1.45 1.70 1.70
Re-GAN2 1.64 1.28 1.67 1.50 1.67 1.62
注:加粗字体表示各列最优结果。

表 2 不同算法在Cifar10、CelebA和LSUN数据集上的FID性能比较
Table 2 Comparison of FID performance among different algorithms on Cifar10, CelebA and LSUN datasets

下载CSV
算法 64幅/批次 4幅/批次
Cifar10 CelebA LSUN Cifar10 CelebA LSUN
DCGAN 175.2 112.8 124.9 262.4 181.3 232.7
WGAN 117.1 148.6 121.2 287.2 174.2 266.7
Re-GAN 134.5 93.4 105.4 241.9 164.6 242.8
Re-GAN2 180.9 96.1 126.7 237.6 171.8 241.7
注:加粗字体表示各列最优结果。

表 1可以看出,批次为64时,Re-GAN算法的IS值比DCGAN和WGAN算法分别提高了5%和30%;批次为4时,Re-GAN算法的IS值比DCGAN和WGAN算法分别提高了0.2%和13%。表明Re-GAN算法在生成图像的多样性上,大批次和小批次都优于DCGAN和WGAN算法。

表 2可以看出,Re-GAN算法在不同批次,特别在小批次情况下可以生成更符合真实情况的图像。在小批次情况下,在Cifar10和CelebA数据集上,Re-GAN算法的FID值分别为241.9和164.6;在LSUN数据集上,Re-GAN算法的表现不如DCGAN算法,但与WGAN算法相比,FID值更低;特别地,Re-GAN2算法比Re-GAN算法的性能更好,表明批次很小时,GN比BN更适合标准化处理,引入GN后,GAN生成图像的多样性更丰富。在大批次情况下,与其他方法相比,Re-GAN同样具有良好性能。在CelebA和LSUN数据集上,FID值分别为93.4和105.4。此时,Re-GAN2算法相对于Re-GAN算法虽然不如在小批次情况的表现,但相比于其他方法依然有着很好的竞争力。表明组标准化GN在大批次情况下具有良好的兼容性。

3.2.2 生成图像比较

将Re-GAN算法与DCGAN和WGAN算法在Cifar10、CelebA和LSUN数据集上生成的图像进行比较,并分为大批次(采用64幅图像为1个批次)和小批次(采用4幅图像为1个批次)两种情况,比较结果如图 6所示。

图 6 三个数据集上各算法在不同批次生成的图像
Fig. 6 The images by different algorithms in different batches on three dataset
((a) Cifar10 dataset; (b) CelebA dataset; (c) LSUN dataset)

图 6可以看出,在大批次情况下,与DCGAN和WGAN算法生成的图像相比,Re-GAN算法在3个数据集上生成的图像质量更高,辨识度更好,种类更加丰富;在小批次情况下,Re-GAN算法生成的图像包含的噪音更少,图像的特征更加明显。由此可以得出,本文所提算法Re-GAN生成的图像具有更好的性能。

3.2.3 梯度比较

在Cifar10数据集上,对Re-GAN和DCGAN算法在不同批次训练中的梯度进行比较,图 7图 8展示DCGAN和Re-GAN算法在大批次和小批次情况下的梯度分布结果。图中每一条线表示1次采样,高度低的采样时间较早。通过对图中的权重分布进行比较,观察梯度消失现象。在大批次(64幅图像为1个批次)情况下,训练次数为2 000次,每50次采样1次。从图 7可以看出,在第5层网络进行比较,DCGAN算法的梯度收敛速度更慢,导致每次采样得到的梯度更加分散,而Re-GAN算法的收敛速度较快;在第1层网络进行比较,梯度从第5层经过反向传播到达第1层网络,由于梯度消失,DCGAN算法中大部分梯度都接近0,而Re-GAN算法的梯度消失现象更加缓和。在小批次(4幅图像为1个批次)情况下,训练次数同样为2 000次。从图 8可以看出,两种算法的梯度与大批次情况具有相同结果。因此,无论在大批次还是小批次情况下,所提算法Re-GAN比DCGAN算法具有更好的稳定性。

图 7 不同算法在Cifar10数据集上大批次(64幅/批次)情况下训练的梯度分布
Fig. 7 The gradient distribution when different algorithms train on Cifar10 dataset (batch size=64)((a) DCGAN; (b) Re-GAN)
图 8 不同算法在Cifar10数据集在小批次(4幅/批次)情况下训练的梯度分布
Fig. 8 The gradient distribution when different algorithms train on Cifar10 dataset (batch size=4)((a) DCGAN; (b) Re-GAN)

4 结论

本文基于残差网络ResNet和组标准化GN提出了一种残差生成式对抗网络Re-GAN。首先,在生成模型中构建深度残差模块以缓解生成式对抗网络的梯度消失和训练稳定性问题,同时增加生成图像的多样性;其次,采用组标准化处理使Re-GAN能够同时适应大批次和小批次样本训练,训练过程更加稳定。在Cifar10、CelebA和LSUN数据集上进行实验,采用IS和FID指标评估所提算法性能,并与WGAN和DCGAN算法进行比较,结果表明,Re-GAN算法的训练时间更短、训练过程更稳定、生成图像质量更高和多样性更丰富。

本文算法Re-GAN在一定程度上缓解了梯度消失现象,提高了图像生成质量,但是仍有部分权重在传播过程中消失,主要是因为没有充分考虑损失函数对模型的影响。因此,在后续的研究中将采用合适的损失函数,更好地衡量生成的数据分布与真实的数据分布之间的差异,从而进一步提升生成图像质量,获得更高的稳定性。

参考文献

  • Arjovsky M, Chintala S and Bottou L. 2017. Wasserstein GAN[EB/OL].[2020-02-13]. https://arxiv.org/pdf/1701.07875.pdf
  • Ba J L, Kiros J R and Hinton G E. 2016. Layer normalization[EB/OL].[2020-02-13]. https://arxiv.org/pdf/1607.06450.pdf
  • Barratt S and Sharma R. 2018. A note on the inception score[EB/OL].[2020-02-13]. https://arxiv.org/pdf/1801.01973.pdf
  • Cao Y J, Jia L L, Chen Y X, Lin N, Li X X. 2018. Review of computer vision based on generative adversarial networks. Journal of Image and Graphics, 23(10): 1433-1449 (曹仰杰, 贾丽丽, 陈永霞, 林楠, 李学相. 2018. 生成式对抗网络及其计算机视觉应用研究综述. 中国图象图形学报, 23(10): 1433-1449) [DOI:10.11834/jig.180103]
  • Che T, Li Y R, Jacob A P, Bengio Y and Li W J. 2016. Mode regularized generative adversarial networks[EB/OL].[2020-02-13]. https://arxiv.org/pdf/1612.02136.pdf
  • Chen X, Duan Y, Houthooft R, Schulman J, Sutskever I and Abbeel P. 2016. InfoGAN: Interpretable representation learning by information maximizing generative adversarial nets//Proceedings of the 30th International Conference on Neural Information Processing Systems. Barcelona, Spain: ACM: 2180-2188[DOI: 10.5555/3157096.3157340]
  • Donahue J, Krähenbühl P and Darrell T. 2016. Adversarial feature learning[EB/OL].[2020-02-13]. https://arxiv.org/pdf/1605.09782.pdf
  • Dumoulin V, Belghazi I, Poole B, Mastropietro O, Lamb A, Arjovsky M and Courville A. 2016. Adversarially learned inference[EB/OL].[2020-02-13]. https://arxiv.org/pdf/1606.00704.pdf
  • Goodfellow I J, Pouget-Abadie J, Mirza M, Xu B, Warde-Farley D, Ozair S, Courville A and Bengio Y. 2014. Generative adversarial nets//Proceedings of the 27th International Conference on Neural Information Processing Systems. Montreal, Canada: MIT Press: 2672-2680
  • He K M, Zhang X Y, Ren S Q and Sun J. 2016. Deep residual learning for image recognition//Proceedings of 2016 IEEE Conference on Computer Vision and Pattern Recognition. Las Vegas, USA: IEEE: 770-778[DOI: 10.1109/CVPR.2016.90]
  • Heusel M, Ramsauer H, Unterthiner T, Nessler B and Hochreiter S. 2017. GANs trained by a two time-scale update rule converge to a local Nash equilibrium//Proceedings of the 31st International Conference on Neural Information Processing Systems. Long Beach, USA: ACM: 6629-6640
  • Huang F R, Ash J, Langford J and Schapire R. 2018. Learning deep ResNet blocks sequentially using boosting theory[EB/OL].[2020-02-13]. https://arxiv.org/pdf/1706.04964.pdf
  • Ioffe S and Szegedy C. 2015. Batch normalization: accelerating deep network training by reducing internal covariate shift//Proceedings of the 32nd International Conference on Machine Learning. Lille, France: ACM: 448-456
  • Isola P, Zhu J Y, Zhou T H and Efros A A. 2017. Image-to-image translation with conditional adversarial networks//Proceedings of 2017 IEEE Conference on Computer Vision and Pattern Recognition. Honolulu, USA: IEEE: 5967-5976[DOI: 10.1109/CVPR.2017.632]
  • Larsen A B L, Sønderby S K, Larochelle H and Winther. 2016. Autoencoding beyond pixels using a learned similarity metric//Proceedings of the 33rd International Conference on Machine Learning. New York, USA: ACM: 1558-1566
  • Lim B, Son S, Kim H, Nah S and Lee K M. 2017. Enhanced deep residual networks for single image super-resolution//Proceedings of 2017 IEEE Conference on Computer Vision and Pattern Recognition Workshops. Honolulu, USA: IEEE: 1132-1140[DOI: 10.1109/CVPRW.2017.151]
  • Liu M Y and Tuzel O. 2016. Coupled generative adversarial networks[EB/OL].[2020-02-13]. https://arxiv.org/pdf/1606.07536.pdf
  • Mathieu M, Couprie C and LeCun Y. 2015. Deep multi-scale video prediction beyond mean square error[EB/OL].[2020-02-13]. https://arxiv.org/pdf/1511.05440.pdf
  • Mirza M and Osindero S. 2014. Conditional generative adversarial nets[EB/OL].[2020-02-13]. https://arxiv.org/pdf/1411.1784.pdf
  • Nitanda A and Suzuki T. 2018. Functional gradient boosting based on residual network perception[EB/OL].[2020-02-13]. https://arxiv.org/pdf/1802.09031.pdf
  • Qiu Z F, Yao T and Mei T. 2017. Learning spatio-temporal representation with pseudo-3D residual networks//Proceedings of 2017 IEEE International Conference on Computer Vision. Venice, Italy: IEEE: 5534-5542[DOI: 10.1109/ICCV.2017.590]
  • Radford A, Metz L and Chintala S. 2015. Unsupervised representation learning with deep convolutional generative adversarial networks[EB/OL].[2020-02-13]. https://arxiv.org/pdf/1511.06434.pdf
  • Rezende D J, Mohamed S and Wierstra D. 2014. Stochastic backpropagation and approximate inference in deep generative models//Proceedings of the 31st International Conference on Machine Learning. Beijing, China: ACM: Ⅱ-1278-1286
  • Silver D, Schrittwieser J, Simonyan K, Antonoglou I, Huang A, Guez A, Hubert T, Baker L, Lai M, Bolton A, Chen Y T, Lillicrap T, Hui F, Sifre L, Van Den Driessche G, Graepel T, Hassabis D. 2017. Mastering the game of Go without human knowledge. Nature, 550(7676): 354-359 [DOI:10.1038/nature24270]
  • Ulyanov D, Vedaldi A and Lempitsky V. 2016. Instance normalization: the missing ingredient for fast stylization[EB/OL].[2020-02-13]. https://arxiv.org/pdf/1607.08022.pdf
  • Wang Y X, Zhang L C and Van De Weijer J. 2016. Ensembles of generative adversarial networks[EB/OL].[2020-02-13]. https://arxiv.org/pdf/1612.00991.pdf
  • Wu Y X, He K M. 2020. Group normalization. International Journal of Computer Vision, 128(3): 742-755 [DOI:10.1007/s11263-019-01198-w]
  • Yeh R A, Chen C, Lim T Y, Schwing A G, Hasegawa-Johnson M and Do M N. 2017. Semantic image inpainting with deep generative models//Proceedings of 2017 IEEE Conference on Computer Vision and Pattern Recognition. Honolulu, USA: IEEE: 6882-6890[DOI: 10.1109/CVPR.2017.728]
  • Zhu J Y, Park T, Isola P and Efros A A. 2017. Unpaired image-to-image translation using cycle-consistent adversarial networks//Proceedings of 2017 IEEE International Conference on Computer Vision (ICCV). Venice, Italy: IEEE: 2242-2251[DOI: 10.1109/ICCV.2017.244]
  • Zhu W T, Xiang X, Tran T D and Xie X H. 2016. Adversarial deep structural networks for mammographic mass segmentation[EB/OL].[2020-02-13]. https://arxiv.org/pdf/1612.05970.pdf