CUDA之通用矩陣乘法:從入門到熟練!

2024.03.25

本文經自動駕駛之心公眾號授權轉載,轉載請洽出處。

通用矩陣乘法(General Matrix Multiplication,GEMM) 是各種模型和計算中的核心部分,同時也是評估計算硬體效能(FLOPS) 的標準技術。本文將透過對GEMM 的實現和最佳化,來試圖理解高效能運算和軟硬體系統。

一、GEMM的基本特徵

1.1 GEMM計算過程及複雜度

GEMM 的定義為:

圖片圖片

圖片圖片

矩陣乘法的計算示意

1.2 簡單實作及流程分析

圖片

以下是依照原始定義實現的CPU 上實現的程式碼,之後用以作為精度的對照

#define OFFSET(row, col, ld) ((row) * (ld) + (col))

void cpuSgemm(
    float *a, float *b, float *c, const int M, const int N, const int K) {

    for (int m = 0; m < M; m++) {
        for (int n = 0; n < N; n++) {
            float psum = 0.0;
            for (int k = 0; k < K; k++) {
                psum += a[OFFSET(m, k, K)] * b[OFFSET(k, n, N)];
            }
            c[OFFSET(m, n, N)] = psum;
        }
    }
}
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.

下面使用CUDA實現最簡單的矩陣乘法的Kernal,總共使用M * N 個執行緒完成整個矩陣乘法。每個執行緒負責矩陣C中一個元素的計算,需要完成K次乘累加。矩陣A,B,C皆存放與全域記憶體中(由修飾符決定),完整代碼見sgemm_naive.cu 。 __global__ 

__global__ void naiveSgemm(
    float * __restrict__ a, float * __restrict__ b, float * __restrict__ c,
    const int M, const int N, const int K) {

    int n = blockIdx.x * blockDim.x + threadIdx.x;
    int m = blockIdx.y * blockDim.y + threadIdx.y;
    if (m < M && n < N) {
        float psum = 0.0;
        #pragma unroll
        for (int k = 0; k < K; k++) {
            psum += a[OFFSET(m, k, K)] * b[OFFSET(k, n, N)];
        }
        c[OFFSET(m, n, N)] = psum;
    }
}

const int BM = 32, BN = 32;
const int M = 512, N = 512, K = 512;
dim3 blockDim(BN, BM);
dim3 gridDim((N + BN - 1) / BN, (M + BM - 1) / BM);
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.

編譯完成,在Tesla V100-PCIE-32GB上執行的結果如下,根據V100的白皮書,FP32 的峰值算力為15.7 TFLOPS,因此此方式算力利用率僅有11.5%。

M N K =    128    128   1024, Time =   0.00010083   0.00010260   0.00010874 s, AVG Performance =   304.5951 Gflops
M N K =    192    192   1024, Time =   0.00010173   0.00010198   0.00010253 s, AVG Performance =   689.4680 Gflops
M N K =    256    256   1024, Time =   0.00010266   0.00010318   0.00010384 s, AVG Performance =  1211.4281 Gflops
M N K =    384    384   1024, Time =   0.00019475   0.00019535   0.00019594 s, AVG Performance =  1439.7206 Gflops
M N K =    512    512   1024, Time =   0.00037693   0.00037794   0.00037850 s, AVG Performance =  1322.9753 Gflops
M N K =    768    768   1024, Time =   0.00075238   0.00075558   0.00075776 s, AVG Performance =  1488.9271 Gflops
M N K =   1024   1024   1024, Time =   0.00121562   0.00121669   0.00121789 s, AVG Performance =  1643.8068 Gflops
M N K =   1536   1536   1024, Time =   0.00273072   0.00275611   0.00280208 s, AVG Performance =  1632.7386 Gflops
M N K =   2048   2048   1024, Time =   0.00487622   0.00488028   0.00488614 s, AVG Performance =  1639.2518 Gflops
M N K =   3072   3072   1024, Time =   0.01001603   0.01071136   0.01099990 s, AVG Performance =  1680.4589 Gflops
M N K =   4096   4096   1024, Time =   0.01771046   0.01792170   0.01803462 s, AVG Performance =  1785.5450 Gflops
M N K =   6144   6144   1024, Time =   0.03988969   0.03993405   0.04000595 s, AVG Performance =  1802.9724 Gflops
M N K =   8192   8192   1024, Time =   0.07119219   0.07139694   0.07160816 s, AVG Performance =  1792.7940 Gflops
M N K =  12288  12288   1024, Time =   0.15978026   0.15993242   0.16043369 s, AVG Performance =  1800.7606 Gflops
M N K =  16384  16384   1024, Time =   0.28559187   0.28567238   0.28573316 s, AVG Performance =  1792.2629 Gflops
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.

