Building Neural Networks with Flax NNX
Flax NNX: Neural Networks That Feel Like PyTorch

I'm a Machine Learning Engineer passionate about building production-ready ML systems for the African market. With experience in TensorFlow, Keras, and Python-based workflows, I help teams bridge the gap between machine learning research and real-world deployment—especially on resource-constrained devices. I'm also a Google Developer Expert in AI. I regularly speak at tech conferences including PyCon Africa, DevFest Kampala, DevFest Nairobi and more and also write technical articles on AI/ML here.
Over the past two weeks, we've learned that JAX is fast (jit), that it eliminates loops (vmap), and that it computes gradients automatically (grad). These are powerful primitives.
But if you've been following along, you might have noticed something uncomfortable: we've been passing arrays around manually, tracking parameters in tuples, and writing functions that return updated state. It works, but it doesn't feel like building neural networks. It feels like accounting.
If you've used PyTorch, you know how natural it is to define a model as a class, call model(x), and let the framework handle the rest. That ergonomic experience is what made PyTorch the dominant framework for research.
Today, we get that experience in JAX.
Flax NNX is a neural network library that gives you PyTorch-style classes and methods while compiling down to JAX's XLA backend. You write object-oriented code. JAX runs functional code. NNX bridges the gap.
By the end of this article, we'll have:
Understood why NNX exists and what problem it solves
Built our first neural network using
nnx.ModuleLearned how NNX handles the critical issue of random number generation
Created a CNN for image classification—the same architecture used in real production systems
Let's build something real.
The State Problem
To understand why Flax NNX is necessary, we need to understand the fundamental tension between PyTorch and JAX.
How PyTorch Handles State
In PyTorch, a model is an object that contains its own parameters:
import torch.nn as nn
class PyTorchMLP(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(784, 128)
self.linear2 = nn.Linear(128, 10)
def forward(self, x):
x = torch.relu(self.linear1(x))
return self.linear2(x)
model = PyTorchMLP()
output = model(input_data) # Parameters are hidden inside
This is intuitive. The nn.Linear layers create their own weight matrices internally. You don't see them, you don't manage them, they just exist.
But this "hidden state" creates problems for JAX. When you call jax.jit on a function, JAX traces through it and compiles a computational graph. If that function secretly reads or writes to hidden variables, JAX can't see those operations, and it can't optimize them.
How Pure JAX Handles State
JAX demands pure functions: functions where the output depends only on the inputs, with no side effects. If you want to update parameters, you must pass them in and return them out:
def forward(params, x):
x = jax.nn.relu(x @ params['w1'] + params['b1'])
return x @ params['w2'] + params['b2']
def train_step(params, x, y):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
new_params = update_params(params, grads)
return new_params, loss # Must return the new state
# You're always passing params around
params, loss = train_step(params, x_batch, y_batch)
params, loss = train_step(params, x_batch, y_batch)
params, loss = train_step(params, x_batch, y_batch)
This is explicit and JIT-compatible, but it's tedious. For a model with dozens of layers, managing nested dictionaries of parameters becomes error-prone.
How Flax NNX Solves This
NNX gives you the PyTorch interface while secretly doing the JAX bookkeeping. You write classes with attributes. NNX intercepts those attributes, extracts the parameters when needed for JIT compilation, and updates them in place when you're done.
from flax import nnx
class NNX_MLP(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.linear1 = nnx.Linear(784, 128, rngs=rngs)
self.linear2 = nnx.Linear(128, 10, rngs=rngs)
def __call__(self, x):
x = nnx.relu(self.linear1(x))
return self.linear2(x)
model = NNX_MLP(rngs=nnx.Rngs(0))
output = model(input_data) # Looks just like PyTorch
The syntax is nearly identical to PyTorch. But under the hood, NNX can extract the parameters, pass them through JAX transformations, and put them back—all without you writing a single line of state management code.
Building Your First NNX Model
Let's build a simple multi-layer perceptron step by step.
The Basic Structure
Every NNX model inherits from nnx.Module:
from flax import nnx
import jax.numpy as jnp
class SimpleMLP(nnx.Module):
def __init__(self, hidden_dim: int, output_dim: int, *, rngs: nnx.Rngs):
# Define layers as attributes
self.linear1 = nnx.Linear(784, hidden_dim, rngs=rngs)
self.linear2 = nnx.Linear(hidden_dim, output_dim, rngs=rngs)
def __call__(self, x):
# Define forward pass
x = self.linear1(x)
x = nnx.relu(x)
x = self.linear2(x)
return x
Key differences from PyTorch:
No
super().__init__(): NNX uses Python metaclasses, so you don't need to call the parent constructor.__call__instead offorward: In Python,__call__makes an object callable. NNX uses this standard convention rather than PyTorch's customforwardmethod.The
rngsparameter: This is required. We'll explain why in the next section.
Instantiating the Model
# Create a random number generator
rngs = nnx.Rngs(42) # Seed for reproducibility
# Create the model
model = SimpleMLP(hidden_dim=128, output_dim=10, rngs=rngs)
# Test with dummy data
x = jnp.ones((32, 784)) # Batch of 32, 784 features each
output = model(x)
print(f"Input shape: {x.shape}") # (32, 784)
print(f"Output shape: {output.shape}") # (32, 10)
That's it. The model is ready to use.
The Randomness Requirement
You might be wondering: why do we need to pass rngs everywhere?
The Problem with Hidden Randomness
In NumPy or PyTorch, random number generation uses a global state:
import numpy as np
np.random.seed(42)
print(np.random.randn()) # Always the same value
print(np.random.randn()) # Different value (global state advanced)
This global state is invisible. It changes every time you call a random function. And as we established in Week 1, JAX cannot work with hidden, changing state.
JAX's Explicit Randomness
In JAX, randomness is deterministic. You create a "key," and that key always produces the same random numbers:
from jax import random
key = random.PRNGKey(42)
print(random.normal(key, shape=(3,))) # Always the same
print(random.normal(key, shape=(3,))) # Still the same!
To get different random numbers, you must split the key:
key, subkey = random.split(key)
print(random.normal(subkey, shape=(3,))) # New values
How nnx.Rngs Helps
When you create a neural network, every layer needs random numbers to initialize its weights. If you have 50 layers, that's 50 key splits you'd need to manage manually.
nnx.Rngs automates this. It's a key dispenser that splits and distributes keys automatically:
rngs = nnx.Rngs(42)
# When you create a layer, it asks rngs for keys internally
linear1 = nnx.Linear(10, 20, rngs=rngs) # Gets its own keys
linear2 = nnx.Linear(20, 30, rngs=rngs) # Gets different keys
# Both layers have different, reproducible initializations
The critical benefit: reproducibility. If you and I both run nnx.Rngs(42), we get identical models. This matters for debugging, for scientific reproducibility, and for distributed training where multiple machines must initialize the same model.
Inspecting Your Model
PyTorch lets you print(model) to see the architecture. NNX has something better: nnx.display().
nnx.display(model)
This produces a rich, hierarchical view showing:
Every layer and sublayer
Parameter shapes and dtypes
The total parameter count
The structure of the computational graph
In Jupyter notebooks, this renders as an interactive tree you can expand and collapse. It's invaluable for debugging shape mismatches and verifying your architecture.
Building a CNN for Image Classification
Let's build something more substantial: a convolutional neural network for classifying images. This is the same architecture pattern used in production image classifiers.
The Architecture
We'll build a classic CNN with:
Two convolutional blocks (Conv → ReLU → Pool)
A flatten operation
Two dense layers for classification
from flax import nnx
import jax.numpy as jnp
from functools import partial
class CNN(nnx.Module):
"""A simple CNN for image classification."""
def __init__(self, num_classes: int, *, rngs: nnx.Rngs):
# Convolutional layers
self.conv1 = nnx.Conv(
in_features=1, # Input channels (1 for grayscale)
out_features=32, # Output channels
kernel_size=(3, 3), # 3x3 filters
rngs=rngs
)
self.conv2 = nnx.Conv(
in_features=32,
out_features=64,
kernel_size=(3, 3),
rngs=rngs
)
# Dense layers
# After two 2x2 pooling operations on 28x28 input: 28 → 14 → 7
# So we have 64 channels × 7 × 7 = 3136 features
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
self.linear2 = nnx.Linear(256, num_classes, rngs=rngs)
# Pooling as a reusable operation
self.pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
def __call__(self, x):
# Block 1: Conv → ReLU → Pool
x = self.conv1(x)
x = nnx.relu(x)
x = self.pool(x)
# Block 2: Conv → ReLU → Pool
x = self.conv2(x)
x = nnx.relu(x)
x = self.pool(x)
# Flatten: (batch, height, width, channels) → (batch, features)
x = x.reshape(x.shape[0], -1)
# Classification head
x = nnx.relu(self.linear1(x))
x = self.linear2(x)
return x
Testing the Model
# Initialize
model = CNN(num_classes=10, rngs=nnx.Rngs(0))
# Create dummy input: batch of 4 grayscale 28×28 images
# Shape: (batch, height, width, channels)
dummy_input = jnp.ones((4, 28, 28, 1))
# Forward pass
output = model(dummy_input)
print(f"Input shape: {dummy_input.shape}") # (4, 28, 28, 1)
print(f"Output shape: {output.shape}") # (4, 10)
# Inspect the model structure
nnx.display(model)
The output shape is (4, 10)—four images, ten class scores each. This is exactly what we'd feed into a softmax for classification.
How NNX Compiles with JAX
Here's the magic: even though we're writing object-oriented code, we can still use JAX's transformations.
The @nnx.jit Decorator
NNX provides its own versions of JAX transforms that understand NNX objects:
@nnx.jit
def forward(model, x):
return model(x)
# This is JIT-compiled, just like @jax.jit
output = forward(model, dummy_input)
When you call this function, NNX:
Extracts the model's parameters into a pure JAX pytree
Traces the computation with those parameters
Compiles the trace with XLA
Updates the model object with any changed state
You write familiar OOP code. JAX gets the pure functions it needs. Everyone wins.
Preview: The Split/Merge Pattern
Under the hood, NNX uses two key operations:
# Split: separate structure from state
graphdef, state = nnx.split(model)
# graphdef: the "blueprint" of the model (static)
# state: the actual parameter values (dynamic, a pytree)
# Merge: reconstruct the model from structure and state
reconstructed_model = nnx.merge(graphdef, state)
You rarely need to call these directly—@nnx.jit handles it automatically. But understanding this pattern helps when you need to do advanced things like:
Saving and loading checkpoints (Week 9)
Distributing models across devices (Week 10)
Custom training loops with fine-grained control
Common Layers Reference
Here are the NNX equivalents of layers you know from PyTorch:
PyTorch | Flax NNX | Notes |
|---|---|---|
|
| Same signature |
|
| Uses |
|
| Tracks running stats automatically |
|
| Same behavior |
|
| Requires |
|
| For token embeddings |
|
| Transformer attention |
Activation functions aren't layers in NNX, they're just functions:
x = nnx.relu(x)
x = nnx.gelu(x)
x = nnx.softmax(x)
x = nnx.sigmoid(x)
Working with Parameters
Sometimes you need direct access to the parameters—for logging, for custom initialization, or for freezing layers.
Accessing Parameters
# Access a specific layer's weights
print(model.linear1.kernel.value.shape) # (3136, 256)
print(model.linear1.bias.value.shape) # (256,)
# Parameters are nnx.Param objects; .value gets the JAX array
Extracting All Parameters
# Get all parameters as a state object
state = nnx.state(model)
# Or specifically just the trainable parameters
params = nnx.state(model, nnx.Param)
Updating Parameters
# Update the model with new state
nnx.update(model, new_state)
This becomes important next week when we build training loops.
Exercises
Before moving on, try these:
Add Dropout: Modify the CNN to include
nnx.Dropout(rate=0.5, rngs=rngs)between the dense layers. Note that dropout needs its own RNG stream for the random mask.Build a deeper network: Create a 4-layer MLP with hidden dimensions [512, 256, 128, 64]. Use a loop in
__init__to avoid repetition.Parameter counting: Write a function that takes an NNX model and returns the total number of trainable parameters. Hint: use
nnx.state(model, nnx.Param)andjax.tree_util.tree_map.
Quick Reference
from flax import nnx
import jax.numpy as jnp
# Define a Model
class MyModel(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.linear = nnx.Linear(10, 5, rngs=rngs)
def __call__(self, x):
return nnx.relu(self.linear(x))
# Instantiate
model = MyModel(rngs=nnx.Rngs(0))
# Forward Pass
output = model(jnp.ones((32, 10)))
# Inspect
nnx.display(model)
# Access Parameters
weights = model.linear.kernel.value
state = nnx.state(model, nnx.Param)
# JIT Compile
@nnx.jit
def forward(model, x):
return model(x)
# Common Layers
nnx.Linear(in_features, out_features, rngs=rngs)
nnx.Conv(in_features, out_features, kernel_size, rngs=rngs)
nnx.BatchNorm(num_features, rngs=rngs)
nnx.Dropout(rate, rngs=rngs)
nnx.Embed(num_embeddings, features, rngs=rngs)
What's Next
We have a model. But a model that can't learn is just a random number generator with extra steps.
Next week, we build the training loop. We'll use:
nnx.value_and_gradto compute loss and gradientsnnx.Optimizerto manage parameter updatesoptaxto define the optimization algorithmMetrics to track accuracy and loss
We'll train our CNN on real data and watch the loss curve drop. That's when this all becomes real.



