1. Introduction: The Scaling Challenge
Training large models on a single GPU is increasingly becoming impossible. Whether it's a 7B parameter LLM or a large vision transformer, modern deep learning demands distributed training. PyTorch offers two main strategies: Distributed Data Parallel (DDP) and Fully Sharded Data Parallel (FSDP).
Both are designed to scale training across multiple GPUs, but they solve fundamentally different problems:
- DDP replicates the entire model on each GPU and synchronizes gradients
- FSDP shards model parameters, gradients, and optimizer states across GPUs
Choosing the right strategy can mean the difference between training successfully and running out of memoryβor between efficient training and wasting compute. Let's dive deep into how each works.
This guide assumes familiarity with PyTorch basics and some understanding of GPU memory. We'll explain the distributed concepts from the ground up.
2. DDP Explained: Data Parallelism
Distributed Data Parallel (DDP) is PyTorch's workhorse for multi-GPU training. The concept is simple: replicate the entire model on each GPU, split the data across GPUs, and synchronize gradients after each backward pass.
Figure 1: DDP replicates the full model on each GPU. Gradients are synchronized via AllReduce after each backward pass.
Key Characteristics of DDP
- Full replication: Each GPU has a complete copy of the model
- Gradient synchronization: Uses efficient AllReduce to average gradients
- Overlapped communication: Gradient sync overlaps with backward pass
- Simple to use: Just wrap your model with
DistributedDataParallel
Simplicity: Easy to implement and debug
Performance: Highly optimized, minimal overhead
Stability: Battle-tested across millions of training runs
Flexibility: Works with any model architecture
3. FSDP Explained: Sharding Everything
Fully Sharded Data Parallel (FSDP) takes a radically different approach. Instead of replicating the model, FSDP shards parameters, gradients, and optimizer states across all GPUs. Each GPU only holds a fraction of the model at any time.
FSDP is inspired by Microsoft's ZeRO (Zero Redundancy Optimizer) and provides similar memory savings. The key insight: at any point during training, a GPU only needs the parameters it's currently computing withβnot the entire model.
Figure 2: FSDP shards everything across GPUs. Parameters are gathered only when needed and immediately discarded.
FSDP's Sharding Strategies
FSDP offers different sharding levels to balance memory savings with communication overhead:
- FULL_SHARD: Shard parameters, gradients, and optimizer states (maximum memory savings)
- SHARD_GRAD_OP: Shard gradients and optimizer states only (less communication)
- NO_SHARD: No sharding, similar to DDP (useful for debugging)
Memory efficiency: Train models that don't fit on single GPU
Scalability: Linear memory reduction with more GPUs
Mixed precision: Native support for FP16/BF16 training
Activation checkpointing: Built-in integration
4. Head-to-Head Comparison
Let's compare DDP and FSDP across the key dimensions that matter for real-world training:
| Aspect | DDP | FSDP |
|---|---|---|
| Memory per GPU | Full model + optimizer + gradients | 1/N of each (sharded) |
| Communication | AllReduce gradients only | AllGather params + ReduceScatter grads |
| Max Model Size | Limited by single GPU memory | Limited by total GPU memory |
| Training Speed | Faster (less communication) | Slightly slower (more communication) |
| Setup Complexity | Simple (1-2 lines) | More complex (wrapping policies) |
| Debugging | Easy (full model visible) | Harder (sharded state) |
| Checkpointing | Standard PyTorch methods | Requires special handling |
| Best For | Models that fit in GPU memory | Large models, memory-constrained |
FSDP's communication pattern is more complex than DDP's. While DDP only syncs gradients, FSDP must gather parameters before each forward/backward and scatter gradients after. This overhead is usually acceptable for large models but can hurt throughput for smaller ones.
5. Code Examples
DDP: Simple and Effective
Setting up DDP requires minimal code changes:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
"""Initialize distributed process group."""
dist.init_process_group(
backend="nccl",
rank=rank,
world_size=world_size
)
torch.cuda.set_device(rank)
def train_ddp(rank, world_size):
setup(rank, world_size)
# Create model and move to GPU
model = MyModel().to(rank)
# Wrap with DDP - that's it!
model = DDP(model, device_ids=[rank])
# Standard training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for epoch in range(num_epochs):
for batch in dataloader:
optimizer.zero_grad()
loss = model(batch)
loss.backward() # Gradients auto-synced!
optimizer.step()
dist.destroy_process_group()
# Launch with: torchrun --nproc_per_node=4 train_ddp.py
FSDP: Memory-Efficient Training
FSDP requires more configuration but unlocks training of much larger models:
import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
MixedPrecision,
CPUOffload,
)
from torch.distributed.fsdp.wrap import (
transformer_auto_wrap_policy,
)
import functools
def train_fsdp(rank, world_size):
setup(rank, world_size)
# Define mixed precision policy
mixed_precision_policy = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.bfloat16,
buffer_dtype=torch.bfloat16,
)
# Define auto-wrap policy (wrap transformer layers)
wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={TransformerBlock}, # Your layer class
)
# Create model
model = MyLargeModel()
# Wrap with FSDP
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision_policy,
auto_wrap_policy=wrap_policy,
device_id=torch.cuda.current_device(),
# Optional: offload params to CPU
# cpu_offload=CPUOffload(offload_params=True),
)
# Use FSDP-aware optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for epoch in range(num_epochs):
for batch in dataloader:
optimizer.zero_grad()
# Forward pass - params gathered automatically
loss = model(batch)
# Backward pass - gradients sharded automatically
loss.backward()
optimizer.step()
dist.destroy_process_group()
FSDP Checkpointing
Saving and loading FSDP checkpoints requires special handling:
from torch.distributed.fsdp import (
FullStateDictConfig,
StateDictType,
)
def save_fsdp_checkpoint(model, optimizer, path, rank):
"""Save full state dict (gathered to rank 0)."""
# Configure to gather full state dict
save_policy = FullStateDictConfig(
offload_to_cpu=True,
rank0_only=True,
)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
state_dict = model.state_dict()
optim_state = FSDP.optim_state_dict(model, optimizer)
if rank == 0:
torch.save({
"model": state_dict,
"optimizer": optim_state,
}, path)
dist.barrier() # Wait for rank 0 to finish saving
def load_fsdp_checkpoint(model, optimizer, path):
"""Load checkpoint into FSDP model."""
checkpoint = torch.load(path, map_location="cpu")
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
model.load_state_dict(checkpoint["model"])
optim_state = FSDP.optim_state_dict_to_load(
checkpoint["optimizer"], model, optimizer
)
optimizer.load_state_dict(optim_state)
For very large models, use StateDictType.SHARDED_STATE_DICT to save sharded
checkpoints directly. This avoids gathering the entire model to rank 0 and is much faster
for models with billions of parameters.
6. Decision Guide: When to Use Which
Here's a practical decision framework for choosing between DDP and FSDP:
π― Quick Decision Guide
Use DDP Recommended
Model fits comfortably in GPU memory (with optimizer & gradients).
You want maximum training throughput.
You need simple debugging and checkpointing.
Use FSDP When Memory-Constrained
Model doesn't fit on a single GPU.
You're training LLMs or very large vision models.
You want to maximize batch size for better convergence.
Performance Priority β DDP
You have fast interconnect (NVLink/InfiniBand).
Model is medium-sized (fits with overhead).
Throughput is more important than batch size.
Large Model Training β FSDP
Training 7B+ parameter models.
Using 80GB A100s but still memory-limited.
Need mixed precision + activation checkpointing.
Model Size Guidelines
As a rough guide, here's when to switch from DDP to FSDP based on model size and GPU memory:
| Model Size | GPU Memory | Recommendation |
|---|---|---|
| < 1B params | 40GB+ (A100) | DDP |
| 1-3B params | 40GB (A100) | DDP with gradient checkpointing |
| 3-7B params | 40GB (A100) | FSDP (SHARD_GRAD_OP or FULL_SHARD) |
| 7-13B params | 80GB (A100) | FSDP FULL_SHARD |
| 13B+ params | Any | FSDP + CPU offload or tensor parallel |
7. Migrating from DDP to FSDP
If you've outgrown DDP and need FSDP, here's what changes:
# Before: DDP
# βββββββββββββββββββββββββββββββββββββ
from torch.nn.parallel import DistributedDataParallel as DDP
model = MyModel().to(rank)
model = DDP(model, device_ids=[rank])
# After: FSDP
# βββββββββββββββββββββββββββββββββββββ
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
model = MyModel() # Don't move to GPU first!
model = FSDP(
model,
sharding_strategy=ShardingStrategy.FULL_SHARD,
device_id=torch.cuda.current_device(),
)
Key Migration Considerations
- Model initialization: Don't call
.to(device)before FSDP wrapping - Checkpointing: Update save/load code to handle sharded state
- Learning rate: May need adjustment due to different gradient handling
- Batch size: You can likely increase it (more memory available)
- Wrap policy: Define which layers to wrap for best performance
1. Moving model to GPU before FSDP wrapping (causes OOM)
2. Not updating checkpoint code (corrupted saves)
3. Missing synchronization barriers (race conditions)
4. Wrong wrap policy (poor performance or errors)
8. Summary & Best Practices
Let's wrap up with actionable best practices for both approaches:
DDP Best Practices
- Use gradient accumulation to simulate larger batch sizes
- Enable find_unused_parameters=False unless you need it (faster)
- Use DistributedSampler for proper data distribution
- Sync batch norm with
torch.nn.SyncBatchNorm.convert_sync_batchnorm - Pin memory in DataLoader for faster GPU transfers
FSDP Best Practices
- Use transformer_auto_wrap_policy for transformer models
- Enable mixed precision (BF16 preferred on Ampere+)
- Combine with activation checkpointing for maximum memory savings
- Use SHARD_GRAD_OP if communication is your bottleneck
- Profile with torch.profiler to identify bottlenecks
- Use sharded checkpoints for models > 10B parameters
Start with DDP β it's simpler and faster for most workloads.
Switch to FSDP when you hit memory limits or need larger batch sizes.
Profile both if you're unsure β actual performance depends on your specific setup.