1. Why Triton?

Triton is an open-source language and compiler by OpenAI that makes writing GPU kernels as accessible as writing NumPy code, while achieving performance competitive with hand-tuned CUDA.

The key insight: most GPU optimization comes from memory access patterns, not arithmetic. Triton handles the hard parts (coalescing, bank conflicts, synchronization) while you focus on the algorithm.

The GPU Programming Abstraction Ladder
PyTorch / JAX (High-Level) torch.matmul(), F.softmax() - No GPU knowledge needed Triton (Sweet Spot) ⭐ Python-like syntax, block-level programming, auto memory optimization CUDA C++ (Low-Level) Full control, manual memory management, thread-level programming PTX / SASS (Assembly) Easier More Control
Triton sits at the sweet spot: Python ease with near-CUDA performance

When to use Triton:

Triton Powers Production Systems

PyTorch 2.0's torch.compile() uses Triton as its default GPU backend. When you use torch.compile(), Triton kernels are generated automatically!

2. Triton vs CUDA: Key Differences

Aspect CUDA C++ Triton
Programming Model Thread-level (SIMT) Block-level (operate on tiles)
Memory Management Manual (shared mem, coalescing) Automatic (compiler optimizes)
Synchronization Explicit __syncthreads() Implicit (compiler inserts)
Bank Conflicts Manual padding/swizzling Automatic avoidance
Tensor Cores Manual wmma intrinsics Automatic when shapes match
Development Time Days to weeks Hours to days
Performance 100% (baseline) 80-100% of hand-tuned CUDA
CUDA vs Triton: Mental Model
CUDA: Think in Threads Thread Block (256 threads) int tid = threadIdx.x; output[tid] = input[tid] * 2; // Each thread: 1 element Triton: Think in Blocks Block of 1024 elements Program instance offs = pid * BLOCK + tl.arange(BLOCK) x = tl.load(input + offs) tl.store(output + offs, x * 2)
CUDA: one thread handles one element. Triton: one program handles a block of elements.

3. Core Concepts

Programs & Blocks

In Triton, you write a program that operates on a block of data:

Python triton_concepts.py
import triton
import triton.language as tl

@triton.jit
def my_kernel(
    input_ptr,    # Pointer to input tensor
    output_ptr,   # Pointer to output tensor
    n_elements,   # Total number of elements
    BLOCK_SIZE: tl.constexpr,  # Compile-time constant
):
    # Which program instance am I?
    pid = tl.program_id(0)  # Like blockIdx.x
    
    # What are my offsets into the data?
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    
    # Create mask for boundary handling
    mask = offsets < n_elements
    
    # Load a block of data (with masking)
    x = tl.load(input_ptr + offsets, mask=mask)
    
    # Compute
    y = x * 2
    
    # Store results
    tl.store(output_ptr + offsets, y, mask=mask)

Pointers & Masks

Triton uses pointer arithmetic (like C) rather than tensor indexing:

Python pointers_masks.py
# 1D access
ptr = base_ptr + offset              # Single element
ptrs = base_ptr + offsets            # Vector of pointers

# 2D access (row-major)
ptr = base_ptr + row * stride + col  # Element at [row, col]

# Masked load/store (CRUCIAL for correctness)
mask = offsets < n_elements
x = tl.load(ptr, mask=mask, other=0.0)  # Load 0 where mask is False
tl.store(ptr, x, mask=mask)              # Only store where mask is True

# 2D mask example
row_mask = rows[:, None] < M
col_mask = cols[None, :] < N
mask = row_mask & col_mask
Always Use Masks!

Unlike CUDA where you can early-return, Triton programs must handle all elements in the block. Use masks to prevent out-of-bounds access - forgetting masks causes silent memory corruption!

4. Hello World: Vector Addition

Let's implement c = a + b - the "hello world" of GPU programming:

Python vector_add.py
import torch
import triton
import triton.language as tl

@triton.jit
def vector_add_kernel(
    a_ptr, b_ptr, c_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,
):
    # Program ID determines which block of elements we process
    pid = tl.program_id(0)
    
    # Compute offsets for this program
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    
    # Mask for boundary
    mask = offsets < n_elements
    
    # Load vectors a and b
    a = tl.load(a_ptr + offsets, mask=mask)
    b = tl.load(b_ptr + offsets, mask=mask)
    
    # Add them
    c = a + b
    
    # Store result
    tl.store(c_ptr + offsets, c, mask=mask)


