1. Introduction: The Scaling Challenge

Training large models on a single GPU is increasingly becoming impossible. Whether it's a 7B parameter LLM or a large vision transformer, modern deep learning demands distributed training. PyTorch offers two main strategies: Distributed Data Parallel (DDP) and Fully Sharded Data Parallel (FSDP).

Both are designed to scale training across multiple GPUs, but they solve fundamentally different problems:

Choosing the right strategy can mean the difference between training successfully and running out of memoryβ€”or between efficient training and wasting compute. Let's dive deep into how each works.

Prerequisites

This guide assumes familiarity with PyTorch basics and some understanding of GPU memory. We'll explain the distributed concepts from the ground up.

2. DDP Explained: Data Parallelism

Distributed Data Parallel (DDP) is PyTorch's workhorse for multi-GPU training. The concept is simple: replicate the entire model on each GPU, split the data across GPUs, and synchronize gradients after each backward pass.

DDP: How It Works
Distributed Data Parallel (DDP) GPU 0 Full Model (All Parameters) Data Batch 0 Gradients GPU 1 Full Model (All Parameters) Data Batch 1 Gradients GPU N Full Model (All Parameters) Data Batch N Gradients AllReduce Sync Gradients Across GPUs Each GPU holds: Full Model + Optimizer State + Gradients Memory per GPU = Model Size Γ— ~4 (with Adam)

Figure 1: DDP replicates the full model on each GPU. Gradients are synchronized via AllReduce after each backward pass.

Key Characteristics of DDP

DDP Advantages

Simplicity: Easy to implement and debug
Performance: Highly optimized, minimal overhead
Stability: Battle-tested across millions of training runs
Flexibility: Works with any model architecture

3. FSDP Explained: Sharding Everything

Fully Sharded Data Parallel (FSDP) takes a radically different approach. Instead of replicating the model, FSDP shards parameters, gradients, and optimizer states across all GPUs. Each GPU only holds a fraction of the model at any time.

FSDP is inspired by Microsoft's ZeRO (Zero Redundancy Optimizer) and provides similar memory savings. The key insight: at any point during training, a GPU only needs the parameters it's currently computing withβ€”not the entire model.

FSDP: How It Works
Fully Sharded Data Parallel (FSDP) GPU 0 Params Shard 0 (1/N of model) Optimizer Shard 0 (1/N of states) Gradient Shard 0 (1/N of grads) Data Batch 0 GPU 1 Params Shard 1 (1/N of model) Optimizer Shard 1 (1/N of states) Gradient Shard 1 (1/N of grads) Data Batch 1 GPU N Params Shard N (1/N of model) Optimizer Shard N (1/N of states) Gradient Shard N (1/N of grads) Data Batch N Each GPU holds: 1/N Model + 1/N Optimizer + 1/N Gradients Memory per GPU = Model Size / N (massive savings!) AllGather reconstructs full params on-demand during forward/backward

Figure 2: FSDP shards everything across GPUs. Parameters are gathered only when needed and immediately discarded.

FSDP's Sharding Strategies

FSDP offers different sharding levels to balance memory savings with communication overhead:

FSDP Advantages

Memory efficiency: Train models that don't fit on single GPU
Scalability: Linear memory reduction with more GPUs
Mixed precision: Native support for FP16/BF16 training
Activation checkpointing: Built-in integration

4. Head-to-Head Comparison

Let's compare DDP and FSDP across the key dimensions that matter for real-world training:

Aspect DDP FSDP
Memory per GPU Full model + optimizer + gradients 1/N of each (sharded)
Communication AllReduce gradients only AllGather params + ReduceScatter grads
Max Model Size Limited by single GPU memory Limited by total GPU memory
Training Speed Faster (less communication) Slightly slower (more communication)
Setup Complexity Simple (1-2 lines) More complex (wrapping policies)
Debugging Easy (full model visible) Harder (sharded state)
Checkpointing Standard PyTorch methods Requires special handling
Best For Models that fit in GPU memory Large models, memory-constrained
Communication Overhead

FSDP's communication pattern is more complex than DDP's. While DDP only syncs gradients, FSDP must gather parameters before each forward/backward and scatter gradients after. This overhead is usually acceptable for large models but can hurt throughput for smaller ones.

5. Code Examples

DDP: Simple and Effective

Setting up DDP requires minimal code changes:

Python train_ddp.py
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

def setup(rank, world_size):
    """Initialize distributed process group."""
    dist.init_process_group(
        backend="nccl",
        rank=rank,
        world_size=world_size
    )
    torch.cuda.set_device(rank)

def train_ddp(rank, world_size):
    setup(rank, world_size)
    
    # Create model and move to GPU
    model = MyModel().to(rank)
    
    # Wrap with DDP - that's it!
    model = DDP(model, device_ids=[rank])
    
    # Standard training loop
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    
    for epoch in range(num_epochs):
        for batch in dataloader:
            optimizer.zero_grad()
            loss = model(batch)
            loss.backward()  # Gradients auto-synced!
            optimizer.step()
    
    dist.destroy_process_group()

# Launch with: torchrun --nproc_per_node=4 train_ddp.py

FSDP: Memory-Efficient Training

FSDP requires more configuration but unlocks training of much larger models:

Python train_fsdp.py
import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    ShardingStrategy,
    MixedPrecision,
    CPUOffload,
)
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
)
import functools

def train_fsdp(rank, world_size):
    setup(rank, world_size)
    
    # Define mixed precision policy
    mixed_precision_policy = MixedPrecision(
        param_dtype=torch.bfloat16,
        reduce_dtype=torch.bfloat16,
        buffer_dtype=torch.bfloat16,
    )
    
    # Define auto-wrap policy (wrap transformer layers)
    wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls={TransformerBlock},  # Your layer class
    )
    
    # Create model
    model = MyLargeModel()
    
    # Wrap with FSDP
    model = FSDP(
        model,
        sharding_strategy=ShardingStrategy.FULL_SHARD,
        mixed_precision=mixed_precision_policy,
        auto_wrap_policy=wrap_policy,
        device_id=torch.cuda.current_device(),
        # Optional: offload params to CPU
        # cpu_offload=CPUOffload(offload_params=True),
    )
    
    # Use FSDP-aware optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    
    for epoch in range(num_epochs):
        for batch in dataloader:
            optimizer.zero_grad()
            
            # Forward pass - params gathered automatically
            loss = model(batch)
            
            # Backward pass - gradients sharded automatically
            loss.backward()
            optimizer.step()
    
    dist.destroy_process_group()

FSDP Checkpointing

Saving and loading FSDP checkpoints requires special handling:

Python fsdp_checkpoint.py
from torch.distributed.fsdp import (
    FullStateDictConfig,
    StateDictType,
)

def save_fsdp_checkpoint(model, optimizer, path, rank):
    """Save full state dict (gathered to rank 0)."""
    
    # Configure to gather full state dict
    save_policy = FullStateDictConfig(
        offload_to_cpu=True,
        rank0_only=True,
    )
    
    with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
        state_dict = model.state_dict()
        optim_state = FSDP.optim_state_dict(model, optimizer)
        
        if rank == 0:
            torch.save({
                "model": state_dict,
                "optimizer": optim_state,
            }, path)
    
    dist.barrier()  # Wait for rank 0 to finish saving

def load_fsdp_checkpoint(model, optimizer, path):
    """Load checkpoint into FSDP model."""
    
    checkpoint = torch.load(path, map_location="cpu")
    
    with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT):
        model.load_state_dict(checkpoint["model"])
        
    optim_state = FSDP.optim_state_dict_to_load(
        checkpoint["optimizer"], model, optimizer
    )
    optimizer.load_state_dict(optim_state)
Pro Tip: Sharded Checkpoints

For very large models, use StateDictType.SHARDED_STATE_DICT to save sharded checkpoints directly. This avoids gathering the entire model to rank 0 and is much faster for models with billions of parameters.

6. Decision Guide: When to Use Which

Here's a practical decision framework for choosing between DDP and FSDP:

🎯 Quick Decision Guide

βœ“
Use DDP Recommended

Model fits comfortably in GPU memory (with optimizer & gradients).
You want maximum training throughput.
You need simple debugging and checkpointing.

βœ“
Use FSDP When Memory-Constrained

Model doesn't fit on a single GPU.
You're training LLMs or very large vision models.
You want to maximize batch size for better convergence.

⚑
Performance Priority β†’ DDP

You have fast interconnect (NVLink/InfiniBand).
Model is medium-sized (fits with overhead).
Throughput is more important than batch size.

🧠
Large Model Training β†’ FSDP

Training 7B+ parameter models.
Using 80GB A100s but still memory-limited.
Need mixed precision + activation checkpointing.

Model Size Guidelines

As a rough guide, here's when to switch from DDP to FSDP based on model size and GPU memory:

Model Size GPU Memory Recommendation
< 1B params 40GB+ (A100) DDP
1-3B params 40GB (A100) DDP with gradient checkpointing
3-7B params 40GB (A100) FSDP (SHARD_GRAD_OP or FULL_SHARD)
7-13B params 80GB (A100) FSDP FULL_SHARD
13B+ params Any FSDP + CPU offload or tensor parallel

7. Migrating from DDP to FSDP

If you've outgrown DDP and need FSDP, here's what changes:

Python migration_diff.py
# Before: DDP
# ─────────────────────────────────────
from torch.nn.parallel import DistributedDataParallel as DDP

model = MyModel().to(rank)
model = DDP(model, device_ids=[rank])

# After: FSDP
# ─────────────────────────────────────
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

model = MyModel()  # Don't move to GPU first!
model = FSDP(
    model,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
    device_id=torch.cuda.current_device(),
)

Key Migration Considerations

Common Migration Pitfalls

1. Moving model to GPU before FSDP wrapping (causes OOM)
2. Not updating checkpoint code (corrupted saves)
3. Missing synchronization barriers (race conditions)
4. Wrong wrap policy (poor performance or errors)

8. Summary & Best Practices

Let's wrap up with actionable best practices for both approaches:

DDP Best Practices

FSDP Best Practices

Final Recommendations

Start with DDP β€” it's simpler and faster for most workloads.
Switch to FSDP when you hit memory limits or need larger batch sizes.
Profile both if you're unsure β€” actual performance depends on your specific setup.