1. The Inference Challenge

Training a model is only half the story. The other half—and often the more critical half for production systems—is inference: running the trained model to make predictions on new data. While training happens once (or periodically), inference happens millions or billions of times in deployment. A 10ms improvement in inference latency can translate to millions of dollars in compute savings and dramatically improved user experience.

This guide provides a comprehensive treatment of inference optimization, with particular emphasis on the hardware implications of each technique. Understanding why certain optimizations work requires understanding how modern processors—GPUs, CPUs, and specialized accelerators—actually execute neural network operations.

The Scale of the Problem

Consider the computational demands of modern AI systems:

The Scale of Modern AI Inference
Model Sizes: Memory Requirements for Inference Parameters 1T 100B 10B 1B 100M Model Timeline / Architecture ResNet-50 25M BERT-base 110M BERT-large 340M GPT-2 1.5B LLaMA-7B 7B LLaMA-70B 70B GPT-4* ~1.7T Mixtral 8×7B Claude 3 ~1T+ Memory Requirements (FP16 inference, weights only) 7B: 14GB 70B: 140GB 175B: 350GB 1T: 2TB + KV cache!
Model parameters have grown 10,000x in 6 years. Memory is the primary bottleneck.

The sheer scale creates multiple challenges:

The Real Cost of Inference

For a large-scale LLM deployment serving 1 million requests/day:

  • Raw compute cost: $50,000–$150,000/month on cloud GPUs
  • Every 10% improvement in efficiency saves $5,000–$15,000/month
  • A 4x speedup from quantization can reduce costs by $150,000+/year

Memory-Bound vs Compute-Bound Workloads

The first step in optimizing inference is understanding which resource is the bottleneck. Modern neural networks can be either memory-bound or compute-bound, and the optimal strategy differs dramatically between the two.

The Memory Hierarchy

Modern processors have a deep memory hierarchy with vastly different bandwidths:

GPU Memory Hierarchy and Bandwidth
Registers Shared Memory / L1 L2 Cache HBM3 (GPU Memory) System RAM (CPU) ~80 TB/s (register file) ~20 TB/s (shared mem) ~12 TB/s (L2, H100) 3.35 TB/s (HBM3, H100) ~0.05 TB/s (PCIe 5.0) ~20 MB total 256 KB/SM 50 MB (H100) 80 GB (H100) 256+ GB Key: There's a 67x bandwidth gap between HBM and registers. Keeping data in fast memory is critical!
The memory hierarchy creates a 10,000x bandwidth gap from registers to system RAM

Arithmetic Intensity: The Key Metric

Arithmetic intensity (also called operational intensity) measures how much computation is performed per byte of memory accessed:

Arithmetic Intensity = FLOPs / Bytes Accessed (ops/byte)

This metric determines whether an operation is memory-bound or compute-bound:

Arithmetic Intensity of Neural Network Operations
Operation FLOPs Bytes Intensity Bound Element-wise (ReLU, LayerNorm) O(n) O(n) ~1 Memory Reduction (Softmax, mean) O(n) O(n) ~1 Memory Attention (batch=1, seq=2K) O(n²·d) O(n²+n·d) ~10 Mixed MatMul (M=1, N=4096, K=4096) 2·M·N·K 2·(M·K+K·N+M·N) ~1 Memory MatMul (M=128, N=4096, K=4096) 2·M·N·K 2·(M·K+K·N+M·N) ~64 Compute Conv2D (3×3, 256→256, 56×56) ~2.3B FLOPs ~3MB weights ~766 Compute Key Insights: • LLM token generation (batch=1) is almost always memory-bound — optimize for bandwidth • Batched inference and CNNs are often compute-bound — optimize for throughput • The crossover point on H100: arithmetic intensity ≈ 200 ops/byte (3.35 TB/s ÷ 2000 TFLOPS)
Most LLM inference operations have low arithmetic intensity and are memory-bound

Why This Matters for Optimization

The bound determines which optimization strategy to use:

Regime Bottleneck Optimization Strategy Examples
Memory-bound Loading weights/activations Reduce data movement: quantization (INT8/INT4), KV cache compression, pruning LLM generation (batch=1), small-batch inference
Compute-bound Arithmetic operations Reduce FLOPs: pruning, early exit, efficient architectures CNNs, batched inference, training
Mixed Both Combined approaches: quantization + pruning Attention layers, medium batches
The LLM Inference Insight

LLM token-by-token generation has arithmetic intensity of approximately 1-2 ops/byte. This means:

  • An H100 (2000 TFLOPS, 3.35 TB/s) should theoretically generate at ~1600 tokens/second per 7B model
  • Actual throughput is ~100-200 tokens/s because we're memory-bound by 10-20x
  • Quantization is the single most impactful optimization—INT4 delivers 4x less memory traffic

The Roofline Model

The roofline model is a visual tool for understanding performance limits. It plots achievable FLOPS against arithmetic intensity, showing where your workload sits relative to hardware limits.

Roofline Model for NVIDIA H100 GPU
Arithmetic Intensity (FLOPs/Byte) 0.1 1 10 100 1000 10000 Achievable TFLOPS 0.1 1 10 100 1000 10000 Memory Ceiling (3.35 TB/s) Compute Ceiling (2000 TFLOPS) Ridge Point (~600 ops/byte) LLM gen (batch=1) LLM gen (batch=8) Attention BERT prefill CNN inference Batched GEMM Regions: Memory-bound (below ridge) Compute-bound (at ceiling)
Most LLM workloads operate in the memory-bound region (left of ridge point)

Reading the Roofline

  1. Below the roofline: There's room for optimization—your code isn't hitting hardware limits
  2. On the memory slope: You're memory-bound. Reduce data movement (quantization, pruning)
  3. On the compute ceiling: You're compute-bound. Reduce FLOPs or use faster hardware
  4. At the ridge: You're balanced—both resources fully utilized
Python roofline_analysis.py
import torch
import time

def analyze_operation(operation, input_tensors, num_flops, num_bytes):
    """Measure achieved performance and compare to roofline."""
    
    # Warmup
    for _ in range(10):
        output = operation(*input_tensors)
    
    torch.cuda.synchronize()
    
    # Benchmark
    start = time.perf_counter()
    for _ in range(100):
        output = operation(*input_tensors)
    torch.cuda.synchronize()
    elapsed = (time.perf_counter() - start) / 100
    
    # Calculate metrics
    achieved_tflops = (num_flops / elapsed) / 1e12
    achieved_bandwidth = (num_bytes / elapsed) / 1e12  # TB/s
    arithmetic_intensity = num_flops / num_bytes
    
    # H100 specs
    peak_tflops = 2000  # FP16 Tensor Core
    peak_bandwidth = 3.35  # TB/s HBM3
    ridge_point = peak_tflops / peak_bandwidth  # ~597 ops/byte
    
    # Determine bottleneck
    if arithmetic_intensity < ridge_point:
        roofline_limit = arithmetic_intensity * peak_bandwidth
        bound = "Memory-bound"
    else:
        roofline_limit = peak_tflops
        bound = "Compute-bound"
    
    efficiency = achieved_tflops / roofline_limit * 100
    
    print(f"Arithmetic Intensity: {arithmetic_intensity:.1f} ops/byte")
    print(f"Achieved: {achieved_tflops:.1f} TFLOPS")
    print(f"Roofline Limit: {roofline_limit:.1f} TFLOPS")
    print(f"Efficiency: {efficiency:.1f}% ({bound})")
    
    return achieved_tflops, bound, efficiency

# Example: Analyze a linear layer
batch_size = 1
input_dim = 4096
output_dim = 4096

x = torch.randn(batch_size, input_dim, device='cuda', dtype=torch.float16)
W = torch.randn(output_dim, input_dim, device='cuda', dtype=torch.float16)

# FLOPs = 2 * batch * input * output (multiply + add)
num_flops = 2 * batch_size * input_dim * output_dim
# Bytes = weights + input + output (all FP16 = 2 bytes each)
num_bytes = 2 * (input_dim * output_dim + batch_size * input_dim + batch_size * output_dim)

analyze_operation(
    lambda x, W: torch.mm(x, W.T),
    (x, W),
    num_flops,
    num_bytes
)
# Output: Arithmetic Intensity: 1.0 ops/byte
#         Achieved: 3.2 TFLOPS
#         Roofline Limit: 3.35 TFLOPS
#         Efficiency: 95.5% (Memory-bound)

The Hardware Landscape

Different hardware platforms have dramatically different characteristics. Understanding these differences is essential for choosing the right optimization strategy.

NVIDIA GPUs: The Deep Learning Workhorse

NVIDIA GPUs dominate deep learning inference due to their Tensor Cores—specialized matrix units that accelerate neural network operations.

NVIDIA GPU Generations for Inference
GPU Arch FP16 TFLOPS INT8 TOPS Memory Bandwidth Key Feature V100 Volta 125 32 GB HBM2 900 GB/s First Tensor Cores T4 Turing 65 130 16 GB GDDR6 320 GB/s INT8 Tensor Cores A100 Ampere 312 624 80 GB HBM2e 2.0 TB/s 2:4 Sparsity, BF16 A10G Ampere 125 250 24 GB GDDR6X 600 GB/s Cost-effective L4 Ada 120 485 24 GB GDDR6 300 GB/s FP8, low power (72W) L40S Ada 362 1452 48 GB GDDR6 864 GB/s FP8, inference opt H100 Hopper 2000 4000 80 GB HBM3 3.35 TB/s FP8, Transformer Engine H200 Hopper 2000 4000 141 GB HBM3e 4.8 TB/s Max memory, LLMs B200 Blackwell ~4500 ~9000 192 GB HBM3e 8.0 TB/s FP4, 2x H100 perf Tensor Core Feature Support by Generation Feature Volta/Turing Ampere Ada Hopper INT8 Tensor Cores ✓✓ ✓✓ ✓✓ 2:4 Sparsity ✓✓ ✓✓ FP8 Support ✓✓
GPU capabilities have evolved rapidly. Choose hardware based on required features.

CPUs: Underrated for Inference

Modern CPUs have significant neural network acceleration capabilities:

CPUs excel when:

Edge & Mobile Accelerators

For edge deployment, specialized hardware provides better power efficiency:

Accelerator Platform INT8 TOPS Power Efficiency Use Case
Apple Neural Engine M1/M2/M3 Macs, iPhones 11-38 ~5-15W ~3 TOPS/W On-device ML, Siri
Google Edge TPU Coral, Pixel phones 4 2W 2 TOPS/W IoT, edge inference
Qualcomm Hexagon Snapdragon 8 Gen 3 73 ~10W ~7 TOPS/W Mobile, on-device AI
NVIDIA Jetson Orin Robotics, automotive 275 60W ~4.5 TOPS/W Autonomous systems
Intel Movidius Vision applications 1 1W 1 TOPS/W Cameras, drones
Hardware Selection Guidelines

Choose based on your constraints:

  • Maximum throughput: H100/H200 with INT8 or FP8
  • Large models: H200 (141GB) or multi-GPU setups
  • Cost-effective: L4 or A10G for inference
  • Edge deployment: Match accelerator to power/latency budget
  • CPU-only: Use AVX-512/AMX with optimized runtimes (ONNX, OpenVINO)

Latency vs Throughput: Understanding the Tradeoff

Two fundamental metrics compete for optimization:

The relationship is not simple—optimizing for one often hurts the other:

Latency-Throughput Tradeoff with Batching
Batch Size 1 4 16 32 64 128 256 Throughput (tokens/s) Latency (ms) 50 200 350 400 20ms 40ms 100ms 200ms Sweet Spot Throughput Latency
Batch size 16-32 often provides the best throughput/latency balance for LLMs

Batching Strategies for LLMs

Modern LLM serving uses sophisticated batching to maximize efficiency:

Latency Budget Guidelines

Target latencies by application type:

  • Real-time chat: <100ms first token, <50ms/token streaming
  • Search/ranking: <50ms total (time-to-glass critical)
  • Batch processing: Latency doesn't matter, maximize throughput
  • Code completion: <200ms (user typing tolerance)

Overview of Optimization Techniques

This guide covers three major categories of inference optimization, each with distinct hardware implications:

Inference Optimization Landscape
Trained Model (FP32, Dense) Pruning Remove unnecessary weights • Unstructured: High sparsity • Structured: GPU-friendly • N:M: Hardware-native • 2-10x speedup possible Quantization Reduce numerical precision • INT8: 4x memory reduction • INT4/FP8: 8x reduction • PTQ vs QAT tradeoffs • 2-4x throughput gain Knowledge Distillation Train smaller model to mimic larger one • DistilBERT: 60% smaller, 97% accuracy • Can combine with pruning + quantization • 2-6x end-to-end speedup Combined Effect Prune + Quantize + Distill 10-100x speedup!
Three complementary approaches to model compression, each with hardware-specific implications

What This Guide Covers

In the following sections, we'll explore each technique in depth, with particular focus on:

  1. Theoretical foundations: Why does this technique work? What are the mathematical principles?
  2. Hardware implications: How does this technique interact with GPU/CPU architecture? Why does it actually speed up inference?
  3. Implementation details: Practical code examples and library recommendations
  4. CNN vs Transformer differences: How does each architecture respond differently to optimization?
  5. Combining techniques: How to stack optimizations for maximum effect
Prerequisites

This guide assumes familiarity with:

  • Basic neural network architectures (CNNs, Transformers)
  • PyTorch or TensorFlow fundamentals
  • Basic linear algebra (matrix multiplication)
  • Understanding of floating-point vs integer arithmetic

2. Pruning Foundations & Unstructured Pruning

Pruning is the process of removing unnecessary parameters from a neural network. The core insight is simple but profound: most neural networks are dramatically overparameterized. A well-trained network can often lose 80-95% of its weights with minimal accuracy loss.

But here's the critical question this section addresses: Does removing weights actually speed up inference? The answer depends entirely on the pruning pattern and the target hardware. Understanding this hardware-algorithm interaction is essential for effective pruning.

Why Pruning Works: The Lottery Ticket Hypothesis

The theoretical foundation for pruning comes from the Lottery Ticket Hypothesis (Frankle & Carlin, 2019):

The Lottery Ticket Hypothesis

"A randomly-initialized, dense neural network contains a subnetwork that is initialized such that—when trained in isolation—it can match the test accuracy of the original network after training for at most the same number of iterations."

In simpler terms: within every large network, there exists a small "winning ticket" subnetwork that can achieve the same performance. The rest of the parameters are essentially lottery losers—they don't contribute meaningfully to the final predictions.

Evidence for Overparameterization

Multiple lines of evidence suggest neural networks have far more parameters than necessary:

Evidence for Neural Network Overparameterization
Pruning Results • ResNet-50: 80% prunable → <1% acc drop • BERT-base: 70% prunable → <1% acc drop • VGG-16: 90%+ prunable • LLaMA-7B: 50% prunable (harder) Most weights are unnecessary! Weight Distribution • Weights follow near-Gaussian distribution • Most weights cluster near zero • Small weights = small contribution • Large weights are rare but important Remove small → keep accuracy! Lottery Ticket Evidence • Sparse nets train to same accuracy • Initialization matters (rewinding) • Some architectures more prunable • Transfer across tasks possible "Winning tickets" exist at init! Typical Weight Distribution in a Trained Network Prune (large negative) Keep (important) Prune (large positive) Most weights near zero -3σ 0 +3σ
Neural network weights concentrate near zero. Removing small-magnitude weights has minimal impact on output.

Pruning Criteria: Which Weights to Remove?

The fundamental question in pruning is: How do we identify which weights are "unnecessary"? Several criteria have been developed:

1. Magnitude-Based Pruning (Most Common)

The simplest and most widely used criterion: remove weights with the smallest absolute values.

importance(w) = |w|

Intuition: A weight close to zero contributes little to the output of a neuron. Mathematically, if $y = \sum_i w_i x_i$, then small $|w_i|$ means small contribution to $y$.

2. Gradient-Based Pruning

Consider both weight magnitude and how much the loss changes when the weight is perturbed:

importance(w) = |w · ∂L/∂w|

This is the first-order Taylor expansion of the loss change. Weights with small magnitude but large gradients might still be important.

3. Second-Order Methods (OBS, OBD)

Use the Hessian to estimate the impact of removing a weight:

ΔL ≈ ½ · w² · H_{ww}

Where $H_{ww}$ is the corresponding diagonal element of the Hessian. More accurate but computationally expensive.

4. Activation-Based Pruning

For neurons/channels, measure importance by average activation magnitude:

importance(neuron_j) = E[|activation_j|]

Neurons that rarely activate (or activate weakly) can be removed.

Pruning Criteria Comparison
Criterion Formula Pros Cons Use Case Magnitude (L1 norm) |w| Simple, fast No training needed Ignores gradients Not loss-aware General purpose Taylor (1st) Gradient-based |w · ∂L/∂w| Loss-aware Better accuracy Needs gradients More compute Fine-grained Hessian (2nd) OBS/OBD w²·H_{ww} Most accurate Theoretical basis Very expensive O(n²) or approx Small models Activation For neurons E[|act|] Structured Hardware-friendly Needs data Per-neuron only Channel pruning In practice: Magnitude pruning achieves 90% of the benefit with 10% of the complexity. Start there.
Magnitude-based pruning is the workhorse; use gradient methods for fine-tuning

Unstructured Pruning: Maximum Flexibility

Unstructured pruning (also called fine-grained or weight-level pruning) removes individual weights anywhere in the network. This provides maximum flexibility—any weight can be pruned regardless of position.

The Unstructured Pruning Process

Unstructured Pruning: Step-by-Step Process
Step 1: Original Dense Matrix 0.82 0.03 0.51 -0.02 -0.08 0.67 0.01 -0.91 0.45 -0.04 0.73 0.02 0.05 -0.88 -0.06 0.58 16 weights, 100% density Rank by |w| threshold=0.4 Step 2: Apply Binary Mask 0.82 0 0.51 0 0 0.67 0 -0.91 0.45 0 0.73 0 0 -0.88 0 0.58 8 non-zero, 50% sparsity Fine-tune 5-10 epochs Step 3: Sparse Result 0.89 0 0.48 0 0 0.71 0 -0.95 0.42 0 0.78 0 0 -0.92 0 0.62 Accuracy recovered, 50% sparse Unstructured Pruning Characteristics ✓ Maximum flexibility: any weight can be pruned ✓ Highest achievable sparsity (80-95%) ✗ Irregular memory access pattern ✗ Requires sparse storage format + sparse kernels
Unstructured pruning creates irregular sparsity patterns that are challenging for hardware

Implementation: Magnitude Pruning in PyTorch

Python unstructured_pruning.py
import torch
import torch.nn.utils.prune as prune
import copy

def magnitude_prune_model(model, sparsity=0.5):
    """
    Apply unstructured magnitude pruning to all Linear and Conv2d layers.
    
    Args:
        model: PyTorch model to prune
        sparsity: Fraction of weights to remove (0.5 = 50% pruned)
    
    Returns:
        Pruned model with masks applied
    """
    model = copy.deepcopy(model)
    
    for name, module in model.named_modules():
        if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
            # Apply L1 unstructured pruning
            prune.l1_unstructured(
                module,
                name='weight',
                amount=sparsity
            )
            print(f"Pruned {name}: {sparsity*100:.0f}% sparsity")
    
    return model


def count_sparsity(model):
    """Calculate actual sparsity of the model."""
    total_params = 0
    zero_params = 0
    
    for name, param in model.named_parameters():
        if 'weight' in name:
            total_params += param.numel()
            zero_params += (param == 0).sum().item()
    
    sparsity = zero_params / total_params
    print(f"Total parameters: {total_params:,}")
    print(f"Zero parameters: {zero_params:,}")
    print(f"Sparsity: {sparsity*100:.1f}%")
    return sparsity


def make_pruning_permanent(model):
    """Remove pruning hooks and make sparse weights permanent."""
    for name, module in model.named_modules():
        if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
            try:
                prune.remove(module, 'weight')
            except ValueError:
                pass  # Already permanent
    return model


# Example usage
from torchvision.models import resnet50

# Load pre-trained model
model = resnet50(pretrained=True)
model.eval()

# Check original size
print("Original model:")
count_sparsity(model)

# Apply 70% unstructured pruning
pruned_model = magnitude_prune_model(model, sparsity=0.7)

# Check after pruning
print("\nAfter pruning:")
count_sparsity(pruned_model)

# Make permanent (removes hooks)
pruned_model = make_pruning_permanent(pruned_model)

Hardware Implications: Why Unstructured Pruning Doesn't Speed Up GPUs

Here's the critical insight that many practitioners miss: unstructured pruning, despite achieving high sparsity, often provides NO speedup on modern GPUs. Understanding why requires diving into how GPUs actually execute matrix operations.

How GPUs Execute Dense Matrix Multiplication

