# Optax: Optimizers You Can Compose Like LEGO

In our [last article](https://kambale.dev/training-loop-in-jax) on this blog in March, we built a complete training loop. We used `optax.adamw()` to create an optimizer, wrapped it in `nnx.Optimizer`, and watched our model learn.

But we barely scratched the surface.

Optax isn't just a collection of optimizers. It's a **compositional system** for building gradient transformations. Instead of monolithic optimizer classes with dozens of parameters, Optax gives you small, focused building blocks that you chain together.

Want Adam with gradient clipping? Chain them. Want SGD with momentum, weight decay, and a cosine learning rate schedule? Chain them. Want something completely custom that doesn't exist in any other framework? Build it from primitives.

By the end of this article, you'll understand:

1.  How Optax's compositional design works
    
2.  How to prevent exploding gradients with clipping
    
3.  How to implement learning rate schedules without a separate scheduler object
    
4.  How to apply different optimization strategies to different parameters
    
5.  How to build production-grade optimizer stacks
    

Let's take apart the LEGO box.

## The Optax Mental Model

In PyTorch, an optimizer is an object that holds references to your parameters:

```python
# PyTorch
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
optimizer.step()  # Updates parameters in place
```

In Optax, an optimizer is a **pair of functions**:

1.  `init(params)` → Creates the optimizer state (momentum buffers, etc.)
    
2.  `update(grads, state, params)` → Returns `(updates, new_state)`
    

You then apply the updates to your parameters separately:

```python
# Pure Optax (without NNX)
import optax

optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(params)

# In training loop:
updates, opt_state = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
```

This separation of concerns—computing updates vs. applying them—is what makes Optax so flexible. Each transformation only needs to know about gradients; it doesn't care where they came from or where they're going.

When we use `nnx.Optimizer`, it handles the state management for us:

```python
# With NNX
optimizer = nnx.Optimizer(model, optax.adam(0.001), wrt=nnx.Param)
optimizer.update(model, grads)  # Handles everything internally
```

But understanding the underlying pattern helps when you need to do something custom.

## Composing Optimizers with optax.chain

The core of Optax's power is `optax.chain()`. It takes multiple gradient transformations and applies them in sequence:

```python
import optax

# Each transformation modifies the gradients, then passes them to the next
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),      # First: clip gradients
    optax.scale_by_adam(),                # Second: compute Adam scaling
    optax.add_decayed_weights(0.01),      # Third: add weight decay
    optax.scale(-0.001)                   # Fourth: scale by -learning_rate
)
```

The gradients flow through each transformation like water through pipes. Each pipe modifies the flow before passing it on.

### What optax.adam() Actually Is

Here's something that surprises most people: `optax.adam()` isn't a primitive. It's a convenience wrapper around a chain:

```python
# optax.adam(learning_rate) is roughly equivalent to:
optax.chain(
    optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
    optax.scale(-learning_rate)
)
```

And `optax.adamw()` adds weight decay:

```python
# optax.adamw(learning_rate, weight_decay) is roughly:
optax.chain(
    optax.scale_by_adam(),
    optax.add_decayed_weights(weight_decay),
    optax.scale(-learning_rate)
)
```

Understanding this lets you build exactly what you need.

## Gradient Clipping: Preventing Explosions

Deep networks, especially RNNs and Transformers, are prone to **exploding gradients**. One bad batch can produce gradients so large that a single update destroys your model.

Gradient clipping caps the magnitude of gradients before they're applied.

### Global Norm Clipping

The most common approach clips by the global norm—the total magnitude across all parameters:

```python
optimizer = optax.chain(
    optax.clip_by_global_norm(max_norm=1.0),
    optax.adam(learning_rate=0.001)
)
```

If the global norm of your gradients exceeds 1.0, they're scaled down proportionally. If it's below 1.0, nothing happens.

This is what you want 90% of the time. It preserves the relative magnitudes between different parameters while preventing any single update from being catastrophically large.

### Value Clipping

Sometimes you want to clip each gradient element independently:

```python
optimizer = optax.chain(
    optax.clip(max_delta=1.0),  # Clip each element to [-1, 1]
    optax.adam(learning_rate=0.001)
)
```

This is more aggressive and can distort the gradient direction, so use it carefully.

### Monitoring Gradient Norms

Before blindly adding clipping, it's useful to know what your gradient norms actually look like:

```python
import jax
import jax.numpy as jnp

def compute_grad_norm(grads):
    """Compute the global L2 norm of gradients."""
    leaves = jax.tree_util.tree_leaves(grads)
    return jnp.sqrt(sum(jnp.sum(g ** 2) for g in leaves))

# In training:
loss, grads = nnx.value_and_grad(loss_fn)(model)
grad_norm = compute_grad_norm(grads)
print(f"Gradient norm: {grad_norm:.4f}")
```

If you see norms of 100+ regularly, you need clipping. If they're consistently below 1, clipping won't change anything.

## Learning Rate Schedules

In PyTorch, learning rate scheduling requires a separate `lr_scheduler` object that you call after each step or epoch:

```python
# PyTorch
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000)

for batch in data:
    optimizer.step()
    scheduler.step()  # Don't forget this!
```

In Optax, the schedule is baked into the optimizer. No separate object, no extra `step()` call:

```python
# Optax
schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=0.001,
    warmup_steps=1000,
    decay_steps=10000
)

optimizer = optax.adam(learning_rate=schedule)
```

The schedule is a function that takes a step count and returns a learning rate. Optax calls it automatically.

### Common Schedules

**Constant (no schedule):**

```python
optimizer = optax.adam(learning_rate=0.001)
```

**Linear warmup, then constant:**

```python
schedule = optax.linear_schedule(
    init_value=0.0,
    end_value=0.001,
    transition_steps=1000
)
```

**Warmup, then cosine decay:**

```python
schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,           # Start at 0
    peak_value=0.001,         # Warm up to this
    warmup_steps=1000,        # Over this many steps
    decay_steps=9000,         # Then decay over this many steps
    end_value=0.0001          # Down to this final value
)
```

**Exponential decay:**

```python
schedule = optax.exponential_decay(
    init_value=0.001,
    transition_steps=1000,
    decay_rate=0.96
)
```

**Piecewise constant (step decay):**

```python
schedule = optax.piecewise_constant_schedule(
    init_value=0.001,
    boundaries_and_scales={
        5000: 0.1,   # At step 5000, multiply LR by 0.1
        8000: 0.1,   # At step 8000, multiply by 0.1 again
    }
)
```

### Visualizing Schedules

It's helpful to plot your schedule before training:

```python
import matplotlib.pyplot as plt
import jax.numpy as jnp

schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0,
    peak_value=0.001,
    warmup_steps=1000,
    decay_steps=9000
)

steps = jnp.arange(10000)
lrs = [schedule(step) for step in steps]

plt.figure(figsize=(10, 4))
plt.plot(steps, lrs)
plt.xlabel('Step')
plt.ylabel('Learning Rate')
plt.title('Warmup Cosine Decay Schedule')
plt.grid(True)
plt.show()
```

## Weight Decay: Regularization Done Right

Weight decay prevents overfitting by penalizing large weights. But there's a subtle difference between L2 regularization and true weight decay.

**L2 regularization** adds a term to the loss:

```plaintext
loss = original_loss + λ * ||weights||²
```

**Weight decay** subtracts directly from the weights:

```plaintext
weights = weights - lr * (grads + λ * weights)
```

For SGD, these are mathematically equivalent. For adaptive optimizers like Adam, they're not. The paper "Decoupled Weight Decay Regularization" showed that true weight decay works better with Adam—hence AdamW.

In Optax:

```python
# AdamW: Adam with decoupled weight decay
optimizer = optax.adamw(learning_rate=0.001, weight_decay=0.01)

# Or build it manually:
optimizer = optax.chain(
    optax.scale_by_adam(),
    optax.add_decayed_weights(weight_decay=0.01),
    optax.scale(-0.001)
)
```

### Excluding Bias Terms from Weight Decay

A common practice is to apply weight decay only to weight matrices, not to biases or layer norm parameters. Optax supports this with masks:

```python
from flax import nnx
import optax

# Create a mask that's True for kernel/weight parameters, False for biases
def create_weight_decay_mask(params):
    def should_decay(path, _):
        # path is a tuple like ('linear1', 'kernel') or ('linear1', 'bias')
        return 'kernel' in path or 'weight' in path
    
    return jax.tree_util.tree_map_with_path(
        lambda path, _: should_decay(path, _),
        params
    )

# Use optax.masked to apply weight decay selectively
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(learning_rate=0.001, weight_decay=0.01, mask=create_weight_decay_mask)
)
```

## Per-Parameter Optimization

Sometimes you want different learning rates for different parts of your model. This is common when fine-tuning: you might want a small learning rate for pretrained layers and a larger one for the new classification head.

Optax handles this with `optax.multi_transform`:

```python
import optax
from flax import nnx

# Define different optimizers for different parameter groups
optimizer = optax.multi_transform(
    transforms={
        'backbone': optax.adam(learning_rate=1e-5),   # Small LR for pretrained
        'head': optax.adam(learning_rate=1e-3),       # Larger LR for new layers
    },
    param_labels=param_labels  # A pytree matching params, with string labels
)
```

The tricky part is creating `param_labels`—a pytree with the same structure as your parameters, where each leaf is a string label.

Here's a practical example:

```python
from flax import nnx
import jax

class FineTuneModel(nnx.Module):
    def __init__(self, *, rngs: nnx.Rngs):
        # Pretrained backbone (we want small LR)
        self.backbone = nnx.Linear(784, 256, rngs=rngs)
        
        # New classification head (we want larger LR)
        self.head = nnx.Linear(256, 10, rngs=rngs)
    
    def __call__(self, x):
        x = nnx.relu(self.backbone(x))
        return self.head(x)

model = FineTuneModel(rngs=nnx.Rngs(0))

# Create parameter labels
def label_fn(path, _):
    """Assign labels based on parameter path."""
    path_str = '/'.join(str(p) for p in path)
    if 'backbone' in path_str:
        return 'backbone'
    else:
        return 'head'

# Get params and create labels
params = nnx.state(model, nnx.Param)
param_labels = jax.tree_util.tree_map_with_path(label_fn, params)

# Create the multi-transform optimizer
tx = optax.multi_transform(
    transforms={
        'backbone': optax.adam(learning_rate=1e-5),
        'head': optax.adam(learning_rate=1e-3),
    },
    param_labels=param_labels
)

optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
```

## Building a Production Optimizer Stack

Let's put everything together into a production-grade optimizer configuration:

```python
import optax

def create_optimizer(
    peak_learning_rate: float = 1e-3,
    warmup_steps: int = 1000,
    total_steps: int = 100000,
    weight_decay: float = 0.01,
    max_grad_norm: float = 1.0,
    end_learning_rate: float = 1e-5,
):
    """Create a production-ready optimizer with all the bells and whistles."""
    
    # Learning rate schedule: warmup then cosine decay
    schedule = optax.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=peak_learning_rate,
        warmup_steps=warmup_steps,
        decay_steps=total_steps - warmup_steps,
        end_value=end_learning_rate,
    )
    
    # Build the optimizer chain
    optimizer = optax.chain(
        # 1. Gradient clipping (prevent explosions)
        optax.clip_by_global_norm(max_grad_norm),
        
        # 2. Adam scaling
        optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8),
        
        # 3. Weight decay (applied after Adam scaling)
        optax.add_decayed_weights(weight_decay),
        
        # 4. Scale by negative learning rate (makes it gradient descent)
        optax.scale_by_schedule(lambda step: -schedule(step)),
    )
    
    return optimizer

# Usage
tx = create_optimizer(
    peak_learning_rate=1e-3,
    warmup_steps=1000,
    total_steps=50000,
    weight_decay=0.01,
    max_grad_norm=1.0,
)

optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
```

This gives you:

*   Gradient clipping to prevent explosions
    
*   Adam's adaptive learning rates
    
*   Decoupled weight decay for regularization
    
*   Warmup to stabilize early training
    
*   Cosine decay to fine-tune convergence
    

This is essentially what's used to train modern Transformers.

## Inspecting Optimizer State

Sometimes you need to peek inside the optimizer. With NNX:

```python
# Get the full optimizer state
opt_state = nnx.state(optimizer)

# The state is a nested structure containing:
# - Step count
# - Momentum buffers (for Adam's first moment)
# - Velocity buffers (for Adam's second moment)
# - Any other state from transformations in the chain

print(jax.tree_util.tree_map(lambda x: x.shape, opt_state))
```

This is useful for:

*   Debugging optimization issues
    
*   Implementing custom logging
    
*   Understanding memory usage (optimizer state can be 2-3x the model size!)
    

## Common Optimizer Recipes

Here are ready-to-use configurations for common scenarios:

### Standard Training (Most Cases)

```python
tx = optax.adamw(learning_rate=1e-3, weight_decay=0.01)
```

### Transformer Training

```python
schedule = optax.warmup_cosine_decay_schedule(
    init_value=0.0, peak_value=1e-4, warmup_steps=2000, decay_steps=98000
)
tx = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(learning_rate=schedule, weight_decay=0.1)
)
```

### Fine-Tuning (Small LR, No Warmup)

```python
tx = optax.adamw(learning_rate=2e-5, weight_decay=0.01)
```

### When Training is Unstable

```python
tx = optax.chain(
    optax.clip_by_global_norm(0.5),  # Aggressive clipping
    optax.sgd(learning_rate=0.01, momentum=0.9)  # SGD is more stable
)
```

## Exercises

1.  **Visualize your schedule**: Create a warmup + cosine decay schedule and plot it. Experiment with different warmup lengths.
    
2.  **Compare clipping strategies**: Train the same model with `clip_by_global_norm(1.0)`, `clip_by_global_norm(0.1)`, and no clipping. Monitor the gradient norms and final accuracy.
    
3.  **Implement a custom schedule**: Create a schedule that uses a high learning rate for the first 1000 steps, then drops to 1/10th, then decays linearly to zero.
    
4.  **Fine-tuning setup**: Take a model with two parts (backbone + head) and set up a multi\_transform optimizer with 10x different learning rates.
    

## Quick Reference

```python
import optax

# Basic Optimizers
optax.sgd(learning_rate, momentum=0.9)
optax.adam(learning_rate)
optax.adamw(learning_rate, weight_decay)

# Gradient Clipping
optax.clip_by_global_norm(max_norm)
optax.clip(max_delta)

# Learning Rate Schedules
optax.warmup_cosine_decay_schedule(init, peak, warmup_steps, decay_steps)
optax.exponential_decay(init, transition_steps, decay_rate)
optax.piecewise_constant_schedule(init, boundaries_and_scales)

# Composing
optax.chain(transform1, transform2, transform3)

# Building Blocks
optax.scale_by_adam()           # Adam's moment estimation
optax.add_decayed_weights(wd)   # Weight decay
optax.scale(-lr)                # Scale by learning rate
optax.scale_by_schedule(fn)     # Dynamic scaling

# With NNX
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
optimizer.update(model, grads)
```

## What's Next

We've built training loops and sophisticated optimizers. But what happens when things go wrong? What if your loss becomes NaN? What if shapes don't match?

Now that we are back, next week, we'll tackle **reliability engineering with Chex**. We'll learn to write assertions that catch bugs before they corrupt your model, validate tensor shapes at compile time, and hunt down the source of NaN values.

Because fast code that produces garbage is still garbage.

**Next week**: *Catching Bugs Before They Catch You; Reliability Engineering with Chex*
