DaNN: Domain Adversarial Neural Network – 整体框架和梯度下降目标详解

释放双眼,带上耳机,听听看~!
本文详细解释了DaNN的整体框架和梯度下降目标,介绍了域适应的概念和特征提取器的作用。

本文正在参加 人工智能创作者扶持计划

0.论文信息

DaNN: Domain Adversarial Neural Network - 整体框架和梯度下降目标详解

  这篇论文大概是我在两年前多读过的第一篇关于 Domain Adaptation 的论文,当时我还是一个小白 (现在是大白),对于 Domain Adaptation 这个领域的知识了解的很少,所以当时读这篇论文的时候,我只是看了一下论文的整体框架,做代码复现的时候对于很多细节了解得不是很透彻 (感觉没有那个意识)。所以这次不谈论文的整体 insight,整体框架也只做简要阐述,从整体梯度下降和梯度反转层 (Gradient Reversal Layer) 的角度来进行一些细节的补充。

1.DaNN 整体框架

DaNN: Domain Adversarial Neural Network - 整体框架和梯度下降目标详解

  DaNN 的整体框架如上图所示,整体框架分为两个部分,第一部分是特征提取器 (Feature Extractor) 和分类器 (Classifier),第二部分是域分类器 (Domain Classifier)。特征提取器和分类器的作用是将输入的数据映射到一个特征空间,然后在特征空间中进行分类。Domain Classifier 的作用是判断输入的数据是属于 Source Domain 还是 Target Domain。整体框架的目标是让 Source Domain 和 Target Domain 的数据经过特征提取器之后更大程度地具有相似的分布,同时 Source Domain 上的数据仍有较好的分类效果,以一定程度上保障 Target Domain 上的分类效果。

2.整体梯度下降目标

E(θf,θy,θd)=∑i=1..Ndi=0Ly(Gy(Gf(xi;θf);θy),yi)−λ∑i=1..NLd(Gd(Gf(xi;θf);θd),yi)==∑i=1.Ndi=0Lyi(θf,θy)−λ∑i=1..NLdi(θf,θd)begin{gathered}
Eleft(theta_f, theta_y, theta_dright)=sum_{substack{i=1 . . N
d_i=0}} L_yleft(G_yleft(G_fleft(mathbf{x}_i ; theta_fright) ; theta_yright), y_iright)-
lambda sum_{i=1 . . N} L_dleft(G_dleft(G_fleft(mathbf{x}_i ; theta_fright) ; theta_dright), y_iright)=
=sum_{substack{i=1 . N
d_i=0}} L_y^ileft(theta_f, theta_yright)-lambda sum_{i=1 . . N} L_d^ileft(theta_f, theta_dright)
end{gathered}

  其中,Ly(⋅,⋅)L_y(cdot,cdot) 是标签预测的损失,Ld(⋅,⋅)L_d(cdot,cdot) 是域分类,而 LyiL^i_yLdiL^i_d 表示第 ii 个训练示例评估的相应损失函数。即可以转化为 :

(θ^f,θ^y)=arg⁡min⁡θf,θyE(θf,θy,θ^d)θ^d=arg⁡max⁡θdE(θ^f,θ^y,θd).begin{gathered}
left(hat{theta}_f, hat{theta}_yright)=arg min _{theta_f, theta_y} Eleft(theta_f, theta_y, hat{theta}_dright)
hat{theta}_d=arg max _{theta_d} Eleft(hat{theta}_f, hat{theta}_y, theta_dright) .
end{gathered}

  即希望通过 θf,θytheta_f,theta_y 的优化使得 E(θf,θy,θd)Eleft(theta_f, theta_y, theta_dright) 尽可能小,同时通过 θdtheta_d 的优化使得 E(θf,θy,θd)Eleft(theta_f, theta_y, theta_dright) 尽可能大。这样就可以使得 Source Domain 和 Target Domain 的数据经过特征提取器之后更大程度地具有相似的分布的同时,Source Domain 上的数据仍有较好的分类效果。因此可以用随机梯度下降法 (SGD) 来进行迭代化表示 :

θf⟵θf−μ(∂Lyi∂θf−λ∂Ldi∂θf)θy⟵θy−μ∂Lyi∂θyθd⟵θd−μ∂Ldi∂θdbegin{aligned}
& theta_f quad longleftarrow quad theta_f-muleft(frac{partial L_y^i}{partial theta_f}-lambda frac{partial L_d^i}{partial theta_f}right)
& theta_y quad longleftarrow quad theta_y-mu frac{partial L_y^i}{partial theta_y}
& theta_d quad longleftarrow quad theta_d-mu frac{partial L_d^i}{partial theta_d}
end{aligned}

  然而直接实现 (4)-(6) 作为 SGD 并不好实现,这时候就要引入梯度反转层 (Gradient Reversal Layer,GRL) 层来进行实现了。

