Comparing CUDA and tensor GPU cores for training Neural Networks

What are tensor cores?

Tensor cores are specialized units of a GPU that are optimized for matrix multiplication. Each tensor core can perform the matrix operation $$AB + C$$ where \(A, B\) and \(C\) are all \(4 \times 4\) matrices, in a single computational step (clock cycle).

(However, there isn't (yet?) an API for accessing individual tensor cores, but rather bunches of cores that perform the above operation on \(16 \times 16\) matrices instead of \(4 \times 4\).)

Tensor cores use mixed-precision arithmetic, meaning that they use different floating point precision at different steps of the computation. For the most computation-intensive operation, namely the product \(AB\), the entries of the matrices \(A\) and \(B\) are always in float16; the entries of \(C\) and the result of the computation can be either float16 or float32. Using float16 can cause problems with numerical stability (and in particular with neural net training stability). For neural nets, NVIDIA suggests using a technique called loss scaling to help make the training more stable. [This situation is likely to improve in the future. Already with newer (Ampere series) GPUs, tensor cores support computations in float64 (although the computations are slower with float64). However, these options are not available on my GPU (RTX 2060 Super), so I cannot test them out.]

There are several ways of accessing tensor cores in CUDA C++: I will compare access using the WMMA namespace, and access through the cuBLAS linear algebra subroutine library.

Accessing tensor cores using the WMMA namespace

To access the WMMA API, it is necessary to include the header mma.h. It is also convenient to include cuda_fp16.h for functions to work with half-precision floating point numbers. The namespace name is nvcuda::wmma::*. To compile, it is necessary to require architecture at least 7.0 when compiling with NVCC (for example, -arch=sm_70).

There are the following commands:

All of the matrices involved in the computation \(D = AB + C\) are specializations of a class template called a fragment (usually the \(16 \times 16\) matrices will be submatrices (fragments) of a larger matrix). The matrices \(A,\, B,\, C\) and \(D\) are different specializations of the class template. The signature of the template is as follows.
template<typename Use, int m, int n, int k, typename T, typename Layout> class fragment
Now, the signatures of the four commands are:
load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm); load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm, layout_t layout);
mptr is the pointer to the first element of the matrix in memory (the matrix is assumed to be stored in a linear memory fragment). ldm (leading dimension) gives the number of items between rows (if row majorization is used) or columns (if column majorization is used) of the larger matrix. For accumulator, the layout type (row major or column major) must be specified, for the other two fragment types it can be deduced.
store_matrix_sync(T* mptr, const fragment<...>, unsigned ldm, layout_t layout);
mptr is the pointer to the first element of the storage memory. ldm (leading dimension) gives the number of items between rows (if row majorization is used) or columns (if column majorization is used) of the larger matrix. The layout type (row major or column major) must be specified.
mma_sync(fragment<...> &d, const fragment<...> &a, const fragment<...> &b, const fragment<...> &c, bool satf=false);
The fragments a,b,c,d are used to compute \(D = AB + C\). The boolean satf means saturate to finite value --- this option can be used for dealing with NaNs.
fill_fragment(fragment<...> &a, const T& v);
This one is rather self-explanatory. Matrix fragment a is filled with value v.

Accessing tensor cores through the cuBLAS library

When compiling, the library is linked by the command -lcublas. The relevant header name is cublas_v2.h.

A handle (cublasHandle_t) to a cuBLAS library context must be created at the start of a subroutine that uses the library, destroyed at the end, and passed to every library function call. The relevant function signatures are cublasCreate(cublasHandle_t*) and cublasDestroy(cublasHandle_t).

Matrices are sent to and from the GPU using the functions cublasSetMatrix and cublasGetMatrix, respectively. Matrix multiplication is carried out by a family of functions called GEMM (General Matrix-Matrix Multiply), which compute the expression \(\alpha AB + \beta C\) where \(\alpha,\, \beta\) are floating point numbers (of precision that matches the matrices). The various GEMM functions handle different numeric types: cublasSgemm float32; cublasDgemm float64; cublasHgemm float16; cublasCgemm complex32; cublasZgemm complex64. There is also cublasGemmEx that allows mixed-precision.

Unlike C/C++, the matrices are stored in column-major order. One of the GEMM functions gives the option to transpose the matrices, however, so this is not a major inconvenience.

Comparing matrix multiplication times

The CUDA and WMMA algorithms for multiplying \(AB\) that I will use on this page will apply the idea of splitting up both \(A\) and \(B\) into submatrices called blocks.

The following is a simple non-trivial example of how matrix products can be computed using a block-decomposition. Suppose that \(A = (a_{ij})_{1 \leq i, j \leq 4}\) and \(B = (b_{ij})_{1 \leq i,j \leq 4}\) are both \(4 \times 4\) matrices, and we would like to compute \(AB\). One way to do this is to first divide \(A\) and \(B\) into \(2 \times 2\) blocks. Formally, on the level of blocks the product of \(A\) and \(B\) is similar to the usual product of two \(2 \times 2\) matrices (but the blocks are multiplied instead of the matrix entries).

$$ A = \begin{pmatrix} A_{11} & A_{12} \\ A_{21} & A_{22} \end{pmatrix}, \quad B = \begin{pmatrix} B_{11} & B_{12} \\ B_{21} & B_{22} \end{pmatrix}, \quad AB = \begin{pmatrix} A_{11} B_{11} + A_{12} B_{21} & A_{11} B_{12} + A_{12}B_{22} \\ A_{21} B_{11} + A_{22} B_{21} & A_{21} B_{12} + A_{22}B_{22} \end{pmatrix}, \qquad \text{where } A_{ij} = \begin{pmatrix} a_{2i, 2j} & a_{2i, 2j+1} \\ a_{2i+1,2j} & a_{2i+1,2j+1} \end{pmatrix} \text{ and similarly for } B_{ij} $$

More generally, if \(A\) and \(B\) are each split into submatrices (not necessarily of equal sizes, as long as the products of the appropriate blocks make sense), then the product \(AB\) may be computed as above by computing sums of products of the blocks.

Approaching the problem by first splitting the matrices into blocks gives a way to use CUDA shared memory (which has lower latency), as well as using Tensor Cores.

The cuBLAS routines are much faster than the naive CUDA or WMMA algorithms. The following bar graph summarizes the comparison. The bars are scaled by the cube of the matrix side length (otherwise, the times would differ too much to be compared on the same scale). Each of the times is an arithmetic average over 100 multiplications, and the standard deviations were at most about 12% of the mean in the worst case (and usually smaller).

The following are short implementations of CUDA matrix multiplication using shared memory, WMMA and cuBLAS. For compactness, the displayed code does not include error-checking, but it should in the complete program.

Using CUDA cores

For simplicity, all of the blocks are chosen to have the same dimension \(b \times b\); moreover, \(b\) is assumed to divide the dimensions of both \(A\) and \(B\) evenly (the latter can always be achieved by padding \(A\) and \(B\) by zeroes if necessary, or adding some index-checking). The product \(AB\) is divided into a grid of \(b \times b\) submatrices as well --- this grid is realized as a grid of thread blocks in CUDA. If \(A\) is \(m \times k\) and \(B\) is \(k \times n\), the dimensions of the thread block grid are \(m/b \times n/b\). For each element of the grid, there are \(b \times b\) threads, each corresponding to a matrix element; each of these threads is responsible for three major operations: i) loading the corresponding entry of a block of \(A\) into shared memory; ii) loading the corresponding entry of a block of \(B\) into shared memory; iii) computing the corresponding entry of \(AB\).

#include <cuda.h> #include <cuda_runtime.h> const int BLOCK_SIZE = 32; const int MATRIX_SIZE = 4096; __host__ __device__ int index(const int i, const int j, const int a, const int b, const int COLS) { int row = a * BLOCK_SIZE + i, col = b * BLOCK_SIZE + j; return row * COLS + col; } __global__ void matrix_mult_ker(const float* A, const float* B, float* C, const int A_COLS, const int B_COLS, const int C_COLS) { int i = threadIdx.x, j = threadIdx.y, a = blockIdx.x, b = blockIdx.y, A_index, B_index, C_index; // i, j are indices of entries within a block; // a, b are indices of the block within the larger matrix AB __shared__ float Asub[BLOCK_SIZE][BLOCK_SIZE]; __shared__ float Bsub[BLOCK_SIZE][BLOCK_SIZE]; C_index = index(i, j, a, b, C_COLS); for (int s = 0; s < gridDim.x; s++) { A_index = index(i, j, a, s, A_COLS); B_index = index(i, j, s, b, B_COLS); Asub[i][j] = A[A_index]; Bsub[i][j] = B[B_index]; __syncthreads(); for (int x = 0; x < BLOCK_SIZE; x++) C[C_index] += Asub[i][x] * Bsub[x][j]; } } int main() { int m = MATRIX_SIZE, k = m, n = m; float* A = new float[m*k]; float* B = new float[k*n]; float* C = new float[m*n]; for (int i = 0; i < m*k; ++i) A[i] = 1.0; for (int i = 0; i < k*n; ++i) B[i] = 1.0; float *A_dev = 0, *B_dev = 0, *C_dev = 0; cudaMalloc(&A_dev, m * k * sizeof(float)); cudaMalloc(&B_dev, k * n * sizeof(float)); cudaMemcpy(A_dev, A, m * k * sizeof(float), cudaMemcpyHostToDevice); cudaMemcpy(B_dev, B, k * n * sizeof(float), cudaMemcpyHostToDevice); dim3 block_grid_dim(m/BLOCK_SIZE, n/BLOCK_SIZE); dim3 block_dim(BLOCK_SIZE, BLOCK_SIZE); matrix_mult_ker <<<block_grid_dim, block_dim>>> (A_dev, B_dev, C_dev, k, n, n); cudaDeviceSynchronize(); cudaMemcpy(C, C_dev, m * n * sizeof(float), cudaMemcpyDeviceToHost); cudaFree(A_dev); cudaFree(B_dev); cudaFree(C_dev); delete[] A; delete[] B; delete[] C; return 0; }

Using tensor cores with WMMA

The idea is similar to the previous algorithm. Again, the matrices \(A\), \(B\) and \(AB\) are split into equal blocks of size \(b \times b\) (for now, the only size supported by tensor cores is \(b = 16\)). The matrix multiplication kernel now looks as follows, and main is similar to the previous example.

using namespace nvcuda; const int BLOCK_SIZE = 16; __global__ void tensor_matrix_mult_ker(half* A, half* B, float* C, const int A_COLS, const int B_COLS, const int C_COLS) { int a = blockIdx.x, b = blockIdx.y; // a, b are indices of a 16x16 block inside a larger matrix // BLOCK_SIZE must be 16 for Tensor Cores wmma::fragment <wmma::matrix_a, BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE, half, wmma::row_major> A_frag; wmma::fragment <wmma::matrix_b, BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE, half, wmma::row_major> B_frag; wmma::fragment <wmma::accumulator, BLOCK_SIZE, BLOCK_SIZE, BLOCK_SIZE, float> acc; wmma::fill_fragment(acc, 0.0f); half *A_ptr, *B_ptr; for (int s = 0; s < gridDim.x; s++) { A_ptr = A + (a * A_COLS + s) * BLOCK_SIZE; B_ptr = B + (s * B_COLS + b) * BLOCK_SIZE; wmma::load_matrix_sync(A_frag, A_ptr, A_COLS); wmma::load_matrix_sync(B_frag, B_ptr, B_COLS); wmma::mma_sync(acc, A_frag, B_frag, acc); } float *C_ptr = C + (a * C_COLS + b) * BLOCK_SIZE; wmma::store_matrix_sync(C_ptr, acc, C_COLS, wmma::mem_row_major); }

WMMA is an improvement over the naive CUDA approach, but not as much as one may hope. Interestingly, using a column-major layout significantly improves performance over a row-major layout, which is slightly unfortunate because the later is somewhat more natural to C/C++ (and everyone who learned that matrices are indexed by rows followed by columns in Linear Algebra class).

Using tensor cores with cuBLAS

#include <cuda.h> #include <cuda_runtime.h> #include <cuda_fp16.h> #include <cublas_v2.h> int main() { int m = MATRIX_SIZE, k = m, n = m; cudaError_t error; cublasStatus_t status; cublasHandle_t handle; float* A = new float[m*k]; float* B = new float[k*n]; float* C = new float[m*n]; for (int i = 0; i < m*k; ++i) A[i] = (float) (rand() % 1000) / 1000; for (int i = 0; i < k*n; ++i) B[i] = (float) (rand() % 1000) / 1000; for (int i = 0; i < m*n; ++i) C[i] = 0.0f; float *A_dev = 0, *B_dev = 0, *C_dev; cudaMalloc(&A_dev, m * k * sizeof(float)); cudaMalloc(&B_dev, k * n * sizeof(float)); cudaMalloc(&C_dev, m * n * sizeof(float)); cudaMemcpy(A_dev, A, m * k * sizeof(float), cudaMemcpyHostToDevice); cudaMemcpy(B_dev, B, k * n * sizeof(float), cudaMemcpyHostToDevice); cublasCreate(&handle); cublasSetMatrix(m, k, sizeof(float), A, m, A_dev, m); cublasSetMatrix(k, n, sizeof(float), B, k, B_dev, k); const float alpha = 1.0f; const float beta = 0.0f; // GEMM returns alpha * AB + beta * C cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, &alpha, A_dev, m, B_dev, k, &beta, C_dev, m); cudaDeviceSynchronize(); cublasGetMatrix(m, n, sizeof(float), C_dev, m, C, m); cublasDestroy(handle); cudaFree(A_dev); cudaFree(B_dev); cudaFree(C_dev); delete[] A; delete[] B; delete[] C; return 0; }

cuBLAS really is a great improvement over the naive algorithms above. It would be quite interesting to study its implementation, but unfortunately it does not seem to be public as of November 2020.

Comparing training times for fully-connected neural nets

The forward and backward passes of a fully-connected neural network can be formulated as matrix multiplications as follows.

Consider a layer of a fully-connected neural net, with input consisting of M nodes and the output consisting of N nodes, together with bias nodes (the latter will be treated as node whose value is always set to 1). Let \(x[m]\) denote the values of the input layer, \(y[n]\) denote the values of the output layer, and \(\theta[n,m]\) denote the collection of weights, where \(0 \leq m \leq M,\ 0 \leq n \leq N\).

Then, the forward pass of the network is defined as $$ y[n] = \sum_{m=0}^M \theta[n,m] x[m] $$ Thinking of \(\mathbf{x} = (x[m])_{m=0}^M\) as a vector in \(\mathbf{R}^{M+1}\) and of \(\mathbf{y} = (y[n])_{n=0}^N\) as a vector in \(\mathbf{R}^{N+1}\) (both are considered as column vectors), the above expression can be rewritten as $$\mathbf{y} = \Theta \mathbf{x} $$ where $$ \Theta = \begin{pmatrix} \theta[0,0] & \theta[0,1] & \cdots & \theta[0,M] \\ \theta[1,0] & \theta[1,1] & \cdots & \theta[1,M] \\ \vdots & \vdots & \ddots & \vdots \\ \theta[N,0] & \theta[N,1] & \cdots & \theta[N,M] \end{pmatrix} $$ (this explains the choice of indexing the weights by \([n,m]\) instead of \([m,n]\)).

For the backward pass, we would like to compute \(\frac{\partial L}{\partial \theta[n,m]}\) and \(\frac{\partial L}{\partial x[m]}\), where \(L\) is our choice of loss function, assuming that \(\frac{\partial L}{\partial y[n]}\) is known for each \(n\).

In turn, we have $$ \frac{\partial L}{\partial x[m]} = \sum_{n=0}^N \frac{\partial L}{\partial y[n]} \frac{\partial y[n]}{\partial x[m]} = \sum_{n=0}^N \frac{\partial L}{\partial y[n]} \theta[n,m] $$ $$ \frac{\partial L}{\partial \theta[n,m]} = \sum_{n'=0}^N \frac{\partial L}{\partial y[n']} \frac{\partial y[n']}{\partial \theta[n,m]} = \frac{\partial L}{\partial y[n]} x[m]$$

Introducing the vector notation \( \displaystyle \frac{\partial L}{\partial \mathbf{x}} = \left( \frac{\partial L}{\partial x[m]} \right)_{m=0}^M \) and \( \displaystyle \frac{\partial L}{\partial \mathbf{y}} = \left( \frac{\partial L}{\partial y[n]} \right)_{n=0}^N \) the above expressions can be rewritten as $$ \frac{\partial L}{\partial \mathbf{x}} = \Theta^T \frac{\partial L}{\partial \mathbf{y}} $$ and $$ \frac{\partial L}{\partial \Theta} = \frac{\partial L}{\partial \mathbf{y}} \mathbf{x}^T $$

Although two of the three expressions above are matrix-vector multiplications, it is possible to group the vectors of a minibatch together into a matrix, to which we can apply matrix-matrix multiplication functions. So, it is necessary that the minibatch size is at least 16 (and better if divisible by 16).

I would now like to compare times for forward and backward passes of fully-connected neural nets of various sizes using the GPU matrix multiplication methods described above (and compare to the performance of cuDNN, the CUDA deep neural network library).

The results are as follows (the times are in microseconds, in format (mean ± stdev)); the minibatch size is 64.

Forward pass (in microseconds, averaged over 1000 passes)
Architecture Naive CUDA matrix multiplication
with shared mem.
WMMA cuBLAS PyTorch
4 layer 16 x 128 x 128 x 16 55.79 ± 1.02 29.05 ± 0.38 26.01 ± 1.12 104.21 ± 28.41
16 x 256 x 256 x 16 124.82 ± 0.74 50.06 ± 0.32 31.60 ± 2.66 121.19 ± 27.98
16 x 512 x 512 x 16 323.86 ± 45.70 99.06  0.34 31.15 ± 2.44 130.67 ± 37.60
16 x 1024 x 1024 x 16 999.68 ± 107.90 313.77 ± 0.43 39.02 ± 2.34 145.29 ± 55.66
16 x 2048 x 2048 x 16 3,533.00 ± 27.71 1,160.85 ± 2.19 64.60 ± 2.53 174.14 ± 58.19
16 x 4096 x 4096 x 16 13,382.70 ± 424.4 4,334.52 ± 8.85 135.48 ± 3.28 376.36 ± 662.91
5 layer 16 x 128 x 128 x 128 x 16 81.50 ± 0.57 41.58 ± 0.59 34.35 ± 1.19 134.72 ± 28.41
16 x 256 x 256 x 256 x 16 193.34 ± 18.53 64.10 ± 0.32 47.18 ± 2.45 161.39 ± 33.55
16 x 512 x 512 x 512 x 16 539.61 ± 71.04 173.22 ± 0.52 45.38 ± 2.14 172.96 ± 50.01
16 x 1024 x 1024 x 1024 x 16 1,840.00 ± 150.16 606.98 ± 1.00 63.24 ± 2.31 191.14 ± 56.32
16 x 2048 x 2048 x 2048 x 16 6,804.29 ± 305.18 2,239.27 ± 4.46 103.53 ± 2.55 260.28 ± 488.70
16 x 4096 x 4096 x 4096 x 16 26,386.7 ± 590.87 8,559.45 ± 13.23 242.09 ± 2.68 741.58 ± 197.25
6 layer 16 x 128 x 128 x 128 x 128 x 16 107.54 ± 0.61 53.66 ± 2.26 39.16 ± 1.24 170.63 ± 34.23
16 x 256 x 256 x 256 x 256 x 16 254.57 ± 31.63 86.63 ± 0.54 60.93 ± 2.42 204.70 ± 38.05
16 x 512 x 512 x 512 x 512 x 16 755.31 ± 88.54 248.49 ± 0.56 60.01 ± 2.03 219.04 ± 58.95
16 x 1024 x 1024 x 1024 x 1024 x 16 2,685.68 ± 185.20 884.141 ± 0.43 85.65 ± 2.60 248.19 ± 58.95
16 x 2048 x 2048 x 2048 x 2048 x 16 10,115.60 ± 359.22 3,324.50 ± 6.96 142.14 ± 2.66 349.86 ± 554.91
16 x 4096 x 4096 x 4096 x 4096 x 16 39,376.20 ± 698.68 12,752.50 ± 16.17 349.29 ± 2.96
The picture for the backward pass is similar: the naive CUDA and WMMA algorithms are close to cuBLAS for smaller networks, but cuBLAS is much faster for larger networks.

Comparing with PyTorch

PyTorch uses TensorRT and cuDNN libraries, but there is some additional overhead from Python. So, PyTorch times do lose to cuBLAS times (and even to more naive implementations on smaller networks).

Bibliography

All of the above resources were accessed on Nov. 2020.
November 2020. Ilia Smirnov (iliathesmirnov@gmail.com)