Skip to main content

Command Palette

Search for a command to run...

Optax: Optimizers You Can Compose Like LEGO

Updated
11 min readView as Markdown
Optax: Optimizers You Can Compose Like LEGO
W

I'm a Machine Learning Engineer passionate about building production-ready ML systems for the African market. With experience in TensorFlow, Keras, and Python-based workflows, I help teams bridge the gap between machine learning research and real-world deployment—especially on resource-constrained devices. I'm also a Google Developer Expert in AI. I regularly speak at tech conferences including PyCon Africa, DevFest Kampala, DevFest Nairobi and more and also write technical articles on AI/ML here.

In our last article on this blog in March, we built a complete training loop. We used optax.adamw() to create an optimizer, wrapped it in nnx.Optimizer, and watched our model learn.

But we barely scratched the surface.

Optax isn't just a collection of optimizers. It's a compositional system for building gradient transformations. Instead of monolithic optimizer classes with dozens of parameters, Optax gives you small, focused building blocks that you chain together.

Want Adam with gradient clipping? Chain them. Want SGD with momentum, weight decay, and a cosine learning rate schedule? Chain them. Want something completely custom that doesn't exist in any other framework? Build it from primitives.

By the end of this article, you'll understand:

  1. How Optax's compositional design works

  2. How to prevent exploding gradients with clipping

  3. How to implement learning rate schedules without a separate scheduler object

  4. How to apply different optimization strategies to different parameters

  5. How to build production-grade optimizer stacks

Let's take apart the LEGO box.

The Optax Mental Model

In PyTorch, an optimizer is an object that holds references to your parameters:

# PyTorch
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer.step()  # Updates parameters in place

In Optax, an optimizer is a pair of functions:

  1. init(params) → Creates the optimizer state (momentum buffers, etc.)

  2. update(grads, state, params) → Returns (updates, new_state)

You then apply the updates to your parameters separately:

# Pure Optax (without NNX)
import optax

optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(params)

# In training loop:
updates, opt_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)

This separation of concerns—computing updates vs. applying them—is what makes Optax so flexible. Each transformation only needs to know about gradients; it doesn't care where they came from or where they're going.

When we use nnx.Optimizer, it handles the state management for us:

# With NNX
optimizer = nnx.Optimizer(model, optax.adam(0.001), wrt=nnx.Param)
optimizer.update(model, grads)  # Handles everything internally

But understanding the underlying pattern helps when you need to do something custom.

Composing Optimizers with optax.chain

The core of Optax's power is optax.chain(). It takes multiple gradient transformations and applies them in sequence:

import optax

# Each transformation modifies the gradients, then passes them to the next
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),      # First: clip gradients
    optax.scale_by_adam(),                # Second: compute Adam scaling
    optax.add_decayed_weights(0.01),      # Third: add weight decay
    optax.scale(-0.001)                   # Fourth: scale by -learning_rate
)

The gradients flow through each transformation like water through pipes. Each pipe modifies the flow before passing it on.

What optax.adam() Actually Is

Here's something that surprises most people: optax.adam() isn't a primitive. It's a convenience wrapper around a chain:

# optax.adam(learning_rate) is roughly equivalent to:
optax.chain(
    optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
    optax.scale(-learning_rate)
)

And optax.adamw() adds weight decay:

# optax.adamw(learning_rate, weight_decay) is roughly:
optax.chain(
    optax.scale_by_adam(),
    optax.add_decayed_weights(weight_decay),
    optax.scale(-learning_rate)
)

Understanding this lets you build exactly what you need.

Gradient Clipping: Preventing Explosions

Deep networks, especially RNNs and Transformers, are prone to exploding gradients. One bad batch can produce gradients so large that a single update destroys your model.

Gradient clipping caps the magnitude of gradients before they're applied.

Global Norm Clipping

The most common approach clips by the global norm—the total magnitude across all parameters:

optimizer = optax.chain(
    optax.clip_by_global_norm(max_norm=1.0),
    optax.adam(learning_rate=0.001)
)

If the global norm of your gradients exceeds 1.0, they're scaled down proportionally. If it's below 1.0, nothing happens.

This is what you want 90% of the time. It preserves the relative magnitudes between different parameters while preventing any single update from being catastrophically large.

Value Clipping

Sometimes you want to clip each gradient element independently:

optimizer = optax.chain(
    optax.clip(max_delta=1.0),  # Clip each element to [-1, 1]
    optax.adam(learning_rate=0.001)
)

This is more aggressive and can distort the gradient direction, so use it carefully.

Monitoring Gradient Norms

Before blindly adding clipping, it's useful to know what your gradient norms actually look like:

import jax
import jax.numpy as jnp

def compute_grad_norm(grads):
    """Compute the global L2 norm of gradients."""
    leaves = jax.tree_util.tree_leaves(grads)
    return jnp.sqrt(sum(jnp.sum(g ** 2) for g in leaves))

# In training:
loss, grads = nnx.value_and_grad(loss_fn)(model)
grad_norm = compute_grad_norm(grads)
print(f"Gradient norm: {grad_norm:.4f}")

If you see norms of 100+ regularly, you need clipping. If they're consistently below 1, clipping won't change anything.

Learning Rate Schedules

In PyTorch, learning rate scheduling requires a separate lr_scheduler object that you call after each step or epoch:

# PyTorch
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

for batch in data:
    optimizer.step()
    scheduler.step()  # Don't forget this!

In Optax, the schedule is baked into the optimizer. No separate object, no extra step() call:

# Optax
schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=0.001,
    warmup_steps=1000,
    decay_steps=10000
)

optimizer = optax.adam(learning_rate=schedule)

The schedule is a function that takes a step count and returns a learning rate. Optax calls it automatically.

Common Schedules

Constant (no schedule):

optimizer = optax.adam(learning_rate=0.001)

Linear warmup, then constant:

schedule = optax.linear_schedule(
    init_value=0.0,
    end_value=0.001,
    transition_steps=1000
)

Warmup, then cosine decay:

schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,           # Start at 0
    peak_value=0.001,         # Warm up to this
    warmup_steps=1000,        # Over this many steps
    decay_steps=9000,         # Then decay over this many steps
    end_value=0.0001          # Down to this final value
)

Exponential decay:

schedule = optax.exponential_decay(
    init_value=0.001,
    transition_steps=1000,
    decay_rate=0.96
)

Piecewise constant (step decay):

schedule = optax.piecewise_constant_schedule(
    init_value=0.001,
    boundaries_and_scales={
        5000: 0.1,   # At step 5000, multiply LR by 0.1
        8000: 0.1,   # At step 8000, multiply by 0.1 again
    }
)

Visualizing Schedules

It's helpful to plot your schedule before training:

import matplotlib.pyplot as plt
import jax.numpy as jnp

schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=0.001,
    warmup_steps=1000,
    decay_steps=9000
)

steps = jnp.arange(10000)
lrs = [schedule(step) for step in steps]

plt.figure(figsize=(10, 4))
plt.plot(steps, lrs)
plt.xlabel('Step')
plt.ylabel('Learning Rate')
plt.title('Warmup Cosine Decay Schedule')
plt.grid(True)
plt.show()

Weight Decay: Regularization Done Right

Weight decay prevents overfitting by penalizing large weights. But there's a subtle difference between L2 regularization and true weight decay.

L2 regularization adds a term to the loss:

loss = original_loss + λ * ||weights||²

Weight decay subtracts directly from the weights:

weights = weights - lr * (grads + λ * weights)

For SGD, these are mathematically equivalent. For adaptive optimizers like Adam, they're not. The paper "Decoupled Weight Decay Regularization" showed that true weight decay works better with Adam—hence AdamW.

In Optax:

# AdamW: Adam with decoupled weight decay
optimizer = optax.adamw(learning_rate=0.001, weight_decay=0.01)

