您现在的位置是:首页 >学无止境 >Pytorch实现之融合注意力机制和残差结构的EIGGAN网站首页学无止境

Pytorch实现之融合注意力机制和残差结构的EIGGAN

LJ1147517021 2025-12-20 00:01:03
简介Pytorch实现之融合注意力机制和残差结构的EIGGAN

简介

论文题目:An Enhanced GAN for Image Generation

论文期刊:Computers, Materials & Continua (CMC)

文章提到一个均衡卷积的论文题目:Progressive growing of gans for improved quality, stability, and variation,来自于ICLR2018.

还有一个是CBAM模块,这个了解通道注意力和空间注意力的小伙伴肯定非常熟悉了!

摘要:具有博弈能力的生成对抗网络(GANs)在图像生成中得到了广泛的应用。 然而,在不同的场景下,游戏生成器和鉴别器可能会降低生成的gan在图像生成中的鲁棒性。 增强生成网络中层次信息之间的关系,扩大不同网络架构之间的差异,可以促进更多的结构化信息,从而提高图像生成的生成效果。 在本文中,我们通过改进图像生成生成器(EIGGAN)提出了一种增强的GAN。 EIGGAN将空间注意力应用于生成器,提取显著信息,增强生成图像的真实性。 考虑上下文关系,将并行残差操作融合成生成网络,从不同层提取更多的结构信息。 最后,利用混合损失函数在速度和精度之间进行权衡,以生成更逼真的图像。 实验结果表明,该方法在Frechet初始距离、学习到的感知图像斑块相似度、多尺度结构相似度指标、核初始距离、统计不同Bins数、初始分数和一些视觉图像生成等指标上均优于Wasserstein梯度惩罚GAN (WGAN-GP)。

其实这篇论文结构是一个非常简单的GAN模型的改进,有非常多内容值得初学者学习,即融合残差结构、注意力机制在模型的生成器或者鉴别器当中!同时也组合了多类损失函数来帮助模型收敛,也更好地权衡速度与精度!非常建议GAN的初学者学习这种改进结构,并后续再次基础上不断创新,得到更为高效、效果更为精良的模型!

模型结构

论文的模型结构很简单,一张图即可完整囊括。就是GAN模型的基础上补充了残差结构,加入了注意力机制等操作。这些方法可以非常快的上手修改GAN结构!!!!!

损失函数为多种损失函数的融合,主要包括了正常损失、非饱和损失和R1和R2正则化。生成器和鉴别器的损失如下:

代码实现

接下来我们展示一下核心的一些代码块

Transopose ConvBlock

class TransposeConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(TransposeConvBlock, self).__init__()
        self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.pixel_norm = PixelNorm()
        self.EConv = EqualizedConv2d(out_channels, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv(x)
        x = self.leaky_relu(x)
        x = self.pixel_norm(x)
        x = self.EConv(x)
        x = self.leaky_relu(x)
        x = self.pixel_norm(x)
        return x

AttentionBlock

class AttentionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(AttentionBlock, self).__init__()
        self.conv1 = EqualizedConv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv2 = EqualizedConv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.CBAM = CBAM(out_channels)
        self.leaky_relu = nn.LeakyReLU(0.2)
        self.pixel_norm = PixelNorm()

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.CBAM(out)
        out = self.leaky_relu(out)
        out2 = self.pixel_norm(out)
        out = self.pixel_norm(out)
        out = self.conv2(out)
        out = self.leaky_relu(out)
        out = self.pixel_norm(out)
        residual = self.conv1(residual)

        return out + residual +
风语者!平时喜欢研究各种技术,目前在从事后端开发工作,热爱生活、热爱工作。