def vector_add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    """Triton vector addition wrapper."""
    assert a.shape == b.shape
    assert a.is_cuda and b.is_cuda
    
    c = torch.empty_like(a)
    n_elements = a.numel()
    
    # Grid: how many programs to launch
    BLOCK_SIZE = 1024
    grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
    
    # Launch kernel
    vector_add_kernel[grid](
        a, b, c,
        n_elements,
        BLOCK_SIZE=BLOCK_SIZE,
    )
    
    return c


# Test it!
if __name__ == "__main__":
    a = torch.randn(10000, device="cuda")
    b = torch.randn(10000, device="cuda")
    
    c_triton = vector_add(a, b)
    c_torch = a + b
    
    print(f"Max error: {(c_triton - c_torch).abs().max()}")
    # Max error: 0.0
How Triton Launches Programs
Input: n_elements = 3500 Block 0: [0:1024] Block 1: [1024:2048] Block 2: [2048:3072] Block 3 [3072:3500] masked! grid = (ceil(3500 / 1024),) = (4,) 4 program instances launched, each handles BLOCK_SIZE=1024 elements Block 3 uses mask to handle partial block (428 real + 596 masked)

5. Matrix Multiplication with Tiling

Matrix multiplication is the workhorse of deep learning. Here's how to implement a tiled matmul in Triton:

Tiled Matrix Multiplication: C = A @ B
A (M×K) Block row (BLOCK_M) B (K×N) BLOCK_N C (M×N) Output tile = Inner Loop: Accumulate over K dimension A tile 0 × B tile 0 + A tile 1 × B tile 1 + ... = C tile
Each program computes one BLOCK_M × BLOCK_N tile of output by iterating over K
Python matmul_triton.py
import triton
import triton.language as tl

@triton.jit
def matmul_kernel(
    # Pointers to matrices
    a_ptr, b_ptr, c_ptr,
    # Matrix dimensions
    M, N, K,
    # Strides (elements to skip for next row/col)
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    # Block sizes (compile-time constants)
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    """Compute C = A @ B for one output tile."""
    # Program ID for 2D grid
    pid_m = tl.program_id(0)  # Which row block
    pid_n = tl.program_id(1)  # Which col block
    
    # Compute starting positions
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)
    
    # Pointers to first tiles of A and B
    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn
    
    # Accumulator for output tile (in registers)
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    
    # Loop over K dimension in blocks
    for k in range(0, K, BLOCK_K):
        # Load tiles with boundary masking
        a_mask = (offs_m[:, None] < M) & (offs_k[None, :] + k < K)
        b_mask = (offs_k[:, None] + k < K) & (offs_n[None, :] < N)
        
        a = tl.load(a_ptrs, mask=a_mask, other=0.0)
        b = tl.load(b_ptrs, mask=b_mask, other=0.0)
        
        # Multiply and accumulate
        acc += tl.dot(a, b)  # Uses Tensor Cores if available!
        
        # Advance pointers to next K block
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk
    
    # Store output tile
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    c_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
    tl.store(c_ptrs, acc, mask=c_mask)


def matmul(a, b):
    """Triton matrix multiplication: C = A @ B"""
    M, K = a.shape
    K, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=a.dtype)
    
    # Block sizes (tunable)
    BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 32
    
    # 2D grid of programs
    grid = (
        triton.cdiv(M, BLOCK_M),
        triton.cdiv(N, BLOCK_N),
    )
    
    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
    )
    return c
tl.dot() Uses Tensor Cores

When block sizes are multiples of 16 and using FP16/BF16 inputs, Triton's tl.dot() automatically uses Tensor Cores for massive speedup (>10x over FP32 CUDA cores).

6. Fused Operations

The real power of Triton is kernel fusion - combining multiple operations to avoid memory round-trips:

Why Fusion Matters
Unfused (3 kernel launches) exp(x) HBM sum() div() 6 memory operations (read + write × 3) Fused (1 kernel launch) fused_softmax(x) = exp(x) / sum(exp(x)) 2 memory operations (1 read + 1 write) ~3x faster (memory bound)

Fused Softmax