以下以M=512,K=512,N=512,為例,詳細分析上述計算過程的workflow:

  1. 在Global Memory 中分別為矩陣A,B,C分配儲存空間.
  2. 由於矩陣C中每個元素的計算均相互獨立, 因此在並行度映射中讓每個thread 對應矩陣C中1 個元素的計算.
  3. 執行配置(execution configuration)中gridSize 和blockSize 都有x(列向)、y(行向)兩個維度, 其中

圖片圖片圖片圖片

nsys 記錄的naive 版本的profiling

二、GEMM的最佳化探究

前文僅在功能上實現了GEMM,性能上還遠遠不及預期,本節將主要研究GEMM 性能上的優化。

2.1 矩陣分塊利用Shared Memory

上述的計算需要兩次Global Memory的load才能完成一次乘累加運算,計算訪存比極低,沒有有效的資料重複使用。所以可以用Shared Memory 來減少重複的記憶體讀取。

首先把矩陣C等分成BMxBN大小的分塊,每個分塊由一個Block 計算,其中每個Thread負責計算矩陣C中的TMxTN個元素。之後計算所需的資料全部從smem 讀取,就消除了一部分重複的A,B矩陣記憶體讀取。考慮到Shared Memory 容量有限,可以在K維上每次讀取BK大小的分塊,這樣的循環一共需要K / BK次以完成整個矩陣乘法操作,即可得到Block 的結果。其過程如下圖所示:

圖片

利用Shared Memory 最佳化後,對每一個分塊,可得:

由上式可知BM和BN越大,計算訪存比越高,效能就會越好。但由於Shared Memory 容量的限制(V100 1個SM僅96KB),而一個Block需要佔用BK * (BM + BN) * 4 Bytes大小。

TM和TN的取值也受到兩方面限制,一方面是線程數的限制,一個Block中有BM / TM * BN / TN個線程,這個數字不能超過1024,且不能太高防止影響SM內Block間的並行;另一方面是寄存器數目的限制,一個線程至少需要TM * TN個寄存器用於存放矩陣C的部分和,再加上一些其它的寄存器,所有的寄存器數目不能超過256,且不能太高防止影響SM內同時並行的執行緒數目。

最終選取BM = BN = 128,BK = 8,TM = TN = 8,則此時計算訪存比為32。根據V100的理論算力15.7TFLOPS,可得15.7TFLOPS/32 = 490GB/s,根據實測的HBM頻寬為763GB/s,可知此時頻寬不再會限制運算效能。

根據上述分析,kernel 函數實作過程如下,完整程式碼參考sgemm_v1.cu,主要步驟包括:

圖片圖片

AB 矩陣分塊的線程索引關係

確定好單一block的執行過程,接下來需要確定多block處理的不同分塊在Global Memory中的對應關係,仍以A為例進行說明。由於分塊沿著行的方向移動,那麼首先需要確定行號,根據Grid 的二維全域線性索引關係,by * BM 表示該分塊的起始行號,同時我們已知load_a_smem_m 為分塊內部的行號,因此全域的行號為load_a_gmem_m = by * BM + load_a_smem_m 。由於分塊沿著行的方向移動,因此列是變化的,需要在循環內部計算,同樣也是先計算起始列號bk * BK 加速分塊內部列號load_a_smem_k 得到,由此我們便可以確定了分塊在原始數據中的位置OFFSET( , K) 。同理可分析矩陣分塊的情況,不再贅述。 load_a_gmem_k = bk * BK + load_a_smem_k load_a_gmem_m, load_a_gmem_k

圖片圖片

計算完後,還需要將其存入Global Memory 中,這就需要計算其在Global Memory 中的對應關係。由於存在較小的分塊,則行和列均由3部分構成:全域行號store_c_gmem_m 等於大分塊的起始行號by * BM+小分塊的起始行號ty * TM+小分塊內部的相對行號。列同理。 i 

