- The Inference Challenge
- Pruning Foundations & Unstructured Pruning
- Structured Pruning for CNNs
- Structured Pruning for Transformers
- N:M Sparsity: Hardware-Native Patterns
- Quantization Foundations
- Post-Training Quantization (PTQ)
- Quantization-Aware Training (QAT)
- LLM Quantization Techniques
- Knowledge Distillation
- Efficient Attention & KV Cache
- Hardware-Specific Deployment
- Case Studies & Best Practices
1. The Inference Challenge
Training a model is only half the story. The other half—and often the more critical half for production systems—is inference: running the trained model to make predictions on new data. While training happens once (or periodically), inference happens millions or billions of times in deployment. A 10ms improvement in inference latency can translate to millions of dollars in compute savings and dramatically improved user experience.
This guide provides a comprehensive treatment of inference optimization, with particular emphasis on the hardware implications of each technique. Understanding why certain optimizations work requires understanding how modern processors—GPUs, CPUs, and specialized accelerators—actually execute neural network operations.
The Scale of the Problem
Consider the computational demands of modern AI systems:
The sheer scale creates multiple challenges:
- Memory capacity: A 70B parameter model requires 140GB of memory in FP16—more than any single GPU can hold
- Memory bandwidth: Even if the model fits, loading weights from memory is often slower than computation
- Compute cost: Generating 1000 tokens from GPT-4 costs approximately $0.03-0.06 in compute
- Latency requirements: Interactive applications demand sub-100ms response times
- Energy consumption: A single H100 GPU draws 700W under load; large deployments cost millions in electricity
For a large-scale LLM deployment serving 1 million requests/day:
- Raw compute cost: $50,000–$150,000/month on cloud GPUs
- Every 10% improvement in efficiency saves $5,000–$15,000/month
- A 4x speedup from quantization can reduce costs by $150,000+/year
Memory-Bound vs Compute-Bound Workloads
The first step in optimizing inference is understanding which resource is the bottleneck. Modern neural networks can be either memory-bound or compute-bound, and the optimal strategy differs dramatically between the two.
The Memory Hierarchy
Modern processors have a deep memory hierarchy with vastly different bandwidths:
Arithmetic Intensity: The Key Metric
Arithmetic intensity (also called operational intensity) measures how much computation is performed per byte of memory accessed:
This metric determines whether an operation is memory-bound or compute-bound:
Why This Matters for Optimization
The bound determines which optimization strategy to use:
| Regime | Bottleneck | Optimization Strategy | Examples |
|---|---|---|---|
| Memory-bound | Loading weights/activations | Reduce data movement: quantization (INT8/INT4), KV cache compression, pruning | LLM generation (batch=1), small-batch inference |
| Compute-bound | Arithmetic operations | Reduce FLOPs: pruning, early exit, efficient architectures | CNNs, batched inference, training |
| Mixed | Both | Combined approaches: quantization + pruning | Attention layers, medium batches |
LLM token-by-token generation has arithmetic intensity of approximately 1-2 ops/byte. This means:
- An H100 (2000 TFLOPS, 3.35 TB/s) should theoretically generate at ~1600 tokens/second per 7B model
- Actual throughput is ~100-200 tokens/s because we're memory-bound by 10-20x
- Quantization is the single most impactful optimization—INT4 delivers 4x less memory traffic
The Roofline Model
The roofline model is a visual tool for understanding performance limits. It plots achievable FLOPS against arithmetic intensity, showing where your workload sits relative to hardware limits.
Reading the Roofline
- Below the roofline: There's room for optimization—your code isn't hitting hardware limits
- On the memory slope: You're memory-bound. Reduce data movement (quantization, pruning)
- On the compute ceiling: You're compute-bound. Reduce FLOPs or use faster hardware
- At the ridge: You're balanced—both resources fully utilized
import torch
import time
def analyze_operation(operation, input_tensors, num_flops, num_bytes):
"""Measure achieved performance and compare to roofline."""
# Warmup
for _ in range(10):
output = operation(*input_tensors)
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(100):
output = operation(*input_tensors)
torch.cuda.synchronize()
elapsed = (time.perf_counter() - start) / 100
# Calculate metrics
achieved_tflops = (num_flops / elapsed) / 1e12
achieved_bandwidth = (num_bytes / elapsed) / 1e12 # TB/s
arithmetic_intensity = num_flops / num_bytes
# H100 specs
peak_tflops = 2000 # FP16 Tensor Core
peak_bandwidth = 3.35 # TB/s HBM3
ridge_point = peak_tflops / peak_bandwidth # ~597 ops/byte
# Determine bottleneck
if arithmetic_intensity < ridge_point:
roofline_limit = arithmetic_intensity * peak_bandwidth
bound = "Memory-bound"
else:
roofline_limit = peak_tflops
bound = "Compute-bound"
efficiency = achieved_tflops / roofline_limit * 100
print(f"Arithmetic Intensity: {arithmetic_intensity:.1f} ops/byte")
print(f"Achieved: {achieved_tflops:.1f} TFLOPS")
print(f"Roofline Limit: {roofline_limit:.1f} TFLOPS")
print(f"Efficiency: {efficiency:.1f}% ({bound})")
return achieved_tflops, bound, efficiency
# Example: Analyze a linear layer
batch_size = 1
input_dim = 4096
output_dim = 4096
x = torch.randn(batch_size, input_dim, device='cuda', dtype=torch.float16)
W = torch.randn(output_dim, input_dim, device='cuda', dtype=torch.float16)
# FLOPs = 2 * batch * input * output (multiply + add)
num_flops = 2 * batch_size * input_dim * output_dim
# Bytes = weights + input + output (all FP16 = 2 bytes each)
num_bytes = 2 * (input_dim * output_dim + batch_size * input_dim + batch_size * output_dim)
analyze_operation(
lambda x, W: torch.mm(x, W.T),
(x, W),
num_flops,
num_bytes
)
# Output: Arithmetic Intensity: 1.0 ops/byte
# Achieved: 3.2 TFLOPS
# Roofline Limit: 3.35 TFLOPS
# Efficiency: 95.5% (Memory-bound)
The Hardware Landscape
Different hardware platforms have dramatically different characteristics. Understanding these differences is essential for choosing the right optimization strategy.
NVIDIA GPUs: The Deep Learning Workhorse
NVIDIA GPUs dominate deep learning inference due to their Tensor Cores—specialized matrix units that accelerate neural network operations.
CPUs: Underrated for Inference
Modern CPUs have significant neural network acceleration capabilities:
- AVX-512: 512-bit vector operations, good for batch processing
- AVX-512 VNNI: INT8 dot products optimized for inference (Cascade Lake+)
- Intel AMX: Dedicated matrix accelerator on Sapphire Rapids (up to 2x speedup)
- ARM SVE/SME: Scalable vector and matrix extensions for ARM servers
CPUs excel when:
- Batch size is very small (1-4)
- Model fits in CPU cache
- GPU is unavailable or too expensive
- Unstructured sparsity is used (CPUs handle irregular memory patterns better)
Edge & Mobile Accelerators
For edge deployment, specialized hardware provides better power efficiency:
| Accelerator | Platform | INT8 TOPS | Power | Efficiency | Use Case |
|---|---|---|---|---|---|
| Apple Neural Engine | M1/M2/M3 Macs, iPhones | 11-38 | ~5-15W | ~3 TOPS/W | On-device ML, Siri |
| Google Edge TPU | Coral, Pixel phones | 4 | 2W | 2 TOPS/W | IoT, edge inference |
| Qualcomm Hexagon | Snapdragon 8 Gen 3 | 73 | ~10W | ~7 TOPS/W | Mobile, on-device AI |
| NVIDIA Jetson Orin | Robotics, automotive | 275 | 60W | ~4.5 TOPS/W | Autonomous systems |
| Intel Movidius | Vision applications | 1 | 1W | 1 TOPS/W | Cameras, drones |
Choose based on your constraints:
- Maximum throughput: H100/H200 with INT8 or FP8
- Large models: H200 (141GB) or multi-GPU setups
- Cost-effective: L4 or A10G for inference
- Edge deployment: Match accelerator to power/latency budget
- CPU-only: Use AVX-512/AMX with optimized runtimes (ONNX, OpenVINO)
Latency vs Throughput: Understanding the Tradeoff
Two fundamental metrics compete for optimization:
- Latency: Time to complete a single request (ms). Critical for interactive applications.
- Throughput: Requests processed per second. Critical for batch processing and cost efficiency.
The relationship is not simple—optimizing for one often hurts the other:
Batching Strategies for LLMs
Modern LLM serving uses sophisticated batching to maximize efficiency:
- Static batching: Wait for N requests, process together. Simple but adds latency.
- Dynamic batching: Process as soon as GPU has capacity. Used by TensorRT, Triton.
- Continuous batching: Add new requests mid-generation. Used by vLLM, TGI. Best for LLMs.
- Speculative decoding: Use a small model to draft, large model to verify. Reduces latency.
Target latencies by application type:
- Real-time chat: <100ms first token, <50ms/token streaming
- Search/ranking: <50ms total (time-to-glass critical)
- Batch processing: Latency doesn't matter, maximize throughput
- Code completion: <200ms (user typing tolerance)
Overview of Optimization Techniques
This guide covers three major categories of inference optimization, each with distinct hardware implications:
What This Guide Covers
In the following sections, we'll explore each technique in depth, with particular focus on:
- Theoretical foundations: Why does this technique work? What are the mathematical principles?
- Hardware implications: How does this technique interact with GPU/CPU architecture? Why does it actually speed up inference?
- Implementation details: Practical code examples and library recommendations
- CNN vs Transformer differences: How does each architecture respond differently to optimization?
- Combining techniques: How to stack optimizations for maximum effect
This guide assumes familiarity with:
- Basic neural network architectures (CNNs, Transformers)
- PyTorch or TensorFlow fundamentals
- Basic linear algebra (matrix multiplication)
- Understanding of floating-point vs integer arithmetic
2. Pruning Foundations & Unstructured Pruning
Pruning is the process of removing unnecessary parameters from a neural network. The core insight is simple but profound: most neural networks are dramatically overparameterized. A well-trained network can often lose 80-95% of its weights with minimal accuracy loss.
But here's the critical question this section addresses: Does removing weights actually speed up inference? The answer depends entirely on the pruning pattern and the target hardware. Understanding this hardware-algorithm interaction is essential for effective pruning.
Why Pruning Works: The Lottery Ticket Hypothesis
The theoretical foundation for pruning comes from the Lottery Ticket Hypothesis (Frankle & Carlin, 2019):
"A randomly-initialized, dense neural network contains a subnetwork that is initialized such that—when trained in isolation—it can match the test accuracy of the original network after training for at most the same number of iterations."
In simpler terms: within every large network, there exists a small "winning ticket" subnetwork that can achieve the same performance. The rest of the parameters are essentially lottery losers—they don't contribute meaningfully to the final predictions.
Evidence for Overparameterization
Multiple lines of evidence suggest neural networks have far more parameters than necessary:
Pruning Criteria: Which Weights to Remove?
The fundamental question in pruning is: How do we identify which weights are "unnecessary"? Several criteria have been developed:
1. Magnitude-Based Pruning (Most Common)
The simplest and most widely used criterion: remove weights with the smallest absolute values.
Intuition: A weight close to zero contributes little to the output of a neuron. Mathematically, if $y = \sum_i w_i x_i$, then small $|w_i|$ means small contribution to $y$.
2. Gradient-Based Pruning
Consider both weight magnitude and how much the loss changes when the weight is perturbed:
This is the first-order Taylor expansion of the loss change. Weights with small magnitude but large gradients might still be important.
3. Second-Order Methods (OBS, OBD)
Use the Hessian to estimate the impact of removing a weight:
Where $H_{ww}$ is the corresponding diagonal element of the Hessian. More accurate but computationally expensive.
4. Activation-Based Pruning
For neurons/channels, measure importance by average activation magnitude:
Neurons that rarely activate (or activate weakly) can be removed.
Unstructured Pruning: Maximum Flexibility
Unstructured pruning (also called fine-grained or weight-level pruning) removes individual weights anywhere in the network. This provides maximum flexibility—any weight can be pruned regardless of position.
The Unstructured Pruning Process
Implementation: Magnitude Pruning in PyTorch
import torch
import torch.nn.utils.prune as prune
import copy
def magnitude_prune_model(model, sparsity=0.5):
"""
Apply unstructured magnitude pruning to all Linear and Conv2d layers.
Args:
model: PyTorch model to prune
sparsity: Fraction of weights to remove (0.5 = 50% pruned)
Returns:
Pruned model with masks applied
"""
model = copy.deepcopy(model)
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
# Apply L1 unstructured pruning
prune.l1_unstructured(
module,
name='weight',
amount=sparsity
)
print(f"Pruned {name}: {sparsity*100:.0f}% sparsity")
return model
def count_sparsity(model):
"""Calculate actual sparsity of the model."""
total_params = 0
zero_params = 0
for name, param in model.named_parameters():
if 'weight' in name:
total_params += param.numel()
zero_params += (param == 0).sum().item()
sparsity = zero_params / total_params
print(f"Total parameters: {total_params:,}")
print(f"Zero parameters: {zero_params:,}")
print(f"Sparsity: {sparsity*100:.1f}%")
return sparsity
def make_pruning_permanent(model):
"""Remove pruning hooks and make sparse weights permanent."""
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
try:
prune.remove(module, 'weight')
except ValueError:
pass # Already permanent
return model
# Example usage
from torchvision.models import resnet50
# Load pre-trained model
model = resnet50(pretrained=True)
model.eval()
# Check original size
print("Original model:")
count_sparsity(model)
# Apply 70% unstructured pruning
pruned_model = magnitude_prune_model(model, sparsity=0.7)
# Check after pruning
print("\nAfter pruning:")
count_sparsity(pruned_model)
# Make permanent (removes hooks)
pruned_model = make_pruning_permanent(pruned_model)
Hardware Implications: Why Unstructured Pruning Doesn't Speed Up GPUs
Here's the critical insight that many practitioners miss: unstructured pruning, despite achieving high sparsity, often provides NO speedup on modern GPUs. Understanding why requires diving into how GPUs actually execute matrix operations.
How GPUs Execute Dense Matrix Multiplication
GPUs achieve high throughput through data parallelism and memory coalescing:
The Three Killers of Sparse GPU Performance
Unstructured sparsity causes three fundamental problems on GPUs:
Sparse Storage Formats
Sparse matrices require special storage formats to avoid storing zeros. Common formats include:
Benchmark: Dense vs Sparse on GPU
import torch
import time
import numpy as np
def benchmark_dense_vs_sparse(size=4096, sparsity=0.9, num_runs=100):
"""
Compare dense vs sparse matrix multiplication on GPU.
This demonstrates why unstructured sparsity doesn't help on GPU.
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Create dense matrices
A_dense = torch.randn(size, size, device=device, dtype=torch.float16)
B = torch.randn(size, size, device=device, dtype=torch.float16)
# Create sparse version (same values, but sparse format)
mask = torch.rand(size, size, device=device) > sparsity
A_sparse_values = A_dense * mask.float()
A_sparse = A_sparse_values.to_sparse_csr()
actual_sparsity = 1 - (A_sparse._nnz() / (size * size))
print(f"Matrix size: {size}×{size}")
print(f"Target sparsity: {sparsity*100:.0f}%")
print(f"Actual sparsity: {actual_sparsity*100:.1f}%")
print(f"Non-zeros: {A_sparse._nnz():,} / {size*size:,}")
print()
# Warmup
for _ in range(10):
_ = torch.mm(A_dense, B)
_ = torch.sparse.mm(A_sparse, B)
torch.cuda.synchronize()
# Benchmark dense
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(num_runs):
result_dense = torch.mm(A_dense, B)
torch.cuda.synchronize()
dense_time = (time.perf_counter() - start) / num_runs * 1000
# Benchmark sparse
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(num_runs):
result_sparse = torch.sparse.mm(A_sparse, B)
torch.cuda.synchronize()
sparse_time = (time.perf_counter() - start) / num_runs * 1000
print(f"Dense GEMM: {dense_time:.2f} ms")
print(f"Sparse SpMM: {sparse_time:.2f} ms")
print(f"Speedup: {dense_time/sparse_time:.2f}x")
if sparse_time > dense_time:
print(f"\n⚠️ Sparse is {sparse_time/dense_time:.1f}x SLOWER despite {actual_sparsity*100:.0f}% sparsity!")
return dense_time, sparse_time
# Run benchmark
if __name__ == "__main__":
print("="*60)
print("Dense vs Sparse GPU Benchmark")
print("="*60)
for sparsity in [0.5, 0.7, 0.9, 0.95]:
print(f"\n--- Sparsity: {sparsity*100:.0f}% ---")
benchmark_dense_vs_sparse(size=4096, sparsity=sparsity)
# Typical output on A100:
# --- Sparsity: 90% ---
# Matrix size: 4096×4096
# Dense GEMM: 0.45 ms
# Sparse SpMM: 1.23 ms
# ⚠️ Sparse is 2.7x SLOWER despite 90% sparsity!
Do NOT expect speedups from unstructured pruning on GPUs. In practice:
- 50% sparsity: Usually slower than dense
- 90% sparsity: Often still slower than dense
- 95%+ sparsity: Might break even, rarely faster
- The overhead of index calculations and irregular memory access outweighs the reduced compute
Where Unstructured Pruning Actually Helps
Despite the GPU limitations, unstructured pruning has valid use cases:
1. CPU Inference
CPUs handle irregular memory access better than GPUs because:
- Larger caches can hide latency
- Branch prediction helps with conditional execution
- Libraries like DeepSparse (Neural Magic) optimize for CPU sparse execution
2. Memory Reduction (Not Speed)
Even if compute isn't faster, sparse storage reduces memory:
- 90% sparse model uses ~10% of the memory for weights
- Useful when memory is the bottleneck (edge devices, large models)
- Must use sparse format to realize savings (not just zero values)
3. As a Prelude to Structured Pruning
Unstructured pruning can identify which structures to remove:
- Find channels/heads with mostly-zero weights
- Convert to structured sparsity for actual speedup
- Used in many pruning pipelines
4. Specialized Hardware
Some hardware natively accelerates unstructured sparsity:
- Cerebras CS-2: Wafer-scale chip with native sparse support
- Graphcore IPU: Designed for irregular computation
- SambaNova: Dataflow architecture handles sparsity well
# Using DeepSparse for CPU-optimized sparse inference
# pip install deepsparse
from deepsparse import compile_model
import numpy as np
# Load a pruned ONNX model
# DeepSparse automatically detects and accelerates sparsity
model_path = "pruned_bert_90sparse.onnx"
# Compile for optimized sparse CPU execution
engine = compile_model(model_path, batch_size=1)
# Run inference
inputs = {
"input_ids": np.array([[101, 2023, 2003, 1037, 3231, 102]]),
"attention_mask": np.array([[1, 1, 1, 1, 1, 1]]),
"token_type_ids": np.array([[0, 0, 0, 0, 0, 0]])
}
outputs = engine.run(inputs)
# DeepSparse can achieve 3-5x speedup on CPU with 90% sparsity
# Compare: Dense PyTorch on same CPU is much slower
Iterative Pruning: Achieving Higher Sparsity
One-shot pruning (prune once to target sparsity) often hurts accuracy. Iterative pruning gradually increases sparsity while fine-tuning, achieving better accuracy-sparsity tradeoffs.
Implementation: Iterative Magnitude Pruning
import torch
import torch.nn.utils.prune as prune
from torch.utils.data import DataLoader
def iterative_pruning(
model,
train_loader,
val_loader,
target_sparsity=0.9,
num_iterations=5,
epochs_per_iteration=3,
criterion=torch.nn.CrossEntropyLoss(),
lr=1e-4
):
"""
Iteratively prune and fine-tune to reach target sparsity.
Args:
model: Model to prune
train_loader: Training data
val_loader: Validation data
target_sparsity: Final sparsity target (e.g., 0.9 for 90%)
num_iterations: Number of prune-finetune cycles
epochs_per_iteration: Fine-tuning epochs per cycle
"""
device = next(model.parameters()).device
# Calculate per-iteration sparsity increase
# Use exponential schedule: prune less early, more later
sparsities = []
current = 0.0
for i in range(num_iterations):
# Each iteration removes (1-s) of remaining weights
remaining_to_prune = target_sparsity - current
prune_fraction = remaining_to_prune / (num_iterations - i)
current += prune_fraction
sparsities.append(current)
print(f"Pruning schedule: {[f'{s*100:.0f}%' for s in sparsities]}")
for iteration, target_sp in enumerate(sparsities):
print(f"\n=== Iteration {iteration+1}/{num_iterations}: Target {target_sp*100:.0f}% ===")
# Calculate how much to prune relative to current weights
# We need to prune (target_sp - current_sp) / (1 - current_sp)
current_sp = get_model_sparsity(model)
if current_sp >= target_sp:
continue
relative_prune = (target_sp - current_sp) / (1 - current_sp)
# Apply pruning
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
prune.l1_unstructured(module, name='weight', amount=relative_prune)
print(f"Sparsity after pruning: {get_model_sparsity(model)*100:.1f}%")
# Fine-tune
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs_per_iteration):
model.train()
total_loss = 0
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
total_loss += loss.item()
# Validate
val_acc = evaluate(model, val_loader, device)
print(f" Epoch {epoch+1}: Loss={total_loss/len(train_loader):.4f}, Val Acc={val_acc*100:.1f}%")
# Make pruning permanent
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)):
try:
prune.remove(module, 'weight')
except:
pass
return model
def get_model_sparsity(model):
"""Calculate current sparsity of model."""
total = 0
zeros = 0
for name, param in model.named_parameters():
if 'weight' in name:
total += param.numel()
zeros += (param == 0).sum().item()
return zeros / total if total > 0 else 0
def evaluate(model, val_loader, device):
"""Evaluate model accuracy."""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data, target in val_loader:
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
return correct / total
Summary: Unstructured Pruning
| Aspect | Details |
|---|---|
| What it does | Removes individual weights based on magnitude or other criteria |
| Achievable sparsity | 80-95% with iterative pruning and fine-tuning |
| GPU speedup | ❌ Usually NONE or negative (slower than dense) |
| CPU speedup | ✓ 2-5x with optimized libraries (DeepSparse) |
| Memory savings | ✓ Significant if using sparse storage format |
| Best for | CPU inference, memory-constrained scenarios, pre-analysis for structured pruning |
| Not recommended for | GPU inference speedup (use structured or N:M instead) |
Unstructured pruning is a powerful technique, but NOT for GPU speedup. If your goal is faster GPU inference, skip ahead to Structured Pruning (Section 3) or N:M Sparsity (Section 5). Use unstructured pruning for CPU inference or memory reduction.
3. Structured Pruning for CNNs
Structured pruning removes entire architectural components rather than individual weights. For CNNs, this typically means removing filters (output channels) or channels (input features). The key advantage: the result is still a dense network that runs efficiently on standard hardware.
Unstructured pruning creates sparse matrices that GPUs can't accelerate. Structured pruning creates smaller dense matrices that run on the same optimized kernels, just with reduced dimensions. A 50% channel-pruned network runs in approximately 50% of the time—no special hardware required.
Channel and Filter Pruning: The Fundamentals
In a convolutional layer, the weight tensor has shape [C_out, C_in, H, W] where:
- C_out: Number of output channels (filters)
- C_in: Number of input channels
- H × W: Kernel spatial dimensions (e.g., 3×3)
Structured pruning can remove:
- Filter pruning: Remove entire filters (rows of C_out), reducing output channels
- Channel pruning: Remove input channels (columns of C_in), requires corresponding output pruning in previous layer
Filter Importance Criteria
How do we decide which filters to remove? Several criteria have been proposed:
1. L1-Norm (Filter Weight Magnitude)
The simplest approach: remove filters with the smallest L1 norm of their weights.
Intuition: Filters with small weights produce small activations, contributing less to the output.
2. L2-Norm (Euclidean Magnitude)
Similar to L1 but penalizes larger individual weights more heavily.
3. Batch Normalization Scaling Factor (γ)
If the layer is followed by BatchNorm, the γ parameter indicates channel importance:
During training with L1 regularization on γ, unimportant channels naturally have γ → 0.
4. Taylor Expansion (Gradient-Based)
Consider the change in loss when removing a filter:
Where $A_i$ is the activation of filter $i$. More computationally expensive but often more accurate.
5. Geometric Median
Remove filters most similar to other filters (redundant information):
Low importance = filter is "replaceable" by others.
Global vs Local Pruning
An important design choice: should we prune the same percentage from each layer, or rank filters globally?
Implementation: Channel Pruning in PyTorch
import torch
import torch.nn as nn
import copy
from typing import List, Dict, Tuple
class StructuredPruner:
"""
Structured (channel/filter) pruning for CNNs.
Removes entire filters and adjusts subsequent layers.
"""
def __init__(self, model: nn.Module, example_input: torch.Tensor):
self.model = copy.deepcopy(model)
self.example_input = example_input
self.layer_info = self._analyze_model()
def _analyze_model(self) -> Dict:
"""Analyze model structure to find prunable conv layers."""
info = {}
prev_layer = None
for name, module in self.model.named_modules():
if isinstance(module, nn.Conv2d):
info[name] = {
'module': module,
'out_channels': module.out_channels,
'in_channels': module.in_channels,
'prev_conv': prev_layer
}
prev_layer = name
elif isinstance(module, nn.BatchNorm2d):
if prev_layer in info:
info[prev_layer]['bn'] = name
return info
def compute_filter_importance(
self,
criterion: str = 'l1'
) -> Dict[str, torch.Tensor]:
"""
Compute importance scores for each filter in each layer.
Args:
criterion: 'l1', 'l2', 'taylor', or 'bn_gamma'
"""
importance = {}
for name, info in self.layer_info.items():
module = info['module']
if criterion == 'l1':
# L1 norm of each filter
scores = module.weight.data.abs().sum(dim=[1, 2, 3])
elif criterion == 'l2':
# L2 norm of each filter
scores = module.weight.data.pow(2).sum(dim=[1, 2, 3]).sqrt()
elif criterion == 'bn_gamma' and 'bn' in info:
# Use BatchNorm gamma as importance
bn_name = info['bn']
bn = dict(self.model.named_modules())[bn_name]
scores = bn.weight.data.abs()
else:
# Default to L1
scores = module.weight.data.abs().sum(dim=[1, 2, 3])
importance[name] = scores
return importance
def get_pruning_mask(
self,
importance: Dict[str, torch.Tensor],
prune_ratio: float,
global_pruning: bool = True,
min_channels: int = 8
) -> Dict[str, torch.Tensor]:
"""
Determine which filters to keep based on importance scores.
Args:
importance: Per-layer importance scores
prune_ratio: Fraction of filters to remove (0.5 = 50%)
global_pruning: If True, rank globally; else per-layer
min_channels: Minimum channels to keep per layer
"""
masks = {}
if global_pruning:
# Concatenate all scores and find global threshold
all_scores = torch.cat([s.flatten() for s in importance.values()])
num_to_prune = int(all_scores.numel() * prune_ratio)
threshold = torch.kthvalue(all_scores, num_to_prune).values.item()
for name, scores in importance.items():
# Keep filters above threshold
mask = scores > threshold
# Ensure minimum channels
if mask.sum() < min_channels:
_, top_idx = scores.topk(min_channels)
mask = torch.zeros_like(mask, dtype=torch.bool)
mask[top_idx] = True
masks[name] = mask
else:
# Per-layer pruning
for name, scores in importance.items():
num_filters = scores.numel()
num_to_keep = max(min_channels, int(num_filters * (1 - prune_ratio)))
_, top_idx = scores.topk(num_to_keep)
mask = torch.zeros(num_filters, dtype=torch.bool)
mask[top_idx] = True
masks[name] = mask
return masks
def apply_pruning(self, masks: Dict[str, torch.Tensor]) -> nn.Module:
"""
Create a new model with pruned channels.
This creates actual smaller dense layers, not masked sparse ones.
"""
# Build channel mapping for each layer
channel_mapping = {}
for name, mask in masks.items():
keep_indices = torch.where(mask)[0]
channel_mapping[name] = keep_indices
# Create new model with pruned layers
new_model = copy.deepcopy(self.model)
for name, info in self.layer_info.items():
if name not in masks:
continue
old_module = info['module']
keep_out = channel_mapping[name]
# Determine input channels to keep
prev_conv = info.get('prev_conv')
if prev_conv and prev_conv in channel_mapping:
keep_in = channel_mapping[prev_conv]
else:
keep_in = torch.arange(old_module.in_channels)
# Create new pruned conv layer
new_conv = nn.Conv2d(
in_channels=len(keep_in),
out_channels=len(keep_out),
kernel_size=old_module.kernel_size,
stride=old_module.stride,
padding=old_module.padding,
bias=old_module.bias is not None
)
# Copy weights for kept channels
new_conv.weight.data = old_module.weight.data[keep_out][:, keep_in]
if old_module.bias is not None:
new_conv.bias.data = old_module.bias.data[keep_out]
# Replace in model
self._set_module(new_model, name, new_conv)
# Update BatchNorm if present
if 'bn' in info:
bn_name = info['bn']
old_bn = dict(self.model.named_modules())[bn_name]
new_bn = nn.BatchNorm2d(len(keep_out))
new_bn.weight.data = old_bn.weight.data[keep_out]
new_bn.bias.data = old_bn.bias.data[keep_out]
new_bn.running_mean = old_bn.running_mean[keep_out]
new_bn.running_var = old_bn.running_var[keep_out]
self._set_module(new_model, bn_name, new_bn)
return new_model
def _set_module(self, model: nn.Module, name: str, new_module: nn.Module):
"""Set a module in the model by name."""
parts = name.split('.')
parent = model
for part in parts[:-1]:
parent = getattr(parent, part)
setattr(parent, parts[-1], new_module)
def prune(
self,
prune_ratio: float = 0.5,
criterion: str = 'l1',
global_pruning: bool = True
) -> nn.Module:
"""
Main method: prune the model by specified ratio.
Returns a new, smaller dense model.
"""
# Step 1: Compute importance
importance = self.compute_filter_importance(criterion)
# Step 2: Get pruning masks
masks = self.get_pruning_mask(
importance,
prune_ratio,
global_pruning
)
# Step 3: Apply pruning
pruned_model = self.apply_pruning(masks)
# Print stats
total_orig = sum(p.numel() for p in self.model.parameters())
total_new = sum(p.numel() for p in pruned_model.parameters())
print(f"Original parameters: {total_orig:,}")
print(f"Pruned parameters: {total_new:,}")
print(f"Reduction: {(1 - total_new/total_orig)*100:.1f}%")
return pruned_model
# Example usage
if __name__ == "__main__":
from torchvision.models import resnet18
# Load model
model = resnet18(pretrained=True)
model.eval()
# Create pruner
example_input = torch.randn(1, 3, 224, 224)
pruner = StructuredPruner(model, example_input)
# Prune 50% of filters
pruned_model = pruner.prune(
prune_ratio=0.5,
criterion='l1',
global_pruning=True
)
# Verify output shape
with torch.no_grad():
orig_out = model(example_input)
pruned_out = pruned_model(example_input)
print(f"Original output shape: {orig_out.shape}")
print(f"Pruned output shape: {pruned_out.shape}")
Hardware Speedup: Why Structured Pruning Works
Unlike unstructured pruning, structured pruning provides real speedups on any hardware because the result is still a dense network with standard tensor operations.
Benchmark: Structured Pruning Speedup
import torch
import torch.nn as nn
import time
def benchmark_structured_pruning(prune_ratio=0.5):
"""
Demonstrate that structured pruning gives proportional speedup.
"""
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Original convolution
in_channels = 256
out_channels = 256
kernel_size = 3
spatial_size = 56
conv_original = nn.Conv2d(in_channels, out_channels, kernel_size, padding=1).to(device)
# Pruned convolution (50% fewer output channels)
pruned_out = int(out_channels * (1 - prune_ratio))
conv_pruned = nn.Conv2d(in_channels, pruned_out, kernel_size, padding=1).to(device)
# Input tensor
x = torch.randn(1, in_channels, spatial_size, spatial_size, device=device)
# Warmup
for _ in range(20):
_ = conv_original(x)
_ = conv_pruned(x)
torch.cuda.synchronize()
# Benchmark original
num_runs = 100
start = time.perf_counter()
for _ in range(num_runs):
_ = conv_original(x)
torch.cuda.synchronize()
orig_time = (time.perf_counter() - start) / num_runs * 1000
# Benchmark pruned
start = time.perf_counter()
for _ in range(num_runs):
_ = conv_pruned(x)
torch.cuda.synchronize()
pruned_time = (time.perf_counter() - start) / num_runs * 1000
# Calculate stats
orig_params = in_channels * out_channels * kernel_size * kernel_size
pruned_params = in_channels * pruned_out * kernel_size * kernel_size
orig_flops = 2 * in_channels * out_channels * kernel_size * kernel_size * spatial_size * spatial_size
pruned_flops = 2 * in_channels * pruned_out * kernel_size * kernel_size * spatial_size * spatial_size
print(f"Original: {out_channels} output channels")
print(f"Pruned: {pruned_out} output channels ({prune_ratio*100:.0f}% pruned)")
print()
print(f"Parameters: {orig_params:,} → {pruned_params:,} ({(1-pruned_params/orig_params)*100:.0f}% reduction)")
print(f"FLOPs: {orig_flops/1e6:.1f}M → {pruned_flops/1e6:.1f}M ({(1-pruned_flops/orig_flops)*100:.0f}% reduction)")
print()
print(f"Latency: {orig_time:.3f} ms → {pruned_time:.3f} ms")
print(f"Speedup: {orig_time/pruned_time:.2f}x")
print()
theoretical_speedup = orig_flops / pruned_flops
actual_speedup = orig_time / pruned_time
efficiency = actual_speedup / theoretical_speedup * 100
print(f"Theoretical speedup: {theoretical_speedup:.2f}x")
print(f"Actual speedup: {actual_speedup:.2f}x")
print(f"Efficiency: {efficiency:.0f}%")
# Run benchmark
benchmark_structured_pruning(prune_ratio=0.5)
# Typical output on A100:
# Original: 256 output channels
# Pruned: 128 output channels (50% pruned)
#
# Parameters: 589,824 → 294,912 (50% reduction)
# FLOPs: 1849.7M → 924.8M (50% reduction)
#
# Latency: 0.182 ms → 0.098 ms
# Speedup: 1.86x
#
# Theoretical speedup: 2.00x
# Actual speedup: 1.86x
# Efficiency: 93% ← Near-linear speedup!
Example: Pruning ResNet-50
Let's see real-world results from pruning a ResNet-50 on ImageNet:
Layer Sensitivity Analysis
Not all layers are equally prunable. Understanding layer sensitivity helps achieve better accuracy-efficiency tradeoffs.
import torch
import torch.nn as nn
import copy
from typing import Dict, List
def analyze_layer_sensitivity(
model: nn.Module,
val_loader,
prune_ratios: List[float] = [0.1, 0.2, 0.3, 0.5, 0.7],
criterion = nn.CrossEntropyLoss()
) -> Dict[str, Dict[float, float]]:
"""
Analyze how sensitive each layer is to pruning.
For each layer, prune ONLY that layer at various ratios
and measure accuracy drop.
Returns:
Dict mapping layer_name -> {prune_ratio: accuracy}
"""
device = next(model.parameters()).device
# Get baseline accuracy
baseline_acc = evaluate(model, val_loader, device)
print(f"Baseline accuracy: {baseline_acc*100:.2f}%")
sensitivity = {}
# Find all Conv2d layers
conv_layers = [(name, module) for name, module in model.named_modules()
if isinstance(module, nn.Conv2d)]
for layer_name, layer_module in conv_layers:
sensitivity[layer_name] = {}
print(f"\nAnalyzing: {layer_name}")
for ratio in prune_ratios:
# Create a copy and prune only this layer
model_copy = copy.deepcopy(model)
# Apply pruning to this layer only
for name, module in model_copy.named_modules():
if name == layer_name:
prune_layer_filters(module, ratio)
break
# Evaluate
acc = evaluate(model_copy, val_loader, device)
sensitivity[layer_name][ratio] = acc
acc_drop = (baseline_acc - acc) * 100
print(f" {ratio*100:.0f}% pruned: {acc*100:.2f}% (drop: {acc_drop:.2f}%)")
del model_copy
return sensitivity
def prune_layer_filters(conv_module: nn.Conv2d, ratio: float):
"""Zero out the least important filters in a conv layer."""
with torch.no_grad():
# Compute L1 importance
importance = conv_module.weight.data.abs().sum(dim=[1, 2, 3])
# Find threshold
num_to_prune = int(len(importance) * ratio)
if num_to_prune == 0:
return
threshold = importance.kthvalue(num_to_prune).values
# Zero out unimportant filters
mask = importance > threshold
conv_module.weight.data[~mask] = 0
def evaluate(model, val_loader, device):
"""Quick evaluation for sensitivity analysis."""
model.eval()
correct = 0
total = 0
with torch.no_grad():
for i, (data, target) in enumerate(val_loader):
if i >= 50: # Use subset for speed
break
data, target = data.to(device), target.to(device)
output = model(data)
_, predicted = output.max(1)
total += target.size(0)
correct += predicted.eq(target).sum().item()
return correct / total
def visualize_sensitivity(sensitivity: Dict[str, Dict[float, float]]):
"""
Print sensitivity analysis results.
Layers with high sensitivity should be pruned less.
"""
print("\n" + "="*60)
print("Layer Sensitivity Summary (lower = more prunable)")
print("="*60)
# Calculate average sensitivity per layer at 50% pruning
layer_sensitivity = {}
for layer_name, ratios in sensitivity.items():
if 0.5 in ratios:
# Sensitivity = accuracy drop at 50%
baseline = ratios.get(0.0, ratios[min(ratios.keys())])
layer_sensitivity[layer_name] = baseline - ratios[0.5]
# Sort by sensitivity
sorted_layers = sorted(layer_sensitivity.items(), key=lambda x: x[1])
print("\nMost Prunable (low sensitivity):")
for name, sens in sorted_layers[:5]:
print(f" {name}: {sens*100:.2f}% drop at 50% pruning")
print("\nLeast Prunable (high sensitivity):")
for name, sens in sorted_layers[-5:]:
print(f" {name}: {sens*100:.2f}% drop at 50% pruning")
When pruning CNNs:
- Skip or lightly prune the first conv layer (conv1) and classification head
- Aggressively prune middle layers (layer2, layer3 in ResNets)
- Use global pruning to automatically allocate pruning budget based on importance
- Fine-tune for 10-30% of original training epochs after pruning
Summary: Structured Pruning for CNNs
| Aspect | Details |
|---|---|
| What it does | Removes entire filters/channels, creating smaller dense layers |
| Achievable reduction | 30-70% FLOPs reduction with <2% accuracy loss typical |
| GPU speedup | ✓ Near-linear speedup (50% pruning → ~1.8x faster) |
| CPU speedup | ✓ Same proportional speedup |
| Best criteria | L1-norm for speed, Taylor for best accuracy |
| Sensitive layers | First conv, classification head—prune less or skip |
| Prunable layers | Middle layers (layer2, layer3 in ResNet) |
4. Structured Pruning for Transformers
Transformers have unique structure compared to CNNs: attention heads, feed-forward networks (FFN), and embedding dimensions. Each offers different pruning opportunities with distinct hardware implications. This section provides a deep dive into head pruning—the most impactful and nuanced form of transformer pruning.
Transformer Anatomy: What Can We Prune?
A transformer block contains several prunable components:
Head Pruning: A Deep Dive
Attention head pruning is the most impactful form of structured pruning for transformers. The key insight from research (Michel et al., 2019; Voita et al., 2019) is that many attention heads are redundant—a 12-head BERT model can often function well with 6-8 heads.
Why Heads Can Be Pruned
Attention heads learn specialized roles, but there's significant redundancy:
Head Importance Metrics
How do we measure which heads to prune? Several methods exist:
1. Gradient-Based Importance (Taylor)
Estimate the impact of removing a head on the loss:
Where $A_h^t$ is the attention output of head $h$ at position $t$.
2. Attention Entropy
Heads with low entropy (peaked attention) often carry more information:
High entropy = diffuse attention = likely prunable. Low entropy = focused attention = likely important.
3. Head Confidence
Measure how "confident" a head is in its attention pattern:
Heads that consistently have low max attention may be contributing less.
4. Learnable Importance (Soft Masking)
Train a scalar importance weight per head during fine-tuning:
Where $z_h$ is a learnable importance score. Add L1 regularization to push unimportant heads to zero.
Implementation: Complete Head Pruning Pipeline
import torch
import torch.nn as nn
import copy
from transformers import BertModel, BertTokenizer
from typing import Dict, List, Tuple
import numpy as np
class HeadImportanceAnalyzer:
"""
Analyze and compute importance scores for attention heads.
"""
def __init__(self, model: nn.Module):
self.model = model
self.num_layers = model.config.num_hidden_layers
self.num_heads = model.config.num_attention_heads
self.head_dim = model.config.hidden_size // self.num_heads
# Storage for importance scores
self.taylor_scores = None
self.entropy_scores = None
def compute_taylor_importance(
self,
dataloader,
num_batches: int = 50
) -> torch.Tensor:
"""
Compute Taylor (gradient-based) importance for each head.
Shape: [num_layers, num_heads]
"""
self.model.eval()
device = next(self.model.parameters()).device
# Initialize importance scores
importance = torch.zeros(self.num_layers, self.num_heads, device=device)
# Register hooks to capture attention outputs and gradients
attention_outputs = {}
attention_grads = {}
def save_output_hook(name):
def hook(module, input, output):
# output[0] is attention output: [batch, seq, hidden]
attention_outputs[name] = output[0]
return hook
def save_grad_hook(name):
def hook(grad):
attention_grads[name] = grad
return hook
# Register hooks on attention layers
handles = []
for layer_idx in range(self.num_layers):
attn = self.model.encoder.layer[layer_idx].attention.self
name = f"layer_{layer_idx}"
handle = attn.register_forward_hook(save_output_hook(name))
handles.append(handle)
# Process batches
for batch_idx, batch in enumerate(dataloader):
if batch_idx >= num_batches:
break
# Move to device
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch.get('labels', input_ids).to(device)
# Forward pass
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=True
)
# Compute loss (for gradient)
# Using a simple proxy: variance of outputs
loss = outputs.last_hidden_state.var()
# Register gradient hooks
for name, output in attention_outputs.items():
output.retain_grad()
# Backward pass
loss.backward()
# Compute Taylor importance per head
for layer_idx in range(self.num_layers):
name = f"layer_{layer_idx}"
output = attention_outputs[name] # [batch, seq, hidden]
grad = output.grad
if grad is None:
continue
# Reshape to [batch, seq, num_heads, head_dim]
batch_size, seq_len, _ = output.shape
output_reshaped = output.view(batch_size, seq_len, self.num_heads, self.head_dim)
grad_reshaped = grad.view(batch_size, seq_len, self.num_heads, self.head_dim)
# Taylor importance: |output * grad|
head_importance = (output_reshaped * grad_reshaped).abs().sum(dim=[0, 1, 3])
importance[layer_idx] += head_importance
# Clear gradients
self.model.zero_grad()
# Remove hooks
for handle in handles:
handle.remove()
# Normalize
importance = importance / num_batches
self.taylor_scores = importance
return importance
def compute_entropy_importance(
self,
dataloader,
num_batches: int = 50
) -> torch.Tensor:
"""
Compute attention entropy for each head.
Lower entropy = more focused = likely more important.
Returns negative entropy so higher = more important.
"""
self.model.eval()
device = next(self.model.parameters()).device
entropy = torch.zeros(self.num_layers, self.num_heads, device=device)
with torch.no_grad():
for batch_idx, batch in enumerate(dataloader):
if batch_idx >= num_batches:
break
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=True
)
# attentions: tuple of [batch, num_heads, seq, seq]
for layer_idx, attn in enumerate(outputs.attentions):
# Compute entropy per head
# attn: [batch, num_heads, seq, seq]
# Add small epsilon for numerical stability
attn_clamped = attn.clamp(min=1e-8)
head_entropy = -(attn_clamped * attn_clamped.log()).sum(dim=-1)
# Average over batch and sequence
head_entropy = head_entropy.mean(dim=[0, 2])
entropy[layer_idx] += head_entropy
entropy = entropy / num_batches
# Return negative entropy (higher = more important)
self.entropy_scores = -entropy
return -entropy
def get_heads_to_prune(
self,
importance: torch.Tensor,
prune_ratio: float = 0.5,
global_pruning: bool = True,
min_heads_per_layer: int = 1
) -> Dict[int, List[int]]:
"""
Determine which heads to prune based on importance scores.
Returns:
Dict mapping layer_idx -> list of head indices to prune
"""
if global_pruning:
# Global pruning: rank all heads together
flat_importance = importance.flatten()
num_to_prune = int(flat_importance.numel() * prune_ratio)
# Find threshold
threshold = torch.kthvalue(flat_importance, num_to_prune).values.item()
heads_to_prune = {}
for layer_idx in range(self.num_layers):
layer_importance = importance[layer_idx]
prune_mask = layer_importance < threshold
# Ensure minimum heads kept
num_to_keep = self.num_heads - prune_mask.sum().item()
if num_to_keep < min_heads_per_layer:
# Keep the top min_heads_per_layer
_, top_indices = layer_importance.topk(min_heads_per_layer)
prune_mask = torch.ones(self.num_heads, dtype=torch.bool)
prune_mask[top_indices] = False
pruned_heads = torch.where(prune_mask)[0].tolist()
if pruned_heads:
heads_to_prune[layer_idx] = pruned_heads
else:
# Per-layer pruning
heads_to_prune = {}
num_to_prune_per_layer = max(1, int(self.num_heads * prune_ratio))
for layer_idx in range(self.num_layers):
layer_importance = importance[layer_idx]
_, indices = layer_importance.topk(num_to_prune_per_layer, largest=False)
heads_to_prune[layer_idx] = indices.tolist()
return heads_to_prune
def visualize_importance(self, importance: torch.Tensor):
"""Print a text visualization of head importance."""
print("\nHead Importance Heatmap:")
print("(Higher = More Important)")
print("-" * (10 + self.num_heads * 8))
# Normalize for visualization
importance_norm = (importance - importance.min()) / (importance.max() - importance.min() + 1e-8)
header = "Layer " + " ".join([f"H{i:02d}" for i in range(self.num_heads)])
print(header)
for layer_idx in range(self.num_layers):
row = f"L{layer_idx:02d} "
for head_idx in range(self.num_heads):
val = importance_norm[layer_idx, head_idx].item()
if val > 0.8:
row += "████ "
elif val > 0.6:
row += "███░ "
elif val > 0.4:
row += "██░░ "
elif val > 0.2:
row += "█░░░ "
else:
row += "░░░░ "
print(row)
class HeadPruner:
"""
Apply structured head pruning to transformer models.
Creates actual smaller attention layers (not masked).
"""
def __init__(self, model: nn.Module):
self.model = copy.deepcopy(model)
def prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> nn.Module:
"""
Prune specified heads from the model.
Args:
heads_to_prune: Dict mapping layer_idx -> list of head indices to remove
Returns:
New model with pruned heads
"""
# Use HuggingFace's built-in head pruning for transformers
for layer_idx, head_indices in heads_to_prune.items():
self.model.encoder.layer[layer_idx].attention.prune_heads(head_indices)
# Update config
# Note: After pruning, different layers may have different head counts
return self.model
@staticmethod
def count_remaining_heads(model) -> Dict[int, int]:
"""Count remaining heads per layer after pruning."""
head_counts = {}
for layer_idx, layer in enumerate(model.encoder.layer):
# After pruning, attention.self has reduced dimensions
attn = layer.attention.self
num_heads = attn.num_attention_heads
head_counts[layer_idx] = num_heads
return head_counts
# Example usage
def prune_bert_heads_example():
"""Complete example of pruning BERT heads."""
# Load model
model = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
print("Original model:")
print(f" Layers: {model.config.num_hidden_layers}")
print(f" Heads per layer: {model.config.num_attention_heads}")
print(f" Total heads: {model.config.num_hidden_layers * model.config.num_attention_heads}")
# Create sample data
texts = [
"The quick brown fox jumps over the lazy dog.",
"Machine learning models can be pruned efficiently.",
"Attention heads have different specializations.",
] * 20
encodings = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
dataset = torch.utils.data.TensorDataset(
encodings['input_ids'],
encodings['attention_mask']
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
# Wrap dataloader to provide dict-style batches
def dict_collate(batch):
return {
'input_ids': torch.stack([item[0] for item in batch]),
'attention_mask': torch.stack([item[1] for item in batch])
}
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, collate_fn=dict_collate)
# Analyze importance
analyzer = HeadImportanceAnalyzer(model)
print("\nComputing head importance (this may take a moment)...")
importance = analyzer.compute_entropy_importance(dataloader, num_batches=10)
# Visualize
analyzer.visualize_importance(importance)
# Determine heads to prune
heads_to_prune = analyzer.get_heads_to_prune(
importance,
prune_ratio=0.4,
global_pruning=True,
min_heads_per_layer=4
)
print(f"\nHeads to prune (40% target):")
for layer, heads in heads_to_prune.items():
print(f" Layer {layer}: heads {heads}")
total_pruned = sum(len(h) for h in heads_to_prune.values())
print(f" Total pruned: {total_pruned} / {12*12} = {total_pruned/(12*12)*100:.1f}%")
# Apply pruning
pruner = HeadPruner(model)
pruned_model = pruner.prune_heads(heads_to_prune)
# Count parameters
orig_params = sum(p.numel() for p in model.parameters())
pruned_params = sum(p.numel() for p in pruned_model.parameters())
print(f"\nParameter reduction:")
print(f" Original: {orig_params:,}")
print(f" Pruned: {pruned_params:,}")
print(f" Reduction: {(1 - pruned_params/orig_params)*100:.1f}%")
return pruned_model
if __name__ == "__main__":
prune_bert_heads_example()
Hardware Impact of Head Pruning
Head pruning provides real speedup because it reduces matrix dimensions while maintaining dense operations.
FFN Width Pruning
The FFN sublayer is actually the largest part of a transformer block (~66% of parameters). Pruning FFN intermediate dimensions can yield significant savings.
import torch
import torch.nn as nn
def prune_ffn_neurons(
model,
prune_ratio: float = 0.3,
criterion: str = 'l1'
):
"""
Prune intermediate neurons in FFN layers.
FFN structure: input → intermediate (4x) → output
We prune neurons in the intermediate layer.
"""
for layer_idx, layer in enumerate(model.encoder.layer):
ffn = layer.intermediate # First linear: [hidden, intermediate]
ffn_out = layer.output # Second linear: [intermediate, hidden]
# Compute neuron importance
if criterion == 'l1':
# Sum of input and output weights for each intermediate neuron
importance = ffn.dense.weight.data.abs().sum(dim=1) + \
ffn_out.dense.weight.data.abs().sum(dim=0)
else:
# L2 norm
importance = ffn.dense.weight.data.pow(2).sum(dim=1).sqrt() + \
ffn_out.dense.weight.data.pow(2).sum(dim=0).sqrt()
# Determine neurons to keep
num_neurons = importance.numel()
num_to_keep = int(num_neurons * (1 - prune_ratio))
_, keep_indices = importance.topk(num_to_keep)
keep_indices = keep_indices.sort().values
# Create new smaller FFN
new_intermediate = nn.Linear(
ffn.dense.in_features,
num_to_keep,
bias=ffn.dense.bias is not None
)
new_output = nn.Linear(
num_to_keep,
ffn_out.dense.out_features,
bias=ffn_out.dense.bias is not None
)
# Copy weights
new_intermediate.weight.data = ffn.dense.weight.data[keep_indices]
if ffn.dense.bias is not None:
new_intermediate.bias.data = ffn.dense.bias.data[keep_indices]
new_output.weight.data = ffn_out.dense.weight.data[:, keep_indices]
if ffn_out.dense.bias is not None:
new_output.bias.data = ffn_out.dense.bias.data.clone()
# Replace in model
ffn.dense = new_intermediate
ffn_out.dense = new_output
print(f"Layer {layer_idx}: FFN {num_neurons} → {num_to_keep} neurons")
return model
# Example: Prune BERT FFN by 30%
from transformers import BertModel
model = BertModel.from_pretrained('bert-base-uncased')
print(f"Original intermediate size: {model.config.intermediate_size}")
pruned_model = prune_ffn_neurons(model, prune_ratio=0.3)
# New intermediate size: 3072 → 2150
# Parameter savings: ~20% of total model
Layer Pruning: Maximum Impact
The most aggressive form of structured pruning: remove entire transformer layers. This provides the largest speedup but requires careful selection.
import torch
import torch.nn as nn
from transformers import BertModel, BertConfig
import copy
def prune_layers(
model: nn.Module,
layers_to_remove: list
) -> nn.Module:
"""
Remove entire transformer layers from the model.
Args:
model: BERT-style model
layers_to_remove: List of layer indices to remove (0-indexed)
Returns:
Model with specified layers removed
"""
model = copy.deepcopy(model)
# Get all layers
all_layers = list(model.encoder.layer)
num_layers = len(all_layers)
# Filter out layers to remove
layers_to_keep = [
i for i in range(num_layers)
if i not in layers_to_remove
]
# Create new layer list
new_layers = nn.ModuleList([all_layers[i] for i in layers_to_keep])
model.encoder.layer = new_layers
# Update config
model.config.num_hidden_layers = len(new_layers)
print(f"Removed layers: {layers_to_remove}")
print(f"Remaining layers: {len(new_layers)}")
return model
def compute_layer_importance(
model,
dataloader,
num_batches: int = 50
) -> torch.Tensor:
"""
Compute importance of each layer using hidden state similarity.
Layers where input ≈ output are less important.
"""
model.eval()
device = next(model.parameters()).device
num_layers = model.config.num_hidden_layers
importance = torch.zeros(num_layers, device=device)
def hook_fn(layer_idx):
def hook(module, input, output):
# Compute how much the layer changes the representation
input_hidden = input[0] # [batch, seq, hidden]
output_hidden = output[0] # [batch, seq, hidden]
# L2 distance between input and output (normalized)
diff = (output_hidden - input_hidden).pow(2).mean()
importance[layer_idx] += diff.item()
return hook
# Register hooks
handles = []
for i, layer in enumerate(model.encoder.layer):
handle = layer.register_forward_hook(hook_fn(i))
handles.append(handle)
# Process batches
with torch.no_grad():
for batch_idx, batch in enumerate(dataloader):
if batch_idx >= num_batches:
break
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
_ = model(input_ids=input_ids, attention_mask=attention_mask)
# Remove hooks
for handle in handles:
handle.remove()
# Normalize
importance = importance / num_batches
return importance
# Example: Remove middle layers from BERT
model = BertModel.from_pretrained('bert-base-uncased')
# Common strategy: Remove layers 4, 5, 6, 7 (keep 1-3 and 8-12)
# This creates an 8-layer BERT
layers_to_remove = [4, 5, 6, 7]
pruned_model = prune_layers(model, layers_to_remove)
# Count parameters
orig_params = sum(p.numel() for p in model.parameters())
pruned_params = sum(p.numel() for p in pruned_model.parameters())
print(f"\nParameters: {orig_params:,} → {pruned_params:,}")
print(f"Reduction: {(1 - pruned_params/orig_params)*100:.1f}%")
print(f"Speedup: ~{12/8:.2f}x (proportional to layer count)")
Real-World Results: Transformer Pruning
Summary: Structured Pruning for Transformers
| Pruning Type | Target | Typical Reduction | Accuracy Impact | Speedup |
|---|---|---|---|---|
| Head Pruning | Attention heads | 30-50% heads | <1% loss | 1.3-1.5x |
| FFN Pruning | Intermediate neurons | 20-40% neurons | <0.5% loss | 1.2-1.4x |
| Layer Pruning | Entire layers | 2-4 layers (16-33%) | 1-2% loss | 1.3-1.5x |
| Combined | All above | 50%+ reduction | 2-3% loss | 2x+ |
For BERT-style models:
- Start with head pruning (40% heads) for quick wins with minimal accuracy loss
- Add FFN neuron pruning (30%) for additional savings
- Use layer pruning (remove middle layers) only if you need >1.5x speedup
- Always fine-tune after pruning on your target task (5-10 epochs)
- Use Taylor importance for best accuracy, entropy for quick analysis
For LLMs (GPT-style): Head and FFN pruning work, but layer pruning is more sensitive. Consider N:M sparsity (Section 5) or quantization (Section 6+) instead.
5. N:M Sparsity: Hardware-Native Patterns
N:M sparsity is the first pruning technique designed specifically for modern hardware accelerators. Unlike unstructured or structured pruning, N:M enforces a fixed pattern: in every group of M weights, exactly N are nonzero. This enables real speedup on GPUs and TPUs.
What is N:M Sparsity?
For a given layer, weights are divided into blocks of size M (e.g., 4, 8, 16). In each block, only N weights are kept, the rest are set to zero. Example: 2:4 sparsity means every group of 4 weights has exactly 2 nonzero.
Why N:M Enables Real Hardware Speedup
- Fixed pattern allows hardware to skip zeros efficiently
- Tensor cores (NVIDIA A100, H100) natively support 2:4 sparsity
- Dense kernels are replaced by sparse kernels with guaranteed structure
- Memory savings (50% for 2:4) and compute savings (up to 2x)
Implementation: N:M Pruning Pipeline
import torch
import torch.nn as nn
import numpy as np
def nm_prune(
weight: torch.Tensor,
N: int = 2,
M: int = 4
) -> torch.Tensor:
"""
Prune tensor to N:M sparsity (N nonzero per M).
"""
w = weight.clone()
shape = w.shape
w_flat = w.view(-1)
num_blocks = w_flat.numel() // M
for i in range(num_blocks):
block = w_flat[i*M:(i+1)*M]
abs_block = block.abs()
keep_idx = abs_block.topk(N).indices
mask = torch.zeros(M, dtype=torch.bool)
mask[keep_idx] = True
block *= mask
return w_flat.view(shape)
# Example: Prune a linear layer to 2:4 sparsity
layer = nn.Linear(1024, 1024)
layer.weight.data = nm_prune(layer.weight.data, N=2, M=4)
Real-World Results: N:M Sparsity
Summary: N:M Sparsity
| Pattern | Speedup | Accuracy Drop | Hardware Support |
|---|---|---|---|
| 2:4 (50%) | 1.5-1.7x | <1% | A100, H100, TPU |
| 1:4 (25%) | 1.2x | <2% | Experimental |
| 4:8 (50%) | 1.5x | <1% | TPU |
- Use 2:4 sparsity for maximum speedup on NVIDIA A100/H100
- Apply N:M pruning after structured pruning for best results
- Always fine-tune after pruning (5-10 epochs)
- Check hardware support before deploying
- For LLMs, combine N:M with quantization for best efficiency
6. Quantization Foundations
Content coming in Batch 6...
7. Post-Training Quantization (PTQ)
Content coming in Batch 7...
8. Quantization-Aware Training (QAT)
Content coming in Batch 8...
9. LLM Quantization Techniques
Content coming in Batch 9...
10. Knowledge Distillation
Content coming in Batch 10...
11. Efficient Attention & KV Cache
Content coming in Batch 11...
12. Hardware-Specific Deployment
Content coming in Batch 12...
13. Case Studies & Best Practices
Content coming in Batch 13...