GPUs achieve high throughput through data parallelism and memory coalescing:

Dense GEMM Execution on GPU (Simplified)
Dense Matrix in HBM HBM Memory (Global) Contiguous memory layout Coalesced load (fast!) 32 threads load 32 consecutive values Shared Memory (Fast) Tile of matrix loaded Tensor Core GEMM Tensor Cores 16×16×16 matrix multiply Sparse Matrix in HBM HBM Memory (Global) Scattered non-zero values Scattered load (slow!) Each thread loads from different location Index Lookups Required CSR/COO: row_ptr, col_idx arrays Generic CUDA kernels CUDA Cores (not Tensor Cores!) No hardware acceleration
Dense operations use coalesced memory access and Tensor Cores; sparse operations suffer from scattered access

The Three Killers of Sparse GPU Performance

Unstructured sparsity causes three fundamental problems on GPUs:

Why Unstructured Sparsity Fails on GPUs
Problem 1: Memory Access GPU warps (32 threads) expect to load 32 consecutive values in ONE transaction. With sparse: • Non-zero values scattered in memory • Each thread needs separate memory access • 32x more memory transactions! Problem 2: Index Overhead Sparse formats need metadata to track which positions are non-zero. CSR format example: • values: [0.5, 0.3, 0.7, 0.2] • col_idx: [1, 4, 0, 3] (4 ints) • row_ptr: [0, 2, 4] (index overhead) Problem 3: No Tensor Cores Tensor Cores accelerate dense 16×16 matrix tiles by 8-16x. With unstructured sparse: • Irregular pattern → no Tensor Cores • Fall back to CUDA cores • Lose 4-16x potential speedup The Math: Why 90% Sparsity ≠ 10x Speedup Dense GEMM (1024×1024): ~2 ms on A100 (Tensor Core optimized) Sparse GEMM (90% sparse, same size): ~3-5 ms on A100 (cuSPARSE, no Tensor Cores) 90% fewer operations, but SLOWER than dense! When Unstructured Sparsity DOES Help ✓ CPU inference (irregular access less penalized) ✓ Very high sparsity (>95%) on specialized hardware
Unstructured sparsity trades compute for memory inefficiency—often a bad trade on GPUs

Sparse Storage Formats

Sparse matrices require special storage formats to avoid storing zeros. Common formats include:

Sparse Matrix Storage Formats
Original Sparse Matrix (4×4) 5 0 0 3 0 8 0 0 0 0 2 0 1 0 0 4 6 non-zeros / 16 total = 37.5% density COO (Coordinate) row: [0, 0, 1, 2, 3, 3] col: [0, 3, 1, 2, 0, 3] values: [5, 3, 8, 2, 1, 4] Storage: 6×3 = 18 values Good for: Construction, transpose CSR (Compressed Sparse Row) row_ptr: [0, 2, 3, 4, 6] col_idx: [0, 3, 1, 2, 0, 3] values: [5, 3, 8, 2, 1, 4] Storage: 5 + 6 + 6 = 17 values Good for: Row-wise access, SpMV CSC (Compressed Sparse Column) col_ptr: [0, 2, 3, 4, 6] row_idx: [0, 3, 1, 2, 0, 3] values: [5, 1, 8, 2, 3, 4] Storage: 5 + 6 + 6 = 17 values Good for: Column-wise access Format Comparison Format SpMV SpMM COO Slow Slow CSR Fast Medium CSC Medium Fast BSR Fast* Fast* *If sparsity is block-structured Index Overhead Warning At 50% sparsity, index arrays can be larger than saved value storage!
Different sparse formats optimize for different access patterns; all add index overhead

Benchmark: Dense vs Sparse on GPU

Python benchmark_sparse_vs_dense.py
import torch
import time
import numpy as np

def benchmark_dense_vs_sparse(size=4096, sparsity=0.9, num_runs=100):
    """
    Compare dense vs sparse matrix multiplication on GPU.
    This demonstrates why unstructured sparsity doesn't help on GPU.
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Create dense matrices
    A_dense = torch.randn(size, size, device=device, dtype=torch.float16)
    B = torch.randn(size, size, device=device, dtype=torch.float16)
    
    # Create sparse version (same values, but sparse format)
    mask = torch.rand(size, size, device=device) > sparsity
    A_sparse_values = A_dense * mask.float()
    A_sparse = A_sparse_values.to_sparse_csr()
    
    actual_sparsity = 1 - (A_sparse._nnz() / (size * size))
    print(f"Matrix size: {size}×{size}")
    print(f"Target sparsity: {sparsity*100:.0f}%")
    print(f"Actual sparsity: {actual_sparsity*100:.1f}%")
    print(f"Non-zeros: {A_sparse._nnz():,} / {size*size:,}")
    print()
    
    # Warmup
    for _ in range(10):
        _ = torch.mm(A_dense, B)
        _ = torch.sparse.mm(A_sparse, B)
    torch.cuda.synchronize()
    
    # Benchmark dense
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(num_runs):
        result_dense = torch.mm(A_dense, B)
    torch.cuda.synchronize()
    dense_time = (time.perf_counter() - start) / num_runs * 1000
    
    # Benchmark sparse
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(num_runs):
        result_sparse = torch.sparse.mm(A_sparse, B)
    torch.cuda.synchronize()
    sparse_time = (time.perf_counter() - start) / num_runs * 1000
    
    print(f"Dense GEMM:  {dense_time:.2f} ms")
    print(f"Sparse SpMM: {sparse_time:.2f} ms")
    print(f"Speedup: {dense_time/sparse_time:.2f}x")
    
    if sparse_time > dense_time:
        print(f"\n⚠️  Sparse is {sparse_time/dense_time:.1f}x SLOWER despite {actual_sparsity*100:.0f}% sparsity!")
    
    return dense_time, sparse_time

# Run benchmark
if __name__ == "__main__":
    print("="*60)
    print("Dense vs Sparse GPU Benchmark")
    print("="*60)
    
    for sparsity in [0.5, 0.7, 0.9, 0.95]:
        print(f"\n--- Sparsity: {sparsity*100:.0f}% ---")
        benchmark_dense_vs_sparse(size=4096, sparsity=sparsity)

# Typical output on A100:
# --- Sparsity: 90% ---
# Matrix size: 4096×4096
# Dense GEMM:  0.45 ms
# Sparse SpMM: 1.23 ms
# ⚠️  Sparse is 2.7x SLOWER despite 90% sparsity!
Critical Insight: Unstructured Sparsity on GPUs

Do NOT expect speedups from unstructured pruning on GPUs. In practice:

  • 50% sparsity: Usually slower than dense
  • 90% sparsity: Often still slower than dense
  • 95%+ sparsity: Might break even, rarely faster
  • The overhead of index calculations and irregular memory access outweighs the reduced compute

Where Unstructured Pruning Actually Helps

Despite the GPU limitations, unstructured pruning has valid use cases:

1. CPU Inference

CPUs handle irregular memory access better than GPUs because:

2. Memory Reduction (Not Speed)

Even if compute isn't faster, sparse storage reduces memory:

3. As a Prelude to Structured Pruning

Unstructured pruning can identify which structures to remove:

4. Specialized Hardware

Some hardware natively accelerates unstructured sparsity:

Python sparse_cpu_inference.py
# Using DeepSparse for CPU-optimized sparse inference
# pip install deepsparse

from deepsparse import compile_model
import numpy as np

# Load a pruned ONNX model
# DeepSparse automatically detects and accelerates sparsity
model_path = "pruned_bert_90sparse.onnx"

# Compile for optimized sparse CPU execution
engine = compile_model(model_path, batch_size=1)

# Run inference
inputs = {
    "input_ids": np.array([[101, 2023, 2003, 1037, 3231, 102]]),
    "attention_mask": np.array([[1, 1, 1, 1, 1, 1]]),
    "token_type_ids": np.array([[0, 0, 0, 0, 0, 0]])
}

outputs = engine.run(inputs)

# DeepSparse can achieve 3-5x speedup on CPU with 90% sparsity
# Compare: Dense PyTorch on same CPU is much slower

Iterative Pruning: Achieving Higher Sparsity

One-shot pruning (prune once to target sparsity) often hurts accuracy. Iterative pruning gradually increases sparsity while fine-tuning, achieving better accuracy-sparsity tradeoffs.

One-Shot vs Iterative Pruning
One-Shot Pruning Dense Prune 90% 90% Sparse ❌ Acc: 85% Large accuracy drop! Iterative Pruning Dense 30% fine-tune 60% fine-tune 90% ✓ Acc: 94% Much better accuracy retention! Accuracy vs Sparsity Comparison Sparsity (%) Accuracy (%) 0 30 60 90 99 80 90 95 100 One-shot: 85% Iterative: 94% Dense (100%)
Iterative pruning maintains ~94% accuracy at 90% sparsity vs 85% for one-shot

Implementation: Iterative Magnitude Pruning

Python iterative_pruning.py
import torch
import torch.nn.utils.prune as prune
from torch.utils.data import DataLoader

def iterative_pruning(
    model,
    train_loader,
    val_loader,
    target_sparsity=0.9,
    num_iterations=5,
    epochs_per_iteration=3,
    criterion=torch.nn.CrossEntropyLoss(),
    lr=1e-4
):
    """
    Iteratively prune and fine-tune to reach target sparsity.
    
    Args:
        model: Model to prune
        train_loader: Training data
        val_loader: Validation data
        target_sparsity: Final sparsity target (e.g., 0.9 for 90%)
        num_iterations: Number of prune-finetune cycles
        epochs_per_iteration: Fine-tuning epochs per cycle
    """
    device = next(model.parameters()).device
    
    # Calculate per-iteration sparsity increase
    # Use exponential schedule: prune less early, more later
    sparsities = []
    current = 0.0
    for i in range(num_iterations):
        # Each iteration removes (1-s) of remaining weights
        remaining_to_prune = target_sparsity - current
        prune_fraction = remaining_to_prune / (num_iterations - i)
        current += prune_fraction
        sparsities.append(current)
    
    print(f"Pruning schedule: {[f'{s*100:.0f}%' for s in sparsities]}")
    
    for iteration, target_sp in enumerate(sparsities):
        print(f"\n=== Iteration {iteration+1}/{num_iterations}: Target {target_sp*100:.0f}% ===")
        
        # Calculate how much to prune relative to current weights
        # We need to prune (target_sp - current_sp) / (1 - current_sp)
        current_sp = get_model_sparsity(model)
        if current_sp >= target_sp:
            continue
        
        relative_prune = (target_sp - current_sp) / (1 - current_sp)
        
        # Apply pruning
        for name, module in model.named_modules():
            if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
                prune.l1_unstructured(module, name='weight', amount=relative_prune)
        
        print(f"Sparsity after pruning: {get_model_sparsity(model)*100:.1f}%")
        
        # Fine-tune
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        
        for epoch in range(epochs_per_iteration):
            model.train()
            total_loss = 0
            
            for batch_idx, (data, target) in enumerate(train_loader):
                data, target = data.to(device), target.to(device)
                
                optimizer.zero_grad()
                output = model(data)
                loss = criterion(output, target)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            # Validate
            val_acc = evaluate(model, val_loader, device)
            print(f"  Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, Val Acc={val_acc*100:.1f}%")
    
    # Make pruning permanent
    for name, module in model.named_modules():
        if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
            try:
                prune.remove(module, 'weight')
            except:
                pass
    
    return model


def get_model_sparsity(model):
    """Calculate current sparsity of model."""
    total = 0
    zeros = 0
    for name, param in model.named_parameters():
        if 'weight' in name:
            total += param.numel()
            zeros += (param == 0).sum().item()
    return zeros / total if total > 0 else 0


def evaluate(model, val_loader, device):
    """Evaluate model accuracy."""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    
    return correct / total

Summary: Unstructured Pruning

Aspect Details
What it does Removes individual weights based on magnitude or other criteria
Achievable sparsity 80-95% with iterative pruning and fine-tuning
GPU speedup ❌ Usually NONE or negative (slower than dense)
CPU speedup ✓ 2-5x with optimized libraries (DeepSparse)
Memory savings ✓ Significant if using sparse storage format
Best for CPU inference, memory-constrained scenarios, pre-analysis for structured pruning
Not recommended for GPU inference speedup (use structured or N:M instead)
Key Takeaway

Unstructured pruning is a powerful technique, but NOT for GPU speedup. If your goal is faster GPU inference, skip ahead to Structured Pruning (Section 3) or N:M Sparsity (Section 5). Use unstructured pruning for CPU inference or memory reduction.

3. Structured Pruning for CNNs

Structured pruning removes entire architectural components rather than individual weights. For CNNs, this typically means removing filters (output channels) or channels (input features). The key advantage: the result is still a dense network that runs efficiently on standard hardware.

The Core Insight

Unstructured pruning creates sparse matrices that GPUs can't accelerate. Structured pruning creates smaller dense matrices that run on the same optimized kernels, just with reduced dimensions. A 50% channel-pruned network runs in approximately 50% of the time—no special hardware required.

Channel and Filter Pruning: The Fundamentals

In a convolutional layer, the weight tensor has shape [C_out, C_in, H, W] where:

Structured pruning can remove:

Filter Pruning vs Channel Pruning in CNNs
Original Convolution Layer Input (C_in=4) H×W×4 Conv2d weights [6, 4, 3, 3] 6 filters 4 ch each Output (C_out=6) H'×W'×6 Filter Pruning: Remove 2 Filters Input (still 4) 4 filters remain Output (C_out=4) ✓ H'×W'×4 What Happens to Next Layer? Critical: Next layer must also be modified! Next layer (before): Conv2d [8, 6, 3, 3] ↳ Expects 6 input channels Next layer (after): Conv2d [8, 4, 3, 3] ↳ Updated to expect 4 input channels Cascading Effect of Channel Pruning Conv1 [64,3,3,3] Conv2 [128→96,64] Prune filters Conv3 [256,128→96] Trim input ch Conv4 [256,256,3,3] Unchanged Why This Matters for Hardware • Results in smaller but DENSE tensors • Same cuDNN kernels, just smaller dimensions Compute Reduction from Filter Pruning Original: 6 filters × 4 channels × 9 = 216 MACs/pixel Pruned: 4 × 4 × 9 = 144 MACs/pixel 33% reduction!
Filter pruning reduces output channels; the next layer must have its input channels trimmed accordingly

Filter Importance Criteria

How do we decide which filters to remove? Several criteria have been proposed:

1. L1-Norm (Filter Weight Magnitude)

The simplest approach: remove filters with the smallest L1 norm of their weights.

importance(filter_i) = Σ|W[i, :, :, :]|

Intuition: Filters with small weights produce small activations, contributing less to the output.

2. L2-Norm (Euclidean Magnitude)

importance(filter_i) = sqrt(Σ W[i, :, :, :]²)

Similar to L1 but penalizes larger individual weights more heavily.

3. Batch Normalization Scaling Factor (γ)

If the layer is followed by BatchNorm, the γ parameter indicates channel importance:

importance(channel_i) = |γ_i|

During training with L1 regularization on γ, unimportant channels naturally have γ → 0.

4. Taylor Expansion (Gradient-Based)

Consider the change in loss when removing a filter:

importance(filter_i) = |Σ (∂L/∂A_i) · A_i|

Where $A_i$ is the activation of filter $i$. More computationally expensive but often more accurate.

5. Geometric Median

Remove filters most similar to other filters (redundant information):

importance(filter_i) = Σ_j ||filter_i - filter_j||

Low importance = filter is "replaceable" by others.

Filter Importance Criteria Comparison
Criterion Computation Accuracy Needs Data? Best For L1-Norm Very Fast Good No Quick baseline, large models L2-Norm Very Fast Good No Alternative to L1 BN-γ Very Fast Very Good* No When training with L1 reg on γ Taylor Moderate Best Yes (mini-batch) High-accuracy requirements Geo-Median Slow Good No Removing redundancy Recommendation: Start with L1-norm for speed; use Taylor for best accuracy; use BN-γ with sparsity-inducing training
L1-norm is fast and effective; Taylor expansion gives best accuracy at higher computational cost

Global vs Local Pruning

An important design choice: should we prune the same percentage from each layer, or rank filters globally?

Local vs Global Pruning Strategy
Local Pruning (50% each layer) Conv1: 64 filters Keep Prune → 32 filters Conv2: 128 filters → 64 filters Each layer loses 50% regardless of importance Global Pruning (50% total) Conv1: 64 filters (more important) → 54 filters (16% pruned) Conv2: 128 filters (more redundant) → 42 filters (67% pruned) Important layers pruned less, redundant layers pruned more Local Pruning Issues • First/last layers often more sensitive • Forces pruning even important layers Global Pruning Benefits • Naturally protects sensitive layers • Better accuracy at same FLOPs reduction ⚠️ Layer Sensitivity Warning First conv layer and classification head are often highly sensitive—consider excluding from pruning or using lower rates
Global pruning typically achieves 1-3% better accuracy than local pruning at the same FLOPs

Implementation: Channel Pruning in PyTorch

Python structured_pruning.py
import torch
import torch.nn as nn
import copy
from typing import List, Dict, Tuple

class StructuredPruner:
    """
    Structured (channel/filter) pruning for CNNs.
    Removes entire filters and adjusts subsequent layers.
    """
    
    def __init__(self, model: nn.Module, example_input: torch.Tensor):
        self.model = copy.deepcopy(model)
        self.example_input = example_input
        self.layer_info = self._analyze_model()
    
    def _analyze_model(self) -> Dict:
        """Analyze model structure to find prunable conv layers."""
        info = {}
        prev_layer = None
        
        for name, module in self.model.named_modules():
            if isinstance(module, nn.Conv2d):
                info[name] = {
                    'module': module,
                    'out_channels': module.out_channels,
                    'in_channels': module.in_channels,
                    'prev_conv': prev_layer
                }
                prev_layer = name
            elif isinstance(module, nn.BatchNorm2d):
                if prev_layer in info:
                    info[prev_layer]['bn'] = name
        
        return info
    
    def compute_filter_importance(
        self, 
        criterion: str = 'l1'
    ) -> Dict[str, torch.Tensor]:
        """
        Compute importance scores for each filter in each layer.
        
        Args:
            criterion: 'l1', 'l2', 'taylor', or 'bn_gamma'
        """
        importance = {}
        
        for name, info in self.layer_info.items():
            module = info['module']
            
            if criterion == 'l1':
                # L1 norm of each filter
                scores = module.weight.data.abs().sum(dim=[1, 2, 3])
            
            elif criterion == 'l2':
                # L2 norm of each filter
                scores = module.weight.data.pow(2).sum(dim=[1, 2, 3]).sqrt()
            
            elif criterion == 'bn_gamma' and 'bn' in info:
                # Use BatchNorm gamma as importance
                bn_name = info['bn']
                bn = dict(self.model.named_modules())[bn_name]
                scores = bn.weight.data.abs()
            
            else:
                # Default to L1
                scores = module.weight.data.abs().sum(dim=[1, 2, 3])
            
            importance[name] = scores
        
        return importance
    
    def get_pruning_mask(
        self,
        importance: Dict[str, torch.Tensor],
        prune_ratio: float,
        global_pruning: bool = True,
        min_channels: int = 8
    ) -> Dict[str, torch.Tensor]:
        """
        Determine which filters to keep based on importance scores.
        
        Args:
            importance: Per-layer importance scores
            prune_ratio: Fraction of filters to remove (0.5 = 50%)
            global_pruning: If True, rank globally; else per-layer
            min_channels: Minimum channels to keep per layer
        """
        masks = {}
        
        if global_pruning:
            # Concatenate all scores and find global threshold
            all_scores = torch.cat([s.flatten() for s in importance.values()])
            num_to_prune = int(all_scores.numel() * prune_ratio)
            threshold = torch.kthvalue(all_scores, num_to_prune).values.item()
            
            for name, scores in importance.items():
                # Keep filters above threshold
                mask = scores > threshold
                
                # Ensure minimum channels
                if mask.sum() < min_channels:
                    _, top_idx = scores.topk(min_channels)
                    mask = torch.zeros_like(mask, dtype=torch.bool)
                    mask[top_idx] = True
                
                masks[name] = mask
        else:
            # Per-layer pruning
            for name, scores in importance.items():
                num_filters = scores.numel()
                num_to_keep = max(min_channels, int(num_filters * (1 - prune_ratio)))
                _, top_idx = scores.topk(num_to_keep)
                
                mask = torch.zeros(num_filters, dtype=torch.bool)
                mask[top_idx] = True
                masks[name] = mask
        
        return masks
    
    def apply_pruning(self, masks: Dict[str, torch.Tensor]) -> nn.Module:
        """
        Create a new model with pruned channels.
        This creates actual smaller dense layers, not masked sparse ones.
        """
        # Build channel mapping for each layer
        channel_mapping = {}
        for name, mask in masks.items():
            keep_indices = torch.where(mask)[0]
            channel_mapping[name] = keep_indices
        
        # Create new model with pruned layers
        new_model = copy.deepcopy(self.model)
        
        for name, info in self.layer_info.items():
            if name not in masks:
                continue
            
            old_module = info['module']
            keep_out = channel_mapping[name]
            
            # Determine input channels to keep
            prev_conv = info.get('prev_conv')
            if prev_conv and prev_conv in channel_mapping:
                keep_in = channel_mapping[prev_conv]
            else:
                keep_in = torch.arange(old_module.in_channels)
            
            # Create new pruned conv layer
            new_conv = nn.Conv2d(
                in_channels=len(keep_in),
                out_channels=len(keep_out),
                kernel_size=old_module.kernel_size,
                stride=old_module.stride,
                padding=old_module.padding,
                bias=old_module.bias is not None
            )
            
            # Copy weights for kept channels
            new_conv.weight.data = old_module.weight.data[keep_out][:, keep_in]
            if old_module.bias is not None:
                new_conv.bias.data = old_module.bias.data[keep_out]
            
            # Replace in model
            self._set_module(new_model, name, new_conv)
            
            # Update BatchNorm if present
            if 'bn' in info:
                bn_name = info['bn']
                old_bn = dict(self.model.named_modules())[bn_name]
                new_bn = nn.BatchNorm2d(len(keep_out))
                
                new_bn.weight.data = old_bn.weight.data[keep_out]
                new_bn.bias.data = old_bn.bias.data[keep_out]
                new_bn.running_mean = old_bn.running_mean[keep_out]
                new_bn.running_var = old_bn.running_var[keep_out]
                
                self._set_module(new_model, bn_name, new_bn)
        
        return new_model
    
    def _set_module(self, model: nn.Module, name: str, new_module: nn.Module):
        """Set a module in the model by name."""
        parts = name.split('.')
        parent = model
        for part in parts[:-1]:
            parent = getattr(parent, part)
        setattr(parent, parts[-1], new_module)
    
    def prune(
        self,
        prune_ratio: float = 0.5,
        criterion: str = 'l1',
        global_pruning: bool = True
    ) -> nn.Module:
        """
        Main method: prune the model by specified ratio.
        
        Returns a new, smaller dense model.
        """
        # Step 1: Compute importance
        importance = self.compute_filter_importance(criterion)
        
        # Step 2: Get pruning masks
        masks = self.get_pruning_mask(
            importance, 
            prune_ratio, 
            global_pruning
        )
        
        # Step 3: Apply pruning
        pruned_model = self.apply_pruning(masks)
        
        # Print stats
        total_orig = sum(p.numel() for p in self.model.parameters())
        total_new = sum(p.numel() for p in pruned_model.parameters())
        
        print(f"Original parameters: {total_orig:,}")
        print(f"Pruned parameters: {total_new:,}")
        print(f"Reduction: {(1 - total_new/total_orig)*100:.1f}%")
        
        return pruned_model


# Example usage
if __name__ == "__main__":
    from torchvision.models import resnet18
    
    # Load model
    model = resnet18(pretrained=True)
    model.eval()
    
    # Create pruner
    example_input = torch.randn(1, 3, 224, 224)
    pruner = StructuredPruner(model, example_input)
    
    # Prune 50% of filters
    pruned_model = pruner.prune(
        prune_ratio=0.5,
        criterion='l1',
        global_pruning=True
    )
    
    # Verify output shape
    with torch.no_grad():
        orig_out = model(example_input)
        pruned_out = pruned_model(example_input)
        print(f"Original output shape: {orig_out.shape}")
        print(f"Pruned output shape: {pruned_out.shape}")

Hardware Speedup: Why Structured Pruning Works

Unlike unstructured pruning, structured pruning provides real speedups on any hardware because the result is still a dense network with standard tensor operations.

Why Structured Pruning Achieves Real Speedup
Unstructured Pruning (50%) Structured Pruning (50%) Matrix remains same size (N×M) Requires sparse storage + sparse kernels No Tensor Core acceleration Matrix is smaller but DENSE (N×M/2) All values non-zero Uses standard dense kernels Full Tensor Core acceleration! Performance Comparison (A100, 50% pruning) Unstructured: ~1.2x slower (index overhead) Structured: ~2x faster Actual speedup! The Math Conv: [64→32, 64, 3, 3] @ input [1, 64, 56, 56] FLOPs: 64×32×3×3×56×56 = 115M → 58M Key Insight: Structured pruning reduces dimensions, not density Same optimized cuDNN kernels, just with smaller input/output dimensions → proportional speedup
Structured pruning achieves near-linear speedup because it maintains dense operations

Benchmark: Structured Pruning Speedup

Python benchmark_structured.py
import torch
import torch.nn as nn
import time

def benchmark_structured_pruning(prune_ratio=0.5):
    """
    Demonstrate that structured pruning gives proportional speedup.
    """
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # Original convolution
    in_channels = 256
    out_channels = 256
    kernel_size = 3
    spatial_size = 56
    
    conv_original = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1).to(device)
    
    # Pruned convolution (50% fewer output channels)
    pruned_out = int(out_channels * (1 - prune_ratio))
    conv_pruned = nn.Conv2d(in_channels, pruned_out, kernel_size, padding=1).to(device)
    
    # Input tensor
    x = torch.randn(1, in_channels, spatial_size, spatial_size, device=device)
    
    # Warmup
    for _ in range(20):
        _ = conv_original(x)
        _ = conv_pruned(x)
    torch.cuda.synchronize()
    
    # Benchmark original
    num_runs = 100
    start = time.perf_counter()
    for _ in range(num_runs):
        _ = conv_original(x)
    torch.cuda.synchronize()
    orig_time = (time.perf_counter() - start) / num_runs * 1000
    
    # Benchmark pruned
    start = time.perf_counter()
    for _ in range(num_runs):
        _ = conv_pruned(x)
    torch.cuda.synchronize()
    pruned_time = (time.perf_counter() - start) / num_runs * 1000
    
    # Calculate stats
    orig_params = in_channels * out_channels * kernel_size * kernel_size
    pruned_params = in_channels * pruned_out * kernel_size * kernel_size
    
    orig_flops = 2 * in_channels * out_channels * kernel_size * kernel_size * spatial_size * spatial_size
    pruned_flops = 2 * in_channels * pruned_out * kernel_size * kernel_size * spatial_size * spatial_size
    
    print(f"Original: {out_channels} output channels")
    print(f"Pruned: {pruned_out} output channels ({prune_ratio*100:.0f}% pruned)")
    print()
    print(f"Parameters: {orig_params:,}{pruned_params:,} ({(1-pruned_params/orig_params)*100:.0f}% reduction)")
    print(f"FLOPs: {orig_flops/1e6:.1f}M → {pruned_flops/1e6:.1f}M ({(1-pruned_flops/orig_flops)*100:.0f}% reduction)")
    print()
    print(f"Latency: {orig_time:.3f} ms → {pruned_time:.3f} ms")
    print(f"Speedup: {orig_time/pruned_time:.2f}x")
    print()
    
    theoretical_speedup = orig_flops / pruned_flops
    actual_speedup = orig_time / pruned_time
    efficiency = actual_speedup / theoretical_speedup * 100
    print(f"Theoretical speedup: {theoretical_speedup:.2f}x")
    print(f"Actual speedup: {actual_speedup:.2f}x")
    print(f"Efficiency: {efficiency:.0f}%")

# Run benchmark
benchmark_structured_pruning(prune_ratio=0.5)

# Typical output on A100:
# Original: 256 output channels
# Pruned: 128 output channels (50% pruned)
# 
# Parameters: 589,824 → 294,912 (50% reduction)
# FLOPs: 1849.7M → 924.8M (50% reduction)
# 
# Latency: 0.182 ms → 0.098 ms
# Speedup: 1.86x
# 
# Theoretical speedup: 2.00x
# Actual speedup: 1.86x
# Efficiency: 93%     ← Near-linear speedup!

Example: Pruning ResNet-50

Let's see real-world results from pruning a ResNet-50 on ImageNet:

ResNet-50 Structured Pruning Results (ImageNet)
Method FLOPs Params Top-1 Acc Acc Drop GPU Speedup ResNet-50 (Original) 4.1 GFLOPs 25.6M 76.2% 1.0x 30% Filter Pruned 2.9 GFLOPs 17.8M 75.8% -0.4% 1.35x 50% Filter Pruned 2.0 GFLOPs 12.8M 75.1% -1.1% 1.85x 70% Filter Pruned 1.2 GFLOPs 7.7M 73.5% -2.7% 2.8x Accuracy vs FLOPs Tradeoff FLOPs (GFLOPs) Top-1 (%) Original 30% 50% 70% Sweet spot: 50%
50% filter pruning offers the best accuracy-speedup tradeoff for ResNet-50

Layer Sensitivity Analysis

Not all layers are equally prunable. Understanding layer sensitivity helps achieve better accuracy-efficiency tradeoffs.

Python sensitivity_analysis.py
import torch
import torch.nn as nn
import copy
from typing import Dict, List

def analyze_layer_sensitivity(
    model: nn.Module,
    val_loader,
    prune_ratios: List[float] = [0.1, 0.2, 0.3, 0.5, 0.7],
    criterion = nn.CrossEntropyLoss()
) -> Dict[str, Dict[float, float]]:
    """
    Analyze how sensitive each layer is to pruning.
    For each layer, prune ONLY that layer at various ratios
    and measure accuracy drop.
    
    Returns:
        Dict mapping layer_name -> {prune_ratio: accuracy}
    """
    device = next(model.parameters()).device
    
    # Get baseline accuracy
    baseline_acc = evaluate(model, val_loader, device)
    print(f"Baseline accuracy: {baseline_acc*100:.2f}%")
    
    sensitivity = {}
    
    # Find all Conv2d layers
    conv_layers = [(name, module) for name, module in model.named_modules() 
                   if isinstance(module, nn.Conv2d)]
    
    for layer_name, layer_module in conv_layers:
        sensitivity[layer_name] = {}
        print(f"\nAnalyzing: {layer_name}")
        
        for ratio in prune_ratios:
            # Create a copy and prune only this layer
            model_copy = copy.deepcopy(model)
            
            # Apply pruning to this layer only
            for name, module in model_copy.named_modules():
                if name == layer_name:
                    prune_layer_filters(module, ratio)
                    break
            
            # Evaluate
            acc = evaluate(model_copy, val_loader, device)
            sensitivity[layer_name][ratio] = acc
            acc_drop = (baseline_acc - acc) * 100
            
            print(f"  {ratio*100:.0f}% pruned: {acc*100:.2f}% (drop: {acc_drop:.2f}%)")
            
            del model_copy
    
    return sensitivity


def prune_layer_filters(conv_module: nn.Conv2d, ratio: float):
    """Zero out the least important filters in a conv layer."""
    with torch.no_grad():
        # Compute L1 importance
        importance = conv_module.weight.data.abs().sum(dim=[1, 2, 3])
        
        # Find threshold
        num_to_prune = int(len(importance) * ratio)
        if num_to_prune == 0:
            return
        
        threshold = importance.kthvalue(num_to_prune).values
        
        # Zero out unimportant filters
        mask = importance > threshold
        conv_module.weight.data[~mask] = 0


def evaluate(model, val_loader, device):
    """Quick evaluation for sensitivity analysis."""
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for i, (data, target) in enumerate(val_loader):
            if i >= 50:  # Use subset for speed
                break
            data, target = data.to(device), target.to(device)
            output = model(data)
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    
    return correct / total


def visualize_sensitivity(sensitivity: Dict[str, Dict[float, float]]):
    """
    Print sensitivity analysis results.
    Layers with high sensitivity should be pruned less.
    """
    print("\n" + "="*60)
    print("Layer Sensitivity Summary (lower = more prunable)")
    print("="*60)
    
    # Calculate average sensitivity per layer at 50% pruning
    layer_sensitivity = {}
    for layer_name, ratios in sensitivity.items():
        if 0.5 in ratios:
            # Sensitivity = accuracy drop at 50%
            baseline = ratios.get(0.0, ratios[min(ratios.keys())])
            layer_sensitivity[layer_name] = baseline - ratios[0.5]
    
    # Sort by sensitivity
    sorted_layers = sorted(layer_sensitivity.items(), key=lambda x: x[1])
    
    print("\nMost Prunable (low sensitivity):")
    for name, sens in sorted_layers[:5]:
        print(f"  {name}: {sens*100:.2f}% drop at 50% pruning")
    
    print("\nLeast Prunable (high sensitivity):")
    for name, sens in sorted_layers[-5:]:
        print(f"  {name}: {sens*100:.2f}% drop at 50% pruning")
Typical Layer Sensitivity in ResNet
Accuracy Drop at 50% Pruning by Layer (ResNet-50) Acc Drop (%) 0 5 10 15 20 conv1 18% layer1.0 12% layer2.0 layer2.1 layer2.2 layer3.0 layer3.1 layer3.2 layer3.3 layer4.0 layer4.1 layer4.2 fc 10% High sensitivity (don't prune) Medium sensitivity (prune carefully) Low sensitivity (prune aggressively)
First conv layer and classifier are typically most sensitive; middle layers are most prunable
Practical Recommendation

When pruning CNNs:

  • Skip or lightly prune the first conv layer (conv1) and classification head
  • Aggressively prune middle layers (layer2, layer3 in ResNets)
  • Use global pruning to automatically allocate pruning budget based on importance
  • Fine-tune for 10-30% of original training epochs after pruning

Summary: Structured Pruning for CNNs

Aspect Details
What it does Removes entire filters/channels, creating smaller dense layers
Achievable reduction 30-70% FLOPs reduction with <2% accuracy loss typical
GPU speedup ✓ Near-linear speedup (50% pruning → ~1.8x faster)
CPU speedup ✓ Same proportional speedup
Best criteria L1-norm for speed, Taylor for best accuracy
Sensitive layers First conv, classification head—prune less or skip
Prunable layers Middle layers (layer2, layer3 in ResNet)

4. Structured Pruning for Transformers

Transformers have unique structure compared to CNNs: attention heads, feed-forward networks (FFN), and embedding dimensions. Each offers different pruning opportunities with distinct hardware implications. This section provides a deep dive into head pruning—the most impactful and nuanced form of transformer pruning.

Transformer Anatomy: What Can We Prune?

A transformer block contains several prunable components:

Prunable Components in a Transformer Block
Transformer Block (Layer i) Input: [B, SeqLen, D_model] Multi-Head Attention Head 1 D_k Head 2 Head 3 Prune! ... Head H Add & LayerNorm Feed-Forward Network D→4D Prune width! 4D→D Output: [B, SeqLen, D_model] Pruning Options 1. Head Pruning • Remove entire attention heads • Hardware-friendly ✓ 2. FFN Width Pruning • Reduce intermediate dim • Hardware-friendly ✓ 3. Layer Pruning • Remove entire layers • Maximum speedup ✓ 4. Embedding Dim • Reduce D_model • Requires retraining Parameter Distribution (BERT-base example) Multi-Head Attention: 28.3M (33%) Feed-Forward: 56.6M (66%) LayerNorm: 0.3M (<1%) Key Insight: FFN is 2/3 of params! But attention dominates compute at long sequences
FFN dominates parameters (66%), but attention dominates compute at long sequences due to O(n²) complexity

Head Pruning: A Deep Dive

Attention head pruning is the most impactful form of structured pruning for transformers. The key insight from research (Michel et al., 2019; Voita et al., 2019) is that many attention heads are redundant—a 12-head BERT model can often function well with 6-8 heads.

Why Heads Can Be Pruned

Attention heads learn specialized roles, but there's significant redundancy:

Attention Head Specialization in BERT
Discovered Attention Patterns in Pre-trained Models Positional Heads • Attend to previous token • Attend to next token • Attend to [CLS] / [SEP] Often redundant: Multiple heads do same thing Syntactic Heads • Verb → Subject • Noun → Determiner • Preposition → Object Usually important: Few heads per pattern Semantic Heads • Coreference resolution • Entity linking • Negation detection Task-dependent: Importance varies by task Diffuse Heads • No clear pattern • Near-uniform attn • Low entropy Usually prunable: Remove with little loss Typical Distribution in BERT-12 Heads Positional (30%) Syntactic (15%) Semantic (20%) Diffuse (35%) Pruning Potential by Layer Early (1-4) Less prunable (syntax) Middle (5-8) Most prunable Late (9-12) Task-dependent Key Finding: BERT can lose 40-50% of heads with <1% accuracy drop on most tasks (Michel et al., 2019)
~35% of attention heads have diffuse patterns and can often be safely pruned

Head Importance Metrics

How do we measure which heads to prune? Several methods exist:

1. Gradient-Based Importance (Taylor)

Estimate the impact of removing a head on the loss:

I_head(h) = |E[∑_t (∂L/∂A_h^t) · A_h^t]|

Where $A_h^t$ is the attention output of head $h$ at position $t$.

2. Attention Entropy

Heads with low entropy (peaked attention) often carry more information:

Entropy(h) = -∑_j α_j^h log(α_j^h)

High entropy = diffuse attention = likely prunable. Low entropy = focused attention = likely important.

3. Head Confidence

Measure how "confident" a head is in its attention pattern:

Confidence(h) = max_j(α_j^h)

Heads that consistently have low max attention may be contributing less.

4. Learnable Importance (Soft Masking)

Train a scalar importance weight per head during fine-tuning:

output_h = σ(z_h) · head_h(x)

Where $z_h$ is a learnable importance score. Add L1 regularization to push unimportant heads to zero.

Head Importance Metrics Comparison
Method Requires Cost Accuracy Best For Taylor (Gradient) Forward + backward pass Moderate Best Task-specific pruning Attention Entropy Forward pass only Low Good Quick analysis Head Confidence Forward pass only Low Medium Simple baseline Learnable Mask Training with L1 reg High Best Joint train+prune Use Taylor for post-training pruning; Learnable masks for training-time pruning
Taylor importance achieves best accuracy; entropy is fast for quick analysis

Implementation: Complete Head Pruning Pipeline

Python head_pruning.py
import torch
import torch.nn as nn
import copy
from transformers import BertModel, BertTokenizer
from typing import Dict, List, Tuple
import numpy as np

class HeadImportanceAnalyzer:
    """
    Analyze and compute importance scores for attention heads.
    """
    
    def __init__(self, model: nn.Module):
        self.model = model
        self.num_layers = model.config.num_hidden_layers
        self.num_heads = model.config.num_attention_heads
        self.head_dim = model.config.hidden_size // self.num_heads
        
        # Storage for importance scores
        self.taylor_scores = None
        self.entropy_scores = None
    
    def compute_taylor_importance(
        self, 
        dataloader, 
        num_batches: int = 50
    ) -> torch.Tensor:
        """
        Compute Taylor (gradient-based) importance for each head.
        
        Shape: [num_layers, num_heads]
        """
        self.model.eval()
        device = next(self.model.parameters()).device
        
        # Initialize importance scores
        importance = torch.zeros(self.num_layers, self.num_heads, device=device)
        
        # Register hooks to capture attention outputs and gradients
        attention_outputs = {}
        attention_grads = {}
        
        def save_output_hook(name):
            def hook(module, input, output):
                # output[0] is attention output: [batch, seq, hidden]
                attention_outputs[name] = output[0]
            return hook
        
        def save_grad_hook(name):
            def hook(grad):
                attention_grads[name] = grad
            return hook
        
        # Register hooks on attention layers
        handles = []
        for layer_idx in range(self.num_layers):
            attn = self.model.encoder.layer[layer_idx].attention.self
            name = f"layer_{layer_idx}"
            handle = attn.register_forward_hook(save_output_hook(name))
            handles.append(handle)
        
        # Process batches
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx >= num_batches:
                break
            
            # Move to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch.get('labels', input_ids).to(device)
            
            # Forward pass
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_attentions=True
            )
            
            # Compute loss (for gradient)
            # Using a simple proxy: variance of outputs
            loss = outputs.last_hidden_state.var()
            
            # Register gradient hooks
            for name, output in attention_outputs.items():
                output.retain_grad()
            
            # Backward pass
            loss.backward()
            
            # Compute Taylor importance per head
            for layer_idx in range(self.num_layers):
                name = f"layer_{layer_idx}"
                output = attention_outputs[name]  # [batch, seq, hidden]
                grad = output.grad
                
                if grad is None:
                    continue
                
                # Reshape to [batch, seq, num_heads, head_dim]
                batch_size, seq_len, _ = output.shape
                output_reshaped = output.view(batch_size, seq_len, self.num_heads, self.head_dim)
                grad_reshaped = grad.view(batch_size, seq_len, self.num_heads, self.head_dim)
                
                # Taylor importance: |output * grad|
                head_importance = (output_reshaped * grad_reshaped).abs().sum(dim=[0, 1, 3])
                importance[layer_idx] += head_importance
            
            # Clear gradients
            self.model.zero_grad()
        
        # Remove hooks
        for handle in handles:
            handle.remove()
        
        # Normalize
        importance = importance / num_batches
        self.taylor_scores = importance
        
        return importance
    
    def compute_entropy_importance(
        self, 
        dataloader, 
        num_batches: int = 50
    ) -> torch.Tensor:
        """
        Compute attention entropy for each head.
        Lower entropy = more focused = likely more important.
        
        Returns negative entropy so higher = more important.
        """
        self.model.eval()
        device = next(self.model.parameters()).device
        
        entropy = torch.zeros(self.num_layers, self.num_heads, device=device)
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(dataloader):
                if batch_idx >= num_batches:
                    break
                
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                
                outputs = self.model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    output_attentions=True
                )
                
                # attentions: tuple of [batch, num_heads, seq, seq]
                for layer_idx, attn in enumerate(outputs.attentions):
                    # Compute entropy per head
                    # attn: [batch, num_heads, seq, seq]
                    # Add small epsilon for numerical stability
                    attn_clamped = attn.clamp(min=1e-8)
                    head_entropy = -(attn_clamped * attn_clamped.log()).sum(dim=-1)
                    # Average over batch and sequence
                    head_entropy = head_entropy.mean(dim=[0, 2])
                    entropy[layer_idx] += head_entropy
        
        entropy = entropy / num_batches
        # Return negative entropy (higher = more important)
        self.entropy_scores = -entropy
        
        return -entropy
    
    def get_heads_to_prune(
        self,
        importance: torch.Tensor,
        prune_ratio: float = 0.5,
        global_pruning: bool = True,
        min_heads_per_layer: int = 1
    ) -> Dict[int, List[int]]:
        """
        Determine which heads to prune based on importance scores.
        
        Returns:
            Dict mapping layer_idx -> list of head indices to prune
        """
        if global_pruning:
            # Global pruning: rank all heads together
            flat_importance = importance.flatten()
            num_to_prune = int(flat_importance.numel() * prune_ratio)
            
            # Find threshold
            threshold = torch.kthvalue(flat_importance, num_to_prune).values.item()
            
            heads_to_prune = {}
            for layer_idx in range(self.num_layers):
                layer_importance = importance[layer_idx]
                prune_mask = layer_importance < threshold
                
                # Ensure minimum heads kept
                num_to_keep = self.num_heads - prune_mask.sum().item()
                if num_to_keep < min_heads_per_layer:
                    # Keep the top min_heads_per_layer
                    _, top_indices = layer_importance.topk(min_heads_per_layer)
                    prune_mask = torch.ones(self.num_heads, dtype=torch.bool)
                    prune_mask[top_indices] = False
                
                pruned_heads = torch.where(prune_mask)[0].tolist()
                if pruned_heads:
                    heads_to_prune[layer_idx] = pruned_heads
        
        else:
            # Per-layer pruning
            heads_to_prune = {}
            num_to_prune_per_layer = max(1, int(self.num_heads * prune_ratio))
            
            for layer_idx in range(self.num_layers):
                layer_importance = importance[layer_idx]
                _, indices = layer_importance.topk(num_to_prune_per_layer, largest=False)
                heads_to_prune[layer_idx] = indices.tolist()
        
        return heads_to_prune
    
    def visualize_importance(self, importance: torch.Tensor):
        """Print a text visualization of head importance."""
        print("\nHead Importance Heatmap:")
        print("(Higher = More Important)")
        print("-" * (10 + self.num_heads * 8))
        
        # Normalize for visualization
        importance_norm = (importance - importance.min()) / (importance.max() - importance.min() + 1e-8)
        
        header = "Layer " + "  ".join([f"H{i:02d}" for i in range(self.num_heads)])
        print(header)
        
        for layer_idx in range(self.num_layers):
            row = f"L{layer_idx:02d}   "
            for head_idx in range(self.num_heads):
                val = importance_norm[layer_idx, head_idx].item()
                if val > 0.8:
                    row += "████  "
                elif val > 0.6:
                    row += "███░  "
                elif val > 0.4:
                    row += "██░░  "
                elif val > 0.2:
                    row += "█░░░  "
                else:
                    row += "░░░░  "
            print(row)


class HeadPruner:
    """
    Apply structured head pruning to transformer models.
    Creates actual smaller attention layers (not masked).
    """
    
    def __init__(self, model: nn.Module):
        self.model = copy.deepcopy(model)
    
    def prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> nn.Module:
        """
        Prune specified heads from the model.
        
        Args:
            heads_to_prune: Dict mapping layer_idx -> list of head indices to remove
        
        Returns:
            New model with pruned heads
        """
        # Use HuggingFace's built-in head pruning for transformers
        for layer_idx, head_indices in heads_to_prune.items():
            self.model.encoder.layer[layer_idx].attention.prune_heads(head_indices)
        
        # Update config
        # Note: After pruning, different layers may have different head counts
        
        return self.model
    
    @staticmethod
    def count_remaining_heads(model) -> Dict[int, int]:
        """Count remaining heads per layer after pruning."""
        head_counts = {}
        for layer_idx, layer in enumerate(model.encoder.layer):
            # After pruning, attention.self has reduced dimensions
            attn = layer.attention.self
            num_heads = attn.num_attention_heads
            head_counts[layer_idx] = num_heads
        return head_counts


# Example usage
def prune_bert_heads_example():
    """Complete example of pruning BERT heads."""
    
    # Load model
    model = BertModel.from_pretrained('bert-base-uncased')
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    
    print("Original model:")
    print(f"  Layers: {model.config.num_hidden_layers}")
    print(f"  Heads per layer: {model.config.num_attention_heads}")
    print(f"  Total heads: {model.config.num_hidden_layers * model.config.num_attention_heads}")
    
    # Create sample data
    texts = [
        "The quick brown fox jumps over the lazy dog.",
        "Machine learning models can be pruned efficiently.",
        "Attention heads have different specializations.",
    ] * 20
    
    encodings = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
    dataset = torch.utils.data.TensorDataset(
        encodings['input_ids'], 
        encodings['attention_mask']
    )
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
    
    # Wrap dataloader to provide dict-style batches
    def dict_collate(batch):
        return {
            'input_ids': torch.stack([item[0] for item in batch]),
            'attention_mask': torch.stack([item[1] for item in batch])
        }
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, collate_fn=dict_collate)
    
    # Analyze importance
    analyzer = HeadImportanceAnalyzer(model)
    
    print("\nComputing head importance (this may take a moment)...")
    importance = analyzer.compute_entropy_importance(dataloader, num_batches=10)
    
    # Visualize
    analyzer.visualize_importance(importance)
    
    # Determine heads to prune
    heads_to_prune = analyzer.get_heads_to_prune(
        importance, 
        prune_ratio=0.4,
        global_pruning=True,
        min_heads_per_layer=4
    )
    
    print(f"\nHeads to prune (40% target):")
    for layer, heads in heads_to_prune.items():
        print(f"  Layer {layer}: heads {heads}")
    
    total_pruned = sum(len(h) for h in heads_to_prune.values())
    print(f"  Total pruned: {total_pruned} / {12*12} = {total_pruned/(12*12)*100:.1f}%")
    
    # Apply pruning
    pruner = HeadPruner(model)
    pruned_model = pruner.prune_heads(heads_to_prune)
    
    # Count parameters
    orig_params = sum(p.numel() for p in model.parameters())
    pruned_params = sum(p.numel() for p in pruned_model.parameters())
    
    print(f"\nParameter reduction:")
    print(f"  Original: {orig_params:,}")
    print(f"  Pruned: {pruned_params:,}")
    print(f"  Reduction: {(1 - pruned_params/orig_params)*100:.1f}%")
    
    return pruned_model

if __name__ == "__main__":
    prune_bert_heads_example()

Hardware Impact of Head Pruning

Head pruning provides real speedup because it reduces matrix dimensions while maintaining dense operations.

Head Pruning: Hardware Acceleration Path
Original: 12 Heads Projections (Q, K, V) W_Q: [768, 768] W_K: [768, 768] W_V: [768, 768] 12 Parallel Attention Computations W_O: [768, 768] Pruned: 6 Heads (50%) Smaller Projections W_Q: [768, 384] W_K: [768, 384] W_V: [768, 384] 6 Parallel Attention Computations W_O: [384, 768] Hardware Acceleration Analysis Compute Savings Projections (Q,K,V,O): 768×768×4 → 768×384×4 + 384×768 ~37% reduction Attention computation: ~50% reduction (6 vs 12 heads) Memory Savings Parameters per layer: Original: 2.36M Pruned: 1.62M ~31% reduction KV Cache also reduced! Real-World Speedup (A100) BERT-base (batch=1, seq=512): Original: 4.2 ms 50% heads pruned: 2.9 ms Speedup: 1.45x (Less than 2x due to fixed costs)
Head pruning achieves ~1.4-1.6x speedup at 50% pruning due to dense operation efficiency

FFN Width Pruning

The FFN sublayer is actually the largest part of a transformer block (~66% of parameters). Pruning FFN intermediate dimensions can yield significant savings.

Python ffn_pruning.py
import torch
import torch.nn as nn

def prune_ffn_neurons(
    model,
    prune_ratio: float = 0.3,
    criterion: str = 'l1'
):
    """
    Prune intermediate neurons in FFN layers.
    
    FFN structure: input → intermediate (4x) → output
    We prune neurons in the intermediate layer.
    """
    
    for layer_idx, layer in enumerate(model.encoder.layer):
        ffn = layer.intermediate  # First linear: [hidden, intermediate]
        ffn_out = layer.output  # Second linear: [intermediate, hidden]
        
        # Compute neuron importance
        if criterion == 'l1':
            # Sum of input and output weights for each intermediate neuron
            importance = ffn.dense.weight.data.abs().sum(dim=1) + \
                        ffn_out.dense.weight.data.abs().sum(dim=0)
        
        else:
            # L2 norm
            importance = ffn.dense.weight.data.pow(2).sum(dim=1).sqrt() + \
                        ffn_out.dense.weight.data.pow(2).sum(dim=0).sqrt()
        
        # Determine neurons to keep
        num_neurons = importance.numel()
        num_to_keep = int(num_neurons * (1 - prune_ratio))
        _, keep_indices = importance.topk(num_to_keep)
        keep_indices = keep_indices.sort().values
        
        # Create new smaller FFN
        new_intermediate = nn.Linear(
            ffn.dense.in_features,
            num_to_keep,
            bias=ffn.dense.bias is not None
        )
        new_output = nn.Linear(
            num_to_keep,
            ffn_out.dense.out_features,
            bias=ffn_out.dense.bias is not None
        )
        
        # Copy weights
        new_intermediate.weight.data = ffn.dense.weight.data[keep_indices]
        if ffn.dense.bias is not None:
            new_intermediate.bias.data = ffn.dense.bias.data[keep_indices]
        
        new_output.weight.data = ffn_out.dense.weight.data[:, keep_indices]
        if ffn_out.dense.bias is not None:
            new_output.bias.data = ffn_out.dense.bias.data.clone()
        
        # Replace in model
        ffn.dense = new_intermediate
        ffn_out.dense = new_output
        
        print(f"Layer {layer_idx}: FFN {num_neurons}{num_to_keep} neurons")
    
    return model


# Example: Prune BERT FFN by 30%
from transformers import BertModel

model = BertModel.from_pretrained('bert-base-uncased')
print(f"Original intermediate size: {model.config.intermediate_size}")

pruned_model = prune_ffn_neurons(model, prune_ratio=0.3)

# New intermediate size: 3072 → 2150
# Parameter savings: ~20% of total model

Layer Pruning: Maximum Impact

The most aggressive form of structured pruning: remove entire transformer layers. This provides the largest speedup but requires careful selection.

Layer Pruning: Which Layers to Remove?
Layer Importance Analysis (BERT-12 on GLUE) Layer Index Importance Score L1 High L2 L3 L4 L5 L6 L7 L8 L9 L10 L11 High L12 Highest Critical (keep) Important (prune carefully) Prunable (safe to remove) Safe to Remove (Little Accuracy Loss) Layers 4-8: Middle layers are most redundant Avoid Removing L1-2: Syntax patterns, L11-12: Task-specific features
Middle layers (4-8) are most prunable; first and last layers are critical
Python layer_pruning.py
import torch
import torch.nn as nn
from transformers import BertModel, BertConfig
import copy

def prune_layers(
    model: nn.Module,
    layers_to_remove: list
) -> nn.Module:
    """
    Remove entire transformer layers from the model.
    
    Args:
        model: BERT-style model
        layers_to_remove: List of layer indices to remove (0-indexed)
    
    Returns:
        Model with specified layers removed
    """
    model = copy.deepcopy(model)
    
    # Get all layers
    all_layers = list(model.encoder.layer)
    num_layers = len(all_layers)
    
    # Filter out layers to remove
    layers_to_keep = [
        i for i in range(num_layers) 
        if i not in layers_to_remove
    ]
    
    # Create new layer list
    new_layers = nn.ModuleList([all_layers[i] for i in layers_to_keep])
    model.encoder.layer = new_layers
    
    # Update config
    model.config.num_hidden_layers = len(new_layers)
    
    print(f"Removed layers: {layers_to_remove}")
    print(f"Remaining layers: {len(new_layers)}")
    
    return model


def compute_layer_importance(
    model,
    dataloader,
    num_batches: int = 50
) -> torch.Tensor:
    """
    Compute importance of each layer using hidden state similarity.
    
    Layers where input ≈ output are less important.
    """
    model.eval()
    device = next(model.parameters()).device
    num_layers = model.config.num_hidden_layers
    
    importance = torch.zeros(num_layers, device=device)
    
    def hook_fn(layer_idx):
        def hook(module, input, output):
            # Compute how much the layer changes the representation
            input_hidden = input[0]  # [batch, seq, hidden]
            output_hidden = output[0]  # [batch, seq, hidden]
            
            # L2 distance between input and output (normalized)
            diff = (output_hidden - input_hidden).pow(2).mean()
            importance[layer_idx] += diff.item()
        return hook
    
    # Register hooks
    handles = []
    for i, layer in enumerate(model.encoder.layer):
        handle = layer.register_forward_hook(hook_fn(i))
        handles.append(handle)
    
    # Process batches
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            if batch_idx >= num_batches:
                break
            
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            _ = model(input_ids=input_ids, attention_mask=attention_mask)
    
    # Remove hooks
    for handle in handles:
        handle.remove()
    
    # Normalize
    importance = importance / num_batches
    
    return importance


# Example: Remove middle layers from BERT
model = BertModel.from_pretrained('bert-base-uncased')

# Common strategy: Remove layers 4, 5, 6, 7 (keep 1-3 and 8-12)
# This creates an 8-layer BERT
layers_to_remove = [4, 5, 6, 7]

pruned_model = prune_layers(model, layers_to_remove)

# Count parameters
orig_params = sum(p.numel() for p in model.parameters())
pruned_params = sum(p.numel() for p in pruned_model.parameters())

print(f"\nParameters: {orig_params:,}{pruned_params:,}")
print(f"Reduction: {(1 - pruned_params/orig_params)*100:.1f}%")
print(f"Speedup: ~{12/8:.2f}x (proportional to layer count)")

Real-World Results: Transformer Pruning

Structured Pruning Results on BERT (GLUE Benchmark)
Model Params Heads Layers FFN MNLI QQP Speedup BERT-base 110M 12×12 12 3072 84.6 91.1 1.0x 40% Head Pruned 94M ~7×12 12 3072 84.2 90.8 1.3x 30% FFN Pruned 88M 12×12 12 ~2150 84.3 90.9 1.25x 4 Layers Removed 73M 12×8 8 3072 83.1 90.3 1.5x Combined (All) 52M ~7×8 8 ~2150 82.5 89.8 2.1x DistilBERT (ref) 66M 12×6 6 3072 82.8 89.5 1.6x Key Insight: Combined pruning (heads + FFN + layers) achieves 2x+ speedup with only ~2% accuracy loss on GLUE tasks
Combined structured pruning achieves better efficiency than DistilBERT with comparable accuracy

Summary: Structured Pruning for Transformers

Pruning Type Target Typical Reduction Accuracy Impact Speedup
Head Pruning Attention heads 30-50% heads <1% loss 1.3-1.5x
FFN Pruning Intermediate neurons 20-40% neurons <0.5% loss 1.2-1.4x
Layer Pruning Entire layers 2-4 layers (16-33%) 1-2% loss 1.3-1.5x
Combined All above 50%+ reduction 2-3% loss 2x+
Practical Recommendations

For BERT-style models:

  • Start with head pruning (40% heads) for quick wins with minimal accuracy loss
  • Add FFN neuron pruning (30%) for additional savings
  • Use layer pruning (remove middle layers) only if you need >1.5x speedup
  • Always fine-tune after pruning on your target task (5-10 epochs)
  • Use Taylor importance for best accuracy, entropy for quick analysis

For LLMs (GPT-style): Head and FFN pruning work, but layer pruning is more sensitive. Consider N:M sparsity (Section 5) or quantization (Section 6+) instead.

5. N:M Sparsity: Hardware-Native Patterns

N:M sparsity is the first pruning technique designed specifically for modern hardware accelerators. Unlike unstructured or structured pruning, N:M enforces a fixed pattern: in every group of M weights, exactly N are nonzero. This enables real speedup on GPUs and TPUs.

What is N:M Sparsity?

For a given layer, weights are divided into blocks of size M (e.g., 4, 8, 16). In each block, only N weights are kept, the rest are set to zero. Example: 2:4 sparsity means every group of 4 weights has exactly 2 nonzero.

N:M Sparsity Pattern Example
2:4 Sparsity (N=2, M=4) 0 w₁ 0 w₂ Block: [0, w₁, 0, w₂] (2 nonzero, 2 zero) Repeat for all blocks
Every block of 4 weights has exactly 2 nonzero (2:4 sparsity)

Why N:M Enables Real Hardware Speedup

Hardware Acceleration: N:M vs Unstructured
Unstructured Sparsity Random zeros, no pattern No speedup on GPU N:M Sparsity Fixed zeros, hardware-native Tensor core speedup
N:M sparsity is the only pattern that enables real GPU speedup

Implementation: N:M Pruning Pipeline

Python nm_pruning.py
import torch
        import torch.nn as nn
        import numpy as np
        
        def nm_prune(
            weight: torch.Tensor,
            N: int = 2,
            M: int = 4
        ) -> torch.Tensor:
            """
            Prune tensor to N:M sparsity (N nonzero per M).
            """
            w = weight.clone()
            shape = w.shape
            w_flat = w.view(-1)
            num_blocks = w_flat.numel() // M
            for i in range(num_blocks):
                block = w_flat[i*M:(i+1)*M]
                abs_block = block.abs()
                keep_idx = abs_block.topk(N).indices
                mask = torch.zeros(M, dtype=torch.bool)
                mask[keep_idx] = True
                block *= mask
            return w_flat.view(shape)
        
        # Example: Prune a linear layer to 2:4 sparsity
        layer = nn.Linear(1024, 1024)
        layer.weight.data = nm_prune(layer.weight.data, N=2, M=4)
        

Real-World Results: N:M Sparsity

N:M Sparsity Results (BERT, GPT, ResNet)
2:4 Sparsity (50% reduction) BERT-base: Speedup: 1.7x (A100) Accuracy drop: <1% GPT-2: Speedup: 1.6x (A100) Accuracy drop: <1% ResNet-50: Speedup: 1.5x (A100) Accuracy drop: <0.5%
N:M sparsity achieves 1.5-1.7x speedup with minimal accuracy loss

Summary: N:M Sparsity

Pattern Speedup Accuracy Drop Hardware Support
2:4 (50%) 1.5-1.7x <1% A100, H100, TPU
1:4 (25%) 1.2x <2% Experimental
4:8 (50%) 1.5x <1% TPU
Practical Recommendations
  • Use 2:4 sparsity for maximum speedup on NVIDIA A100/H100
  • Apply N:M pruning after structured pruning for best results
  • Always fine-tune after pruning (5-10 epochs)
  • Check hardware support before deploying
  • For LLMs, combine N:M with quantization for best efficiency

6. Quantization Foundations

Content coming in Batch 6...

7. Post-Training Quantization (PTQ)

Content coming in Batch 7...

8. Quantization-Aware Training (QAT)

Content coming in Batch 8...

9. LLM Quantization Techniques

Content coming in Batch 9...

10. Knowledge Distillation

Content coming in Batch 10...

11. Efficient Attention & KV Cache

Content coming in Batch 11...

12. Hardware-Specific Deployment

Content coming in Batch 12...

13. Case Studies & Best Practices

Content coming in Batch 13...