转载、参考: https://zhuanlan.zhihu.com/p/1910636263666610461

计算量推导

矩阵乘法:
$C = \alpha AB + \beta C$
$A$ 形状为 $M \times K$ ,$B$ 形状为 $ K \times N $ ,$C$ 形状为 $M \times N$。

矩阵 $A$($M \times K$)与 $B$($K \times N$)相乘,得到 $C$($M \times N$)。

每个 $C_{i,j}$ 的计算:

$$
C_{i,j} = \sum_{k=1}^{K} A_{i,k} B_{k,j}
$$

每个元素需要 $K$ 次乘法和 $K-1$ 次加法,计算$AB$需要$(2K-1)MN$次浮点计算,缩放$MN$和$C$各需要$MN$次计算,最后两个矩阵相加需要$MN$次计算,总计算为$(2K-1)MN+MN+MN+MN = (2K+2)MN$次计算,约为

每个 thread 负责 C 中一个元素的计算

1
2
3
4
5
6
7
8
9
10
11
__global__ void GEMM(float* A,float*B,float*C,const int M,const int N,const int K){
int r = threadIdx.y + blockIdx.y * blockDim.y;
int c = threadIdx.x + blockIdx.y * blockDix.x;
if(r < M && c < N){
float sum = 0 ;
for(int i = 0 ; i < K ; i++){
sum += A[r * K + i]*B[k * N + c];//下标计算怎么来的
}
C[r * N + c] = sum;
}
}

问题:全局显存访问2MNK,全局带宽占用受限制

思路:转移到shared_memory

分块计算


分块计算以后,将Bm和Bn计算放置到shared-memory中可以有效减少全局内存访问次数,因为在块内复用了一部分

但是K通常很大,可能无法放下,所以说需要在K的维度上拆分:

bk \bm\bn的数据会影响性能,文章根据计算量、访存量进行了推算,选定bk=8 bm=bn=128

流程:

  1. 申请 shared memory 空间,用于存储每次循环中矩阵  和矩阵  参与计算的 tile,变量名为 As 和 Bs
  2. 计算每个 thread 负责的矩阵  中的元素个数,定义相同大小的寄存器数组 Ct,用于存储累加结果。
  3. K-Loop 循环,循环步长为  ,循环体包括:
    1. 从 global memory 上的矩阵  中读取参与计算的 tileA,存入 As
    1. 从 global memory 上的矩阵  中读取参与计算的 tileB,存入 Bs
  4. 线程块内同步;
  5. 计算 As 与 Bs 的矩阵相乘结果,累加存入 Ct
  6. 线程块内同步;
  7. 循环完成后,将Ct 结果写入 global memory 上的矩阵  中对应位置。

本站由 Zane Jiang 使用 Stellar 1.33.1 主题创建,一款很棒的 Hexo 主题!

总访问 次 || 本页访问
总访客 人 || 本页访客