In this project, I wrote a custom matrix multiplication kernel using Triton that fuses three operations—A @ B + C, followed by a ReLU activation—into a single GPU kernel. The goal was to outperform PyTorch’s high-performance cuBLAS backend by taking advantage of kernel fusion and manual tiling. My Triton implementation achieved a 1.4× speedup over the equivalent PyTorch function on fp16 tensors.

Kernel Design

The kernel uses several optimization strategies commonly found in high-performance GPU code:

  • Shared memory tiling: Tiles from matrices A and B are cooperatively loaded into shared memory to reduce redundant global memory access.
  • Register tiling: Partial results are accumulated in registers using tl.dot() for efficient computation.
  • Fused Epilogue: The kernel applies C + matmul followed by ReLU directly before writing to global memory—reducing latency and memory bandwidth usage.

I implemented the kernel using Triton’s block pointer abstraction (tl.make_block_ptr and tl.advance) to tile and navigate memory across program IDs. Tiling dimensions BLOCK_M, BLOCK_N, and BLOCK_K were statically passed as compile-time constants to Triton’s @jit decorator.

Automatic Code Generation & Launch

I wrote a wrapper function matmul_add_relu_fp16() that:

  • Allocates the output tensor
  • Calculates grid size based on BLOCK_M and BLOCK_N
  • Launches the Triton kernel with all required strides and dimensions

This function works with PyTorch tensors and executes entirely on GPU.

Grid Search for Performance Tuning

To optimize the kernel, I implemented a grid search over the tile sizes (BLOCK_M, BLOCK_N, BLOCK_K) and num_warps. I measured execution time across many configurations to find the best-performing kernel variant. Through this search, I found that using:

  • BLOCK_M = 128
  • BLOCK_N = 256
  • BLOCK_K = 32

yielded the best performance on my GPU, consistently outperforming PyTorch by over 1.4× in matmul + add + ReLU throughput on large matrices.

Outcome & Learnings

This project gave me hands-on experience with low-level GPU programming using Triton. I learned to:

  • Write performant GPU kernels with memory coalescing and tiling
  • Apply fusion to reduce overhead and memory traffic
  • Benchmark and tune GPU workloads using grid search
  • Use Triton’s abstractions for pointer arithmetic, program IDs, and boundary checking

Writing a fused kernel by hand helped me understand the memory and compute bottlenecks in neural network operations and why kernel fusion plays a central role in ML model acceleration.