1. Introduction: Why ADMM?
The Alternating Direction Method of Multipliers (ADMM) is one of the most versatile algorithms in optimization. Originally developed in the 1970s, it has experienced a remarkable renaissance in the era of big data and distributed computing.
ADMM shines in scenarios where you need to:
- Distribute computation across multiple machines or workers
- Handle constraints elegantly through variable splitting
- Combine different loss functions that are individually easy to optimize
- Scale to massive datasets that don't fit on a single machine
Think of ADMM as a "divide and conquer" strategy for optimization. It breaks hard problems into simpler subproblems that can be solved in parallel, then coordinates the solutions through a consensus mechanism.
ADMM was independently discovered by Glowinski & Marroco (1975) and Gabay & Mercier (1976). The seminal review by Boyd et al. (2011) sparked its modern popularity, showing its power for large-scale distributed optimization.
2. Mathematical Background
Before diving into ADMM, let's establish the mathematical foundation. ADMM is designed to solve problems of the following form:
The standard ADMM problem formulation
Where:
- $f(x)$ and $g(z)$ are convex functions (the objectives we want to minimize)
- $A \in \mathbb{R}^{p \times n}$ and $B \in \mathbb{R}^{p \times m}$ are matrices
- $x \in \mathbb{R}^n$ and $z \in \mathbb{R}^m$ are the optimization variables
- $c \in \mathbb{R}^p$ is a constant vector
The Augmented Lagrangian
ADMM works with the augmented Lagrangian, which adds a quadratic penalty term to the standard Lagrangian:
The augmented Lagrangian with penalty parameter ρ
The key insight is that:
- $y$ is the dual variable (Lagrange multiplier) enforcing the constraint
- $\rho > 0$ is the penalty parameter that controls constraint violation
- The quadratic term improves convergence without changing the optimal solution
The standard Lagrangian only has the linear term $y^T(Ax + Bz - c)$. Adding the quadratic penalty $\frac{\rho}{2}\|Ax + Bz - c\|_2^2$ makes the problem more strongly convex, leading to better numerical stability and faster convergence.
3. The ADMM Algorithm
ADMM minimizes the augmented Lagrangian by alternating between updating $x$, $z$, and $y$. The "alternating direction" name comes from optimizing one variable at a time while keeping others fixed.
Algorithm: ADMM
- Initialize: Choose $z^0$, $y^0$, and penalty $\rho > 0$
-
Repeat until convergence:
- x-update: $x^{k+1} = \arg\min_x \mathcal{L}_\rho(x, z^k, y^k)$
- z-update: $z^{k+1} = \arg\min_z \mathcal{L}_\rho(x^{k+1}, z, y^k)$
- y-update: $y^{k+1} = y^k + \rho(Ax^{k+1} + Bz^{k+1} - c)$
Let's understand each step:
The x-update
Minimize with respect to $x$, treating $z$ and $y$ as constants:
Where $u = y/\rho$ is the scaled dual variable
The z-update
Minimize with respect to $z$, using the fresh $x^{k+1}$:
The y-update (Dual Update)
This is simply a gradient ascent step on the dual problem:
Or equivalently: $u^{k+1} = u^k + Ax^{k+1} + Bz^{k+1} - c$
Figure 1: ADMM alternates between three updates, each using the most recent values from previous steps.
4. ADMM for Distributed ML
The real power of ADMM emerges in distributed settings. Consider training a machine learning model across $N$ workers, each with local data $\{(x_i, y_i)\}$:
Global objective: sum of local losses + regularizer
The challenge: data is distributed, but we want a single global model. ADMM solves this through a consensus formulation.
Consensus ADMM
The key idea: give each worker its own local copy $\theta_i$, then enforce they all agree on a global consensus variable $z$:
Consensus formulation: local variables must equal global consensus
Figure 2: Consensus ADMM architecture. Workers optimize locally with their data, server aggregates to enforce consensus.
The Distributed Algorithm
Algorithm: Consensus ADMM for Distributed ML
- Initialize: $z^0$, $u_i^0 = 0$ for all workers, choose $\rho > 0$
-
Repeat until convergence:
- Local update (parallel): Each worker $i$ solves: $$\theta_i^{k+1} = \arg\min_{\theta_i} f_i(\theta_i) + \frac{\rho}{2}\|\theta_i - z^k + u_i^k\|_2^2$$
- Global update (server): Compute consensus: $$z^{k+1} = \frac{1}{N}\sum_{i=1}^{N}(\theta_i^{k+1} + u_i^k) + \text{prox}_{r/\rho N}$$
- Dual update (parallel): Each worker updates: $$u_i^{k+1} = u_i^k + \theta_i^{k+1} - z^{k+1}$$
ADMM only requires workers to send/receive one vector per iteration (the parameter vector $\theta_i$). This is much more efficient than methods that require gradient accumulation across all workers every step.
5. Convergence Guarantees
One of ADMM's strengths is its solid theoretical foundation. Under mild conditions, ADMM provides convergence guarantees that practitioners can rely on.
Convergence Theorem
Assume $f$ and $g$ are closed, proper, convex functions, and the unaugmented Lagrangian $\mathcal{L}_0$ has a saddle point. Then ADMM satisfies:
1. Residual convergence: $Ax^k + Bz^k - c \to 0$ as $k \to \infty$
2. Objective convergence: $f(x^k) + g(z^k) \to p^*$ (optimal value)
3. Dual convergence: $y^k \to y^*$ (optimal dual variable)
Convergence Rate
For general convex problems, ADMM achieves an $O(1/k)$ convergence rate, meaning the residuals decrease as $1/k$ after $k$ iterations. This can be improved under stronger assumptions:
| Problem Class | Convergence Rate | Conditions |
|---|---|---|
| General convex | $O(1/k)$ | f, g convex |
| Strongly convex | $O(\gamma^k)$ (linear) | f or g strongly convex |
| Smooth + strongly convex | $O(\gamma^k)$ (linear) | Accelerated variants |
Stopping Criteria
In practice, we monitor two residuals to decide when to stop:
Stop when both $\|r^k\|_2 \leq \epsilon^{\text{pri}}$ and $\|s^k\|_2 \leq \epsilon^{\text{dual}}$.
6. Python Implementation
Let's implement consensus ADMM for distributed logistic regression. This is a common use case in federated learning and privacy-preserving ML.
Basic ADMM Solver
import numpy as np
from typing import Callable, Tuple, List
class ADMMSolver:
"""Generic ADMM solver for consensus optimization."""
def __init__(
self,
n_features: int,
n_workers: int,
rho: float = 1.0,
max_iter: int = 100,
tol: float = 1e-4
):
self.n_features = n_features
self.n_workers = n_workers
self.rho = rho
self.max_iter = max_iter
self.tol = tol
# Initialize variables
self.z = np.zeros(n_features) # Global consensus
self.theta = [np.zeros(n_features) for _ in range(n_workers)]
self.u = [np.zeros(n_features) for _ in range(n_workers)]
self.history = {'primal_residual': [], 'dual_residual': []}
def local_update(
self,
worker_id: int,
local_solver: Callable
) -> np.ndarray:
"""
Perform local θ-update for a single worker.
local_solver should minimize: f_i(θ) + (ρ/2)||θ - z + u||²
"""
target = self.z - self.u[worker_id]
self.theta[worker_id] = local_solver(target, self.rho)
return self.theta[worker_id]
def global_update(self, regularizer: Callable = None):
"""
Perform global z-update (consensus averaging).
z = (1/N) Σ(θᵢ + uᵢ) [+ proximal step if regularizer]
"""
z_old = self.z.copy()
# Average of (θ + u) across workers
self.z = np.mean([
self.theta[i] + self.u[i]
for i in range(self.n_workers)
], axis=0)
# Apply proximal operator for regularization
if regularizer is not None:
self.z = regularizer(self.z, self.rho * self.n_workers)
return z_old
def dual_update(self):
"""Update dual variables for all workers."""
for i in range(self.n_workers):
self.u[i] = self.u[i] + self.theta[i] - self.z
def compute_residuals(self, z_old: np.ndarray) -> Tuple[float, float]:
"""Compute primal and dual residuals."""
# Primal residual: ||θᵢ - z||
primal = np.sqrt(sum(
np.sum((self.theta[i] - self.z)**2)
for i in range(self.n_workers)
))
# Dual residual: ρ||z - z_old||
dual = self.rho * np.sqrt(self.n_workers) * np.linalg.norm(self.z - z_old)
return primal, dual
def converged(self, primal: float, dual: float) -> bool:
"""Check convergence criteria."""
return primal < self.tol and dual < self.tol
Distributed Logistic Regression
import numpy as np
from scipy.optimize import minimize
def logistic_loss(theta, X, y, target, rho):
"""
Logistic loss + ADMM proximal term.
L(θ) = Σ log(1 + exp(-yᵢ·xᵢᵀθ)) + (ρ/2)||θ - target||²
"""
n = X.shape[0]
logits = y * (X @ theta)
# Numerically stable log-sum-exp
log_loss = np.mean(np.logaddexp(0, -logits))
# ADMM proximal term
proximal = (rho / 2) * np.sum((theta - target)**2)
return log_loss + proximal
def logistic_gradient(theta, X, y, target, rho):
"""Gradient of logistic loss + proximal term."""
n = X.shape[0]
logits = y * (X @ theta)
probs = 1 / (1 + np.exp(logits)) # σ(-y·xᵀθ)
grad_loss = -X.T @ (y * probs) / n
grad_proximal = rho * (theta - target)
return grad_loss + grad_proximal
def local_logistic_solver(X, y):
"""
Returns a solver function for local logistic regression.
"""
def solver(target, rho):
result = minimize(
logistic_loss,
x0=target, # Warm start
args=(X, y, target, rho),
jac=logistic_gradient,
method='L-BFGS-B',
options={'maxiter': 50}
)
return result.x
return solver
def l1_proximal(z, lambd):
"""Proximal operator for L1 regularization (soft thresholding)."""
return np.sign(z) * np.maximum(np.abs(z) - lambd, 0)
def l2_proximal(z, lambd):
"""Proximal operator for L2 regularization."""
return z / (1 + lambd)
Full Training Loop
import numpy as np
from admm_core import ADMMSolver
from admm_logistic import local_logistic_solver, l1_proximal
def train_distributed_logistic(
X_splits: List[np.ndarray],
y_splits: List[np.ndarray],
lambda_l1: float = 0.01,
rho: float = 1.0,
max_iter: int = 100,
verbose: bool = True
):
"""
Train logistic regression using consensus ADMM.
Args:
X_splits: List of feature matrices, one per worker
y_splits: List of label vectors, one per worker
lambda_l1: L1 regularization strength
rho: ADMM penalty parameter
"""
n_workers = len(X_splits)
n_features = X_splits[0].shape[1]
# Initialize solver
solver = ADMMSolver(
n_features=n_features,
n_workers=n_workers,
rho=rho,
max_iter=max_iter
)
# Create local solvers for each worker
local_solvers = [
local_logistic_solver(X_splits[i], y_splits[i])
for i in range(n_workers)
]
# L1 regularization proximal operator
def regularizer(z, scale):
return l1_proximal(z, lambda_l1 / scale)
# ADMM iterations
for k in range(max_iter):
# 1. Local updates (can be parallelized!)
for i in range(n_workers):
solver.local_update(i, local_solvers[i])
# 2. Global update with regularization
z_old = solver.global_update(regularizer)
# 3. Dual update
solver.dual_update()
# 4. Check convergence
primal, dual = solver.compute_residuals(z_old)
solver.history['primal_residual'].append(primal)
solver.history['dual_residual'].append(dual)
if verbose and k % 10 == 0:
print(f"Iter {k:3d}: primal={primal:.6f}, dual={dual:.6f}")
if solver.converged(primal, dual):
print(f"Converged at iteration {k}")
break
return solver.z, solver.history
# Example usage
if __name__ == "__main__":
# Generate synthetic data
np.random.seed(42)
n_samples, n_features = 10000, 50
n_workers = 4
X = np.random.randn(n_samples, n_features)
true_theta = np.random.randn(n_features)
y = np.sign(X @ true_theta + 0.1 * np.random.randn(n_samples))
# Split data across workers
split_size = n_samples // n_workers
X_splits = [X[i*split_size:(i+1)*split_size] for i in range(n_workers)]
y_splits = [y[i*split_size:(i+1)*split_size] for i in range(n_workers)]
# Train!
theta, history = train_distributed_logistic(
X_splits, y_splits,
lambda_l1=0.01,
rho=1.0
)
The penalty parameter $\rho$ significantly affects convergence. Too small → slow convergence; too large → numerical instability. A common heuristic: start with $\rho = 1$ and use adaptive ρ adjustment based on residual ratios.
7. Applications & Examples
ADMM's flexibility makes it applicable to many machine learning problems. Here are some common applications:
1. LASSO Regression
The classic sparse regression problem with L1 penalty:
ADMM formulation: split into $f(\theta) = \frac{1}{2}\|X\theta - y\|_2^2$ and $g(z) = \lambda\|z\|_1$, with constraint $\theta = z$. The z-update becomes soft thresholding!
2. Federated Learning
Train a global model across devices without centralizing data:
- Each device keeps its data locally (privacy)
- Devices send only model updates to server
- ADMM ensures consensus on the global model
3. Constrained Deep Learning
Enforce constraints on neural network weights:
- Weight bounds: $\|W\|_\infty \leq c$ (robustness)
- Sparsity: Prune weights during training
- Low-rank: Compress layers via nuclear norm
Comparison with Other Methods
| Method | Communication | Convergence | Best For |
|---|---|---|---|
| ADMM | O(d) per iter | O(1/k) | Constrained, sparse, distributed |
| SGD | O(d) per iter | O(1/√k) | Large-scale deep learning |
| FedAvg | O(d) per round | Varies | Federated, non-iid data |
| CoCoA | O(d) per iter | O(1/k) | Dual methods, kernels |
8. Summary & Further Reading
ADMM is a powerful tool for distributed optimization that combines the benefits of decomposition methods with the convergence guarantees of augmented Lagrangian approaches.
Key Takeaways
- Flexibility: Handle constraints, regularizers, and distributed data naturally
- Convergence: Guaranteed O(1/k) rate for convex problems, often faster in practice
- Parallelization: Local updates are embarrassingly parallel
- Communication: Only requires exchanging model parameters, not gradients
When to Use ADMM
- Problems with natural decomposition structure
- Constraints that are easy to handle separately
- Distributed settings with moderate communication costs
- When convergence guarantees matter
Further Reading
- Boyd et al. (2011) - "Distributed Optimization and Statistical Learning via ADMM" - The definitive survey
- Parikh & Boyd (2014) - "Proximal Algorithms" - Deep dive into proximal operators
- Zhang & Kwok (2014) - "Asynchronous Distributed ADMM" - Handling stragglers
- Hong et al. (2016) - "Convergence Analysis of ADMM" - Theoretical foundations
Try implementing ADMM for your own distributed ML problem! Start with the consensus formulation, and experiment with different ρ values. For production systems, consider asynchronous variants that handle heterogeneous workers and network delays.