Python fused_softmax.py
@triton.jit
def softmax_kernel(
    input_ptr, output_ptr,
    n_cols,
    input_row_stride, output_row_stride,
    BLOCK_SIZE: tl.constexpr,
):
    """Compute softmax for one row."""
    # Each program handles one row
    row_idx = tl.program_id(0)
    
    # Pointers to row start
    row_start_ptr = input_ptr + row_idx * input_row_stride
    col_offsets = tl.arange(0, BLOCK_SIZE)
    
    # Load row with masking
    mask = col_offsets < n_cols
    row = tl.load(row_start_ptr + col_offsets, mask=mask, other=-float('inf'))
    
    # Numerically stable softmax
    # 1. Subtract max for numerical stability
    row_max = tl.max(row, axis=0)
    row = row - row_max
    
    # 2. Exponentiate
    numerator = tl.exp(row)
    
    # 3. Sum and divide
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator
    
    # Store result
    output_row_ptr = output_ptr + row_idx * output_row_stride
    tl.store(output_row_ptr + col_offsets, softmax_output, mask=mask)


def softmax(x):
    """Fused softmax over last dimension."""
    n_rows, n_cols = x.shape
    
    # BLOCK_SIZE must be power of 2 and >= n_cols
    BLOCK_SIZE = triton.next_power_of_2(n_cols)
    
    y = torch.empty_like(x)
    
    # One program per row
    grid = (n_rows,)
    
    softmax_kernel[grid](
        x, y,
        n_cols,
        x.stride(0), y.stride(0),
        BLOCK_SIZE=BLOCK_SIZE,
    )
    return y

Fused LayerNorm

Python fused_layernorm.py
@triton.jit
def layernorm_kernel(
    x_ptr, y_ptr, weight_ptr, bias_ptr,
    n_cols, eps,
    x_row_stride,
    BLOCK_SIZE: tl.constexpr,
):
    """LayerNorm: y = (x - mean) / sqrt(var + eps) * weight + bias"""
    row_idx = tl.program_id(0)
    
    # Offsets for this row
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols
    
    # Load input row
    x_ptr += row_idx * x_row_stride
    x = tl.load(x_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32)
    
    # Compute mean
    mean = tl.sum(x, axis=0) / n_cols
    
    # Compute variance
    x_centered = tl.where(mask, x - mean, 0.0)
    var = tl.sum(x_centered * x_centered, axis=0) / n_cols
    
    # Normalize
    rstd = 1.0 / tl.sqrt(var + eps)
    x_norm = x_centered * rstd
    
    # Scale and shift (load weight and bias)
    weight = tl.load(weight_ptr + col_offsets, mask=mask)
    bias = tl.load(bias_ptr + col_offsets, mask=mask)
    y = x_norm * weight + bias
    
    # Store
    y_ptr += row_idx * x_row_stride
    tl.store(y_ptr + col_offsets, y, mask=mask)


def layer_norm(x, weight, bias, eps=1e-5):
    """Fused LayerNorm."""
    n_rows = x.numel() // x.shape[-1]
    n_cols = x.shape[-1]
    
    x_flat = x.view(n_rows, n_cols)
    y = torch.empty_like(x_flat)
    
    BLOCK_SIZE = triton.next_power_of_2(n_cols)
    grid = (n_rows,)
    
    layernorm_kernel[grid](
        x_flat, y, weight, bias,
        n_cols, eps,
        x_flat.stride(0),
        BLOCK_SIZE=BLOCK_SIZE,
    )
    return y.view_as(x)

7. Autotuning for Performance

Triton's killer feature: automatic tuning of block sizes and other parameters:

Python autotuned_matmul.py
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 32}, num_warps=8),
        triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_warps=8),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_warps=8),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 256, 'BLOCK_K': 32}, num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_warps=4),
    ],
    key=['M', 'N', 'K'],  # Retune when these change
)
@triton.jit
def matmul_kernel_autotuned(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
    BLOCK_K: tl.constexpr,
):
    # Same kernel code as before...
    # Triton will try all configs and pick the fastest!
    ...
How Autotuning Works

The first time your kernel runs with new dimensions, Triton benchmarks all configurations randomly for ~20ms total, then caches the best one. Subsequent calls use the cached config instantly.

8. PyTorch Integration

Triton kernels integrate seamlessly with PyTorch autograd:

Python pytorch_integration.py
import torch
from torch.autograd import Function

class TritonSoftmax(Function):
    @staticmethod
    def forward(ctx, x):
        # Run forward kernel
        y = softmax(x)  # Our Triton kernel
        ctx.save_for_backward(y)
        return y
    
    @staticmethod
    def backward(ctx, grad_output):
        # Softmax backward: dy * (y - y * sum(dy * y))
        y, = ctx.saved_tensors
        # Could also implement backward as Triton kernel!
        grad_input = grad_output * y - y * (grad_output * y).sum(dim=-1, keepdim=True)
        return grad_input


