torch.einsum()的实际应用分析

释放双眼,带上耳机,听听看~!
本文从实际复杂案例的角度对torch.einsum()的计算过程进行分析,介绍了爱因斯坦求和约定和复杂案例推导,帮助读者更好地理解和应用torch.einsum()。

引言: torch.einsum()的分析和介绍已经有很多博客介绍过了, 但大多数的落脚点都是爱因斯坦求和约定,许多篇幅是用于介绍爱因斯坦求和约定到的各项法则,而实际案例分析方面只是草草给出一笔带过,涉及到的案例也较为简单。而实际我们要用到或者看到torch.einsum()的时候往往是在计算非常复杂的情况下。
因此本文将从实际复杂案例的角度对torch.einsum()的计算过程进行分析,一步一步的推导最终输出的每个元素和输入元素之间的关系。

爱因斯坦求和约定

 首先,torch.einsum()的基础原理是爱因斯坦求和约定,此处为了行文的整体性将对其进行简要的介绍,如果只关注计算本身,可以跳到下一节。爱因斯坦求和约定是为了简化计算而诞生的一种“记法”,就类似于我们用×times来标记乘法一样,不同之处在于爱因斯坦求和约定可表示的运算更为复杂、灵活性也更高。爱因斯坦求和约定的典型写法为:

i1i2…iN,j1j2…jM→ik1ik2..jl1jl1,k1…∈N,l1..l∈Mi_1i_2…i_N,j_1j_2…j_Mrightarrow i_{k_1}i_{k_2}..j_{l_1}j_{l_1},k_1…in N,l_1..lin M

其中左端i1i2…iN,j1j2…jMi_1i_2…i_N,j_1j_2…j_M就表示了输入两个矩阵元素的坐标索引,右端ik1ik2..jl1jl1i_{k_1}i_{k_2}..j_{l_1}j_{l_1}为输出矩阵元素的坐标索引,可以看到输出矩阵元素索引相较于输入端的索引可能会缺少几项,运算就是发生这几个维度上的乘累加操作。
其中同时出现在左端和右端的坐标索引为自由索引,只用于标记位置;而仅仅出现在右端的索引为求和索引,爱因斯坦求和约定的本质就是沿着求和索引的方向计算两个输入逐元素乘累加和的结果放到输出自由索引的位置上,更为细致的介绍参见:一文学会 Pytorch 中的 einsumeinsum:爱因斯坦求和约定
举例而言:

ij,jk→ikij,jkrightarrow ik

就表示沿着jj这个维度进行乘累加操作:

Oik=∑jAijBjkO_{ik}=sum_{j}A_{ij}B_{jk}

输出的第(i,k)(i,k)个元素为Ai⋅A_{i cdot}的行向量和B⋅kB_{cdot k}列向量逐元素乘累加,实际上就是矩阵相乘。

复杂案例推导

 正如第一节中所介绍的,torch.enisum()的核心计算过程就是沿着只在算式右边出现的轴对输入矩阵元素进行乘累加得到对应位置的输出元素。因此,想要弄清一个复杂的torch.eisum()表达式含义需要做的也只是将这个求和公式写出来再仔细分析。

案例. 四维张量乘三维张量

 给出一个复杂案例:

ncjt,npj−>ncptncjt,npj->ncpt

则其输出元素可以表示为:

Cncpt=∑jAncjtBnpjC_{ncpt}=sum_j A_{ncjt}B_{npj}

 首先我们可以注意到对于C的第一维nn而言,它同时出现在A和B的首位,也就是对于这一维的每个元素,都是会逐元素的执行A和B剩余维度的计算再在当前维度上排布,用深度学习中的描述来说就是对BATCH中的每个元素都独立的执行后续子操作,子操作可以记为:

Ccpt=∑jAcjtBpjC_{cpt}=sum_j A_{cjt}B_{pj}

紧接着,对当前算式的第一维cc来说它只出现在AA中,每沿着cc计算一个不同的元素都要和“相同”的B计算,也就出现了广播机制,B有了个隐藏的、元素重复的维度cc,计算变为Ccpt=∑jAcjtBcpjC_{cpt}=sum_j A_{cjt}B_{cpj},同第一步计算的原理,这里又可以化简成逐元素的子操作:

Cpt=∑jAjtBpj=∑jBpjAjtC_{pt}=sum_jA_{jt}B_{pj}=sum_jB_{pj}A_{jt}

此时易看出(p,t)(p,t)元素就是B的第pp行向量和A的第tt列向量求内积。
 从而我们可以得出结论,这一表达式的意思是,对于BATCH内的每个元素(A‘三维,B’二维),对B在第一维度进行广播(A”三维,B”三维),最后沿着第二维和第三维计算矩阵相乘B”’A”’(A”’二维,B”’二维)。

而整个的推导过程可以总结为以下几要点

  1. 沿着维数较高输入的第一维开始,判断是否存在于B中,如果在的话就可认为是逐元素操作,暂时忽略该维度;
  2. 如果该维度指示不在B中,则进行广播操作,重新回到1,否则3.
  3. 判断当前最简表达式的意义。
本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

数据挖掘实践:金融风控之贷款违约预测挑战赛(下篇)

2023-11-28 8:39:14

AI教程

ATC模型转换动态shape问题案例及解决方法

2023-11-28 8:48:14

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索