__global__ void sgemm_V1(
    float * __restrict__ a, float * __restrict__ b, float * __restrict__ c,
    const int M, const int N, const int K) {

    const int BM = 128;
    const int BN = 128;
    const int BK = 8;
    const int TM = 8;
    const int TN = 8;

    const int bx = blockIdx.x;
    const int by = blockIdx.y;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;
    const int tid = ty * blockDim.x + tx;

    __shared__ float s_a[BM][BK];
    __shared__ float s_b[BK][BN];

    float r_c[TM][TN] = {0.0};

    int load_a_smem_m = tid >> 1;  // tid/2, row of s_a
    int load_a_smem_k = (tid & 1) << 2;  // (tid % 2 == 0) ? 0 : 4, col of s_a
    int load_b_smem_k = tid >> 5;   // tid/32, row of s_b
    int load_b_smem_n = (tid & 31) << 2;  // (tid % 32) * 4, col of s_b

    int load_a_gmem_m = by * BM + load_a_smem_m;  // global row of a
    int load_b_gmem_n = bx * BN + load_b_smem_n;  // global col of b

    for (int bk = 0; bk < (K + BK - 1) / BK; bk++) {
        int load_a_gmem_k = bk * BK + load_a_smem_k;   // global col of a
        int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
        FLOAT4(s_a[load_a_smem_m][load_a_smem_k]) = FLOAT4(a[load_a_gmem_addr]);
        int load_b_gmem_k = bk * BK + load_b_smem_k;   // global row of b
        int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
        FLOAT4(s_b[load_b_smem_k][load_b_smem_n]) = FLOAT4(b[load_b_gmem_addr]);
        __syncthreads();

        #pragma unroll
        for (int k = 0; k < BK; k++) {
            #pragma unroll
            for (int m = 0; m < TM; m++) {
                #pragma unroll
                for (int n = 0; n < TN; n++) {
                    int comp_a_smem_m = ty * TM + m;
                    int comp_b_smem_n = tx * TN + n;
                    r_c[m][n] += s_a[comp_a_smem_m][k] * s_b[k][comp_b_smem_n];
                }
            }
        }

        __syncthreads();
    }

    #pragma unroll
    for (int i = 0; i < TM; i++) {
        int store_c_gmem_m = by * BM + ty * TM + i;
        #pragma unroll
        for (int j = 0; j < TN; j += 4) {
            int store_c_gmem_n = bx * BN + tx * TN + j;
            int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
            FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][j]);
        }
    }
}
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.

計算結果如下,性能達到了理論峰值性能的51.7%:

M N K =    128    128   1024, Time =   0.00031578   0.00031727   0.00032288 s, AVG Performance =    98.4974 Gflops
M N K =    192    192   1024, Time =   0.00031638   0.00031720   0.00031754 s, AVG Performance =   221.6661 Gflops
M N K =    256    256   1024, Time =   0.00031488   0.00031532   0.00031606 s, AVG Performance =   396.4287 Gflops
M N K =    384    384   1024, Time =   0.00031686   0.00031814   0.00032080 s, AVG Performance =   884.0425 Gflops
M N K =    512    512   1024, Time =   0.00031814   0.00032007   0.00032493 s, AVG Performance =  1562.1563 Gflops
M N K =    768    768   1024, Time =   0.00032397   0.00034419   0.00034848 s, AVG Performance =  3268.5245 Gflops
M N K =   1024   1024   1024, Time =   0.00034570   0.00034792   0.00035331 s, AVG Performance =  5748.3952 Gflops
M N K =   1536   1536   1024, Time =   0.00068797   0.00068983   0.00069094 s, AVG Performance =  6523.3424 Gflops
M N K =   2048   2048   1024, Time =   0.00136173   0.00136552   0.00136899 s, AVG Performance =  5858.5604 Gflops
M N K =   3072   3072   1024, Time =   0.00271910   0.00273115   0.00274006 s, AVG Performance =  6590.6331 Gflops
M N K =   4096   4096   1024, Time =   0.00443805   0.00445964   0.00446883 s, AVG Performance =  7175.4698 Gflops
M N K =   6144   6144   1024, Time =   0.00917891   0.00950608   0.00996963 s, AVG Performance =  7574.0999 Gflops
M N K =   8192   8192   1024, Time =   0.01628838   0.01645271   0.01660790 s, AVG Performance =  7779.8733 Gflops
M N K =  12288  12288   1024, Time =   0.03592557   0.03597434   0.03614323 s, AVG Performance =  8005.7066 Gflops
M N K =  16384  16384   1024, Time =   0.06304122   0.06306373   0.06309302 s, AVG Performance =  8118.7715 Gflops
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.

下面仍以M=512,K=512,N=512為例,分析一下結果。首先透過profiling 可以看到Shared Memory 佔用為8192 bytes,這與理論上(128+128)X8X4完全一致。

圖片nsys 記錄的V1 版本的profiling

profiling 顯示Occupancy 為12.5%,可以透過cuda-calculator 加以印證,該例中threads per block = 256, Registers per thread = 136, 由此可以計算得到每個SM中活躍的warp 為8,而對於V100,每個SM中的warp 總數為64,因此Occupancy 為8/64 = 12.5%。

2.2 解決Bank Conflict 問題

上節透過利用Shared Memory 大幅提高了存取效率,進而提高了效能,本節將進一步優化Shared Memory 的使用。

Shared Memory共分為32個Bank,每個Bank的寬度為4 Bytes,如果需要存取同一個Bank的多個數據,就會發生Bank Conflict。例如一個Warp的32個線程,如果訪問的地址分別為0、4、8、...、124,就不會發生Bank Conflict,只佔用Shared Memory一拍的時間;如果訪問的地址為0、8 、16、...、248,這樣一來地址0和地址128對應的資料位於同一Bank、地址4和地址132對應的資料位於同一Bank,以此類推,那麼就需要佔用Shared Memory兩拍的時間才能讀出。

有Bank Conflict VS 無Bank Conflict

再看V1 版本計算部分的三層循環,每次從Shared memory中取矩陣A的長度為TM的向量和矩陣B的長度為TN的向量,這兩個向量做外積並累加到部分和中,一次外積共TM * TN次乘累加,共需循環BK次取數與外積。

接下來分析從Shared Memory load的過程中存在的Bank Conflict:

i) 取矩陣A需要取一個列向量,而矩陣A在Shared Memory中是按行儲存的;

ii) 在TM = TN = 8的情況下,無論矩陣A或矩陣B,從Shared Memory中取數時需要取連續的8個數,即便用LDS.128指令一條指令取四個數,也需要兩條指令,由於一個執行緒的兩個load指令的位址是連續的,那麼同一個Warp不同執行緒的同一條load指令的訪存位址就是被間隔開的,就存在著Bank Conflict。

為了解決上述的兩點Shared Memory的Bank Conflict,採用了一下兩點優化:

i) 為矩陣A指派Shared Memory時形狀分配為[BK][BM],即讓矩陣A在Shared Memory中按列存儲

ii) 將原本每個執行緒負責計算的TM * TN的矩陣C,分成下圖中這樣的兩塊TM/2 * TN的矩陣C,由於TM/2=4,一條指令即可完成A的一塊的load操作,兩個load可同時進行。

kernel 函數的核心部分實作如下,完整程式碼見sgemm_v2.cu 。