# Use in models
def triton_softmax(x):
    return TritonSoftmax.apply(x)


# Example: Custom attention with Triton
class TritonAttention(torch.nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.qkv = torch.nn.Linear(d_model, 3 * d_model)
        self.out = torch.nn.Linear(d_model, d_model)
    
    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv(x).view(B, T, 3, self.n_heads, self.d_head)
        q, k, v = qkv.unbind(2)
        
        # Attention scores
        scores = torch.einsum('bthd,bshd->bhts', q, k) / (self.d_head ** 0.5)
        
        # Use our Triton softmax!
        attn = triton_softmax(scores)
        
        out = torch.einsum('bhts,bshd->bthd', attn, v)
        return self.out(out.reshape(B, T, C))

9. Debugging & Profiling

Print Debugging

Python debugging.py
@triton.jit
def debug_kernel(x_ptr, n, BLOCK: tl.constexpr):
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    
    x = tl.load(x_ptr + offs, mask=offs < n)
    
    # Debug print (only on first program!)
    if pid == 0:
        tl.device_print("First block values:", x)
        tl.device_print("Max:", tl.max(x))


# Set environment for debugging
import os
os.environ["TRITON_PRINT_AUTOTUNING"] = "1"  # Show autotuning
os.environ["TRITON_INTERPRET"] = "1"  # Run on CPU for debugging

Profiling with Nsight

Bash profile.sh
# Profile with Nsight Systems
nsys profile -o triton_trace python my_triton_script.py

# Profile with Nsight Compute (detailed kernel analysis)
ncu --set full -o triton_kernel python my_triton_script.py

# In Python: use built-in timing
import triton.testing

@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['N'],
        x_vals=[2**i for i in range(10, 25)],
        line_arg='provider',
        line_vals=['triton', 'torch'],
        line_names=['Triton', 'PyTorch'],
        ylabel='GB/s',
        plot_name='vector-add-performance',
    )
)
def benchmark(N, provider):
    x = torch.randn(N, device='cuda', dtype=torch.float32)
    y = torch.randn(N, device='cuda', dtype=torch.float32)
    
    if provider == 'triton':
        ms = triton.testing.do_bench(lambda: vector_add(x, y))
    else:
        ms = triton.testing.do_bench(lambda: x + y)
    
    gbps = 3 * N * 4 / ms * 1e-6  # 3 arrays × N elements × 4 bytes
    return gbps

benchmark.run(print_data=True, show_plots=True)

10. Best Practices & Patterns

Do's

✓ Use power-of-2 block sizes - Required for many operations
✓ Always use masks - Even if you think data is aligned
✓ Accumulate in FP32 - Then cast to FP16/BF16 for output
✓ Fuse memory-bound ops - Biggest wins come from avoiding memory traffic
✓ Use autotuning - Let Triton find optimal parameters

Don'ts

✗ Don't use dynamic shapes in loops - Must be compile-time constant
✗ Don't forget stride parameters - Non-contiguous tensors need them
✗ Don't mix integer and float carelessly - Explicit casts required
✗ Don't ignore numerical stability - Subtract max before softmax, etc.

Common Patterns

Python common_patterns.py
# Pattern 1: Row-wise operations (one program per row)
row_idx = tl.program_id(0)
col_offs = tl.arange(0, BLOCK_SIZE)
ptr = base_ptr + row_idx * row_stride + col_offs

# Pattern 2: 2D tiling (programs form a grid)
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

# Pattern 3: Reduction within block
x = tl.load(...)
total = tl.sum(x, axis=0)  # Reduces BLOCK_SIZE elements to 1

# Pattern 4: Broadcasting
a = tl.load(...)  # shape: (BLOCK_M,)
b = tl.load(...)  # shape: (BLOCK_N,)
c = a[:, None] + b[None, :]  # shape: (BLOCK_M, BLOCK_N)

# Pattern 5: Conditional computation
result = tl.where(condition, value_if_true, value_if_false)

# Pattern 6: Atomic operations (for reductions across programs)
tl.atomic_add(ptr, value)  # Thread-safe accumulation

Summary & Resources

When to Use Triton vs Alternatives
Use PyTorch/JAX Standard ops, quick prototyping Use Triton ⭐ Custom fusions, research, Flash Attention Use CUDA C++ Maximum control, exotic hardware Performance critical? → Need custom ops? → Need Tensor Cores/special hardware? No → PyTorch | Yes + research → Triton | Yes + production → CUDA (or Triton with autotuning)

Key takeaways:

Resources: