The Evolution of Matrix Multiplication (fastai course Part 2 Lessons 11 and 12)

python
deep learning
fastai
In this blog post I walk through 10 different implementations of matrix multiplication in python, Numba and PyTorch, comparing execution times on matrix multiplications between a 5-digit subset of MNIST and a single weight matrix, as well as matrix multiplications between the full 50k-image MNIST dataset and the weight matrix. My two main takeaways: 1) when in doubt, use PyTorch’s .cuda with the @ operator, and 2) different matrix multiplication algorithms scale differently!
Author

Vishal Bakshi

Published

May 21, 2025

Show setup code
# !conda install -y -c nvidia/label/cuda-12.8.0 cuda-toolkit
# !pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
# !conda install -y numba
# !conda install -y fastcore -c fastai

from pathlib import Path
import pickle, gzip, math, os, time, shutil
from urllib.request import urlretrieve
import torch
from torch import tensor
import numpy as np
from numba import njit
from numpy import array
from fastcore.test import *

from numba import cuda

torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
np.set_printoptions(precision=2, linewidth=140)

MNIST_URL='https://github.com/mnielsen/neural-networks-and-deep-learning/blob/master/data/mnist.pkl.gz?raw=true'
path_data = Path('data')
path_data.mkdir(exist_ok=True)
path_gz = path_data/'mnist.pkl.gz'


if not path_gz.exists(): urlretrieve(MNIST_URL, path_gz)

with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')

x_train,y_train,x_valid,y_valid = map(tensor, (x_train,y_train,x_valid,y_valid))
x_train.shape
torch.Size([50000, 784])
Show setup code
torch.manual_seed(1)
weights = torch.randn(784,10)
bias = torch.zeros(10)

m1 = x_valid[:5]
m2 = weights

m1.shape,m2.shape
(torch.Size([5, 784]), torch.Size([784, 10]))

Results

5-digit Subset

Method Time
PyTorch @ Op 18.1 μs
Numba Broadcasting 69.2 μs
Einstein Summation 83.1 μs
PyTorch Cuda 108 μs
Numba Cuda 108 μs
PyTorch Broadcasting 203 μs
Numba Dot Product 542 μs
torch.dot 1.19 ms
Element-wise PyTorch Ops 1.49 ms
Nested for-loops 604 ms

Full Dataset (50k images)

Method Time
PyTorch cuda 541 μs
Numba Cuda 3.91 ms
PyTorch @ Op 5.8 ms
Einstein Summation 5.87 ms
Numba Broadcasting 663 ms
PyTorch Broadcasting 1.26 s
Numba Dot Product 3.71 s

Version 0: Nested For-Loops

Excalidraw diagram showing nested for-loop implementation of matrix multiplication

ar,ac = m1.shape # n_rows * n_cols
br,bc = m2.shape
(ar,ac),(br,bc)
((5, 784), (784, 10))
t1 = torch.zeros(ar, bc)
t1.shape
torch.Size([5, 10])
for i in range(ar):         # 5
    for j in range(bc):     # 10
        for k in range(ac): # 784
            t1[i,j] += m1[i,k] * m2[k,j]
t1.shape
torch.Size([5, 10])
t1
tensor([[-10.94,  -0.68,  -7.00,  -4.01,  -2.09,  -3.36,   3.91,  -3.44, -11.47,  -2.12],
        [ 14.54,   6.00,   2.89,  -4.08,   6.59, -14.74,  -9.28,   2.16, -15.28,  -2.68],
        [  2.22,  -3.22,  -4.80,  -6.05,  14.17,  -8.98,  -4.79,  -5.44, -20.68,  13.57],
        [ -6.71,   8.90,  -7.46,  -7.90,   2.70,  -4.73, -11.03, -12.98,  -6.44,   3.64],
        [ -2.44,  -6.40,  -2.40,  -9.04,  11.18,  -5.77,  -8.92,  -3.79,  -8.98,   5.28]])
def matmul(a,b):
    (ar,ac),(br,bc) = a.shape,b.shape
    c = torch.zeros(ar, bc)
    for i in range(ar):
        for j in range(bc):
            for k in range(ac): c[i,j] += a[i,k] * b[k,j]
    return c
%time _=matmul(m1, m2)
CPU times: user 675 ms, sys: 0 ns, total: 675 ms
Wall time: 674 ms

Version 1: Numba Dot Product

Replacing the inner-most for-loop with a numba dot-product implementation.

Excalidraw diagram showing dot-product implementation of matrix multiplication

@njit
def dot(a,b):
    res = 0.
    for i in range(len(a)): res+=a[i]*b[i]
    return res
%time dot(array([1.,2,3]),array([2.,3,4]))
CPU times: user 124 ms, sys: 0 ns, total: 124 ms
Wall time: 123 ms
20.0
%time dot(array([1.,2,3]),array([2.,3,4]))
CPU times: user 26 μs, sys: 2 μs, total: 28 μs
Wall time: 32.4 μs
20.0
def matmul(a,b):
    (ar,ac),(br,bc) = a.shape,b.shape
    c = torch.zeros(ar, bc)
    for i in range(ar):
        for j in range(bc): c[i,j] = dot(a[i,:], b[:,j])
    return c
m1a,m2a = m1.numpy(),m2.numpy()
test_close(t1,matmul(m1a, m2a))
%timeit -n 50 matmul(m1a,m2a)
495 μs ± 39.4 μs per loop (mean ± std. dev. of 7 runs, 50 loops each)

Version 2: Element-wise Operations

def matmul(a,b):
    (ar,ac),(br,bc) = a.shape,b.shape
    c = torch.zeros(ar, bc)
    for i in range(ar):
        for j in range(bc): c[i,j] = (a[i,:] * b[:,j]).sum()
    return c
test_close(t1,matmul(m1, m2))
%timeit -n 50 _=matmul(m1, m2)
1.48 ms ± 354 μs per loop (mean ± std. dev. of 7 runs, 50 loops each)

Version 3: torch.dot

def matmul(a,b):
    (ar,ac),(br,bc) = a.shape,b.shape
    c = torch.zeros(ar, bc)
    for i in range(ar):
        for j in range(bc): c[i,j] = torch.dot(a[i,:], b[:,j])
    return c
test_close(t1,matmul(m1, m2))
%timeit -n 50 _=matmul(m1, m2)
1.23 ms ± 380 μs per loop (mean ± std. dev. of 7 runs, 50 loops each)

Version 4: PyTorch Broadcasting

Excalidraw diagram showing broadcasting implementation of matrix multiplication

def matmul(a,b):
    (ar,ac),(br,bc) = a.shape,b.shape
    c = torch.zeros(ar, bc)
    for i in range(ar): c[i] = (a[i,:,None] * b).sum(dim=0)
    return c
test_close(t1,matmul(m1, m2))
%timeit -n 50 _=matmul(m1, m2)
314 μs ± 92.1 μs per loop (mean ± std. dev. of 7 runs, 50 loops each)

Version 5: Numba Broadcasting

@njit
def matmul(a,b):
    (ar,ac),(br,bc) = a.shape,b.shape
    c = np.zeros((ar, bc))
    for i in range(ar): c[i] = (a[i,:,None] * b).sum(axis=0)
    return c
test_close(t1,matmul(m1a, m2a))
%timeit -n 50 _=matmul(m1a, m2a)
69 μs ± 1.96 μs per loop (mean ± std. dev. of 7 runs, 50 loops each)

Version 6: Einstein Summation

Excalidraw diagram showing einsum implementation of matrix multiplication

def matmul(a,b): return torch.einsum('ik,kj->ij', a, b)
test_close(t1,matmul(m1, m2))
%timeit -n 50 _=matmul(m1, m2)
80.8 μs ± 4.18 μs per loop (mean ± std. dev. of 7 runs, 50 loops each)

Version 7: PyTorch @ Operator

test_close(t1,m1@m2)
%timeit -n 50 _=m1@m2
16.7 μs ± 1.96 μs per loop (mean ± std. dev. of 7 runs, 50 loops each)

Version 8: Numba CUDA

@cuda.jit
def matmul(a,b,c):
    i, j = cuda.grid(2)
    if i < c.shape[0] and j < c.shape[1]:
        tmp = 0.
        for k in range(a.shape[1]): tmp += a[i, k] * b[k, j]
        c[i,j] = tmp
def launch_kernel(kernel, grid_x, grid_y, *args, **kwargs):
    for i in range(grid_x):
        for j in range(grid_y): kernel((i,j), *args, **kwargs)
r = np.zeros(t1.shape)
m1g,m2g,rg = map(cuda.to_device, (m1,m2,r))
m1g.shape, m2g.shape, rg.shape
((5, 784), (784, 10), (5, 10))
TPB = 16
rr,rc = r.shape
blockspergrid = (math.ceil(rr / TPB), math.ceil(rc / TPB))
blockspergrid
(1, 1)
matmul[blockspergrid, (TPB,TPB)](m1g,m2g,rg)
r = rg.copy_to_host()
test_close(t1, r, eps=1e-3)
/mnt/my4tb/vishal_data/miniconda3/envs/course-numba/lib/python3.10/site-packages/numba/cuda/dispatcher.py:536: NumbaPerformanceWarning: Grid size 1 will likely result in GPU under-utilization due to low occupancy.
  warn(NumbaPerformanceWarning(msg))
%%timeit -n 10
matmul[blockspergrid, (TPB,TPB)](m1g,m2g,rg)
r = rg.copy_to_host()
245 μs ± 47.4 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Version 9: PyTorch .cuda

m1c,m2c = m1.cuda(),m2.cuda()
r=(m1c@m2c).cpu()
test_close(t1, r)
%timeit -n 10 r=(m1c@m2c).cpu()
113 μs ± 26.4 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Comparing Fastest Versions on Full Dataset

Method Time
PyTorch cuda 541 μs
Numba Cuda 3.91 ms
PyTorch @ Op 5.8 ms
Einstein Summation 5.87 ms
Numba Broadcasting 663 ms
PyTorch Broadcasting 1.26 s
Numba Dot Product 3.71 s

Numba Dot Product

def matmul(a,b):
    (ar,ac),(br,bc) = a.shape,b.shape
    c = torch.zeros(ar, bc)
    for i in range(ar):
        for j in range(bc): c[i,j] = dot(a[i,:], b[:,j])
    return c

x_train_a,weights_a = x_train.numpy(),weights.numpy()
%timeit _ = matmul(x_train_a, weights_a)
3.71 s ± 20.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

PyTorch Broadcasting

def matmul(a,b):
    (ar,ac),(br,bc) = a.shape,b.shape
    c = torch.zeros(ar, bc)
    for i in range(ar): c[i] = (a[i,:,None] * b).sum(dim=0)
    return c

%timeit _ = matmul(x_train, weights)
1.26 s ± 1.42 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
x_train.shape, weights.shape
(torch.Size([50000, 784]), torch.Size([784, 10]))
%timeit _ = matmul(x_train.cuda(), weights.cuda())
2.86 s ± 4.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Interestingly, putting the tensors on the GPU and then broadcasting is slower than the CPU.

Numba Cuda

@cuda.jit
def matmul(a,b,c):
    i, j = cuda.grid(2)
    if i < c.shape[0] and j < c.shape[1]:
        tmp = 0.
        for k in range(a.shape[1]): tmp += a[i, k] * b[k, j]
        c[i,j] = tmp
r = np.zeros((50000, 10))
m1g,m2g,rg = map(cuda.to_device, (x_train,weights,r))
TPB = 16
rr,rc = r.shape
blockspergrid = (math.ceil(rr / TPB), math.ceil(rc / TPB))
blockspergrid
(3125, 1)
%%timeit -n 10
matmul[blockspergrid, (TPB,TPB)](m1g,m2g,rg)
r = rg.copy_to_host()
3.91 ms ± 68.6 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

PyTorch cuda

m1c,m2c = x_train.cuda(),weights.cuda()
%timeit -n 10 r=(m1c@m2c).cpu()
541 μs ± 6.82 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Einstein Summation

def matmul(a,b): return torch.einsum('ik,kj->ij', a, b)
%timeit -n 10 _=matmul(x_train, weights)
5.87 ms ± 229 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Numba Broadcasting

@njit
def matmul(a,b):
    (ar,ac),(br,bc) = a.shape,b.shape
    c = np.zeros((ar, bc))
    for i in range(ar): c[i] = (a[i,:,None] * b).sum(axis=0)
    return c
_=matmul(x_train.numpy(), weights.numpy())
%timeit _=matmul(x_train.numpy(), weights.numpy())
663 ms ± 378 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)

PyTorch @ Op

%timeit -n 10 _=x_train@weights
5.8 ms ± 212 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Comparing 5-digit Subset to Full Dataset Times

Method Full Dataset Time 5-digit Subset Time/Rank
PyTorch cuda 541 μs 108 μs (4)
Numba Cuda 3.91 ms 108 μs (4)
PyTorch @ Op 5.8 ms 18.1 μs (1)
Einstein Summation 5.87 ms 83.1 μs (3)
Numba Broadcasting 663 ms 69.2 μs (2)
PyTorch Broadcasting 1.26 s 203 μs (6)
Numba Dot Product 3.71 s 542 μs (7)

Closing Thoughts

I initially ran into some problems on Colab when implementing @cuda.jit (an error about compute compatibility) so I switched to an RTX 3090 machine and installed the following, which let me successfully run this notebook:

conda install -y -c nvidia/label/cuda-12.8.0 cuda-toolkit
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128
conda install -y numba
conda install -y fastcore -c fastai

The glaring takeaway from this exercise is that these methods all scale differently. For the 5-digit subset, PyTorch cuda was about 9 times slower than PyTorch CPU (when using the @ operator). Numba cuda and PyTorch cuda were tied for the small subset, but PyTorch cuda was 8 times faster for the larger dataset. I don’t yet understand why these differences exist, so that’s something I’ll keep an eye out for as I learn more about how GPUs work!