__shared__ float s_a[BK][BM];
    __shared__ float s_b[BK][BN];

    float r_load_a[4];
    float r_load_b[4];
    float r_comp_a[TM];
    float r_comp_b[TN];
    float r_c[TM][TN] = {0.0};

    int load_a_smem_m = tid >> 1;
    int load_a_smem_k = (tid & 1) << 2;
    int load_b_smem_k = tid >> 5;
    int load_b_smem_n = (tid & 31) << 2;

    int load_a_gmem_m = by * BM + load_a_smem_m;
    int load_b_gmem_n = bx * BN + load_b_smem_n;

    for (int bk = 0; bk < (K + BK - 1) / BK; bk++) {

        int load_a_gmem_k = bk * BK + load_a_smem_k;
        int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
        int load_b_gmem_k = bk * BK + load_b_smem_k;
        int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
        FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
        FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);

        s_a[load_a_smem_k    ][load_a_smem_m] = r_load_a[0];
        s_a[load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
        s_a[load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
        s_a[load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
        FLOAT4(s_b[load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);

        __syncthreads();

        #pragma unroll
        for (int tk = 0; tk < BK; tk++) {
            FLOAT4(r_comp_a[0]) = FLOAT4(s_a[tk][ty * TM / 2         ]);
            FLOAT4(r_comp_a[4]) = FLOAT4(s_a[tk][ty * TM / 2 + BM / 2]);
            FLOAT4(r_comp_b[0]) = FLOAT4(s_b[tk][tx * TN / 2         ]);
            FLOAT4(r_comp_b[4]) = FLOAT4(s_b[tk][tx * TN / 2 + BN / 2]);

            #pragma unroll
            for (int tm = 0; tm < TM; tm++) {
                #pragma unroll
                for (int tn = 0; tn < TN; tn++) {
                    r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
                }
            }
        }

        __syncthreads();
    }

    #pragma unroll
    for (int i = 0; i < TM / 2; i++) {
        int store_c_gmem_m = by * BM + ty * TM / 2 + i;
        int store_c_gmem_n = bx * BN + tx * TN / 2;
        int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
        FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][0]);
        FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i][4]);
    }
    #pragma unroll
    for (int i = 0; i < TM / 2; i++) {
        int store_c_gmem_m = by * BM + BM / 2 + ty * TM / 2 + i;
        int store_c_gmem_n = bx * BN + tx * TN / 2;
        int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
        FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i + TM / 2][0]);
        FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i + TM / 2][4]);
    }
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.
  • 68.
  • 69.

結果如下,相對未解決Bank Conflict 版(V1) 性能提高了14.4%,達到了理論峰值的74.3%。

M N K =    128    128   1024, Time =   0.00029699   0.00029918   0.00030989 s, AVG Performance =   104.4530 Gflops
M N K =    192    192   1024, Time =   0.00029776   0.00029828   0.00029882 s, AVG Performance =   235.7252 Gflops
M N K =    256    256   1024, Time =   0.00029485   0.00029530   0.00029619 s, AVG Performance =   423.2949 Gflops
M N K =    384    384   1024, Time =   0.00029734   0.00029848   0.00030090 s, AVG Performance =   942.2843 Gflops
M N K =    512    512   1024, Time =   0.00029853   0.00029945   0.00030070 s, AVG Performance =  1669.7479 Gflops
M N K =    768    768   1024, Time =   0.00030458   0.00032467   0.00032790 s, AVG Performance =  3465.1038 Gflops
M N K =   1024   1024   1024, Time =   0.00032406   0.00032494   0.00032621 s, AVG Performance =  6155.0281 Gflops
M N K =   1536   1536   1024, Time =   0.00047990   0.00048224   0.00048461 s, AVG Performance =  9331.3912 Gflops
M N K =   2048   2048   1024, Time =   0.00094426   0.00094636   0.00094992 s, AVG Performance =  8453.4569 Gflops
M N K =   3072   3072   1024, Time =   0.00187866   0.00188096   0.00188538 s, AVG Performance =  9569.5816 Gflops
M N K =   4096   4096   1024, Time =   0.00312589   0.00319050   0.00328147 s, AVG Performance = 10029.7885 Gflops
M N K =   6144   6144   1024, Time =   0.00641280   0.00658940   0.00703498 s, AVG Performance = 10926.6372 Gflops
M N K =   8192   8192   1024, Time =   0.01101130   0.01116194   0.01122950 s, AVG Performance = 11467.5446 Gflops
M N K =  12288  12288   1024, Time =   0.02464854   0.02466705   0.02469344 s, AVG Performance = 11675.4946 Gflops
M N K =  16384  16384   1024, Time =   0.04385955   0.04387468   0.04388355 s, AVG Performance = 11669.5995 Gflops
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.

分析一下profiling 可以看到Static Shared Memory 仍然是使用了8192 Bytes,奇怪的是,Shared Memory executed 卻翻倍變成了16384 Bytes(知友如果知道原因可以告訴我一下)。

圖片

2.3 流水並行化:Double Buffering

Double Buffering,即雙重緩衝,即透過增加buffer的方式,使得訪存-計算的串列模式流水線化,以減少等待時間,提高計算效率,其原理如下圖所示:  

單緩衝 VS 雙緩衝

具體到GEMM 任務中來,就是需要兩倍的Shared Memory,之前只需要BK * (BM + BN) * 4 Bytes的Shared Memory,採用Double Buffering之後需要2BK * (BM + BN) * 4 Bytes的Shared Memory ,然後使其pipeline 流動起來。