3.梯度反转层 (Gradient Reversal Layer)

  可以通过引入定义如下的特殊梯度反转层 (GRL) 来完成。梯度反转层没有与其相关的参数 (除了参数 λlambda 不通过反向传播更新)。在前向传播期间,GRL 充当恒等变换。然而,在反向传播期间,GRL 从后续层获取梯度,将其 ×(−λ)times(-lambda) 并将其传递给前一层。

  可以进行如下形式化定义 :

Rλ(x)=xdRλdx=−λIbegin{array}{r}
R_lambda(mathbf{x})=mathbf{x}
frac{d R_lambda}{d mathbf{x}}=-lambda mathbf{I}
end{array}

  其中,Imathbf{I} 是单位矩阵。然后我们可以定义 (θf,θy,θd)(theta_f,theta_y,theta_d) 的目标“伪函数”,通过我们的方法中的随机梯度下降进行优化 :

E~(θf,θy,θd)=∑i=1…Ndi=0Ly(Gy(Gf(xi;θf);θy),yi)+∑i=1…NLd(Gd(Rλ(Gf(xi;θf));θd),yi)begin{gathered}
tilde{E}left(theta_f, theta_y, theta_dright)=sum_{substack{i=1 ldots N
d_i=0}} L_yleft(G_yleft(G_fleft(mathbf{x}_i ; theta_fright) ; theta_yright), y_iright)+
sum_{i=1 ldots N} L_dleft(G_dleft(R_lambdaleft(G_fleft(mathbf{x}_i ; theta_fright)right) ; theta_dright), y_iright)
end{gathered}

  为了方便理解,结合上图,我们可以进行一次反向过程的模拟,首先我们会计算得到 E~(θf,θy,θd)tilde{E}left(theta_f, theta_y, theta_dright),需要在域分类器上得到对应的 −∂E~(θf,θy,θd)∂θd-frac{partial tilde{E}left(theta_f, theta_y, theta_dright)}{partial theta_d},实际计算过程中下降梯度为

−μ∂E~(θf,θy,θd)∂θd=−μ∂Ldi(Gd(Gf(xi;θf);θd),yi)∂θd=−μ∂Ldi∂θd-mufrac{partial tilde{E}left(theta_f, theta_y, theta_dright)}{partial theta_d}=-mufrac{partial L_d^ileft(G_dleft(G_fleft(mathbf{x}_i ; theta_fright) ; theta_dright), y_iright)}{partial theta_d}=-mu frac{partial L_d^i}{partial theta_d}

  而当梯度传播过 Rλ(⋅)R_lambda(cdot) 之后,梯度会 ×(−λ)times(-lambda),此时对于 θy,θftheta_y,theta_f 的下降梯度则变为

−μ∂E~(θf,θy,θd)∂θy=−μ∂(Ly(Gy(Gf(xi;θf);θy),yi))∂θy+μλ(Ld(Gd(Gf(xi;θf);θd),yi))∂θy=−μ∂Lyi∂θy−μ∂E~(θf,θy,θd)∂θf=−μ∂(Ly(Gy(Gf(xi;θf);θy),yi))∂θf+μλ(Ld(Gd(Gf(xi;θf);θd),yi))∂θf=−μ(∂Lyi∂θf−λ∂Ldi∂θf)begin{aligned}
-mufrac{partial tilde{E}left(theta_f, theta_y, theta_dright)}{partial theta_y}=&-mufrac{partialleft( L_yleft(G_yleft(G_fleft(mathbf{x}_i ; theta_fright) ; theta_yright), y_iright)right)}{partial theta_y}
&+mufrac{lambdaleft(
L_dleft(G_dleft(G_fleft(mathbf{x}_i ; theta_fright) ; theta_dright), y_iright)right)}{partial theta_y}
=&-mu frac{partial L_y^i}{partial theta_y}
-mufrac{partial tilde{E}left(theta_f, theta_y, theta_dright)}{partial theta_f}=&-mufrac{partialleft( L_yleft(G_yleft(G_fleft(mathbf{x}_i ; theta_fright) ; theta_yright), y_iright)right)}{partial theta_f}
&+mufrac{lambdaleft(
L_dleft(G_dleft(G_fleft(mathbf{x}_i ; theta_fright) ; theta_dright), y_iright)right)}{partial theta_f}
=&-muleft(frac{partial L_y^i}{partial theta_f}-lambda frac{partial L_d^i}{partial theta_f}right)
end{aligned}

  不得不说,还是很精巧的设计,巧妙地将对抗转化为了一致的迭代目标。

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

基于深度学习的高精度Caltech数据集行人检测识别系统

2023-12-16 19:27:14

AI教程

AI算法在视频上的应用及未来发展

2023-12-16 19:37:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索