NVIDIA开源的cutlass库介绍及tensor core与wmma API原理解析

释放双眼,带上耳机,听听看~!
本文介绍了NVIDIA开源的cutlass库,以及基于Volta架构对底层所使用的wmma API原理进行的解析,涵盖了GPU加速计算和矩阵乘法优化的相关内容。

引言

cutlass是nvidia官方开源的一套用于通用矩阵乘法(GEMM)的C++模板库。相比于cuBlas、cuTensor等其他功能近似的cuda库,cutlass具有以下优点:

  • 开源:方便进行自定义和客制化的开发;
  • 模板化:API设计灵活,开发更高效;
  • 高性能:通常能达到cuBlas 90+%的性能。

cutlass底层依赖tensor core和warp矩阵乘累加(Warp Matrix Multiply Accumulate, WMMA) API。在2017年Nvidia发布了Tesla V100 GPU中采用了Volta架构,引入了第一代tensor core。并在同一时期发布的cuda 9.0中新增了与tensor core配套的wmma API。本文将结合Volta架构对cutlass底层所使用的wmma API原理进行简单介绍。

cutlass GEMM hierarchy

CUDA中一次GEMM操作可以分为两个阶段,如图1所示:

  • main loop: m×k的矩阵A与k×n的矩阵B做矩阵乘,需要在k的维度进行分块,然后在k维度进行遍历,这个过程称之为main loop。在此阶段,cutlass支持batched gemm、splitK等优化。

  • epilogue: 做完矩阵乘后可以进行element wise操作,如add bias、activation等。在此阶段cutlass可以进行kernel fuse。

图1 Gemm计算流程图

NVIDIA开源的cutlass库介绍及tensor core与wmma API原理解析

cutlass的main loop阶段的层次结构如图2所示,执行过程如下:

  1. 申请一个二维的grid,网格中每个block负责一个小的矩阵块(tile),将数据从global memory读取到shared memory,这一步与使用cuda core的GEMM操作类似;
  2. 每个block中的每个warp负责从shared memory中读取一个矩阵片段(fragment)的数据到寄存器(register file),在register file上使用wmma指令调动cuda core执行矩阵乘法。

图2 cutlass main loop层次结构

NVIDIA开源的cutlass库介绍及tensor core与wmma API原理解析

在最内层warp级别使用tensor core进行矩阵乘累加计算,需要依赖cuda 9.0之后的wmma API,下面我们将借助Volta架构中的第一代tensor core结构,对底层的mma指令进行了解。

tensor core & wmma API

与cuda core相比,tensor core是一种SM级别的硬件结构,可编程的粒度为warp level,在开发上不如cuda core的thread level灵活。Volta架构的第一代tensor core可以在一个时钟周期实现4(m)×4(n)×4(k)的矩阵乘累加,其中输入矩阵A(m×k: 4×4)、B(k×n: 4×4)数据类型为FP16,输出数据类型可以为FP16或FP32:

图3 tensor core在一个时钟周期内实现4(m)×4(n)×4(k)的矩阵乘累加

NVIDIA开源的cutlass库介绍及tensor core与wmma API原理解析

而在cuda core中做一次同样4×4×4的场景FP32的矩阵乘法,一个时钟周期只能实现1×4的运算,完成4x4x4矩阵乘法需要16个时钟周期:

图4 cuda core(Pascal) vs tensor core(Volta)

NVIDIA开源的cutlass库介绍及tensor core与wmma API原理解析

图4展示的X12倍是指tensor core相比于cuda core每秒浮点数运算理论峰值速度的差异,并不是指左边的cuda core完成1次运算完花费了16个时钟周期,右边的tensor core就要完成16次4×4×4算,毕竟cuda core和tensor core单个时钟周期所花费的时间也是不一样的。

cuda中使用tensor core来加速Gemm要借助wmma API来实现,wmma API位于cuda头文件mma.h中,有以下几个功能,结合使用可实现最底层的warp level gemm运算:

// 定义warp所负责的矩阵片段(fragment)的数据布局
template<typename Use, int m, int n, int k, typename T, typename Layout = void>
class fragment;
template<> class fragment<matrix_a, 16, 16, 16, __half, row_major>
template<> class fragment<matrix_a, 16, 16, 16, __half, col_major>
template<> class fragment<matrix_b, 16, 16, 16, __half, row_major>
template<> class fragment<matrix_b, 16, 16, 16, __half, col_major>
template<> class fragment<accumulator, 16, 16, 16, __half>
template<> class fragment<accumulator, 16, 16, 16, float>

// 从内存中加载数据到warp负责的fragment
void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm);
void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm, layout_t,layout);
// 用定值v填充fragment
void fill_fragment(fragment<...> &a, const T& v);

// 执行mma运算
void mma_sync(fragment<...> &d, const fragment<...> &a, const fragment<...>
                            &b, const fragment<...> &c, bool satf = false);

