1. Why Efficient Inference Matters
The gap between training and inference is often overlooked. During training, we optimize for:
- Accuracy improvement
- Convergence speed
- Memory during backpropagation
But in production inference, we optimize for:
- Throughput: How many predictions per second?
- Latency: How fast is each prediction?
- Memory: Can it fit on edge devices or with many concurrent requests?
- Energy: Power consumption directly impacts cost and sustainability
A typical state-of-the-art LLM inference scenario:
- A single inference request: ~100-500ms latency without optimization
- Bandwidth bottleneck: A 7B parameter model needs 14GB to load (FP16)
- Energy cost: A GPU like H100 draws 700W under full load
This is why model compression (pruning + quantization) is critical: 10-100x speedup is routinely achievable with minimal accuracy loss.
2. Pruning Fundamentals: Removing Unnecessary Parameters
Pruning is the process of setting certain weights to zero, removing parameters that contribute little to the model's output. The key insight: not all parameters are equally important.
Why Pruning Works
- Overparameterization: Neural networks are trained with far more parameters than strictly necessary
- Lottery Ticket Hypothesis: Dense networks contain subnetworks that can be trained to similar accuracy in fewer iterations
- Redundancy: Many weights with small magnitudes contribute minimally to predictions
Unstructured Pruning: Individual Weight Removal
Unstructured pruning removes individual weights from a matrix based on their magnitude. The sparsity pattern is irregular—any weight can be removed.
How It Works
- Magnitude Ranking: Sort all weights by absolute value
- Threshold Setting: Choose a percentile (e.g., remove bottom 90% of weights)
- Masking: Multiply weight matrix element-wise with a binary mask
- Fine-tuning: Retrain the remaining weights to recover accuracy
Pros & Cons
- ✓ Maximum flexibility: Can achieve very high sparsity (80-95%)
- ✓ Simple to implement: Just apply a binary mask
- ✗ Poor hardware utilization: Irregular memory access kills GPU bandwidth
- ✗ Limited speedup: Modern GPUs don't natively support unstructured sparsity
Structured Pruning: Channel and Filter Removal
Structured pruning removes entire channels, filters, or blocks at once. The sparsity pattern is regular and GPU-friendly.
Key Structures
Pros & Cons
- ✓ Hardware efficient: Dense matrix multiplications on GPU
- ✓ Straightforward implementation: Just skip computation
- ✓ Real speedups: 2-10x on modern GPUs
- ✗ Lower max sparsity: Usually 30-50% vs 80-95% for unstructured
- ✗ Accuracy-sparsity trade-off: Lose more accuracy for same sparsity level
N:M Sparsity: Hardware-Native Sparsity Patterns
N:M sparsity (also called m-way sparsity) is a compromise between unstructured and structured pruning. In each group of M consecutive elements, you keep exactly N non-zero values. The most common is 2:4 sparsity (keep 2 out of every 4 weights).
2:4 Sparsity Example
Consider a weight matrix where we apply 2:4 sparsity:
Why 2:4?
- 50% compression: 2 out of 4 weights → 50% sparsity
- GPU hardware native: NVIDIA Tensor Cores can execute directly
- Accuracy-efficient: Usually maintains 95-98% accuracy
- Dynamic shaping: Can be applied during training (QAT for sparsity)
N:M Sparsity on NVIDIA GPUs
NVIDIA first introduced native N:M sparsity support in Ampere architecture (A100, RTX A6000) and improved it in Ada (H100, RTX Ada) and Hopper (L40S).
How It Works on GPU
Supported NVIDIA GPUs (2:4 Sparsity)
- Ampere: A100, RTX A6000, RTX A5000, RTX 3090
- Ada: H100, RTX 6000 Ada, RTX 5880 Ada
- Hopper: H200 (4x throughput w/ 2:4 sparsity!)
- Not supported: V100, T4, older architectures
2:4 sparsity acceleration requires:
- cuSPARSELt library or TensorRT with sparsity support
- Weights to be structured in 2:4 pattern
- Not all frameworks support it natively (TensorRT does, PyTorch needs plugins)
Pruning CNNs vs Transformers
While pruning principles are universal, CNNs and Transformers respond differently to sparsity due to their architectural differences.
3. Software Stack & Tools for Sparsity
To apply sparsity in practice, you need libraries and frameworks that understand your target hardware:
Framework-Level Support
- PyTorch:
- Native structured pruning via `torch.nn.utils.prune`
- Plugin: `TorchSparse` for unstructured sparsity
- Quantization-aware + pruning (QAP): `torch.ao.quantization`
- Limited 2:4 sparsity support (need TensorRT)
- TensorFlow:
- `tf.keras.regularizers` for magnitude pruning
- `tensorflow_model_optimization` for structured/unstructured
- Quantization + pruning combined
Inference Engines
- TensorRT (NVIDIA): Best for 2:4 sparsity on GPU
- Native 2:4 sparse kernels
- INT8 + 2:4 sparsity combined
- Up to 8x speedup reported
- DeepSparse (Neural Magic): CPU-optimized sparsity
- Unstructured sparsity on CPU
- Vectorized execution
- No GPU needed
- Triton (University of Washington): Custom sparse kernels
- Write custom sparse ops in Python
- Auto-compile to GPU
- TVM (Apache): Universal sparse compiler
- Cross-platform sparsity support
- Supports both structured and unstructured
import torch
import torch.nn.utils.prune as prune
# Load your model
model = YourModel()
# Apply structured magnitude pruning to all Conv2d layers
for module in model.modules():
if isinstance(module, torch.nn.Conv2d):
# 30% of channels pruned
prune.ln_structured(
module,
name="weight",
amount=0.3,
n=2,
dim=0 # Prune output channels
)
# Make pruning permanent
prune.remove(module, 'weight')
# Fine-tune the pruned model
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(10):
for batch in train_loader:
output = model(batch)
loss = loss_fn(output, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
Don't prune uniformly! Early layers (close to input) can be pruned more aggressively (70-80%), while later layers need more precision. This is called layer-adaptive pruning.
4. Quantization Fundamentals: Reducing Precision
Quantization reduces the numerical precision of weights and activations from floating-point to fixed-point or integer representations. Unlike pruning which removes weights, quantization represents the same weights using fewer bits.
Number Formats: From FP32 to INT4
Neural networks are typically trained in FP32 (32-bit floating-point). For inference, we can use lower-precision formats:
Key Formats (For Inference)
- FP32: Baseline, 32-bit. Full precision, slow on modern hardware
- FP16: 16-bit float. Fast on GPU, slight precision loss (~99% accuracy)
- BF16: Google's Brain Float, wider exponent than FP16. Better for very large ranges (LLMs)
- INT8: 8-bit signed integer. 4x compression, requires calibration
- INT4: 4-bit integer, 8x compression. For LLMs, emerging technology
Quantization Theory: Scale and Zero-Point
Integer quantization maps floating-point values to integers using a linear quantization scheme:
Symmetric vs Asymmetric
- Symmetric: Zero-point at center. Simpler math, good for weights
- Asymmetric: Zero-point offset. Better for activations (usually not symmetric around 0)
Post-Training Quantization (PTQ)
PTQ is the simplest approach: train the model normally, then quantize weights/activations without retraining.
PTQ Workflow
Pros & Cons
- ✓ No retraining: Fast, one-shot quantization
- ✓ Simple to implement: Use libraries like TensorRT or ONNX quantizer
- ✗ Accuracy loss: ~2-5% depending on calibration
- ✗ Limited flexibility: Can't adapt during training
Quantization-Aware Training (QAT)
QAT simulates quantization during training, allowing the model to adapt to the reduced precision. The model learns to work with quantization instead of fighting it.
QAT Process
Pros & Cons
- ✓ Better accuracy: Model learns to work with quantization (~99%+ of baseline)
- ✓ Faster inference: Same speedup as PTQ but with minimal accuracy loss
- ✗ Requires retraining: Slower to deploy, needs labeled data
- ✗ Complex tuning: Need to tune quantization schedules, scales per layer
import torch
from torch.quantization import prepare_qat
# Load pre-trained model
model = ResNet50("pretrained")
model.eval()
# Enable QAT (insert fake quant ops)
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model = prepare_qat(model)
# Fine-tune with QAT for 5 epochs
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
for epoch in range(5):
for batch in train_loader:
images, labels = batch
output = model(images)
loss = loss_fn(output, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Convert to quantized model
model.eval()
quantized_model = torch.quantization.convert(model)
# Save quantized model for inference
torch.jit.save(torch.jit.trace(quantized_model, images), "model_int8.pt")
INT8 Tensor Cores & Hardware Acceleration
Modern GPUs (NVIDIA A100, H100) have dedicated INT8 Tensor Cores that can execute quantized operations in parallel.
Modern LLM Quantization Techniques
For large language models, quantization is more challenging because:
- Token-to-token generation is memory-bandwidth bound (not compute bound)
- Activations have dramatically different scales per layer
- Some weights/activations are extremely sensitive to precision
Advanced Techniques
- GPTQ (Post-training 4-bit): Quasi-Random Equalization of Weights. Reduces 7B model to 4-bit (1.7GB). ~100 GPU-hours to quantize.
- AWQ: Activation-Weighted Quantization. Analyzes which weights matter most per layer. Faster than GPTQ.
- SmoothQuant: Moves quantization burden from activations to weights via smooth transfer. Maintains accuracy better.
- Tensor-wise quantization: Different scale per output channel (for LLMs, per neuron in FFN)
Combining Pruning + Quantization: Extreme Compression
The real magic happens when you combine pruning and quantization:
A 7B parameter LLM can be compressed to 1.5-2 GB with pruning + INT8 quantization while retaining 95%+ accuracy. With 4-bit quantization + pruning, models fit on mobile devices!
Summary: Choosing Your Compression Strategy
| Scenario | Recommended Approach | Expected Speedup |
|---|---|---|
| Fast PTQ for inference | INT8 PTQ (calibrate on val data) | 4x (+ structure for 2x more) |
| Maximum accuracy | INT8 QAT (5-20 epochs) | 4x (+ 99%+ accuracy) |
| CNN inference (mobile) | Structured pruning 50% + INT8 PTQ | 8-16x |
| LLM on edge (phone/RPi) | 4-bit (GPTQ/AWQ) + pruning | 8-10x smaller model |
| GPU cluster training | 2:4 sparsity + mixed-precision FP16 | 2-3x throughput |
Start with INT8 PTQ (takes minutes, simple). If accuracy drops >1%, move to QAT. Only use 4-bit for LLMs where calibration is hard.