Skip to main content

Command Palette

Search for a command to run...

Building Neural Networks with Flax NNX

Flax NNX: Neural Networks That Feel Like PyTorch

Published
11 min read
Building Neural Networks with Flax NNX
W

I'm a Machine Learning Engineer passionate about building production-ready ML systems for the African market. With experience in TensorFlow, Keras, and Python-based workflows, I help teams bridge the gap between machine learning research and real-world deployment—especially on resource-constrained devices. I'm also a Google Developer Expert in AI. I regularly speak at tech conferences including PyCon Africa, DevFest Kampala, DevFest Nairobi and more and also write technical articles on AI/ML here.

Over the past two weeks, we've learned that JAX is fast (jit), that it eliminates loops (vmap), and that it computes gradients automatically (grad). These are powerful primitives.

But if you've been following along, you might have noticed something uncomfortable: we've been passing arrays around manually, tracking parameters in tuples, and writing functions that return updated state. It works, but it doesn't feel like building neural networks. It feels like accounting.

If you've used PyTorch, you know how natural it is to define a model as a class, call model(x), and let the framework handle the rest. That ergonomic experience is what made PyTorch the dominant framework for research.

Today, we get that experience in JAX.

Flax NNX is a neural network library that gives you PyTorch-style classes and methods while compiling down to JAX's XLA backend. You write object-oriented code. JAX runs functional code. NNX bridges the gap.

By the end of this article, we'll have:

  1. Understood why NNX exists and what problem it solves

  2. Built our first neural network using nnx.Module

  3. Learned how NNX handles the critical issue of random number generation

  4. Created a CNN for image classification—the same architecture used in real production systems

Let's build something real.

The State Problem

To understand why Flax NNX is necessary, we need to understand the fundamental tension between PyTorch and JAX.

How PyTorch Handles State

In PyTorch, a model is an object that contains its own parameters:

import torch.nn as nn

class PyTorchMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(784, 128)
        self.linear2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = torch.relu(self.linear1(x))
        return self.linear2(x)

model = PyTorchMLP()
output = model(input_data)  # Parameters are hidden inside

This is intuitive. The nn.Linear layers create their own weight matrices internally. You don't see them, you don't manage them, they just exist.

But this "hidden state" creates problems for JAX. When you call jax.jit on a function, JAX traces through it and compiles a computational graph. If that function secretly reads or writes to hidden variables, JAX can't see those operations, and it can't optimize them.

How Pure JAX Handles State

JAX demands pure functions: functions where the output depends only on the inputs, with no side effects. If you want to update parameters, you must pass them in and return them out:

def forward(params, x):
    x = jax.nn.relu(x @ params['w1'] + params['b1'])
    return x @ params['w2'] + params['b2']

