Skip to main content

Command Palette

Search for a command to run...

Custom training loop from scratch in JAX

Writing a training loop that you can actually control

Published
10 min read
Custom training loop from scratch in JAX
W

I'm a Machine Learning Engineer passionate about building production-ready ML systems for the African market. With experience in TensorFlow, Keras, and Python-based workflows, I help teams bridge the gap between machine learning research and real-world deployment—especially on resource-constrained devices. I'm also a Google Developer Expert in AI. I regularly speak at tech conferences including PyCon Africa, DevFest Kampala, DevFest Nairobi and more and also write technical articles on AI/ML here.

For the past three weeks, we've been building up to this moment.

Week 1 taught us that JAX is fast. Week 2 showed us how to eliminate loops with vmap and compute gradients with grad. Week 3 gave us Flax NNX—a way to define neural networks that feels like PyTorch but runs like JAX.

But we haven't actually trained anything yet.

Today, that changes. We're going to write a complete training loop from scratch: forward pass, loss calculation, gradient computation, parameter updates, and evaluation. No model.fit(). No magic. Just explicit, controllable, JIT-compiled code that you own completely.

By the end of this article, you'll have trained a CNN on MNIST and understood every single line of code that made it happen.

Why Write Your Own Training Loop?

If you've used Keras, you know how easy training can be:

model.fit(x_train, y_train, epochs=10)

One line. Done. So why would anyone write hundreds of lines to do the same thing?

Because model.fit() is a black box. It works until it doesn't. And when it doesn't, when your loss explodes, when you need gradient accumulation, when you're debugging a custom layer, when you need to log something specific every 37 steps, you're stuck.

The training loop is where machine learning actually happens. If you don't control it, you don't control your model.

In production systems, you almost always need:

  • Custom metrics that aren't built into the framework

  • Gradient clipping to prevent exploding gradients

  • Learning rate schedules that change based on validation loss

  • Mixed precision training for speed

  • Gradient accumulation when batches don't fit in memory

  • Early stopping based on custom criteria

Writing your own loop means you can implement any of these with a few lines of code.

The Four Beats of Training

Every training loop, regardless of framework, follows the same rhythm:

  1. Forward Pass: Run input through the model to get predictions

  2. Loss Calculation: Compare predictions to ground truth

  3. Backward Pass: Compute gradients of the loss with respect to parameters

  4. Parameter Update: Adjust parameters using the gradients

In PyTorch, this looks like:

optimizer.zero_grad()        # Clear old gradients
predictions = model(x)       # Forward pass
loss = loss_fn(predictions, y)  # Loss calculation
loss.backward()              # Backward pass
optimizer.step()             # Parameter update

PyTorch hides a lot here. loss.backward() traverses a hidden computational graph and secretly updates .grad attributes on every parameter.

JAX makes everything explicit.

The Optax Library

Optax is JAX's optimization library. It handles gradient transformations and parameter updates.

The key insight about Optax is that optimizers are composable. Instead of monolithic optimizer classes, Optax gives you building blocks that you chain together:

import optax

# A simple Adam optimizer
optimizer = optax.adam(learning_rate=0.001)

# Adam with gradient clipping
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),  # First, clip gradients
    optax.adam(learning_rate=0.001)   # Then, apply Adam
)

# Adam with weight decay (AdamW)
optimizer = optax.adamw(learning_rate=0.001, weight_decay=0.01)

This composability is powerful. Need to add gradient clipping to your existing optimizer? Add one line. Need a custom transformation? Write a function that takes gradients and returns modified gradients.

Common Optax Optimizers

Optimizer Usage
optax.sgd(lr) Basic stochastic gradient descent
optax.adam(lr) Adam (most common default)
optax.adamw(lr, weight_decay) Adam with decoupled weight decay
optax.sgd(lr, momentum=0.9) SGD with momentum
optax.rmsprop(lr) RMSprop

Common Gradient Transformations

Transformation Purpose
optax.clip_by_global_norm(max_norm) Prevent exploding gradients
optax.add_decayed_weights(decay) L2 regularization
optax.scale_by_schedule(schedule_fn) Learning rate scheduling

The nnx.Optimizer Wrapper

Optax is a functional library, it works with raw JAX arrays. But our models are NNX objects with structured parameters. We need a bridge.

nnx.Optimizer is that bridge:

from flax import nnx
import optax

# Create the Optax optimizer
tx = optax.adam(learning_rate=0.001)

# Wrap it with nnx.Optimizer
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)

The wrt parameter is critical. It stands for "with respect to" and tells the optimizer which variables to update. nnx.Param means "only update trainable parameters", not batch normalization statistics or other state.

Once wrapped, updating parameters is simple:

optimizer.update(model, grads)

This applies the gradients to the model's parameters in place. No need to manually extract parameters, apply updates, and put them back.

Computing Gradients with nnx.value_and_grad

In PyTorch, you call loss.backward() to compute gradients. In JAX, we use value_and_grad.

Here's the pattern:

def loss_fn(model, x, y):
    logits = model(x)
    loss = compute_loss(logits, y)
    return loss

# Get both the loss value and the gradients
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)

Critical rule: nnx.value_and_grad computes gradients with respect to the first argument of the function. That's why model must be the first parameter.

The grads object has the same structure as the model's parameters—it's a pytree where each leaf is the gradient for the corresponding parameter.

The Complete Training Step

Let's put it all together. Here's a complete, JIT-compiled training step:

@nnx.jit
def train_step(model, optimizer, x, y):
    """Execute one training step."""
    
    def loss_fn(model):
        logits = model(x)
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, 
            labels=y
        ).mean()
        return loss, logits
    
    # Compute loss and gradients
    (loss, logits), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    
    # Update parameters
    optimizer.update(model, grads)
    
    # Compute accuracy for logging
    predictions = jnp.argmax(logits, axis=-1)
    accuracy = jnp.mean(predictions == y)
    
    return loss, accuracy

A few things to note:

  1. @nnx.jit: The entire function is JIT-compiled. The first call traces and compiles; subsequent calls execute the compiled code.

  2. has_aux=True: Our loss function returns (loss, logits). The has_aux=True flag tells value_and_grad that the function returns auxiliary data alongside the loss. Gradients are only computed for the first returned value.

  3. In-place updates: optimizer.update(model, grads) modifies the model in place. NNX handles the functional-to-stateful conversion internally.

Training a CNN on MNIST

Let's train a real model on real data. We'll use the CNN from Week 3 and the MNIST dataset.

Setup and Data Loading

import jax
import jax.numpy as jnp
from flax import nnx
import optax
import tensorflow_datasets as tfds
import tensorflow as tf

# Disable TensorFlow GPU to avoid conflicts with JAX
tf.config.set_visible_devices([], 'GPU')

# Load MNIST
def load_mnist(batch_size=32):
    """Load and preprocess MNIST dataset."""
    
    def preprocess(sample):
        image = tf.cast(sample['image'], tf.float32) / 255.0
        label = sample['label']
        return {'image': image, 'label': label}
    
    train_ds = tfds.load('mnist', split='train')
    test_ds = tfds.load('mnist', split='test')
    
    train_ds = (train_ds
                .map(preprocess)
                .shuffle(1024)
                .batch(batch_size, drop_remainder=True)
                .prefetch(1))
    
    test_ds = (test_ds
               .map(preprocess)
               .batch(batch_size, drop_remainder=True)
               .prefetch(1))
    
    return train_ds, test_ds

train_ds, test_ds = load_mnist(batch_size=32)

Define the Model

We'll use a slightly simplified version of the CNN from Week 3:

from functools import partial

class CNN(nnx.Module):
    """Convolutional Neural Network for MNIST."""
    
    def __init__(self, *, rngs: nnx.Rngs):
        self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
        self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
        self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
        self.linear2 = nnx.Linear(256, 10, rngs=rngs)
        self.pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
    
    def __call__(self, x):
        x = nnx.relu(self.conv1(x))
        x = self.pool(x)
        x = nnx.relu(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)  # Flatten
        x = nnx.relu(self.linear1(x))
        x = self.linear2(x)
        return x

Initialize Model and Optimizer

# Initialize model
model = CNN(rngs=nnx.Rngs(0))

# Test forward pass
dummy_input = jnp.ones((1, 28, 28, 1))
dummy_output = model(dummy_input)
print(f"Model output shape: {dummy_output.shape}")  # (1, 10)

# Initialize optimizer
learning_rate = 0.005
momentum = 0.9
tx = optax.adamw(learning_rate, momentum)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)

# Display model structure
nnx.display(model)

Define Training and Evaluation Steps

@nnx.jit
def train_step(model, optimizer, batch):
    """Single training step."""
    
    def loss_fn(model):
        logits = model(batch['image'])
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits,
            labels=batch['label']
        ).mean()
        return loss, logits
    
    (loss, logits), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizer.update(model, grads)
    
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == batch['label'])
    return loss, accuracy

@nnx.jit
def eval_step(model, batch):
    """Single evaluation step."""
    logits = model(batch['image'])
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits,
        labels=batch['label']
    ).mean()
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == batch['label'])
    return loss, accuracy

The Training Loop

# Training configuration
num_epochs = 5
train_steps_per_epoch = 1000
eval_steps = 200

