<?xml version="1.0" encoding="UTF-8"?><rss xmlns:dc="http://purl.org/dc/elements/1.1/" xmlns:content="http://purl.org/rss/1.0/modules/content/" xmlns:atom="http://www.w3.org/2005/Atom" version="2.0"><channel><title><![CDATA[Wesley Kambale]]></title><description><![CDATA[I'm a machine learning engineer and Google Developer Expert in AI, adept at crafting production-ready ML systems that provide impactful solutions in the African]]></description><link>https://kambale.dev</link><generator>RSS for Node</generator><lastBuildDate>Sat, 18 Apr 2026 00:16:35 GMT</lastBuildDate><atom:link href="https://kambale.dev/rss.xml" rel="self" type="application/rss+xml"/><language><![CDATA[en]]></language><ttl>60</ttl><item><title><![CDATA[Custom training loop from scratch in JAX]]></title><description><![CDATA[For the past three weeks, we've been building up to this moment.
Week 1 taught us that JAX is fast. Week 2 showed us how to eliminate loops with vmap and compute gradients with grad. Week 3 gave us Fl]]></description><link>https://kambale.dev/training-loop-in-jax</link><guid isPermaLink="true">https://kambale.dev/training-loop-in-jax</guid><category><![CDATA[jax]]></category><category><![CDATA[training-loop]]></category><category><![CDATA[Machine Learning]]></category><category><![CDATA[optax]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Wed, 04 Mar 2026 22:25:37 GMT</pubDate><enclosure url="https://cdn.hashnode.com/uploads/covers/61143e31119030192497a888/46e5dfb7-7634-4ea1-9b04-a5b4909af30b.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<p>For the past three weeks, we've been building up to this moment.</p>
<p>Week 1 taught us that JAX is fast. Week 2 showed us how to eliminate loops with <code>vmap</code> and compute gradients with <code>grad</code>. Week 3 gave us Flax NNX—a way to define neural networks that feels like PyTorch but runs like JAX.</p>
<p>But we haven't actually <em>trained</em> anything yet.</p>
<p>Today, that changes. We're going to write a complete training loop from scratch: forward pass, loss calculation, gradient computation, parameter updates, and evaluation. No <code>model.fit()</code>. No magic. Just explicit, controllable, JIT-compiled code that you own completely.</p>
<p>By the end of this article, you'll have trained a CNN on MNIST and understood every single line of code that made it happen.</p>
<h2>Why Write Your Own Training Loop?</h2>
<p>If you've used Keras, you know how easy training can be:</p>
<pre><code class="language-python">model.fit(x_train, y_train, epochs=10)
</code></pre>
<p>One line. Done. So why would anyone write hundreds of lines to do the same thing?</p>
<p>Because <code>model.fit()</code> is a black box. It works until it doesn't. And when it doesn't, when your loss explodes, when you need gradient accumulation, when you're debugging a custom layer, when you need to log something specific every 37 steps, you're stuck.</p>
<p>The training loop is where machine learning actually happens. If you don't control it, you don't control your model.</p>
<p>In production systems, you almost always need:</p>
<ul>
<li><p><strong>Custom metrics</strong> that aren't built into the framework</p>
</li>
<li><p><strong>Gradient clipping</strong> to prevent exploding gradients</p>
</li>
<li><p><strong>Learning rate schedules</strong> that change based on validation loss</p>
</li>
<li><p><strong>Mixed precision training</strong> for speed</p>
</li>
<li><p><strong>Gradient accumulation</strong> when batches don't fit in memory</p>
</li>
<li><p><strong>Early stopping</strong> based on custom criteria</p>
</li>
</ul>
<p>Writing your own loop means you can implement any of these with a few lines of code.</p>
<h2>The Four Beats of Training</h2>
<p>Every training loop, regardless of framework, follows the same rhythm:</p>
<ol>
<li><p><strong>Forward Pass</strong>: Run input through the model to get predictions</p>
</li>
<li><p><strong>Loss Calculation</strong>: Compare predictions to ground truth</p>
</li>
<li><p><strong>Backward Pass</strong>: Compute gradients of the loss with respect to parameters</p>
</li>
<li><p><strong>Parameter Update</strong>: Adjust parameters using the gradients</p>
</li>
</ol>
<p>In PyTorch, this looks like:</p>
<pre><code class="language-python">optimizer.zero_grad()        # Clear old gradients
predictions = model(x)       # Forward pass
loss = loss_fn(predictions, y)  # Loss calculation
loss.backward()              # Backward pass
optimizer.step()             # Parameter update
</code></pre>
<p>PyTorch hides a lot here. <code>loss.backward()</code> traverses a hidden computational graph and secretly updates <code>.grad</code> attributes on every parameter.</p>
<p>JAX makes everything explicit.</p>
<h2>The Optax Library</h2>
<p><strong>Optax</strong> is JAX's optimization library. It handles gradient transformations and parameter updates.</p>
<p>The key insight about Optax is that optimizers are <strong>composable</strong>. Instead of monolithic optimizer classes, Optax gives you building blocks that you chain together:</p>
<pre><code class="language-python">import optax

# A simple Adam optimizer
optimizer = optax.adam(learning_rate=0.001)

# Adam with gradient clipping
optimizer = optax.chain(
    optax.clip_by_global_norm(1.0),  # First, clip gradients
    optax.adam(learning_rate=0.001)   # Then, apply Adam
)

# Adam with weight decay (AdamW)
optimizer = optax.adamw(learning_rate=0.001, weight_decay=0.01)
</code></pre>
<p>This composability is powerful. Need to add gradient clipping to your existing optimizer? Add one line. Need a custom transformation? Write a function that takes gradients and returns modified gradients.</p>
<h3>Common Optax Optimizers</h3>
<table>
<thead>
<tr>
<th>Optimizer</th>
<th>Usage</th>
</tr>
</thead>
<tbody><tr>
<td><code>optax.sgd(lr)</code></td>
<td>Basic stochastic gradient descent</td>
</tr>
<tr>
<td><code>optax.adam(lr)</code></td>
<td>Adam (most common default)</td>
</tr>
<tr>
<td><code>optax.adamw(lr, weight_decay)</code></td>
<td>Adam with decoupled weight decay</td>
</tr>
<tr>
<td><code>optax.sgd(lr, momentum=0.9)</code></td>
<td>SGD with momentum</td>
</tr>
<tr>
<td><code>optax.rmsprop(lr)</code></td>
<td>RMSprop</td>
</tr>
</tbody></table>
<h3>Common Gradient Transformations</h3>
<table>
<thead>
<tr>
<th>Transformation</th>
<th>Purpose</th>
</tr>
</thead>
<tbody><tr>
<td><code>optax.clip_by_global_norm(max_norm)</code></td>
<td>Prevent exploding gradients</td>
</tr>
<tr>
<td><code>optax.add_decayed_weights(decay)</code></td>
<td>L2 regularization</td>
</tr>
<tr>
<td><code>optax.scale_by_schedule(schedule_fn)</code></td>
<td>Learning rate scheduling</td>
</tr>
</tbody></table>
<h2>The nnx.Optimizer Wrapper</h2>
<p>Optax is a functional library, it works with raw JAX arrays. But our models are NNX objects with structured parameters. We need a bridge.</p>
<p><code>nnx.Optimizer</code> is that bridge:</p>
<pre><code class="language-python">from flax import nnx
import optax

# Create the Optax optimizer
tx = optax.adam(learning_rate=0.001)

# Wrap it with nnx.Optimizer
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)
</code></pre>
<p>The <code>wrt</code> parameter is critical. It stands for "with respect to" and tells the optimizer which variables to update. <code>nnx.Param</code> means "only update trainable parameters", not batch normalization statistics or other state.</p>
<p>Once wrapped, updating parameters is simple:</p>
<pre><code class="language-python">optimizer.update(model, grads)
</code></pre>
<p>This applies the gradients to the model's parameters in place. No need to manually extract parameters, apply updates, and put them back.</p>
<h2>Computing Gradients with nnx.value_and_grad</h2>
<p>In PyTorch, you call <code>loss.backward()</code> to compute gradients. In JAX, we use <code>value_and_grad</code>.</p>
<p>Here's the pattern:</p>
<pre><code class="language-python">def loss_fn(model, x, y):
    logits = model(x)
    loss = compute_loss(logits, y)
    return loss

# Get both the loss value and the gradients
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
</code></pre>
<p><strong>Critical rule</strong>: <code>nnx.value_and_grad</code> computes gradients with respect to the <strong>first argument</strong> of the function. That's why <code>model</code> must be the first parameter.</p>
<p>The <code>grads</code> object has the same structure as the model's parameters—it's a pytree where each leaf is the gradient for the corresponding parameter.</p>
<h2>The Complete Training Step</h2>
<p>Let's put it all together. Here's a complete, JIT-compiled training step:</p>
<pre><code class="language-python">@nnx.jit
def train_step(model, optimizer, x, y):
    """Execute one training step."""
    
    def loss_fn(model):
        logits = model(x)
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, 
            labels=y
        ).mean()
        return loss, logits
    
    # Compute loss and gradients
    (loss, logits), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    
    # Update parameters
    optimizer.update(model, grads)
    
    # Compute accuracy for logging
    predictions = jnp.argmax(logits, axis=-1)
    accuracy = jnp.mean(predictions == y)
    
    return loss, accuracy
</code></pre>
<p>A few things to note:</p>
<ol>
<li><p><code>@nnx.jit</code>: The entire function is JIT-compiled. The first call traces and compiles; subsequent calls execute the compiled code.</p>
</li>
<li><p><code>has_aux=True</code>: Our loss function returns <code>(loss, logits)</code>. The <code>has_aux=True</code> flag tells <code>value_and_grad</code> that the function returns auxiliary data alongside the loss. Gradients are only computed for the first returned value.</p>
</li>
<li><p><strong>In-place updates</strong>: <code>optimizer.update(model, grads)</code> modifies the model in place. NNX handles the functional-to-stateful conversion internally.</p>
</li>
</ol>
<h2>Training a CNN on MNIST</h2>
<p>Let's train a real model on real data. We'll use the CNN from Week 3 and the MNIST dataset.</p>
<h3>Setup and Data Loading</h3>
<pre><code class="language-python">import jax
import jax.numpy as jnp
from flax import nnx
import optax
import tensorflow_datasets as tfds
import tensorflow as tf

# Disable TensorFlow GPU to avoid conflicts with JAX
tf.config.set_visible_devices([], 'GPU')

# Load MNIST
def load_mnist(batch_size=32):
    """Load and preprocess MNIST dataset."""
    
    def preprocess(sample):
        image = tf.cast(sample['image'], tf.float32) / 255.0
        label = sample['label']
        return {'image': image, 'label': label}
    
    train_ds = tfds.load('mnist', split='train')
    test_ds = tfds.load('mnist', split='test')
    
    train_ds = (train_ds
                .map(preprocess)
                .shuffle(1024)
                .batch(batch_size, drop_remainder=True)
                .prefetch(1))
    
    test_ds = (test_ds
               .map(preprocess)
               .batch(batch_size, drop_remainder=True)
               .prefetch(1))
    
    return train_ds, test_ds

train_ds, test_ds = load_mnist(batch_size=32)
</code></pre>
<h3>Define the Model</h3>
<p>We'll use a slightly simplified version of the CNN from Week 3:</p>
<pre><code class="language-python">from functools import partial

class CNN(nnx.Module):
    """Convolutional Neural Network for MNIST."""
    
    def __init__(self, *, rngs: nnx.Rngs):
        self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
        self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
        self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
        self.linear2 = nnx.Linear(256, 10, rngs=rngs)
        self.pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
    
    def __call__(self, x):
        x = nnx.relu(self.conv1(x))
        x = self.pool(x)
        x = nnx.relu(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)  # Flatten
        x = nnx.relu(self.linear1(x))
        x = self.linear2(x)
        return x
</code></pre>
<h3>Initialize Model and Optimizer</h3>
<pre><code class="language-python"># Initialize model
model = CNN(rngs=nnx.Rngs(0))

# Test forward pass
dummy_input = jnp.ones((1, 28, 28, 1))
dummy_output = model(dummy_input)
print(f"Model output shape: {dummy_output.shape}")  # (1, 10)

# Initialize optimizer
learning_rate = 0.005
momentum = 0.9
tx = optax.adamw(learning_rate, momentum)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)

# Display model structure
nnx.display(model)
</code></pre>
<h3>Define Training and Evaluation Steps</h3>
<pre><code class="language-python">@nnx.jit
def train_step(model, optimizer, batch):
    """Single training step."""
    
    def loss_fn(model):
        logits = model(batch['image'])
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits,
            labels=batch['label']
        ).mean()
        return loss, logits
    
    (loss, logits), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizer.update(model, grads)
    
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == batch['label'])
    return loss, accuracy

@nnx.jit
def eval_step(model, batch):
    """Single evaluation step."""
    logits = model(batch['image'])
    loss = optax.softmax_cross_entropy_with_integer_labels(
        logits=logits,
        labels=batch['label']
    ).mean()
    accuracy = jnp.mean(jnp.argmax(logits, axis=-1) == batch['label'])
    return loss, accuracy
</code></pre>
<h3>The Training Loop</h3>
<pre><code class="language-python"># Training configuration
num_epochs = 5
train_steps_per_epoch = 1000
eval_steps = 200

print("Starting training...")
print("=" * 60)

for epoch in range(num_epochs):
    # Training
    model.train()  # Set model to training mode
    train_loss, train_acc = 0.0, 0.0
    
    for step, batch in enumerate(train_ds.as_numpy_iterator()):
        if step &gt;= train_steps_per_epoch:
            break
        
        # Convert to JAX arrays
        batch = {k: jnp.array(v) for k, v in batch.items()}
        
        loss, acc = train_step(model, optimizer, batch)
        train_loss += loss
        train_acc += acc
    
    train_loss /= train_steps_per_epoch
    train_acc /= train_steps_per_epoch
    
    # Evaluation
    model.eval()  # Set model to evaluation mode
    eval_loss, eval_acc = 0.0, 0.0
    eval_batches = 0
    
    for step, batch in enumerate(test_ds.as_numpy_iterator()):
        if step &gt;= eval_steps:
            break
        
        batch = {k: jnp.array(v) for k, v in batch.items()}
        
        loss, acc = eval_step(model, batch)
        eval_loss += loss
        eval_acc += acc
        eval_batches += 1
    
    eval_loss /= eval_batches
    eval_acc /= eval_batches
    
    print(f"Epoch {epoch + 1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Eval Loss:  {eval_loss:.4f} | Eval Acc:  {eval_acc:.4f}")

print("Training complete!")
</code></pre>
<h2>Using nnx.MultiMetric for Cleaner Tracking</h2>
<p>NNX provides a <code>MultiMetric</code> helper for accumulating metrics across batches:</p>
<pre><code class="language-python"># Define metrics
metrics = nnx.MultiMetric(
    accuracy=nnx.metrics.Accuracy(),
    loss=nnx.metrics.Average('loss'),
)

@nnx.jit
def train_step_with_metrics(model, optimizer, metrics, batch):
    def loss_fn(model):
        logits = model(batch['image'])
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=batch['label']
        ).mean()
        return loss, logits
    
    (loss, logits), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizer.update(model, grads)
    
    # Update metrics in place
    metrics.update(loss=loss, logits=logits, labels=batch['label'])

# In training loop:
for batch in train_ds:
    train_step_with_metrics(model, optimizer, metrics, batch)

# Get aggregated metrics
results = metrics.compute()
print(f"Loss: {results['loss']:.4f}, Accuracy: {results['accuracy']:.4f}")

# Reset for next epoch
metrics.reset()
</code></pre>
<p>This is cleaner than manual accumulation and handles edge cases (like different batch sizes) correctly.</p>
<h2>Inference: Using the Trained Model</h2>
<p>After training, making predictions is straightforward:</p>
<pre><code class="language-python">model.eval()  # Ensure evaluation mode (deterministic)

@nnx.jit
def predict(model, images):
    logits = model(images)
    return jnp.argmax(logits, axis=-1)

# Get predictions
test_batch = next(test_ds.as_numpy_iterator())
predictions = predict(model, jnp.array(test_batch['image']))

print(f"Predictions: {predictions[:10]}")
print(f"Actual: {test_batch['label'][:10]}")
</code></pre>
<h2>Exercises</h2>
<ol>
<li><p><strong>Add gradient clipping</strong>: Modify the optimizer to include <code>optax.clip_by_global_norm(1.0)</code>. Train the model and compare the loss curves.</p>
</li>
<li><p><strong>Learning rate schedule</strong>: Implement a learning rate that decays over time using <code>optax.warmup_cosine_decay_schedule</code>. Check the Optax documentation for the parameters.</p>
</li>
<li><p><strong>Track more metrics</strong>: Add precision and recall tracking alongside accuracy. You'll need to compute true positives, false positives, and false negatives.</p>
</li>
<li><p><strong>Early stopping</strong>: Modify the training loop to stop if validation accuracy doesn't improve for 3 consecutive epochs.</p>
</li>
</ol>
<h2>Quick Reference</h2>
<pre><code class="language-python">import optax
from flax import nnx

# Optimizer Setup
tx = optax.adam(learning_rate=0.001)
optimizer = nnx.Optimizer(model, tx, wrt=nnx.Param)

# Training Step Pattern
@nnx.jit
def train_step(model, optimizer, batch):
    def loss_fn(model):
        logits = model(batch['x'])
        loss = compute_loss(logits, batch['y'])
        return loss, logits
    
    (loss, logits), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizer.update(model, grads)
    return loss

# Evaluation Step Pattern
@nnx.jit  
def eval_step(model, batch):
    logits = model(batch['x'])
    loss = compute_loss(logits, batch['y'])
    return loss

# Training/Eval Mode
model.train()  # Enable dropout, stochastic behavior
model.eval()   # Deterministic inference
</code></pre>
<h2>What's Next</h2>
<p>We've trained a model. But what happens when training goes wrong? What if loss becomes <code>NaN</code>? What if shapes don't match? What if the model silently produces garbage?</p>
<p>Next week, we dive into <strong>Optax</strong> in more depth: learning rate schedules, gradient clipping, weight decay, and building custom optimizer stacks. We'll make our training more robust and production-ready.</p>
]]></content:encoded></item><item><title><![CDATA[Building Neural Networks with Flax NNX]]></title><description><![CDATA[Over the past two weeks, we've learned that JAX is fast (jit), that it eliminates loops (vmap), and that it computes gradients automatically (grad). These are powerful primitives.
But if you've been f]]></description><link>https://kambale.dev/flax-nnx</link><guid isPermaLink="true">https://kambale.dev/flax-nnx</guid><category><![CDATA[jax]]></category><category><![CDATA[flax]]></category><category><![CDATA[neural networks]]></category><category><![CDATA[nnx]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Mon, 23 Feb 2026 19:46:48 GMT</pubDate><enclosure url="https://cloudmate-test.s3.us-east-1.amazonaws.com/uploads/covers/61143e31119030192497a888/476b79f9-b003-4338-86d7-9b270d60f550.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<p>Over the past two weeks, we've learned that JAX is fast (<code>jit</code>), that it eliminates loops (<code>vmap</code>), and that it computes gradients automatically (<code>grad</code>). These are powerful primitives.</p>
<p>But if you've been following along, you might have noticed something uncomfortable: we've been passing arrays around manually, tracking parameters in tuples, and writing functions that return updated state. It works, but it doesn't feel like building neural networks. It feels like accounting.</p>
<p>If you've used PyTorch, you know how natural it is to define a model as a class, call <code>model(x)</code>, and let the framework handle the rest. That ergonomic experience is what made PyTorch the dominant framework for research.</p>
<p>Today, we get that experience in JAX.</p>
<p><strong>Flax NNX</strong> is a neural network library that gives you PyTorch-style classes and methods while compiling down to JAX's XLA backend. You write object-oriented code. JAX runs functional code. NNX bridges the gap.</p>
<p>By the end of this article, we'll have:</p>
<ol>
<li><p>Understood why NNX exists and what problem it solves</p>
</li>
<li><p>Built our first neural network using <code>nnx.Module</code></p>
</li>
<li><p>Learned how NNX handles the critical issue of random number generation</p>
</li>
<li><p>Created a CNN for image classification—the same architecture used in real production systems</p>
</li>
</ol>
<p>Let's build something real.</p>
<h2>The State Problem</h2>
<p>To understand why Flax NNX is necessary, we need to understand the fundamental tension between PyTorch and JAX.</p>
<h3>How PyTorch Handles State</h3>
<p>In PyTorch, a model is an object that contains its own parameters:</p>
<pre><code class="language-python">import torch.nn as nn

class PyTorchMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(784, 128)
        self.linear2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = torch.relu(self.linear1(x))
        return self.linear2(x)

model = PyTorchMLP()
output = model(input_data)  # Parameters are hidden inside
</code></pre>
<p>This is intuitive. The <code>nn.Linear</code> layers create their own weight matrices internally. You don't see them, you don't manage them, they just exist.</p>
<p>But this "hidden state" creates problems for JAX. When you call <code>jax.jit</code> on a function, JAX traces through it and compiles a computational graph. If that function secretly reads or writes to hidden variables, JAX can't see those operations, and it can't optimize them.</p>
<h3>How Pure JAX Handles State</h3>
<p>JAX demands <strong>pure functions</strong>: functions where the output depends only on the inputs, with no side effects. If you want to update parameters, you must pass them in and return them out:</p>
<pre><code class="language-python">def forward(params, x):
    x = jax.nn.relu(x @ params['w1'] + params['b1'])
    return x @ params['w2'] + params['b2']

def train_step(params, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    new_params = update_params(params, grads)
    return new_params, loss  # Must return the new state

# You're always passing params around
params, loss = train_step(params, x_batch, y_batch)
params, loss = train_step(params, x_batch, y_batch)
params, loss = train_step(params, x_batch, y_batch)
</code></pre>
<p>This is explicit and JIT-compatible, but it's tedious. For a model with dozens of layers, managing nested dictionaries of parameters becomes error-prone.</p>
<h3>How Flax NNX Solves This</h3>
<p>NNX gives you the PyTorch interface while secretly doing the JAX bookkeeping. You write classes with attributes. NNX intercepts those attributes, extracts the parameters when needed for JIT compilation, and updates them in place when you're done.</p>
<pre><code class="language-python">from flax import nnx

class NNX_MLP(nnx.Module):
    def __init__(self, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(784, 128, rngs=rngs)
        self.linear2 = nnx.Linear(128, 10, rngs=rngs)
    
    def __call__(self, x):
        x = nnx.relu(self.linear1(x))
        return self.linear2(x)

model = NNX_MLP(rngs=nnx.Rngs(0))
output = model(input_data)  # Looks just like PyTorch
</code></pre>
<p>The syntax is nearly identical to PyTorch. But under the hood, NNX can extract the parameters, pass them through JAX transformations, and put them back—all without you writing a single line of state management code.</p>
<h2>Building Your First NNX Model</h2>
<p>Let's build a simple multi-layer perceptron step by step.</p>
<h3>The Basic Structure</h3>
<p>Every NNX model inherits from <code>nnx.Module</code>:</p>
<pre><code class="language-python">from flax import nnx
import jax.numpy as jnp

class SimpleMLP(nnx.Module):
    def __init__(self, hidden_dim: int, output_dim: int, *, rngs: nnx.Rngs):
        # Define layers as attributes
        self.linear1 = nnx.Linear(784, hidden_dim, rngs=rngs)
        self.linear2 = nnx.Linear(hidden_dim, output_dim, rngs=rngs)
    
    def __call__(self, x):
        # Define forward pass
        x = self.linear1(x)
        x = nnx.relu(x)
        x = self.linear2(x)
        return x
</code></pre>
<p>Key differences from PyTorch:</p>
<ol>
<li><p><strong>No</strong> <code>super().__init__()</code>: NNX uses Python metaclasses, so you don't need to call the parent constructor.</p>
</li>
<li><p><code>__call__</code> <strong>instead of</strong> <code>forward</code>: In Python, <code>__call__</code> makes an object callable. NNX uses this standard convention rather than PyTorch's custom <code>forward</code> method.</p>
</li>
<li><p><strong>The</strong> <code>rngs</code> <strong>parameter</strong>: This is required. We'll explain why in the next section.</p>
</li>
</ol>
<h3>Instantiating the Model</h3>
<pre><code class="language-python"># Create a random number generator
rngs = nnx.Rngs(42)  # Seed for reproducibility

# Create the model
model = SimpleMLP(hidden_dim=128, output_dim=10, rngs=rngs)

# Test with dummy data
x = jnp.ones((32, 784))  # Batch of 32, 784 features each
output = model(x)

print(f"Input shape:  {x.shape}")       # (32, 784)
print(f"Output shape: {output.shape}")  # (32, 10)
</code></pre>
<p>That's it. The model is ready to use.</p>
<h2>The Randomness Requirement</h2>
<p>You might be wondering: why do we need to pass <code>rngs</code> everywhere?</p>
<h3>The Problem with Hidden Randomness</h3>
<p>In NumPy or PyTorch, random number generation uses a global state:</p>
<pre><code class="language-python">import numpy as np
np.random.seed(42)
print(np.random.randn())  # Always the same value
print(np.random.randn())  # Different value (global state advanced)
</code></pre>
<p>This global state is invisible. It changes every time you call a random function. And as we established in Week 1, JAX cannot work with hidden, changing state.</p>
<h3>JAX's Explicit Randomness</h3>
<p>In JAX, randomness is deterministic. You create a "key," and that key always produces the same random numbers:</p>
<pre><code class="language-python">from jax import random

key = random.PRNGKey(42)
print(random.normal(key, shape=(3,)))  # Always the same
print(random.normal(key, shape=(3,)))  # Still the same!
</code></pre>
<p>To get different random numbers, you must <strong>split</strong> the key:</p>
<pre><code class="language-python">key, subkey = random.split(key)
print(random.normal(subkey, shape=(3,)))  # New values
</code></pre>
<h3>How nnx.Rngs Helps</h3>
<p>When you create a neural network, every layer needs random numbers to initialize its weights. If you have 50 layers, that's 50 key splits you'd need to manage manually.</p>
<p><code>nnx.Rngs</code> automates this. It's a key dispenser that splits and distributes keys automatically:</p>
<pre><code class="language-python">rngs = nnx.Rngs(42)

# When you create a layer, it asks rngs for keys internally
linear1 = nnx.Linear(10, 20, rngs=rngs)  # Gets its own keys
linear2 = nnx.Linear(20, 30, rngs=rngs)  # Gets different keys

# Both layers have different, reproducible initializations
</code></pre>
<p>The critical benefit: <strong>reproducibility</strong>. If you and I both run <code>nnx.Rngs(42)</code>, we get identical models. This matters for debugging, for scientific reproducibility, and for distributed training where multiple machines must initialize the same model.</p>
<h2>Inspecting Your Model</h2>
<p>PyTorch lets you <code>print(model)</code> to see the architecture. NNX has something better: <code>nnx.display()</code>.</p>
<pre><code class="language-python">nnx.display(model)
</code></pre>
<p>This produces a rich, hierarchical view showing:</p>
<ul>
<li><p>Every layer and sublayer</p>
</li>
<li><p>Parameter shapes and dtypes</p>
</li>
<li><p>The total parameter count</p>
</li>
<li><p>The structure of the computational graph</p>
</li>
</ul>
<p>In Jupyter notebooks, this renders as an interactive tree you can expand and collapse. It's invaluable for debugging shape mismatches and verifying your architecture.</p>
<h2>Building a CNN for Image Classification</h2>
<p>Let's build something more substantial: a convolutional neural network for classifying images. This is the same architecture pattern used in production image classifiers.</p>
<h3>The Architecture</h3>
<p>We'll build a classic CNN with:</p>
<ul>
<li><p>Two convolutional blocks (Conv → ReLU → Pool)</p>
</li>
<li><p>A flatten operation</p>
</li>
<li><p>Two dense layers for classification</p>
</li>
</ul>
<pre><code class="language-python">from flax import nnx
import jax.numpy as jnp
from functools import partial

class CNN(nnx.Module):
    """A simple CNN for image classification."""
    
    def __init__(self, num_classes: int, *, rngs: nnx.Rngs):
        # Convolutional layers
        self.conv1 = nnx.Conv(
            in_features=1,      # Input channels (1 for grayscale)
            out_features=32,    # Output channels
            kernel_size=(3, 3), # 3x3 filters
            rngs=rngs
        )
        self.conv2 = nnx.Conv(
            in_features=32,
            out_features=64,
            kernel_size=(3, 3),
            rngs=rngs
        )
        
        # Dense layers
        # After two 2x2 pooling operations on 28x28 input: 28 → 14 → 7
        # So we have 64 channels × 7 × 7 = 3136 features
        self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
        self.linear2 = nnx.Linear(256, num_classes, rngs=rngs)
        
        # Pooling as a reusable operation
        self.pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
    
    def __call__(self, x):
        # Block 1: Conv → ReLU → Pool
        x = self.conv1(x)
        x = nnx.relu(x)
        x = self.pool(x)
        
        # Block 2: Conv → ReLU → Pool
        x = self.conv2(x)
        x = nnx.relu(x)
        x = self.pool(x)
        
        # Flatten: (batch, height, width, channels) → (batch, features)
        x = x.reshape(x.shape[0], -1)
        
        # Classification head
        x = nnx.relu(self.linear1(x))
        x = self.linear2(x)
        
        return x
</code></pre>
<h3>Testing the Model</h3>
<pre><code class="language-python"># Initialize
model = CNN(num_classes=10, rngs=nnx.Rngs(0))

# Create dummy input: batch of 4 grayscale 28×28 images
# Shape: (batch, height, width, channels)
dummy_input = jnp.ones((4, 28, 28, 1))

# Forward pass
output = model(dummy_input)

print(f"Input shape:  {dummy_input.shape}")  # (4, 28, 28, 1)
print(f"Output shape: {output.shape}")       # (4, 10)

# Inspect the model structure
nnx.display(model)
</code></pre>
<p>The output shape is <code>(4, 10)</code>—four images, ten class scores each. This is exactly what we'd feed into a softmax for classification.</p>
<h2>How NNX Compiles with JAX</h2>
<p>Here's the magic: even though we're writing object-oriented code, we can still use JAX's transformations.</p>
<h3>The @nnx.jit Decorator</h3>
<p>NNX provides its own versions of JAX transforms that understand NNX objects:</p>
<pre><code class="language-python">@nnx.jit
def forward(model, x):
    return model(x)

# This is JIT-compiled, just like @jax.jit
output = forward(model, dummy_input)
</code></pre>
<p>When you call this function, NNX:</p>
<ol>
<li><p><strong>Extracts</strong> the model's parameters into a pure JAX pytree</p>
</li>
<li><p><strong>Traces</strong> the computation with those parameters</p>
</li>
<li><p><strong>Compiles</strong> the trace with XLA</p>
</li>
<li><p><strong>Updates</strong> the model object with any changed state</p>
</li>
</ol>
<p>You write familiar OOP code. JAX gets the pure functions it needs. Everyone wins.</p>
<h3>Preview: The Split/Merge Pattern</h3>
<p>Under the hood, NNX uses two key operations:</p>
<pre><code class="language-python"># Split: separate structure from state
graphdef, state = nnx.split(model)

# graphdef: the "blueprint" of the model (static)
# state: the actual parameter values (dynamic, a pytree)

# Merge: reconstruct the model from structure and state
reconstructed_model = nnx.merge(graphdef, state)
</code></pre>
<p>You rarely need to call these directly—<code>@nnx.jit</code> handles it automatically. But understanding this pattern helps when you need to do advanced things like:</p>
<ul>
<li><p>Saving and loading checkpoints (Week 9)</p>
</li>
<li><p>Distributing models across devices (Week 10)</p>
</li>
<li><p>Custom training loops with fine-grained control</p>
</li>
</ul>
<h2>Common Layers Reference</h2>
<p>Here are the NNX equivalents of layers you know from PyTorch:</p>
<table style="min-width:75px"><colgroup><col style="min-width:25px"></col><col style="min-width:25px"></col><col style="min-width:25px"></col></colgroup><tbody><tr><th><p>PyTorch</p></th><th><p>Flax NNX</p></th><th><p>Notes</p></th></tr><tr><td><p><code>nn.Linear</code></p></td><td><p><code>nnx.Linear</code></p></td><td><p>Same signature</p></td></tr><tr><td><p><code>nn.Conv2d</code></p></td><td><p><code>nnx.Conv</code></p></td><td><p>Uses <code>in_features</code>/<code>out_features</code></p></td></tr><tr><td><p><code>nn.BatchNorm2d</code></p></td><td><p><code>nnx.BatchNorm</code></p></td><td><p>Tracks running stats automatically</p></td></tr><tr><td><p><code>nn.LayerNorm</code></p></td><td><p><code>nnx.LayerNorm</code></p></td><td><p>Same behavior</p></td></tr><tr><td><p><code>nn.Dropout</code></p></td><td><p><code>nnx.Dropout</code></p></td><td><p>Requires <code>rngs</code>, respects <code>deterministic</code> flag</p></td></tr><tr><td><p><code>nn.Embedding</code></p></td><td><p><code>nnx.Embed</code></p></td><td><p>For token embeddings</p></td></tr><tr><td><p><code>nn.MultiheadAttention</code></p></td><td><p><code>nnx.MultiHeadAttention</code></p></td><td><p>Transformer attention</p></td></tr></tbody></table>

<p>Activation functions aren't layers in NNX, they're just functions:</p>
<pre><code class="language-python">x = nnx.relu(x)
x = nnx.gelu(x)
x = nnx.softmax(x)
x = nnx.sigmoid(x)
</code></pre>
<h2>Working with Parameters</h2>
<p>Sometimes you need direct access to the parameters—for logging, for custom initialization, or for freezing layers.</p>
<h3>Accessing Parameters</h3>
<pre><code class="language-python"># Access a specific layer's weights
print(model.linear1.kernel.value.shape)  # (3136, 256)
print(model.linear1.bias.value.shape)    # (256,)

# Parameters are nnx.Param objects; .value gets the JAX array
</code></pre>
<h3>Extracting All Parameters</h3>
<pre><code class="language-python"># Get all parameters as a state object
state = nnx.state(model)

# Or specifically just the trainable parameters
params = nnx.state(model, nnx.Param)
</code></pre>
<h3>Updating Parameters</h3>
<pre><code class="language-python"># Update the model with new state
nnx.update(model, new_state)
</code></pre>
<p>This becomes important next week when we build training loops.</p>
<h2>Exercises</h2>
<p>Before moving on, try these:</p>
<ol>
<li><p><strong>Add Dropout</strong>: Modify the CNN to include <code>nnx.Dropout(rate=0.5, rngs=rngs)</code> between the dense layers. Note that dropout needs its own RNG stream for the random mask.</p>
</li>
<li><p><strong>Build a deeper network</strong>: Create a 4-layer MLP with hidden dimensions [512, 256, 128, 64]. Use a loop in <code>__init__</code> to avoid repetition.</p>
</li>
<li><p><strong>Parameter counting</strong>: Write a function that takes an NNX model and returns the total number of trainable parameters. Hint: use <code>nnx.state(model, nnx.Param)</code> and <code>jax.tree_util.tree_map</code>.</p>
</li>
</ol>
<h2>Quick Reference</h2>
<pre><code class="language-python">from flax import nnx
import jax.numpy as jnp

# Define a Model
class MyModel(nnx.Module):
    def __init__(self, *, rngs: nnx.Rngs):
        self.linear = nnx.Linear(10, 5, rngs=rngs)
    
    def __call__(self, x):
        return nnx.relu(self.linear(x))

# Instantiate
model = MyModel(rngs=nnx.Rngs(0))

# Forward Pass
output = model(jnp.ones((32, 10)))

# Inspect
nnx.display(model)

# Access Parameters
weights = model.linear.kernel.value
state = nnx.state(model, nnx.Param)

# JIT Compile
@nnx.jit
def forward(model, x):
    return model(x)

# Common Layers
nnx.Linear(in_features, out_features, rngs=rngs)
nnx.Conv(in_features, out_features, kernel_size, rngs=rngs)
nnx.BatchNorm(num_features, rngs=rngs)
nnx.Dropout(rate, rngs=rngs)
nnx.Embed(num_embeddings, features, rngs=rngs)
</code></pre>
<h2>What's Next</h2>
<p>We have a model. But a model that can't learn is just a random number generator with extra steps.</p>
<p>Next week, we build the <strong>training loop</strong>. We'll use:</p>
<ul>
<li><p><code>nnx.value_and_grad</code> to compute loss and gradients</p>
</li>
<li><p><code>nnx.Optimizer</code> to manage parameter updates</p>
</li>
<li><p><code>optax</code> to define the optimization algorithm</p>
</li>
<li><p>Metrics to track accuracy and loss</p>
</li>
</ul>
<p>We'll train our CNN on real data and watch the loss curve drop. That's when this all becomes real.</p>
]]></content:encoded></item><item><title><![CDATA[Transformations That Change Everything]]></title><description><![CDATA[Last week, we learned that JAX makes code fast through JIT compilation. We took a matrix multiplication from 2 seconds to 0.001 seconds with a single decorator.
But speed isn't JAX's only trick. The real power of JAX lies in its transformations; func...]]></description><link>https://kambale.dev/transformations-that-change-everything</link><guid isPermaLink="true">https://kambale.dev/transformations-that-change-everything</guid><category><![CDATA[jax]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Mon, 09 Feb 2026 17:57:28 GMT</pubDate><enclosure url="https://cdn.hashnode.com/res/hashnode/image/upload/v1770659671216/1cee86fb-2dc9-42ed-a66c-5294c8c11bc1.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<p>Last week, we learned that JAX makes code fast through JIT compilation. We took a matrix multiplication from 2 seconds to 0.001 seconds with a single decorator.</p>
<p>But speed isn't JAX's only trick. The real power of JAX lies in its <strong>transformations</strong>; functions that take functions and return new functions with different behavior.</p>
<p>Today we're covering the two transformations that make JAX indispensable for machine learning:</p>
<ol>
<li><p><code>jax.vmap</code>: Automatic vectorization. Write code for one example, run it on a million.</p>
</li>
<li><p><code>jax.grad</code>: Automatic differentiation. Get gradients of any function, for free.</p>
</li>
</ol>
<p>By the end of this article, you'll understand why JAX developers almost never write for-loops, and you'll have trained multiple machine learning models in parallel without writing any loop at all.</p>
<h2 id="heading-the-problem-with-python-loops">The Problem with Python Loops</h2>
<p>Let's start with why loops are the enemy.</p>
<p>When you write a Python for-loop, the interpreter does a surprising amount of work for each iteration:</p>
<pre><code class="lang-python">results = []
<span class="hljs-keyword">for</span> x <span class="hljs-keyword">in</span> data:
    <span class="hljs-comment"># For EACH iteration, Python must:</span>
    <span class="hljs-comment"># 1. Fetch x from memory</span>
    <span class="hljs-comment"># 2. Check the type of x</span>
    <span class="hljs-comment"># 3. Look up what "+" means for that type</span>
    <span class="hljs-comment"># 4. Execute the addition</span>
    <span class="hljs-comment"># 5. Append to the list (which may require memory reallocation)</span>
    results.append(x + <span class="hljs-number">1</span>)
</code></pre>
<p>If <code>data</code> has a million elements, Python performs those administrative steps a million times. The actual math (<code>x + 1</code>) takes nanoseconds. The overhead takes microseconds. You're spending 99% of your time on bookkeeping.</p>
<p>In deep learning, we process batches: 64 images, 128 sentences, 256 audio clips at once. If you loop through them in Python, your GPU sits idle while Python shuffles paperwork.</p>
<p>NumPy helps by pushing operations into C:</p>
<pre><code class="lang-python">results = data + <span class="hljs-number">1</span>  <span class="hljs-comment"># Vectorized, fast</span>
</code></pre>
<p>But what if your function is more complex than addition? What if it involves multiple steps, conditionals, or nested operations? You'd have to manually rewrite everything to handle batches, adding batch dimensions everywhere and keeping track of which axis is which.</p>
<p>This is where <code>jax.vmap</code> changes the game.</p>
<h2 id="heading-jaxvmap-automatic-vectorization">jax.vmap: Automatic Vectorization</h2>
<p><code>vmap</code> stands for "vectorizing map." It takes a function written for a single example and transforms it into a function that operates on batches.</p>
<h3 id="heading-the-basic-pattern">The Basic Pattern</h3>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> jax
<span class="hljs-keyword">import</span> jax.numpy <span class="hljs-keyword">as</span> jnp

<span class="hljs-comment"># A function that works on ONE number</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">square</span>(<span class="hljs-params">x</span>):</span>
    <span class="hljs-keyword">return</span> x ** <span class="hljs-number">2</span>

<span class="hljs-comment"># Transform it to work on MANY numbers</span>
batched_square = jax.vmap(square)

<span class="hljs-comment"># Now use it</span>
numbers = jnp.array([<span class="hljs-number">1</span>, <span class="hljs-number">2</span>, <span class="hljs-number">3</span>, <span class="hljs-number">4</span>, <span class="hljs-number">5</span>])
result = batched_square(numbers)
print(result)  <span class="hljs-comment"># [1, 4, 9, 16, 25]</span>
</code></pre>
<p>"But wait," you might say, "I could just write <code>numbers ** 2</code> directly."</p>
<p>True. The power of <code>vmap</code> shows up with complex functions:</p>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">complex_operation</span>(<span class="hljs-params">x</span>):</span>
    <span class="hljs-string">"""A function with multiple steps."""</span>
    a = jnp.sin(x)
    b = jnp.exp(-x ** <span class="hljs-number">2</span>)
    c = a * b + jnp.log(<span class="hljs-number">1</span> + jnp.abs(x))
    <span class="hljs-keyword">return</span> c

<span class="hljs-comment"># Without vmap, you'd need to think about broadcasting at each step</span>
<span class="hljs-comment"># With vmap, you just wrap it</span>
batched_complex = jax.vmap(complex_operation)

x_batch = jnp.linspace(<span class="hljs-number">-3</span>, <span class="hljs-number">3</span>, <span class="hljs-number">1000</span>)
results = batched_complex(x_batch)
</code></pre>
<p>You write the function thinking about one input. <code>vmap</code> handles the batch dimension for you.</p>
<h3 id="heading-multiple-arguments-the-inaxes-parameter">Multiple Arguments: The <code>in_axes</code> Parameter</h3>
<p>Real functions have multiple arguments. <code>in_axes</code> tells <code>vmap</code> which arguments to map over and which to broadcast.</p>
<p>Consider a dot product:</p>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">dot_product</span>(<span class="hljs-params">weights, features</span>):</span>
    <span class="hljs-keyword">return</span> jnp.dot(weights, features)
</code></pre>
<p>Different scenarios require different <code>in_axes</code>:</p>
<pre><code class="lang-python"><span class="hljs-comment"># Scenario 1: One set of weights, many feature vectors</span>
<span class="hljs-comment"># weights: don't map (None), features: map over axis 0</span>
batch_predict = jax.vmap(dot_product, in_axes=(<span class="hljs-literal">None</span>, <span class="hljs-number">0</span>))

weights = jnp.array([<span class="hljs-number">1.0</span>, <span class="hljs-number">2.0</span>, <span class="hljs-number">3.0</span>])
features = jnp.array([
    [<span class="hljs-number">1.0</span>, <span class="hljs-number">0.0</span>, <span class="hljs-number">0.0</span>],
    [<span class="hljs-number">0.0</span>, <span class="hljs-number">1.0</span>, <span class="hljs-number">0.0</span>],
    [<span class="hljs-number">0.0</span>, <span class="hljs-number">0.0</span>, <span class="hljs-number">1.0</span>],
])

results = batch_predict(weights, features)
print(results)  <span class="hljs-comment"># [1., 2., 3.]</span>
</code></pre>
<pre><code class="lang-python"><span class="hljs-comment"># Scenario 2: Many weight sets, one feature vector (ensemble of models)</span>
<span class="hljs-comment"># weights: map over axis 0, features: don't map (None)</span>
ensemble_predict = jax.vmap(dot_product, in_axes=(<span class="hljs-number">0</span>, <span class="hljs-literal">None</span>))

many_weights = jnp.array([
    [<span class="hljs-number">1.0</span>, <span class="hljs-number">0.0</span>, <span class="hljs-number">0.0</span>],
    [<span class="hljs-number">0.0</span>, <span class="hljs-number">1.0</span>, <span class="hljs-number">0.0</span>],
    [<span class="hljs-number">0.0</span>, <span class="hljs-number">0.0</span>, <span class="hljs-number">1.0</span>],
])
single_features = jnp.array([<span class="hljs-number">1.0</span>, <span class="hljs-number">2.0</span>, <span class="hljs-number">3.0</span>])

results = ensemble_predict(many_weights, single_features)
print(results)  <span class="hljs-comment"># [1., 2., 3.]</span>
</code></pre>
<pre><code class="lang-python"><span class="hljs-comment"># Scenario 3: Many weights, many features (parallel evaluation)</span>
<span class="hljs-comment"># Both: map over axis 0</span>
parallel_predict = jax.vmap(dot_product, in_axes=(<span class="hljs-number">0</span>, <span class="hljs-number">0</span>))

results = parallel_predict(many_weights, features)
print(results)  <span class="hljs-comment"># [1., 2., 3.]</span>
</code></pre>
<p>The rule is simple:</p>
<ul>
<li><p><code>0</code> means "iterate over the first axis of this argument"</p>
</li>
<li><p><code>None</code> means "broadcast this argument to all iterations"</p>
</li>
<li><p>You can use other integers for different axes</p>
</li>
</ul>
<h3 id="heading-nested-vmap">Nested vmap</h3>
<p>You can stack <code>vmap</code> calls for multi-dimensional batching:</p>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">single_multiply</span>(<span class="hljs-params">a, b</span>):</span>
    <span class="hljs-keyword">return</span> a * b

<span class="hljs-comment"># Map over rows, then over columns</span>
double_batched = jax.vmap(jax.vmap(single_multiply))

matrix_a = jnp.array([[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>], [<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]])
matrix_b = jnp.array([[<span class="hljs-number">10</span>, <span class="hljs-number">20</span>], [<span class="hljs-number">30</span>, <span class="hljs-number">40</span>]])

result = double_batched(matrix_a, matrix_b)
print(result)
<span class="hljs-comment"># [[10, 40],</span>
<span class="hljs-comment">#  [90, 160]]</span>
</code></pre>
<h2 id="heading-jaxgrad-automatic-differentiation">jax.grad: Automatic Differentiation</h2>
<p>The other transformation that makes JAX essential for ML is <code>jax.grad</code>. It computes gradients automatically.</p>
<h3 id="heading-the-basic-pattern-1">The Basic Pattern</h3>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">f</span>(<span class="hljs-params">x</span>):</span>
    <span class="hljs-keyword">return</span> x ** <span class="hljs-number">2</span>

<span class="hljs-comment"># grad returns a NEW FUNCTION that computes the derivative</span>
df_dx = jax.grad(f)

print(f(<span class="hljs-number">3.0</span>))      <span class="hljs-comment"># 9.0</span>
print(df_dx(<span class="hljs-number">3.0</span>))  <span class="hljs-comment"># 6.0 (derivative of x² is 2x, and 2*3=6)</span>
</code></pre>
<p>This works for any function, no matter how complex:</p>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">messy_function</span>(<span class="hljs-params">x</span>):</span>
    <span class="hljs-keyword">return</span> jnp.sin(x) * jnp.exp(-x ** <span class="hljs-number">2</span>) + jnp.tanh(x)

gradient_fn = jax.grad(messy_function)

<span class="hljs-comment"># The gradient at x=1.0</span>
print(gradient_fn(<span class="hljs-number">1.0</span>))  <span class="hljs-comment"># -0.5047...</span>
</code></pre>
<p>You didn't write any derivative rules. JAX traced through your function and computed the gradient automatically using the chain rule.</p>
<h3 id="heading-jaxvalueandgrad-get-both-at-once">jax.value_and_grad: Get Both at Once</h3>
<p>In training loops, you need both the loss value (to log it) and the gradients (to update parameters). <code>jax.value_and_grad</code> gives you both in one pass:</p>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">loss_fn</span>(<span class="hljs-params">params, x, y</span>):</span>
    prediction = params * x
    <span class="hljs-keyword">return</span> (prediction - y) ** <span class="hljs-number">2</span>

<span class="hljs-comment"># Returns (loss_value, gradients)</span>
loss_and_grad_fn = jax.value_and_grad(loss_fn)

params = <span class="hljs-number">1.0</span>
x, y = <span class="hljs-number">2.0</span>, <span class="hljs-number">6.0</span>  <span class="hljs-comment"># We want params=3.0 so that 3*2=6</span>

loss, grad = loss_and_grad_fn(params, x, y)
print(<span class="hljs-string">f"Loss: <span class="hljs-subst">{loss}</span>"</span>)  <span class="hljs-comment"># 16.0 (because (1*2 - 6)² = 16)</span>
print(<span class="hljs-string">f"Grad: <span class="hljs-subst">{grad}</span>"</span>)  <span class="hljs-comment"># -16.0</span>
</code></pre>
<h3 id="heading-gradients-with-respect-to-specific-arguments">Gradients with Respect to Specific Arguments</h3>
<p>By default, <code>grad</code> differentiates with respect to the first argument. Use <code>argnums</code> to change this:</p>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">f</span>(<span class="hljs-params">x, y</span>):</span>
    <span class="hljs-keyword">return</span> x ** <span class="hljs-number">2</span> + x * y

<span class="hljs-comment"># Gradient with respect to x (first argument, default)</span>
df_dx = jax.grad(f, argnums=<span class="hljs-number">0</span>)

<span class="hljs-comment"># Gradient with respect to y (second argument)</span>
df_dy = jax.grad(f, argnums=<span class="hljs-number">1</span>)

<span class="hljs-comment"># Gradients with respect to both</span>
df_both = jax.grad(f, argnums=(<span class="hljs-number">0</span>, <span class="hljs-number">1</span>))

x, y = <span class="hljs-number">2.0</span>, <span class="hljs-number">3.0</span>
print(<span class="hljs-string">f"df/dx: <span class="hljs-subst">{df_dx(x, y)}</span>"</span>)  <span class="hljs-comment"># 2*2 + 3 = 7</span>
print(<span class="hljs-string">f"df/dy: <span class="hljs-subst">{df_dy(x, y)}</span>"</span>)  <span class="hljs-comment"># 2</span>
print(<span class="hljs-string">f"Both:  <span class="hljs-subst">{df_both(x, y)}</span>"</span>)  <span class="hljs-comment"># (7.0, 2.0)</span>
</code></pre>
<h2 id="heading-combining-transforms-the-real-power">Combining Transforms: The Real Power</h2>
<p>JAX transforms compose. You can combine <code>jit</code>, <code>vmap</code>, and <code>grad</code> freely:</p>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">loss_single</span>(<span class="hljs-params">params, x, y</span>):</span>
    <span class="hljs-string">"""Loss for a single data point."""</span>
    pred = params[<span class="hljs-number">0</span>] * x + params[<span class="hljs-number">1</span>]  <span class="hljs-comment"># Linear: y = mx + b</span>
    <span class="hljs-keyword">return</span> (pred - y) ** <span class="hljs-number">2</span>

<span class="hljs-comment"># Stack the transforms:</span>
<span class="hljs-comment"># 1. grad: compute gradients with respect to params</span>
<span class="hljs-comment"># 2. vmap: do this for a batch of (x, y) pairs</span>
<span class="hljs-comment"># 3. jit: compile the whole thing</span>

batched_grad_fn = jax.jit(
    jax.vmap(
        jax.grad(loss_single),
        in_axes=(<span class="hljs-literal">None</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>)  <span class="hljs-comment"># Same params, batch of x, batch of y</span>
    )
)

params = jnp.array([<span class="hljs-number">1.0</span>, <span class="hljs-number">0.0</span>])  <span class="hljs-comment"># Initial guess: y = 1*x + 0</span>
x_batch = jnp.array([<span class="hljs-number">1.0</span>, <span class="hljs-number">2.0</span>, <span class="hljs-number">3.0</span>])
y_batch = jnp.array([<span class="hljs-number">2.0</span>, <span class="hljs-number">4.0</span>, <span class="hljs-number">6.0</span>])  <span class="hljs-comment"># True relationship: y = 2x</span>

<span class="hljs-comment"># Get gradients for each example in the batch</span>
grads_per_example = batched_grad_fn(params, x_batch, y_batch)
print(<span class="hljs-string">"Gradients per example:"</span>)
print(grads_per_example)

<span class="hljs-comment"># Average them for a batch gradient</span>
batch_grad = jnp.mean(grads_per_example, axis=<span class="hljs-number">0</span>)
print(<span class="hljs-string">f"Batch gradient: <span class="hljs-subst">{batch_grad}</span>"</span>)
</code></pre>
<p>This pattern, <code>jit(vmap(grad(...)))</code>, is the backbone of efficient training in JAX.</p>
<h2 id="heading-project-parallel-linear-regression">Project: Parallel Linear Regression</h2>
<p>Let's train multiple models simultaneously without a single Python loop during training.</p>
<p><strong>Scenario</strong>: We have housing price data from three different cities. Each city has different price dynamics, so we want to train a separate linear regression model for each.</p>
<h3 id="heading-step-1-generate-synthetic-data">Step 1: Generate Synthetic Data</h3>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> jax
<span class="hljs-keyword">import</span> jax.numpy <span class="hljs-keyword">as</span> jnp
<span class="hljs-keyword">from</span> jax <span class="hljs-keyword">import</span> random

<span class="hljs-comment"># Seed for reproducibility</span>
key = random.PRNGKey(<span class="hljs-number">42</span>)

<span class="hljs-comment"># True parameters for 3 cities</span>
<span class="hljs-comment"># City 0: price = 50 * size + 100</span>
<span class="hljs-comment"># City 1: price = 30 * size + 200  </span>
<span class="hljs-comment"># City 2: price = 80 * size + 50</span>
true_slopes = jnp.array([<span class="hljs-number">50.0</span>, <span class="hljs-number">30.0</span>, <span class="hljs-number">80.0</span>])
true_intercepts = jnp.array([<span class="hljs-number">100.0</span>, <span class="hljs-number">200.0</span>, <span class="hljs-number">50.0</span>])

n_cities = <span class="hljs-number">3</span>
n_samples = <span class="hljs-number">100</span>

<span class="hljs-comment"># Generate features (house sizes) for each city</span>
key, subkey = random.split(key)
X = random.uniform(subkey, (n_cities, n_samples, <span class="hljs-number">1</span>), minval=<span class="hljs-number">10</span>, maxval=<span class="hljs-number">100</span>)

<span class="hljs-comment"># Generate targets (prices) with some noise</span>
key, subkey = random.split(key)
noise = random.normal(subkey, (n_cities, n_samples)) * <span class="hljs-number">50</span>

<span class="hljs-comment"># Y[i] = true_slopes[i] * X[i] + true_intercepts[i] + noise[i]</span>
Y = (X[:, :, <span class="hljs-number">0</span>] * true_slopes[:, <span class="hljs-literal">None</span>] + 
     true_intercepts[:, <span class="hljs-literal">None</span>] + 
     noise)

print(<span class="hljs-string">f"X shape: <span class="hljs-subst">{X.shape}</span>"</span>)  <span class="hljs-comment"># (3, 100, 1)</span>
print(<span class="hljs-string">f"Y shape: <span class="hljs-subst">{Y.shape}</span>"</span>)  <span class="hljs-comment"># (3, 100)</span>
</code></pre>
<h3 id="heading-step-2-define-the-model-for-one-city">Step 2: Define the Model for ONE City</h3>
<p>We write everything as if we only have one city:</p>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">predict</span>(<span class="hljs-params">params, x</span>):</span>
    <span class="hljs-string">"""Predict price for one city's houses."""</span>
    slope, intercept = params[<span class="hljs-number">0</span>], params[<span class="hljs-number">1</span>]
    <span class="hljs-keyword">return</span> x[:, <span class="hljs-number">0</span>] * slope + intercept

<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">loss_fn</span>(<span class="hljs-params">params, x, y</span>):</span>
    <span class="hljs-string">"""MSE loss for one city."""</span>
    predictions = predict(params, x)
    <span class="hljs-keyword">return</span> jnp.mean((predictions - y) ** <span class="hljs-number">2</span>)

<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train_step</span>(<span class="hljs-params">params, x, y, learning_rate</span>):</span>
    <span class="hljs-string">"""One gradient descent step for one city."""</span>
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    new_params = params - learning_rate * grads
    <span class="hljs-keyword">return</span> new_params, loss
</code></pre>
<h3 id="heading-step-3-vectorize-across-cities">Step 3: Vectorize Across Cities</h3>
<p>Now we use <code>vmap</code> to run training for all three cities in parallel:</p>
<pre><code class="lang-python"><span class="hljs-comment"># Vectorize the training step</span>
<span class="hljs-comment"># params: axis 0 (each city has its own params)</span>
<span class="hljs-comment"># x: axis 0 (each city has its own data)</span>
<span class="hljs-comment"># y: axis 0 (each city has its own targets)</span>
<span class="hljs-comment"># learning_rate: None (same for all)</span>
parallel_train_step = jax.jit(
    jax.vmap(train_step, in_axes=(<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-literal">None</span>))
)

<span class="hljs-comment"># Initialize parameters for all 3 cities</span>
<span class="hljs-comment"># Shape: (3, 2) - 3 cities, 2 params each (slope, intercept)</span>
key, subkey = random.split(key)
params = random.normal(subkey, (n_cities, <span class="hljs-number">2</span>))

print(<span class="hljs-string">"Initial parameters:"</span>)
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(n_cities):
    print(<span class="hljs-string">f"  City <span class="hljs-subst">{i}</span>: slope=<span class="hljs-subst">{params[i, <span class="hljs-number">0</span>]:<span class="hljs-number">.2</span>f}</span>, intercept=<span class="hljs-subst">{params[i, <span class="hljs-number">1</span>]:<span class="hljs-number">.2</span>f}</span>"</span>)
</code></pre>
<h3 id="heading-step-4-train-all-models-in-parallel">Step 4: Train All Models in Parallel</h3>
<pre><code class="lang-python">learning_rate = <span class="hljs-number">0.0001</span>
n_epochs = <span class="hljs-number">2000</span>

print(<span class="hljs-string">"\nTraining..."</span>)
<span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> range(n_epochs):
    <span class="hljs-comment"># This single line trains ALL THREE models simultaneously</span>
    params, losses = parallel_train_step(params, X, Y, learning_rate)

    <span class="hljs-keyword">if</span> epoch % <span class="hljs-number">500</span> == <span class="hljs-number">0</span>:
        print(<span class="hljs-string">f"Epoch <span class="hljs-subst">{epoch:<span class="hljs-number">4</span>d}</span> | Losses: <span class="hljs-subst">{losses}</span>"</span>)

print(<span class="hljs-string">"\nFinal learned parameters:"</span>)
print(<span class="hljs-string">f"<span class="hljs-subst">{<span class="hljs-string">'City'</span>:&lt;<span class="hljs-number">6</span>}</span> <span class="hljs-subst">{<span class="hljs-string">'True Slope'</span>:&lt;<span class="hljs-number">12</span>}</span> <span class="hljs-subst">{<span class="hljs-string">'Learned Slope'</span>:&lt;<span class="hljs-number">15</span>}</span> <span class="hljs-subst">{<span class="hljs-string">'True Intercept'</span>:&lt;<span class="hljs-number">15</span>}</span> <span class="hljs-subst">{<span class="hljs-string">'Learned Intercept'</span>:&lt;<span class="hljs-number">15</span>}</span>"</span>)
print(<span class="hljs-string">"-"</span> * <span class="hljs-number">70</span>)
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(n_cities):
    print(<span class="hljs-string">f"<span class="hljs-subst">{i:&lt;<span class="hljs-number">6</span>}</span> <span class="hljs-subst">{true_slopes[i]:&lt;<span class="hljs-number">12.1</span>f}</span> <span class="hljs-subst">{params[i, <span class="hljs-number">0</span>]:&lt;<span class="hljs-number">15.2</span>f}</span> <span class="hljs-subst">{true_intercepts[i]:&lt;<span class="hljs-number">15.1</span>f}</span> <span class="hljs-subst">{params[i, <span class="hljs-number">1</span>]:&lt;<span class="hljs-number">15.2</span>f}</span>"</span>)
</code></pre>
<p>Expected output:</p>
<pre><code class="lang-plaintext">Final learned parameters:
City   True Slope   Learned Slope   True Intercept   Learned Intercept  
----------------------------------------------------------------------
0      50.0         49.87           100.0            102.34         
1      30.0         29.92           200.0            201.15         
2      80.0         79.78           50.0             52.89
</code></pre>
<p>We trained three separate models, and there's not a single Python for-loop in the training logic. The <code>parallel_train_step</code> function processes all cities in one fused GPU kernel.</p>
<h2 id="heading-why-this-matters">Why This Matters</h2>
<p>The pattern we just used scales to serious applications:</p>
<p><strong>Hyperparameter search</strong>: Train 100 models with different learning rates simultaneously. Pick the best one.</p>
<p><strong>Ensemble methods</strong>: Train 10 models with different random seeds. Average their predictions for more robust results.</p>
<p><strong>Per-user personalization</strong>: Train a tiny model for each of your 10,000 users. <code>vmap</code> handles the parallelization.</p>
<p><strong>Bayesian methods</strong>: Sample 1000 parameter configurations from a posterior distribution and evaluate all of them at once.</p>
<p>The key insight is that <code>vmap</code> doesn't just save you from writing loops; it enables computations that would be impractical with sequential processing.</p>
<h2 id="heading-exercises">Exercises</h2>
<ol>
<li><p><strong>Gradient verification</strong>: Use <code>jax.grad</code> to compute the derivative of <code>f(x) = sin(x)</code>. Plot it alongside <code>cos(x)</code> to verify they match.</p>
</li>
<li><p><strong>The ensemble challenge</strong>: Modify the project to train 10 models on the <em>same</em> city but with different random initializations. Use <code>vmap</code> over the params axis only (<code>in_axes=(0, None, None, None)</code>). Check if they all converge to similar values.</p>
</li>
<li><p><strong>Second derivatives</strong>: <code>jax.grad</code> returns a function; and you can take its gradient too. Compute the second derivative of <code>f(x) = x³</code> and verify it equals <code>6x</code>.</p>
</li>
</ol>
<h2 id="heading-quick-reference">Quick Reference</h2>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> jax
<span class="hljs-keyword">import</span> jax.numpy <span class="hljs-keyword">as</span> jnp

<span class="hljs-comment"># vmap: Automatic Vectorization</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">single_fn</span>(<span class="hljs-params">x</span>):</span>
    <span class="hljs-keyword">return</span> x ** <span class="hljs-number">2</span>

batched_fn = jax.vmap(single_fn)
results = batched_fn(jnp.array([<span class="hljs-number">1</span>, <span class="hljs-number">2</span>, <span class="hljs-number">3</span>]))

<span class="hljs-comment"># With multiple arguments</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">dot</span>(<span class="hljs-params">w, x</span>):</span>
    <span class="hljs-keyword">return</span> jnp.dot(w, x)

<span class="hljs-comment"># Shared weights, batched inputs</span>
batch_dot = jax.vmap(dot, in_axes=(<span class="hljs-literal">None</span>, <span class="hljs-number">0</span>))

<span class="hljs-comment"># grad: Automatic Differentiation</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">loss</span>(<span class="hljs-params">params</span>):</span>
    <span class="hljs-keyword">return</span> params ** <span class="hljs-number">2</span>

grad_fn = jax.grad(loss)
gradient = grad_fn(<span class="hljs-number">3.0</span>)  <span class="hljs-comment"># 6.0</span>

<span class="hljs-comment"># Get both value and gradient</span>
loss_val, grad_val = jax.value_and_grad(loss)(<span class="hljs-number">3.0</span>)

<span class="hljs-comment"># Gradient with respect to specific argument</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">f</span>(<span class="hljs-params">x, y</span>):</span>
    <span class="hljs-keyword">return</span> x * y

df_dy = jax.grad(f, argnums=<span class="hljs-number">1</span>)

<span class="hljs-comment"># Combining Transforms</span>
fast_batched_grad = jax.jit(jax.vmap(jax.grad(loss_fn), in_axes=(<span class="hljs-literal">None</span>, <span class="hljs-number">0</span>)))
</code></pre>
<h2 id="heading-whats-next">What's Next</h2>
<p>We've now covered the core JAX transforms: <code>jit</code> for speed, <code>vmap</code> for batching, and <code>grad</code> for gradients. These three tools are enough to train neural networks from scratch.</p>
<p>But writing raw JAX for complex models gets tedious. Next week, we'll introduce <strong>Flax NNX';</strong> a neural network library that gives you PyTorch-style ergonomics while keeping all the power of JAX transformations.</p>
<p>We'll build our first real neural network: a CNN for image classification.</p>
<h2 id="heading-resources">Resources</h2>
<p>Automatic Vectorization</p>
<p><a target="_blank" href="https://jaxstack.ai/">JAX AI Stack Guide</a></p>
]]></content:encoded></item><item><title><![CDATA[Why JAX? The NumPy You Know, But Faster]]></title><description><![CDATA[If you've been doing machine learning in Python for any length of time, you've written code like this:
import numpy as np

x = np.random.randn(1000, 1000)
y = np.random.randn(1000, 1000)
result = np.dot(x, y)

NumPy is comfortable. It's the first thi...]]></description><link>https://kambale.dev/why-jax-the-numpy-you-know-but-faster</link><guid isPermaLink="true">https://kambale.dev/why-jax-the-numpy-you-know-but-faster</guid><category><![CDATA[jax]]></category><category><![CDATA[xla]]></category><category><![CDATA[numpy]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Mon, 02 Feb 2026 14:15:09 GMT</pubDate><enclosure url="https://cdn.hashnode.com/res/hashnode/image/upload/v1770040892040/a8adacb0-31ca-4dd9-bc27-d29584712241.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<p>If you've been doing machine learning in Python for any length of time, you've written code like this:</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np

x = np.random.randn(<span class="hljs-number">1000</span>, <span class="hljs-number">1000</span>)
y = np.random.randn(<span class="hljs-number">1000</span>, <span class="hljs-number">1000</span>)
result = np.dot(x, y)
</code></pre>
<p>NumPy is comfortable. It's the first thing we reach for when we need to do math. But here's the uncomfortable truth: for the kind of work we're about to do: training neural networks, processing millions of samples, running on GPUs and TPUs, NumPy is holding us back.</p>
<p>Not because NumPy is bad. It's genuinely excellent at what it was designed for. The problem is that NumPy was designed in 2005, before GPUs became the workhorses of machine learning, before TPUs existed, and before we needed to compute gradients of functions with millions of parameters.</p>
<p>JAX is what NumPy would look like if we designed it today, knowing what we know now.</p>
<p>By the end of this article, we'll have:</p>
<ol>
<li><p>Understood <em>why</em> JAX is faster (not just that it is)</p>
</li>
<li><p>Written our first JAX code and seen the speedup ourselves</p>
</li>
<li><p>Learned the one mental shift that trips up everyone coming from NumPy or PyTorch</p>
</li>
<li><p>Built a working benchmark that proves the difference</p>
</li>
</ol>
<p>Let's get into it.</p>
<h2 id="heading-the-problem-with-normal-python">The Problem with "Normal" Python</h2>
<p>When you run Python code, the interpreter reads your instructions one line at a time, translates each one to machine code, executes it, then moves to the next line. This is called <strong>interpretation</strong>.</p>
<pre><code class="lang-plaintext">Line 1 → Translate → Execute
Line 2 → Translate → Execute
Line 3 → Translate → Execute
...
</code></pre>
<p>For a script that processes a CSV file or serves a web page, this is fine. The overhead of interpretation is negligible compared to the actual work being done.</p>
<p>But matrix multiplication? That's different. When we multiply two 5000×5000 matrices, we're doing 125 billion floating-point operations. The "translate → execute" overhead for each operation adds up fast.</p>
<p>NumPy helps by pushing the heavy lifting into compiled C code. When you call <a target="_blank" href="http://np.dot"><code>np.dot</code></a><code>()</code>, Python hands off the work to optimized BLAS libraries that run at near-hardware speed. That's why NumPy is fast<em>er</em> than pure Python.</p>
<p>But there's still a problem: <strong>Python is still orchestrating the operations</strong>. Every time you chain NumPy calls together: <a target="_blank" href="http://np.dot"><code>np.dot</code></a><code>()</code>, then <code>np.sum()</code>, then <code>np.exp()</code>, Python has to:</p>
<ol>
<li><p>Call into C</p>
</li>
<li><p>Wait for the result</p>
</li>
<li><p>Copy the result back to Python</p>
</li>
<li><p>Call into C again for the next operation</p>
</li>
</ol>
<p>Each of those handoffs has overhead. And when you're doing this millions of times in a training loop, it adds up.</p>
<h2 id="heading-how-jax-fixes-this-xla-compilation">How JAX Fixes This: XLA Compilation</h2>
<p>JAX takes a different approach. Instead of executing operations one at a time, JAX can <strong>compile your entire function</strong> into a single optimized program using XLA (Accelerated Linear Algebra).</p>
<p>Here's what that means in practice:</p>
<pre><code class="lang-plaintext">Read entire function → Analyze → Optimize → Fuse operations → Execute once
</code></pre>
<p>When JAX compiles a function, it:</p>
<ul>
<li><p><strong>Fuses operations</strong>: Instead of computing <code>a + b</code>, storing the result, then computing <code>result * c</code>, XLA fuses these into a single kernel that does <code>(a + b) * c</code> without intermediate storage.</p>
</li>
<li><p><strong>Eliminates dead code</strong>: If you compute something but never use it, XLA removes it entirely.</p>
</li>
<li><p><strong>Optimizes memory access</strong>: XLA reorders operations to minimize cache misses and memory transfers.</p>
</li>
<li><p><strong>Targets your hardware</strong>: The same JAX code compiles to optimized instructions for CPU, GPU, or TPU without you changing anything.</p>
</li>
</ul>
<p>The result? Code that runs 10x, 100x, sometimes 1000x faster than the NumPy equivalent.</p>
<h2 id="heading-setting-up">Setting Up</h2>
<p>Let's stop talking and start coding. We'll use Google Colab for this, it's free and gives us access to GPUs.</p>
<pre><code class="lang-bash"><span class="hljs-comment"># Install the JAX AI stack (includes JAX, Flax, Optax, and friends)</span>
!pip install -q jax-ai-stack
</code></pre>
<p>Now let's verify our setup:</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> jax
<span class="hljs-keyword">import</span> jax.numpy <span class="hljs-keyword">as</span> jnp
<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-keyword">import</span> time

print(<span class="hljs-string">f"JAX version: <span class="hljs-subst">{jax.__version__}</span>"</span>)
print(<span class="hljs-string">f"Available devices: <span class="hljs-subst">{jax.devices()}</span>"</span>)
</code></pre>
<p>If you're on a GPU runtime, you should see something like:</p>
<pre><code class="lang-plaintext">JAX version: 0.8.0
Available devices: [CudaDevice(id=0)]
</code></pre>
<p>If you see <code>CpuDevice</code>, that's fine too; JAX still provides speedups on CPU through XLA compilation.</p>
<h2 id="heading-the-first-mental-shift-immutability">The First Mental Shift: Immutability</h2>
<p>Before we benchmark anything, we need to talk about the one thing that trips up <em>everyone</em> coming from NumPy or PyTorch.</p>
<p><strong>In JAX, arrays are immutable. You cannot modify them in place.</strong></p>
<p>In NumPy, this is perfectly normal:</p>
<pre><code class="lang-python"><span class="hljs-comment"># NumPy: Mutable arrays</span>
arr = np.zeros(<span class="hljs-number">5</span>)
arr[<span class="hljs-number">0</span>] = <span class="hljs-number">42</span>  <span class="hljs-comment"># Modify in place</span>
print(arr)   <span class="hljs-comment"># [42. 0. 0. 0. 0.]</span>
</code></pre>
<p>In JAX, this will raise an error:</p>
<pre><code class="lang-python"><span class="hljs-comment"># JAX: This will FAIL</span>
arr = jnp.zeros(<span class="hljs-number">5</span>)
arr[<span class="hljs-number">0</span>] = <span class="hljs-number">42</span>  <span class="hljs-comment">#TypeError: JAX arrays are immutable</span>
</code></pre>
<p>Why? Because immutability is what makes JAX's optimizations possible. If arrays can be modified from anywhere in your code, the compiler can't safely reorder operations or run them in parallel. By guaranteeing that arrays never change, JAX can aggressively optimize your code.</p>
<p>So how do we update arrays? We use the <code>.at[].set()</code> syntax, which <strong>returns a new array</strong> with the modification:</p>
<pre><code class="lang-python"><span class="hljs-comment"># JAX: The correct way</span>
arr = jnp.zeros(<span class="hljs-number">5</span>)
new_arr = arr.at[<span class="hljs-number">0</span>].set(<span class="hljs-number">42</span>)

print(arr)      <span class="hljs-comment"># [0. 0. 0. 0. 0.] — Original unchanged</span>
print(new_arr)  <span class="hljs-comment"># [42. 0. 0. 0. 0.] — New array with the update</span>
</code></pre>
<p>This feels wasteful at first, are we really copying the entire array just to change one element? In practice, no. JAX and XLA are smart enough to optimize this. But conceptually, you should think of it as creating a new array.</p>
<p>Here's the full set of <code>.at[]</code> operations:</p>
<pre><code class="lang-python">x = jnp.array([<span class="hljs-number">1</span>, <span class="hljs-number">2</span>, <span class="hljs-number">3</span>, <span class="hljs-number">4</span>, <span class="hljs-number">5</span>])

<span class="hljs-comment"># Set a value</span>
x.at[<span class="hljs-number">0</span>].set(<span class="hljs-number">10</span>)         <span class="hljs-comment"># [10, 2, 3, 4, 5]</span>

<span class="hljs-comment"># Add to a value</span>
x.at[<span class="hljs-number">0</span>].add(<span class="hljs-number">10</span>)         <span class="hljs-comment"># [11, 2, 3, 4, 5]</span>

<span class="hljs-comment"># Multiply a value</span>
x.at[<span class="hljs-number">0</span>].multiply(<span class="hljs-number">10</span>)    <span class="hljs-comment"># [10, 2, 3, 4, 5]</span>

<span class="hljs-comment"># Works with slices too</span>
x.at[<span class="hljs-number">1</span>:<span class="hljs-number">3</span>].set(<span class="hljs-number">99</span>)       <span class="hljs-comment"># [1, 99, 99, 4, 5]</span>
</code></pre>
<p>Commit this to memory. You'll use it constantly.</p>
<h2 id="heading-the-benchmark-numpy-vs-jax-vs-jaxjit">The Benchmark: NumPy vs JAX vs JAX+JIT</h2>
<p>Now let's prove that JAX is actually faster. We'll multiply two large matrices; this is the core operation in neural networks (every linear layer is a matrix multiplication).</p>
<h3 id="heading-step-1-create-the-data">Step 1: Create the Data</h3>
<pre><code class="lang-python"><span class="hljs-comment"># Matrix size</span>
size = <span class="hljs-number">3000</span>

<span class="hljs-comment"># Create random matrices with NumPy</span>
x_np = np.random.normal(size=(size, size)).astype(np.float32)
y_np = np.random.normal(size=(size, size)).astype(np.float32)

<span class="hljs-comment"># Convert to JAX arrays</span>
x_jax = jnp.array(x_np)
y_jax = jnp.array(y_np)

print(<span class="hljs-string">f"Matrix shape: <span class="hljs-subst">{x_np.shape}</span>"</span>)
print(<span class="hljs-string">f"Total elements per matrix: <span class="hljs-subst">{size * size:,}</span>"</span>)
print(<span class="hljs-string">f"Operations for multiplication: <span class="hljs-subst">{size ** <span class="hljs-number">3</span>:,}</span>"</span>)
</code></pre>
<p>Output:</p>
<pre><code class="lang-plaintext">Matrix shape: (3000, 3000)
Total elements per matrix: 9,000,000
Operations for multiplication: 27,000,000,000
</code></pre>
<p>That's 27 billion operations. Let's see who can do it fastest.</p>
<h3 id="heading-step-2-benchmark-numpy">Step 2: Benchmark NumPy</h3>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">matmul_numpy</span>(<span class="hljs-params">x, y</span>):</span>
    <span class="hljs-keyword">return</span> np.dot(x, y)

<span class="hljs-comment"># Warmup</span>
_ = matmul_numpy(x_np, y_np)

<span class="hljs-comment"># Timed run</span>
start = time.perf_counter()
result_np = matmul_numpy(x_np, y_np)
numpy_time = time.perf_counter() - start

print(<span class="hljs-string">f"NumPy time: <span class="hljs-subst">{numpy_time:<span class="hljs-number">.4</span>f}</span> seconds"</span>)
</code></pre>
<p>Result:</p>
<pre><code class="lang-plaintext">NumPy time: 0.6294 seconds
</code></pre>
<h3 id="heading-step-3-benchmark-jax-without-jit">Step 3: Benchmark JAX (without JIT)</h3>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">matmul_jax</span>(<span class="hljs-params">x, y</span>):</span>
    <span class="hljs-keyword">return</span> jnp.dot(x, y)

<span class="hljs-comment"># Warmup</span>
_ = matmul_jax(x_jax, y_jax).block_until_ready()

<span class="hljs-comment"># Timed run</span>
start = time.perf_counter()
result_jax = matmul_jax(x_jax, y_jax).block_until_ready()
jax_time = time.perf_counter() - start

print(<span class="hljs-string">f"JAX time (no JIT): <span class="hljs-subst">{jax_time:<span class="hljs-number">.4</span>f}</span> seconds"</span>)
</code></pre>
<p><strong>Important</strong>: We call <code>.block_until_ready()</code> because JAX operations are <strong>asynchronous</strong>. When you call <a target="_blank" href="http://jnp.dot"><code>jnp.dot</code></a><code>()</code>, JAX immediately returns a "future" and continues executing Python code while the GPU works in the background. Without <code>block_until_ready()</code>, we'd be timing how fast JAX can <em>dispatch</em> the operation, not how fast it actually <em>runs</em>.</p>
<p>Result:</p>
<pre><code class="lang-plaintext">JAX time (no JIT): 0.0094 seconds
</code></pre>
<h3 id="heading-step-4-benchmark-jax-with-jit-compilation">Step 4: Benchmark JAX with JIT Compilation</h3>
<p>Here's where JAX shows its true power. We add a single decorator:</p>
<pre><code class="lang-python"><span class="hljs-meta">@jax.jit</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">matmul_jax_jit</span>(<span class="hljs-params">x, y</span>):</span>
    <span class="hljs-keyword">return</span> jnp.dot(x, y)

<span class="hljs-comment"># First call: JAX traces and compiles the function</span>
<span class="hljs-comment"># This takes a moment, so we don't include it in the benchmark</span>
print(<span class="hljs-string">"Compiling..."</span>)
_ = matmul_jax_jit(x_jax, y_jax).block_until_ready()
print(<span class="hljs-string">"Done."</span>)

<span class="hljs-comment"># Timed run (using the compiled version)</span>
start = time.perf_counter()
result_jit = matmul_jax_jit(x_jax, y_jax).block_until_ready()
jit_time = time.perf_counter() - start

print(<span class="hljs-string">f"JAX time (with JIT): <span class="hljs-subst">{jit_time:<span class="hljs-number">.4</span>f}</span> seconds"</span>)
</code></pre>
<p>The first call to a <code>@jax.jit</code> function is slow because JAX is <strong>tracing</strong> your function, figuring out what operations it contains, and then <strong>compiling</strong> it with XLA. Subsequent calls use the compiled version and are extremely fast.</p>
<h3 id="heading-step-5-compare-results">Step 5: Compare Results</h3>
<pre><code class="lang-python">print(<span class="hljs-string">"RESULTS"</span>)
print(<span class="hljs-string">f"NumPy:          <span class="hljs-subst">{numpy_time:<span class="hljs-number">.4</span>f}</span> seconds"</span>)
print(<span class="hljs-string">f"JAX (no JIT):   <span class="hljs-subst">{jax_time:<span class="hljs-number">.4</span>f}</span> seconds"</span>)
print(<span class="hljs-string">f"JAX (with JIT): <span class="hljs-subst">{jit_time:<span class="hljs-number">.4</span>f}</span> seconds"</span>)
print(<span class="hljs-string">f"\nSpeedup (JIT vs NumPy): <span class="hljs-subst">{numpy_time / jit_time:<span class="hljs-number">.1</span>f}</span>x"</span>)
</code></pre>
<p>Typical results on a Colab GPU:</p>
<pre><code class="lang-plaintext">RESULTS
NumPy:          0.6294 seconds
JAX (no JIT):   0.0094 seconds
JAX (with JIT): 0.0198 seconds

Speedup (JAX vs NumPy):     67.2x
Speedup (JIT vs NumPy):     31.8x
Speedup (JIT vs JAX):       0.5x
</code></pre>
<h2 id="heading-what-just-happened">What Just Happened?</h2>
<p>Let's break down why the JIT version is so much faster:</p>
<ol>
<li><p><strong>Hardware acceleration</strong>: JAX moved the computation to the GPU, which has thousands of cores optimized for parallel math.</p>
</li>
<li><p><strong>XLA compilation</strong>: Even on GPU, the JIT version is faster than raw JAX because XLA fuses operations and optimizes memory access patterns.</p>
</li>
<li><p><strong>No Python overhead</strong>: Once compiled, the function runs entirely in native code. Python is only involved in dispatching the call.</p>
</li>
</ol>
<p>The key insight is that <code>@jax.jit</code> doesn't just run your code on a GPU; it fundamentally changes <em>how</em> your code runs.</p>
<h2 id="heading-the-randomness-trap-bonus-lesson">The Randomness Trap (Bonus Lesson)</h2>
<p>There's one more gotcha that catches everyone early on. Try this:</p>
<pre><code class="lang-python"><span class="hljs-meta">@jax.jit</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">broken_random</span>():</span>
    <span class="hljs-keyword">return</span> np.random.randn(<span class="hljs-number">5</span>)  <span class="hljs-comment"># Using NumPy's random</span>

result1 = broken_random()
result2 = broken_random()
print(<span class="hljs-string">f"First call:  <span class="hljs-subst">{result1}</span>"</span>)
print(<span class="hljs-string">f"Second call: <span class="hljs-subst">{result2}</span>"</span>)
</code></pre>
<p>You'll notice that <code>result1</code> and <code>result2</code> are <strong>identical</strong>. The random numbers got "baked in" during compilation.</p>
<p>JAX requires <strong>explicit random state</strong> management. Here's the correct way:</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> jax <span class="hljs-keyword">import</span> random

<span class="hljs-meta">@jax.jit</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">correct_random</span>(<span class="hljs-params">key</span>):</span>
    <span class="hljs-keyword">return</span> random.normal(key, shape=(<span class="hljs-number">5</span>,))

<span class="hljs-comment"># Create a PRNG key</span>
key = random.PRNGKey(<span class="hljs-number">42</span>)

<span class="hljs-comment"># Split the key for each use</span>
key, subkey1 = random.split(key)
result1 = correct_random(subkey1)

key, subkey2 = random.split(key)
result2 = correct_random(subkey2)

print(<span class="hljs-string">f"First call:  <span class="hljs-subst">{result1}</span>"</span>)
print(<span class="hljs-string">f"Second call: <span class="hljs-subst">{result2}</span>"</span>)
</code></pre>
<p>Now you get different random numbers each time. We'll cover this pattern in depth when we build neural networks, but for now, just remember: <strong>never use</strong> <code>np.random</code> inside JIT-compiled functions.</p>
<h2 id="heading-exercises">Exercises</h2>
<p>Before moving on, try these:</p>
<ol>
<li><p><strong>Break the rules</strong>: Try to modify a JAX array in place (<code>x[0] = 1</code>). Read the error message carefully: JAX errors are verbose but informative.</p>
</li>
<li><p><strong>Vary the size</strong>: Run the benchmark with different matrix sizes (1000, 2000, 5000). How does the speedup change?</p>
</li>
<li><p><strong>Chain operations</strong>: Write a function that does multiple operations (<a target="_blank" href="http://jnp.dot"><code>jnp.dot</code></a>, then <code>jnp.sum</code>, then <code>jnp.exp</code>). Compare JIT vs non-JIT. The speedup should be even larger because XLA fuses the operations.</p>
</li>
</ol>
<h2 id="heading-whats-next">What's Next</h2>
<p>We've established <em>why</em> JAX is fast and seen the proof. But speed is only half the story. Next week, we'll explore <strong>transformations</strong>; the features that make JAX genuinely different from NumPy, not just faster.</p>
<p>Specifically, we'll cover:</p>
<ul>
<li><p><code>jax.vmap</code>: Automatic vectorization that eliminates for-loops</p>
</li>
<li><p><code>jax.grad</code>: Automatic differentiation that makes backpropagation trivial</p>
</li>
</ul>
<p>These two functions are why JAX has become the framework of choice for machine learning research. Once you understand them, you'll never look at NumPy the same way again.</p>
<h2 id="heading-quick-reference">Quick Reference</h2>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> jax
<span class="hljs-keyword">import</span> jax.numpy <span class="hljs-keyword">as</span> jnp
<span class="hljs-keyword">from</span> jax <span class="hljs-keyword">import</span> random

<span class="hljs-comment"># Basic array operations (same as NumPy)</span>
x = jnp.array([<span class="hljs-number">1</span>, <span class="hljs-number">2</span>, <span class="hljs-number">3</span>])
y = jnp.zeros((<span class="hljs-number">3</span>, <span class="hljs-number">3</span>))
z = jnp.dot(a, b)

<span class="hljs-comment"># Updating arrays (immutable style)</span>
new_x = x.at[<span class="hljs-number">0</span>].set(<span class="hljs-number">99</span>)
new_x = x.at[<span class="hljs-number">1</span>:].add(<span class="hljs-number">10</span>)

<span class="hljs-comment"># JIT compilation</span>
<span class="hljs-meta">@jax.jit</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">fast_function</span>(<span class="hljs-params">x</span>):</span>
    <span class="hljs-keyword">return</span> jnp.dot(x, x.T)

<span class="hljs-comment"># Explicit randomness</span>
key = random.PRNGKey(<span class="hljs-number">0</span>)
key, subkey = random.split(key)
samples = random.normal(subkey, shape=(<span class="hljs-number">100</span>,))

<span class="hljs-comment"># Block for accurate timing</span>
result = fast_function(x).block_until_ready()
</code></pre>
<p><strong>Next week</strong>: <em>Transformations That Change Everything—Automatic Vectorization with vmap and Gradients with grad</em></p>
<h3 id="heading-resources"><strong>Resources:</strong></h3>
<p><a target="_blank" href="https://jax.readthedocs.io">JAX Documentation</a></p>
<p><a target="_blank" href="https://jaxstack.ai">JAX AI Stack</a></p>
<p><a target="_blank" href="https://colab.research.google.com/drive/1f-qg4vlfBHSQSdEdhfgDdRA4k1dPtAqW?usp=sharing">Notebook</a></p>
]]></content:encoded></item><item><title><![CDATA[The reverse turing test: We must now prove we are “dumb” to beat AI]]></title><description><![CDATA[In the dystopian logic of the digital age, a new anxiety has gripped the writing world. From university lecture halls in Makerere to newsrooms in Kampala, humans are facing a pressure that would have seemed laughable just two years ago: the pressure ...]]></description><link>https://kambale.dev/the-reverse-turing-test</link><guid isPermaLink="true">https://kambale.dev/the-reverse-turing-test</guid><category><![CDATA[AI Writer]]></category><category><![CDATA[Ai detector]]></category><category><![CDATA[writing]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Wed, 10 Dec 2025 13:11:19 GMT</pubDate><enclosure url="https://cdn.hashnode.com/res/hashnode/image/upload/v1765372066941/1a94be35-ebd3-4f18-9138-1e345f1182ed.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<p>In the dystopian logic of the digital age, a new anxiety has gripped the writing world. From university lecture halls in Makerere to newsrooms in Kampala, humans are facing a pressure that would have seemed laughable just two years ago: the pressure to appear less polished, less articulate, or even "dumber", simply to avoid being mistaken for a machine.</p>
<p>It sounds absurd, but the "Reverse Turing Test" is here. Writers are deliberately inserting typos, breaking grammatical rules, and shunning perfectly functional words like "delighted," "landscape," or "delve," replacing them with awkward alternatives. Why? Because the black-box algorithms of AI detection tools have decided that high-proficiency English is evidence of a robot.</p>
<p>We have reached a dangerous inflection point where linguistic competence is treated with suspicion, and human error is fetishized as the only remaining proof of authenticity. But should the degradation of language really be the price of being believed?</p>
<h3 id="heading-the-statistical-mirror-why-ai-sounds-like-us">The statistical mirror: Why AI sounds like us</h3>
<p>To understand this crisis, we must first demystify the adversary. Large Language Models (LLMs) like GPT-4 or Gemini are not sentient poets; they are probabilistic engines. They are trained on the internet’s vast corpus of text; trillions of words written by humans over centuries.</p>
<p>When an AI uses words like "pivotal," "crucial," or "emphasize," it is not because it has a personal preference for corporate speak. It is demonstrating <strong>Zipf’s Law</strong>. This linguistic principle states that in any natural language, a small number of words are used with disproportionately high frequency to maximize efficiency. Humans naturally gravitate toward words that reduce cognitive load while maintaining clarity. LLMs simply mirror this statistical reality.</p>
<p>Therefore, the list of "banned" AI-sounding words currently circulating on social media reads like the standard vocabulary of any Ugandan NGO report, government white paper, or academic thesis from the last twenty years. If AI sounds "academic," it is only because it was trained on our academia. To penalize a writer for using structure and precision is to penalize them for being well-read.</p>
<h3 id="heading-the-bias-against-excellence">The bias against excellence</h3>
<p>The reliance on AI detectors is not just unscientific; it is discriminatory. We are witnessing a collision with <strong>Goodhart’s Law</strong>: "When a measure becomes a target, it ceases to be a good measure." By using "perplexity" (a measure of randomness) to judge humanity, detectors punish clear, logical writing.</p>
<p>This has grave implications for Africa. A 2023 study by researchers at Stanford University revealed a stunning bias: AI detectors flagged over <strong>61% of essays written by non-native English speakers</strong> as AI-generated, compared to nearly zero for native US 8th graders.</p>
<p>For the African student who has spent years mastering the "Queen’s English"; learning the formal transitions and structured arguments prized by schools, this is a slap in the face. Writing with the clarity and structure taught in our schools now puts you at risk of being labeled a fraud. The message sent to our students and professionals is chilling: <em>Write badly, or be doubted.</em></p>
<h3 id="heading-the-soul-in-the-machine">The soul in the machine</h3>
<p>However, while we should not dumb down our syntax, we must accept that AI forces us to elevate our substance.</p>
<p>AI can mimic the <em>form</em> of human expression, the rhythm of a sonnet or the structure of a press release, but it lacks the <em>referent</em>. It has no connection to the physical world. It processes symbols, not reality.</p>
<p>This is where the true differentiation lies. An AI can generate a paragraph about the concept of "cultural heritage," but it does not know the specific, heavy silence that falls over a clan meeting when <em>obuntu bulamu</em> is violated. It can describe the ingredients of <em>luwombo</em>, but it cannot understand the politics of the banana plantation or why the preparation of food is a language of love in Buganda.</p>
<p>AI operates on prediction; humans operate on intention.</p>
<p><strong>Contextual wisdom:</strong> A model knows that traffic jams are bad. It does not know the specific, communal frustration of a Friday evening gridlock on Jinja road, nor the humor exchanged between strangers in a taxi.</p>
<p><strong>Emotional weight:</strong> AI can string together words about grief, but it cannot choose a metaphor that breaks the heart because it has never had a heart to break. It cannot anchor a sentence in the lived experience of paying school fees in January.</p>
<h3 id="heading-the-path-forward">The path forward</h3>
<p>We must reject the impulse to perform incompetence. Trying to "beat the detector" by inserting errors is a losing battle; the models will eventually learn those tricks too.</p>
<p>Instead, the rise of AI should force a renaissance of <strong>voice</strong>. The era of generic, filler-heavy writing is indeed over; not because it is "AI," but because AI can do it faster. The human writer must now bring something the machine cannot: original insight grounded in lived reality.</p>
<p>We must lean into our idiosyncrasies, our cultural nuances, our irony, and our specific Ugandan perspectives. We must tell stories that rely on the messy, unpredictable texture of real life.</p>
<p>The real test of humanity is not whether we can write "less like a robot." It is whether we can think, feel, and observe the world deeply enough to say something that no statistical model could ever predict. On that front, we still hold the advantage.</p>
]]></content:encoded></item><item><title><![CDATA[Google Colab in VS Code: A Deep Dive into the New Extension]]></title><description><![CDATA[For years, a subtle but significant divide has existed in the workflow of millions of developers, data scientists, and AI researchers. On one side stood Visual Studio Code, the fast, lightweight, and endlessly customizable code editor beloved by the ...]]></description><link>https://kambale.dev/google-colab-in-vs-code-a-deep-dive-into-the-new-extension</link><guid isPermaLink="true">https://kambale.dev/google-colab-in-vs-code-a-deep-dive-into-the-new-extension</guid><category><![CDATA[colab]]></category><category><![CDATA[vscode extensions]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Sun, 16 Nov 2025 20:33:15 GMT</pubDate><enclosure url="https://cdn.hashnode.com/res/hashnode/image/upload/v1763323817437/4c5383bd-d4e7-4b33-aac4-cb1937f1d6d6.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<p>For years, a subtle but significant divide has existed in the workflow of millions of developers, data scientists, and AI researchers. On one side stood Visual Studio Code, the fast, lightweight, and endlessly customizable code editor beloved by the global developer community. On the other was Google Colab, the go-to platform for seamless access to powerful compute resources like GPUs and TPUs, simplifying the process of writing, executing, and collaborating on code. The workflow often involved a cumbersome dance between a customized local VS Code environment for project development and a separate, web-based Colab interface for training and inference.</p>
<p>Responding to years of passionate community requests manifested in blog posts, forum threads, and creative GitHub workarounds, Google has officially bridged this gap. Today, we are thrilled to explore the new <strong>Google Colab extension for Visual Studio Code</strong>, a tool that promises the best of both worlds. This article provides a detailed, technical tutorial on how to install, configure, and harness the power of this game-changing extension, transforming your local VS Code into a control room for Google’s heavy-lifting cloud infrastructure.</p>
<h4 id="heading-the-best-of-both-worlds-unifying-local-ide-and-cloud-compute">The Best of Both Worlds: Unifying Local IDE and Cloud Compute</h4>
<p>The core value of the Colab extension is its ability to meet developers where they are. It acknowledges that while Colab’s simplicity is a major strength, many users crave the advanced features of a full-fledged IDE for larger projects and complex workflows.</p>
<ul>
<li><p><strong>For VS Code Users:</strong> The primary advantage is the ability to connect local <code>.ipynb</code> notebooks to high-powered Colab runtimes. This means you can continue using your familiar, highly customized editor while seamlessly accessing premium GPUs and TPUs, including those available through Colab Pro subscriptions, without leaving your local environment.</p>
</li>
<li><p><strong>For Colab Users:</strong> This integration supports the common practice of working on notebooks that are part of a larger project or Git repository. It empowers users who need more robust IDE features—such as superior code completion, version control, and advanced debugging—by pairing the simplicity of Colab's provisioned runtimes with the prolific VS Code editor.</p>
</li>
</ul>
<p>Essentially, this move bridges the gap between code productivity and cloud compute scalability, eliminating the need to switch tabs, export notebooks, or manage credentials across different platforms.</p>
<h3 id="heading-getting-started-a-step-by-step-guide">Getting Started: A Step-by-Step Guide</h3>
<p>You can get up and running with the Colab extension in just a few clicks. The setup is designed to be intuitive and fast.</p>
<h4 id="heading-step-1-install-the-colab-extension">Step 1: Install the Colab Extension</h4>
<p>First, you need to add the extension to your VS Code installation.</p>
<ol>
<li><p>Open the <strong>Extensions</strong> view from the Activity Bar on the left side of your VS Code window (or press <code>Ctrl+Shift+X</code>).</p>
</li>
<li><p>In the marketplace search bar, type <code>Google Colab</code>.</p>
</li>
<li><p>Click <strong>Install</strong> on the official extension published by Google.</p>
</li>
<li><p>If you do not already have it, the installer will prompt you to install its required dependency, the official <strong>Jupyter</strong> extension.</p>
</li>
</ol>
<p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1763320984138/90506969-e1ac-44ee-bf66-ebdf803e45cc.png" alt class="image--center mx-auto" /></p>
<h4 id="heading-step-2-connect-to-a-colab-runtime">Step 2: Connect to a Colab Runtime</h4>
<p>Once installed, you can connect any local notebook to a Colab runtime.</p>
<ol>
<li><p>Create a new notebook (<code>.ipynb</code> file) or open an existing one in your local VS Code workspace.</p>
</li>
<li><p>To select the execution environment, you can either run a cell, which will prompt you to choose a kernel, or click the <strong>Select Kernel</strong> button in the top-right corner of the notebook interface.</p>
</li>
<li><p>From the dropdown menu, choose <strong>Select Another Kernel...</strong></p>
</li>
<li><p>Click on the <strong>Colab</strong> option. You will be prompted to sign in with your Google account.</p>
</li>
<li><p>After signing in, you can choose to create a <strong>New Colab Server</strong> or connect to an existing one you may have running. For your first time, you will create a new one.</p>
</li>
</ol>
<p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1763321266631/3267217f-9ab9-436d-b584-8639a329fb9a.png" alt class="image--center mx-auto" /></p>
<p>Your local notebook is now powered by a Google Colab runtime! You can give your Colab server a name so that it can be easily referenced and reused in the future.</p>
<h4 id="heading-step-3-select-your-compute-resources">Step 3: Select Your Compute Resources</h4>
<p>The true power of this extension lies in accessing specialized hardware. The available accelerator options and memory limits are determined by your Google Colab subscription plan.</p>
<ul>
<li><p><strong>Free Tier:</strong> Users have access to NVIDIA T4 GPUs and TPU v5e accelerators.</p>
</li>
<li><p><strong>Colab Pro Tier:</strong> Subscribers gain access to more powerful hardware, such as premium GPUs like the NVIDIA A100.</p>
</li>
</ul>
<p>After connecting to the Colab kernel, you can select your desired hardware accelerator for the session.</p>
<p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1763321472629/c39d6bf8-ba6b-421a-8798-277b0555f5d9.png" alt class="image--center mx-auto" /></p>
<h3 id="heading-practical-examples-beyond-the-basics">Practical Examples — Beyond the Basics</h3>
<p>To truly appreciate this new workflow, let's move beyond simple demonstrations. Here are two original, real-world scenarios that are impractical on a standard local machine but become trivial with the Colab extension.</p>
<h4 id="heading-example-1-gpu-accelerated-big-data-analysis-with-rapids-cudf">Example 1: GPU-Accelerated Big Data Analysis with RAPIDS cuDF</h4>
<p><strong>The Challenge:</strong> You need to analyze a large CSV file (several gigabytes) containing millions of records. Using a standard library like Pandas on a CPU can be painfully slow, with simple grouping and aggregation operations taking minutes to complete.</p>
<p><strong>The Solution:</strong> We'll use <strong>RAPIDS cuDF</strong>, a GPU-accelerated DataFrame library with a Pandas-like API. By running this in VS Code connected to a Colab GPU, we can perform the analysis in seconds. RAPIDS cuDF is now pre-installed in Colab GPU runtimes, making this seamless.</p>
<p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1763322318242/2e2070c3-2e92-4d82-ba9c-4971680afac7.png" alt class="image--center mx-auto" /></p>
<p><strong>The Result:</strong></p>
<p>The complex aggregation on <strong>3,066,766 rows</strong> of data completes in just <strong>0.3070 seconds</strong>. This incredible speed, demonstrated in the output below, transforms what would be a coffee-break task on a CPU into an interactive, real-time query. This showcases a real-world data engineering task made efficient and seamless, all within the comfort of VS Code.</p>
<p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1763322464430/d77ec1f4-6d34-467d-ba8e-8645530cecb7.png" alt class="image--center mx-auto" /></p>
<p><strong>The Code:</strong></p>
<p><em>You can find a copy of the corresponding Colab workbook for this example</em> <a target="_blank" href="https://github.com/wkambale/vscode-colab-extension-tutorial/blob/main/01-gpu-data-analysis-rapids.ipynb"><em>here</em></a><em>.</em></p>
<h4 id="heading-example-2-creative-ai-generating-art-with-stable-diffusion">Example 2: Creative AI — Generating Art with Stable Diffusion</h4>
<p><strong>The Challenge:</strong> Text-to-image models like Stable Diffusion are computationally expensive and require significant GPU VRAM, making them inaccessible to users without high-end local hardware.</p>
<p><strong>The Solution:</strong> We'll use the Hugging Face <code>diffusers</code> library to run a Stable Diffusion pipeline on our Colab GPU kernel. This allows us to generate high-quality images from text prompts directly inside a VS Code notebook.</p>
<p><strong>Install required libraries</strong> We need <code>diffusers</code>, <code>transformers</code>, and <code>accelerate</code> for this task.</p>
<p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1763322920056/5e9660b2-f0a5-4f25-9fee-882a81b051d4.png" alt class="image--center mx-auto" /></p>
<p><strong>Set up the Stable Diffusion Pipeline</strong> This code downloads the pre-trained model weights and prepares the pipeline for inference on the GPU.</p>
<p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1763323239647/9949e611-24ff-44c3-b07d-9430edcee16d.png" alt class="image--center mx-auto" /></p>
<p><strong>Generate an Image</strong> Define your creative prompt and let the model generate an image.</p>
<p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1763323355765/f57b7af2-f8f2-42b1-b201-283c1774ac36.png" alt class="image--center mx-auto" /></p>
<p><strong>The Result:</strong> Within a minute, a high-resolution, AI-generated image appears directly in your VS Code notebook output. This showcases how the extension democratizes access to powerful generative AI models, enabling creative experimentation without the need for a dedicated local GPU.</p>
<p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1763323425396/ca4ee33c-ef6f-445b-b091-ace783bd689f.png" alt class="image--center mx-auto" /></p>
<p><strong>The Code:</strong></p>
<p><em>You can find the corresponding Colab notebook for this example</em> <a target="_blank" href="https://github.com/wkambale/vscode-colab-extension-tutorial/blob/main/02-creative-ai-stable-diffusion.ipynb"><em>here</em></a><em>.</em></p>
<h3 id="heading-advanced-tips-and-current-limitations">Advanced Tips and Current Limitations</h3>
<p>To get the most out of the extension, keep these points in mind:</p>
<ul>
<li><p><strong>File Management:</strong> You are working on a remote Colab file system. Use commands like <code>!ls -l</code> or VS Code's built-in file explorer to see files generated in your runtime. To persist your work, consider mounting your Google Drive:</p>
<pre><code class="lang-python">  <span class="hljs-keyword">from</span> google.colab <span class="hljs-keyword">import</span> drive
  drive.mount(<span class="hljs-string">'/content/drive'</span>)
</code></pre>
</li>
<li><p><strong>Secrets Management:</strong> The web UI's native secrets manager is not yet available. For securely handling API keys, use the file-upload workaround to upload a <code>.env</code> file, as detailed in the source articles.</p>
</li>
<li><p><strong>Session Lifetime:</strong> Remember that Colab runtimes are ephemeral. They will disconnect after a period of inactivity (typically 90 minutes for free-tier users) or if you exceed the maximum session duration (12 hours). Save your work frequently.</p>
</li>
</ul>
<h3 id="heading-the-bigger-picture-whats-next">The Bigger Picture: What's Next?</h3>
<p>As a newly released tool, the Colab extension is still in its early stages, and some limitations exist. As noted, certain web-UI-specific functions like the secrets manager are not yet implemented. However, Google has positioned this release as a "launchpad," signaling a commitment to bringing even more of Colab's functionality to developers everywhere.</p>
<p>This launch also places Google in a fascinating strategic position, turning VS Code into a key battleground for AI developer mindshare. By bringing its powerful code <em>execution</em> engine into the same interface where tools like GitHub Copilot excel at code <em>generation</em>, Google is challenging the AI-assisted developer landscape.</p>
<p>For developers, this rising competition is a net positive. It promises a future where the lines between code generation and execution blur, and where powerful, integrated, and accessible AI tools become a fundamental component of the editor itself.</p>
<h4 id="heading-conclusion">Conclusion</h4>
<p>The new Google Colab extension for VS Code is more than just a convenience; it's a transformative tool that unifies the best of local development with the power of cloud computing. It empowers developers and ML engineers to harness free GPUs and TPUs directly within their preferred editor, streamlining workflows and accelerating innovation. While still in its early stages, the extension represents a significant step forward in making AI and machine learning development more accessible and productive. The future looks bright, and it runs on a seamless connection between your local machine and the cloud.</p>
]]></content:encoded></item><item><title><![CDATA[Build Your First AI Agent with Gemini and LlamaIndex]]></title><description><![CDATA[The world of LLMs is moving beyond simple chatbots. The new frontier is AI agents: systems that can reason, plan, and use external tools to accomplish complex tasks. In tourism, this means our assistant can automatically fetch up-to-date info on dest...]]></description><link>https://kambale.dev/build-your-first-ai-agent-with-gemini-and-llamaindex</link><guid isPermaLink="true">https://kambale.dev/build-your-first-ai-agent-with-gemini-and-llamaindex</guid><category><![CDATA[gemini]]></category><category><![CDATA[ai agents]]></category><category><![CDATA[tourism]]></category><category><![CDATA[LlamaIndex]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Sun, 14 Sep 2025 22:27:20 GMT</pubDate><enclosure url="https://cdn.hashnode.com/res/hashnode/image/upload/v1757888693269/ee93cfd2-3554-4dbf-baf2-4d7f2867425f.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<p>The world of LLMs is moving beyond simple chatbots. The new frontier is <strong>AI agents</strong>: systems that can reason, plan, and use external tools to accomplish complex tasks. In tourism, this means our assistant can automatically fetch up-to-date info on destinations, guide users to attractions or restaurants, find relevant images, and even draft promotional posts. Under the hood, we use Google’s advanced Gemini model as the “brain” and LlamaIndex to connect Gemini to custom Python tools. The agent follows the <strong>ReAct (Reason+Act)</strong> framework: it thinks about the query, chooses tools, acts (calls them), then reasons over the results. The result is a flexible travel assistant that can handle multi-step tasks in a human-like way.</p>
<p>Our agent will demonstrate how to:</p>
<ul>
<li><p><strong>Combine LLMs and Tools:</strong> Overcome LLM limitations by letting them call Python functions (tools) for fresh data.</p>
</li>
<li><p><strong>Build Custom Tools:</strong> Write scrapers and generators (e.g. for attractions and images) that the agent can invoke.</p>
</li>
<li><p><strong>Use Gemini as the LLM:</strong> Leverage Google’s multimodal Gemini 2.5 Pro model (our most advanced AI model) for reasoning.</p>
</li>
<li><p><strong>Leverage LlamaIndex:</strong> Use LlamaIndex’s <code>ReActAgent</code> and <code>FunctionTool</code> wrappers to glue tools and the LLM together.</p>
</li>
<li><p><strong>See the ReAct Loop in Action:</strong> Peek at the agent’s chain of thought as it solves a query by sequentially “thinking” and “acting”.</p>
</li>
<li><p><strong>Advanced Prompting:</strong> Go beyond raw data by feeding tool outputs into a carefully crafted prompt for insightful recommendations.</p>
</li>
</ul>
<p>By the end, you’ll have a working Tourism AI Assistant that answers travel questions, scrapes live data, finds images, and even tweets about destinations—all powered by Gemini and LlamaIndex.</p>
<h2 id="heading-our-technology-stack">Our Technology Stack</h2>
<p>Before coding, let’s understand the pieces we’ll use:</p>
<ul>
<li><p><strong>Google Gemini 2.5 Pro:</strong> A state-of-the-art multimodal LLM (text+code+images+video) with strong reasoning capabilities. This is the “brain” of our agent. We use the <code>models/gemini-2.5-pro</code> endpoint via Google’s AI Generative SDK.</p>
</li>
<li><p><strong>LlamaIndex:</strong> An open-source framework for building LLM apps. It provides the <code>ReActAgent</code> class and <code>FunctionTool</code> wrapper to connect LLMs with Python tools. LlamaIndex lets us expose any function to the agent in a structured way.</p>
</li>
<li><p><strong>Web Scraping Libraries:</strong> We’ll use <code>requests</code> and <code>beautifulsoup4</code> to fetch live data from travel sites (for attractions, etc.). (In production, a tourism API is often preferred to avoid brittle scrapers, but for our tutorial we’ll show how to scrape responsibly.)</p>
</li>
<li><p><strong>Image Search:</strong> We’ll implement a simple DuckDuckGo image scraper to let the agent fetch pictures of destinations.</p>
</li>
<li><p><strong>Content Generator:</strong> A function to draft a promotional social-media post about a place or event, illustrating how the agent can take action, not just fetch data.</p>
</li>
</ul>
<h2 id="heading-installing-dependencies">Installing Dependencies</h2>
<p>First, install the required Python packages. In your notebook or script, run:</p>
<pre><code class="lang-bash">!pip install -q llama-index llama-index-llms-gemini google-generativeai python-dotenv beautifulsoup4 requests
</code></pre>
<ul>
<li><p><strong>llama-index-llms-gemini:</strong> Adds Gemini support to LlamaIndex.</p>
</li>
<li><p><strong>google-generativeai:</strong> Google’s SDK to call the Gemini API.</p>
</li>
<li><p><strong>python-dotenv:</strong> For loading API keys from a <code>.env</code> file (or Colab secrets).</p>
</li>
<li><p><strong>beautifulsoup4 &amp; requests:</strong> For our web-scraping tools.</p>
</li>
</ul>
<p>We’ll also enable detailed logging to trace the agent’s reasoning:</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> os, logging
logging.basicConfig(level=logging.INFO, format=<span class="hljs-string">'%(asctime)s - %(levelname)s - %(message)s'</span>)
logger = logging.getLogger(__name__)
</code></pre>
<h2 id="heading-authentication">Authentication</h2>
<p>To use Gemini, you need a Google AI Studio API key:</p>
<ol>
<li><p>Create an API key in Google AI Studio.</p>
</li>
<li><p>Store it as <code>GOOGLE_API_KEY</code> in your environment (or Colab secrets).</p>
</li>
</ol>
<p>Then configure the Google Generative AI SDK:</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> google.generativeai <span class="hljs-keyword">as</span> genai
<span class="hljs-keyword">from</span> dotenv <span class="hljs-keyword">import</span> load_dotenv

load_dotenv()  <span class="hljs-comment"># if using .env file</span>
GOOGLE_API_KEY = os.getenv(<span class="hljs-string">"GOOGLE_API_KEY"</span>)
<span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> GOOGLE_API_KEY:
    <span class="hljs-keyword">raise</span> ValueError(<span class="hljs-string">"Please set GOOGLE_API_KEY environment variable."</span>)
genai.configure(api_key=GOOGLE_API_KEY)
logger.info(<span class="hljs-string">"Configured Google API key."</span>)
</code></pre>
<h2 id="heading-configuring-the-llm-and-embedding-model">Configuring the LLM and Embedding Model</h2>
<p>We tell LlamaIndex which LLM (Gemini) and embedding model to use. The embedding model isn’t crucial here (we won’t do retrieval-heavy tasks), but LlamaIndex requires one. We’ll use a small HuggingFace model:</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> llama_index.llms.gemini <span class="hljs-keyword">import</span> Gemini
<span class="hljs-keyword">from</span> llama_index.core <span class="hljs-keyword">import</span> Settings
<span class="hljs-keyword">from</span> llama_index.embeddings.huggingface <span class="hljs-keyword">import</span> HuggingFaceEmbedding

MODEL_NAME = <span class="hljs-string">"models/gemini-2.5-pro"</span>  <span class="hljs-comment"># Google’s Gemini 2.5 Pro model</span>

logger.info(<span class="hljs-string">"Configuring LlamaIndex settings..."</span>)
Settings.embed_model = HuggingFaceEmbedding(model_name=<span class="hljs-string">"BAAI/bge-small-en-v1.5"</span>)
Settings.llm = Gemini(model=MODEL_NAME)  <span class="hljs-comment"># uses GOOGLE_API_KEY automatically</span>
logger.info(<span class="hljs-string">f"Using Gemini LLM: <span class="hljs-subst">{MODEL_NAME}</span>"</span>)
</code></pre>
<p>Note: Gemini 2.5 Pro is a powerful model that excels at complex tasks. It supports text, code, and images, and even has built-in “thinking” ability for chain-of-thought reasoning.</p>
<h2 id="heading-crafting-our-tools">Crafting Our Tools</h2>
<p>An agent is only as capable as its tools. A <strong>tool</strong> here is a Python function that performs some action—like fetching data or creating content—that the LLM alone cannot do (e.g. real-time web queries). We’ll create a small toolkit for tourism tasks:</p>
<ul>
<li><p><strong>Tool 1:</strong> <code>get_city_attractions(city)</code> – Scrape top attractions for a city. (As an example, we’ll scrape a known travel guide or Wikipedia for a few highlights.)</p>
</li>
<li><p><strong>Tool 2:</strong> <code>get_city_restaurants(city)</code> – (Optional) Scrape or list popular restaurants. For brevity, this could return a static list or use a simple scrape.</p>
</li>
<li><p><strong>Tool 3:</strong> <code>search_for_destination_images(place)</code> – Search DuckDuckGo for images of a place (like “Eiffel Tower Paris”). Returns a few image URLs.</p>
</li>
<li><p><strong>Tool 4:</strong> <code>generate_tourism_post(place, summary)</code> – Generate a friendly promotional social-media post (tweet) about an attraction or city, given its name and a short summary.</p>
</li>
</ul>
<p>Each tool will have a clear docstring explaining its purpose and inputs. The agent reads these docstrings to know when to use which tool. For example:</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> requests
<span class="hljs-keyword">from</span> bs4 <span class="hljs-keyword">import</span> BeautifulSoup

<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">get_city_attractions</span>(<span class="hljs-params">city: str</span>) -&gt; str:</span>
    <span class="hljs-string">"""
    Fetches the top tourist attractions for the given city.
    Returns a bullet list of attractions and brief info.
    """</span>
    <span class="hljs-keyword">try</span>:
        <span class="hljs-comment"># Example: scrape PlanetWare or Wikipedia page for top attractions</span>
        url = <span class="hljs-string">f"https://www.planetware.com/<span class="hljs-subst">{city.lower()}</span>/top-rated-tourist-attractions-in-<span class="hljs-subst">{city.lower()}</span>.htm"</span>
        resp = requests.get(url)
        soup = BeautifulSoup(resp.text, <span class="hljs-string">"html.parser"</span>)
        attractions = [h2.get_text(strip=<span class="hljs-literal">True</span>) <span class="hljs-keyword">for</span> h2 <span class="hljs-keyword">in</span> soup.find_all(<span class="hljs-string">'h2'</span>)[:<span class="hljs-number">5</span>]]
        <span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> attractions:
            <span class="hljs-keyword">raise</span> Exception(<span class="hljs-string">"No attractions found"</span>)
        <span class="hljs-keyword">return</span> <span class="hljs-string">"\n"</span>.join(<span class="hljs-string">f"- <span class="hljs-subst">{a}</span>"</span> <span class="hljs-keyword">for</span> a <span class="hljs-keyword">in</span> attractions)
    <span class="hljs-keyword">except</span> Exception <span class="hljs-keyword">as</span> e:
        <span class="hljs-keyword">return</span> <span class="hljs-string">f"Could not retrieve attractions for <span class="hljs-subst">{city}</span> (error: <span class="hljs-subst">{e}</span>)"</span>
</code></pre>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">get_city_restaurants</span>(<span class="hljs-params">city: str</span>) -&gt; str:</span>
    <span class="hljs-string">"""
    Returns a short list of popular restaurants in the city.
    (For demo purposes, this may be a static list or scraped from a simple source.)
    """</span>
    <span class="hljs-comment"># Placeholder: in a real app, use an API or reliable source</span>
    dummy_data = {
        <span class="hljs-string">"Paris"</span>: [<span class="hljs-string">"Le Jules Verne (Eiffel Tower)"</span>, <span class="hljs-string">"L'Ambroisie"</span>, <span class="hljs-string">"Septime"</span>],
        <span class="hljs-string">"Rome"</span>: [<span class="hljs-string">"Da Enzo al 29"</span>, <span class="hljs-string">"Roscioli"</span>, <span class="hljs-string">"La Pergola"</span>]
    }
    <span class="hljs-keyword">return</span> <span class="hljs-string">"\n"</span>.join(<span class="hljs-string">f"- <span class="hljs-subst">{r}</span>"</span> <span class="hljs-keyword">for</span> r <span class="hljs-keyword">in</span> dummy_data.get(city, [<span class="hljs-string">"No data available"</span>]))
</code></pre>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">search_for_destination_images</span>(<span class="hljs-params">query: str</span>) -&gt; list:</span>
    <span class="hljs-string">"""
    Searches DuckDuckGo Images for the query and returns a list of image URLs.
    """</span>
    <span class="hljs-keyword">try</span>:
        res = requests.get(<span class="hljs-string">f"https://duckduckgo.com/?q=<span class="hljs-subst">{query.replace(<span class="hljs-string">' '</span>, <span class="hljs-string">'+'</span>)}</span>&amp;iar=images&amp;iax=images"</span>)
        soup = BeautifulSoup(res.text, <span class="hljs-string">"html.parser"</span>)
        imgs = soup.select(<span class="hljs-string">"img.tile--img__img"</span>)[:<span class="hljs-number">5</span>]
        <span class="hljs-keyword">return</span> [img.get(<span class="hljs-string">"src"</span>) <span class="hljs-keyword">for</span> img <span class="hljs-keyword">in</span> imgs <span class="hljs-keyword">if</span> img.get(<span class="hljs-string">"src"</span>)]
    <span class="hljs-keyword">except</span> Exception <span class="hljs-keyword">as</span> e:
        <span class="hljs-keyword">return</span> [<span class="hljs-string">f"Image search failed: <span class="hljs-subst">{e}</span>"</span>]
</code></pre>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">generate_tourism_post</span>(<span class="hljs-params">place: str, summary: str</span>) -&gt; str:</span>
    <span class="hljs-string">"""
    Generates a friendly social media post about a place.
    """</span>
    prompt = <span class="hljs-string">f"Write an enthusiastic tweet about visiting <span class="hljs-subst">{place}</span>. Summary: <span class="hljs-subst">{summary}</span>"</span>
    resp = genai.generate_message(model=MODEL_NAME, messages=[{<span class="hljs-string">"role"</span>: <span class="hljs-string">"user"</span>, <span class="hljs-string">"content"</span>: prompt}])
    <span class="hljs-keyword">return</span> resp.last
</code></pre>
<p><em>Note:</em> In production, scraping arbitrary sites (like PlanetWare) can be fragile. Sites may block scrapers, change layout, or have legal terms against scraping. For robust applications, a dedicated travel data API is preferable. Here we use simple scrapers for illustration.</p>
<h2 id="heading-wrapping-functions-for-llamaindex">Wrapping Functions for LlamaIndex</h2>
<p>LlamaIndex needs tools wrapped in <code>FunctionTool</code> objects. This exposes each function’s signature and docstring to the LLM. Docstrings become part of the agent’s understanding of when to use each tool. For example:</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> llama_index.core.tools <span class="hljs-keyword">import</span> FunctionTool

logger.info(<span class="hljs-string">"Wrapping functions into LlamaIndex tools..."</span>)
tools = [
    FunctionTool.from_defaults(fn=get_city_attractions,
                               name=<span class="hljs-string">"get_city_attractions"</span>,
                               description=<span class="hljs-string">"Get the top tourist attractions in a city."</span>),
    FunctionTool.from_defaults(fn=get_city_restaurants,
                               name=<span class="hljs-string">"get_city_restaurants"</span>,
                               description=<span class="hljs-string">"Get popular restaurants in a city."</span>),
    FunctionTool.from_defaults(fn=search_for_destination_images,
                               name=<span class="hljs-string">"search_for_destination_images"</span>,
                               description=<span class="hljs-string">"Search for images of a place."</span>),
    FunctionTool.from_defaults(fn=generate_tourism_post,
                               name=<span class="hljs-string">"generate_tourism_post"</span>,
                               description=<span class="hljs-string">"Create a social media post promoting an attraction or city."</span>)
]
logger.info(<span class="hljs-string">f"Created <span class="hljs-subst">{len(tools)}</span> tools: <span class="hljs-subst">{[t.name <span class="hljs-keyword">for</span> t <span class="hljs-keyword">in</span> tools]}</span>"</span>)
</code></pre>
<p>Each <code>FunctionTool</code> includes the function name, parameters, and a human-readable docstring. The agent will <strong>choose which tool to call</strong> by looking at your query and the tool descriptions, then pass the right arguments.</p>
<h2 id="heading-initializing-the-react-agent">Initializing the ReAct Agent</h2>
<p>With tools ready, we create the agent. We use LlamaIndex’s <code>ReActAgent</code>, which implements the Reason-and-Act loop. The agent will <strong>think</strong> about the user’s question, decide <strong>which tool</strong> to use and with what inputs, <strong>execute</strong> the tool, then observe the result and repeat as needed. Setting <code>verbose=True</code> lets us see the entire chain of thought:</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> llama_index.core.agent <span class="hljs-keyword">import</span> ReActAgent

logger.info(<span class="hljs-string">"Initializing the ReAct agent..."</span>)
agent = ReActAgent.from_tools(tools=tools, llm=Settings.llm, verbose=<span class="hljs-literal">True</span>)
logger.info(<span class="hljs-string">"Tourism Agent is ready to go!"</span>)
</code></pre>
<p>Under the hood, <code>ReActAgent</code> will prompt Gemini with a system message describing the tools and this reasoning framework. The LLM will follow a structure like:</p>
<pre><code class="lang-bash">Thought: I should use [tool] because ...
Action: [tool] with input {...}
Observation: (tool output)
Thought: ... <span class="hljs-built_in">continue</span> or final answer.
</code></pre>
<p>This explicit reasoning flow is characteristic of the ReAct approach.</p>
<h2 id="heading-trying-out-the-tourism-agent">Trying Out the Tourism Agent</h2>
<p>Our agent is now online! Let’s test it with a couple of example queries and inspect its reasoning (verbose output):</p>
<h3 id="heading-scenario-1-simple-question">Scenario 1: Simple Question</h3>
<p><strong>User:</strong> “What are the top attractions in Paris?”</p>
<pre><code class="lang-plaintext">response = agent.chat("What are the top tourist attractions in Paris?")
print(response)
</code></pre>
<p><em>Agent’s (verbose) Thought Process:</em></p>
<ul>
<li><p><strong>Thought:</strong> The user asks for top attractions in Paris. The <code>get_city_attractions</code> tool seems appropriate.</p>
</li>
<li><p><strong>Action:</strong> get_city_attractions</p>
</li>
<li><p><strong>Action Input:</strong> <code>{"city": "Paris"}</code></p>
</li>
<li><p><strong>Observation:</strong></p>
<pre><code class="lang-bash">  - Eiffel Tower
  - Louvre Museum
  - Notre-Dame Cathedral
  - Sacré-Cœur Basilica
  - Musée d<span class="hljs-string">'Orsay</span>
</code></pre>
</li>
<li><p><strong>Thought:</strong> I have the list of attractions. Now I answer the user using this info.</p>
</li>
<li><p><strong>Answer:</strong> <em>“Top attractions in Paris include the Eiffel Tower, Louvre Museum, Notre-Dame Cathedral, Sacré-Cœur, and Musée d’Orsay.”</em></p>
</li>
</ul>
<p><strong>Agent Answer:</strong> <em>Top attractions in Paris include: Eiffel Tower, Louvre Museum, Notre-Dame Cathedral, Sacré-Cœur Basilica, and Musée d’Orsay.</em></p>
<p>This shows the ReAct loop: the agent identified the need for a tool, executed it, and synthesized the output.</p>
<h3 id="heading-scenario-2-multi-step-task">Scenario 2: Multi-Step Task</h3>
<p><strong>User:</strong> “Find art-related attractions in Paris and then write a tweet about them.”</p>
<p>This requires multiple steps: identify art attractions, then generate a tweet.</p>
<pre><code class="lang-python">response = agent.chat(<span class="hljs-string">"Find art-related attractions in Paris and draft a tweet about them."</span>)
print(response)
</code></pre>
<p><em>Agent’s Thought Process:</em></p>
<ul>
<li><p><strong>Thought:</strong> The user needs two things: 1) art-related attractions in Paris, 2) a tweet about them. First, get all attractions.</p>
</li>
<li><p><strong>Action:</strong> get_city_attractions</p>
</li>
<li><p><strong>Input:</strong> <code>{"city": "Paris"}</code></p>
</li>
<li><p><strong>Observation:</strong> (same list as before)</p>
</li>
<li><p><strong>Thought:</strong> The attractions list is obtained. Now pick those related to art (e.g. museums, galleries). These include Louvre Museum and Musée d'Orsay. I will create a summary and use the <code>generate_tourism_post</code> tool.</p>
</li>
<li><p><strong>Action:</strong> generate_tourism_post</p>
</li>
<li><p><strong>Input:</strong> <code>{"place": "Paris Museums", "summary": "Highlights: Louvre Museum and Musée d'Orsay"}</code></p>
</li>
<li><p><strong>Observation:</strong></p>
<pre><code class="lang-bash">  Explore Paris<span class="hljs-string">'s Artistic Treasures!
  Don’t miss the Louvre Museum’s timeless classics and the Musée d’Orsay’s Impressionist masterpieces. From ancient wonders to modern art, Paris has it all. Bon voyage and happy exploring!</span>
</code></pre>
</li>
<li><p><strong>Thought:</strong> I have the tweet text. Now I can share the final answer.</p>
</li>
<li><p><strong>Answer:</strong> <em>Here’s a tweet promoting Paris’s art attractions:… (tweet above)…</em></p>
</li>
</ul>
<p><strong>Agent Answer:</strong> <em>“Here’s a tweet about Paris’s art attractions: Explore Paris’s Artistic Treasures! Don’t miss the Louvre Museum’s timeless classics and the Musée d’Orsay’s Impressionist masterpieces. From ancient wonders to modern art, Paris has it all. Bon voyage and happy exploring!”</em></p>
<p>This demonstrates chaining: the agent used <code>get_city_attractions</code>, filtered the results, then used <code>generate_tourism_post</code> with the summary. The LLM reasoned at each step, as expected in a ReAct agent.</p>
<h2 id="heading-advanced-prompting-personalized-recommendations">Advanced Prompting: Personalized Recommendations</h2>
<p>So far we’ve fetched raw data and posted about it. But LLMs can also provide <strong>insights</strong> on top of data. For example, suppose a user asks:</p>
<p><em>“I’m a history buff visiting Rome. Which attractions should I see?”</em></p>
<p>Simply returning a list of <em>all</em> Rome attractions isn’t ideal. Instead, we want <strong>personalized recommendations</strong>. We can achieve this by combining our tools with a clever prompt:</p>
<ol>
<li><p><strong>Get raw data via tool:</strong> Call <code>get_city_attractions("Rome")</code>.</p>
</li>
<li><p><strong>Engineer a focused prompt:</strong> Tell Gemini to read the list and pick the best 3-5 for a history enthusiast, explaining why.</p>
</li>
<li><p><strong>Generate the answer with Gemini:</strong> The LLM acts as a reasoning layer on the data.</p>
</li>
</ol>
<p>For example:</p>
<pre><code class="lang-python"><span class="hljs-comment"># Step 1: fetch attractions</span>
attractions = get_city_attractions(<span class="hljs-string">"Rome"</span>)
<span class="hljs-comment"># (assume this returns a bullet list of sites like Colosseum, Vatican, Pantheon, etc.)</span>

<span class="hljs-comment"># Step 2: craft a targeted prompt</span>
prompt = <span class="hljs-string">f"""
You are an expert tour guide for Rome. A user is interested in historical sites. 
From the following list of Rome attractions, recommend 3-5 must-see locations for a history buff. Explain why each is a good choice.
Rome Attractions:
<span class="hljs-subst">{attractions}</span>

Focus on historical significance and include brief descriptions.
"""</span>
<span class="hljs-comment"># Step 3: call Gemini directly with this prompt</span>
historical_resp = genai.generate_message(model=MODEL_NAME, messages=[{<span class="hljs-string">"role"</span>: <span class="hljs-string">"user"</span>, <span class="hljs-string">"content"</span>: prompt}])
print(historical_resp.last)
</code></pre>
<p><strong>Sample Output:</strong></p>
<blockquote>
<p>Excellent choices for a history enthusiast in Rome include:</p>
<ul>
<li><p><strong>Colosseum</strong> – The iconic ancient amphitheater where gladiators once fought. A symbol of Rome’s imperial past and architecture.</p>
</li>
<li><p><strong>Roman Forum</strong> – The heart of ancient Rome’s political and social life. You can walk among ruins of temples and government buildings.</p>
</li>
<li><p><strong>Pantheon</strong> – A 2,000-year-old temple-turned-church, showcasing Rome’s engineering and dedication to the gods. Its monumental dome is a marvel.</p>
</li>
<li><p><strong>Vatican Museums &amp; St. Peter’s Basilica</strong> – While a church, the Vatican holds vast historical and artistic treasures spanning millennia, including Raphael’s Rooms and Michelangelo’s Pietà.</p>
</li>
<li><p><strong>Catacombs of Callixtus</strong> – Underground burial chambers that reveal early Christian history and traditions.</p>
</li>
</ul>
<p>Each of these sites offers rich historical insights into Rome’s past civilizations and will captivate any history buff!</p>
</blockquote>
<p>This “reasoning layer” approach leverages Gemini’s understanding to <em>analyze and filter</em> the tool output. We transformed a raw list into a personalized recommendation list with explanation. One could even wrap this workflow into a new tool (e.g. <code>recommend_attractions_for_interest</code>) for the agent to use directly.</p>
<h2 id="heading-conclusion">Conclusion</h2>
<p>We’ve built a fully functional <strong>Tourism AI Assistant</strong> that can interpret user requests, choose the right tool, fetch live data, and present it in a helpful way. We saw:</p>
<ul>
<li><p><strong>Setting up the agent:</strong> Configuring Gemini and LlamaIndex.</p>
</li>
<li><p><strong>Writing tools:</strong> Python functions with clear docstrings, wrapped as <code>FunctionTool</code> for the agent.</p>
</li>
<li><p><strong>Using ReAct:</strong> The agent’s chain of thought (“Thought”, “Action”, “Observation”) shows how it plans and executes.</p>
</li>
<li><p><strong>Multimodal capability:</strong> Although we didn’t show it here, Gemini supports images and code, so you could extend this agent to analyze photos of landmarks or compute routes.</p>
</li>
<li><p><strong>Advanced prompting:</strong> We enhanced raw data by feeding it into Gemini with a crafted prompt, yielding richer, customized advice.</p>
</li>
</ul>
<p>This agentic architecture is highly extensible. Next steps might include integrating real travel APIs (for flights or hotels), adding a <strong>memory</strong> so the assistant recalls user preferences, or connecting to mapping services for directions. The key idea is that the LLM does the reasoning, while we supply specialized tools for any real-world data or action.</p>
<p><em>Happy travels and happy building!</em></p>
]]></content:encoded></item><item><title><![CDATA[Who Cares About the Mental Health of IT Professionals?]]></title><description><![CDATA[The flickering neon glow of my monitor is often the only constant at night, a stark contrast to the erratic power supply that often interrupts my work. Deadlines loom, datasets often stubbornly imperfect, reflecting biases I'm struggling to mitigate,...]]></description><link>https://kambale.dev/mental-health-of-it-professionals</link><guid isPermaLink="true">https://kambale.dev/mental-health-of-it-professionals</guid><category><![CDATA[mentalhealth]]></category><category><![CDATA[Mental Health]]></category><category><![CDATA[It Professionals ]]></category><category><![CDATA[software development]]></category><category><![CDATA[Software Engineering]]></category><category><![CDATA[HR]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Wed, 07 May 2025 14:51:00 GMT</pubDate><enclosure url="https://cdn.hashnode.com/res/hashnode/image/upload/v1746628092731/22d9e8d2-8a1b-409f-a939-af858384723a.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<p>The flickering neon glow of my monitor is often the only constant at night, a stark contrast to the erratic power supply that often interrupts my work. Deadlines loom, datasets often stubbornly imperfect, reflecting biases I'm struggling to mitigate, and a familiar knot of anxiety tightens in my chest. Is this algorithm truly going to help, or would its flaws inadvertently harm the very communities we aim to serve? This pressure, this ethical weight, combined with the relentless demand for innovation, is a silent burden carried by many of us in Africa's burgeoning tech scene.</p>
<p>So, who cares about the mental health of software engineers and IT professionals, particularly those powering Africa's digital transformation? The answer must be: we all should. Behind the celebratory headlines of Africa's tech unicorns, a human cost is mounting – the often-invisible erosion of mental well-being among the architects of this progress. This is not merely a collection of individual struggles; it is a systemic challenge that threatens the sustainability of innovation and the very fabric of our tech ecosystems.</p>
<p>As we observe Mental Health Awareness Month this May, it is time to move beyond acknowledging this crisis to implementing tangible changes. My experiences, and those of my colleagues across the continent, underscore the urgency of this call. The narrative of "Africa Rising" in technology is inspiring, but it must not overshadow the profound pressures faced by those on the ground. The expectation to contribute to this grand narrative, often with limited resources and against significant infrastructural odds, can itself become a significant stressor.  </p>
<p><strong>Beyond the Global North: The Unique Pressures on Africa's Tech Talent</strong></p>
<p>Globally, the mental toll on tech professionals is well-documented. Demanding work environments, tight deadlines, the constant need for upskilling, and the specter of burnout, anxiety, and imposter syndrome are common parlance in Silicon Valley and beyond. Studies show that over half of software engineers globally report burnout. However, in Africa, these universal strains are not just replicated; they are amplified and compounded by a unique set of local realities.  </p>
<p>In many Ugandan communities, and indeed across much of Africa, mental health remains a taboo subject. Seeking help for anxiety or depression can be misconstrued as personal weakness or even attributed to spiritual failings, rather than being recognized as a legitimate health concern. This cultural barrier is a formidable obstacle; if employees are afraid to speak out, companies may remain unaware of the depth of the problem or feel unmotivated to invest in solutions. This silence perpetuates a cycle where low demand for services (due to stigma) leads to low investment, reinforcing the notion that these are not critical workplace issues.  </p>
<p>Economic and infrastructural realities add further pressure. In a context where a tech job can be a life-changing opportunity for an individual and their family, the stakes for success are incredibly high. The daily grind can be exacerbated by challenges like unreliable power or expensive and erratic internet connectivity – issues that can turn a straightforward coding task into a marathon of frustration and directly impact productivity and stress levels.  </p>
<p>For those of us working in Machine Learning and AI in Africa, there are additional, nuanced burdens. The global discourse on "Ethical AI" and "AI for Good" often places a heavy responsibility on African engineers to solve complex societal problems, sometimes without adequate local data, resources, or established ethical frameworks. We grapple with building AI that is not only innovative but also culturally relevant and free from biases that could harm vulnerable populations, for example, in healthcare or financial services. The challenge of "decolonizing AI," ensuring our creations are truly beneficial and equitable for African contexts, adds a significant mental and ethical load. This isn't just about delivering a product; it's about navigating a complex ethical landscape with potentially far-reaching consequences, a unique form of pressure that can lead to "ethical burnout."  </p>
<p><strong>"System Error": The Current State of Mental Health Support in African Tech</strong></p>
<p>When we look at the mental health support typically available within African tech companies and startups, it often resembles a patchwork of well-intentioned but ultimately inadequate measures. Employee Assistance Programs (EAPs), if they exist, may be generic and not culturally attuned. The approach is frequently reactive, addressing crises rather than fostering a culture of proactive well-being. Many "wellness" initiatives might touch on stress management but fail to address the deep-seated systemic issues or the unique pressures of the tech environment.  </p>
<p>Consider a software developer in Kampala, let's call her Aisha. Aisha is brilliant and dedicated, working for a promising fintech startup. She's battling intense burnout from months of 70-hour weeks leading up to a critical product launch. When she finally musters the courage to hint at her struggles to her manager, she’s met with a well-meaning but ultimately unhelpful suggestion to "take a short break" and a reminder of how crucial her role is. There's no formal HR pathway for mental health support, no access to confidential counselling that understands her context. Her story, and others like it – the junior IT support staff in Nairobi feeling overwhelmed by user demands and job insecurity, or the data scientist in Lagos wrestling with the ethical implications of an algorithm with little institutional guidance – are common, though often whispered in hushed tones. These anonymized experiences highlight a gap: the tech sector, while creating innovative solutions for Africa, sometimes fails to apply that same innovative spirit to supporting its own people.  </p>
<p>There are glimmers of hope. Organizations like Mindverse Uganda are conducting workshops in companies, and innovative approaches like the Tele-Support Psychotherapy (TSP) model show the potential of culturally sensitive, tech-enabled solutions, even if primarily focused outside the corporate sphere for now. Some forward-thinking companies, like Andela or Interswitch, have begun to list mental health support among their employee perks. However, these are often exceptions. The "startup culture" globally tends to prioritize breakneck growth over sustained well-being, a trend likely amplified in Africa's resource-tighter, higher-pressure emerging ecosystems. This can lead to a systemic lack of support, masked by a few positive but isolated examples.  </p>
<p><strong>Acknowledging Realities, Not Excuses</strong></p>
<p>It is important to acknowledge the genuine constraints. Many African tech companies, particularly startups, are indeed navigating challenging economic landscapes where every shilling counts. Implementing comprehensive mental health programs can seem like a daunting expense when survival itself is a daily concern. The argument of "resource constraints," while valid, risks becoming an excuse if mental health is not viewed as a critical investment in human capital—an investment that yields tangible returns in productivity, innovation, and retention.  </p>
<p>Furthermore, mental well-being is a multifaceted societal challenge, influenced by factors far beyond the workplace, including the overall strength of national healthcare systems, which in many African countries are underfunded and under-resourced. The workplace cannot solve everything, but it has a profound impact and a significant role to play.  </p>
<p>Encouragingly, awareness is growing. The rise of African mental health tech startups is a testament to this, even as it underscores a service gap that traditional employers, including tech companies themselves, have yet to fully address. These platforms offer innovative ways to bypass stigma and access barriers, yet their reach can be limited by the same infrastructural challenges—like internet accessibility or data costs—that contribute to workplace stress.  </p>
<p><strong>Debugging Our Approach: A Call for Action in African Tech</strong></p>
<p>As we embrace the 2025 Mental Health Awareness Month theme, the well-being of Africa's tech talent must become a strategic priority. This is not merely an HR concern; it is fundamental to fostering sustainable innovation, driving economic growth, and ensuring our continent's technological advancement is both equitable and human-centered.  </p>
<p>For tech leaders and companies in Uganda and across Africa, the first step is to <strong>cultivate cultures of openness and psychological safety</strong>. This means leadership openly discussing mental health, destigmatizing help-seeking, and establishing safe, confidential channels for employees to voice concerns without fear of reprisal. For those of us in AI/ML, this safety must extend to discussing the ethical quandaries and potential societal impacts of our work without fear of being penalized for raising difficult questions.  </p>
<p>Secondly, <strong>invest in culturally-attuned, accessible mental health resources</strong>. This means moving beyond generic offerings to partner with local mental health professionals and organizations that understand the specific cultural contexts. Companies like Mindverse Uganda already offer workplace workshops. Services leveraging principles from successful local models like Uganda's TSP, which emphasizes cultural sensitivity, should be explored. Therapy benefits, flexible mental health leave policies, and confidential counselling should be the norm, not the exception.  </p>
<p>Thirdly, <strong>design for well-being within work environments and project workflows</strong>. This includes promoting manageable workloads, offering flexible work arrangements where practical, ensuring clear role expectations, and embedding ethical AI development frameworks that explicitly consider developer well-being, especially when projects involve sensitive data or carry high societal impact.  </p>
<p>Finally, <strong>empower managers and team leads with mental health literacy</strong>. Training them to recognize signs of distress, engage in supportive conversations, and guide team members to appropriate resources is crucial for early intervention.  </p>
<p>As tech professionals, we also have a role. We must <strong>foster peer support networks</strong> and advocate responsibly for better conditions within our workplaces. Sharing our experiences, even anonymously, can build collective awareness and drive change. Crucially, we must prioritize <strong>practicing self-care</strong> and utilizing available resources, however limited they may seem. As ML engineers, we are uniquely positioned to not only highlight these issues but also to contribute to designing tech-enabled, data-informed mental health solutions for our workplaces and the wider community.  </p>
<p>For policymakers and ecosystem enablers—governments, investors, and tech hubs—the call is to <strong>integrate mental health into tech development agendas</strong>. This means recognizing that a healthy workforce is a productive and innovative workforce. Investors can play a powerful role by <strong>incentivizing and supporting best practices</strong>, perhaps by including employee well-being metrics in their investment criteria or incubator programs.  </p>
<p>Who cares about the mental health of Africa's tech talent? The answer must be a resounding "We do." The vision is an African tech ecosystem that is not only a global hub of innovation but also a beacon of human-centered progress. An ecosystem where the brilliant minds building the future, like Aisha, feel supported, valued, and mentally sound. Achieving this requires a collective commitment—from the boardroom to the individual developer, from policymakers to investors. Only then can we ensure that Africa's tech revolution truly benefits all.</p>
]]></content:encoded></item><item><title><![CDATA[Luganda Inference on Gemma 3]]></title><description><![CDATA[Introduction
Google has unveiled Gemma 3, the latest iteration of its open AI models, featuring four versions: gemma-3-1b-it, gemma-3-4b-it, gemma-3-12b-it, and gemma-3-27b-it.
The gemma-3-1b-it model is limited to text-only input, supports English e...]]></description><link>https://kambale.dev/luganda-inference-on-gemma-3</link><guid isPermaLink="true">https://kambale.dev/luganda-inference-on-gemma-3</guid><category><![CDATA[luganda]]></category><category><![CDATA[gemma]]></category><category><![CDATA[genai]]></category><category><![CDATA[inference]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Mon, 17 Mar 2025 15:12:41 GMT</pubDate><enclosure url="https://cdn.hashnode.com/res/hashnode/image/upload/v1742224299893/52f52d5b-aeff-46f3-a84a-c63c21db3b3c.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<h2 id="heading-introduction">Introduction</h2>
<p>Google has unveiled <strong>Gemma 3</strong>, the latest iteration of its open AI models, featuring four versions: <code>gemma-3-1b-it</code>, <code>gemma-3-4b-it</code>, <code>gemma-3-12b-it</code>, and <code>gemma-3-27b-it</code>.</p>
<p>The <code>gemma-3-1b-it</code> model is limited to <strong>text-only input</strong>, supports <strong>English exclusively</strong>, and comes with a <strong>32k context length</strong>. Due to its lack of multilingual capabilities, it is unsuitable for a <strong>Luganda inference</strong>.</p>
<p>In contrast, the <code>gemma-3-4b-it</code><strong>,</strong> <code>gemma-3-12b-it</code><strong>,</strong> and <code>gemma-3-27b-it</code> models support <strong>both text and image input</strong>, recognize <strong>140+ languages</strong>, and offer an extended <strong>128k context length</strong>, making them far better suited for multilingual tasks.</p>
<p>For this specific task, we are using <code>gemma-3-4b-it</code> due to its balance of <strong>performance and efficiency</strong>.</p>
<p><strong>Accessing Gemma 3 models</strong></p>
<p>Before using Gemma 3 for the first time, you must request access to the model through Hugging Face by completing the following steps:</p>
<ol>
<li><p>Log in to <a target="_blank" href="https://www.google.com/url?q=https%3A%2F%2Fhuggingface.co">Hugging Face</a>, or create a new Hugging Face account if you don't already have one.</p>
</li>
<li><p>Go to the <a target="_blank" href="https://www.google.com/url?q=https%3A%2F%2Fhuggingface.co%2Fgoogle%2Fgemma-3-4b-it">Gemma 3 model card</a> to get access to the model.</p>
</li>
<li><p>Complete the consent form and accept the terms and conditions.</p>
</li>
</ol>
<p>To generate a Hugging Face token, open your <a target="_blank" href="https://www.google.com/url?q=https%3A%2F%2Fhuggingface.co%2Fsettings"><strong>Settings</strong> page in Hugging Face</a>, choose <strong>Access Tokens</strong> option in the left pane and click <strong>New token</strong>. In the next window that appears, give a name to your token and choose the type as <strong>Write</strong> to get the write access.</p>
<p>Then, in Colab, select <strong>Secrets</strong> (🔑) in the left pane and add your Hugging Face token. Store your Hugging Face token under the name <code>HF_TOKEN</code>.</p>
<p><strong>Select the runtime</strong></p>
<p>To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to load the Gemma 3 model. In this case, a T4/L4 GPU would be needed to load the model weights.</p>
<ol>
<li><p>In the upper-right of the Colab window, click the dropdown menu.</p>
</li>
<li><p>Select <strong>Change runtime type</strong>.</p>
</li>
<li><p>Under <strong>Hardware accelerator</strong>, select <strong>T4 or L4</strong>.</p>
</li>
</ol>
<h3 id="heading-install-transformers">Install Transformers</h3>
<pre><code class="lang-bash">!pip install git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3
</code></pre>
<h3 id="heading-import-libraries-and-dependencies">Import libraries and dependencies</h3>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> transformers <span class="hljs-keyword">import</span> AutoProcessor, Gemma3ForConditionalGeneration
<span class="hljs-keyword">from</span> PIL <span class="hljs-keyword">import</span> Image
<span class="hljs-keyword">import</span> cv2
<span class="hljs-keyword">from</span> IPython.display <span class="hljs-keyword">import</span> Markdown, HTML
<span class="hljs-keyword">from</span> base64 <span class="hljs-keyword">import</span> b64encode
<span class="hljs-keyword">import</span> requests
<span class="hljs-keyword">import</span> torch
</code></pre>
<h3 id="heading-choose-the-gemma-3-model-variant-to-use">Choose the Gemma 3 model variant to use</h3>
<p>Gemma 3 is available in four sizes, each supporting different features:</p>
<ul>
<li><p><code>gemma-3-1b-it</code></p>
<ul>
<li><p>Supports only text input and English language</p>
</li>
<li><p>32k context length</p>
</li>
</ul>
</li>
<li><p><code>gemma-3-4b-it</code>, <code>gemma-3-12b-it</code>, <code>gemma-3-27b-it</code></p>
<ul>
<li><p>Supports both text and image input</p>
</li>
<li><p>Supports 140+ languages</p>
</li>
<li><p>128k context length</p>
</li>
</ul>
</li>
</ul>
<pre><code class="lang-python">model_name = <span class="hljs-string">'gemma-3-4b-it'</span> <span class="hljs-comment">#We are using 4b</span>
model_id = <span class="hljs-string">f"google/<span class="hljs-subst">{model_name}</span>"</span>

model = Gemma3ForConditionalGeneration.from_pretrained(
    model_id, device_map=<span class="hljs-string">"auto"</span>, torch_dtype=torch.bfloat16,
).eval()

processor = AutoProcessor.from_pretrained(model_id)
</code></pre>
<h3 id="heading-define-helper-functions">Define helper functions</h3>
<ul>
<li><p><code>resize_image</code>: Resizes the input images to <code>n x n</code> pixels, ensuring the aspect ratio is preserved.</p>
</li>
<li><p><code>get_model_response</code>: Send a text prompt and an image to the model, and retrieve the model's response.</p>
</li>
<li><p><code>extract_frames</code>: Extracts a specified number of evenly spaced frames from a video file along with their timestamps.</p>
</li>
<li><p><code>show_video</code>: Embeds and displays a video in an HTML5 player.</p>
</li>
</ul>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">resize_image</span>(<span class="hljs-params">image_path</span>):</span>
    img = Image.open(image_path)

    target_width, target_height = <span class="hljs-number">640</span>, <span class="hljs-number">640</span>
    <span class="hljs-comment"># Calculate the target size (maximum width and height).</span>
    <span class="hljs-keyword">if</span> target_width <span class="hljs-keyword">and</span> target_height:
        max_size = (target_width, target_height)
    <span class="hljs-keyword">elif</span> target_width:
        max_size = (target_width, img.height)
    <span class="hljs-keyword">elif</span> target_height:
        max_size = (img.width, target_height)

    img.thumbnail(max_size)

    <span class="hljs-keyword">return</span> img


<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">get_model_response</span>(<span class="hljs-params">img: Image, prompt: str, model, processor</span>):</span>
    <span class="hljs-comment"># Prepare the messages for the model.</span>
    messages = [
        {
            <span class="hljs-string">"role"</span>: <span class="hljs-string">"system"</span>,
            <span class="hljs-string">"content"</span>: [{<span class="hljs-string">"type"</span>: <span class="hljs-string">"text"</span>, <span class="hljs-string">"text"</span>: <span class="hljs-string">"You are a helpful assistant. Reply only with the answer to the question asked in Luganda language only, and avoid using additional text in your response like 'here's the answer'."</span>}]
        },
        {
            <span class="hljs-string">"role"</span>: <span class="hljs-string">"user"</span>,
            <span class="hljs-string">"content"</span>: [
                {<span class="hljs-string">"type"</span>: <span class="hljs-string">"image"</span>, <span class="hljs-string">"image"</span>: img},
                {<span class="hljs-string">"type"</span>: <span class="hljs-string">"text"</span>, <span class="hljs-string">"text"</span>: prompt}
            ]
        }
    ]

    <span class="hljs-comment"># Tokenize inputs and prepare for the model.</span>
    inputs = processor.apply_chat_template(
        messages, add_generation_prompt=<span class="hljs-literal">True</span>, tokenize=<span class="hljs-literal">True</span>,
        return_dict=<span class="hljs-literal">True</span>, return_tensors=<span class="hljs-string">"pt"</span>
    ).to(model.device, dtype=torch.bfloat16)

    input_len = inputs[<span class="hljs-string">"input_ids"</span>].shape[<span class="hljs-number">-1</span>]

    <span class="hljs-comment"># Generate response from the model.</span>
    <span class="hljs-keyword">with</span> torch.inference_mode():
        generation = model.generate(**inputs, max_new_tokens=<span class="hljs-number">100</span>, do_sample=<span class="hljs-literal">False</span>)
        generation = generation[<span class="hljs-number">0</span>][input_len:]

    <span class="hljs-comment"># Decode the response.</span>
    response = processor.decode(generation, skip_special_tokens=<span class="hljs-literal">True</span>)
    <span class="hljs-keyword">return</span> response


<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">extract_frames</span>(<span class="hljs-params">video_path, num_frames</span>):</span>
    <span class="hljs-string">"""
    The function is adapted from:
    https://github.com/merveenoyan/smol-vision/blob/main/Gemma_3_for_Video_Understanding.ipynb
    """</span>
    cap = cv2.VideoCapture(video_path)

    <span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> cap.isOpened():
        print(<span class="hljs-string">"Error: Could not open video file."</span>)
        <span class="hljs-keyword">return</span> []

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    <span class="hljs-comment"># Calculate the step size to evenly distribute frames across the video.</span>
    step = total_frames // num_frames
    frames = []

    <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(num_frames):
        frame_idx = i * step
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
        ret, frame = cap.read()
        <span class="hljs-keyword">if</span> <span class="hljs-keyword">not</span> ret:
            <span class="hljs-keyword">break</span>
        img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
        timestamp = round(frame_idx / fps, <span class="hljs-number">2</span>)
        frames.append((img, timestamp))

    cap.release()
    <span class="hljs-keyword">return</span> frames


<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">show_video</span>(<span class="hljs-params">video_path, video_width = <span class="hljs-number">600</span></span>):</span>
  video_file = open(video_path, <span class="hljs-string">"r+b"</span>).read()
  video_url = <span class="hljs-string">f"data:video/mp4;base64,<span class="hljs-subst">{b64encode(video_file).decode()}</span>"</span>
  video_html = <span class="hljs-string">f"""&lt;video width=<span class="hljs-subst">{video_width}</span> controls&gt;&lt;source src="<span class="hljs-subst">{video_url}</span>"&gt;&lt;/video&gt;"""</span>
  <span class="hljs-keyword">return</span> HTML(video_html)
</code></pre>
<h2 id="heading-run-an-inference-on-images">Run an inference on images</h2>
<p>Fetch some sample images for inferencing.</p>
<pre><code class="lang-bash">!wget https://raw.githubusercontent.com/wkambale/Luganda-Inference-on-Gemma-3/main/assets/image_1.jpg -O /content/image_1.jpg
!wget https://raw.githubusercontent.com/wkambale/Luganda-Inference-on-Gemma-3/main/assets/image_2.jpg -O /content/image_2.jpg
!wget https://raw.githubusercontent.com/wkambale/Luganda-Inference-on-Gemma-3/main/assets/image_3.jpg -O /content/image_3.jpg
!wget https://raw.githubusercontent.com/wkambale/Luganda-Inference-on-Gemma-3/main/assets/image_4.png -O /content/image_4.png
</code></pre>
<p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1741973322171/639309f6-cd16-4ac9-8c38-6a71883cafe5.jpeg" alt class="image--center mx-auto" /></p>
<ul>
<li><em>Image 1 Credit: The Pearl</em></li>
</ul>
<p><strong>Task 1: Describe an image</strong></p>
<p>The prompt is in Luganda language which translates to: "Describe the image."</p>
<pre><code class="lang-python">image_file = <span class="hljs-string">'image_1.jpg'</span>
prompt = <span class="hljs-string">"Nnyonnyola emmeere eri mu kifaananyi."</span>


img = resize_image(image_file)
display(img)
response = get_model_response(img, prompt, model, processor)
display(Markdown(response))
</code></pre>
<p>Response:</p>
<pre><code class="lang-bash">Omukoyogo.
</code></pre>
<h4 id="heading-example-2-identify-a-landmark">Example 2: Identify a landmark</h4>
<p>The prompt is in Luganda language which translates to: "Identify the famous landmark and location"</p>
<p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1742221026408/d08c920b-36d2-4ceb-bae4-cd315cffc8c8.jpeg" alt class="image--center mx-auto" /></p>
<ul>
<li><em>Image 2: Ali Zali</em></li>
</ul>
<pre><code class="lang-python">image_file = <span class="hljs-string">'image_2.jpg'</span>
prompt = <span class="hljs-string">"Londoola ekifo kino ekimanyiddwa ennyo nne w'ekisangibwa."</span>

img = resize_image(image_file)
display(img)
response = get_model_response(img, prompt, model, processor)
display(Markdown(response))
</code></pre>
<p>Response:</p>
<pre><code class="lang-bash">Ebibuga.
</code></pre>
<p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1742221137566/1d334f7e-2236-4903-9d79-079a022af5b7.jpeg" alt class="image--center mx-auto" /></p>
<ul>
<li><em>Image 3: The Tower Post</em></li>
</ul>
<pre><code class="lang-python">image_file = <span class="hljs-string">'image_3.jpg'</span>
prompt = <span class="hljs-string">"Londoola ekifo kino ekimanyiddwa ennyo nne w'ekisangibwa."</span>

img = resize_image(image_file)
display(img)
response = get_model_response(img, prompt, model, processor)
display(Markdown(response))
</code></pre>
<p>Response:</p>
<pre><code class="lang-bash">Kampala Bbwalo.
</code></pre>
<h4 id="heading-task-3-mathematical-reasoningokulowooza-mu-kubala">Task 3: Mathematical Reasoning/Okulowooza mu Kubala</h4>
<p>The prompt is in Luganda language which translates to: "What is the value of x?"</p>
<p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1742221248993/8ac94b19-069a-4ace-bc5a-338c2cb79125.png" alt class="image--center mx-auto" /></p>
<ul>
<li><em>Image: Nitin</em></li>
</ul>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> PIL <span class="hljs-keyword">import</span> Image
<span class="hljs-keyword">from</span> IPython.display <span class="hljs-keyword">import</span> Markdown

image_file = <span class="hljs-string">'image_4.png'</span>
prompt = <span class="hljs-string">"Omuwendo gwa x gwe guliwa?"</span>

img = resize_image(image_file)
display(img)
response = get_model_response(img, prompt, model, processor)
display(Markdown(response))
</code></pre>
<p>Response:</p>
<pre><code class="lang-bash">x = 3
</code></pre>
<h3 id="heading-inference-on-videos"><strong>Inference on videos</strong></h3>
<p>The video is a clip from "Why Uganda is the Pearl Of Africa!" shoot.</p>
<ul>
<li><p>Credits: Eunice Tess</p>
</li>
<li><p>Source: <a target="_blank" href="https://www.google.com/url?q=https%3A%2F%2Fyoutu.be%2Fu4D20WDrZyY%3Fsi%3DCU03hErvHfHCtxaX">YouTube</a></p>
</li>
</ul>
<pre><code class="lang-python"><span class="hljs-comment"># Video file.</span>
video_path = <span class="hljs-string">"video.mp4"</span>

<span class="hljs-comment"># No. of frames to be extracted from the video.</span>
num_frames = <span class="hljs-number">10</span>

video_output = show_video(video_path, video_width=<span class="hljs-number">800</span>)
display(video_output)
</code></pre>
<div class="embed-wrapper"><div class="embed-loading"><div class="loadingRow"></div><div class="loadingRow"></div></div><a class="embed-card" href="https://youtu.be/u4D20WDrZyY?si=uigb9f9hjCFjFxPV">https://youtu.be/u4D20WDrZyY?si=uigb9f9hjCFjFxPV</a></div>
<p> </p>
<p>The prompt is in Luganda language which translates to: "Please summarize what is happening in this video"</p>
<pre><code class="lang-python">video_frames = extract_frames(video_path, num_frames=num_frames)

messages = [
    {
        <span class="hljs-string">"role"</span>: <span class="hljs-string">"system"</span>,
        <span class="hljs-string">"content"</span>: [{<span class="hljs-string">"type"</span>: <span class="hljs-string">"text"</span>, <span class="hljs-string">"text"</span>: <span class="hljs-string">"You are a helpful assistant."</span>}]
    },
    {
        <span class="hljs-string">"role"</span>: <span class="hljs-string">"user"</span>,
        <span class="hljs-string">"content"</span>: [{<span class="hljs-string">"type"</span>: <span class="hljs-string">"text"</span>, <span class="hljs-string">"text"</span>: <span class="hljs-string">"Nsaba mufunze ebigenda mu maaso mu katambi kano"</span>}]
    }
]


<span class="hljs-comment"># Add frames to the messages structure.</span>
<span class="hljs-keyword">for</span> frame_data <span class="hljs-keyword">in</span> video_frames:
    img, timestamp = frame_data
    messages[<span class="hljs-number">1</span>][<span class="hljs-string">"content"</span>].append({<span class="hljs-string">"type"</span>: <span class="hljs-string">"text"</span>, <span class="hljs-string">"text"</span>: <span class="hljs-string">f"Frame at <span class="hljs-subst">{timestamp}</span> seconds:"</span>})
    img.save(<span class="hljs-string">f"/content/frames/frame_<span class="hljs-subst">{timestamp}</span>.png"</span>)
    messages[<span class="hljs-number">1</span>][<span class="hljs-string">"content"</span>].append({<span class="hljs-string">"type"</span>: <span class="hljs-string">"image"</span>, <span class="hljs-string">"url"</span>: <span class="hljs-string">f"/content/frames/frame_<span class="hljs-subst">{timestamp}</span>.png"</span>})


inputs = processor.apply_chat_template(
    messages, add_generation_prompt=<span class="hljs-literal">True</span>, tokenize=<span class="hljs-literal">True</span>,
    return_dict=<span class="hljs-literal">True</span>, return_tensors=<span class="hljs-string">"pt"</span>
).to(model.device)


input_length = inputs[<span class="hljs-string">"input_ids"</span>].shape[<span class="hljs-number">-1</span>]

<span class="hljs-comment"># Generate a response based on the inputs.</span>
output = model.generate(**inputs, max_new_tokens=<span class="hljs-number">500</span>, do_sample=<span class="hljs-literal">False</span>)
output = output[<span class="hljs-number">0</span>][input_length:]
response = processor.decode(output, skip_special_tokens=<span class="hljs-literal">True</span>)

display(Markdown(response))
</code></pre>
<p>Response:</p>
<pre><code class="lang-markdown">Okay, let's look at these images. Here's what I see in Luganda:

Frame 0.0: "Nyo mu maaso, ttiima nnyo. Twee nyo mu nsi y'omutwe, nti nyo mu maaso." (It's beautiful, very impressive. It's in a world of wonder, it's in the water.)
Frame 13.2: "Nyo mu maaso, ttiima nnyo. Twee nyo mu nsi y'omutwe, nti nyo mu maaso." (It's beautiful, very impressive. It's in a world of wonder, it's in the water.)
Frame 26.4: "Nyo mu maaso, ttiima nnyo. Twee nyo mu nsi y'omutwe, nti nyo mu maaso." (It's beautiful, very impressive. It's in a world of wonder, it's in the water.)
Frame 39.6: "Nyo mu maaso, ttiima nnyo. Twee nyo mu nsi y'omutwe, nti nyo mu maaso." (It's beautiful, very impressive. It's in a world of wonder, it's in the water.)
Frame 52.8: "Nyo mu maaso, ttiima nnyo. Twee nyo mu nsi y'omutwe, nti nyo mu maaso." (It's beautiful, very impressive. It's in a world of wonder, it's in the water.)
Frame 66.0: "Nyo mu maaso, ttiima nnyo. Twee nyo mu nsi y'omutwe, nti nyo mu maaso." (It's beautiful, very impressive. It's in a world of wonder, it's in the water.)
Frame 79.2: "Nyo mu maaso, ttiima nnyo. Twee nyo mu nsi y'omutwe, nti nyo mu maaso."
</code></pre>
<h3 id="heading-deductions"><strong>Deductions</strong></h3>
<p>The outputs above reveal a significant limitation: despite <strong>Gemma 3 models</strong> boasting multilingual capabilities across <strong>140+ languages</strong>, they still struggle to handle <strong>vision tasks (images and videos) in Luganda</strong> effectively.</p>
<p>This demonstration underscores the urgent need for:</p>
<ol>
<li><p><strong>More research</strong> into optimizing AI models for low-resource languages like Luganda.</p>
</li>
<li><p><strong>Expanding datasets</strong> with high-quality, Luganda-specific image and video annotations.</p>
</li>
<li><p><strong>Training foundational models</strong> that natively understand Luganda in multimodal contexts.</p>
</li>
</ol>
<p>Without these critical steps, AI-powered vision systems will continue to exclude Luganda and other underrepresented languages from advancements in <strong>multimodal AI</strong>.</p>
<p><strong>Resources</strong></p>
<p>Here is the <a target="_blank" href="https://colab.research.google.com/drive/1xkB4O3yjjDEnd4x2yZCy56MbmK_OuASn?usp=sharing">notebook</a>.</p>
]]></content:encoded></item><item><title><![CDATA[Building Convolutional Neural Networks in JAX]]></title><description><![CDATA[Introduction
Deep learning has revolutionized the field of artificial intelligence, and at the heart of this revolution are Convolutional Neural Networks (CNNs). CNNs have become the go-to architectures for tasks involving images, such as object dete...]]></description><link>https://kambale.dev/build-cnn-in-jax</link><guid isPermaLink="true">https://kambale.dev/build-cnn-in-jax</guid><category><![CDATA[CNN]]></category><category><![CDATA[jax]]></category><category><![CDATA[CNN for begginers]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Wed, 12 Mar 2025 15:05:24 GMT</pubDate><enclosure url="https://cdn.hashnode.com/res/hashnode/image/upload/v1741790696877/2599696c-54c6-4121-9022-c158009ae0bd.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<h2 id="heading-introduction"><strong>Introduction</strong></h2>
<p>Deep learning has revolutionized the field of artificial intelligence, and at the heart of this revolution are <strong>Convolutional Neural Networks (CNNs)</strong>. CNNs have become the go-to architectures for tasks involving images, such as object detection, facial recognition, medical imaging, and self-driving cars.</p>
<p>Traditionally, frameworks like TensorFlow and PyTorch have dominated the deep learning landscape. However, <strong>JAX</strong> has emerged as a powerful alternative, especially for research and high-performance computing. Developed by Google, JAX provides <strong>automatic differentiation</strong> and <strong>Just-In-Time (JIT) compilation</strong>, making it highly efficient for numerical computing and deep learning applications.</p>
<p><strong>Why Use JAX for CNNs?</strong></p>
<p>JAX stands out because of its ability to seamlessly run code on <strong>CPUs, GPUs, and TPUs</strong> while maintaining a NumPy-like API. This means you can develop models using familiar syntax while benefiting from:</p>
<ol>
<li><p><strong>Automatic Vectorization</strong> – With functions like <code>vmap</code>, JAX makes it easy to apply operations over large batches of data without writing explicit loops.</p>
</li>
<li><p><strong>Efficient Autograd</strong> – JAX provides <strong>automatic differentiation</strong> using <code>grad</code>, which simplifies training deep learning models.</p>
</li>
<li><p><strong>XLA Compilation</strong> – Just-In-Time (JIT) compilation speeds up execution by compiling computation graphs for efficient hardware utilization.</p>
</li>
<li><p><strong>Functional Programming Paradigm</strong> – Unlike traditional deep learning frameworks, JAX encourages <strong>pure functions</strong>, which improves reproducibility and debugging.</p>
</li>
</ol>
<p><strong>Prerequisites</strong></p>
<p>Before proceeding, ensure you are familiar with:</p>
<ul>
<li><p>JAX fundamentals, you can check out the JAX documentation <a target="_blank" href="https://docs.jax.dev/en/latest/index.html">here</a>.</p>
</li>
<li><p>Building CNNs in TensorFlow or PyTorch</p>
</li>
<li><p>JAX optimizers and loss functions</p>
</li>
</ul>
<p><strong>Install Dependencies</strong></p>
<p>Install the required libraries:</p>
<pre><code class="lang-bash">!pip install jax jaxlib flax optax tensorflow tensorflow_datasets dm-pix tqdm matplotlib
</code></pre>
<ul>
<li><p><code>jax</code> and <code>jaxlib</code> – The core JAX library and its hardware acceleration backend.</p>
</li>
<li><p><code>flax</code> – A neural network library for JAX, similar to PyTorch’s <code>torch.nn</code>.</p>
</li>
<li><p><code>optax</code> – A library for optimization algorithms in JAX.</p>
</li>
<li><p><code>dm_pix</code> – A lightweight image processing library for JAX.</p>
</li>
<li><p><code>matplotlib</code> – For visualizing images.</p>
</li>
</ul>
<p><strong>Import Packages</strong></p>
<p>Load necessary libraries:</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> os
<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf
<span class="hljs-keyword">import</span> tensorflow_datasets <span class="hljs-keyword">as</span> tfds
<span class="hljs-keyword">import</span> matplotlib.pyplot <span class="hljs-keyword">as</span> plt
<span class="hljs-keyword">import</span> jax
<span class="hljs-keyword">import</span> jax.numpy <span class="hljs-keyword">as</span> jnp
<span class="hljs-keyword">import</span> optax
<span class="hljs-keyword">from</span> tqdm.auto <span class="hljs-keyword">import</span> tqdm
<span class="hljs-keyword">from</span> flax <span class="hljs-keyword">import</span> linen <span class="hljs-keyword">as</span> nn
<span class="hljs-keyword">from</span> flax.training <span class="hljs-keyword">import</span> train_state
<span class="hljs-keyword">import</span> dm_pix <span class="hljs-keyword">as</span> pix  <span class="hljs-comment"># Image processing in JAX</span>
</code></pre>
<p><strong>Verify GPU Access</strong></p>
<p>JAX runs computations on <strong>CPUs, GPUs, and TPUs</strong> seamlessly. To check if your machine has a GPU or TPU available:</p>
<pre><code class="lang-python"><span class="hljs-comment"># Get available devices</span>
print(<span class="hljs-string">"Available Devices:"</span>, jax.devices())

<span class="hljs-comment"># Check if GPU is available</span>
<span class="hljs-keyword">if</span> jax.default_backend() == <span class="hljs-string">"gpu"</span>:
    print(<span class="hljs-string">"Using GPU:"</span>, jax.devices(<span class="hljs-string">"gpu"</span>))
<span class="hljs-keyword">elif</span> jax.default_backend() == <span class="hljs-string">"tpu"</span>:
    print(<span class="hljs-string">"Using TPU:"</span>, jax.devices(<span class="hljs-string">"tpu"</span>))
<span class="hljs-keyword">else</span>:
    print(<span class="hljs-string">"Using CPU"</span>)
</code></pre>
<h2 id="heading-data-preprocessing"><strong>Data Preprocessing</strong></h2>
<p>Raw image data comes in various sizes, orientations, and quality levels. Preprocessing is crucial for:</p>
<ul>
<li><p>Ensuring uniform input dimensions.</p>
</li>
<li><p>Normalizing pixel values for stable training.</p>
</li>
<li><p>Augmenting data to improve model generalization.</p>
</li>
<li><p>Converting images into JAX-compatible tensors.</p>
</li>
</ul>
<p><strong>Loading the Dataset</strong></p>
<p>We will use the <strong>Cats vs. Dogs dataset</strong> available on Kaggle. Download and unzip the dataset using:</p>
<pre><code class="lang-bash">!kaggle datasets download -d chetankv/dogs-cats-images
!unzip dogs-cats-images.zip
</code></pre>
<p>This dataset gives us:</p>
<ul>
<li><p>A <strong>training dataset</strong> (<code>train_data</code>)</p>
</li>
<li><p>A <strong>test dataset</strong> (<code>test_data</code>)</p>
</li>
</ul>
<p>Define the path to the images and the batch size:</p>
<pre><code class="lang-python">base_dir = <span class="hljs-string">"/content/dog vs cat/dataset/training_set"</span>
batch_size = <span class="hljs-number">64</span>
</code></pre>
<p><strong>Resizing and Normalizing Images</strong></p>
<p>Since images in the dataset have varying sizes, we must <strong>resize</strong> them to a fixed size (e.g., <strong>128×128 pixels</strong>). Additionally, we normalize pixel values from <strong>[0, 255] → [0, 1]</strong> for stable training.</p>
<pre><code class="lang-python">IMG_SIZE = <span class="hljs-number">128</span>

resize_and_rescale = tf.keras.Sequential(
    [
        tf.keras.layers.Resizing(IMG_SIZE, IMG_SIZE),
        tf.keras.layers.Rescaling(<span class="hljs-number">1.0</span> / <span class="hljs-number">255</span>),
    ]
)
</code></pre>
<p><strong>Data Augmentation for Better Generalization</strong></p>
<p>Data augmentation helps improve model generalization by applying transformations like <strong>flipping, rotation, brightness adjustments, and cropping</strong>.</p>
<pre><code class="lang-python">rng = jax.random.PRNGKey(<span class="hljs-number">0</span>)
rng, inp_rng, init_rng = jax.random.split(rng, <span class="hljs-number">3</span>)

delta = <span class="hljs-number">0.42</span>
factor = <span class="hljs-number">0.42</span>

<span class="hljs-meta">@jax.jit</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">data_augmentation</span>(<span class="hljs-params">image</span>):</span>
    new_image = pix.adjust_brightness(image=image, delta=delta)
    new_image = pix.random_brightness(image=new_image, max_delta=delta, key=inp_rng)
    new_image = pix.flip_up_down(image=image)
    new_image = pix.flip_left_right(image=new_image)
    new_image = pix.rot90(k=<span class="hljs-number">1</span>, image=new_image) <span class="hljs-comment"># k = number of times the rotation is applied</span>

    <span class="hljs-keyword">return</span> new_image
</code></pre>
<p><strong>Converting Data to JAX-Compatible Tensors</strong></p>
<p>JAX primarily operates on NumPy-like arrays (<code>jnp.array</code>). TensorFlow uses <code>tf.Tensor</code>, so we must convert our dataset into a JAX-friendly format.</p>
<pre><code class="lang-python">AUTOTUNE = tf.data.AUTOTUNE

<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">prepare</span>(<span class="hljs-params">ds, shuffle=False</span>):</span>
    <span class="hljs-comment"># Rescale and resize all datasets.</span>
    ds = ds.map(<span class="hljs-keyword">lambda</span> x, y: (resize_and_rescale(x), y), num_parallel_calls=AUTOTUNE)

    <span class="hljs-keyword">if</span> shuffle:
        ds = ds.shuffle(<span class="hljs-number">1000</span>)

    <span class="hljs-comment"># Use buffered prefetching on all datasets.</span>
    <span class="hljs-keyword">return</span> ds.prefetch(buffer_size=AUTOTUNE)

train_ds = prepare(training_set, shuffle=<span class="hljs-literal">True</span>)
val_ds = prepare(validation_set)
evaluation_set = prepare(eval_set)
</code></pre>
<p><strong>Visualizing Preprocessed Images</strong></p>
<p>Let’s check if preprocessing is working as expected.</p>
<pre><code class="lang-python">plt.figure(figsize=(<span class="hljs-number">10</span>, <span class="hljs-number">10</span>))

augmented_images = []

<span class="hljs-keyword">for</span> images, _ <span class="hljs-keyword">in</span> training_set.take(<span class="hljs-number">1</span>):
  <span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(<span class="hljs-number">9</span>):
    augmented_image = data_augmentation(np.array(images[i], dtype=jnp.float32))
    augmented_images.append(augmented_image)
    ax = plt.subplot(<span class="hljs-number">3</span>, <span class="hljs-number">3</span>, i + <span class="hljs-number">1</span>)
    plt.imshow(augmented_images[i].astype(<span class="hljs-string">"uint8"</span>))
    plt.axis(<span class="hljs-string">"off"</span>)
</code></pre>
<h2 id="heading-defining-a-cnn-in-jax"><strong>Defining a CNN in JAX</strong></h2>
<p>In JAX, neural networks are often implemented using <strong>Flax</strong>, a high-level neural network library designed to work seamlessly with JAX’s functional paradigm. Flax provides an intuitive way to define models using <strong>Module</strong> classes.</p>
<p>Below is a simple implementation of a Convolutional Neural Network (CNN) in JAX using Flax:</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> jax.numpy <span class="hljs-keyword">as</span> jnp
<span class="hljs-keyword">import</span> flax.linen <span class="hljs-keyword">as</span> nn

<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">CNN</span>(<span class="hljs-params">nn.Module</span>):</span>
    num_classes: int  <span class="hljs-comment"># Number of output classes</span>

<span class="hljs-meta">    @nn.compact</span>
    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__call__</span>(<span class="hljs-params">self, x</span>):</span>
        x = nn.Conv(features=<span class="hljs-number">32</span>, kernel_size=(<span class="hljs-number">3</span>, <span class="hljs-number">3</span>), strides=(<span class="hljs-number">1</span>, <span class="hljs-number">1</span>), padding=<span class="hljs-string">"SAME"</span>)(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>), strides=(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>))

        x = nn.Conv(features=<span class="hljs-number">64</span>, kernel_size=(<span class="hljs-number">3</span>, <span class="hljs-number">3</span>), strides=(<span class="hljs-number">1</span>, <span class="hljs-number">1</span>), padding=<span class="hljs-string">"SAME"</span>)(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>), strides=(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>))

        x = nn.Conv(features=<span class="hljs-number">128</span>, kernel_size=(<span class="hljs-number">3</span>, <span class="hljs-number">3</span>), strides=(<span class="hljs-number">1</span>, <span class="hljs-number">1</span>), padding=<span class="hljs-string">"SAME"</span>)(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>), strides=(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>))

        x = x.reshape((x.shape[<span class="hljs-number">0</span>], <span class="hljs-number">-1</span>))  <span class="hljs-comment"># Flatten feature maps</span>
        x = nn.Dense(features=<span class="hljs-number">128</span>)(x)
        x = nn.relu(x)

        x = nn.Dense(features=self.num_classes)(x)  <span class="hljs-comment"># Output layer</span>
        <span class="hljs-keyword">return</span> x
</code></pre>
<ol>
<li><p><strong>First Convolutional Layer</strong></p>
<ul>
<li><p>Applies a <strong>32-channel</strong> convolution with a <strong>3×3</strong> kernel.</p>
</li>
<li><p>Uses <strong>ReLU</strong> activation to introduce non-linearity.</p>
</li>
<li><p>Applies <strong>max pooling</strong> with a <strong>2×2</strong> window and a stride of <strong>2</strong>, reducing the spatial dimensions by half.</p>
</li>
</ul>
</li>
<li><p><strong>Second Convolutional Layer</strong></p>
<ul>
<li><p>Uses a <strong>64-channel</strong> convolution with a <strong>3×3</strong> kernel.</p>
</li>
<li><p>Again applies <strong>ReLU</strong> activation.</p>
</li>
<li><p>Another <strong>max pooling</strong> operation further reduces spatial dimensions.</p>
</li>
</ul>
</li>
<li><p><strong>Third Convolutional Layer</strong></p>
<ul>
<li><p>Increases the number of channels to <strong>128</strong> while keeping the <strong>3×3</strong> kernel size.</p>
</li>
<li><p>Applies <strong>ReLU</strong> activation and another <strong>max pooling</strong> step.</p>
</li>
</ul>
</li>
<li><p><strong>Flattening and Fully Connected Layers</strong></p>
<ul>
<li><p>The feature maps from the final convolutional layer are <strong>flattened</strong> into a 1D vector.</p>
</li>
<li><p>A <strong>dense layer with 128 neurons</strong> applies a ReLU activation.</p>
</li>
<li><p>The final <strong>output layer</strong> produces logits corresponding to the number of classes.</p>
</li>
</ul>
</li>
</ol>
<p><strong>Why Use</strong> <code>@nn.compact</code>?</p>
<p>Flax provides two ways to define models:</p>
<ul>
<li><p>Using <code>@nn.compact</code>, which allows direct instantiation of layers within the <code>__call__</code> method.</p>
</li>
<li><p>Using <code>setup()</code>, where layers are explicitly defined as attributes.</p>
</li>
</ul>
<p>The <strong>compact</strong> approach is cleaner and more intuitive for simple models, avoiding the need to define layer attributes separately.</p>
<h2 id="heading-initializing-the-model"><strong>Initializing the Model</strong></h2>
<p>JAX does not use an implicit state, so model parameters must be explicitly initialized. The <code>init</code> function from Flax helps generate the model’s parameters using a random key and an input shape.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> flax.core <span class="hljs-keyword">import</span> freeze, unfreeze

<span class="hljs-comment"># Set up PRNG key</span>
key = jax.random.PRNGKey(<span class="hljs-number">0</span>)

<span class="hljs-comment"># Define input shape (batch_size, height, width, channels)</span>
input_shape = (<span class="hljs-number">1</span>, <span class="hljs-number">32</span>, <span class="hljs-number">32</span>, <span class="hljs-number">3</span>)  <span class="hljs-comment"># Example for a 32x32 RGB image</span>

<span class="hljs-comment"># Initialize model</span>
model = CNN(num_classes=<span class="hljs-number">10</span>)  <span class="hljs-comment"># Assuming 10 output classes</span>
params = model.init(key, jnp.ones(input_shape))[<span class="hljs-string">"params"</span>]
</code></pre>
<ul>
<li><p>A <strong>random key</strong> is generated using <code>jax.random.PRNGKey(0)</code>. JAX requires explicit control over random number generation for reproducibility.</p>
</li>
<li><p>A <strong>dummy input tensor</strong> of shape <code>(1, 32, 32, 3)</code> is created to initialize the network.</p>
</li>
<li><p>The model is <strong>instantiated</strong> and the <code>init</code> function generates model parameters using the random key.</p>
</li>
<li><p>The <code>"params"</code> field is extracted from the initialization output, as Flax’s <code>init</code> method returns a dictionary containing additional information (e.g., batch statistics if using BatchNorm).</p>
</li>
</ul>
<p><strong>Defining the Training State</strong></p>
<p>Flax provides a <code>train_state</code> abstraction to manage model parameters, optimizer state, and other training-related information. The <code>optax</code> library is used for defining the optimizer.</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> optax
<span class="hljs-keyword">from</span> flax.training <span class="hljs-keyword">import</span> train_state

<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">TrainState</span>(<span class="hljs-params">train_state.TrainState</span>):</span>
    <span class="hljs-keyword">pass</span>  <span class="hljs-comment"># No additional attributes needed for now</span>

<span class="hljs-comment"># Define the optimizer</span>
learning_rate = <span class="hljs-number">0.001</span>
optimizer = optax.adam(learning_rate)

<span class="hljs-comment"># Initialize the training state</span>
state = TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)
</code></pre>
<ul>
<li><p><code>TrainState</code> is a dataclass that stores the model's parameters, optimizer state, and <code>apply_fn</code> (the function used for forward passes).</p>
</li>
<li><p><strong>Optax's Adam optimizer</strong> is set up with a learning rate of <code>0.001</code>.</p>
</li>
<li><p>The <code>state.create()</code> method initializes the model’s training state with:</p>
<ul>
<li><p><code>apply_fn</code>: The forward pass function from the model.</p>
</li>
<li><p><code>params</code>: The initialized parameters from the previous step.</p>
</li>
<li><p><code>tx</code>: The optimizer (Adam in this case).</p>
</li>
</ul>
</li>
</ul>
<p><strong>Defining Loss and Accuracy Metrics</strong></p>
<p>A loss function is required to guide training, while an accuracy function evaluates model performance.</p>
<p><strong>Loss Function</strong></p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> jax.nn <span class="hljs-keyword">as</span> jnn

<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">cross_entropy_loss</span>(<span class="hljs-params">params, state, batch</span>):</span>
    logits = state.apply_fn({<span class="hljs-string">'params'</span>: params}, batch[<span class="hljs-string">'images'</span>])
    labels = jnn.one_hot(batch[<span class="hljs-string">'labels'</span>], num_classes=<span class="hljs-number">10</span>)
    <span class="hljs-keyword">return</span> -jnp.sum(labels * jnn.log_softmax(logits)) / batch[<span class="hljs-string">'labels'</span>].shape[<span class="hljs-number">0</span>]
</code></pre>
<ul>
<li><p>The function takes <strong>model parameters</strong>, the <strong>current training state</strong>, and a <strong>batch of input data</strong>.</p>
</li>
<li><p>The <strong>forward pass</strong> is performed using <code>apply_fn</code>, producing logits (raw model predictions).</p>
</li>
<li><p>The <strong>labels are one-hot encoded</strong> to match the logits' shape.</p>
</li>
<li><p>The <strong>cross-entropy loss</strong> is computed using <code>log_softmax(logits)</code>, ensuring numerical stability.</p>
</li>
<li><p>The loss is <strong>averaged over the batch size</strong> for proper optimization.</p>
</li>
</ul>
<p><strong>Accuracy Function</strong></p>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">compute_accuracy</span>(<span class="hljs-params">params, state, batch</span>):</span>
    logits = state.apply_fn({<span class="hljs-string">'params'</span>: params}, batch[<span class="hljs-string">'images'</span>])
    predicted_labels = jnp.argmax(logits, axis=<span class="hljs-number">1</span>)
    <span class="hljs-keyword">return</span> jnp.mean(predicted_labels == batch[<span class="hljs-string">'labels'</span>])
</code></pre>
<ul>
<li><p>The function takes <strong>model parameters</strong>, the <strong>training state</strong>, and a <strong>batch of data</strong>.</p>
</li>
<li><p>The <strong>forward pass</strong> is executed to obtain logits.</p>
</li>
<li><p>The <strong>highest-scoring class</strong> is selected using <code>argmax()</code>, determining the model’s predicted label.</p>
</li>
<li><p>Accuracy is computed by <strong>comparing predictions with actual labels</strong> and averaging the correct classifications.</p>
</li>
</ul>
<h2 id="heading-training-and-evaluating-a-cnn-in-jax"><strong>Training and Evaluating a CNN in JAX</strong></h2>
<p>Training in JAX is based on functional transformations, meaning explicit gradient computation and parameter updates are required. The <code>jax.grad</code> function is used to compute gradients efficiently.</p>
<p><strong>Training Step Function</strong></p>
<pre><code class="lang-python"><span class="hljs-meta">@jax.jit</span>
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train_step</span>(<span class="hljs-params">state, batch</span>):</span>
    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">loss_fn</span>(<span class="hljs-params">params</span>):</span>
        logits = state.apply_fn({<span class="hljs-string">'params'</span>: params}, batch[<span class="hljs-string">'images'</span>])
        labels = jax.nn.one_hot(batch[<span class="hljs-string">'labels'</span>], num_classes=<span class="hljs-number">10</span>)
        loss = -jnp.sum(labels * jax.nn.log_softmax(logits)) / batch[<span class="hljs-string">'labels'</span>].shape[<span class="hljs-number">0</span>]
        <span class="hljs-keyword">return</span> loss

    <span class="hljs-comment"># Compute gradients</span>
    grads = jax.grad(loss_fn)(state.params)

    <span class="hljs-comment"># Update model state</span>
    state = state.apply_gradients(grads=grads)

    <span class="hljs-keyword">return</span> state
</code></pre>
<ul>
<li><p><strong>JIT Compilation (</strong><code>@jax.jit</code>): JAX’s just-in-time compilation speeds up training by optimizing computation.</p>
</li>
<li><p><code>loss_fn</code> Function: Defines the cross-entropy loss to be minimized.</p>
</li>
<li><p><code>jax.grad(loss_fn)</code>: Computes gradients with respect to model parameters.</p>
</li>
<li><p><code>state.apply_gradients(grads=grads)</code>: Updates the training state using computed gradients.</p>
</li>
</ul>
<p><strong>Training Loop</strong></p>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">train_model</span>(<span class="hljs-params">state, train_loader, num_epochs=<span class="hljs-number">10</span></span>):</span>
    <span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> range(num_epochs):
        <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> train_loader:
            state = train_step(state, batch)
        print(<span class="hljs-string">f"Epoch <span class="hljs-subst">{epoch + <span class="hljs-number">1</span>}</span> completed"</span>)
    <span class="hljs-keyword">return</span> state
</code></pre>
<ul>
<li><p><strong>Iterates through multiple epochs</strong>, training the model for <code>num_epochs</code>.</p>
</li>
<li><p><strong>Processes each batch</strong>, updating the model parameters.</p>
</li>
<li><p><strong>Logs progress</strong> at the end of each epoch.</p>
</li>
</ul>
<p><strong>Evaluation Step Function</strong></p>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">evaluate_model</span>(<span class="hljs-params">state, test_loader</span>):</span>
    accuracies = []

    <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> test_loader:
        acc = compute_accuracy(state.params, state, batch)
        accuracies.append(acc)

    final_accuracy = jnp.mean(jnp.array(accuracies))
    print(<span class="hljs-string">f"Test Accuracy: <span class="hljs-subst">{final_accuracy * <span class="hljs-number">100</span>:<span class="hljs-number">.2</span>f}</span>%"</span>)
    <span class="hljs-keyword">return</span> final_accuracy
</code></pre>
<ul>
<li><p>Iterates through the <strong>test dataset</strong>, computing accuracy for each batch.</p>
</li>
<li><p><strong>Aggregates accuracy scores</strong> across batches to compute the final accuracy.</p>
</li>
<li><p><strong>Prints the test accuracy</strong>, indicating how well the model generalizes to unseen data.</p>
</li>
</ul>
<h2 id="heading-conclusion"><strong>Conclusion</strong></h2>
<p>JAX’s functional and hardware-accelerated approach allows for efficient model training, particularly on GPUs and TPUs. The explicit handling of gradients and optimizers ensures flexibility while maintaining high performance.</p>
<p>Future work could explore <strong>advanced techniques</strong> such as data augmentation, model regularization, and hyperparameter tuning to improve performance. Additionally, integrating JAX with frameworks like TensorFlow or PyTorch could provide hybrid workflows for deep learning research and production applications.</p>
]]></content:encoded></item><item><title><![CDATA[Non-technical summits and conferences not impactful in the AI race]]></title><description><![CDATA[For decades, developer communities across the globe have served as vital engines for fostering technological adoption and innovation. These communities are not merely gatherings for networking but essential conduits for pushing developers to embrace ...]]></description><link>https://kambale.dev/non-technical-summits-and-conferences-not-impactful-in-the-ai-race</link><guid isPermaLink="true">https://kambale.dev/non-technical-summits-and-conferences-not-impactful-in-the-ai-race</guid><category><![CDATA[summit]]></category><category><![CDATA[AI]]></category><category><![CDATA[technical]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Tue, 04 Mar 2025 21:00:00 GMT</pubDate><enclosure url="https://cdn.hashnode.com/res/hashnode/image/upload/v1741950125921/ba637f86-58b3-4046-8fc3-ab17215f1a53.jpeg" length="0" type="image/jpeg"/><content:encoded><![CDATA[<p>For decades, developer communities across the globe have served as vital engines for fostering technological adoption and innovation. These communities are not merely gatherings for networking but essential conduits for pushing developers to embrace emerging technologies, providing feedback loops that guide improvements, and ensuring the technology being built is practical and useful. Global tech giants like Google operate the Google Developer Groups (GDGs) worldwide, Facebook previously ran Facebook Developer Circles with a chapter in Kampala, and African companies like Africa’s Talking have also actively supported developer communities.</p>
<p>A key mandate of these communities is to organize regular meetups, workshops, summits, and conferences. These gatherings are typically hands-on, bringing together expert developers to share skills, knowledge, and technological advancements with attendees through interactive workshops, code labs, and hackathons. In Uganda, groups such as GDG Cloud Kampala and GDG Cloud Mbarara have consistently held technical events like DevFests, while the Python community has successfully organized PyCon Uganda, now heading into its third edition.</p>
<p>However, alongside the rise of these developer-led events, corporate entities have increasingly taken an interest in organizing technology-related summits and conferences. This trend, while seemingly positive, presents a growing concern: many of these corporate-backed events are significantly lacking in technical depth, especially in the crucial domain of artificial intelligence (AI).</p>
<p><strong>The Corporate Takeover and the Technical Deficit</strong></p>
<p>These corporate-led summits and conferences often secure substantial funding, sometimes from government entities and corporate sponsors eager to see their logos on billboards and banners. This kind of financial backing is something developer-led community events can only dream of. Having organized conferences like Google I/O Extended and DevFest Mbarara for three years—each attracting close to 200 students and young professionals—I have personally experienced the uphill battle of securing sponsorship. The struggle to find support is a familiar nightmare for many tech community organizers, who must often rely on volunteers and minimal budgets to pull off impactful events.</p>
<p>Yet, despite their financial strength, corporate-led events frequently fall short of delivering meaningful technical content. A crucial weakness is the lack of hands-on sessions where attendees can actively engage in building, deploying, and experimenting with AI models and other emerging technologies. In many cases, not a single AI model is built or deployed throughout these high-profile summits. The reason? The invited speakers are usually corporate executives, heads of IT departments, and business strategists who, while experienced in management, often lack deep familiarity with emerging AI technologies.</p>
<p><strong>The AI Knowledge Gap at Corporate Summits</strong></p>
<p>To be clear, this is not to undermine the experience and expertise of IT heads in various companies and entities. Many of them have led digital transformation projects, implemented enterprise systems, and managed IT infrastructure at scale. However, AI is a different ballgame. The ability to discuss generative AI tools like OpenAI’s ChatGPT, Google’s Gemini, xAI’s Grok, or DeepSeek’s R1 is one thing; understanding how to fine-tune or integrate large language models (LLMs) such as Gemma 2, LLaMA 3, or OpenAI’s GPT series to solve local problems another.</p>
<p>Few, if any, of these corporate speakers can claim hands-on experience in fine-tuning models, optimizing neural networks, or deploying AI systems into production. This expertise gap means that rather than meaningful, technical AI discussions, these summits often feature surface-level talks filled with buzzwords and high-level strategies that lack practical application. Meanwhile, developers seeking actual skills in AI model development and deployment leave these events empty-handed.</p>
<p>Contrast this with many corporate-led AI summits that focus solely on panel discussions and keynote addresses. While such formats are useful for high-level industry insights, they do little to equip attendees with the hands-on skills necessary for AI development. Without technical immersion, these summits risk becoming echo chambers where the same concepts are discussed repeatedly without any tangible output.</p>
<p><strong>The Danger of Mimicking Bureaucratic Inefficiency</strong></p>
<p>We frequently criticize governments for their endless boardroom meetings and benchmarking trips that yield little practical implementation. It would be a grave mistake to let this inefficiency seep into the developer community. If AI is the future, then our engagement with it must be hands-on. We cannot afford to merely discuss AI trends at conferences while leaving the actual coding and model development to others. We must actively build and deploy AI models, integrate them into real-world applications, and refine them through continuous experimentation.</p>
<p><strong>Shifting the Focus: How AI Events Can Be More Impactful</strong></p>
<p>To ensure that our AI summits and conferences are truly impactful, organizers should:</p>
<ul>
<li><p><strong>Prioritize Hands-On Workshops</strong> – Every AI conference should include code labs where participants can build and deploy AI models in real time. For example, workshops on fine-tuning LLMs, training custom vision models, or implementing AI in cloud environments should be staple sessions.</p>
</li>
<li><p><strong>Feature Technical Speakers</strong> – Rather than filling panels with corporate executives, AI summits should prioritize experts with hands-on experience in AI research and development. This means bringing in data scientists, machine learning engineers, and AI practitioners who can demonstrate real-world applications.</p>
</li>
<li><p><strong>Incorporate Hackathons and AI Challenges</strong> – Practical AI challenges, such as hackathons, Kaggle-style competitions, or live coding challenges, should be an integral part of any AI summit.</p>
</li>
<li><p><strong>Encourage Open-Source Contributions</strong> – Conferences should foster contributions to AI open-source projects, enabling attendees to leave with more than just theoretical knowledge.</p>
</li>
<li><p><strong>Partner with Developer Communities</strong> – Corporate entities should work with developer communities to deliver the hands-on approach to their summits. These communities have the target audience and the experience in organizing these events.</p>
</li>
<li><p><strong>Develop Local AI Talent</strong> – AI conferences should facilitate mentorship programs that connect experienced AI practitioners with budding developers to ensure continuous learning beyond the event itself.</p>
</li>
</ul>
<p><strong>The Call to Action: Time to Get Technical</strong></p>
<p>If we are serious about competing in the AI race, we must radically rethink the structure of our summits and conferences. The AI field evolves at a blistering pace, and merely talking about it will leave us perpetually playing catch-up. To secure our place in the global AI ecosystem, we must do the actual work: writing the code, training the models, deploying real-world AI solutions, and iterating on them.</p>
<p>Uganda’s developer community has shown immense potential, but this potential must be nurtured with the right kind of engagement. Technical summits, practical workshops, and real-world AI applications should be our focus. Anything less than this risks turning our AI discourse into an empty parade of slogans and missed opportunities.</p>
<p>We must choose: do we want to be passive spectators in the AI revolution, or do we want to be active builders shaping the future? The answer lies in how we structure our conferences today.</p>
<p><em>…</em></p>
<p><em>Originally published in</em> <strong><em>The Independent - Uganda</em></strong> <a target="_blank" href="https://www.independent.co.ug/non-technical-summits-and-conferences-not-impactful-in-the-ai-race/"><em>here</em></a><em>.</em></p>
]]></content:encoded></item><item><title><![CDATA[Scalable Model Serving with TensorFlow Serving]]></title><description><![CDATA[Introduction
In this article, we explore how to deploy machine learning models in a scalable and efficient manner using TensorFlow Serving. TensorFlow Serving is a flexible, high-performance serving system designed for production environments, enabli...]]></description><link>https://kambale.dev/scalable-model-serving-with-tensorflow-serving</link><guid isPermaLink="true">https://kambale.dev/scalable-model-serving-with-tensorflow-serving</guid><category><![CDATA[scalable models]]></category><category><![CDATA[TensorFlow]]></category><category><![CDATA[tensorflow-serving]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Mon, 10 Feb 2025 22:09:09 GMT</pubDate><enclosure url="https://cdn.hashnode.com/res/hashnode/image/upload/v1723164320089/ca0bea24-fd42-464d-9547-4cc9db9fedcd.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<h1 id="heading-introduction">Introduction</h1>
<p>In this article, we explore how to deploy machine learning models in a scalable and efficient manner using TensorFlow Serving. TensorFlow Serving is a flexible, high-performance serving system designed for production environments, enabling you to serve your machine learning models to a large number of clients efficiently. We will cover the basics of TensorFlow Serving, how to set it up, how to serve models, and best practices for scaling your deployment.</p>
<p><strong>What is TensorFlow Serving?</strong></p>
<p>TensorFlow Serving is an open-source serving system specifically designed for deploying machine learning models in production environments. It allows you to serve multiple models or multiple versions of the same model simultaneously, and it can be easily integrated with TensorFlow models.</p>
<p><strong>Why Use TensorFlow Serving?</strong></p>
<p><strong>Scalability</strong>: TensorFlow Serving is designed to handle high-throughput predictions, making it suitable for large-scale deployments.</p>
<p><strong>Flexibility</strong>: It supports multiple models and versions, allowing for easy model management.</p>
<p><strong>Efficiency</strong>: TensorFlow Serving is optimized for performance, with low latency and high throughput.</p>
<h2 id="heading-setting-up-tensorflow-serving">Setting Up TensorFlow Serving</h2>
<p><strong>Installation</strong></p>
<p>TensorFlow Serving can be installed on various platforms, including Linux, macOS, and Windows. However, the most common way to get TensorFlow Serving up and running is through Docker, which simplifies the process and ensures a consistent environment.</p>
<p><strong>Installing TensorFlow Serving on Linux</strong></p>
<p>If you prefer to install TensorFlow Serving directly on your system, you can follow these steps:</p>
<pre><code class="lang-bash"><span class="hljs-built_in">echo</span> <span class="hljs-string">"deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal"</span> | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list &gt; /dev/null
curl -fsSL https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -
sudo apt-get update &amp;&amp; sudo apt-get install tensorflow-model-server
</code></pre>
<p><strong>Verify Installation</strong></p>
<pre><code class="lang-bash">tensorflow_model_server --version
</code></pre>
<p><strong>Docker Setup</strong></p>
<p>Using Docker is the recommended way to set up TensorFlow Serving, as it provides a consistent environment across different platforms.</p>
<ol>
<li><p><strong>Pull the TensorFlow Serving Docker Image</strong></p>
<pre><code class="lang-bash"> docker pull tensorflow/serving
</code></pre>
</li>
<li><p><strong>Run the Docker Container</strong></p>
<pre><code class="lang-bash"> docker run -p 8501:8501 --name=tf_serving \
 --mount <span class="hljs-built_in">type</span>=<span class="hljs-built_in">bind</span>,<span class="hljs-built_in">source</span>=$(<span class="hljs-built_in">pwd</span>)/model,target=/models/model_name \
 -e MODEL_NAME=model_name -t tensorflow/serving
</code></pre>
</li>
</ol>
<p>This command starts a TensorFlow Serving container, serving the model located in the <code>model</code> directory.</p>
<h2 id="heading-serving-a-tensorflow-model">Serving a TensorFlow Model</h2>
<h3 id="heading-exporting-a-tensorflow-model">Exporting a TensorFlow Model</h3>
<p>Before serving a model, you need to export it in a format that TensorFlow Serving can understand. Typically, TensorFlow models are saved in the SavedModel format.</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf

<span class="hljs-comment"># Define a simple model</span>
model = tf.keras.Sequential([
    tf.keras.layers.Dense(<span class="hljs-number">10</span>, activation=<span class="hljs-string">'relu'</span>, input_shape=(<span class="hljs-number">32</span>,)),
    tf.keras.layers.Dense(<span class="hljs-number">1</span>)
])

<span class="hljs-comment"># Save the model</span>
model.save(<span class="hljs-string">'/path/to/exported_model'</span>)
</code></pre>
<h3 id="heading-loading-the-model-into-tensorflow-serving">Loading the Model into TensorFlow Serving</h3>
<p>With the model saved in the correct format, you can load it into TensorFlow Serving by pointing the server to the directory containing the exported model.</p>
<pre><code class="lang-bash">docker run -p 8501:8501 --name=tf_serving \
--mount <span class="hljs-built_in">type</span>=<span class="hljs-built_in">bind</span>,<span class="hljs-built_in">source</span>=/path/to/exported_model,target=/models/my_model \
-e MODEL_NAME=my_model -t tensorflow/serving
</code></pre>
<h3 id="heading-making-predictions-via-rest-api">Making Predictions via REST API</h3>
<p>TensorFlow Serving provides a REST API to interact with your models. You can make predictions by sending HTTP POST requests.</p>
<pre><code class="lang-bash">curl -d <span class="hljs-string">'{"instances": [[1.0, 2.0, 5.0, 1.0, 2.0, 5.0]]}'</span> \
  -X POST http://localhost:8501/v1/models/my_model:predict
</code></pre>
<p>This request sends a JSON payload containing the input data and receives the model's predictions as a response.</p>
<h2 id="heading-scaling-tensorflow-serving">Scaling TensorFlow Serving</h2>
<h3 id="heading-horizontal-and-vertical-scaling">Horizontal and Vertical Scaling</h3>
<ul>
<li><p><strong>Horizontal Scaling</strong>: Involves adding more instances of TensorFlow Serving, distributing the load across multiple servers. This can be achieved using container orchestration platforms like Kubernetes.</p>
</li>
<li><p><strong>Vertical Scaling</strong>: Involves increasing the resources (CPU, memory) of a single TensorFlow Serving instance. This can be done by allocating more resources to the Docker container.</p>
</li>
</ul>
<h3 id="heading-load-balancing">Load Balancing</h3>
<p>Load balancing is crucial for handling large volumes of requests efficiently. You can use a load balancer to distribute incoming requests across multiple TensorFlow Serving instances.</p>
<h3 id="heading-monitoring-and-logging">Monitoring and Logging</h3>
<p>Monitoring and logging are essential for understanding the performance of your TensorFlow Serving deployment. TensorFlow Serving integrates well with monitoring tools like Prometheus and Grafana.</p>
<p>Example Prometheus configuration:</p>
<pre><code class="lang-yaml"><span class="hljs-attr">global:</span>
  <span class="hljs-attr">scrape_interval:</span> <span class="hljs-string">15s</span>

<span class="hljs-attr">scrape_configs:</span>
  <span class="hljs-bullet">-</span> <span class="hljs-attr">job_name:</span> <span class="hljs-string">'tensorflow_serving'</span>
    <span class="hljs-attr">static_configs:</span>
      <span class="hljs-bullet">-</span> <span class="hljs-attr">targets:</span> [<span class="hljs-string">'localhost:8501'</span>]
</code></pre>
<p>This configuration will scrape metrics from your TensorFlow Serving instance every 15 seconds.</p>
<h2 id="heading-advanced-features-of-tensorflow-serving">Advanced Features of TensorFlow Serving</h2>
<h3 id="heading-model-versioning">Model Versioning</h3>
<p>TensorFlow Serving supports serving multiple versions of the same model. You can specify which version to serve using the <code>--model_version_policy</code> flag.</p>
<pre><code class="lang-bash">docker run -p 8501:8501 --name=tf_serving \
--mount <span class="hljs-built_in">type</span>=<span class="hljs-built_in">bind</span>,<span class="hljs-built_in">source</span>=/path/to/exported_model,target=/models/my_model \
-e MODEL_NAME=my_model -e MODEL_VERSION_POLICY=<span class="hljs-string">"latest"</span> -t tensorflow/serving
</code></pre>
<h3 id="heading-batch-prediction">Batch Prediction</h3>
<p>TensorFlow Serving can be configured to perform batch predictions, which can significantly improve performance for high-throughput scenarios.</p>
<pre><code class="lang-bash">docker run -p 8501:8501 --name=tf_serving \
--mount <span class="hljs-built_in">type</span>=<span class="hljs-built_in">bind</span>,<span class="hljs-built_in">source</span>=/path/to/exported_model,target=/models/my_model \
-e MODEL_NAME=my_model -e TF_SERVING_BATCHING_PARAMETERS_FILE=<span class="hljs-string">"/path/to/batching_parameters"</span> \
-t tensorflow/serving
</code></pre>
<h3 id="heading-customizing-tensorflow-serving">Customizing TensorFlow Serving</h3>
<p>You can customize TensorFlow Serving by adding custom code for preprocessing, postprocessing, or integrating with other systems.</p>
<p>Example of a custom model handler:</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf
<span class="hljs-keyword">from</span> tensorflow_serving.apis <span class="hljs-keyword">import</span> predict_pb2, prediction_service_pb2_grpc

<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">CustomModelHandler</span>:</span>
    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span>(<span class="hljs-params">self, model_path</span>):</span>
        self.model = tf.keras.models.load_model(model_path)

    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">predict</span>(<span class="hljs-params">self, request: predict_pb2.PredictRequest</span>) -&gt; predict_pb2.PredictResponse:</span>
        <span class="hljs-comment"># Custom preprocessing</span>
        inputs = request.inputs[<span class="hljs-string">'input_tensor'</span>].numpy()

        <span class="hljs-comment"># Model prediction</span>
        predictions = self.model.predict(inputs)

        <span class="hljs-comment"># Custom postprocessing</span>
        response = predict_pb2.PredictResponse()
        response.outputs[<span class="hljs-string">'output_tensor'</span>].CopyFrom(tf.make_tensor_proto(predictions))
        <span class="hljs-keyword">return</span> response
</code></pre>
<h2 id="heading-best-practices">Best Practices</h2>
<h3 id="heading-security-considerations">Security Considerations</h3>
<ul>
<li><p><strong>Authentication</strong>: Implement authentication mechanisms to ensure that only authorized clients can access your models.</p>
</li>
<li><p><strong>Encryption</strong>: Use TLS to encrypt data in transit between clients and TensorFlow Serving.</p>
</li>
</ul>
<h3 id="heading-resource-management">Resource Management</h3>
<ul>
<li><p><strong>CPU and Memory Limits</strong>: Set appropriate limits on CPU and memory usage to prevent resource exhaustion.</p>
</li>
<li><p><strong>Autoscaling</strong>: Use autoscaling to dynamically adjust the number of TensorFlow Serving instances based on demand.</p>
</li>
</ul>
<h3 id="heading-optimizing-performance">Optimizing Performance</h3>
<ul>
<li><p><strong>Model Optimization</strong>: Optimize your model using techniques like quantization to reduce latency.</p>
</li>
<li><p><strong>Caching</strong>: Implement caching mechanisms to store frequently requested predictions, reducing the load on your TensorFlow Serving instance.</p>
</li>
</ul>
<h2 id="heading-conclusion">Conclusion</h2>
<p>TensorFlow Serving provides a powerful and flexible solution for serving machine learning models in production environments. By following the steps outlined in this tutorial, you can set up a scalable TensorFlow Serving deployment that is capable of handling large volumes of requests efficiently. With advanced features like model versioning, batch prediction, and custom handlers, TensorFlow Serving can be tailored to meet the specific needs of your application.</p>
]]></content:encoded></item><item><title><![CDATA[Introduction to TensorFlow Extended (TFX)]]></title><description><![CDATA[TensorFlow Extended (TFX) is an end-to-end platform for deploying production machine learning (ML) pipelines. TFX allows data scientists and ML engineers to build, evaluate, and deploy ML models in a scalable, reliable, and reproducible manner. This ...]]></description><link>https://kambale.dev/tensorflow-extended-tfx</link><guid isPermaLink="true">https://kambale.dev/tensorflow-extended-tfx</guid><category><![CDATA[TensorFlow]]></category><category><![CDATA[TFX]]></category><category><![CDATA[Pipeline]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Tue, 06 Aug 2024 00:11:38 GMT</pubDate><enclosure url="https://cdn.hashnode.com/res/hashnode/image/upload/v1722774005457/afdce8da-141e-4909-8eca-3ab92381f512.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<p>TensorFlow Extended (TFX) is an end-to-end platform for deploying production machine learning (ML) pipelines. TFX allows data scientists and ML engineers to build, evaluate, and deploy ML models in a scalable, reliable, and reproducible manner. This article will introduce you to the core components of TFX, provide practical examples using the Iris dataset, and guide you through building a simple TFX pipeline.</p>
<h2 id="heading-what-is-tfx">What is TFX?</h2>
<p>TFX is a production-ready ML platform designed to help you build, deploy, and manage ML models. It consists of a set of libraries and tools that help automate and manage the ML lifecycle. TFX pipelines are portable and can run on various platforms, including Apache Beam, Apache Airflow, and Kubeflow.</p>
<h3 id="heading-key-features-of-tfx">Key Features of TFX</h3>
<ul>
<li><p><strong>Scalability</strong>: TFX can handle large-scale data processing and training.</p>
</li>
<li><p><strong>Portability</strong>: Pipelines can run on different platforms and environments.</p>
</li>
<li><p><strong>Modularity</strong>: TFX components are designed to be modular, allowing you to customize and extend them as needed.</p>
</li>
<li><p><strong>Production-Ready</strong>: TFX is built with production deployment in mind, ensuring reliability and robustness.</p>
</li>
</ul>
<h2 id="heading-tfx-components">TFX Components</h2>
<p>TFX pipelines are composed of several components, each responsible for a specific part of the ML lifecycle. Here are the core components:</p>
<p><strong>ExampleGen</strong></p>
<p>ExampleGen is the first component in a TFX pipeline. It ingests and splits data into training and evaluation datasets. ExampleGen supports various data sources, such as CSV files, TFRecord files, and BigQuery.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> tfx.components <span class="hljs-keyword">import</span> CsvExampleGen

example_gen = CsvExampleGen(input_base=<span class="hljs-string">'/content/iris.csv'</span>)
</code></pre>
<p><strong>StatisticsGen</strong></p>
<p>StatisticsGen computes statistics over the data for data visualization and validation. It generates statistics using TensorFlow Data Validation (TFDV).</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> tfx.components <span class="hljs-keyword">import</span> StatisticsGen

statistics_gen = StatisticsGen(examples=example_gen.outputs[<span class="hljs-string">'examples'</span>])
</code></pre>
<p><strong>SchemaGen</strong></p>
<p>SchemaGen generates a schema for the data based on the statistics computed by StatisticsGen. The schema includes information about the data types, domains, and constraints.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> tfx.components <span class="hljs-keyword">import</span> SchemaGen

schema_gen = SchemaGen(statistics=statistics_gen.outputs[<span class="hljs-string">'statistics'</span>])
</code></pre>
<p><strong>ExampleValidator</strong></p>
<p>ExampleValidator detects anomalies in the data by comparing the data against the schema generated by SchemaGen. It helps ensure data quality.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> tfx.components <span class="hljs-keyword">import</span> ExampleValidator

example_validator = ExampleValidator(statistics=statistics_gen.outputs[<span class="hljs-string">'statistics'</span>], schema=schema_gen.outputs[<span class="hljs-string">'schema'</span>])
</code></pre>
<p><strong>Transform</strong></p>
<p>Transform performs feature engineering and data transformation using TensorFlow Transform (TFT). It preprocesses the data for model training and serving.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> tfx.components <span class="hljs-keyword">import</span> Transform

<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">preprocessing_fn</span>(<span class="hljs-params">inputs</span>):</span>
    outputs = {
        <span class="hljs-string">'sepal_length'</span>: Transform.scale_to_z_score(inputs[<span class="hljs-string">'sepal.length'</span>]),
        <span class="hljs-string">'sepal_width'</span>: Transform.scale_to_z_score(inputs[<span class="hljs-string">'sepal.width'</span>]),
        <span class="hljs-string">'petal_length'</span>: Transform.scale_to_z_score(inputs[<span class="hljs-string">'petal.ength'</span>]),
        <span class="hljs-string">'petal_width'</span>: Transform.scale_to_z_score(inputs[<span class="hljs-string">'petal.width'</span>]),
        <span class="hljs-string">'species'</span>: inputs[<span class="hljs-string">'variety'</span>]
    }
    <span class="hljs-keyword">return</span> outputs

transform = Transform(
    examples=example_gen.outputs[<span class="hljs-string">'examples'</span>],
    schema=schema_gen.outputs[<span class="hljs-string">'schema'</span>],
    module_file=<span class="hljs-string">'/content/preprocessing.py'</span>
)
</code></pre>
<p><strong>Trainer</strong></p>
<p>Trainer trains an ML model using the preprocessed data. It supports various frameworks, including TensorFlow, Keras, and Estimator.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> tfx.components <span class="hljs-keyword">import</span> Trainer
<span class="hljs-keyword">from</span> tfx.proto <span class="hljs-keyword">import</span> trainer_pb2

trainer = Trainer(
    module_file=<span class="hljs-string">'/content/trainer_module.py'</span>,
    examples=transform.outputs[<span class="hljs-string">'transformed_examples'</span>],
    schema=schema_gen.outputs[<span class="hljs-string">'schema'</span>],
    transform_graph=transform.outputs[<span class="hljs-string">'transform_graph'</span>],
    train_args=trainer_pb2.TrainArgs(num_steps=<span class="hljs-number">1000</span>),
    eval_args=trainer_pb2.EvalArgs(num_steps=<span class="hljs-number">100</span>)
)
</code></pre>
<p><strong>Evaluator</strong></p>
<p>Evaluator evaluates the trained model using TensorFlow Model Analysis (TFMA). It helps in validating and comparing different models.</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> tensorflow_model_analysis <span class="hljs-keyword">as</span> tfma

eval_config = tfma.EvalConfig(
    slicing_specs=[tfma.SlicingSpec()],
    metrics_specs=[
        tfma.MetricsSpec(
            metrics=[
                tfma.MetricConfig(class_name=<span class="hljs-string">'SparseCategoricalAccuracy'</span>)
            ]
        )
    ]
)

evaluator = Evaluator(
    examples=example_gen.outputs[<span class="hljs-string">'examples'</span>],
    model=trainer.outputs[<span class="hljs-string">'model'</span>],
    eval_config=eval_config
)
</code></pre>
<p><strong>Pusher</strong></p>
<p>Pusher deploys the trained model to a serving infrastructure. It ensures that the model meets certain criteria before pushing it to production.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> tfx.components <span class="hljs-keyword">import</span> Pusher
<span class="hljs-keyword">from</span> tfx.proto <span class="hljs-keyword">import</span> pusher_pb2

pusher = Pusher(
    model=trainer.outputs[<span class="hljs-string">'model'</span>],
    model_blessing=evaluator.outputs[<span class="hljs-string">'blessing'</span>],
    push_destination=pusher_pb2.PushDestination(
        filesystem=pusher_pb2.PushDestination.Filesystem(
            base_directory=<span class="hljs-string">'/content/model'</span>
        )
    )
)
</code></pre>
<h2 id="heading-setting-up-tfx">Setting Up TFX</h2>
<p>Before building a TFX pipeline, it's essential to set up the environment. This involves installing the necessary packages and configuring the runtime environment.</p>
<p><strong>Installing TFX</strong>: TFX can be installed via pip. The installation includes all the required libraries and dependencies for running TFX components.</p>
<pre><code class="lang-bash">pip install tfx
</code></pre>
<p><strong>Configuring the Environment</strong>: Setting up the environment involves configuring paths for data, pipelines, and model artifacts. This ensures that all components can access the necessary resources and save outputs in the correct locations.</p>
<h2 id="heading-building-a-simple-tfx-pipeline">Building a Simple TFX Pipeline</h2>
<p>To illustrate the capabilities of TFX, we will build a pipeline using the Iris dataset, a well-known dataset for classification tasks. The Iris dataset contains 150 samples of iris flowers, each with four features (sepal length, sepal width, petal length, petal width) and a class label (species).</p>
<p><strong>Data Ingestion</strong></p>
<p>The first step in the TFX pipeline is to ingest the Iris dataset using <code>ExampleGen</code>. This component reads the dataset, splits it into training and evaluation sets, and converts it into the TFX internal format.</p>
<p><strong>CSV ExampleGen</strong>: For the Iris dataset, we use the <code>CsvExampleGen</code> component, which ingests data from CSV files. It automatically splits the data into training and evaluation sets based on a specified ratio.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> tfx.components <span class="hljs-keyword">import</span> CsvExampleGen

example_gen = CsvExampleGen(input_base=<span class="hljs-string">'/content/iris.csv'</span>)
</code></pre>
<p><strong>Data Statistics</strong></p>
<p><code>StatisticsGen</code> computes descriptive statistics for the dataset, providing insights into data distributions and detecting anomalies. It uses TensorFlow Data Validation (TFDV) to generate statistics such as mean, median, and standard deviation for each feature.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> tfx.components <span class="hljs-keyword">import</span> StatisticsGen

statistics_gen = StatisticsGen(examples=example_gen.outputs[<span class="hljs-string">'examples'</span>])
</code></pre>
<p><strong>Importance</strong>: Understanding the distribution of data is crucial for identifying potential issues and ensuring that the data is suitable for training. StatisticsGen helps detect anomalies such as outliers and missing values, which can affect model performance.</p>
<p><strong>Schema Generation</strong></p>
<p>Based on the statistics computed by StatisticsGen, SchemaGen generates a schema for the dataset. The schema includes information about feature types, value ranges, and presence constraints, serving as a blueprint for data validation and transformation.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> tfx.components <span class="hljs-keyword">import</span> SchemaGen

schema_gen = SchemaGen(statistics=statistics_gen.outputs[<span class="hljs-string">'statistics'</span>])
</code></pre>
<p><strong>Definition</strong>: The schema defines the expected structure of the data, including feature types (numeric, categorical, etc.), value ranges, and constraints (e.g., required features). This information is critical for ensuring data consistency and preparing it for model training.</p>
<p><strong>Data Validation</strong></p>
<p><code>ExampleValidator</code> validates the dataset against the schema, identifying anomalies and missing values. It ensures that the data adheres to the expected format, which is essential for training reliable models.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> tfx.components <span class="hljs-keyword">import</span> ExampleValidator

example_validator = ExampleValidator(statistics=statistics_gen.outputs[<span class="hljs-string">'statistics'</span>], schema=schema_gen.outputs[<span class="hljs-string">'schema'</span>])
</code></pre>
<p><strong>Anomaly Detection</strong>: ExampleValidator detects anomalies such as outliers, missing values, and unexpected feature values. These issues can affect model performance and lead to unreliable predictions, making data validation a crucial step in the pipeline.</p>
<p><strong>Data Transformation</strong></p>
<p><code>Transform</code> performs feature engineering and data preprocessing using TensorFlow Transform (TFT). It applies transformations such as scaling, normalization, and encoding, preparing the data for model training.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> tfx.components <span class="hljs-keyword">import</span> Transform

<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">preprocessing_fn</span>(<span class="hljs-params">inputs</span>):</span>
    <span class="hljs-comment"># Normalize the numeric features</span>
    outputs = {
        <span class="hljs-string">'sepal_length'</span>: Transform.scale_to_z_score(inputs[<span class="hljs-string">'sepal.length'</span>]),
        <span class="hljs-string">'sepal_width'</span>: Transform.scale_to_z_score(inputs[<span class="hljs-string">'sepal.width'</span>]),
        <span class="hljs-string">'petal_length'</span>: Transform.scale_to_z_score(inputs[<span class="hljs-string">'petal.ength'</span>]),
        <span class="hljs-string">'petal_width'</span>: Transform.scale_to_z_score(inputs[<span class="hljs-string">'petal.width'</span>]),
        <span class="hljs-string">'species'</span>: inputs[<span class="hljs-string">'variety'</span>]
    }
    <span class="hljs-keyword">return</span> outputs

transform = Transform(
    examples=example_gen.outputs[<span class="hljs-string">'examples'</span>],
    schema=schema_gen.outputs[<span class="hljs-string">'schema'</span>],
    module_file=<span class="hljs-string">'/content/preprocessing.py'</span>
)
</code></pre>
<p><strong>Preprocessing Functions</strong>: Transform defines preprocessing functions that apply transformations to the raw data. These functions can include operations such as scaling numerical features, encoding categorical features, and generating new features based on existing ones.</p>
<p><strong>Model Training</strong></p>
<p>The <code>Trainer</code> component trains the ML model using the transformed data. It leverages TensorFlow's capabilities to define, train, and evaluate models.</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf
<span class="hljs-keyword">from</span> tfx <span class="hljs-keyword">import</span> v1 <span class="hljs-keyword">as</span> tfxio

<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">_input_fn</span>(<span class="hljs-params">file_pattern, data_accessor, schema, batch_size=<span class="hljs-number">200</span></span>):</span>
    raw_data = data_accessor.tf_dataset_factory(
        file_pattern, tfxio.TensorFlowDatasetOptions(batch_size=batch_size), schema)
    transformed_data = raw_data.map(_parse_fn)
    <span class="hljs-keyword">return</span> transformed_data

<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">_build_keras_model</span>():</span>
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(<span class="hljs-number">4</span>,)),
        tf.keras.layers.Dense(<span class="hljs-number">10</span>, activation=<span class="hljs-string">'relu'</span>),
        tf.keras.layers.Dense(<span class="hljs-number">10</span>, activation=<span class="hljs-string">'relu'</span>),
        tf.keras.layers.Dense(<span class="hljs-number">3</span>, activation=<span class="hljs-string">'softmax'</span>)
    ])
    model.compile(optimizer=tf.keras.optimizers.Adam(lr=<span class="hljs-number">0.001</span>),
                  loss=<span class="hljs-string">'sparse_categorical_crossentropy'</span>,
                  metrics=[<span class="hljs-string">'accuracy'</span>])
    <span class="hljs-keyword">return</span> model

<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">run_fn</span>(<span class="hljs-params">fn_args</span>):</span>
    schema = tfx_bsl.public.tfxio_utils.get_tfx_schema_from_tensorflow_metadata_schema(
        fn_args.schema)
    train_dataset = _input_fn(fn_args.train_files, fn_args.data_accessor, schema, batch_size=<span class="hljs-number">200</span>)
    eval_dataset = _input_fn(fn_args.eval_files, fn_args.data_accessor, schema, batch_size=<span class="hljs-number">200</span>)

    model = _build_keras_model()
    model.fit(train_dataset, steps_per_epoch=fn_args.train_steps,
              validation_data=eval_dataset, validation_steps=fn_args.eval_steps)
    model.save(fn_args.serving_model_dir, save_format=<span class="hljs-string">'tf'</span>)
</code></pre>
<p><strong>Model Definition</strong>: Trainer uses a module file containing the model definition and training logic. This file defines the architecture of the model, the loss function, and the optimization algorithm. It also includes the training and evaluation steps, specifying the number of epochs, batch size, and evaluation metrics.</p>
<p><strong>Trainer</strong></p>
<p>Trainer trains an ML model using the preprocessed data. It supports various frameworks, including TensorFlow, Keras, and Estimator</p>
<pre><code class="lang-python">
<span class="hljs-keyword">from</span> tfx.components <span class="hljs-keyword">import</span> Trainer
<span class="hljs-keyword">from</span> tfx.proto <span class="hljs-keyword">import</span> trainer_pb2

trainer = Trainer(
    module_file=<span class="hljs-string">'/content/trainer_module.py'</span>,
    examples=transform.outputs[<span class="hljs-string">'transformed_examples'</span>],
    schema=schema_gen.outputs[<span class="hljs-string">'schema'</span>],
    transform_graph=transform.outputs[<span class="hljs-string">'transform_graph'</span>],
    train_args=trainer_pb2.TrainArgs(num_steps=<span class="hljs-number">1000</span>),
    eval_args=trainer_pb2.EvalArgs(num_steps=<span class="hljs-number">100</span>)
)
</code></pre>
<p><strong>Model Evaluation</strong></p>
<p>Evaluator evaluates the trained model using TensorFlow Model Analysis (TFMA). It performs a detailed analysis of model performance, identifying potential issues and ensuring that the model meets the desired criteria before deployment.</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> tensorflow_model_analysis <span class="hljs-keyword">as</span> tfma

eval_config = tfma.EvalConfig(
    slicing_specs=[tfma.SlicingSpec()],
    metrics_specs=[
        tfma.MetricsSpec(
            metrics=[
                tfma.MetricConfig(class_name=<span class="hljs-string">'SparseCategoricalAccuracy'</span>)
            ]
        )
    ]
)

evaluator = Evaluator(
    examples=example_gen.outputs[<span class="hljs-string">'examples'</span>],
    model=trainer.outputs[<span class="hljs-string">'model'</span>],
    eval_config=eval_config
)
</code></pre>
<p><strong>Evaluation Configuration</strong>: Evaluator uses an evaluation configuration to specify the metrics and slicing specifications for model evaluation. Metrics such as accuracy, precision, and recall are used to assess model performance, while slicing specifications allow for analyzing performance across different subsets of the data.</p>
<p><strong>Model Deployment</strong></p>
<p>The final step in the TFX pipeline is to deploy the validated model using <code>Pusher</code>. This component ensures that only the best models are deployed to the serving infrastructure.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> tfx.components <span class="hljs-keyword">import</span> Pusher
<span class="hljs-keyword">from</span> tfx.proto <span class="hljs-keyword">import</span> pusher_pb2

pusher = Pusher(
    model=trainer.outputs[<span class="hljs-string">'model'</span>],
    model_blessing=evaluator.outputs[<span class="hljs-string">'blessing'</span>],
    push_destination=pusher_pb2.PushDestination(
        filesystem=pusher_pb2.PushDestination.Filesystem(
            base_directory=<span class="hljs-string">'/content/model'</span>
        )
    )
)
</code></pre>
<p><strong>Model Deployment</strong>: Pusher deploys the model to a specified serving infrastructure, such as TensorFlow Serving. It ensures that the model is production-ready and meets the desired performance criteria, facilitating continuous model improvement.</p>
<h2 id="heading-running-the-pipeline">Running the Pipeline</h2>
<p>The pipeline is executed using an orchestrator like the Local DAG Runner, which processes the data, trains the model, evaluates its performance, and deploys it if it meets the required criteria.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> tfx.orchestration <span class="hljs-keyword">import</span> pipeline
<span class="hljs-keyword">from</span> tfx.orchestration.local.local_dag_runner <span class="hljs-keyword">import</span> LocalDagRunner

<span class="hljs-comment"># Define the pipeline</span>
pipeline = pipeline.Pipeline(
    pipeline_name=<span class="hljs-string">'iris_pipeline'</span>,
    pipeline_root=<span class="hljs-string">'/content/iris_pipeline'</span>,
    components=[example_gen, statistics_gen, schema_gen, example_validator, transform, trainer, evaluator, pusher],
    enable_cache=<span class="hljs-literal">True</span>,
    metadata_connection_config=<span class="hljs-literal">None</span>
)

<span class="hljs-comment"># Run the pipeline</span>
LocalDagRunner().run(pipeline)
</code></pre>
<p>After execution, the pipeline produces outputs such as transformed data, trained models, evaluation results, and metadata, all stored in a specified directory for further analysis and deployment.</p>
<p>Monitoring involves checking logs and reviewing evaluation metrics, while debugging includes inspecting artifacts and re-executing specific components to resolve any issues. This process ensures the pipeline runs smoothly, producing reliable and scalable machine learning models ready for deployment.</p>
<h2 id="heading-conclusion">Conclusion</h2>
<p>In this article, we introduced TensorFlow Extended (TFX) and its core components. We demonstrated how to set up a TFX environment, build a simple TFX pipeline using the Iris dataset, and run it. TFX provides a powerful and flexible framework for managing the end-to-end ML lifecycle, from data ingestion to model deployment. You should now have a solid foundation for building and deploying your own TFX pipelines.</p>
<p>TFX's modularity and scalability make it suitable for a wide range of ML applications, ensuring that you can build robust and production-ready ML systems. Happy experimenting with TFX!</p>
<h1 id="heading-resources">Resources</h1>
<p>GitHub Gist to the notebook</p>
<div class="gist-block embed-wrapper" data-gist-show-loading="false" data-id="ce26c4ab2a7f040372007d422ae0fbe2"><div class="embed-loading"><div class="loadingRow"></div><div class="loadingRow"></div></div><a href="https://gist.github.com/wkambale/ce26c4ab2a7f040372007d422ae0fbe2" class="embed-card">https://gist.github.com/wkambale/ce26c4ab2a7f040372007d422ae0fbe2</a></div>]]></content:encoded></item><item><title><![CDATA[Distributed Model Training with TensorFlow]]></title><description><![CDATA[Training machine learning models on large datasets can be time-consuming and computationally intensive. To address this, TensorFlow provides robust support for distributed training, allowing models to be trained across multiple devices and machines. ...]]></description><link>https://kambale.dev/distributed-model-training</link><guid isPermaLink="true">https://kambale.dev/distributed-model-training</guid><category><![CDATA[Model]]></category><category><![CDATA[TensorFlow]]></category><category><![CDATA[distributed training]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Tue, 30 Jul 2024 21:50:00 GMT</pubDate><enclosure url="https://cdn.hashnode.com/res/hashnode/image/upload/v1722371260527/f4429f7e-1563-4c4e-89d0-0c55ff8709ba.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<p>Training machine learning models on large datasets can be time-consuming and computationally intensive. To address this, TensorFlow provides robust support for distributed training, allowing models to be trained across multiple devices and machines. This article will guide you through the process of setting up and running distributed model training with TensorFlow.</p>
<h2 id="heading-what-is-distributed-training">What is Distributed Training</h2>
<p>Distributed training allows you to leverage multiple GPUs, TPUs, or even multiple machines to accelerate the training process of your machine learning models. TensorFlow's distributed training capabilities are built around the concept of a "distribution strategy," which specifies how computation is distributed across devices.</p>
<h2 id="heading-types-of-distributed-strategies">Types of Distributed Strategies</h2>
<p>TensorFlow provides several strategies for distributed training, each suited for different scenarios and hardware configurations. Let's get into each strategy, including their use cases and advantages to help you get started.</p>
<p><strong>MirroredStrategy</strong></p>
<p><code>tf.distribute.MirroredStrategy</code> is designed for synchronous training on multiple GPUs on a single machine. It replicates all of the model variables across the GPUs and then performs a synchronous update to keep them in sync.</p>
<div class="hn-table">
<table>
<thead>
<tr>
<td><strong>Use Case</strong></td><td><strong>Advantages</strong></td></tr>
</thead>
<tbody>
<tr>
<td>Best suited for training on a single machine with multiple GPUs.</td><td>Easy to set up and use.</td></tr>
<tr>
<td>Ideal for high-performance workstations or cloud instances with multiple GPUs.</td><td>Provides synchronous training, which is generally easier to debug and produces consistent results.</td></tr>
</tbody>
</table>
</div><p><strong>MultiWorkerMirroredStrategy</strong></p>
<p><code>tf.distribute.MultiWorkerMirroredStrategy</code> extends <code>MirroredStrategy</code> to multiple machines. Each worker (machine) runs a replica of the model and synchronizes updates across all workers.</p>
<div class="hn-table">
<table>
<thead>
<tr>
<td><strong>Use Case</strong></td><td><strong>Advantages</strong></td></tr>
</thead>
<tbody>
<tr>
<td>Suitable for large-scale training on multiple machines.</td><td>Scales seamlessly from a few to many workers.</td></tr>
<tr>
<td>Ideal for scenarios where a single machine's resources are insufficient.</td><td>Utilizes the collective communication strategy to aggregate gradients and synchronize updates.</td></tr>
</tbody>
</table>
</div><p><strong>TPUStrategy</strong></p>
<p><code>tf.distribute.TPUStrategy</code> is used to train models on Google's TPUs. It is optimized for high-performance training and requires minimal code changes from GPU training.</p>
<div class="hn-table">
<table>
<thead>
<tr>
<td><strong>Use Case</strong></td><td><strong>Advantages</strong></td></tr>
</thead>
<tbody>
<tr>
<td>Best for large-scale models and datasets that require high computational power.</td><td>TPUs provide significant speedup compared to GPUs for specific workloads.</td></tr>
<tr>
<td>Ideal for cloud environments where TPU resources are available.</td><td>TensorFlow seamlessly integrates with TPUs, making it easier to switch from GPU to TPU.</td></tr>
</tbody>
</table>
</div><p><strong>ParameterServerStrategy</strong></p>
<p><code>tf.distribute.experimental.ParameterServerStrategy</code> is an asynchronous training strategy where the computation is divided between parameter servers and workers. Parameter servers store model parameters, and workers perform the computations.</p>
<div class="hn-table">
<table>
<thead>
<tr>
<td><strong>Use Case</strong></td><td><strong>Advantages</strong></td></tr>
</thead>
<tbody>
<tr>
<td>Suitable for large-scale distributed training where asynchronous updates are acceptable.</td><td>Allows for more flexible and scalable training.</td></tr>
<tr>
<td>Ideal for scenarios with large models and datasets where synchronous updates may cause bottlenecks.</td><td>Reduces synchronization overhead, potentially speeding up training.</td></tr>
</tbody>
</table>
</div><h2 id="heading-preparing-the-data">Preparing the Data</h2>
<p>Data preparation is a critical step in any machine learning workflow. For distributed training, the way you prepare and feed data to your model can significantly impact the training efficiency and performance. TensorFlow's <a target="_blank" href="http://tf.data"><code>tf.data</code></a> API is a powerful tool for building input pipelines that can be easily integrated with distributed training.</p>
<p><strong>Loading and Preprocessing Data</strong></p>
<p>We will use the MNIST dataset, a classic dataset of handwritten digits. The dataset is available directly through TensorFlow, which makes loading and preprocessing straightforward.</p>
<pre><code class="lang-python">(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train = x_train / <span class="hljs-number">255.0</span>
x_test = x_test / <span class="hljs-number">255.0</span>

x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]
</code></pre>
<p><strong>Creating TensorFlow Datasets</strong></p>
<p>TensorFlow Datasets (<a target="_blank" href="http://tf.data"><code>tf.data.Dataset</code></a>) provides a high-level API for creating and manipulating data pipelines. Using this API, we can create efficient input pipelines that are capable of feeding data to the model in a scalable and efficient manner.</p>
<pre><code class="lang-python">train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))

BATCH_SIZE = <span class="hljs-number">64</span>
train_dataset = train_dataset.shuffle(buffer_size=<span class="hljs-number">10000</span>).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)
</code></pre>
<p><strong>Optimizing Data Pipelines</strong></p>
<p>For distributed training, it’s important to ensure that the data pipeline does not become a bottleneck. TensorFlow provides several techniques to optimize data pipelines:</p>
<ul>
<li><p><strong>Prefetching</strong>: Overlap the preprocessing and model execution of data.</p>
</li>
<li><p><strong>Caching</strong>: Cache data in memory to avoid redundant computations.</p>
</li>
<li><p><strong>Parallel Interleave</strong>: Read data from multiple files in parallel.</p>
</li>
</ul>
<pre><code class="lang-python">AUTOTUNE = tf.data.experimental.AUTOTUNE

train_dataset = train_dataset.cache()
train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)
</code></pre>
<h2 id="heading-defining-the-model">Defining the Model</h2>
<p>Defining a model in TensorFlow is typically done using the Keras API, which provides a simple and flexible way to build neural networks. Let's define a convolutional neural network (CNN) for the MNIST dataset.</p>
<p><strong>Creating the Model</strong></p>
<p>A CNN is well-suited for image classification tasks. Here, we'll create a simple CNN with two convolutional layers followed by pooling layers, a flattening layer, and two dense layers.</p>
<pre><code class="lang-python"><span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">create_model</span>():</span>
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(<span class="hljs-number">28</span>, <span class="hljs-number">28</span>, <span class="hljs-number">1</span>)),
        tf.keras.layers.Conv2D(<span class="hljs-number">32</span>, (<span class="hljs-number">3</span>, <span class="hljs-number">3</span>), activation=<span class="hljs-string">'relu'</span>),
        tf.keras.layers.MaxPooling2D((<span class="hljs-number">2</span>, <span class="hljs-number">2</span>)),
        tf.keras.layers.Conv2D(<span class="hljs-number">64</span>, (<span class="hljs-number">3</span>, <span class="hljs-number">3</span>), activation=<span class="hljs-string">'relu'</span>),
        tf.keras.layers.MaxPooling2D((<span class="hljs-number">2</span>, <span class="hljs-number">2</span>)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(<span class="hljs-number">64</span>, activation=<span class="hljs-string">'relu'</span>),
        tf.keras.layers.Dense(<span class="hljs-number">10</span>, activation=<span class="hljs-string">'softmax'</span>)
    ])
    <span class="hljs-keyword">return</span> model
</code></pre>
<p><strong>Compiling the Model</strong></p>
<p>After defining the model, the next step is to compile it. Compilation involves specifying the optimizer, loss function, and metrics that the model should use during training.</p>
<pre><code class="lang-python">model = create_model()
model.compile(optimizer=<span class="hljs-string">'adam'</span>,
              loss=<span class="hljs-string">'sparse_categorical_crossentropy'</span>,
              metrics=[<span class="hljs-string">'accuracy'</span>])
</code></pre>
<p><strong>Model Summary</strong></p>
<p>It’s always a good practice to print the model summary to understand the architecture and ensure that the model is correctly defined.</p>
<pre><code class="lang-python">model.summary()
</code></pre>
<h2 id="heading-configuring-the-distributed-strategy">Configuring the Distributed Strategy</h2>
<p>TensorFlow's distribution strategies allow you to run your training on multiple GPUs, TPUs, or even across multiple machines. This section explains how to set up and configure different distributed strategies.</p>
<p><strong>MirroredStrategy</strong></p>
<p><code>tf.distribute.MirroredStrategy</code> is designed for synchronous training on multiple GPUs on a single machine. It replicates all model variables across the GPUs and then performs a synchronous update to keep them in sync.</p>
<pre><code class="lang-python">strategy = tf.distribute.MirroredStrategy()

<span class="hljs-keyword">with</span> strategy.scope():
    model = create_model()
    model.compile(optimizer=<span class="hljs-string">'adam'</span>,
                  loss=<span class="hljs-string">'sparse_categorical_crossentropy'</span>,
                  metrics=[<span class="hljs-string">'accuracy'</span>])
</code></pre>
<p><strong>MultiWorkerMirroredStrategy</strong></p>
<p><code>tf.distribute.MultiWorkerMirroredStrategy</code> extends <code>MirroredStrategy</code> to multiple machines. You need to configure the cluster spec and set the environment variables appropriately.</p>
<p><strong>Setting Up Cluster Spec</strong></p>
<pre><code class="lang-python">cluster_spec = {
    <span class="hljs-string">'worker'</span>: [<span class="hljs-string">'worker1.example.com:2222'</span>, <span class="hljs-string">'worker2.example.com:2222'</span>]
}

os.environ[<span class="hljs-string">'TF_CONFIG'</span>] = json.dumps({
    <span class="hljs-string">'cluster'</span>: cluster_spec,
    <span class="hljs-string">'task'</span>: {<span class="hljs-string">'type'</span>: <span class="hljs-string">'worker'</span>, <span class="hljs-string">'index'</span>: <span class="hljs-number">0</span>}
})

strategy = tf.distribute.MultiWorkerMirroredStrategy()
</code></pre>
<p><strong>Training with MultiWorkerMirroredStrategy</strong></p>
<pre><code class="lang-python"><span class="hljs-keyword">with</span> strategy.scope():
    model = create_model()
    model.compile(optimizer=<span class="hljs-string">'adam'</span>,
                  loss=<span class="hljs-string">'sparse_categorical_crossentropy'</span>,
                  metrics=[<span class="hljs-string">'accuracy'</span>])

    model.fit(train_dataset, epochs=<span class="hljs-number">5</span>)
    model.evaluate(test_dataset)
</code></pre>
<p><strong>TPUStrategy</strong></p>
<p><code>tf.distribute.TPUStrategy</code> is used to train models on Google's TPUs. It is optimized for high-performance training and requires minimal code changes from GPU training.</p>
<pre><code class="lang-python">resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=<span class="hljs-string">'your-tpu-address'</span>)
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)

<span class="hljs-keyword">with</span> strategy.scope():
    model = create_model()
    model.compile(optimizer=<span class="hljs-string">'adam'</span>,
                  loss=<span class="hljs-string">'sparse_categorical_crossentropy'</span>,
                  metrics=[<span class="hljs-string">'accuracy'</span>])

    model.fit(train_dataset, epochs=<span class="hljs-number">5</span>)
    model.evaluate(test_dataset)
</code></pre>
<p><strong>ParameterServerStrategy</strong></p>
<p><code>tf.distribute.experimental.ParameterServerStrategy</code> is an asynchronous training strategy where the computation is divided between parameter servers and workers. Parameter servers store model parameters, and workers perform the computations.</p>
<pre><code class="lang-python">cluster_spec = {
    <span class="hljs-string">'worker'</span>: [<span class="hljs-string">'worker1.example.com:2222'</span>, <span class="hljs-string">'worker2.example.com:2222'</span>],
    <span class="hljs-string">'ps'</span>: [<span class="hljs-string">'ps0.example.com:2222'</span>]
}

os.environ[<span class="hljs-string">'TF_CONFIG'</span>] = json.dumps({
    <span class="hljs-string">'cluster'</span>: cluster_spec,
    <span class="hljs-string">'task'</span>: {<span class="hljs-string">'type'</span>: <span class="hljs-string">'worker'</span>, <span class="hljs-string">'index'</span>: <span class="hljs-number">0</span>}

strategy = tf.distribute.experimental.ParameterServerStrategy()
</code></pre>
<p><strong>Training with ParameterServerStrategy</strong></p>
<pre><code class="lang-python"><span class="hljs-keyword">with</span> strategy.scope():
    model = create_model()
    model.compile(optimizer=<span class="hljs-string">'adam'</span>,
                  loss=<span class="hljs-string">'sparse_categorical_crossentropy'</span>,
                  metrics=[<span class="hljs-string">'accuracy'</span>])

    model.fit(train_dataset, epochs=<span class="hljs-number">5</span>)
    model.evaluate(test_dataset)
</code></pre>
<h2 id="heading-monitoring-and-debugging">Monitoring and Debugging</h2>
<p>Monitoring and debugging distributed training can be challenging due to the complexity and scale of operations. TensorFlow provides several tools to help with this process, including TensorBoard, logging, and callbacks.</p>
<p><strong>Using TensorBoard</strong></p>
<p>TensorBoard is a powerful visualization tool that allows you to track and visualize metrics such as loss and accuracy during training. It can also display graphs, histograms, and other metrics to help you understand your model's behavior.</p>
<p>To use TensorBoard, you need to set up a TensorBoard callback during model training. This callback will log the metrics to a specified directory.</p>
<pre><code class="lang-python">log_dir = <span class="hljs-string">"logs/"</span>
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=<span class="hljs-number">1</span>)

<span class="hljs-keyword">with</span> strategy.scope():
    model = create_model()
    model.compile(optimizer=<span class="hljs-string">'adam'</span>,
                  loss=<span class="hljs-string">'sparse_categorical_crossentropy'</span>,
                  metrics=[<span class="hljs-string">'accuracy'</span>])

model.fit(train_dataset, epochs=<span class="hljs-number">5</span>, callbacks=[tensorboard_callback])
model.evaluate(test_dataset)
</code></pre>
<p><strong>Launching TensorBoard</strong></p>
<p>To launch TensorBoard, run the following command in your terminal:</p>
<pre><code class="lang-bash">tensorboard --logdir=logs/
</code></pre>
<p>This will start a local server where you can visualize the training metrics. Open your browser and navigate to <a target="_blank" href="http://localhost:6006/"><code>http://localhost:6006/</code></a> to view the TensorBoard dashboard.</p>
<p><strong>Using Logging</strong></p>
<p>Logging is another useful tool for monitoring and debugging your training process. You can use Python’s built-in logging module to log messages and metrics during training.</p>
<pre><code class="lang-python">logging.basicConfig(level=logging.INFO)
logging.info(<span class="hljs-string">"Starting model training..."</span>)

<span class="hljs-keyword">with</span> strategy.scope():
    model = create_model()
    model.compile(optimizer=<span class="hljs-string">'adam'</span>,
                  loss=<span class="hljs-string">'sparse_categorical_crossentropy'</span>,
                  metrics=[<span class="hljs-string">'accuracy'</span>])

model.fit(train_dataset, epochs=<span class="hljs-number">5</span>, callbacks=[tensorboard_callback])
model.evaluate(test_dataset)

logging.info(<span class="hljs-string">"Model training completed."</span>)
</code></pre>
<p><strong>Using Callbacks</strong></p>
<p>Callbacks are powerful tools that allow you to perform actions at various stages of the training process. TensorFlow provides several built-in callbacks, and you can also create custom callbacks to suit your needs.</p>
<p><strong>Built-In Callbacks</strong></p>
<p>TensorFlow includes several built-in callbacks, such as <code>EarlyStopping</code>, <code>ModelCheckpoint</code>, and <code>ReduceLROnPlateau</code>.</p>
<pre><code class="lang-python">early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor=<span class="hljs-string">'val_loss'</span>, patience=<span class="hljs-number">3</span>)
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath=<span class="hljs-string">'model.h5'</span>, save_best_only=<span class="hljs-literal">True</span>)
reduce_lr_callback = tf.keras.callbacks.ReduceLROnPlateau(monitor=<span class="hljs-string">'val_loss'</span>, factor=<span class="hljs-number">0.2</span>, patience=<span class="hljs-number">2</span>)

<span class="hljs-keyword">with</span> strategy.scope():
    model = create_model()
    model.compile(optimizer=<span class="hljs-string">'adam'</span>,
                  loss=<span class="hljs-string">'sparse_categorical_crossentropy'</span>,
                  metrics=[<span class="hljs-string">'accuracy'</span>])

    model.fit(train_dataset, epochs=<span class="hljs-number">5</span>, validation_data=test_dataset,
              callbacks=[tensorboard_callback, early_stopping_callback, model_checkpoint_callback, reduce_lr_callback])
</code></pre>
<p><strong>Custom Callbacks</strong></p>
<p>You can also create custom callbacks by subclassing <code>tf.keras.callbacks.Callback</code>.</p>
<pre><code class="lang-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">CustomCallback</span>(<span class="hljs-params">tf.keras.callbacks.Callback</span>):</span>
    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">on_epoch_end</span>(<span class="hljs-params">self, epoch, logs=None</span>):</span>
        logging.info(<span class="hljs-string">f"Epoch <span class="hljs-subst">{epoch}</span> ended with loss: <span class="hljs-subst">{logs[<span class="hljs-string">'loss'</span>]}</span> and accuracy: <span class="hljs-subst">{logs[<span class="hljs-string">'accuracy'</span>]}</span>"</span>)

<span class="hljs-keyword">with</span> strategy.scope():
    model = create_model()
    model.compile(optimizer=<span class="hljs-string">'adam'</span>,
                  loss=<span class="hljs-string">'sparse_categorical_crossentropy'</span>,
                  metrics=[<span class="hljs-string">'accuracy'</span>])

    model.fit(train_dataset, epochs=<span class="hljs-number">5</span>, validation_data=test_dataset,
              callbacks=[tensorboard_callback, CustomCallback()])
</code></pre>
<p><strong>Debugging with tf.debugging</strong></p>
<p>TensorFlow also provides debugging tools in</p>
<p>the <code>tf.debugging</code> module to catch and diagnose issues during training. For example, you can use <code>tf.debugging.assert_equal</code> to ensure that tensors have expected values.</p>
<pre><code class="lang-python">a = tf.constant(<span class="hljs-number">1</span>)
b = tf.constant(<span class="hljs-number">2</span>)

tf.debugging.assert_equal(a, b, message=<span class="hljs-string">"Tensors are not equal"</span>)
</code></pre>
<h2 id="heading-conclusion">Conclusion</h2>
<p>Distributed training with TensorFlow can significantly accelerate the training process of your models by leveraging multiple devices and machines. This article covered the basics of setting up and running distributed training using various distribution strategies provided by TensorFlow. By understanding and utilizing these strategies, you can scale your machine learning workflows to handle larger datasets and more complex models efficiently.</p>
<p>Here is a summary of what we covered:</p>
<ol>
<li><p><strong>Introduction to Distributed Training</strong>: Understanding the need and benefits of distributed training.</p>
</li>
<li><p><strong>Types of Distributed Strategies</strong>: Exploring different strategies like MirroredStrategy, MultiWorkerMirroredStrategy, TPUStrategy, and ParameterServerStrategy.</p>
</li>
<li><p><strong>Preparing the Data</strong>: Loading and preprocessing the dataset.</p>
</li>
<li><p><strong>Defining the Model</strong>: Creating a simple CNN model using TensorFlow's Keras API.</p>
</li>
<li><p><strong>Configuring the Distributed Strategy</strong>: Setting up the appropriate distribution strategy for your training.</p>
</li>
<li><p><strong>Monitoring and Debugging</strong>: Using TensorBoard to monitor and debug the training process.</p>
</li>
</ol>
<p>With this knowledge, you are now equipped to start leveraging the power of distributed training to build and train more efficient and scalable machine learning models. Happy coding!</p>
]]></content:encoded></item><item><title><![CDATA[Implementing Advanced Model Architecture with TensorFlow - Part II]]></title><description><![CDATA[if you are finding this for the first time, it means you've missed Part I, it's recommended that you start from the beginning, okay? Let's do that real quick here.
Done? Okay. Let's go...
Implementing Attention Mechanisms
Attention mechanisms have re...]]></description><link>https://kambale.dev/advanced-model-architecture-part-ii</link><guid isPermaLink="true">https://kambale.dev/advanced-model-architecture-part-ii</guid><category><![CDATA[TensorFlow]]></category><category><![CDATA[models]]></category><category><![CDATA[architecture]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Tue, 30 Jul 2024 00:20:53 GMT</pubDate><enclosure url="https://cdn.hashnode.com/res/hashnode/image/upload/v1722044079231/65d18b11-4320-4d72-9d9d-2eae1d93f3e6.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<p><em>if you are finding this for the first time, it means you've missed</em> <a target="_blank" href="https://kambale.dev/advanced-model-architecture-part-i"><em>Part I</em></a><em>, it's recommended that you start from the beginning, okay? Let's do that real quick</em> <a target="_blank" href="https://kambale.dev/advanced-model-architecture-part-i"><em>here</em></a><em>.</em></p>
<p><strong>Done? Okay. Let's go...</strong></p>
<h2 id="heading-implementing-attention-mechanisms">Implementing Attention Mechanisms</h2>
<p>Attention mechanisms have revolutionized the field of deep learning, particularly in natural language processing (NLP) and computer vision. They allow models to focus on specific parts of the input sequence or data, effectively improving the model's ability to capture dependencies and relationships.</p>
<h3 id="heading-understanding-attention">Understanding Attention</h3>
<p>Attention mechanisms work by assigning different weights to different parts of the input, allowing the model to focus on the most relevant parts. This is particularly useful in sequence-to-sequence tasks such as machine translation, where certain words in the input sequence may be more important than others for generating the output sequence.</p>
<h4 id="heading-types-of-attention">Types of Attention</h4>
<ol>
<li><p><strong>Self-Attention</strong>: Computes attention weights within the same sequence, allowing each element to focus on other elements in the sequence.</p>
</li>
<li><p><strong>Cross-Attention</strong>: Computes attention weights between two different sequences, such as in encoder-decoder models.</p>
</li>
</ol>
<h3 id="heading-scaled-dot-product-attention">Scaled Dot-Product Attention</h3>
<p>The scaled dot-product attention mechanism is a common type of attention used in many models, including the Transformer. It involves three main components: queries (Q), keys (K), and values (V).</p>
<p>$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$</p><h3 id="heading-multi-head-attention">Multi-Head Attention</h3>
<p>Multi-head attention extends the concept of single attention by applying multiple attention mechanisms in parallel, allowing the model to focus on different parts of the input sequence simultaneously.</p>
<p>$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \text{head}_2, ..., \text{head}_h)W^O$$</p><p>where each head is an independent attention mechanism.</p>
<h3 id="heading-implementing-attention-in-tensorflow">Implementing Attention in TensorFlow</h3>
<p>Here's an example of implementing scaled dot-product attention and multi-head attention in TensorFlow.</p>
<h4 id="heading-scaled-dot-product-attention-1">Scaled Dot-Product Attention</h4>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf

<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">scaled_dot_product_attention</span>(<span class="hljs-params">q, k, v, mask=None</span>):</span>
    matmul_qk = tf.matmul(q, k, transpose_b=<span class="hljs-literal">True</span>)
    dk = tf.cast(tf.shape(k)[<span class="hljs-number">-1</span>], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

    <span class="hljs-keyword">if</span> mask <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:
        scaled_attention_logits += (mask * <span class="hljs-number">-1e9</span>)

    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=<span class="hljs-number">-1</span>)
    output = tf.matmul(attention_weights, v)

    <span class="hljs-keyword">return</span> output, attention_weights
</code></pre>
<h4 id="heading-multi-head-attention-1">Multi-Head Attention</h4>
<pre><code class="lang-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">MultiHeadAttention</span>(<span class="hljs-params">tf.keras.layers.Layer</span>):</span>
    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span>(<span class="hljs-params">self, d_model, num_heads</span>):</span>
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model

        <span class="hljs-keyword">assert</span> d_model % self.num_heads == <span class="hljs-number">0</span>

        self.depth = d_model // self.num_heads

        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)

    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">split_heads</span>(<span class="hljs-params">self, x, batch_size</span>):</span>
        x = tf.reshape(x, (batch_size, <span class="hljs-number">-1</span>, self.num_heads, self.depth))
        <span class="hljs-keyword">return</span> tf.transpose(x, perm=[<span class="hljs-number">0</span>, <span class="hljs-number">2</span>, <span class="hljs-number">1</span>, <span class="hljs-number">3</span>])

    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">call</span>(<span class="hljs-params">self, v, k, q, mask</span>):</span>
        batch_size = tf.shape(q)[<span class="hljs-number">0</span>]

        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)

        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)

        scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)

        scaled_attention = tf.transpose(scaled_attention, perm=[<span class="hljs-number">0</span>, <span class="hljs-number">2</span>, <span class="hljs-number">1</span>, <span class="hljs-number">3</span>])
        concat_attention = tf.reshape(scaled_attention, (batch_size, <span class="hljs-number">-1</span>, self.d_model))

        output = self.dense(concat_attention)

        <span class="hljs-keyword">return</span> output, attention_weights
</code></pre>
<h3 id="heading-using-attention-in-a-transformer-model">Using Attention in a Transformer Model</h3>
<p>The Transformer model relies heavily on attention mechanisms. Here's a brief overview of how attention is used in the Transformer architecture.</p>
<h4 id="heading-transformer-encoder">Transformer Encoder</h4>
<p>The encoder consists of multiple layers, each containing a multi-head self-attention mechanism and a feed-forward neural network.</p>
<pre><code class="lang-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">EncoderLayer</span>(<span class="hljs-params">tf.keras.layers.Layer</span>):</span>
    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span>(<span class="hljs-params">self, d_model, num_heads, dff, rate=<span class="hljs-number">0.1</span></span>):</span>
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadAttention(d_model, num_heads)
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(dff, activation=<span class="hljs-string">'relu'</span>),
            tf.keras.layers.Dense(d_model)
        ])

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=<span class="hljs-number">1e-6</span>)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=<span class="hljs-number">1e-6</span>)

        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)

    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">call</span>(<span class="hljs-params">self, x, training, mask</span>):</span>
        attn_output, _ = self.mha(x, x, x, mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)

        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output)

        <span class="hljs-keyword">return</span> out2
</code></pre>
<h2 id="heading-building-generative-models">Building Generative Models</h2>
<p>Generative models are a class of machine learning models that learn to generate new data samples that resemble the training data. Two popular types of generative models are Variational Autoencoders (VAEs) and Generative Adversarial Networks (GANs).</p>
<h3 id="heading-variational-autoencoders-vaes">Variational Autoencoders (VAEs)</h3>
<p>Variational Autoencoders (VAEs) are probabilistic graphical models that aim to learn a latent representation of the data, which can then be used to generate new samples. VAEs consist of two main components: the encoder and the decoder.</p>
<h4 id="heading-key-components-of-vaes">Key Components of VAEs</h4>
<ol>
<li><p><strong>Encoder</strong>: Maps the input data to a latent space, producing a mean and a variance for each dimension of the latent space.</p>
</li>
<li><p><strong>Decoder</strong>: Maps the latent representation back to the data space, generating new samples that resemble the original data.</p>
</li>
</ol>
<h4 id="heading-implementing-a-vae-in-tensorflow">Implementing a VAE in TensorFlow</h4>
<p>Here's an example of implementing a simple VAE for generating images.</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf
<span class="hljs-keyword">from</span> tensorflow.keras <span class="hljs-keyword">import</span> layers, models

<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">Sampling</span>(<span class="hljs-params">layers.Layer</span>):</span>
    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">call</span>(<span class="hljs-params">self, inputs</span>):</span>
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[<span class="hljs-number">0</span>]
        dim = tf.shape(z_mean)[<span class="hljs-number">1</span>]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        <span class="hljs-keyword">return</span> z_mean + tf.exp(<span class="hljs-number">0.5</span> * z_log_var) * epsilon

latent_dim = <span class="hljs-number">2</span>
encoder_inputs = tf.keras.Input(shape=(<span class="hljs-number">28</span>, <span class="hljs-number">28</span>, <span class="hljs-number">1</span>))
x = layers.Flatten()(encoder_inputs)
x = layers.Dense(<span class="hljs-number">512</span>, activation=<span class="hljs-string">'relu'</span>)(x)
z_mean = layers.Dense(latent_dim)(x)
z_log_var = layers.Dense(latent_dim)(x)
z = Sampling()([z_mean, z_log_var])
encoder = tf.keras.Model(encoder_inputs, [z_mean, z_log_var, z], name=<span class="hljs-string">"encoder"</span>)

decoder_inputs = tf.keras.Input(shape=(latent_dim,))
x = layers.Dense(<span class="hljs-number">512</span>, activation=<span class="hljs-string">'relu'</span>)(decoder_inputs)
x = layers.Dense(<span class="hljs-number">28</span> * <span class="hljs-number">28</span> * <span class="hljs-number">1</span>, activation=<span class="hljs-string">'sigmoid'</span>)(x)
decoder_outputs = layers.Reshape((<span class="hljs-number">28</span>, <span class="hljs-number">28</span>, <span class="hljs-number">1</span>))(x)
decoder = tf.keras.Model(decoder_inputs, decoder_outputs, name=<span class="hljs-string">"decoder"</span>)

<span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">VAE</span>(<span class="hljs-params">tf.keras.Model</span>):</span>
    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span>(<span class="hljs-params">self, encoder, decoder, **kwargs</span>):</span>
        super(VAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">call</span>(<span class="hljs-params">self, inputs</span>):</span>
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        kl_loss = <span class="hljs-number">-0.5</span> * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + <span class="hljs-number">1</span>)
        self.add_loss(kl_loss)
        <span class="hljs-keyword">return</span> reconstructed

vae = VAE(encoder, decoder)
vae.compile(optimizer=<span class="hljs-string">'adam'</span>, loss=<span class="hljs-string">'binary_crossentropy'</span>)

(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype(<span class="hljs-string">"float32"</span>) / <span class="hljs-number">255.0</span>
x_train = x_train.reshape(<span class="hljs-number">-1</span>, <span class="hljs-number">28</span>, <span class="hljs-number">28</span>, <span class="hljs-number">1</span>)
x_test = x_test.astype(<span class="hljs-string">"float32"</span>) / <span class="hljs-number">255.0</span>
x_test = x_test.reshape(<span class="hljs-number">-1</span>, <span class="hljs-number">28</span>, <span class="hljs-number">28</span>, <span class="hljs-number">1</span>)

vae.fit(x_train, x_train, epochs=<span class="hljs-number">30</span>, batch_size=<span class="hljs-number">128</span>, validation_data=(x_test, x_test))
</code></pre>
<ul>
<li><p><strong>Encoder</strong>: The encoder consists of a dense layer followed by two output layers: one for the mean and one for the log variance of the latent space.</p>
</li>
<li><p><strong>Decoder</strong>: The decoder maps the latent space back to the original data space.</p>
</li>
<li><p><strong>Sampling Layer</strong>: The sampling layer implements the reparameterization trick, which allows backpropagation through the stochastic latent space.</p>
</li>
<li><p><strong>VAE Model</strong>: The VAE model combines the encoder and decoder, adding the KL divergence loss to encourage the latent space to follow a standard normal distribution.</p>
</li>
</ul>
<h3 id="heading-generative-adversarial-networks-gans">Generative Adversarial Networks (GANs)</h3>
<p>Generative Adversarial Networks (GANs) consist of two neural networks: the generator and the discriminator. The generator learns to produce realistic data samples, while the discriminator learns to distinguish between real and generated samples. The two networks are trained in a competitive process.</p>
<h4 id="heading-key-components-of-gans">Key Components of GANs</h4>
<ol>
<li><p><strong>Generator</strong>: Takes random noise as input and generates data samples.</p>
</li>
<li><p><strong>Discriminator</strong>: Takes data samples as input and classifies them as real or fake.</p>
</li>
</ol>
<h4 id="heading-implementing-a-gan-in-tensorflow">Implementing a GAN in TensorFlow</h4>
<p>Here's an example of implementing a simple GAN for generating images.</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf
<span class="hljs-keyword">from</span> tensorflow.keras <span class="hljs-keyword">import</span> layers, models

<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">build_generator</span>():</span>
    model = models.Sequential()
    model.add(layers.Dense(<span class="hljs-number">256</span>, activation=<span class="hljs-string">'relu'</span>, input_dim=<span class="hljs-number">100</span>))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(<span class="hljs-number">512</span>, activation=<span class="hljs-string">'relu'</span>))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(<span class="hljs-number">1024</span>, activation=<span class="hljs-string">'relu'</span>))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(<span class="hljs-number">28</span> * <span class="hljs-number">28</span> * <span class="hljs-number">1</span>, activation=<span class="hljs-string">'tanh'</span>))
    model.add(layers.Reshape((<span class="hljs-number">28</span>, <span class="hljs-number">28</span>, <span class="hljs-number">1</span>)))
    <span class="hljs-keyword">return</span> model

<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">build_discriminator</span>():</span>
    model = models.Sequential()
    model.add(layers.Flatten(input_shape=(<span class="hljs-number">28</span>, <span class="hljs-number">28</span>, <span class="hljs-number">1</span>)))
    model.add(layers.Dense(<span class="hljs-number">512</span>, activation=<span class="hljs-string">'relu'</span>))
    model.add(layers.Dense(<span class="hljs-number">256</span>, activation=<span class="hljs-string">'relu'</span>))
    model.add(layers.Dense(<span class="hljs-number">1</span>, activation=<span class="hljs-string">'sigmoid'</span>))
    <span class="hljs-keyword">return</span> model

generator = build_generator()
discriminator = build_discriminator()
discriminator.compile(optimizer=<span class="hljs-string">'adam'</span>, loss=<span class="hljs-string">'binary_crossentropy'</span>, metrics=[<span class="hljs-string">'accuracy'</span>])

gan_input = tf.keras.Input(shape=(<span class="hljs-number">100</span>,))
generated_image = generator(gan_input)
discriminator.trainable = <span class="hljs-literal">False</span>
gan_output = discriminator(generated_image)
gan = tf.keras.Model(gan_input, gan_output)
gan.compile(optimizer=<span class="hljs-string">'adam'</span>, loss=<span class="hljs-string">'binary_crossentropy'</span>)

(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype(<span class="hljs-string">"float32"</span>) / <span class="hljs-number">255.0</span>
x_train = x_train.reshape(<span class="hljs-number">-1</span>, <span class="hljs-number">28</span>, <span class="hljs-number">28</span>, <span class="hljs-number">1</span>)

<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np

batch_size = <span class="hljs-number">128</span>
epochs = <span class="hljs-number">10000</span>
half_batch = batch_size // <span class="hljs-number">2</span>

<span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> range(epochs):
    idx = np.random.randint(<span class="hljs-number">0</span>, x_train.shape[<span class="hljs-number">0</span>], half_batch)
    real_images = x_train[idx]
    noise = np.random.normal(<span class="hljs-number">0</span>, <span class="hljs-number">1</span>, (half_batch, <span class="hljs-number">100</span>))
    fake_images = generator.predict(noise)

    d_loss_real = discriminator.train_on_batch(real_images, np.ones((half_batch, <span class="hljs-number">1</span>)))
    d_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((half_batch, <span class="hljs-number">1</span>)))
    d_loss = <span class="hljs-number">0.5</span> * np.add(d_loss_real, d_loss_fake)

    noise = np.random.normal(<span class="hljs-number">0</span>, <span class="hljs-number">1</span>, (batch_size, <span class="hljs-number">100</span>))
    valid_y = np.array([<span class="hljs-number">1</span>] * batch_size)
    g_loss = gan.train_on_batch(noise, valid_y)

    <span class="hljs-keyword">if</span> epoch % <span class="hljs-number">100</span> == <span class="hljs-number">0</span>:
        print(<span class="hljs-string">f"<span class="hljs-subst">{epoch}</span> [D loss: <span class="hljs-subst">{d_loss[<span class="hljs-number">0</span>]}</span> | D accuracy: <span class="hljs-subst">{<span class="hljs-number">100</span>*d_loss[<span class="hljs-number">1</span>]}</span>] [G loss: <span class="hljs-subst">{g_loss}</span>]"</span>)
</code></pre>
<ul>
<li><p><strong>Generator</strong>: The generator network consists of dense layers followed by batch normalization and activation functions. It maps random noise to a data sample.</p>
</li>
<li><p><strong>Discriminator</strong>: The discriminator network consists of dense layers and activation functions. It classifies data samples as real or fake.</p>
</li>
<li><p><strong>Training Loop</strong>: The GAN is trained in a loop where the discriminator is trained on real and fake samples, followed by training the generator to produce samples that can fool the discriminator.</p>
</li>
</ul>
<h2 id="heading-hyperparameter-tuning-and-model-evaluation">Hyperparameter Tuning and Model Evaluation</h2>
<p>Hyperparameter tuning and model evaluation are crucial steps in the development of machine learning models. Proper tuning ensures optimal performance, while thorough evaluation helps understand the model's strengths and weaknesses.</p>
<h3 id="heading-hyperparameter-tuning">Hyperparameter Tuning</h3>
<p>Hyperparameters are settings that define the model structure and how it is trained, such as learning rate, batch size, number of layers, and units per layer. Unlike parameters learned during training, hyperparameters need to be set before the training process begins.</p>
<h4 id="heading-importance-of-hyperparameter-tuning">Importance of Hyperparameter Tuning</h4>
<p>Effective hyperparameter tuning can significantly improve model performance. Poorly chosen hyperparameters can lead to underfitting or overfitting, resulting in a model that performs poorly on unseen data.</p>
<h4 id="heading-techniques-for-hyperparameter-tuning">Techniques for Hyperparameter Tuning</h4>
<ol>
<li><p><strong>Grid Search</strong>: Exhaustively searches over a specified hyperparameter grid.</p>
</li>
<li><p><strong>Random Search</strong>: Samples hyperparameters randomly from a defined range.</p>
</li>
<li><p><strong>Bayesian Optimization</strong>: Uses probabilistic models to find the optimal hyperparameters.</p>
</li>
<li><p><strong>Hyperband</strong>: Combines random search and early stopping to efficiently find optimal hyperparameters.</p>
</li>
</ol>
<h4 id="heading-grid-search">Grid Search</h4>
<p>Grid search is a brute-force technique that searches over a predefined grid of hyperparameters.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> sklearn.model_selection <span class="hljs-keyword">import</span> GridSearchCV
<span class="hljs-keyword">from</span> sklearn.ensemble <span class="hljs-keyword">import</span> RandomForestClassifier

param_grid = {
    <span class="hljs-string">'n_estimators'</span>: [<span class="hljs-number">100</span>, <span class="hljs-number">200</span>, <span class="hljs-number">300</span>],
    <span class="hljs-string">'max_depth'</span>: [<span class="hljs-number">10</span>, <span class="hljs-number">20</span>, <span class="hljs-number">30</span>],
    <span class="hljs-string">'min_samples_split'</span>: [<span class="hljs-number">2</span>, <span class="hljs-number">5</span>, <span class="hljs-number">10</span>]
}

grid_search = GridSearchCV(estimator=RandomForestClassifier(), param_grid=param_grid, cv=<span class="hljs-number">3</span>, n_jobs=<span class="hljs-number">-1</span>, verbose=<span class="hljs-number">2</span>)
grid_search.fit(X_train, y_train)

print(<span class="hljs-string">"Best Hyperparameters:"</span>, grid_search.best_params_)
</code></pre>
<h4 id="heading-random-search">Random Search</h4>
<p>Random search samples hyperparameters from a specified distribution, which can be more efficient than grid search.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> sklearn.model_selection <span class="hljs-keyword">import</span> RandomizedSearchCV
<span class="hljs-keyword">from</span> sklearn.ensemble <span class="hljs-keyword">import</span> RandomForestClassifier
<span class="hljs-keyword">from</span> scipy.stats <span class="hljs-keyword">import</span> randint

param_dist = {
    <span class="hljs-string">'n_estimators'</span>: randint(<span class="hljs-number">100</span>, <span class="hljs-number">500</span>),
    <span class="hljs-string">'max_depth'</span>: randint(<span class="hljs-number">10</span>, <span class="hljs-number">50</span>),
    <span class="hljs-string">'min_samples_split'</span>: randint(<span class="hljs-number">2</span>, <span class="hljs-number">11</span>)
}

random_search = RandomizedSearchCV(estimator=RandomForestClassifier(), param_distributions=param_dist, n_iter=<span class="hljs-number">100</span>, cv=<span class="hljs-number">3</span>, n_jobs=<span class="hljs-number">-1</span>, verbose=<span class="hljs-number">2</span>)
random_search.fit(X_train, y_train)

print(<span class="hljs-string">"Best Hyperparameters:"</span>, random_search.best_params_)
</code></pre>
<h4 id="heading-bayesian-optimization">Bayesian Optimization</h4>
<p>Bayesian optimization uses a surrogate model to estimate the performance of hyperparameters and efficiently searches the space.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> skopt <span class="hljs-keyword">import</span> BayesSearchCV
<span class="hljs-keyword">from</span> sklearn.ensemble <span class="hljs-keyword">import</span> RandomForestClassifier

param_space = {
    <span class="hljs-string">'n_estimators'</span>: (<span class="hljs-number">100</span>, <span class="hljs-number">500</span>),
    <span class="hljs-string">'max_depth'</span>: (<span class="hljs-number">10</span>, <span class="hljs-number">50</span>),
    <span class="hljs-string">'min_samples_split'</span>: (<span class="hljs-number">2</span>, <span class="hljs-number">11</span>)
}

bayes_search = BayesSearchCV(estimator=RandomForestClassifier(), search_spaces=param_space, n_iter=<span class="hljs-number">32</span>, cv=<span class="hljs-number">3</span>, n_jobs=<span class="hljs-number">-1</span>, verbose=<span class="hljs-number">2</span>)
bayes_search.fit(X_train, y_train)

print(<span class="hljs-string">"Best Hyperparameters:"</span>, bayes_search.best_params_)
</code></pre>
<h4 id="heading-hyperband">Hyperband</h4>
<p>Hyperband combines random search with early stopping to find the best hyperparameters more efficiently.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> keras_tuner.tuners <span class="hljs-keyword">import</span> Hyperband
<span class="hljs-keyword">from</span> tensorflow.keras.models <span class="hljs-keyword">import</span> Sequential
<span class="hljs-keyword">from</span> tensorflow.keras.layers <span class="hljs-keyword">import</span> Dense

<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">build_model</span>(<span class="hljs-params">hp</span>):</span>
    model = Sequential()
    model.add(Dense(units=hp.Int(<span class="hljs-string">'units'</span>, min_value=<span class="hljs-number">32</span>, max_value=<span class="hljs-number">512</span>, step=<span class="hljs-number">32</span>), activation=<span class="hljs-string">'relu'</span>, input_shape=(input_dim,)))
    model.add(Dense(<span class="hljs-number">1</span>, activation=<span class="hljs-string">'sigmoid'</span>))
    model.compile(optimizer=<span class="hljs-string">'adam'</span>, loss=<span class="hljs-string">'binary_crossentropy'</span>, metrics=[<span class="hljs-string">'accuracy'</span>])
    <span class="hljs-keyword">return</span> model

tuner = Hyperband(build_model, objective=<span class="hljs-string">'val_accuracy'</span>, max_epochs=<span class="hljs-number">10</span>, factor=<span class="hljs-number">3</span>, directory=<span class="hljs-string">'my_dir'</span>, project_name=<span class="hljs-string">'helloworld'</span>)
tuner.search(X_train, y_train, epochs=<span class="hljs-number">50</span>, validation_split=<span class="hljs-number">0.2</span>)

print(<span class="hljs-string">"Best Hyperparameters:"</span>, tuner.get_best_hyperparameters()[<span class="hljs-number">0</span>].values)
</code></pre>
<h3 id="heading-model-evaluation">Model Evaluation</h3>
<p>Model evaluation involves assessing the performance of a trained model using various metrics. This helps determine how well the model generalizes to new, unseen data.</p>
<h4 id="heading-evaluation-metrics">Evaluation Metrics</h4>
<ol>
<li><p><strong>Accuracy</strong>: Proportion of correctly predicted instances.</p>
</li>
<li><p><strong>Precision</strong>: Proportion of true positives among the predicted positives.</p>
</li>
<li><p><strong>Recall</strong>: Proportion of true positives among the actual positives.</p>
</li>
<li><p><strong>F1 Score</strong>: Harmonic mean of precision and recall.</p>
</li>
<li><p><strong>ROC-AUC</strong>: Area under the Receiver Operating Characteristic curve, measuring the trade-off between true positive rate and false positive rate.</p>
</li>
<li><p><strong>Mean Squared Error (MSE)</strong>: Average of the squared differences between predicted and actual values (for regression).</p>
</li>
<li><p><strong>Mean Absolute Error (MAE)</strong>: Average of the absolute differences between predicted and actual values (for regression).</p>
</li>
</ol>
<h4 id="heading-cross-validation">Cross-Validation</h4>
<p>Cross-validation is a technique for assessing model performance by splitting the data into multiple folds and training/testing the model on these folds. Common methods include k-fold cross-validation and stratified k-fold cross-validation.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> sklearn.model_selection <span class="hljs-keyword">import</span> cross_val_score
<span class="hljs-keyword">from</span> sklearn.ensemble <span class="hljs-keyword">import</span> RandomForestClassifier

model = RandomForestClassifier(n_estimators=<span class="hljs-number">100</span>)
scores = cross_val_score(model, X, y, cv=<span class="hljs-number">5</span>, scoring=<span class="hljs-string">'accuracy'</span>)

print(<span class="hljs-string">"Cross-Validation Scores:"</span>, scores)
print(<span class="hljs-string">"Mean Accuracy:"</span>, scores.mean())
</code></pre>
<h4 id="heading-confusion-matrix">Confusion Matrix</h4>
<p>A confusion matrix provides a detailed breakdown of model predictions, showing the counts of true positives, true negatives, false positives, and false negatives.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> sklearn.metrics <span class="hljs-keyword">import</span> confusion_matrix, ConfusionMatrixDisplay

y_pred = model.predict(X_test)
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot()
plt.show()
</code></pre>
<h4 id="heading-roc-curve-and-auc">ROC Curve and AUC</h4>
<p>The ROC curve plots the true positive rate against the false positive rate at various threshold settings. The AUC represents the area under this curve.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> sklearn.metrics <span class="hljs-keyword">import</span> roc_curve, roc_auc_score
<span class="hljs-keyword">import</span> matplotlib.pyplot <span class="hljs-keyword">as</span> plt

y_pred_proba = model.predict_proba(X_test)[:, <span class="hljs-number">1</span>]
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
auc = roc_auc_score(y_test, y_pred_proba)

plt.plot(fpr, tpr, label=<span class="hljs-string">f'ROC Curve (AUC = <span class="hljs-subst">{auc:<span class="hljs-number">.2</span>f}</span>)'</span>)
plt.xlabel(<span class="hljs-string">'False Positive Rate'</span>)
plt.ylabel(<span class="hljs-string">'True Positive Rate'</span>)
plt.title(<span class="hljs-string">'Receiver Operating Characteristic'</span>)
plt.legend(loc=<span class="hljs-string">'lower right'</span>)
plt.show()
</code></pre>
<h2 id="heading-conclusion-part-ii">Conclusion - Part II</h2>
<p>Implementing advanced model architectures with TensorFlow encompasses a broad range of techniques and methodologies, each crucial for developing robust, efficient, and high-performing machine learning models. From setting up the development environment to fine-tuning hyperparameters and evaluating models, every step plays a vital role in the model development lifecycle.</p>
<h3 id="heading-key-takeaways">Key Takeaways</h3>
<ol>
<li><p><strong>Implementing Attention Mechanisms</strong>: Attention mechanisms, especially in the context of the Transformer architecture, have revolutionized the way models handle sequential data. By enabling models to focus on relevant parts of the input, attention mechanisms significantly enhance the capability of models to understand complex dependencies.</p>
</li>
<li><p><strong>Building Generative Models</strong>: Generative models like Variational Autoencoders (VAEs) and Generative Adversarial Networks (GANs) open up new possibilities in data generation and augmentation. These models are particularly powerful in applications such as image synthesis, data augmentation, and creative AI tasks.</p>
</li>
<li><p><strong>Hyperparameter Tuning and Model Evaluation</strong>: Hyperparameter tuning is a critical step in optimizing model performance. Techniques like grid search, random search, Bayesian optimization, and Hyperband provide systematic approaches to finding the best hyperparameters. Model evaluation metrics and methods ensure that the models are not only accurate but also generalize well to unseen data.</p>
</li>
</ol>
<h3 id="heading-final-thoughts">Final Thoughts</h3>
<p>Building and deploying advanced model architectures with TensorFlow requires a blend of theoretical knowledge and practical skills. By understanding and applying the concepts covered in this tutorial, developers can build sophisticated models capable of solving a wide range of real-world problems. The journey from setting up the environment to fine-tuning hyperparameters and evaluating model performance is iterative and requires continuous learning and experimentation. With TensorFlow’s powerful capabilities and a systematic approach, the possibilities for innovation in machine learning are vast and exciting.</p>
<p>Embarking on this journey will not only enhance your technical skills but also enable you to contribute to the rapidly advancing field of artificial intelligence, pushing the boundaries of what is possible with machine learning.</p>
]]></content:encoded></item><item><title><![CDATA[Implementing Advanced Model Architecture with TensorFlow - Part I]]></title><description><![CDATA[Introduction
Implementing advanced model architecture with TensorFlow is a crucial aspect of building powerful and effective machine learning models. TensorFlow, an open-source machine learning library, provides a versatile framework for designing, t...]]></description><link>https://kambale.dev/advanced-model-architecture-part-i</link><guid isPermaLink="true">https://kambale.dev/advanced-model-architecture-part-i</guid><category><![CDATA[TensorFlow]]></category><category><![CDATA[models]]></category><category><![CDATA[neural networks]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Sun, 28 Jul 2024 08:19:41 GMT</pubDate><enclosure url="https://cdn.hashnode.com/res/hashnode/image/upload/v1719575000522/b0a1dc33-9088-4f4d-995a-23b6bb4f5759.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<h1 id="heading-introduction">Introduction</h1>
<p>Implementing advanced model architecture with TensorFlow is a crucial aspect of building powerful and effective machine learning models. TensorFlow, an open-source machine learning library, provides a versatile framework for designing, training, and deploying various neural network architectures.</p>
<p>In the rapidly evolving field of machine learning, staying ahead often requires a deep understanding of advanced model architectures. In this article, we will take a dive through creating sophisticated models using TensorFlow, exploring foundational concepts and advanced techniques.</p>
<h3 id="heading-importance-of-advanced-model-architecture">Importance of Advanced Model Architecture</h3>
<p>While basic models are suitable for simple tasks, advanced model architectures are essential for tackling more complex problems and achieving state-of-the-art performance. As machine learning tasks become increasingly sophisticated, the need for specialized architectures, such as neural networks, attention mechanisms, generative models, hyper-parameter tuning and model evaluation.</p>
<h3 id="heading-setting-up-your-environment">Setting Up Your Environment</h3>
<p>Before diving into advanced model architectures, we need to set up our environment. We'll use TensorFlow, a powerful open-source library for machine learning.</p>
<p>We'll use Jupyter notebooks for this article. If you don't have Jupyter installed, you can install it using pip:</p>
<pre><code class="lang-bash">pip install jupyter
</code></pre>
<p><strong>Install TensorFlow</strong></p>
<p>TensorFlow can be installed in Python easily, just like any other module, with a terminal command using <code>pip</code>, the package manager for Python. Open a terminal or command prompt and enter the following command:</p>
<pre><code class="lang-bash">pip install tensorflow
</code></pre>
<p><em>Note: This general installation command may not work for all operating systems.</em></p>
<p><strong>macOS</strong></p>
<p>To install TensorFlow that is optimized for Apple's macOS processors (especially M1 and M2 chips) without going through the troubles of using the general installation, the following command is used:</p>
<pre><code class="lang-bash">pip install tensorflow-macos
</code></pre>
<p><strong>Windows &amp; Ubuntu without GPU</strong></p>
<p>You can install the CPU version of TensorFlow on Windows &amp; Ubuntu if you do not have an external GPU installed or wish to use the CPU:</p>
<pre><code class="lang-bash">pip install tensorflow-cpu
</code></pre>
<p><em>Note: Training neural networks with CPUs has performance issues in comparison to powerful GPUs.</em></p>
<p>TensorFlow also provides a high-level API, called <code>Keras</code>, which can simplify the creation and training of machine learning models. You can install Keras using the following command:</p>
<pre><code class="lang-bash">pip install tensorflow-keras
</code></pre>
<p>To start a Jupyter notebook, run:</p>
<pre><code class="lang-bash">jupyter notebook
</code></pre>
<p><strong>Importing Necessary Libraries</strong></p>
<p>In your Jupyter notebook, start by importing the necessary libraries:</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf
<span class="hljs-keyword">from</span> tensorflow <span class="hljs-keyword">import</span> keras
<span class="hljs-keyword">from</span> tensorflow.keras <span class="hljs-keyword">import</span> layers
<span class="hljs-keyword">import</span> numpy <span class="hljs-keyword">as</span> np
<span class="hljs-keyword">import</span> matplotlib.pyplot <span class="hljs-keyword">as</span> plt
</code></pre>
<h1 id="heading-understanding-custom-layers-and-models">Understanding Custom Layers and Models</h1>
<p>Custom layers and models in TensorFlow give you the flexibility to build complex and tailored machine learning models. This section explores the creation of custom layers, custom models, and demonstrates how they can be integrated into a neural network.</p>
<h3 id="heading-creating-custom-layers">Creating Custom Layers</h3>
<p>Custom layers allow you to encapsulate custom operations in a reusable and modular way. TensorFlow's <code>Layer</code> class is the building block for creating custom layers. Let's break down the process of creating a custom layer.</p>
<h4 id="heading-basic-structure-of-a-custom-layer">Basic Structure of a Custom Layer</h4>
<p>A custom layer typically involves defining the following components:</p>
<ol>
<li><p><strong>Initialization (</strong><code>__init__</code> method): This method is where you define the attributes of the layer.</p>
</li>
<li><p><strong>Build (</strong><code>build</code> method): This method is where you define the weights and other variables that the layer will use.</p>
</li>
<li><p><strong>Call (</strong><code>call</code> method): This method contains the forward pass logic, specifying how the layer should process its inputs to produce outputs.</p>
</li>
</ol>
<h4 id="heading-custom-dense-layer">Custom Dense Layer</h4>
<p>Here's an example of a custom dense (fully connected) layer that adds a scalar value to the input.</p>
<pre><code class="lang-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">MyCustomLayer</span>(<span class="hljs-params">layers.Layer</span>):</span>
    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span>(<span class="hljs-params">self, units=<span class="hljs-number">32</span>, activation=None</span>):</span>
        super(MyCustomLayer, self).__init__()
        self.units = units
        self.activation = keras.activations.get(activation)

    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">build</span>(<span class="hljs-params">self, input_shape</span>):</span>
        self.w = self.add_weight(shape=(input_shape[<span class="hljs-number">-1</span>], self.units),
                                 initializer=<span class="hljs-string">'random_normal'</span>,
                                 trainable=<span class="hljs-literal">True</span>)
        self.b = self.add_weight(shape=(self.units,),
                                 initializer=<span class="hljs-string">'zeros'</span>,
                                 trainable=<span class="hljs-literal">True</span>)

    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">call</span>(<span class="hljs-params">self, inputs</span>):</span>
        <span class="hljs-keyword">return</span> self.activation(tf.matmul(inputs, self.w) + self.b)

inputs = keras.Input(shape=(<span class="hljs-number">784</span>,))
x = MyCustomLayer(<span class="hljs-number">64</span>, activation=<span class="hljs-string">'relu'</span>)(inputs)
outputs = MyCustomLayer(<span class="hljs-number">10</span>, activation=<span class="hljs-string">'softmax'</span>)(x)
model = keras.Model(inputs, outputs)
model.compile(optimizer=<span class="hljs-string">'adam'</span>, loss=<span class="hljs-string">'sparse_categorical_crossentropy'</span>, metrics=[<span class="hljs-string">'accuracy'</span>])
</code></pre>
<ul>
<li><p><strong>Initialization (</strong><code>__init__</code> method): Initializes the number of units and activation function for the layer.</p>
</li>
<li><p><strong>Build (</strong><code>build</code> method): Defines the weights (<code>self.w</code>) and biases (<code>self.b</code>) for the layer.</p>
</li>
<li><p><strong>Call (</strong><code>call</code> method): Implements the forward pass, applying the weights, biases, and activation function to the inputs.</p>
</li>
</ul>
<h3 id="heading-creating-custom-models">Creating Custom Models</h3>
<p>Custom models allow you to define complex architectures beyond the sequential and functional APIs. TensorFlow's <code>Model</code> class is used to create custom models by subclassing it and defining the forward pass logic in the <code>call</code> method.</p>
<h4 id="heading-custom-model-with-functional-api">Custom Model with Functional API</h4>
<p>Let's create a custom model using the functional API, incorporating our custom layer.</p>
<pre><code class="lang-python"><span class="hljs-class"><span class="hljs-keyword">class</span> <span class="hljs-title">MyCustomModel</span>(<span class="hljs-params">keras.Model</span>):</span>
    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span>(<span class="hljs-params">self, units=<span class="hljs-number">32</span>, num_classes=<span class="hljs-number">10</span></span>):</span>
        super(MyCustomModel, self).__init__()
        self.dense1 = MyCustomLayer(units, activation=<span class="hljs-string">'relu'</span>)
        self.dense2 = MyCustomLayer(num_classes, activation=<span class="hljs-string">'softmax'</span>)

    <span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">call</span>(<span class="hljs-params">self, inputs</span>):</span>
        x = self.dense1(inputs)
        <span class="hljs-keyword">return</span> self.dense2(x)

model = MyCustomModel(units=<span class="hljs-number">64</span>, num_classes=<span class="hljs-number">10</span>)
model.compile(optimizer=<span class="hljs-string">'adam'</span>, loss=<span class="hljs-string">'sparse_categorical_crossentropy'</span>, metrics=[<span class="hljs-string">'accuracy'</span>])
</code></pre>
<ul>
<li><p><strong>Initialization (</strong><code>__init__</code> method): Initializes two instances of the custom layer (<code>dense1</code> and <code>dense2</code>) with specified units and activation functions.</p>
</li>
<li><p><strong>Call (</strong><code>call</code> method): Implements the forward pass, applying the first custom layer followed by the second.</p>
</li>
</ul>
<h3 id="heading-custom-training-loop">Custom Training Loop</h3>
<p>For more control over the training process, you can write a custom training loop. This allows you to customize every aspect of the training process, including the forward pass, backward pass, and optimization.</p>
<h4 id="heading-custom-training-loop-1">Custom Training Loop</h4>
<p>Here's an example of a custom training loop for our custom model.</p>
<pre><code class="lang-python">(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train, x_test = x_train.astype(<span class="hljs-string">"float32"</span>) / <span class="hljs-number">255.0</span>, x_test.astype(<span class="hljs-string">"float32"</span>) / <span class="hljs-number">255.0</span>
x_train = x_train.reshape(<span class="hljs-number">-1</span>, <span class="hljs-number">784</span>)
x_test = x_test.reshape(<span class="hljs-number">-1</span>, <span class="hljs-number">784</span>)

epochs = <span class="hljs-number">5</span>
batch_size = <span class="hljs-number">64</span>
optimizer = keras.optimizers.Adam()
loss_fn = keras.losses.SparseCategoricalCrossentropy()

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=<span class="hljs-number">1024</span>).batch(batch_size)

model = MyCustomModel(units=<span class="hljs-number">64</span>, num_classes=<span class="hljs-number">10</span>)

<span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> range(epochs):
    print(<span class="hljs-string">f"Epoch <span class="hljs-subst">{epoch+<span class="hljs-number">1</span>}</span>/<span class="hljs-subst">{epochs}</span>"</span>)
    <span class="hljs-keyword">for</span> step, (x_batch, y_batch) <span class="hljs-keyword">in</span> enumerate(train_dataset):
        <span class="hljs-keyword">with</span> tf.GradientTape() <span class="hljs-keyword">as</span> tape:
            logits = model(x_batch, training=<span class="hljs-literal">True</span>)
            loss = loss_fn(y_batch, logits)
        gradients = tape.gradient(loss, model.trainable_weights)
        optimizer.apply_gradients(zip(gradients, model.trainable_weights))

        <span class="hljs-keyword">if</span> step % <span class="hljs-number">100</span> == <span class="hljs-number">0</span>:
            print(<span class="hljs-string">f"Step <span class="hljs-subst">{step}</span>, Loss: <span class="hljs-subst">{loss.numpy():<span class="hljs-number">.4</span>f}</span>"</span>)
</code></pre>
<ul>
<li><p><strong>Data Preparation</strong>: Loads and preprocesses the MNIST dataset.</p>
</li>
<li><p><strong>Training Loop</strong>: Iterates over epochs and batches, computes gradients, and updates weights using the optimizer.</p>
</li>
</ul>
<h1 id="heading-implementing-neural-network-architectures">Implementing Neural Network Architectures</h1>
<p>In this section, we'll dive into various neural network architectures that are fundamental in deep learning: Convolutional Neural Networks (CNNs), Recurrent Neural Networks (RNNs), and using Transfer Learning with Pre-trained Models. Each architecture is tailored for specific types of tasks and data, and we'll explore how to implement them in TensorFlow.</p>
<h3 id="heading-convolutional-neural-networks-cnns">Convolutional Neural Networks (CNNs)</h3>
<p>Convolutional Neural Networks (CNNs) are specialized for processing grid-like data, such as images. They are designed to automatically and adaptively learn spatial hierarchies of features through backpropagation. CNNs are composed of convolutional layers, pooling layers, and fully connected layers.</p>
<h4 id="heading-key-components-of-cnns">Key Components of CNNs</h4>
<ol>
<li><p><strong>Convolutional Layers</strong>: Apply convolutional operations to the input, using filters to extract features like edges, textures, and patterns.</p>
</li>
<li><p><strong>Pooling Layers</strong>: Reduce the spatial dimensions (width and height) of the data, typically using max pooling or average pooling.</p>
</li>
<li><p><strong>Fully Connected Layers</strong>: Connect every neuron in one layer to every neuron in the next layer, used for classification tasks.</p>
</li>
</ol>
<h4 id="heading-implementing-a-simple-cnn-for-image-classification">Implementing a Simple CNN for Image Classification</h4>
<p>Let's implement a simple CNN for classifying images from the CIFAR-10 dataset.</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf
<span class="hljs-keyword">from</span> tensorflow.keras <span class="hljs-keyword">import</span> layers, models

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / <span class="hljs-number">255.0</span>, x_test / <span class="hljs-number">255.0</span>

model = models.Sequential()
model.add(layers.Conv2D(<span class="hljs-number">32</span>, (<span class="hljs-number">3</span>, <span class="hljs-number">3</span>), activation=<span class="hljs-string">'relu'</span>, input_shape=(<span class="hljs-number">32</span>, <span class="hljs-number">32</span>, <span class="hljs-number">3</span>)))
model.add(layers.MaxPooling2D((<span class="hljs-number">2</span>, <span class="hljs-number">2</span>)))
model.add(layers.Conv2D(<span class="hljs-number">64</span>, (<span class="hljs-number">3</span>, <span class="hljs-number">3</span>), activation=<span class="hljs-string">'relu'</span>))
model.add(layers.MaxPooling2D((<span class="hljs-number">2</span>, <span class="hljs-number">2</span>)))
model.add(layers.Conv2D(<span class="hljs-number">64</span>, (<span class="hljs-number">3</span>, <span class="hljs-number">3</span>), activation=<span class="hljs-string">'relu'</span>))

model.add(layers.Flatten())
model.add(layers.Dense(<span class="hljs-number">64</span>, activation=<span class="hljs-string">'relu'</span>))
model.add(layers.Dense(<span class="hljs-number">10</span>, activation=<span class="hljs-string">'softmax'</span>))

model.compile(optimizer=<span class="hljs-string">'adam'</span>, loss=<span class="hljs-string">'sparse_categorical_crossentropy'</span>, metrics=[<span class="hljs-string">'accuracy'</span>])

model.fit(x_train, y_train, epochs=<span class="hljs-number">10</span>, validation_data=(x_test, y_test))
</code></pre>
<ul>
<li><p><strong>Convolutional Layers</strong>: The first layer has 32 filters of size 3x3 and ReLU activation. This is followed by a max pooling layer. This pattern repeats, with increasing filter sizes.</p>
</li>
<li><p><strong>Fully Connected Layers</strong>: After flattening the output from the convolutional layers, we add a dense layer with 64 units and ReLU activation, followed by a dense layer with 10 units and softmax activation for classification.</p>
</li>
</ul>
<h3 id="heading-recurrent-neural-networks-rnns">Recurrent Neural Networks (RNNs)</h3>
<p>Recurrent Neural Networks (RNNs) are designed for sequential data, such as time series or text. They maintain a hidden state that captures information about previous inputs, making them suitable for tasks where context or order matters.</p>
<h4 id="heading-key-components-of-rnns">Key Components of RNNs</h4>
<ol>
<li><p><strong>Recurrent Layers</strong>: Process each element of the sequence, maintaining a hidden state that is updated at each step.</p>
</li>
<li><p><strong>LSTM and GRU</strong>: Variants of RNNs that use gating mechanisms to better capture long-range dependencies and mitigate the vanishing gradient problem.</p>
</li>
</ol>
<h4 id="heading-implementing-an-rnn-for-text-classification">Implementing an RNN for Text Classification</h4>
<p>Let's implement an RNN for classifying sentiments in text data using the IMDB dataset.</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf
<span class="hljs-keyword">from</span> tensorflow.keras <span class="hljs-keyword">import</span> layers, models
<span class="hljs-keyword">from</span> tensorflow.keras.datasets <span class="hljs-keyword">import</span> imdb
<span class="hljs-keyword">from</span> tensorflow.keras.preprocessing <span class="hljs-keyword">import</span> sequence

max_features = <span class="hljs-number">10000</span>
maxlen = <span class="hljs-number">500</span>
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)
x_train = sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = sequence.pad_sequences(x_test, maxlen=maxlen)

model = models.Sequential()
model.add(layers.Embedding(max_features, <span class="hljs-number">128</span>))
model.add(layers.SimpleRNN(<span class="hljs-number">128</span>))
model.add(layers.Dense(<span class="hljs-number">1</span>, activation=<span class="hljs-string">'sigmoid'</span>))

model.compile(optimizer=<span class="hljs-string">'adam'</span>, loss=<span class="hljs-string">'binary_crossentropy'</span>, metrics=[<span class="hljs-string">'accuracy'</span>])

model.fit(x_train, y_train, epochs=<span class="hljs-number">10</span>, validation_data=(x_test, y_test))
</code></pre>
<ul>
<li><p><strong>Embedding Layer</strong>: Converts the input sequences into dense vectors of fixed size.</p>
</li>
<li><p><strong>SimpleRNN Layer</strong>: Processes the sequence data, maintaining a hidden state that captures information about the sequence.</p>
</li>
<li><p><strong>Dense Layer</strong>: Outputs a single value with sigmoid activation for binary classification (positive or negative sentiment).</p>
</li>
</ul>
<h3 id="heading-transfer-learning-with-pre-trained-models">Transfer Learning with Pre-trained Models</h3>
<p>Transfer learning leverages pre-trained models, usually trained on large datasets, and fine-tunes them for a specific task. This approach is beneficial when you have limited data.</p>
<h4 id="heading-steps-in-transfer-learning">Steps in Transfer Learning</h4>
<ol>
<li><p><strong>Select a Pre-trained Model</strong>: Choose a model pre-trained on a large dataset, such as ImageNet.</p>
</li>
<li><p><strong>Load the Pre-trained Model</strong>: Load the model with pre-trained weights, excluding the top layers.</p>
</li>
<li><p><strong>Add Custom Layers</strong>: Add new layers for your specific task.</p>
</li>
<li><p><strong>Freeze the Base Layers</strong>: Freeze the weights of the pre-trained layers.</p>
</li>
<li><p><strong>Compile and Train the Model</strong>: Compile and train the model on your dataset.</p>
</li>
<li><p><strong>Unfreeze and Fine-tune</strong>: Optionally, unfreeze some layers and fine-tune the entire model.</p>
</li>
</ol>
<h4 id="heading-transfer-learning-with-resnet50">Transfer Learning with ResNet50</h4>
<p>Let's implement transfer learning using the ResNet50 model for image classification on the CIFAR-10 dataset.</p>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> tensorflow.keras.applications <span class="hljs-keyword">import</span> ResNet50
<span class="hljs-keyword">from</span> tensorflow.keras <span class="hljs-keyword">import</span> layers, models
<span class="hljs-keyword">from</span> tensorflow.keras.optimizers <span class="hljs-keyword">import</span> Adam

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / <span class="hljs-number">255.0</span>, x_test / <span class="hljs-number">255.0</span>

base_model = ResNet50(weights=<span class="hljs-string">'imagenet'</span>, include_top=<span class="hljs-literal">False</span>, input_shape=(<span class="hljs-number">32</span>, <span class="hljs-number">32</span>, <span class="hljs-number">3</span>))
x = base_model.output
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dense(<span class="hljs-number">1024</span>, activation=<span class="hljs-string">'relu'</span>)(x)
x = layers.Dense(<span class="hljs-number">10</span>, activation=<span class="hljs-string">'softmax'</span>)(x)

model = models.Model(inputs=base_model.input, outputs=x)

<span class="hljs-keyword">for</span> layer <span class="hljs-keyword">in</span> base_model.layers:
    layer.trainable = <span class="hljs-literal">False</span>

model.compile(optimizer=Adam(), loss=<span class="hljs-string">'sparse_categorical_crossentropy'</span>, metrics=[<span class="hljs-string">'accuracy'</span>])

model.fit(x_train, y_train, epochs=<span class="hljs-number">10</span>, validation_data=(x_test, y_test))

<span class="hljs-keyword">for</span> layer <span class="hljs-keyword">in</span> base_model.layers[<span class="hljs-number">-10</span>:]:
    layer.trainable = <span class="hljs-literal">True</span>

model.compile(optimizer=Adam(learning_rate=<span class="hljs-number">1e-5</span>), loss=<span class="hljs-string">'sparse_categorical_crossentropy'</span>, metrics=[<span class="hljs-string">'accuracy'</span>])
model.fit(x_train, y_train, epochs=<span class="hljs-number">10</span>, validation_data=(x_test, y_test))
</code></pre>
<ul>
<li><p><strong>Load the Pre-trained Model</strong>: The ResNet50 model is loaded with weights pre-trained on ImageNet, excluding the top layers.</p>
</li>
<li><p><strong>Add Custom Layers</strong>: Global average pooling is added to reduce the spatial dimensions, followed by dense layers for classification.</p>
</li>
<li><p><strong>Freeze the Base Layers</strong>: The base layers of ResNet50 are frozen to retain the learned features.</p>
</li>
<li><p><strong>Compile and Train</strong>: The model is compiled and initially trained. Then, some layers are unfrozen for fine-tuning with a lower learning rate.</p>
</li>
</ul>
<h2 id="heading-conclusion-part-i">Conclusion - Part I</h2>
<p>Implementing advanced model architectures with TensorFlow is a multifaceted process that requires a solid understanding of various components and techniques. In this first part, we have laid the groundwork for developing sophisticated machine learning models by covering the following key areas:</p>
<h3 id="heading-key-takeaways">Key Takeaways</h3>
<ol>
<li><p><strong>Setting Up Your Environment</strong>: Establishing a robust and efficient development environment is the first step towards successful model implementation. Essential tools such as TensorFlow, Keras, and Jupyter Notebooks provide a strong foundation for experimentation and development.</p>
</li>
<li><p><strong>Custom Layers and Models</strong>: Creating custom layers and models allows developers to tailor neural network architectures to specific tasks and data types. This customization enhances the flexibility and effectiveness of the models, enabling them to tackle complex challenges more efficiently.</p>
</li>
<li><p><strong>Implementing Neural Network Architectures</strong>: Understanding and implementing various neural network architectures is crucial for addressing different types of data and tasks. Convolutional Neural Networks (CNNs) excel at image processing, while Recurrent Neural Networks (RNNs) are well-suited for sequential data. Each architecture has its strengths and applications, and mastering them is essential for building effective models.</p>
</li>
<li><p><strong>Transfer Learning with Pre-trained Models</strong>: Transfer learning leverages existing, well-trained models to accelerate development and improve performance. By fine-tuning pre-trained models on new datasets, developers can achieve high accuracy and efficiency with less training time and data. This approach is particularly beneficial when dealing with limited data or complex tasks.</p>
</li>
</ol>
<p>Let's wait for <strong><em>Part II</em></strong> soon, shall we?</p>
]]></content:encoded></item><item><title><![CDATA[Fine-tuning BERT for text classification with KerasNLP]]></title><description><![CDATA[Introduction
Text classification is a basic job in natural language processing (NLP) that is used in sentiment analysis, spam detection, and content categorization. Transformer-based models, like BERT (Bi-directional Encoder Representations from Tran...]]></description><link>https://kambale.dev/fine-tuning-bert</link><guid isPermaLink="true">https://kambale.dev/fine-tuning-bert</guid><category><![CDATA[finetuning]]></category><category><![CDATA[BERT]]></category><category><![CDATA[kerasNLP]]></category><category><![CDATA[keras]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Sat, 11 May 2024 19:50:40 GMT</pubDate><enclosure url="https://cdn.hashnode.com/res/hashnode/image/upload/v1710174607858/d6b96b98-a153-4237-9cc7-a20a4c393264.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<h1 id="heading-introduction">Introduction</h1>
<p>Text classification is a basic job in natural language processing (NLP) that is used in sentiment analysis, spam detection, and content categorization. Transformer-based models, like BERT (Bi-directional Encoder Representations from Transformers), have become popular recently because of their outstanding performance in different NLP tasks.</p>
<p>In this article, we'll explore how to implement text classification using BERT and the KerasNLP library, providing examples and code snippets to guide you through the process.</p>
<h3 id="heading-understanding-bert">Understanding BERT</h3>
<p>BERT, introduced by Google in 2018, is a pre-trained transformer-based model created for understanding natural language. Unlike traditional models that analyze text in one direction, BERT looks at context from both sides, which helps it capture complex relationships within sentences effectively.</p>
<h3 id="heading-bert-architecture">BERT Architecture</h3>
<p>BERT's architecture consists of layers of attention mechanisms and feedforward neural networks. It employs a transformer encoder stack, allowing it to learn contextualized representations of words. The model is pre-trained on large corpora, gaining a deep understanding of language nuances.</p>
<p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1715456920422/68ac176b-cff4-408f-92b4-0776ab115bab.png" alt class="image--center mx-auto" /></p>
<h3 id="heading-tokenization-with-bert"><strong>Tokenization with BERT</strong></h3>
<p>Before delving into text classification, it's crucial to understand tokenization, a process that breaks down text into smaller units, such as words or subwords. BERT utilizes WordPiece tokenization, which divides text into subword tokens, enhancing its ability to handle out-of-vocabulary words.</p>
<p><img src="https://cdn.hashnode.com/res/hashnode/image/upload/v1710174138203/63b1bbb8-0bff-4cda-b9df-653e817d581a.jpeg" alt="BERT-enhanced tokenization. (c) Batuhan Gundogdu" class="image--center mx-auto" /></p>
<h1 id="heading-setting-up-the-environment">Setting Up the Environment</h1>
<p>To get started with BERT-based text classification, you need to set up your Python environment. Ensure you have the required libraries installed:</p>
<pre><code class="lang-bash">pip install tensorflow
pip install keras-nlp
pip install transformers
</code></pre>
<p>These packages include TensorFlow, KerasNLP, and the Hugging Face Transformers library, which provides pre-trained BERT models.</p>
<h3 id="heading-import-the-libraries">Import the libraries</h3>
<pre><code class="lang-python"><span class="hljs-keyword">from</span> keras_nlp <span class="hljs-keyword">import</span> load_bert_model
<span class="hljs-keyword">import</span> pandas <span class="hljs-keyword">as</span> pd
<span class="hljs-keyword">from</span> sklearn.model_selection <span class="hljs-keyword">import</span> train_test_split

<span class="hljs-keyword">from</span> keras.layers <span class="hljs-keyword">import</span> Input, Dense
<span class="hljs-keyword">from</span> keras.models <span class="hljs-keyword">import</span> Model
<span class="hljs-keyword">from</span> keras.optimizers <span class="hljs-keyword">import</span> Adam

<span class="hljs-keyword">from</span> keras_nlp <span class="hljs-keyword">import</span> Tokenizer

<span class="hljs-keyword">from</span> keras_nlp <span class="hljs-keyword">import</span> load_bert_finetuned_model
</code></pre>
<h3 id="heading-loading-bert-model-with-kerasnlp">Loading BERT Model with KerasNLP</h3>
<p>KerasNLP simplifies the process of working with BERT models in Keras. Let's load a pre-trained BERT model using KerasNLP:</p>
<pre><code class="lang-python">model_name = <span class="hljs-string">'bert-base-uncased'</span>
bert_model = load_bert_model(model_name)
</code></pre>
<p><em>Note: You can choose other variants based on your requirements, such as multilingual models or models fine-tuned for specific tasks.</em></p>
<h1 id="heading-text-classification">Text Classification</h1>
<p>Now, let's move on to text classification using BERT. For this example, we'll create a binary sentiment analysis model. Assume you have a dataset with labeled sentiments (positive or negative). First, load and preprocess the data</p>
<pre><code class="lang-python">data = pd.read_csv(<span class="hljs-string">'sentiment_data.csv'</span>)

train_data, test_data = train_test_split(data, test_size=<span class="hljs-number">0.2</span>, random_state=<span class="hljs-number">42</span>)

tokenizer = Tokenizer(model_name)
X_train = tokenizer.tokenize(train_data[<span class="hljs-string">'text'</span>].tolist())
X_test = tokenizer.tokenize(test_data[<span class="hljs-string">'text'</span>].tolist())

y_train = train_data[<span class="hljs-string">'sentiment'</span>].map({<span class="hljs-string">'negative'</span>: <span class="hljs-number">0</span>, <span class="hljs-string">'positive'</span>: <span class="hljs-number">1</span>}).values
y_test = test_data[<span class="hljs-string">'sentiment'</span>].map({<span class="hljs-string">'negative'</span>: <span class="hljs-number">0</span>, <span class="hljs-string">'positive'</span>: <span class="hljs-number">1</span>}).values
</code></pre>
<p>In this example, we assume that your dataset has a 'text' column containing the text data and a 'sentiment' column with labels ('negative' or 'positive'). Adjust the column names based on your dataset structure.</p>
<h2 id="heading-building-the-bert-text-classification-model"><strong>Building the BERT Text Classification Model</strong></h2>
<p>Now, let's build the BERT-based text classification model using Keras:</p>
<pre><code class="lang-python">input_layer = Input(shape=(tokenizer.max_seq_length,), dtype=<span class="hljs-string">'int32'</span>)

bert_output = bert_model(input_layer)

output_layer = Dense(<span class="hljs-number">1</span>, activation=<span class="hljs-string">'sigmoid'</span>)(bert_output[<span class="hljs-string">'pooled_output'</span>])

model = Model(inputs=input_layer, outputs=output_layer)

model.compile(optimizer=Adam(learning_rate=<span class="hljs-number">2e-5</span>), loss=<span class="hljs-string">'binary_crossentropy'</span>, metrics=[<span class="hljs-string">'accuracy'</span>])
</code></pre>
<p>This code snippet creates a simple neural network for sentiment analysis. The BERT output is fed into a dense layer with a sigmoid activation function for binary classification. Adjust the architecture based on your specific task and requirements.</p>
<h2 id="heading-training-the-bert-text-classification-model"><strong>Training the BERT Text Classification Model</strong></h2>
<p>Now, let's train the BERT text classification model using the prepared data:</p>
<pre><code class="lang-python">model.fit(X_train, y_train, epochs=<span class="hljs-number">3</span>, batch_size=<span class="hljs-number">32</span>, validation_split=<span class="hljs-number">0.1</span>)

loss, accuracy = model.evaluate(X_test, y_test)
print(<span class="hljs-string">f'Test Loss: <span class="hljs-subst">{loss}</span>, Test Accuracy: <span class="hljs-subst">{accuracy}</span>'</span>)
</code></pre>
<p>This code snippet trains the model for three epochs with a batch size of 32 and validates on a 10% subset of the training data. After training, it evaluates the model on the test set, providing insights into its performance.</p>
<h2 id="heading-fine-tuning-bert-for-specific-tasks"><strong>Fine-Tuning BERT for Specific Tasks</strong></h2>
<p>While the above example demonstrates a basic BERT text classification model, fine-tuning allows you to adapt BERT to specific tasks or domains. Fine-tuning involves training the pre-trained BERT model on a task-specific dataset, enabling it to learn task-specific features.</p>
<h3 id="heading-loading-a-fine-tuned-bert-model"><strong>Loading a Fine-Tuned BERT Model</strong></h3>
<p>Assuming you have a fine-tuned BERT model saved, you can load it using KerasNLP:</p>
<pre><code class="lang-python">fine_tuned_model_path = <span class="hljs-string">'path/to/fine_tuned_model'</span>
fine_tuned_model = load_bert_finetuned_model(fine_tuned_model_path)
</code></pre>
<p><em>Replace 'path/to/fine_tuned_model' with the actual path to your fine-tuned BERT model.</em></p>
<h3 id="heading-fine-tuning-bert-for-text-classification"><strong>Fine-Tuning BERT for Text Classification</strong></h3>
<p>Let's explore how to fine-tune BERT for text classification using KerasNLP. Assume you have a task-specific dataset with text and corresponding labels:</p>
<pre><code class="lang-python">task_data = pd.read_csv(<span class="hljs-string">'task_specific_data.csv'</span>)

X_task = tokenizer.tokenize(task_data[<span class="hljs-string">'text'</span>].tolist())

y_task = task_data[<span class="hljs-string">'label'</span>].values
</code></pre>
<p>Now, fine-tune the BERT model on your task-specific dataset:</p>
<pre><code class="lang-python">fine_tuned_model.fit(X_task, y_task, epochs=<span class="hljs-number">5</span>, batch_size=<span class="hljs-number">16</span>, validation_split=<span class="hljs-number">0.1</span>)

fine_tuned_model.save(<span class="hljs-string">'path/to/save/fine_tuned_model'</span>)
</code></pre>
<p>This code snippet fine-tunes the BERT model on the task-specific dataset for five epochs with a batch size of 16, validating on a 10% subset. After fine-tuning, it saves the model for future use.</p>
<h2 id="heading-conclusion"><strong>Conclusion</strong></h2>
<p>In this article, we explored text classification using BERT and KerasNLP. We explained the fundamentals of BERT, prepared the environment, loaded a pre-trained BERT model, and created a basic text classification model. Furthermore, we talked about fine-tuning BERT for particular tasks, offering code snippets and examples to help you along the way.</p>
<p>Implementing text classification with BERT offers numerous opportunities for NLP applications. Whether you're focusing on sentiment analysis, spam detection, or any classification task, using BERT can greatly improve the accuracy and reliability of your models. As NLP progresses, keeping abreast of the newest developments and integrating them into your projects will keep you ahead in this dynamic and thrilling field.</p>
]]></content:encoded></item><item><title><![CDATA[Building, Compiling, and Fitting Models with TensorFlow]]></title><description><![CDATA[Introduction
TensorFlow is a free and open-source software library that can be used to build machine learning models. It includes the Keras API, which provides a user-friendly interface for building models. Machine learning engineers make decisions a...]]></description><link>https://kambale.dev/build-compile-and-fit-models-with-tensorflow</link><guid isPermaLink="true">https://kambale.dev/build-compile-and-fit-models-with-tensorflow</guid><category><![CDATA[TensorFlow]]></category><category><![CDATA[keras]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Tue, 16 Jan 2024 13:04:58 GMT</pubDate><enclosure url="https://cdn.hashnode.com/res/hashnode/image/upload/v1693603576873/57c366e8-5358-41f0-a53e-bbe6c1824815.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<h3 id="heading-introduction">Introduction</h3>
<p>TensorFlow is a free and open-source software library that can be used to build machine learning models. It includes the Keras API, which provides a user-friendly interface for building models. Machine learning engineers make decisions about the architecture of a model based on the type of data they are working with, the task they are trying to accomplish, and the resources they have available. The best way to learn how to build models in TensorFlow is to start with a simple task and then gradually work your way up to more complex tasks.</p>
<h3 id="heading-why-and-how">Why and How?</h3>
<p>Machine learning engineers consider the type of problem, the properties of the data, and the intended performance of the model when choosing a model architecture. Here are some tips to help you understand the decisions they make:</p>
<p><strong>Start with a simple architecture</strong>: It is often a good idea to start with a simple architecture when building a new model and add complexity as needed. This allows you to quickly test and improve your ideas.</p>
<p><strong>Experiment with different architectures</strong>: There is no one-size-fits-all answer to model architecture. It is important to try different architectures to see which best solves your problem.</p>
<p><strong>Use prior knowledge</strong>: If you know about the problem you are trying to solve or the data you are using, you can use this information to inform your model architecture decisions.</p>
<p><strong>Stay up-to-date with the field</strong>: Machine learning is a constantly evolving field, with new methods and architectures being developed all the time. Staying up-to-date with the latest research can help you make informed decisions about your model design.</p>
<h3 id="heading-building-the-model"><strong>Building the model</strong></h3>
<p>Let's say you are building a model to classify images of cats and dogs. You could start with a simple architecture, such as a convolutional neural network (CNN). You could then experiment with different architectures, such as a recurrent neural network (RNN) or a long short-term memory (LSTM) network. You could also use prior knowledge about the problem, such as the fact that cats and dogs have different fur patterns, to inform your architectural decisions. Finally, you could stay up-to-date with the latest research on image classification to find new and improved architectures.</p>
<p><strong>Import libraries</strong></p>
<pre><code class="lang-python"><span class="hljs-comment"># import libraries already installed </span>
<span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf
<span class="hljs-keyword">from</span> tensorflow.keras.models <span class="hljs-keyword">import</span> Sequential
<span class="hljs-keyword">from</span> tensorflow.keras.layers <span class="hljs-keyword">import</span> Dense
</code></pre>
<p><strong>Prepare your data</strong></p>
<p>Load your dataset using appropriate methods (e.g., <code>tf.keras.datasets</code>, <code>pandas</code>, etc.).</p>
<p>Preprocess your data if needed (e.g., normalization, scaling, feature engineering).</p>
<p>Split your data into training and testing sets.</p>
<p><strong>Model architecture</strong></p>
<pre><code class="lang-python"><span class="hljs-comment"># Create a sequential model</span>
model = Sequential()

<span class="hljs-comment"># Add layers to the model</span>
model.add(Flatten(input_shape=(<span class="hljs-number">28</span>, <span class="hljs-number">28</span>)))  
model.add(Dense(<span class="hljs-number">128</span>, activation=<span class="hljs-string">'relu'</span>))                      
model.add(Dense(<span class="hljs-number">10</span>))
</code></pre>
<p>We are creating a basic neural network model using TensorFlow's Keras API. The first layer is a <code>Flatten</code> layer that takes an input image with a shape of (28, 28) and flattens it into a 1D array. The subsequent layer is a <code>Dense</code> layer with 128 neurons, using the <code>ReLU</code> activation function. Finally, we have another Dense layer with 10 neurons.</p>
<p>Our choice of the number of layers and neurons was based on previous knowledge and experimentation. For example, 128 neurons in the hidden layer have been shown to perform well on similar problems in the past. Similarly, using the ReLU activation function is a common choice as it has been proven to be effective in practice. For more detailed explanations and information, please refer to the provided link.</p>
<p>After constructing the model, the next step is to compile it. This involves instructing TensorFlow on how we want it to learn. Our ultimate goal is to enable our program to learn and become more intelligent to tackle future challenges.</p>
<h3 id="heading-compiling-a-model-in-tensorflow">Compiling a Model in TensorFlow</h3>
<p>Building a TensorFlow model is like building a Jenga tower, where the different layers of the model correspond to the different types of blocks in the tower. When building the model, we check that the tower is stable and that all the bricks are neatly stacked. Compiling the model is like adding the finishing touches to the tower.</p>
<p>To compile a machine learning model, we select components like the loss function (how well the model is performing) and the optimizer (which helps adjust the blocks to make the tower more stable). The loss function acts as a benchmark, while the optimizer acts as a shovel. Machine learning engineers select the best-performing loss function and optimizer for their problem.</p>
<p>Here’s an example of compiling a model in TensorFlow:</p>
<pre><code class="lang-python"><span class="hljs-comment"># Configure learning process</span>
model.compile(optimizer=<span class="hljs-string">'adam'</span>,
              loss=<span class="hljs-string">'categorical_crossentropy'</span>,
              metrics=[<span class="hljs-string">'accuracy'</span>])
</code></pre>
<p>In TensorFlow, we compile a model to set up the loss function, optimizer, and metrics. This is like ensuring that all the Jenga blocks are properly placed. After creating the model, we can fit it with data to train it.</p>
<p>In this example, we are telling TensorFlow that we want to use <code>categorical_crossentropy</code> as our loss function and <code>adam</code> as our optimizer. We are also saying that we want to keep track of how accurate our model is by including accuracy in our list of metrics.</p>
<h3 id="heading-fitting-a-model-in-tensorflow">Fitting a Model in TensorFlow</h3>
<p>When working with TensorFlow, training a model is similar to playing Jenga. You must maintain the balance of your tower of blocks, with each layer representing a different level of the tower. As you add or remove blocks, you assess how well the tower can remain stable.</p>
<p>In the field of machine learning, fitting a model involves providing it with data to learn from. The model examines the data and tries to make predictions, then measures the accuracy of those predictions. If the results are not up to standard, the model tweaks its settings and tries again, aiming to enhance its accuracy with each attempt.</p>
<pre><code class="lang-python"><span class="hljs-comment"># Train the model on your training data</span>
model.fit(x_train, y_train, epochs=<span class="hljs-number">10</span>, batch_size=<span class="hljs-number">32</span>)
</code></pre>
<p>Training a TensorFlow model involves feeding it data (<code>x_train</code>, <code>y_train</code>) and letting it learn through multiple passes (epochs). We test its performance on unseen data (<code>x_test</code>, <code>y_test</code>) to gauge its progress.</p>
<p><strong>Why is fitting important?</strong></p>
<p>Model fitting is the core of machine learning. Just like a poorly stacked Jenga tower, a poorly fitted model won't be reliable for real-world decisions. Fitting finds the best internal settings (hyperparameters) for your data, allowing the model to extract key information and make accurate predictions.</p>
<p><strong>Think of it as automated tuning</strong></p>
<p>Fitting automatically adjusts your model's parameters to optimally solve your specific problem. This ensures high accuracy and eliminates manual parameter tweaking.</p>
<h3 id="heading-evaluate-your-model">Evaluate your model</h3>
<pre><code class="lang-python"><span class="hljs-comment"># Evaluate model performance on test data</span>
test_loss, test_acc = model.evaluate(x_test, y_test)
print(<span class="hljs-string">'Test accuracy:'</span>, test_acc)
</code></pre>
<h3 id="heading-why-choose-tensorflow">Why Choose TensorFlow?</h3>
<p><strong>Flexibility and Versatility</strong></p>
<p>TensorFlow supports various deep learning tasks like image recognition, natural language processing, and time series forecasting. It caters to a wide range of applications, making it a versatile choice for different projects. Its diverse backend options, including Python, C++, and Java, allow for integration with various existing systems and tools, enhancing flexibility.</p>
<p><strong>Scalability and Performance</strong></p>
<p>TensorFlow can handle large datasets and complex models efficiently, thanks to its distributed computing capabilities. This allows scaling up your training process for faster model development and deployment. Its integration with various cloud platforms like Google Cloud TPUs and NVIDIA GPUs further boosts performance and scalability.</p>
<p><strong>Eager Execution and Debugging</strong></p>
<p>TensorFlow offers eager execution, enabling line-by-line code evaluation and debugging. This makes it easier to understand and troubleshoot your model's behavior, leading to faster development cycles. Visualization tools like TensorBoard provide insights into your model's training process, allowing you to monitor performance and identify potential issues.</p>
<p><strong>Continuous Development and Innovation</strong></p>
<p>TensorFlow is constantly evolving, with regular updates and new features. This ensures access to cutting-edge advancements in the field of deep learning and machine learning. The active development team and community contribute to ongoing improvements in stability, performance, and usability, making TensorFlow a reliable and future-proof choice.</p>
<h3 id="heading-disadvantages">Disadvantages</h3>
<p><strong>Steep Learning Curve</strong></p>
<p>TensorFlow can have a steeper learning curve compared to some other frameworks, especially for beginners. It's complex API and diverse functionalities require a dedicated effort to master. While the extensive community and resources can help, initial setup and configuration might require additional time and effort.</p>
<p><strong>Resource Intensity</strong></p>
<p>Training complex models in TensorFlow can be resource-intensive, demanding powerful hardware and computing resources. This can be a constraint for smaller projects or those with limited budgets. Cloud platforms can alleviate this issue, but their costs need to be factored in when making the decision.</p>
<p><strong>Debugging Challenges</strong></p>
<p>Debugging complex models in TensorFlow can be challenging due to its intricate architecture and data flow. While eager execution helps, identifying the root cause of issues might require advanced knowledge and expertise. Investing in proper monitoring and logging practices can help mitigate this challenge.</p>
<p><strong>Potential for Overfitting</strong></p>
<p>TensorFlow's flexibility allows for building powerful models, but it also increases the risk of overfitting. This occurs when the model memorizes the training data instead of learning generalizable patterns. Techniques like regularization and early stopping can help prevent overfitting, but careful tuning might be necessary.</p>
<h3 id="heading-conclusion">Conclusion</h3>
<p>You've taken a major step into the world of deep learning by understanding the fundamentals of building, compiling, and fitting models with TensorFlow. This journey may have its challenges, but the rewards are significant – the ability to unlock powerful insights from your data and solve complex problems.</p>
<p>Here are some key takeaways to keep in mind as you continue your learning journey:</p>
<p><strong>Start small and iterate:</strong> Begin with simple models and gradually increase complexity as you gain confidence. Experiment with different architectures and hyperparameters to see their impact on performance.</p>
<p><strong>Leverage the community:</strong> Don't hesitate to seek help from the vast TensorFlow community. Utilize online resources, forums, and documentation to troubleshoot problems and learn from others' experiences.</p>
<p><strong>Practice makes perfect:</strong> The more you train models, the better you'll understand their behavior and potential pitfalls. Use diverse datasets and tasks to hone your skills and become a well-rounded machine learning practitioner.</p>
<p><strong>Stay curious and engaged:</strong> The field of deep learning is constantly evolving, with new tools and techniques emerging regularly. Keep up with the latest advancements and be open to exploring new ideas to stay ahead of the curve.</p>
<p>Remember, building and training effective models is not just about writing code. It's about understanding the problem you're trying to solve, choosing the right tools, and iteratively refining your approach. With dedication and a curious mind, you can harness the power of TensorFlow to build impactful solutions and become a valuable asset in the world of AI and machine learning.</p>
]]></content:encoded></item><item><title><![CDATA[Dear 2023, thank you!]]></title><description><![CDATA[Dearest gentle reader,
We won some. We lost some. There's nothing else I can add.
I wish you the very best of the 366 days ahead.
With love,
Wes.]]></description><link>https://kambale.dev/dear-2023-thank-you</link><guid isPermaLink="true">https://kambale.dev/dear-2023-thank-you</guid><category><![CDATA[2023]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Sun, 31 Dec 2023 11:33:51 GMT</pubDate><enclosure url="https://cdn.hashnode.com/res/hashnode/image/upload/v1704011425544/622c2dc4-be64-4548-a4c8-879372038f2d.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<p><em>Dearest gentle reader</em>,</p>
<p>We won some. We lost some. There's nothing else I can add.</p>
<p>I wish you the very best of the 366 days ahead.</p>
<p><em>With love,</em></p>
<p><em>Wes</em>.</p>
]]></content:encoded></item><item><title><![CDATA[Setting Up a TensorBoard in Google Colab]]></title><description><![CDATA[Introduction
Setting up TensorBoard in Google Colab can be incredibly useful for visualizing your machine learning model's training progress and performance. TensorBoard is a powerful tool that helps you monitor various metrics, visualize model archi...]]></description><link>https://kambale.dev/setting-up-a-tensorboard</link><guid isPermaLink="true">https://kambale.dev/setting-up-a-tensorboard</guid><category><![CDATA[Tensorboard]]></category><category><![CDATA[TensorFlow]]></category><category><![CDATA[Google Colab]]></category><dc:creator><![CDATA[Wesley Kambale]]></dc:creator><pubDate>Sun, 03 Sep 2023 21:00:00 GMT</pubDate><enclosure url="https://cdn.hashnode.com/res/hashnode/image/upload/v1692189864692/167bdf89-45a4-40d6-82de-be582c2b72c7.png" length="0" type="image/jpeg"/><content:encoded><![CDATA[<h3 id="heading-introduction">Introduction</h3>
<p>Setting up TensorBoard in Google Colab can be incredibly useful for visualizing your machine learning model's training progress and performance. TensorBoard is a powerful tool that helps you monitor various metrics, visualize model architectures, and gain insights into your model's behavior. Here's a step-by-step tutorial with examples and code snippets to guide you through the process.</p>
<h3 id="heading-import-necessary-libraries">Import Necessary Libraries</h3>
<p>First, you need to import the required libraries. Make sure you have TensorFlow installed in your Colab environment.</p>
<pre><code class="lang-python"><span class="hljs-keyword">import</span> tensorflow <span class="hljs-keyword">as</span> tf
<span class="hljs-keyword">from</span> tensorboard <span class="hljs-keyword">import</span> notebook
</code></pre>
<h3 id="heading-load-and-prepare-your-data">Load and Prepare Your Data</h3>
<p>For demonstration purposes, let's use a simple dataset. Replace this with your actual dataset and preprocessing steps.</p>
<pre><code class="lang-python"><span class="hljs-comment"># Load and preprocess your data</span>
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / <span class="hljs-number">255.0</span>, x_test / <span class="hljs-number">255.0</span>  <span class="hljs-comment"># Normalize pixel values</span>
</code></pre>
<h3 id="heading-build-and-compile-your-model">Build and Compile Your Model</h3>
<p>Again, this is just a simple example. Replace it with your actual model architecture and configuration.</p>
<pre><code class="lang-python">model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(<span class="hljs-number">28</span>, <span class="hljs-number">28</span>)),
    tf.keras.layers.Dense(<span class="hljs-number">128</span>, activation=<span class="hljs-string">'relu'</span>),
    tf.keras.layers.Dropout(<span class="hljs-number">0.2</span>),
    tf.keras.layers.Dense(<span class="hljs-number">10</span>)
])

<span class="hljs-comment"># Compile the model</span>
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=<span class="hljs-literal">True</span>)
model.compile(optimizer=<span class="hljs-string">'adam'</span>, loss=loss_fn, metrics=[<span class="hljs-string">'accuracy'</span>])
</code></pre>
<h3 id="heading-set-up-tensorboard-callback">Set Up TensorBoard Callback</h3>
<p>Now, you'll create a TensorBoard callback that will save logs for visualization.</p>
<pre><code class="lang-python"><span class="hljs-comment"># Define the log directory</span>
log_dir = <span class="hljs-string">"/content/logs"</span>  <span class="hljs-comment"># You can modify this path</span>

<span class="hljs-comment"># Create a TensorBoard callback</span>
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=<span class="hljs-number">1</span>)
</code></pre>
<h3 id="heading-train-your-model">Train Your Model</h3>
<p>Train your model using the <code>fit</code> function and include the TensorBoard callback.</p>
<pre><code class="lang-python"><span class="hljs-comment"># Train the model</span>
model.fit(x_train, y_train, epochs=<span class="hljs-number">5</span>, callbacks=[tensorboard_callback])
</code></pre>
<h3 id="heading-start-tensorboard-in-colab">Start TensorBoard in Colab</h3>
<p>TensorBoard can be started directly within a Colab notebook using the <code>notebook</code> module.</p>
<pre><code class="lang-python"><span class="hljs-comment"># Load TensorBoard in Colab</span>
notebook.start(<span class="hljs-string">'--logdir '</span> + log_dir)
</code></pre>
<h3 id="heading-access-and-visualize-tensorboard">Access and Visualize TensorBoard</h3>
<p>After running the previous cell, you'll see a link to access TensorBoard. Click on that link to open TensorBoard within your Colab environment. You can navigate through various tabs to visualize different aspects of your training process.</p>
<h3 id="heading-stop-tensorboard">Stop TensorBoard</h3>
<p>Once you're done with TensorBoard, you can stop it using the "Stop" button in the TensorBoard UI, or you can run the following code to stop the TensorBoard instance:</p>
<pre><code class="lang-python">notebook.stop()
</code></pre>
<h3 id="heading-conclusion">Conclusion</h3>
<p>That's it! You've successfully set up TensorBoard in Google Colab to monitor and visualize your model's training progress.</p>
<p>Remember that this tutorial provided a basic example. Depending on your use case, you might need to adjust the code to suit your specific model architecture, dataset, and training configuration.</p>
]]></content:encoded></item></channel></rss>