程式碼核心部分如下所示,完整程式碼請參考sgemm_v3.cu 。有以下幾點要注意:

1)主循環從bk = 1 開始,第一次資料載入在主循環之前,最後一次計算在主循環之後,這是pipeline 的特徵決定的;

2)由於計算和下次訪存使用的Shared Memory不同,因此在主循環中每次循環只需要一次__syncthreads()即可

3)由於GPU無法向CPU那樣支援亂序執行,主循環中需要先將下一次循環計算所需的Gloabal Memory中的資料load 到暫存器,然後進行本次計算,之後再將load到暫存器中的資料寫到Shared Memory,這樣在LDG指令向Global Memory做load時,不會影響後續FFMA及其它運算指令的launch 執行,也就達到了Double Buffering的目的。

__shared__ float s_a[2][BK][BM];
    __shared__ float s_b[2][BK][BN];

    float r_load_a[4];
    float r_load_b[4];
    float r_comp_a[TM];
    float r_comp_b[TN];
    float r_c[TM][TN] = {0.0};

    int load_a_smem_m = tid >> 1;
    int load_a_smem_k = (tid & 1) << 2;
    int load_b_smem_k = tid >> 5;
    int load_b_smem_n = (tid & 31) << 2;

    int load_a_gmem_m = by * BM + load_a_smem_m;
    int load_b_gmem_n = bx * BN + load_b_smem_n;

    {
        int load_a_gmem_k = load_a_smem_k;
        int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
        int load_b_gmem_k = load_b_smem_k;
        int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
        FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
        FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);

        s_a[0][load_a_smem_k    ][load_a_smem_m] = r_load_a[0];
        s_a[0][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
        s_a[0][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
        s_a[0][load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
        FLOAT4(s_b[0][load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);
    }

    for (int bk = 1; bk < (K + BK - 1) / BK; bk++) {

        int smem_sel = (bk - 1) & 1;
        int smem_sel_next = bk & 1;

        int load_a_gmem_k = bk * BK + load_a_smem_k;
        int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_gmem_k, K);
        int load_b_gmem_k = bk * BK + load_b_smem_k;
        int load_b_gmem_addr = OFFSET(load_b_gmem_k, load_b_gmem_n, N);
        FLOAT4(r_load_a[0]) = FLOAT4(a[load_a_gmem_addr]);
        FLOAT4(r_load_b[0]) = FLOAT4(b[load_b_gmem_addr]);

        #pragma unroll
        for (int tk = 0; tk < BK; tk++) {
            FLOAT4(r_comp_a[0]) = FLOAT4(s_a[smem_sel][tk][ty * TM / 2         ]);
            FLOAT4(r_comp_a[4]) = FLOAT4(s_a[smem_sel][tk][ty * TM / 2 + BM / 2]);
            FLOAT4(r_comp_b[0]) = FLOAT4(s_b[smem_sel][tk][tx * TN / 2         ]);
            FLOAT4(r_comp_b[4]) = FLOAT4(s_b[smem_sel][tk][tx * TN / 2 + BN / 2]);

            #pragma unroll
            for (int tm = 0; tm < TM; tm++) {
                #pragma unroll
                for (int tn = 0; tn < TN; tn++) {
                    r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
                }
            }
        }

        s_a[smem_sel_next][load_a_smem_k    ][load_a_smem_m] = r_load_a[0];
        s_a[smem_sel_next][load_a_smem_k + 1][load_a_smem_m] = r_load_a[1];
        s_a[smem_sel_next][load_a_smem_k + 2][load_a_smem_m] = r_load_a[2];
        s_a[smem_sel_next][load_a_smem_k + 3][load_a_smem_m] = r_load_a[3];
        FLOAT4(s_b[smem_sel_next][load_b_smem_k][load_b_smem_n]) = FLOAT4(r_load_b[0]);

        __syncthreads();
    }

    #pragma unroll
    for (int tk = 0; tk < BK; tk++) {
        FLOAT4(r_comp_a[0]) = FLOAT4(s_a[1][tk][ty * TM / 2         ]);
        FLOAT4(r_comp_a[4]) = FLOAT4(s_a[1][tk][ty * TM / 2 + BM / 2]);
        FLOAT4(r_comp_b[0]) = FLOAT4(s_b[1][tk][tx * TN / 2         ]);
        FLOAT4(r_comp_b[4]) = FLOAT4(s_b[1][tk][tx * TN / 2 + BN / 2]);

        #pragma unroll
        for (int tm = 0; tm < TM; tm++) {
            #pragma unroll
            for (int tn = 0; tn < TN; tn++) {
                r_c[tm][tn] += r_comp_a[tm] * r_comp_b[tn];
            }
        }
    }

    #pragma unroll
    for (int i = 0; i < TM / 2; i++) {
        int store_c_gmem_m = by * BM + ty * TM / 2 + i;
        int store_c_gmem_n = bx * BN + tx * TN / 2;
        int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
        FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i][0]);
        FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i][4]);
    }
    #pragma unroll
    for (int i = 0; i < TM / 2; i++) {
        int store_c_gmem_m = by * BM + BM / 2 + ty * TM / 2 + i;
        int store_c_gmem_n = bx * BN + tx * TN / 2;
        int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
        FLOAT4(c[store_c_gmem_addr]) = FLOAT4(r_c[i + TM / 2][0]);
        FLOAT4(c[store_c_gmem_addr + BN / 2]) = FLOAT4(r_c[i + TM / 2][4]);
    }
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.
  • 68.
  • 69.
  • 70.
  • 71.
  • 72.
  • 73.
  • 74.
  • 75.
  • 76.
  • 77.
  • 78.
  • 79.
  • 80.
  • 81.
  • 82.
  • 83.
  • 84.
  • 85.
  • 86.
  • 87.
  • 88.
  • 89.
  • 90.
  • 91.
  • 92.
  • 93.
  • 94.
  • 95.
  • 96.
  • 97.
  • 98.
  • 99.
  • 100.
  • 101.

