GANs很难?这篇文章教你50行代码搞定(PyTorch)

AI资讯1年前 (2023)发布 AI工具箱
415 0 0

作者:Dev Nag,Wavefront创始人、CTO,曾是Google、PayPal工程师。量子位编译。

2014年,Ian Goodfellow和他在蒙特利尔大学的同事们发表了一篇令人惊叹的论文,正式把生成对抗网络(GANs)介绍给全世界。通过把计算图和博弈论创新性的结合起来,GANs有能力让两个互相对抗的模型通过反向传播共同训练。

模型中有两个相互对抗的角色,我们分别称为GD,简单解释如下:G是一个生成器,它试图通过学习真实数据集R,来创建逼真的假数据;D鉴别器,从R和G处获得数据并标记差异。

Goodfellow有个很好的比喻:G是一个造假团队,试图造出跟真画一样的赝品;D是鉴定专家,试图找出真画和赝品的差异。当然在GANs的设定里,G是一群永远见不到真画的造假团队,他们能够获得的反馈只有D的鉴定意见。

在理想情况下,D和G都会随着时间的推移变得更好,直到G变成一个造假大师,最终让D无法区分出真画和赝品。实际上,Goodfellow已经表明G能够对原始数据集进行无监督学习,并且找到这些数据的低维表达方式。


这么厉害的技术,代码怎么也得一大堆吧?

并不是。使用刚刚发布的PyTorch,实际上可以只用不到50行代码,就能创建一个GAN。我们需要考虑的组件只有下面五个:

 R:原始的真实数据集

 I:作为熵源输入生成器的随机噪声

 G:尝试复制/模仿原始数据集的生成器

 D:尝试分辨G输出的鉴别器

 一个训练循环:教G造假,再教D来鉴别……

1)R: 我们将从最简单的R,一个钟形曲线开始。这个函数以平均值和标准偏差为参数,然后返回一个函数。在我们的示例代码中,使用了平均值4.0和标准差1.25。

2)I: 输入生成器的噪声也是随机的,但是为了增加点难度,我们使用了一个均匀分布,而不是正态分布。这意味着模型G不能简单地通过移动/缩放复制R,而必须以非线性的方式重塑数据。

3)G: 生成器是一个标准的前馈图,包含两个隐藏层,三个线性映射。在这里,我们使用了ELU(指数线性单位)。G将从I获得均匀分布的数据样本,并以某种方式模仿来自R的正态分布样本。

4)D: 鉴别器与生成器G的代码非常相似,都是有两个隐藏层和三个线性映射的前馈图。它将从R或G获取样本,并输出介于0和1之间的单个标量,0和1分别表示“假”和“真”。

5)训练循环 最后,训练循环在两种模式之间交替:首先,用带有准确标签的真实数据和假数据来训练D;然后,训练G来愚弄D。

即使你从没用过PyTorch,也大致能看出发生了什么。在上图标为绿色的第一部分,我们将不同类型的数据输入D,并对D的猜测结果和实际的标签进行评判。这一步是“正向”的,然后我们用“反向”来计算梯度,并用它来更新d_optimizer step()调用的D参数。

上面,我们用到了G,但没有训练它。

在标为红色的下半部分中,我们对G做了同样的事情,注意:我们还会通过D来运行G的输出,相当于给了造假者一个侦探练习。但是在这一步中,我们不会对D进行优化或更改,因为我们不希望D学到错误的标签。因此,我们只调用g_optimizer.step()。

就这些啦,还有一些其他的样本代码,但是针对GAN的只有这五个组件。


对D和G进行几千轮训练之后,我们能得到什么?鉴别器D优化得很快,而G一开始优化得比较慢,不过,一旦到达了特定水平,G就开始迅速成长。

两万轮训练过后,G的输出的平均值超过4.0,但随后回到一个相当稳定,正确的范围(如左图)。同样,标准偏差最初在错误的方向下降,但随后上升到所要求的1.25范围(右图),与R相当。

所以,基本的统计最终与R相当,那么高阶矩如何呢?分布的形状是否正确?毕竟,你当然可以有一个平均值为4.0、标准差为1.25的均匀分布,但这不会真正与R相匹配。让我们看看G形成的最终分布。

还不错。左尾比右边稍微长了一点,但是我们可以说,它的偏斜和峰态符合原始的高斯函数。

G几乎完美还原了R的原始分布,而D独自在角落徘徊,无法分清真伪。这正是我们想要的结果。用不到50行的代码,就能实现。

© 版权声明

相关文章

暂无评论

暂无评论...