In this project, I wrote a custom matrix multiplication kernel using Triton that fuses three operations—A @ B + C
, followed by a ReLU activation—into a single GPU kernel. The goal was to outperform PyTorch’s high-performance cuBLAS backend by taking advantage of kernel fusion and manual tiling. My Triton implementation achieved a 1.4× speedup over the equivalent PyTorch function on fp16 tensors.
Kernel Design
The kernel uses several optimization strategies commonly found in high-performance GPU code:
- Shared memory tiling: Tiles from matrices A and B are cooperatively loaded into shared memory to reduce redundant global memory access.
- Register tiling: Partial results are accumulated in registers using
tl.dot()
for efficient computation. - Fused Epilogue: The kernel applies
C + matmul
followed byReLU
directly before writing to global memory—reducing latency and memory bandwidth usage.
I implemented the kernel using Triton’s block pointer abstraction (tl.make_block_ptr
and tl.advance
) to tile and navigate memory across program IDs. Tiling dimensions BLOCK_M
, BLOCK_N
, and BLOCK_K
were statically passed as compile-time constants to Triton’s @jit
decorator.
Automatic Code Generation & Launch
I wrote a wrapper function matmul_add_relu_fp16()
that:
- Allocates the output tensor
- Calculates grid size based on
BLOCK_M
andBLOCK_N
- Launches the Triton kernel with all required strides and dimensions
This function works with PyTorch tensors and executes entirely on GPU.
Grid Search for Performance Tuning
To optimize the kernel, I implemented a grid search over the tile sizes (BLOCK_M
, BLOCK_N
, BLOCK_K
) and num_warps
. I measured execution time across many configurations to find the best-performing kernel variant. Through this search, I found that using:
BLOCK_M = 128
BLOCK_N = 256
BLOCK_K = 32
yielded the best performance on my GPU, consistently outperforming PyTorch by over 1.4× in matmul + add + ReLU
throughput on large matrices.
Outcome & Learnings
This project gave me hands-on experience with low-level GPU programming using Triton. I learned to:
- Write performant GPU kernels with memory coalescing and tiling
- Apply fusion to reduce overhead and memory traffic
- Benchmark and tune GPU workloads using grid search
- Use Triton’s abstractions for pointer arithmetic, program IDs, and boundary checking
Writing a fused kernel by hand helped me understand the memory and compute bottlenecks in neural network operations and why kernel fusion plays a central role in ML model acceleration.