性能如下所示,達到了理論峰值的80.6%。

M N K =    128    128   1024, Time =   0.00024000   0.00024240   0.00025792 s, AVG Performance =   128.9191 Gflops
M N K =    192    192   1024, Time =   0.00024000   0.00024048   0.00024125 s, AVG Performance =   292.3840 Gflops
M N K =    256    256   1024, Time =   0.00024029   0.00024114   0.00024272 s, AVG Performance =   518.3728 Gflops
M N K =    384    384   1024, Time =   0.00024070   0.00024145   0.00024198 s, AVG Performance =  1164.8394 Gflops
M N K =    512    512   1024, Time =   0.00024173   0.00024237   0.00024477 s, AVG Performance =  2062.9786 Gflops
M N K =    768    768   1024, Time =   0.00024291   0.00024540   0.00026010 s, AVG Performance =  4584.3820 Gflops
M N K =   1024   1024   1024, Time =   0.00024534   0.00024631   0.00024941 s, AVG Performance =  8119.7302 Gflops
M N K =   1536   1536   1024, Time =   0.00045712   0.00045780   0.00045872 s, AVG Performance =  9829.5167 Gflops
M N K =   2048   2048   1024, Time =   0.00089632   0.00089970   0.00090656 s, AVG Performance =  8891.8924 Gflops
M N K =   3072   3072   1024, Time =   0.00177891   0.00178289   0.00178592 s, AVG Performance = 10095.9883 Gflops
M N K =   4096   4096   1024, Time =   0.00309763   0.00310057   0.00310451 s, AVG Performance = 10320.6843 Gflops
M N K =   6144   6144   1024, Time =   0.00604826   0.00619887   0.00663078 s, AVG Performance = 11615.0253 Gflops
M N K =   8192   8192   1024, Time =   0.01031738   0.01045051   0.01048861 s, AVG Performance = 12248.2036 Gflops
M N K =  12288  12288   1024, Time =   0.02283978   0.02285837   0.02298272 s, AVG Performance = 12599.3212 Gflops
M N K =  16384  16384   1024, Time =   0.04043287   0.04044823   0.04046151 s, AVG Performance = 12658.1556 Gflops
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.

從profiling 可以看到雙倍的Shared Memory 的佔用

三、cuBLAS 實現方式探究

本節我們將認識CUDA的標準函式庫-cuBLAS,即NVIDIA版本的基本線性代數子程式(Basic Linear Algebra Subprograms, BLAS) 規範實作程式碼。它支援Level 1 (向量與向量運算) ,Level 2 (向量與矩陣運算) ,Level 3 (矩陣與矩陣運算) 層級的標準矩陣運算。

cuBLAS/CUTLASS GEMM的基本過程

如上圖所示,計算過程分解成線程塊片(thread block tile)、線程束片(warp tile)和線程片(thread tile)的層次結構並將AMP的策略應用於此層次結構來高效率的完成基於GPU的拆分成tile的GEMM。這個層次結構緊密地反映了NVIDIA CUDA程式設計模型。可以看到從global memory到shared memory的資料移動(矩陣到thread block tile);從shared memory到暫存器的資料移動(thread block tile到warp tile);從暫存器到CUDA core的計算(warp tile到thread tile )。

cuBLAS 實現了單精度矩陣乘的函數cublasSgemm,其主要參數如下:

cublasStatus_t cublasSgemm( cublasHandle_t handle, // 调用 cuBLAS 库时的句柄 
                            cublasOperation_t transa, // A 矩阵是否需要转置 
                            cublasOperation_t transb, // B 矩阵是否需要转置 
                            int m, // A 的行数 
                            int n, // B 的列数 
                            int k, // A 的列数 
                            const float *alpha, // 系数 α, host or device pointer 
                            const float *A, // 矩阵 A 的指针,device pointer 
                            int lda, // 矩阵 A 的主维,if A 转置, lda = max(1, k), else max(1, m) 
                            const float *B, // 矩阵 B 的指针, device pointer 
                            int ldb, // 矩阵 B 的主维,if B 转置, ldb = max(1, n), else max(1, k) 
                            const float *beta, // 系数 β, host or device pointer 
                            float *C, // 矩阵 C 的指针,device pointer 
                            int ldc // 矩阵 C 的主维,ldc >= max(1, m) );
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.

調用方式如下:

cublasHandle_t cublas_handle;
cublasCreate(&cublas_handle);
float cublas_alpha = 1.0;
float cublas_beta = 0;
cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K, &cublas_alpha, d_b, N, d_a, K, &cublas_beta, d_c, N);
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.

性能如下所示,達到了理論峰值的82.4%。

M N K =    128    128   1024, Time =   0.00002704   0.00003634   0.00010822 s, AVG Performance =   860.0286 Gflops
M N K =    192    192   1024, Time =   0.00003155   0.00003773   0.00007267 s, AVG Performance =  1863.6689 Gflops
M N K =    256    256   1024, Time =   0.00003917   0.00004524   0.00007747 s, AVG Performance =  2762.9438 Gflops
M N K =    384    384   1024, Time =   0.00005318   0.00005978   0.00009120 s, AVG Performance =  4705.0655 Gflops
M N K =    512    512   1024, Time =   0.00008326   0.00010280   0.00013840 s, AVG Performance =  4863.9646 Gflops
M N K =    768    768   1024, Time =   0.00014278   0.00014867   0.00018816 s, AVG Performance =  7567.1560 Gflops
M N K =   1024   1024   1024, Time =   0.00023485   0.00024460   0.00028150 s, AVG Performance =  8176.5614 Gflops
M N K =   1536   1536   1024, Time =   0.00046474   0.00047607   0.00051181 s, AVG Performance =  9452.3201 Gflops
M N K =   2048   2048   1024, Time =   0.00077930   0.00087862   0.00092307 s, AVG Performance =  9105.2126 Gflops
M N K =   3072   3072   1024, Time =   0.00167904   0.00168434   0.00171114 s, AVG Performance = 10686.6837 Gflops
M N K =   4096   4096   1024, Time =   0.00289619   0.00291068   0.00295904 s, AVG Performance = 10994.0128 Gflops
M N K =   6144   6144   1024, Time =   0.00591766   0.00594586   0.00596915 s, AVG Performance = 12109.2611 Gflops
M N K =   8192   8192   1024, Time =   0.01002384   0.01017465   0.01028435 s, AVG Performance = 12580.2896 Gflops
M N K =  12288  12288   1024, Time =   0.02231159   0.02233805   0.02245619 s, AVG Performance = 12892.7969 Gflops
M N K =  16384  16384   1024, Time =   0.03954650   0.03959291   0.03967242 s, AVG Performance = 12931.6086 Gflops
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.

由此可以比較以上各種方法的性能情況,可見手動實現的性能已接近官方的性能,如下:


責任編輯:張燕妮來源: 自動駕駛之心