print("Starting training...")
print("=" * 60)

for epoch in range(num_epochs):
    # Training
    model.train()  # Set model to training mode
    train_loss, train_acc = 0.0, 0.0
    
    for step, batch in enumerate(train_ds.as_numpy_iterator()):
        if step >= train_steps_per_epoch:
            break
        
        # Convert to JAX arrays
        batch = {k: jnp.array(v) for k, v in batch.items()}
        
        loss, acc = train_step(model, optimizer, batch)
        train_loss += loss
        train_acc += acc
    
    train_loss /= train_steps_per_epoch
    train_acc /= train_steps_per_epoch
    
    # Evaluation
    model.eval()  # Set model to evaluation mode
    eval_loss, eval_acc = 0.0, 0.0
    eval_batches = 0
    
    for step, batch in enumerate(test_ds.as_numpy_iterator()):
        if step >= eval_steps:
            break
        
        batch = {k: jnp.array(v) for k, v in batch.items()}
        
        loss, acc = eval_step(model, batch)
        eval_loss += loss
        eval_acc += acc
        eval_batches += 1
    
    eval_loss /= eval_batches
    eval_acc /= eval_batches
    
    print(f"Epoch {epoch + 1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Eval Loss:  {eval_loss:.4f} | Eval Acc:  {eval_acc:.4f}")

print("Training complete!")

Using nnx.MultiMetric for Cleaner Tracking

NNX provides a MultiMetric helper for accumulating metrics across batches:

# Define metrics
metrics = nnx.MultiMetric(
    accuracy=nnx.metrics.Accuracy(),
    loss=nnx.metrics.Average('loss'),
)

@nnx.jit
def train_step_with_metrics(model, optimizer, metrics, batch):
    def loss_fn(model):
        logits = model(batch['image'])
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=batch['label']
        ).mean()
        return loss, logits
    
    (loss, logits), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizer.update(model, grads)
    
    # Update metrics in place
    metrics.update(loss=loss, logits=logits, labels=batch['label'])

# In training loop:
for batch in train_ds:
    train_step_with_metrics(model, optimizer, metrics, batch)

# Get aggregated metrics
results = metrics.compute()
print(f"Loss: {results['loss']:.4f}, Accuracy: {results['accuracy']:.4f}")

# Reset for next epoch
metrics.reset()

This is cleaner than manual accumulation and handles edge cases (like different batch sizes) correctly.

Inference: Using the Trained Model

After training, making predictions is straightforward:

model.eval()  # Ensure evaluation mode (deterministic)

@nnx.jit
def predict(model, images):
    logits = model(images)
    return jnp.argmax(logits, axis=-1)

# Get predictions
test_batch = next(test_ds.as_numpy_iterator())
predictions = predict(model, jnp.array(test_batch['image']))

print(f"Predictions: {predictions[:10]}")
print(f"Actual: {test_batch['label'][:10]}")

Exercises

  1. Add gradient clipping: Modify the optimizer to include optax.clip_by_global_norm(1.0). Train the model and compare the loss curves.

  2. Learning rate schedule: Implement a learning rate that decays over time using optax.warmup_cosine_decay_schedule. Check the Optax documentation for the parameters.

  3. Track more metrics: Add precision and recall tracking alongside accuracy. You'll need to compute true positives, false positives, and false negatives.

  4. Early stopping: Modify the training loop to stop if validation accuracy doesn't improve for 3 consecutive epochs.

Quick Reference

import optax
from flax import nnx

# Optimizer Setup
tx = optax.adam(learning_rate=0.001)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)

# Training Step Pattern
@nnx.jit
def train_step(model, optimizer, batch):
    def loss_fn(model):
        logits = model(batch['x'])
        loss = compute_loss(logits, batch['y'])
        return loss, logits
    
    (loss, logits), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizer.update(model, grads)
    return loss

# Evaluation Step Pattern
@nnx.jit  
def eval_step(model, batch):
    logits = model(batch['x'])
    loss = compute_loss(logits, batch['y'])
    return loss

# Training/Eval Mode
model.train()  # Enable dropout, stochastic behavior
model.eval()   # Deterministic inference

What's Next

We've trained a model. But what happens when training goes wrong? What if loss becomes NaN? What if shapes don't match? What if the model silently produces garbage?

Next week, we dive into Optax in more depth: learning rate schedules, gradient clipping, weight decay, and building custom optimizer stacks. We'll make our training more robust and production-ready.

Machine Learning

Part 1 of 24

In this series, I will be sharing how to kickstart a career in Machine Learning with a concentration on TensorFlow

Up next

Transformations That Change Everything

Automatic vectorization with vmap and gradients with grad in Jax