CUDA之通用矩陣乘法:從入門到熟練!
本文經自動駕駛之心公眾號授權轉載,轉載請洽出處。
通用矩陣乘法(General Matrix Multiplication,GEMM) 是各種模型和計算中的核心部分,同時也是評估計算硬體效能(FLOPS) 的標準技術。本文將透過對GEMM 的實現和最佳化,來試圖理解高效能運算和軟硬體系統。
一、GEMM的基本特徵
1.1 GEMM計算過程及複雜度
GEMM 的定義為:
矩陣乘法的計算示意
1.2 簡單實作及流程分析
以下是依照原始定義實現的CPU 上實現的程式碼,之後用以作為精度的對照
#define OFFSET(row, col, ld) ((row) * (ld) + (col))
void cpuSgemm(
float *a, float *b, float *c, const int M, const int N, const int K) {
for (int m = 0; m < M; m++) {
for (int n = 0; n < N; n++) {
float psum = 0.0;
for (int k = 0; k < K; k++) {
psum += a[OFFSET(m, k, K)] * b[OFFSET(k, n, N)];
}
c[OFFSET(m, n, N)] = psum;
}
}
}
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
下面使用CUDA實現最簡單的矩陣乘法的Kernal,總共使用M * N 個執行緒完成整個矩陣乘法。每個執行緒負責矩陣C中一個元素的計算,需要完成K次乘累加。矩陣A,B,C皆存放與全域記憶體中(由修飾符決定),完整代碼見sgemm_naive.cu 。 __global__
__global__ void naiveSgemm(
float * __restrict__ a, float * __restrict__ b, float * __restrict__ c,
const int M, const int N, const int K) {
int n = blockIdx.x * blockDim.x + threadIdx.x;
int m = blockIdx.y * blockDim.y + threadIdx.y;
if (m < M && n < N) {
float psum = 0.0;
#pragma unroll
for (int k = 0; k < K; k++) {
psum += a[OFFSET(m, k, K)] * b[OFFSET(k, n, N)];
}
c[OFFSET(m, n, N)] = psum;
}
}
const int BM = 32, BN = 32;
const int M = 512, N = 512, K = 512;
dim3 blockDim(BN, BM);
dim3 gridDim((N + BN - 1) / BN, (M + BM - 1) / BM);
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
編譯完成,在Tesla V100-PCIE-32GB上執行的結果如下,根據V100的白皮書,FP32 的峰值算力為15.7 TFLOPS,因此此方式算力利用率僅有11.5%。
M N K = 128 128 1024, Time = 0.00010083 0.00010260 0.00010874 s, AVG Performance = 304.5951 Gflops
M N K = 192 192 1024, Time = 0.00010173 0.00010198 0.00010253 s, AVG Performance = 689.4680 Gflops
M N K = 256 256 1024, Time = 0.00010266 0.00010318 0.00010384 s, AVG Performance = 1211.4281 Gflops
M N K = 384 384 1024, Time = 0.00019475 0.00019535 0.00019594 s, AVG Performance = 1439.7206 Gflops
M N K = 512 512 1024, Time = 0.00037693 0.00037794 0.00037850 s, AVG Performance = 1322.9753 Gflops
M N K = 768 768 1024, Time = 0.00075238 0.00075558 0.00075776 s, AVG Performance = 1488.9271 Gflops
M N K = 1024 1024 1024, Time = 0.00121562 0.00121669 0.00121789 s, AVG Performance = 1643.8068 Gflops
M N K = 1536 1536 1024, Time = 0.00273072 0.00275611 0.00280208 s, AVG Performance = 1632.7386 Gflops
M N K = 2048 2048 1024, Time = 0.00487622 0.00488028 0.00488614 s, AVG Performance = 1639.2518 Gflops
M N K = 3072 3072 1024, Time = 0.01001603 0.01071136 0.01099990 s, AVG Performance = 1680.4589 Gflops
M N K = 4096 4096 1024, Time = 0.01771046 0.01792170 0.01803462 s, AVG Performance = 1785.5450 Gflops
M N K = 6144 6144 1024, Time = 0.03988969 0.03993405 0.04000595 s, AVG Performance = 1802.9724 Gflops
M N K = 8192 8192 1024, Time = 0.07119219 0.07139694 0.07160816 s, AVG Performance = 1792.7940 Gflops
M N K = 12288 12288 1024, Time = 0.15978026 0.15993242 0.16043369 s, AVG Performance = 1800.7606 Gflops
M N K = 16384 16384 1024, Time = 0.28559187 0.28567238 0.28573316 s, AVG Performance = 1792.2629 Gflops
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
以下以M=512,K=512,N=512,為例,詳細分析上述計算過程的workflow:
- 在Global Memory 中分別為矩陣A,B,C分配儲存空間.
- 由於矩陣C中每個元素的計算均相互獨立, 因此在並行度映射中讓每個thread 對應矩陣C中1 個元素的計算.
- 執行配置(execution configuration)中gridSize 和blockSize 都有x(列向)、y(行向)兩個維度, 其中
nsys 記錄的naive 版本的profiling
二、GEMM的最佳化探究
前文僅在功能上實現了GEMM,性能上還遠遠不及預期,本節將主要研究GEMM 性能上的優化。
2.1 矩陣分塊利用Shared Memory
上述的計算需要兩次Global Memory的load才能完成一次乘累加運算,計算訪存比極低,沒有有效的資料重複使用。所以可以用Shared Memory 來減少重複的記憶體讀取。
首先把矩陣C等分成BMxBN大小的分塊,每個分塊由一個Block 計算,其中每個Thread負責計算矩陣C中的TMxTN個元素。之後計算所需的資料全部從smem 讀取,就消除了一部分重複的A,B矩陣記憶體讀取。考慮到Shared Memory 容量有限,可以在K維上每次讀取BK大小的分塊,這樣的循環一共需要K / BK次以完成整個矩陣乘法操作,即可得到Block 的結果。其過程如下圖所示:
利用Shared Memory 最佳化後,對每一個分塊,可得:
由上式可知BM和BN越大,計算訪存比越高,效能就會越好。但由於Shared Memory 容量的限制(V100 1個SM僅96KB),而一個Block需要佔用BK * (BM + BN) * 4 Bytes大小。
TM和TN的取值也受到兩方面限制,一方面是線程數的限制,一個Block中有BM / TM * BN / TN個線程,這個數字不能超過1024,且不能太高防止影響SM內Block間的並行;另一方面是寄存器數目的限制,一個線程至少需要TM * TN個寄存器用於存放矩陣C的部分和,再加上一些其它的寄存器,所有的寄存器數目不能超過256,且不能太高防止影響SM內同時並行的執行緒數目。
最終選取BM = BN = 128,BK = 8,TM = TN = 8,則此時計算訪存比為32。根據V100的理論算力15.7TFLOPS,可得15.7TFLOPS/32 = 490GB/s,根據實測的HBM頻寬為763GB/s,可知此時頻寬不再會限制運算效能。
根據上述分析,kernel 函數實作過程如下,完整程式碼參考sgemm_v1.cu,主要步驟包括:
AB 矩陣分塊的線程索引關係
確定好單一block的執行過程,接下來需要確定多block處理的不同分塊在Global Memory中的對應關係,仍以A為例進行說明。由於分塊沿著行的方向移動,那麼首先需要確定行號,根據Grid 的二維全域線性索引關係,by * BM
表示該分塊的起始行號,同時我們已知load_a_smem_m
為分塊內部的行號,因此全域的行號為load_a_gmem_m = by * BM + load_a_smem_m
。由於分塊沿著行的方向移動,因此列是變化的,需要在循環內部計算,同樣也是先計算起始列號bk * BK
加速分塊內部列號load_a_smem_k
得到,由此我們便可以確定了分塊在原始數據中的位置OFFSET( , K) 。同理可分析矩陣分塊的情況,不再贅述。 load_a_gmem_k = bk * BK + load_a_smem_k
load_a_gmem_m, load_a_gmem_k
計算完後,還需要將其存入Global Memory 中,這就需要計算其在Global Memory 中的對應關係。由於存在較小的分塊,則行和列均由3部分構成:全域行號store_c_gmem_m
等於大分塊的起始行號by * BM
+小分塊的起始行號ty * TM
+小分塊內部的相對行號。列同理。 i
__global__ void sgemm_V1(
float * __restrict__ a, float * __restrict__ b, float * __restrict__ c,
const int M, const int N, const int K) {
const int BM = 128;
const int BN = 128;
const int BK = 8;
const int TM = 8;
const int TN = 8;
const int bx = blockIdx.x;
const int by = blockIdx.y;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int tid = ty * blockDim.x + tx;
__shared__ float s_a[BM][BK];
__shared__ float s_b[BK][BN];
float r_c[TM][TN] = {0.0};
int load_a_smem_m = tid >> 1; // tid/2, row of s_a
int load_a_smem_k = (tid & 1) << 2; // (tid % 2 == 0) ? 0 : 4, col of s_a
int load_b_smem_k = tid >> 5; // tid/32, row of s_b
int load_b_smem_n = (tid & 31) << 2; // (tid % 32) * 4, col of s_b
int load_a_gmem_m = by * BM + load_a_smem_m; // global row of a
int load_b_gmem_n = bx * BN + load_b_smem_n; // global col of b
for (int bk = 0; bk < (K + BK - 1) / BK; bk++) {
int load_a_gmem_k = bk * BK + load_a_smem_k; // global col of a
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
FLOAT4(s_a[load_a_smem_m][load_a_smem_k]) = FLOAT4(a[load_a_gmem_addr]);
int load_b_gmem_k = bk * BK + load_b_smem_k; // global row of b
int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
FLOAT4(s_b[load_b_smem_k][load_b_smem_n]) = FLOAT4(b[load_b_gmem_addr]);
__syncthreads();
#pragma unroll
for (int k = 0; k < BK; k++) {
#pragma unroll
for (int m = 0; m < TM; m++) {
#pragma unroll
for (int n = 0; n < TN; n++) {
int comp_a_smem_m = ty * TM + m;
int comp_b_smem_n = tx * TN + n;
r_c[m][n] += s_a[comp_a_smem_m][k] * s_b[k][comp_b_smem_n];
}
}
}
__syncthreads();
}
#pragma unroll
for (int i = 0; i < TM; i++) {
int store_c_gmem_m = by * BM + ty * TM + i;
#pragma unroll
for (int j = 0; j < TN; j += 4) {
int store_c_gmem_n = bx * BN + tx * TN + j;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][j]);
}
}
}
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
計算結果如下,性能達到了理論峰值性能的51.7%:
M N K = 128 128 1024, Time = 0.00031578 0.00031727 0.00032288 s, AVG Performance = 98.4974 Gflops
M N K = 192 192 1024, Time = 0.00031638 0.00031720 0.00031754 s, AVG Performance = 221.6661 Gflops
M N K = 256 256 1024, Time = 0.00031488 0.00031532 0.00031606 s, AVG Performance = 396.4287 Gflops
M N K = 384 384 1024, Time = 0.00031686 0.00031814 0.00032080 s, AVG Performance = 884.0425 Gflops
M N K = 512 512 1024, Time = 0.00031814 0.00032007 0.00032493 s, AVG Performance = 1562.1563 Gflops
M N K = 768 768 1024, Time = 0.00032397 0.00034419 0.00034848 s, AVG Performance = 3268.5245 Gflops
M N K = 1024 1024 1024, Time = 0.00034570 0.00034792 0.00035331 s, AVG Performance = 5748.3952 Gflops
M N K = 1536 1536 1024, Time = 0.00068797 0.00068983 0.00069094 s, AVG Performance = 6523.3424 Gflops
M N K = 2048 2048 1024, Time = 0.00136173 0.00136552 0.00136899 s, AVG Performance = 5858.5604 Gflops
M N K = 3072 3072 1024, Time = 0.00271910 0.00273115 0.00274006 s, AVG Performance = 6590.6331 Gflops
M N K = 4096 4096 1024, Time = 0.00443805 0.00445964 0.00446883 s, AVG Performance = 7175.4698 Gflops
M N K = 6144 6144 1024, Time = 0.00917891 0.00950608 0.00996963 s, AVG Performance = 7574.0999 Gflops
M N K = 8192 8192 1024, Time = 0.01628838 0.01645271 0.01660790 s, AVG Performance = 7779.8733 Gflops
M N K = 12288 12288 1024, Time = 0.03592557 0.03597434 0.03614323 s, AVG Performance = 8005.7066 Gflops
M N K = 16384 16384 1024, Time = 0.06304122 0.06306373 0.06309302 s, AVG Performance = 8118.7715 Gflops
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
下面仍以M=512,K=512,N=512為例,分析一下結果。首先透過profiling 可以看到Shared Memory 佔用為8192 bytes,這與理論上(128+128)X8X4完全一致。
nsys 記錄的V1 版本的profiling
profiling 顯示Occupancy 為12.5%,可以透過cuda-calculator 加以印證,該例中threads per block = 256, Registers per thread = 136, 由此可以計算得到每個SM中活躍的warp 為8,而對於V100,每個SM中的warp 總數為64,因此Occupancy 為8/64 = 12.5%。
2.2 解決Bank Conflict 問題
上節透過利用Shared Memory 大幅提高了存取效率,進而提高了效能,本節將進一步優化Shared Memory 的使用。
Shared Memory共分為32個Bank,每個Bank的寬度為4 Bytes,如果需要存取同一個Bank的多個數據,就會發生Bank Conflict。例如一個Warp的32個線程,如果訪問的地址分別為0、4、8、...、124,就不會發生Bank Conflict,只佔用Shared Memory一拍的時間;如果訪問的地址為0、8 、16、...、248,這樣一來地址0和地址128對應的資料位於同一Bank、地址4和地址132對應的資料位於同一Bank,以此類推,那麼就需要佔用Shared Memory兩拍的時間才能讀出。
有Bank Conflict VS 無Bank Conflict
再看V1 版本計算部分的三層循環,每次從Shared memory中取矩陣A的長度為TM的向量和矩陣B的長度為TN的向量,這兩個向量做外積並累加到部分和中,一次外積共TM * TN次乘累加,共需循環BK次取數與外積。
接下來分析從Shared Memory load的過程中存在的Bank Conflict:
i) 取矩陣A需要取一個列向量,而矩陣A在Shared Memory中是按行儲存的;
ii) 在TM = TN = 8的情況下,無論矩陣A或矩陣B,從Shared Memory中取數時需要取連續的8個數,即便用LDS.128指令一條指令取四個數,也需要兩條指令,由於一個執行緒的兩個load指令的位址是連續的,那麼同一個Warp不同執行緒的同一條load指令的訪存位址就是被間隔開的,就存在著Bank Conflict。
為了解決上述的兩點Shared Memory的Bank Conflict,採用了一下兩點優化:
i) 為矩陣A指派Shared Memory時形狀分配為[BK][BM],即讓矩陣A在Shared Memory中按列存儲
ii) 將原本每個執行緒負責計算的TM * TN的矩陣C,分成下圖中這樣的兩塊TM/2 * TN的矩陣C,由於TM/2=4,一條指令即可完成A的一塊的load操作,兩個load可同時進行。
kernel 函數的核心部分實作如下,完整程式碼見sgemm_v2.cu 。
__shared__ float s_a[BK][BM];
__shared__ float s_b[BK][BN];
float r_load_a[4];
float r_load_b[4];
float r_comp_a[TM];
float r_comp_b[TN];
float r_c[TM][TN] = {0.0};
int load_a_smem_m = tid >> 1;
int load_a_smem_k = (tid & 1) << 2;
int load_b_smem_k = tid >> 5;
int load_b_smem_n = (tid & 31) << 2;
int load_a_gmem_m = by * BM + load_a_smem_m;
int load_b_gmem_n = bx * BN + load_b_smem_n;
for (int bk = 0; bk < (K + BK - 1) / BK; bk++) {
int load_a_gmem_k = bk * BK + load_a_smem_k;
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
int load_b_gmem_k = bk * BK + load_b_smem_k;
int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);
s_a[load_a_smem_k ][load_a_smem_m] = r_load_a[0];
s_a[load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
s_a[load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
s_a[load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
FLOAT4(s_b[load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);
__syncthreads();
#pragma unroll
for (int tk = 0; tk < BK; tk++) {
FLOAT4(r_comp_a[0]) = FLOAT4(s_a[tk][ty * TM / 2 ]);
FLOAT4(r_comp_a[4]) = FLOAT4(s_a[tk][ty * TM / 2 + BM / 2]);
FLOAT4(r_comp_b[0]) = FLOAT4(s_b[tk][tx * TN / 2 ]);
FLOAT4(r_comp_b[4]) = FLOAT4(s_b[tk][tx * TN / 2 + BN / 2]);
#pragma unroll
for (int tm = 0; tm < TM; tm++) {
#pragma unroll
for (int tn = 0; tn < TN; tn++) {
r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
}
}
}
__syncthreads();
}
#pragma unroll
for (int i = 0; i < TM / 2; i++) {
int store_c_gmem_m = by * BM + ty * TM / 2 + i;
int store_c_gmem_n = bx * BN + tx * TN / 2;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][0]);
FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i][4]);
}
#pragma unroll
for (int i = 0; i < TM / 2; i++) {
int store_c_gmem_m = by * BM + BM / 2 + ty * TM / 2 + i;
int store_c_gmem_n = bx * BN + tx * TN / 2;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i + TM / 2][0]);
FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i + TM / 2][4]);
}
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
結果如下,相對未解決Bank Conflict 版(V1) 性能提高了14.4%,達到了理論峰值的74.3%。
M N K = 128 128 1024, Time = 0.00029699 0.00029918 0.00030989 s, AVG Performance = 104.4530 Gflops
M N K = 192 192 1024, Time = 0.00029776 0.00029828 0.00029882 s, AVG Performance = 235.7252 Gflops
M N K = 256 256 1024, Time = 0.00029485 0.00029530 0.00029619 s, AVG Performance = 423.2949 Gflops
M N K = 384 384 1024, Time = 0.00029734 0.00029848 0.00030090 s, AVG Performance = 942.2843 Gflops
M N K = 512 512 1024, Time = 0.00029853 0.00029945 0.00030070 s, AVG Performance = 1669.7479 Gflops
M N K = 768 768 1024, Time = 0.00030458 0.00032467 0.00032790 s, AVG Performance = 3465.1038 Gflops
M N K = 1024 1024 1024, Time = 0.00032406 0.00032494 0.00032621 s, AVG Performance = 6155.0281 Gflops
M N K = 1536 1536 1024, Time = 0.00047990 0.00048224 0.00048461 s, AVG Performance = 9331.3912 Gflops
M N K = 2048 2048 1024, Time = 0.00094426 0.00094636 0.00094992 s, AVG Performance = 8453.4569 Gflops
M N K = 3072 3072 1024, Time = 0.00187866 0.00188096 0.00188538 s, AVG Performance = 9569.5816 Gflops
M N K = 4096 4096 1024, Time = 0.00312589 0.00319050 0.00328147 s, AVG Performance = 10029.7885 Gflops
M N K = 6144 6144 1024, Time = 0.00641280 0.00658940 0.00703498 s, AVG Performance = 10926.6372 Gflops
M N K = 8192 8192 1024, Time = 0.01101130 0.01116194 0.01122950 s, AVG Performance = 11467.5446 Gflops
M N K = 12288 12288 1024, Time = 0.02464854 0.02466705 0.02469344 s, AVG Performance = 11675.4946 Gflops
M N K = 16384 16384 1024, Time = 0.04385955 0.04387468 0.04388355 s, AVG Performance = 11669.5995 Gflops
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
分析一下profiling 可以看到Static Shared Memory 仍然是使用了8192 Bytes,奇怪的是,Shared Memory executed 卻翻倍變成了16384 Bytes(知友如果知道原因可以告訴我一下)。
2.3 流水並行化:Double Buffering
Double Buffering,即雙重緩衝,即透過增加buffer的方式,使得訪存-計算的串列模式流水線化,以減少等待時間,提高計算效率,其原理如下圖所示:
單緩衝 VS 雙緩衝
具體到GEMM 任務中來,就是需要兩倍的Shared Memory,之前只需要BK * (BM + BN) * 4 Bytes的Shared Memory,採用Double Buffering之後需要2BK * (BM + BN) * 4 Bytes的Shared Memory ,然後使其pipeline 流動起來。
程式碼核心部分如下所示,完整程式碼請參考sgemm_v3.cu 。有以下幾點要注意:
1)主循環從bk = 1
開始,第一次資料載入在主循環之前,最後一次計算在主循環之後,這是pipeline 的特徵決定的;
2)由於計算和下次訪存使用的Shared Memory不同,因此在主循環中每次循環只需要一次__syncthreads()即可
3)由於GPU無法向CPU那樣支援亂序執行,主循環中需要先將下一次循環計算所需的Gloabal Memory中的資料load 到暫存器,然後進行本次計算,之後再將load到暫存器中的資料寫到Shared Memory,這樣在LDG指令向Global Memory做load時,不會影響後續FFMA及其它運算指令的launch 執行,也就達到了Double Buffering的目的。
__shared__ float s_a[2][BK][BM];
__shared__ float s_b[2][BK][BN];
float r_load_a[4];
float r_load_b[4];
float r_comp_a[TM];
float r_comp_b[TN];
float r_c[TM][TN] = {0.0};
int load_a_smem_m = tid >> 1;
int load_a_smem_k = (tid & 1) << 2;
int load_b_smem_k = tid >> 5;
int load_b_smem_n = (tid & 31) << 2;
int load_a_gmem_m = by * BM + load_a_smem_m;
int load_b_gmem_n = bx * BN + load_b_smem_n;
{
int load_a_gmem_k = load_a_smem_k;
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
int load_b_gmem_k = load_b_smem_k;
int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);
s_a[0][load_a_smem_k ][load_a_smem_m] = r_load_a[0];
s_a[0][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
s_a[0][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
s_a[0][load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
FLOAT4(s_b[0][load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);
}
for (int bk = 1; bk < (K + BK - 1) / BK; bk++) {
int smem_sel = (bk - 1) & 1;
int smem_sel_next = bk & 1;
int load_a_gmem_k = bk * BK + load_a_smem_k;
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
int load_b_gmem_k = bk * BK + load_b_smem_k;
int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);
#pragma unroll
for (int tk = 0; tk < BK; tk++) {
FLOAT4(r_comp_a[0]) = FLOAT4(s_a[smem_sel][tk][ty * TM / 2 ]);
FLOAT4(r_comp_a[4]) = FLOAT4(s_a[smem_sel][tk][ty * TM / 2 + BM / 2]);
FLOAT4(r_comp_b[0]) = FLOAT4(s_b[smem_sel][tk][tx * TN / 2 ]);
FLOAT4(r_comp_b[4]) = FLOAT4(s_b[smem_sel][tk][tx * TN / 2 + BN / 2]);
#pragma unroll
for (int tm = 0; tm < TM; tm++) {
#pragma unroll
for (int tn = 0; tn < TN; tn++) {
r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
}
}
}
s_a[smem_sel_next][load_a_smem_k ][load_a_smem_m] = r_load_a[0];
s_a[smem_sel_next][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
s_a[smem_sel_next][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
s_a[smem_sel_next][load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
FLOAT4(s_b[smem_sel_next][load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);
__syncthreads();
}
#pragma unroll
for (int tk = 0; tk < BK; tk++) {
FLOAT4(r_comp_a[0]) = FLOAT4(s_a[1][tk][ty * TM / 2 ]);
FLOAT4(r_comp_a[4]) = FLOAT4(s_a[1][tk][ty * TM / 2 + BM / 2]);
FLOAT4(r_comp_b[0]) = FLOAT4(s_b[1][tk][tx * TN / 2 ]);
FLOAT4(r_comp_b[4]) = FLOAT4(s_b[1][tk][tx * TN / 2 + BN / 2]);
#pragma unroll
for (int tm = 0; tm < TM; tm++) {
#pragma unroll
for (int tn = 0; tn < TN; tn++) {
r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
}
}
}
#pragma unroll
for (int i = 0; i < TM / 2; i++) {
int store_c_gmem_m = by * BM + ty * TM / 2 + i;
int store_c_gmem_n = bx * BN + tx * TN / 2;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][0]);
FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i][4]);
}
#pragma unroll
for (int i = 0; i < TM / 2; i++) {
int store_c_gmem_m = by * BM + BM / 2 + ty * TM / 2 + i;
int store_c_gmem_n = bx * BN + tx * TN / 2;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i + TM / 2][0]);
FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i + TM / 2][4]);
}
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
- 28.
- 29.
- 30.
- 31.
- 32.
- 33.
- 34.
- 35.
- 36.
- 37.
- 38.
- 39.
- 40.
- 41.
- 42.
- 43.
- 44.
- 45.
- 46.
- 47.
- 48.
- 49.
- 50.
- 51.
- 52.
- 53.
- 54.
- 55.
- 56.
- 57.
- 58.
- 59.
- 60.
- 61.
- 62.
- 63.
- 64.
- 65.
- 66.
- 67.
- 68.
- 69.
- 70.
- 71.
- 72.
- 73.
- 74.
- 75.
- 76.
- 77.
- 78.
- 79.
- 80.
- 81.
- 82.
- 83.
- 84.
- 85.
- 86.
- 87.
- 88.
- 89.
- 90.
- 91.
- 92.
- 93.
- 94.
- 95.
- 96.
- 97.
- 98.
- 99.
- 100.
- 101.
性能如下所示,達到了理論峰值的80.6%。
M N K = 128 128 1024, Time = 0.00024000 0.00024240 0.00025792 s, AVG Performance = 128.9191 Gflops
M N K = 192 192 1024, Time = 0.00024000 0.00024048 0.00024125 s, AVG Performance = 292.3840 Gflops
M N K = 256 256 1024, Time = 0.00024029 0.00024114 0.00024272 s, AVG Performance = 518.3728 Gflops
M N K = 384 384 1024, Time = 0.00024070 0.00024145 0.00024198 s, AVG Performance = 1164.8394 Gflops
M N K = 512 512 1024, Time = 0.00024173 0.00024237 0.00024477 s, AVG Performance = 2062.9786 Gflops
M N K = 768 768 1024, Time = 0.00024291 0.00024540 0.00026010 s, AVG Performance = 4584.3820 Gflops
M N K = 1024 1024 1024, Time = 0.00024534 0.00024631 0.00024941 s, AVG Performance = 8119.7302 Gflops
M N K = 1536 1536 1024, Time = 0.00045712 0.00045780 0.00045872 s, AVG Performance = 9829.5167 Gflops
M N K = 2048 2048 1024, Time = 0.00089632 0.00089970 0.00090656 s, AVG Performance = 8891.8924 Gflops
M N K = 3072 3072 1024, Time = 0.00177891 0.00178289 0.00178592 s, AVG Performance = 10095.9883 Gflops
M N K = 4096 4096 1024, Time = 0.00309763 0.00310057 0.00310451 s, AVG Performance = 10320.6843 Gflops
M N K = 6144 6144 1024, Time = 0.00604826 0.00619887 0.00663078 s, AVG Performance = 11615.0253 Gflops
M N K = 8192 8192 1024, Time = 0.01031738 0.01045051 0.01048861 s, AVG Performance = 12248.2036 Gflops
M N K = 12288 12288 1024, Time = 0.02283978 0.02285837 0.02298272 s, AVG Performance = 12599.3212 Gflops
M N K = 16384 16384 1024, Time = 0.04043287 0.04044823 0.04046151 s, AVG Performance = 12658.1556 Gflops
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
從profiling 可以看到雙倍的Shared Memory 的佔用
三、cuBLAS 實現方式探究
本節我們將認識CUDA的標準函式庫-cuBLAS,即NVIDIA版本的基本線性代數子程式(Basic Linear Algebra Subprograms, BLAS) 規範實作程式碼。它支援Level 1 (向量與向量運算) ,Level 2 (向量與矩陣運算) ,Level 3 (矩陣與矩陣運算) 層級的標準矩陣運算。
cuBLAS/CUTLASS GEMM的基本過程
如上圖所示,計算過程分解成線程塊片(thread block tile)、線程束片(warp tile)和線程片(thread tile)的層次結構並將AMP的策略應用於此層次結構來高效率的完成基於GPU的拆分成tile的GEMM。這個層次結構緊密地反映了NVIDIA CUDA程式設計模型。可以看到從global memory到shared memory的資料移動(矩陣到thread block tile);從shared memory到暫存器的資料移動(thread block tile到warp tile);從暫存器到CUDA core的計算(warp tile到thread tile )。
cuBLAS 實現了單精度矩陣乘的函數cublasSgemm,其主要參數如下:
cublasStatus_t cublasSgemm( cublasHandle_t handle, // 调用 cuBLAS 库时的句柄
cublasOperation_t transa, // A 矩阵是否需要转置
cublasOperation_t transb, // B 矩阵是否需要转置
int m, // A 的行数
int n, // B 的列数
int k, // A 的列数
const float *alpha, // 系数 α, host or device pointer
const float *A, // 矩阵 A 的指针,device pointer
int lda, // 矩阵 A 的主维,if A 转置, lda = max(1, k), else max(1, m)
const float *B, // 矩阵 B 的指针, device pointer
int ldb, // 矩阵 B 的主维,if B 转置, ldb = max(1, n), else max(1, k)
const float *beta, // 系数 β, host or device pointer
float *C, // 矩阵 C 的指针,device pointer
int ldc // 矩阵 C 的主维,ldc >= max(1, m) );
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
調用方式如下:
cublasHandle_t cublas_handle;
cublasCreate(&cublas_handle);
float cublas_alpha = 1.0;
float cublas_beta = 0;
cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K, &cublas_alpha, d_b, N, d_a, K, &cublas_beta, d_c, N);
- 1.
- 2.
- 3.
- 4.
- 5.
性能如下所示,達到了理論峰值的82.4%。
M N K = 128 128 1024, Time = 0.00002704 0.00003634 0.00010822 s, AVG Performance = 860.0286 Gflops
M N K = 192 192 1024, Time = 0.00003155 0.00003773 0.00007267 s, AVG Performance = 1863.6689 Gflops
M N K = 256 256 1024, Time = 0.00003917 0.00004524 0.00007747 s, AVG Performance = 2762.9438 Gflops
M N K = 384 384 1024, Time = 0.00005318 0.00005978 0.00009120 s, AVG Performance = 4705.0655 Gflops
M N K = 512 512 1024, Time = 0.00008326 0.00010280 0.00013840 s, AVG Performance = 4863.9646 Gflops
M N K = 768 768 1024, Time = 0.00014278 0.00014867 0.00018816 s, AVG Performance = 7567.1560 Gflops
M N K = 1024 1024 1024, Time = 0.00023485 0.00024460 0.00028150 s, AVG Performance = 8176.5614 Gflops
M N K = 1536 1536 1024, Time = 0.00046474 0.00047607 0.00051181 s, AVG Performance = 9452.3201 Gflops
M N K = 2048 2048 1024, Time = 0.00077930 0.00087862 0.00092307 s, AVG Performance = 9105.2126 Gflops
M N K = 3072 3072 1024, Time = 0.00167904 0.00168434 0.00171114 s, AVG Performance = 10686.6837 Gflops
M N K = 4096 4096 1024, Time = 0.00289619 0.00291068 0.00295904 s, AVG Performance = 10994.0128 Gflops
M N K = 6144 6144 1024, Time = 0.00591766 0.00594586 0.00596915 s, AVG Performance = 12109.2611 Gflops
M N K = 8192 8192 1024, Time = 0.01002384 0.01017465 0.01028435 s, AVG Performance = 12580.2896 Gflops
M N K = 12288 12288 1024, Time = 0.02231159 0.02233805 0.02245619 s, AVG Performance = 12892.7969 Gflops
M N K = 16384 16384 1024, Time = 0.03954650 0.03959291 0.03967242 s, AVG Performance = 12931.6086 Gflops
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
由此可以比較以上各種方法的性能情況,可見手動實現的性能已接近官方的性能,如下: