Kavya G

Matrix multiplication in Triton (understanding tl.dot)

Continuing from my last Triton post on matrix addition, I’ve now turned towards understanding matrix mulitplication and deriving the optimizations from scratch.

This post is a first attempt at naive matrix multiplication (each thread computes a single cell), which turned out to be more complex in Triton than expected, seeing that the language really prefers if you use tiling.

The idea is simple - given 2 input matrices, A (M x K) and B (K x N), we want to calculate the matrix product C (M x N). For the naive approach, we create a grid of size (M x N) where each thread, responsible for a single cell, calculates the dot product of a row of A and a column of B. Simple.

We start as before with the header code:

import torch
import triton
import triton.language as tl

DEVICE = triton.runtime.driver.active.get_active_torch_device()

Let’s dive right into the multiply kernel. We add the triton.jit annotation and the kernel takes the matrix sizes as inputs, followed by the matrix pointers themselves.

An important point to note (and something I myself forgot) is that these arguments are not the tensors themselves, just pointers to them. In case you forget, it is best to name the variables the way the Triton blog does.

@triton.jit
def naive_multiply(M: tl.constexpr,
                   N: tl.constexpr,
                   K: tl.constexpr,
                   A_ptr,
                   B_ptr,
                   C_ptr,
    ):

M and N are the grid sizes and hence passed as constants.

We first derive the xid and yid:

    xid = tl.program_id(axis=0)
    yid = tl.program_id(axis=1)

Then use those values to load the xth row of A matrix and yth column of B matrix.

    a_ids = tl.arange(0, K) + xid * K
    a = tl.load(A_ptr + a_ids)
    print(f"size: {a.shape}")

    b_ids = tl.arange(0, K) * N + yid
    b = tl.load(B_ptr + b_ids)
    print(f"size: {b.shape}")

I derived the math by hand and some trial and error.

The print statements are for debugging and as mentioned in the previous blog, adding the TRITON_INTERPRET='1' env gives the print outputs.

Otherwise we get the following error regarding the print statements:

Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type <class 'triton.language.core.tuple'>

Just for completeness, let’s finish the calling function as well to run and see the prints in action.

def multiply(a, b):
    M, K = a.shape
    K, N = b.shape
    c = torch.empty((M, N), dtype=a.dtype, device=a.device)
    naive_multiply[(M, N)](
        M, N, K,
        a, b, c,
    )
    return c

M = 2
N = 2
K = 2

A = torch.rand(M, K, device=DEVICE)
B = torch.rand(K, N, device=DEVICE)

out = multiply(A, B)

Of course, we don’t get any output now, but we can see that the a_ids and b_ids are simple 1D tensors containing the indices the thread will operate on.

Let’s continue by running the dot product on these loaded tensors.

	tmp = tl.dot(a, b)
    print(f"size: {tmp.shape}")

For now we don’t store the result back, just print the intermediate output. Write this to naive.py and run this as:

TRITON_INTERPRET='1' python3 naive.py

Unfortunately (and the reason for this blog) is that we don’t get what we are looking for. We get this error:

  File "/usr/local/lib/python3.12/dist-packages/triton/language/semantic.py", line 1497, in dot
    assert lhs_rank == rhs_rank == 2 or lhs_rank == rhs_rank == 3, f"Both inputs must be either 2D or 3D; (lhs: {lhs.shape} vs rhs: {rhs.shape})"
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Both inputs must be either 2D or 3D; (lhs: ['constexpr[2]'] vs rhs: ['constexpr[2]'])

We can see the tl.dot docs as well to understand this. tl.dot expects the inputs to be atleast 2D matrices. Which leads to the question, what does tl.dot even do.

Let’s do a quick test, giving it a 2D matrix that it expects and see what happens.

import torch
import triton
import triton.language as tl
DEVICE = triton.runtime.driver.active.get_active_torch_device()

@triton.jit
def try_dot(M: tl.constexpr,
            A,
            B,
            O_ptr
            ):
    xid = tl.program_id(axis=0)
    yid = tl.program_id(axis=1)
    a_ids = tl.arange(0, M)
    b_ids = tl.arange(0, M)
    ids = a_ids[None, :] + b_ids[:, None] * M
    a = tl.load(A + ids)
    b = tl.load(B + ids)
    x = tl.dot(a, b)
    tl.store(O_ptr + ids, x)

The program is quite similar to our matrix multiplication, we take a square matrix (for simplicity) of size M. a_ids and b_ids is just a tensor of size M. To convert this to a 2D tensor by using the [None,:] to add a dimension. Overall at the end of that maneuver we get a simple 2D matrix of increasing indices.

We load both A and B with the same set of indices and perform dot product on them.

def dot(a, b, M):
    c = torch.empty((M, M), dtype=a.dtype, device=a.device)
    try_dot[(M, M)](M, a, b, c)
    return c

M = 128
A = torch.rand((M, M), device=DEVICE)
B = torch.rand((M, M), device=DEVICE)

C = dot(A, B, M)
act = torch.matmul(A, B)
print(torch.allclose(C, act, rtol=1e-2))

The runner code is again similar, but the last 2 lines gives us what we have been looking for. tl.dot on a 2D matrix is nothing but matrix multiplication!

With that detour, we continue with our awkward task of implementing matrix multiplication while using matrix multiplication.

tl.dot expects a 2D matrix, so let’s give it one using the same squeeze technique to add a dimension for the a_ids and the b_ids.

    a_ids = tl.arange(0, M) + xid * M
    a_ids = a_ids[None, :]
    a = tl.load(A + a_ids)
    print(f"size: {a.shape}")

    b_ids = tl.arange(0, N) * K + yid * N
    b_ids = b_ids[:, None]
    b = tl.load(B + b_ids)
    print(f"size: {b.shape}")

With that done, let’s complete the code by storing result back in the output tensor.

    o_id = xid * N + yid
    tl.store(C + o_id, tmp)

Unfortunately, we get an error:

           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/triton/language/semantic.py", line 1248, in _store_legacy
    raise ValueError("Value argument cannot be block type if pointer argument is not a block")
ValueError: Value argument cannot be block type if pointer argument is not a block

Turns out we are trying to write a nested tensor onto the location where we only want to write a single value.

Using a Claude suggested hack to convert the nested tensor to a single value,

    tmp = tl.sum(tmp)

Let’s remove the print statements, increase the matrix size and run without debug mode. And get ….. another error!

  File "/usr/local/lib/python3.12/dist-packages/triton/language/semantic.py", line 1503, in dot
    assert lhs.shape[-2].value >= min_dot_size[0] and lhs.shape[-1].value >= min_dot_size[2] \
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: Input shapes should have M >= 16, N >= 16 and K >= 16

Examining the code further shows that triton’s semantic checks require the backend to specify the minimum size of the matrix for performing dot operation. For NVIDIA backends, this is set as 16 to use the tensor core for matrix multiplication. Unfortunately it does not allow fallback to FMA for smaller size (an ongoing issue).

One way to solve this is to pad the lean tensors into size 16 (on the second dimension), but this is obviously inefficient due to wasted computation.

Another is to do the dot operation manually instead of relying on the tl provided function. I tried this through for loops but Triton doesn’t really allow to work on tensor elements one by one very easily.

Claude suggests element-wise multiplication and using tl.sum to achieve the same result. Let’s verify that with a test script.

import torch
import triton
import triton.language as tl
DEVICE = triton.runtime.driver.active.get_active_torch_device()

@triton.jit
def verify_dot(M: tl.constexpr,
               A,
               B,
               C,
               ):
    xid = tl.program_id(axis=0)
    yid = tl.program_id(axis=1)
    a_ids = tl.arange(0, M)
    b_ids = tl.arange(0, M)
    a = tl.load(A + a_ids)
    b = tl.load(B + b_ids)

    x = tl.sum(a * b, axis=0)
    print(f"x: {x}")
    tl.store(C, x)

def dot(A, B, M):
    C = torch.zeros(1, device=DEVICE)
    verify_dot[(1, 1)](M, A, B, C)
    return C

M = 16
A = torch.rand(M, device=DEVICE)
B = torch.rand(M, device=DEVICE)
C = dot(A, B, M)
act = torch.dot(A, B)
print(torch.allclose(C, act, rtol=1e-2))

We get a true, showing that this logic should work. So let’s replace our tl.dot section with this logic. The entire function now looks like this:

@triton.jit
def naive_multiply(M: tl.constexpr,
                   N: tl.constexpr,
                   K: tl.constexpr,
                   A,
                   B,
                   C,
    ):
    xid = tl.program_id(axis=0)
    yid = tl.program_id(axis=1)

    a_ids = tl.arange(0, K) + xid * K
    a = tl.load(A + a_ids)
    print(f"A: {a_ids}")

    b_ids = tl.arange(0, K) * N + yid
    b = tl.load(B + b_ids)
    print(f"B: {b_ids}")

    tmp = tl.sum(a * b, axis=0)

    o_id = xid * N + yid
    print(f"C[{o_id}] = {tmp}")
    tl.store(C + o_id, tmp)

And, running with the rest of the runner code gives us SUCCESS!!!

To continue towards my original goal of reproducing the blog, I’ll need to add benchmarking/profiling code to measure the time this kernel takes.

But for now, the complete final code is below:

import torch
import triton
import triton.language as tl

DEVICE = triton.runtime.driver.active.get_active_torch_device()

@triton.jit
def naive_multiply(M: tl.constexpr,
                   N: tl.constexpr,
                   K: tl.constexpr,
                   A,
                   B,
                   C,
    ):
    xid = tl.program_id(axis=0)
    yid = tl.program_id(axis=1)

    a_ids = tl.arange(0, K) + xid * K
    a = tl.load(A + a_ids)

    b_ids = tl.arange(0, K) * N + yid
    b = tl.load(B + b_ids)

    tmp = tl.sum(a * b, axis=0)

    o_id = xid * N + yid
    print(f"C[{o_id}] = {tmp}")
    tl.store(C + o_id, tmp)
	
def multiply(a, b):
    assert a.shape[1] == b.shape[0], "Incompatible dimensions"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
    K, N = b.shape
    c = torch.empty((M, N), dtype=a.dtype, device=a.device)
    naive_multiply[(M, N)](
        M, N, K,
        a, b, c,
    )
    return c
	
M = 128
N = 128
K = 128

A = torch.rand(M, K, device=DEVICE)
B = torch.rand(K, N, device=DEVICE)

out = multiply(A, B)
act = A @ B
print(torch.allclose(out, act, rtol=1e-2))