1. Why Triton?
Triton is an open-source language and compiler by OpenAI that makes writing GPU kernels as accessible as writing NumPy code, while achieving performance competitive with hand-tuned CUDA.
The key insight: most GPU optimization comes from memory access patterns, not arithmetic. Triton handles the hard parts (coalescing, bank conflicts, synchronization) while you focus on the algorithm.
When to use Triton:
- Custom fused operations - Combine multiple ops to reduce memory traffic
- Exotic activations/losses - When PyTorch doesn't have what you need
- Research prototyping - Test kernel ideas 10x faster than CUDA
- Flash Attention variants - Memory-efficient attention mechanisms
PyTorch 2.0's torch.compile() uses Triton as its default GPU backend. When you use torch.compile(), Triton kernels are generated automatically!
2. Triton vs CUDA: Key Differences
| Aspect | CUDA C++ | Triton |
|---|---|---|
| Programming Model | Thread-level (SIMT) | Block-level (operate on tiles) |
| Memory Management | Manual (shared mem, coalescing) | Automatic (compiler optimizes) |
| Synchronization | Explicit __syncthreads() |
Implicit (compiler inserts) |
| Bank Conflicts | Manual padding/swizzling | Automatic avoidance |
| Tensor Cores | Manual wmma intrinsics |
Automatic when shapes match |
| Development Time | Days to weeks | Hours to days |
| Performance | 100% (baseline) | 80-100% of hand-tuned CUDA |
3. Core Concepts
Programs & Blocks
In Triton, you write a program that operates on a block of data:
- Program ID (
pid): Which chunk of data this instance handles (like blockIdx) - Block Size: How many elements each program processes (a compile-time constant)
- Offsets: The indices into your tensors (computed from pid + arange)
import triton
import triton.language as tl
@triton.jit
def my_kernel(
input_ptr, # Pointer to input tensor
output_ptr, # Pointer to output tensor
n_elements, # Total number of elements
BLOCK_SIZE: tl.constexpr, # Compile-time constant
):
# Which program instance am I?
pid = tl.program_id(0) # Like blockIdx.x
# What are my offsets into the data?
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# Create mask for boundary handling
mask = offsets < n_elements
# Load a block of data (with masking)
x = tl.load(input_ptr + offsets, mask=mask)
# Compute
y = x * 2
# Store results
tl.store(output_ptr + offsets, y, mask=mask)
Pointers & Masks
Triton uses pointer arithmetic (like C) rather than tensor indexing:
# 1D access
ptr = base_ptr + offset # Single element
ptrs = base_ptr + offsets # Vector of pointers
# 2D access (row-major)
ptr = base_ptr + row * stride + col # Element at [row, col]
# Masked load/store (CRUCIAL for correctness)
mask = offsets < n_elements
x = tl.load(ptr, mask=mask, other=0.0) # Load 0 where mask is False
tl.store(ptr, x, mask=mask) # Only store where mask is True
# 2D mask example
row_mask = rows[:, None] < M
col_mask = cols[None, :] < N
mask = row_mask & col_mask
Unlike CUDA where you can early-return, Triton programs must handle all elements in the block. Use masks to prevent out-of-bounds access - forgetting masks causes silent memory corruption!
4. Hello World: Vector Addition
Let's implement c = a + b - the "hello world" of GPU programming:
import torch
import triton
import triton.language as tl
@triton.jit
def vector_add_kernel(
a_ptr, b_ptr, c_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
# Program ID determines which block of elements we process
pid = tl.program_id(0)
# Compute offsets for this program
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
# Mask for boundary
mask = offsets < n_elements
# Load vectors a and b
a = tl.load(a_ptr + offsets, mask=mask)
b = tl.load(b_ptr + offsets, mask=mask)
# Add them
c = a + b
# Store result
tl.store(c_ptr + offsets, c, mask=mask)
def vector_add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""Triton vector addition wrapper."""
assert a.shape == b.shape
assert a.is_cuda and b.is_cuda
c = torch.empty_like(a)
n_elements = a.numel()
# Grid: how many programs to launch
BLOCK_SIZE = 1024
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
# Launch kernel
vector_add_kernel[grid](
a, b, c,
n_elements,
BLOCK_SIZE=BLOCK_SIZE,
)
return c
# Test it!
if __name__ == "__main__":
a = torch.randn(10000, device="cuda")
b = torch.randn(10000, device="cuda")
c_triton = vector_add(a, b)
c_torch = a + b
print(f"Max error: {(c_triton - c_torch).abs().max()}")
# Max error: 0.0
5. Matrix Multiplication with Tiling
Matrix multiplication is the workhorse of deep learning. Here's how to implement a tiled matmul in Triton:
import triton
import triton.language as tl
@triton.jit
def matmul_kernel(
# Pointers to matrices
a_ptr, b_ptr, c_ptr,
# Matrix dimensions
M, N, K,
# Strides (elements to skip for next row/col)
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
# Block sizes (compile-time constants)
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
"""Compute C = A @ B for one output tile."""
# Program ID for 2D grid
pid_m = tl.program_id(0) # Which row block
pid_n = tl.program_id(1) # Which col block
# Compute starting positions
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# Pointers to first tiles of A and B
a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
# Accumulator for output tile (in registers)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Loop over K dimension in blocks
for k in range(0, K, BLOCK_K):
# Load tiles with boundary masking
a_mask = (offs_m[:, None] < M) & (offs_k[None, :] + k < K)
b_mask = (offs_k[:, None] + k < K) & (offs_n[None, :] < N)
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
# Multiply and accumulate
acc += tl.dot(a, b) # Uses Tensor Cores if available!
# Advance pointers to next K block
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
# Store output tile
c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, acc, mask=c_mask)
def matmul(a, b):
"""Triton matrix multiplication: C = A @ B"""
M, K = a.shape
K, N = b.shape
c = torch.empty((M, N), device=a.device, dtype=a.dtype)
# Block sizes (tunable)
BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32
# 2D grid of programs
grid = (
triton.cdiv(M, BLOCK_M),
triton.cdiv(N, BLOCK_N),
)
matmul_kernel[grid](
a, b, c,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c.stride(0), c.stride(1),
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
)
return c
When block sizes are multiples of 16 and using FP16/BF16 inputs, Triton's tl.dot() automatically uses Tensor Cores for massive speedup (>10x over FP32 CUDA cores).
6. Fused Operations
The real power of Triton is kernel fusion - combining multiple operations to avoid memory round-trips:
Fused Softmax
@triton.jit
def softmax_kernel(
input_ptr, output_ptr,
n_cols,
input_row_stride, output_row_stride,
BLOCK_SIZE: tl.constexpr,
):
"""Compute softmax for one row."""
# Each program handles one row
row_idx = tl.program_id(0)
# Pointers to row start
row_start_ptr = input_ptr + row_idx * input_row_stride
col_offsets = tl.arange(0, BLOCK_SIZE)
# Load row with masking
mask = col_offsets < n_cols
row = tl.load(row_start_ptr + col_offsets, mask=mask, other=-float('inf'))
# Numerically stable softmax
# 1. Subtract max for numerical stability
row_max = tl.max(row, axis=0)
row = row - row_max
# 2. Exponentiate
numerator = tl.exp(row)
# 3. Sum and divide
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# Store result
output_row_ptr = output_ptr + row_idx * output_row_stride
tl.store(output_row_ptr + col_offsets, softmax_output, mask=mask)
def softmax(x):
"""Fused softmax over last dimension."""
n_rows, n_cols = x.shape
# BLOCK_SIZE must be power of 2 and >= n_cols
BLOCK_SIZE = triton.next_power_of_2(n_cols)
y = torch.empty_like(x)
# One program per row
grid = (n_rows,)
softmax_kernel[grid](
x, y,
n_cols,
x.stride(0), y.stride(0),
BLOCK_SIZE=BLOCK_SIZE,
)
return y
Fused LayerNorm
@triton.jit
def layernorm_kernel(
x_ptr, y_ptr, weight_ptr, bias_ptr,
n_cols, eps,
x_row_stride,
BLOCK_SIZE: tl.constexpr,
):
"""LayerNorm: y = (x - mean) / sqrt(var + eps) * weight + bias"""
row_idx = tl.program_id(0)
# Offsets for this row
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
# Load input row
x_ptr += row_idx * x_row_stride
x = tl.load(x_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
# Compute mean
mean = tl.sum(x, axis=0) / n_cols
# Compute variance
x_centered = tl.where(mask, x - mean, 0.0)
var = tl.sum(x_centered * x_centered, axis=0) / n_cols
# Normalize
rstd = 1.0 / tl.sqrt(var + eps)
x_norm = x_centered * rstd
# Scale and shift (load weight and bias)
weight = tl.load(weight_ptr + col_offsets, mask=mask)
bias = tl.load(bias_ptr + col_offsets, mask=mask)
y = x_norm * weight + bias
# Store
y_ptr += row_idx * x_row_stride
tl.store(y_ptr + col_offsets, y, mask=mask)
def layer_norm(x, weight, bias, eps=1e-5):
"""Fused LayerNorm."""
n_rows = x.numel() // x.shape[-1]
n_cols = x.shape[-1]
x_flat = x.view(n_rows, n_cols)
y = torch.empty_like(x_flat)
BLOCK_SIZE = triton.next_power_of_2(n_cols)
grid = (n_rows,)
layernorm_kernel[grid](
x_flat, y, weight, bias,
n_cols, eps,
x_flat.stride(0),
BLOCK_SIZE=BLOCK_SIZE,
)
return y.view_as(x)
7. Autotuning for Performance
Triton's killer feature: automatic tuning of block sizes and other parameters:
@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32}, num_warps=8),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_warps=8),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32}, num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_warps=4),
],
key=['M', 'N', 'K'], # Retune when these change
)
@triton.jit
def matmul_kernel_autotuned(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
):
# Same kernel code as before...
# Triton will try all configs and pick the fastest!
...
The first time your kernel runs with new dimensions, Triton benchmarks all configurations randomly for ~20ms total, then caches the best one. Subsequent calls use the cached config instantly.
8. PyTorch Integration
Triton kernels integrate seamlessly with PyTorch autograd:
import torch
from torch.autograd import Function
class TritonSoftmax(Function):
@staticmethod
def forward(ctx, x):
# Run forward kernel
y = softmax(x) # Our Triton kernel
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, grad_output):
# Softmax backward: dy * (y - y * sum(dy * y))
y, = ctx.saved_tensors
# Could also implement backward as Triton kernel!
grad_input = grad_output * y - y * (grad_output * y).sum(dim=-1, keepdim=True)
return grad_input
# Use in models
def triton_softmax(x):
return TritonSoftmax.apply(x)
# Example: Custom attention with Triton
class TritonAttention(torch.nn.Module):
def __init__(self, d_model, n_heads):
super().__init__()
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.qkv = torch.nn.Linear(d_model, 3 * d_model)
self.out = torch.nn.Linear(d_model, d_model)
def forward(self, x):
B, T, C = x.shape
qkv = self.qkv(x).view(B, T, 3, self.n_heads, self.d_head)
q, k, v = qkv.unbind(2)
# Attention scores
scores = torch.einsum('bthd,bshd->bhts', q, k) / (self.d_head ** 0.5)
# Use our Triton softmax!
attn = triton_softmax(scores)
out = torch.einsum('bhts,bshd->bthd', attn, v)
return self.out(out.reshape(B, T, C))
9. Debugging & Profiling
Print Debugging
@triton.jit
def debug_kernel(x_ptr, n, BLOCK: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
x = tl.load(x_ptr + offs, mask=offs < n)
# Debug print (only on first program!)
if pid == 0:
tl.device_print("First block values:", x)
tl.device_print("Max:", tl.max(x))
# Set environment for debugging
import os
os.environ["TRITON_PRINT_AUTOTUNING"] = "1" # Show autotuning
os.environ["TRITON_INTERPRET"] = "1" # Run on CPU for debugging
Profiling with Nsight
# Profile with Nsight Systems
nsys profile -o triton_trace python my_triton_script.py
# Profile with Nsight Compute (detailed kernel analysis)
ncu --set full -o triton_kernel python my_triton_script.py
# In Python: use built-in timing
import triton.testing
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'],
x_vals=[2**i for i in range(10, 25)],
line_arg='provider',
line_vals=['triton', 'torch'],
line_names=['Triton', 'PyTorch'],
ylabel='GB/s',
plot_name='vector-add-performance',
)
)
def benchmark(N, provider):
x = torch.randn(N, device='cuda', dtype=torch.float32)
y = torch.randn(N, device='cuda', dtype=torch.float32)
if provider == 'triton':
ms = triton.testing.do_bench(lambda: vector_add(x, y))
else:
ms = triton.testing.do_bench(lambda: x + y)
gbps = 3 * N * 4 / ms * 1e-6 # 3 arrays × N elements × 4 bytes
return gbps
benchmark.run(print_data=True, show_plots=True)
10. Best Practices & Patterns
✓ Use power-of-2 block sizes - Required for many operations
✓ Always use masks - Even if you think data is aligned
✓ Accumulate in FP32 - Then cast to FP16/BF16 for output
✓ Fuse memory-bound ops - Biggest wins come from avoiding memory traffic
✓ Use autotuning - Let Triton find optimal parameters
✗ Don't use dynamic shapes in loops - Must be compile-time constant
✗ Don't forget stride parameters - Non-contiguous tensors need them
✗ Don't mix integer and float carelessly - Explicit casts required
✗ Don't ignore numerical stability - Subtract max before softmax, etc.
Common Patterns
# Pattern 1: Row-wise operations (one program per row)
row_idx = tl.program_id(0)
col_offs = tl.arange(0, BLOCK_SIZE)
ptr = base_ptr + row_idx * row_stride + col_offs
# Pattern 2: 2D tiling (programs form a grid)
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# Pattern 3: Reduction within block
x = tl.load(...)
total = tl.sum(x, axis=0) # Reduces BLOCK_SIZE elements to 1
# Pattern 4: Broadcasting
a = tl.load(...) # shape: (BLOCK_M,)
b = tl.load(...) # shape: (BLOCK_N,)
c = a[:, None] + b[None, :] # shape: (BLOCK_M, BLOCK_N)
# Pattern 5: Conditional computation
result = tl.where(condition, value_if_true, value_if_false)
# Pattern 6: Atomic operations (for reductions across programs)
tl.atomic_add(ptr, value) # Thread-safe accumulation
Summary & Resources
Key takeaways:
- Triton makes GPU programming accessible with Python-like syntax
- Think in blocks, not threads - Triton handles the details
- Kernel fusion is where the real gains are (avoid memory round-trips)
- Use autotuning - don't guess block sizes
- Always mask boundary accesses
Resources: