Optax: Optimizers You Can Compose Like LEGO

I'm a Machine Learning Engineer passionate about building production-ready ML systems for the African market. With experience in TensorFlow, Keras, and Python-based workflows, I help teams bridge the gap between machine learning research and real-world deployment—especially on resource-constrained devices. I'm also a Google Developer Expert in AI. I regularly speak at tech conferences including PyCon Africa, DevFest Kampala, DevFest Nairobi and more and also write technical articles on AI/ML here.
In our last article on this blog in March, we built a complete training loop. We used optax.adamw() to create an optimizer, wrapped it in nnx.Optimizer, and watched our model learn.
But we barely scratched the surface.
Optax isn't just a collection of optimizers. It's a compositional system for building gradient transformations. Instead of monolithic optimizer classes with dozens of parameters, Optax gives you small, focused building blocks that you chain together.
Want Adam with gradient clipping? Chain them. Want SGD with momentum, weight decay, and a cosine learning rate schedule? Chain them. Want something completely custom that doesn't exist in any other framework? Build it from primitives.
By the end of this article, you'll understand:
How Optax's compositional design works
How to prevent exploding gradients with clipping
How to implement learning rate schedules without a separate scheduler object
How to apply different optimization strategies to different parameters
How to build production-grade optimizer stacks
Let's take apart the LEGO box.
The Optax Mental Model
In PyTorch, an optimizer is an object that holds references to your parameters:
# PyTorch
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer.step() # Updates parameters in place
In Optax, an optimizer is a pair of functions:
init(params)→ Creates the optimizer state (momentum buffers, etc.)update(grads, state, params)→ Returns(updates, new_state)
You then apply the updates to your parameters separately:
# Pure Optax (without NNX)
import optax
optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(params)
# In training loop:
updates, opt_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
This separation of concerns—computing updates vs. applying them—is what makes Optax so flexible. Each transformation only needs to know about gradients; it doesn't care where they came from or where they're going.
When we use nnx.Optimizer, it handles the state management for us:
# With NNX
optimizer = nnx.Optimizer(model, optax.adam(0.001), wrt=nnx.Param)
optimizer.update(model, grads) # Handles everything internally
But understanding the underlying pattern helps when you need to do something custom.
Composing Optimizers with optax.chain
The core of Optax's power is optax.chain(). It takes multiple gradient transformations and applies them in sequence:
import optax
# Each transformation modifies the gradients, then passes them to the next
optimizer = optax.chain(
optax.clip_by_global_norm(1.0), # First: clip gradients
optax.scale_by_adam(), # Second: compute Adam scaling
optax.add_decayed_weights(0.01), # Third: add weight decay
optax.scale(-0.001) # Fourth: scale by -learning_rate
)
The gradients flow through each transformation like water through pipes. Each pipe modifies the flow before passing it on.
What optax.adam() Actually Is
Here's something that surprises most people: optax.adam() isn't a primitive. It's a convenience wrapper around a chain:
# optax.adam(learning_rate) is roughly equivalent to:
optax.chain(
optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
optax.scale(-learning_rate)
)
And optax.adamw() adds weight decay:
# optax.adamw(learning_rate, weight_decay) is roughly:
optax.chain(
optax.scale_by_adam(),
optax.add_decayed_weights(weight_decay),
optax.scale(-learning_rate)
)
Understanding this lets you build exactly what you need.
Gradient Clipping: Preventing Explosions
Deep networks, especially RNNs and Transformers, are prone to exploding gradients. One bad batch can produce gradients so large that a single update destroys your model.
Gradient clipping caps the magnitude of gradients before they're applied.
Global Norm Clipping
The most common approach clips by the global norm—the total magnitude across all parameters:
optimizer = optax.chain(
optax.clip_by_global_norm(max_norm=1.0),
optax.adam(learning_rate=0.001)
)
If the global norm of your gradients exceeds 1.0, they're scaled down proportionally. If it's below 1.0, nothing happens.
This is what you want 90% of the time. It preserves the relative magnitudes between different parameters while preventing any single update from being catastrophically large.
Value Clipping
Sometimes you want to clip each gradient element independently:
optimizer = optax.chain(
optax.clip(max_delta=1.0), # Clip each element to [-1, 1]
optax.adam(learning_rate=0.001)
)
This is more aggressive and can distort the gradient direction, so use it carefully.
Monitoring Gradient Norms
Before blindly adding clipping, it's useful to know what your gradient norms actually look like:
import jax
import jax.numpy as jnp
def compute_grad_norm(grads):
"""Compute the global L2 norm of gradients."""
leaves = jax.tree_util.tree_leaves(grads)
return jnp.sqrt(sum(jnp.sum(g ** 2) for g in leaves))
# In training:
loss, grads = nnx.value_and_grad(loss_fn)(model)
grad_norm = compute_grad_norm(grads)
print(f"Gradient norm: {grad_norm:.4f}")
If you see norms of 100+ regularly, you need clipping. If they're consistently below 1, clipping won't change anything.
Learning Rate Schedules
In PyTorch, learning rate scheduling requires a separate lr_scheduler object that you call after each step or epoch:
# PyTorch
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)
for batch in data:
optimizer.step()
scheduler.step() # Don't forget this!
In Optax, the schedule is baked into the optimizer. No separate object, no extra step() call:
# Optax
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=0.001,
warmup_steps=1000,
decay_steps=10000
)
optimizer = optax.adam(learning_rate=schedule)
The schedule is a function that takes a step count and returns a learning rate. Optax calls it automatically.
Common Schedules
Constant (no schedule):
optimizer = optax.adam(learning_rate=0.001)
Linear warmup, then constant:
schedule = optax.linear_schedule(
init_value=0.0,
end_value=0.001,
transition_steps=1000
)
Warmup, then cosine decay:
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, # Start at 0
peak_value=0.001, # Warm up to this
warmup_steps=1000, # Over this many steps
decay_steps=9000, # Then decay over this many steps
end_value=0.0001 # Down to this final value
)
Exponential decay:
schedule = optax.exponential_decay(
init_value=0.001,
transition_steps=1000,
decay_rate=0.96
)
Piecewise constant (step decay):
schedule = optax.piecewise_constant_schedule(
init_value=0.001,
boundaries_and_scales={
5000: 0.1, # At step 5000, multiply LR by 0.1
8000: 0.1, # At step 8000, multiply by 0.1 again
}
)
Visualizing Schedules
It's helpful to plot your schedule before training:
import matplotlib.pyplot as plt
import jax.numpy as jnp
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=0.001,
warmup_steps=1000,
decay_steps=9000
)
steps = jnp.arange(10000)
lrs = [schedule(step) for step in steps]
plt.figure(figsize=(10, 4))
plt.plot(steps, lrs)
plt.xlabel('Step')
plt.ylabel('Learning Rate')
plt.title('Warmup Cosine Decay Schedule')
plt.grid(True)
plt.show()
Weight Decay: Regularization Done Right
Weight decay prevents overfitting by penalizing large weights. But there's a subtle difference between L2 regularization and true weight decay.
L2 regularization adds a term to the loss:
loss = original_loss + λ * ||weights||²
Weight decay subtracts directly from the weights:
weights = weights - lr * (grads + λ * weights)
For SGD, these are mathematically equivalent. For adaptive optimizers like Adam, they're not. The paper "Decoupled Weight Decay Regularization" showed that true weight decay works better with Adam—hence AdamW.
In Optax:
# AdamW: Adam with decoupled weight decay
optimizer = optax.adamw(learning_rate=0.001, weight_decay=0.01)
# Or build it manually:
optimizer = optax.chain(
optax.scale_by_adam(),
optax.add_decayed_weights(weight_decay=0.01),
optax.scale(-0.001)
)
Excluding Bias Terms from Weight Decay
A common practice is to apply weight decay only to weight matrices, not to biases or layer norm parameters. Optax supports this with masks:
from flax import nnx
import optax
# Create a mask that's True for kernel/weight parameters, False for biases
def create_weight_decay_mask(params):
def should_decay(path, _):
# path is a tuple like ('linear1', 'kernel') or ('linear1', 'bias')
return 'kernel' in path or 'weight' in path
return jax.tree_util.tree_map_with_path(
lambda path, _: should_decay(path, _),
params
)
# Use optax.masked to apply weight decay selectively
optimizer = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adamw(learning_rate=0.001, weight_decay=0.01, mask=create_weight_decay_mask)
)
Per-Parameter Optimization
Sometimes you want different learning rates for different parts of your model. This is common when fine-tuning: you might want a small learning rate for pretrained layers and a larger one for the new classification head.
Optax handles this with optax.multi_transform:
import optax
from flax import nnx
# Define different optimizers for different parameter groups
optimizer = optax.multi_transform(
transforms={
'backbone': optax.adam(learning_rate=1e-5), # Small LR for pretrained
'head': optax.adam(learning_rate=1e-3), # Larger LR for new layers
},
param_labels=param_labels # A pytree matching params, with string labels
)
The tricky part is creating param_labels—a pytree with the same structure as your parameters, where each leaf is a string label.
Here's a practical example:
from flax import nnx
import jax
class FineTuneModel(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
# Pretrained backbone (we want small LR)
self.backbone = nnx.Linear(784, 256, rngs=rngs)
# New classification head (we want larger LR)
self.head = nnx.Linear(256, 10, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.backbone(x))
return self.head(x)
model = FineTuneModel(rngs=nnx.Rngs(0))
# Create parameter labels
def label_fn(path, _):
"""Assign labels based on parameter path."""
path_str = '/'.join(str(p) for p in path)
if 'backbone' in path_str:
return 'backbone'
else:
return 'head'
# Get params and create labels
params = nnx.state(model, nnx.Param)
param_labels = jax.tree_util.tree_map_with_path(label_fn, params)
# Create the multi-transform optimizer
tx = optax.multi_transform(
transforms={
'backbone': optax.adam(learning_rate=1e-5),
'head': optax.adam(learning_rate=1e-3),
},
param_labels=param_labels
)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
Building a Production Optimizer Stack
Let's put everything together into a production-grade optimizer configuration:
import optax
def create_optimizer(
peak_learning_rate: float = 1e-3,
warmup_steps: int = 1000,
total_steps: int = 100000,
weight_decay: float = 0.01,
max_grad_norm: float = 1.0,
end_learning_rate: float = 1e-5,
):
"""Create a production-ready optimizer with all the bells and whistles."""
# Learning rate schedule: warmup then cosine decay
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0,
peak_value=peak_learning_rate,
warmup_steps=warmup_steps,
decay_steps=total_steps - warmup_steps,
end_value=end_learning_rate,
)
# Build the optimizer chain
optimizer = optax.chain(
# 1. Gradient clipping (prevent explosions)
optax.clip_by_global_norm(max_grad_norm),
# 2. Adam scaling
optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
# 3. Weight decay (applied after Adam scaling)
optax.add_decayed_weights(weight_decay),
# 4. Scale by negative learning rate (makes it gradient descent)
optax.scale_by_schedule(lambda step: -schedule(step)),
)
return optimizer
# Usage
tx = create_optimizer(
peak_learning_rate=1e-3,
warmup_steps=1000,
total_steps=50000,
weight_decay=0.01,
max_grad_norm=1.0,
)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
This gives you:
Gradient clipping to prevent explosions
Adam's adaptive learning rates
Decoupled weight decay for regularization
Warmup to stabilize early training
Cosine decay to fine-tune convergence
This is essentially what's used to train modern Transformers.
Inspecting Optimizer State
Sometimes you need to peek inside the optimizer. With NNX:
# Get the full optimizer state
opt_state = nnx.state(optimizer)
# The state is a nested structure containing:
# - Step count
# - Momentum buffers (for Adam's first moment)
# - Velocity buffers (for Adam's second moment)
# - Any other state from transformations in the chain
print(jax.tree_util.tree_map(lambda x: x.shape, opt_state))
This is useful for:
Debugging optimization issues
Implementing custom logging
Understanding memory usage (optimizer state can be 2-3x the model size!)
Common Optimizer Recipes
Here are ready-to-use configurations for common scenarios:
Standard Training (Most Cases)
tx = optax.adamw(learning_rate=1e-3, weight_decay=0.01)
Transformer Training
schedule = optax.warmup_cosine_decay_schedule(
init_value=0.0, peak_value=1e-4, warmup_steps=2000, decay_steps=98000
)
tx = optax.chain(
optax.clip_by_global_norm(1.0),
optax.adamw(learning_rate=schedule, weight_decay=0.1)
)
Fine-Tuning (Small LR, No Warmup)
tx = optax.adamw(learning_rate=2e-5, weight_decay=0.01)
When Training is Unstable
tx = optax.chain(
optax.clip_by_global_norm(0.5), # Aggressive clipping
optax.sgd(learning_rate=0.01, momentum=0.9) # SGD is more stable
)
Exercises
Visualize your schedule: Create a warmup + cosine decay schedule and plot it. Experiment with different warmup lengths.
Compare clipping strategies: Train the same model with
clip_by_global_norm(1.0),clip_by_global_norm(0.1), and no clipping. Monitor the gradient norms and final accuracy.Implement a custom schedule: Create a schedule that uses a high learning rate for the first 1000 steps, then drops to 1/10th, then decays linearly to zero.
Fine-tuning setup: Take a model with two parts (backbone + head) and set up a multi_transform optimizer with 10x different learning rates.
Quick Reference
import optax
# Basic Optimizers
optax.sgd(learning_rate, momentum=0.9)
optax.adam(learning_rate)
optax.adamw(learning_rate, weight_decay)
# Gradient Clipping
optax.clip_by_global_norm(max_norm)
optax.clip(max_delta)
# Learning Rate Schedules
optax.warmup_cosine_decay_schedule(init, peak, warmup_steps, decay_steps)
optax.exponential_decay(init, transition_steps, decay_rate)
optax.piecewise_constant_schedule(init, boundaries_and_scales)
# Composing
optax.chain(transform1, transform2, transform3)
# Building Blocks
optax.scale_by_adam() # Adam's moment estimation
optax.add_decayed_weights(wd) # Weight decay
optax.scale(-lr) # Scale by learning rate
optax.scale_by_schedule(fn) # Dynamic scaling
# With NNX
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
optimizer.update(model, grads)
What's Next
We've built training loops and sophisticated optimizers. But what happens when things go wrong? What if your loss becomes NaN? What if shapes don't match?
Now that we are back, next week, we'll tackle reliability engineering with Chex. We'll learn to write assertions that catch bugs before they corrupt your model, validate tensor shapes at compile time, and hunt down the source of NaN values.
Because fast code that produces garbage is still garbage.
Next week: Catching Bugs Before They Catch You; Reliability Engineering with Chex