# Or build it manually:
optimizer = optax.chain(
    optax.scale_by_adam(),
    optax.add_decayed_weights(weight_decay=0.01),
    optax.scale(-0.001)
)

Excluding Bias Terms from Weight Decay

A common practice is to apply weight decay only to weight matrices, not to biases or layer norm parameters. Optax supports this with masks:

from flax import nnx
import optax

# Create a mask that's True for kernel/weight parameters, False for biases
def create_weight_decay_mask(params):
    def should_decay(path, _):
        # path is a tuple like ('linear1', 'kernel') or ('linear1', 'bias')
        return 'kernel' in path or 'weight' in path
    
    return jax.tree_util.tree_map_with_path(
        lambda path, _: should_decay(path, _),
        params
    )

# Use optax.masked to apply weight decay selectively
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(learning_rate=0.001, weight_decay=0.01, mask=create_weight_decay_mask)
)

Per-Parameter Optimization

Sometimes you want different learning rates for different parts of your model. This is common when fine-tuning: you might want a small learning rate for pretrained layers and a larger one for the new classification head.

Optax handles this with optax.multi_transform:

import optax
from flax import nnx

# Define different optimizers for different parameter groups
optimizer = optax.multi_transform(
    transforms={
        'backbone': optax.adam(learning_rate=1e-5),   # Small LR for pretrained
        'head': optax.adam(learning_rate=1e-3),       # Larger LR for new layers
    },
    param_labels=param_labels  # A pytree matching params, with string labels
)

The tricky part is creating param_labels—a pytree with the same structure as your parameters, where each leaf is a string label.

Here's a practical example:

from flax import nnx
import jax

class FineTuneModel(nnx.Module):
    def __init__(self, *, rngs: nnx.Rngs):
        # Pretrained backbone (we want small LR)
        self.backbone = nnx.Linear(784, 256, rngs=rngs)
        
        # New classification head (we want larger LR)
        self.head = nnx.Linear(256, 10, rngs=rngs)
    
    def __call__(self, x):
        x = nnx.relu(self.backbone(x))
        return self.head(x)

model = FineTuneModel(rngs=nnx.Rngs(0))

# Create parameter labels
def label_fn(path, _):
    """Assign labels based on parameter path."""
    path_str = '/'.join(str(p) for p in path)
    if 'backbone' in path_str:
        return 'backbone'
    else:
        return 'head'

# Get params and create labels
params = nnx.state(model, nnx.Param)
param_labels = jax.tree_util.tree_map_with_path(label_fn, params)

# Create the multi-transform optimizer
tx = optax.multi_transform(
    transforms={
        'backbone': optax.adam(learning_rate=1e-5),
        'head': optax.adam(learning_rate=1e-3),
    },
    param_labels=param_labels
)

optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)

Building a Production Optimizer Stack

Let's put everything together into a production-grade optimizer configuration:

import optax

def create_optimizer(
    peak_learning_rate: float = 1e-3,
    warmup_steps: int = 1000,
    total_steps: int = 100000,
    weight_decay: float = 0.01,
    max_grad_norm: float = 1.0,
    end_learning_rate: float = 1e-5,
):
    """Create a production-ready optimizer with all the bells and whistles."""
    
    # Learning rate schedule: warmup then cosine decay
    schedule = optax.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=peak_learning_rate,
        warmup_steps=warmup_steps,
        decay_steps=total_steps - warmup_steps,
        end_value=end_learning_rate,
    )
    
    # Build the optimizer chain
    optimizer = optax.chain(
        # 1. Gradient clipping (prevent explosions)
        optax.clip_by_global_norm(max_grad_norm),
        
        # 2. Adam scaling
        optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
        
        # 3. Weight decay (applied after Adam scaling)
        optax.add_decayed_weights(weight_decay),
        
        # 4. Scale by negative learning rate (makes it gradient descent)
        optax.scale_by_schedule(lambda step: -schedule(step)),
    )
    
    return optimizer

