Domain-Adversarial Training of Neural Networks(DaNN)实现
Domain-Adversarial Training of Neural Networks(DaNN)实现
总体介绍
在传统的机器学习中,我们经常需要大量带标签的数据进行训练, 并且需要保证训练集和测试集中的数据分布相似。在一些问题中,如果训练集和测试集的数据具有不同的分布,训练后的分类器在测试集上就没有好的表现。
域适应(Domain Adaption)是迁移学习中一个重要的分支,目的是把具有不同分布的源域(Source Domain) 和目标域 (Target Domain) 中的数据,映射到同一个特征空间,寻找某一种度量准则,使其在这个空间上的“距离”尽可能近。然后,我们在源域 (带标签) 上训练好的分类器,就可以直接用于目标域数据的分类。
DaNN是一种域适应学习方法,它采用了GAN的思想。为了使模型在目标集上也能有好的表现,它的目的是使模型特征提取器在源域和目标域提取的特征具有相同的分布。
DANN结构主要包含3个部分:
- 特征提取器 (feature extractor) - 图示绿色部分,用来将数据映射到特定的特征空间,使标签预测器能够分辨出来自源域数据的类别的同时,域判别器无法区分数据来自哪个域。
- 标签预测器 (label predictor) - 图示蓝色部分,对来自源域的数据进行分类,尽可能分出正确的标签。
- 域判别器(domain classifier)- 图示红色部分,对特征空间的数据进行分类,尽可能分出数据来自哪个域。
对抗迁移网络的总损失由两部分构成:网络的训练损失(标签预测器损失)和域判别损失。
我们通过最小化目标函数来更新标签预测器的参数,最大化目标函数来更新域判别器的参数。
相关论文:https://arxiv.org/abs/1505.07818
代码实现
特征提取器实现,VGG网络。
1 |
|
标签预测器实现,由线性层组成。
1 |
|
域判别器实现,由线性层组成。
1 |
|
损失函数选择
1 |
|
训练过程
1 |
|
结果
训练损失上下波动,是特征提取器和域判别器在不断对抗。
训练10、100、200个epoch后,特征提取器提取的特征分布按类别和source/target展示。可以看到,随着训练的过程,特征越来越能区分不同的类,source和target domain的分布越来越一致。
Domain-Adversarial Training of Neural Networks(DaNN)实现
https://wangyinan.cn/Domain-Adversarial-Training-of-Neural-Networks(DaNN)实现