// 将warp负责fragment数据储存到内存中
void store_matrix_sync(T* mptr, const fragment<...> &a, unsigned ldm, layout_t
                            layout);

wmma API属于指令级别的API,操作的数据是fragment级别的,运算发生在register file上。单个warp每次从shared memory中拿出一个fragment,在寄存器上实现整个fragment的运算,fragment层级的运算是由一个warp内32个线程借助tensor core共同完成的,Volta中一个fragment大小为16×16,即m、n、k尺寸都是16,下面将介绍Volta架构中为什么是16×16的fragment,以及单个warp上是怎么利用tensor core执行mma的。

Volta架构mma执行原理

根据Modeling Deep Learning Accelerator Enabled GPUs (arXiv:1811.08309)相关研究,在对16×16的FP16矩阵A、B进行加载时,一个warp内4个连续的thread分成一个组(thread group),32个thread共分为8组,编号0~7,每个group中的4个thread可以负责4行(A row major)或4列(B colum major的4×16子块(sub fragment)加载,具体加载方式如图5所示。

图5 Volta架构中每个warp加载16×16 fragment A B C的方式

NVIDIA开源的cutlass库介绍及tensor core与wmma API原理解析

图5左边展示了fragment A和B的加载方式,①中一个颜色块对应的4×16 sub fragment被2个thread group同时加载,A和B中子fragment与thread group映射关系如④所示(感觉论文影印版这张图有问题,后面会进行说明)。一行16个FP16的数据共需要256位,针对不同的layout,thread group 在对4×16 sub fragment加载方式如②和③所示。对于A(row major)和B(column major)的layout,1个线程会使用两个合并的128位的加载指令共256位加载16个FP16,对于A(column major)和B(row major),1个线程会使用4个间隔的64位加载指令完成。

图5右边展示了fragment C的加载方式,C中1个thread group中的4个线程加载1个4×8 sub fragment,thread group中线程与sub fragment中数据映射关系和布局无关,只与数据类型有关。

图6展示了threadGroup 0和threadGroup 4执行mma指令计算fragment C中1个4×8 sub fragment的全过程:一个mma指令分为4个组(set 1 ~ 4),如图6(a)所示。其中每个set按照数据类型混精、FP16又分为4或2个step,如图6(b)和6(c)所示。如果是混精运算,(a)中set 1将分为(b)中的4步;如果是FP16,将会分为(c)中的2步。因此完成一个4×8 sub fragment的计算共需4×4或4×2步,每1步都是一个4×4的结构,便可借助tensor core在1个时钟周期内完成

图6 Volta架构中每个warp执行16×16 fragment运算方式

NVIDIA开源的cutlass库介绍及tensor core与wmma API原理解析

接下来我们回到图5中,说说为什么感觉影印版中的图有问题,回到图5(b),其实threadGroup 0和threadGroup 4作为一个group对,除了负责C中左上角的4×8的子片段[0:3,0:7],还负责靠下位置的4×8的子片段[4:7,0:7],共负责C中左上角[0:7,0:7]的1/4块。这也是为什么每个4×16的 sub fragment会被2个thread group同时加载的原因。这样的2个thread group做为一组,称之为一个octet。表1展示了octet的组对方式和fragment A和B被加载方式。如果按照影印版原图(图5)会和表1对不上,对图5中matrixB两处进行对调后可以对上。

表1 octet加载A和B

NVIDIA开源的cutlass库介绍及tensor core与wmma API原理解析

在理清octet关系后,Volta架构中整个warp负责的16×16 fragment C = AxB+C计算过程如图7所示:每个octet负责一个4×8的A、8×4的B、4×4的C小片段做乘累加。threadGroup-0和threadGroup-4所组成的octet-0计算过程如图7(1-b)和7(2)表格所示。

图7 Volta架构中的octet与混精计算详细过程

NVIDIA开源的cutlass库介绍及tensor core与wmma API原理解析

至此应该可以理解在Volta架构中一个warp是怎么通过32个线程,利用tensor core恰好对16×16的fragment,完成一次mma运算。论文中还介绍了Turing架构的wmma执行原理,感兴趣的读者可以去原论文中进行了解。

小结

cutlass底层依靠tensor core和WMMA API,按照matrix(grid) => tile(block) => warp(fragment)的层次结构进行GEMM操作,了解在warp level如何利用tensor core对fragment执行mma运算的原理,将有助于我们进一步熟悉cutlass源码实现,从而更得心应手的使用cutlass进行GEMM开发。

本网站的内容主要来自互联网上的各种资源,仅供参考和信息分享之用,不代表本网站拥有相关版权或知识产权。如您认为内容侵犯您的权益,请联系我们,我们将尽快采取行动,包括删除或更正。
AI教程

LangChain:AI应用开发的链式思想

2024-1-2 16:54:00

AI教程

Stable Diffusion AI绘图工具实例演示及教程分享

2024-1-2 17:01:00

个人中心
购物车
优惠劵
今日签到
有新私信 私信列表
搜索