您现在的位置是:首页 >学无止境 >Pytorch实现之融合注意力机制和残差结构的EIGGAN网站首页学无止境
Pytorch实现之融合注意力机制和残差结构的EIGGAN
简介
论文题目:An Enhanced GAN for Image Generation
论文期刊:Computers, Materials & Continua (CMC)
文章提到一个均衡卷积的论文题目:Progressive growing of gans for improved quality, stability, and variation,来自于ICLR2018.
还有一个是CBAM模块,这个了解通道注意力和空间注意力的小伙伴肯定非常熟悉了!
摘要:具有博弈能力的生成对抗网络(GANs)在图像生成中得到了广泛的应用。 然而,在不同的场景下,游戏生成器和鉴别器可能会降低生成的gan在图像生成中的鲁棒性。 增强生成网络中层次信息之间的关系,扩大不同网络架构之间的差异,可以促进更多的结构化信息,从而提高图像生成的生成效果。 在本文中,我们通过改进图像生成生成器(EIGGAN)提出了一种增强的GAN。 EIGGAN将空间注意力应用于生成器,提取显著信息,增强生成图像的真实性。 考虑上下文关系,将并行残差操作融合成生成网络,从不同层提取更多的结构信息。 最后,利用混合损失函数在速度和精度之间进行权衡,以生成更逼真的图像。 实验结果表明,该方法在Frechet初始距离、学习到的感知图像斑块相似度、多尺度结构相似度指标、核初始距离、统计不同Bins数、初始分数和一些视觉图像生成等指标上均优于Wasserstein梯度惩罚GAN (WGAN-GP)。
其实这篇论文结构是一个非常简单的GAN模型的改进,有非常多内容值得初学者学习,即融合残差结构、注意力机制在模型的生成器或者鉴别器当中!同时也组合了多类损失函数来帮助模型收敛,也更好地权衡速度与精度!非常建议GAN的初学者学习这种改进结构,并后续再次基础上不断创新,得到更为高效、效果更为精良的模型!
模型结构
论文的模型结构很简单,一张图即可完整囊括。就是GAN模型的基础上补充了残差结构,加入了注意力机制等操作。这些方法可以非常快的上手修改GAN结构!!!!!

损失函数为多种损失函数的融合,主要包括了正常损失、非饱和损失和R1和R2正则化。生成器和鉴别器的损失如下:

代码实现
接下来我们展示一下核心的一些代码块
Transopose ConvBlock
class TransposeConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(TransposeConvBlock, self).__init__()
self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
self.leaky_relu = nn.LeakyReLU(0.2)
self.pixel_norm = PixelNorm()
self.EConv = EqualizedConv2d(out_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x):
x = self.conv(x)
x = self.leaky_relu(x)
x = self.pixel_norm(x)
x = self.EConv(x)
x = self.leaky_relu(x)
x = self.pixel_norm(x)
return x
AttentionBlock
class AttentionBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(AttentionBlock, self).__init__()
self.conv1 = EqualizedConv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.conv2 = EqualizedConv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.CBAM = CBAM(out_channels)
self.leaky_relu = nn.LeakyReLU(0.2)
self.pixel_norm = PixelNorm()
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.CBAM(out)
out = self.leaky_relu(out)
out2 = self.pixel_norm(out)
out = self.pixel_norm(out)
out = self.conv2(out)
out = self.leaky_relu(out)
out = self.pixel_norm(out)
residual = self.conv1(residual)
return out + residual +





U8W/U8W-Mini使用与常见问题解决
QT多线程的5种用法,通过使用线程解决UI主界面的耗时操作代码,防止界面卡死。...
stm32使用HAL库配置串口中断收发数据(保姆级教程)
分享几个国内免费的ChatGPT镜像网址(亲测有效)
Allegro16.6差分等长设置及走线总结