# Usage
tx = create_optimizer(
    peak_learning_rate=1e-3,
    warmup_steps=1000,
    total_steps=50000,
    weight_decay=0.01,
    max_grad_norm=1.0,
)

optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)

This gives you:

  • Gradient clipping to prevent explosions

  • Adam's adaptive learning rates

  • Decoupled weight decay for regularization

  • Warmup to stabilize early training

  • Cosine decay to fine-tune convergence

This is essentially what's used to train modern Transformers.

Inspecting Optimizer State

Sometimes you need to peek inside the optimizer. With NNX:

# Get the full optimizer state
opt_state = nnx.state(optimizer)

# The state is a nested structure containing:
# - Step count
# - Momentum buffers (for Adam's first moment)
# - Velocity buffers (for Adam's second moment)
# - Any other state from transformations in the chain

print(jax.tree_util.tree_map(lambda x: x.shape, opt_state))

This is useful for:

  • Debugging optimization issues

  • Implementing custom logging

  • Understanding memory usage (optimizer state can be 2-3x the model size!)

Common Optimizer Recipes

Here are ready-to-use configurations for common scenarios:

Standard Training (Most Cases)

tx = optax.adamw(learning_rate=1e-3, weight_decay=0.01)

Transformer Training

schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0, peak_value=1e-4, warmup_steps=2000, decay_steps=98000
)
tx = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(learning_rate=schedule, weight_decay=0.1)
)

Fine-Tuning (Small LR, No Warmup)

tx = optax.adamw(learning_rate=2e-5, weight_decay=0.01)

When Training is Unstable

tx = optax.chain(
    optax.clip_by_global_norm(0.5),  # Aggressive clipping
    optax.sgd(learning_rate=0.01, momentum=0.9)  # SGD is more stable
)

Exercises

  1. Visualize your schedule: Create a warmup + cosine decay schedule and plot it. Experiment with different warmup lengths.

  2. Compare clipping strategies: Train the same model with clip_by_global_norm(1.0), clip_by_global_norm(0.1), and no clipping. Monitor the gradient norms and final accuracy.

  3. Implement a custom schedule: Create a schedule that uses a high learning rate for the first 1000 steps, then drops to 1/10th, then decays linearly to zero.

  4. Fine-tuning setup: Take a model with two parts (backbone + head) and set up a multi_transform optimizer with 10x different learning rates.

Quick Reference

import optax

# Basic Optimizers
optax.sgd(learning_rate, momentum=0.9)
optax.adam(learning_rate)
optax.adamw(learning_rate, weight_decay)

# Gradient Clipping
optax.clip_by_global_norm(max_norm)
optax.clip(max_delta)

# Learning Rate Schedules
optax.warmup_cosine_decay_schedule(init, peak, warmup_steps, decay_steps)
optax.exponential_decay(init, transition_steps, decay_rate)
optax.piecewise_constant_schedule(init, boundaries_and_scales)

# Composing
optax.chain(transform1, transform2, transform3)

# Building Blocks
optax.scale_by_adam()           # Adam's moment estimation
optax.add_decayed_weights(wd)   # Weight decay
optax.scale(-lr)                # Scale by learning rate
optax.scale_by_schedule(fn)     # Dynamic scaling

# With NNX
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
optimizer.update(model, grads)

What's Next

We've built training loops and sophisticated optimizers. But what happens when things go wrong? What if your loss becomes NaN? What if shapes don't match?

Now that we are back, next week, we'll tackle reliability engineering with Chex. We'll learn to write assertions that catch bugs before they corrupt your model, validate tensor shapes at compile time, and hunt down the source of NaN values.

Because fast code that produces garbage is still garbage.

Next week: Catching Bugs Before They Catch You; Reliability Engineering with Chex

Machine Learning

Part 1 of 25

In this series, I will be sharing how to kickstart a career in Machine Learning with a concentration on TensorFlow

Up next

Custom training loop from scratch in JAX

Writing a training loop that you can actually control