def train_step(params, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    new_params = update_params(params, grads)
    return new_params, loss  # Must return the new state

# You're always passing params around
params, loss = train_step(params, x_batch, y_batch)
params, loss = train_step(params, x_batch, y_batch)
params, loss = train_step(params, x_batch, y_batch)

This is explicit and JIT-compatible, but it's tedious. For a model with dozens of layers, managing nested dictionaries of parameters becomes error-prone.

How Flax NNX Solves This

NNX gives you the PyTorch interface while secretly doing the JAX bookkeeping. You write classes with attributes. NNX intercepts those attributes, extracts the parameters when needed for JIT compilation, and updates them in place when you're done.

from flax import nnx

class NNX_MLP(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(784, 128, rngs=rngs)
        self.linear2 = nnx.Linear(128, 10, rngs=rngs)
    
    def __call__(self, x):
        x = nnx.relu(self.linear1(x))
        return self.linear2(x)

model = NNX_MLP(rngs=nnx.Rngs(0))
output = model(input_data)  # Looks just like PyTorch

The syntax is nearly identical to PyTorch. But under the hood, NNX can extract the parameters, pass them through JAX transformations, and put them back—all without you writing a single line of state management code.

Building Your First NNX Model

Let's build a simple multi-layer perceptron step by step.

The Basic Structure

Every NNX model inherits from nnx.Module:

from flax import nnx
import jax.numpy as jnp

class SimpleMLP(nnx.Module):
    def __init__(self, hidden_dim: int, output_dim: int, *, rngs: nnx.Rngs):
        # Define layers as attributes
        self.linear1 = nnx.Linear(784, hidden_dim, rngs=rngs)
        self.linear2 = nnx.Linear(hidden_dim, output_dim, rngs=rngs)
    
    def __call__(self, x):
        # Define forward pass
        x = self.linear1(x)
        x = nnx.relu(x)
        x = self.linear2(x)
        return x

Key differences from PyTorch:

  1. No super().__init__(): NNX uses Python metaclasses, so you don't need to call the parent constructor.

  2. __call__ instead of forward: In Python, __call__ makes an object callable. NNX uses this standard convention rather than PyTorch's custom forward method.

  3. The rngs parameter: This is required. We'll explain why in the next section.

Instantiating the Model

# Create a random number generator
rngs = nnx.Rngs(42)  # Seed for reproducibility

# Create the model
model = SimpleMLP(hidden_dim=128, output_dim=10, rngs=rngs)

# Test with dummy data
x = jnp.ones((32, 784))  # Batch of 32, 784 features each
output = model(x)

print(f"Input shape:  {x.shape}")       # (32, 784)
print(f"Output shape: {output.shape}")  # (32, 10)

That's it. The model is ready to use.

The Randomness Requirement

You might be wondering: why do we need to pass rngs everywhere?

The Problem with Hidden Randomness

In NumPy or PyTorch, random number generation uses a global state:

import numpy as np
np.random.seed(42)
print(np.random.randn())  # Always the same value
print(np.random.randn())  # Different value (global state advanced)

This global state is invisible. It changes every time you call a random function. And as we established in Week 1, JAX cannot work with hidden, changing state.

JAX's Explicit Randomness

In JAX, randomness is deterministic. You create a "key," and that key always produces the same random numbers:

from jax import random

key = random.PRNGKey(42)
print(random.normal(key, shape=(3,)))  # Always the same
print(random.normal(key, shape=(3,)))  # Still the same!

To get different random numbers, you must split the key:

key, subkey = random.split(key)
print(random.normal(subkey, shape=(3,)))  # New values

How nnx.Rngs Helps

When you create a neural network, every layer needs random numbers to initialize its weights. If you have 50 layers, that's 50 key splits you'd need to manage manually.

nnx.Rngs automates this. It's a key dispenser that splits and distributes keys automatically:

rngs = nnx.Rngs(42)

# When you create a layer, it asks rngs for keys internally
linear1 = nnx.Linear(10, 20, rngs=rngs)  # Gets its own keys
linear2 = nnx.Linear(20, 30, rngs=rngs)  # Gets different keys

# Both layers have different, reproducible initializations

The critical benefit: reproducibility. If you and I both run nnx.Rngs(42), we get identical models. This matters for debugging, for scientific reproducibility, and for distributed training where multiple machines must initialize the same model.

Inspecting Your Model

PyTorch lets you print(model) to see the architecture. NNX has something better: nnx.display().

nnx.display(model)

This produces a rich, hierarchical view showing:

  • Every layer and sublayer

  • Parameter shapes and dtypes

  • The total parameter count

  • The structure of the computational graph

In Jupyter notebooks, this renders as an interactive tree you can expand and collapse. It's invaluable for debugging shape mismatches and verifying your architecture.

Building a CNN for Image Classification

Let's build something more substantial: a convolutional neural network for classifying images. This is the same architecture pattern used in production image classifiers.

The Architecture

We'll build a classic CNN with:

  • Two convolutional blocks (Conv → ReLU → Pool)

  • A flatten operation

  • Two dense layers for classification

from flax import nnx
import jax.numpy as jnp
from functools import partial

class CNN(nnx.Module):
    """A simple CNN for image classification."""
    
    def __init__(self, num_classes: int, *, rngs: nnx.Rngs):
        # Convolutional layers
        self.conv1 = nnx.Conv(
            in_features=1,      # Input channels (1 for grayscale)
            out_features=32,    # Output channels
            kernel_size=(3, 3), # 3x3 filters
            rngs=rngs
        )
        self.conv2 = nnx.Conv(
            in_features=32,
            out_features=64,
            kernel_size=(3, 3),
            rngs=rngs
        )
        
        # Dense layers
        # After two 2x2 pooling operations on 28x28 input: 28 → 14 → 7
        # So we have 64 channels × 7 × 7 = 3136 features
        self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
        self.linear2 = nnx.Linear(256, num_classes, rngs=rngs)
        
        # Pooling as a reusable operation
        self.pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
    
    def __call__(self, x):
        # Block 1: Conv → ReLU → Pool
        x = self.conv1(x)
        x = nnx.relu(x)
        x = self.pool(x)
        
        # Block 2: Conv → ReLU → Pool
        x = self.conv2(x)
        x = nnx.relu(x)
        x = self.pool(x)
        
        # Flatten: (batch, height, width, channels) → (batch, features)
        x = x.reshape(x.shape[0], -1)
        
        # Classification head
        x = nnx.relu(self.linear1(x))
        x = self.linear2(x)
        
        return x

Testing the Model

# Initialize
model = CNN(num_classes=10, rngs=nnx.Rngs(0))

# Create dummy input: batch of 4 grayscale 28×28 images
# Shape: (batch, height, width, channels)
dummy_input = jnp.ones((4, 28, 28, 1))

# Forward pass
output = model(dummy_input)

print(f"Input shape:  {dummy_input.shape}")  # (4, 28, 28, 1)
print(f"Output shape: {output.shape}")       # (4, 10)

# Inspect the model structure
nnx.display(model)

The output shape is (4, 10)—four images, ten class scores each. This is exactly what we'd feed into a softmax for classification.

How NNX Compiles with JAX

Here's the magic: even though we're writing object-oriented code, we can still use JAX's transformations.

The @nnx.jit Decorator

NNX provides its own versions of JAX transforms that understand NNX objects:

@nnx.jit
def forward(model, x):
    return model(x)

# This is JIT-compiled, just like @jax.jit
output = forward(model, dummy_input)

When you call this function, NNX:

  1. Extracts the model's parameters into a pure JAX pytree

  2. Traces the computation with those parameters

  3. Compiles the trace with XLA

  4. Updates the model object with any changed state

You write familiar OOP code. JAX gets the pure functions it needs. Everyone wins.

Preview: The Split/Merge Pattern

Under the hood, NNX uses two key operations:

# Split: separate structure from state
graphdef, state = nnx.split(model)

# graphdef: the "blueprint" of the model (static)
# state: the actual parameter values (dynamic, a pytree)

# Merge: reconstruct the model from structure and state
reconstructed_model = nnx.merge(graphdef, state)

You rarely need to call these directly—@nnx.jit handles it automatically. But understanding this pattern helps when you need to do advanced things like:

  • Saving and loading checkpoints (Week 9)

  • Distributing models across devices (Week 10)

  • Custom training loops with fine-grained control

Common Layers Reference

Here are the NNX equivalents of layers you know from PyTorch:

PyTorch

Flax NNX

Notes

nn.Linear

nnx.Linear

Same signature

nn.Conv2d

nnx.Conv

Uses in_features/out_features

nn.BatchNorm2d

nnx.BatchNorm

Tracks running stats automatically

nn.LayerNorm

nnx.LayerNorm

Same behavior

nn.Dropout

nnx.Dropout

Requires rngs, respects deterministic flag

nn.Embedding

nnx.Embed

For token embeddings

nn.MultiheadAttention

nnx.MultiHeadAttention

Transformer attention

Activation functions aren't layers in NNX, they're just functions:

x = nnx.relu(x)
x = nnx.gelu(x)
x = nnx.softmax(x)
x = nnx.sigmoid(x)

Working with Parameters

Sometimes you need direct access to the parameters—for logging, for custom initialization, or for freezing layers.

Accessing Parameters

# Access a specific layer's weights
print(model.linear1.kernel.value.shape)  # (3136, 256)
print(model.linear1.bias.value.shape)    # (256,)

# Parameters are nnx.Param objects; .value gets the JAX array

Extracting All Parameters

# Get all parameters as a state object
state = nnx.state(model)

# Or specifically just the trainable parameters
params = nnx.state(model, nnx.Param)

Updating Parameters

# Update the model with new state
nnx.update(model, new_state)

This becomes important next week when we build training loops.

Exercises

Before moving on, try these:

  1. Add Dropout: Modify the CNN to include nnx.Dropout(rate=0.5, rngs=rngs) between the dense layers. Note that dropout needs its own RNG stream for the random mask.

  2. Build a deeper network: Create a 4-layer MLP with hidden dimensions [512, 256, 128, 64]. Use a loop in __init__ to avoid repetition.

  3. Parameter counting: Write a function that takes an NNX model and returns the total number of trainable parameters. Hint: use nnx.state(model, nnx.Param) and jax.tree_util.tree_map.

Quick Reference

from flax import nnx
import jax.numpy as jnp

# Define a Model
class MyModel(nnx.Module):
    def __init__(self, *, rngs: nnx.Rngs):
        self.linear = nnx.Linear(10, 5, rngs=rngs)
    
    def __call__(self, x):
        return nnx.relu(self.linear(x))

# Instantiate
model = MyModel(rngs=nnx.Rngs(0))

# Forward Pass
output = model(jnp.ones((32, 10)))

# Inspect
nnx.display(model)

# Access Parameters
weights = model.linear.kernel.value
state = nnx.state(model, nnx.Param)

# JIT Compile
@nnx.jit
def forward(model, x):
    return model(x)

# Common Layers
nnx.Linear(in_features, out_features, rngs=rngs)
nnx.Conv(in_features, out_features, kernel_size, rngs=rngs)
nnx.BatchNorm(num_features, rngs=rngs)
nnx.Dropout(rate, rngs=rngs)
nnx.Embed(num_embeddings, features, rngs=rngs)

What's Next

We have a model. But a model that can't learn is just a random number generator with extra steps.

Next week, we build the training loop. We'll use:

  • nnx.value_and_grad to compute loss and gradients

  • nnx.Optimizer to manage parameter updates

  • optax to define the optimization algorithm

  • Metrics to track accuracy and loss

We'll train our CNN on real data and watch the loss curve drop. That's when this all becomes real.