NumPy vs JAX: Optimizing Performance and Functionality

A beginner's guide and comparison of NumPy and JAX for scientific computing and machine learning

NumPy vs JAX: Optimizing Performance and Functionality

Introduction

Numpy and Jax are both Python libraries that are widely used in numerical computing and scientific computing. While they have similar functionalities, there are some key differences between them that make Jax particularly useful for machine learning applications. In this tutorial, we will explore some of the differences between Numpy and Jax and provide code snippets to illustrate these differences.

Installation

Both Numpy and Jax can be installed using pip. To install Numpy, you can run the following command in your terminal:

pip install numpy

To install Jax, you can run the following command:

pip install jax

Note that Jax also requires the jaxlib library, which can be installed using the following command:

pip install jaxlib

Functionality

Let us take a look at how to write NumPy and Jax when performing simple and complex scientific computing and machine learning.

Array creation

Numpy and Jax both provide functions for creating arrays. The simplest way to create an array in Numpy is to use the array function:

import numpy as np

# create a 1D array
npArrOne = np.array([1, 2, 3])
print(npArrOne)

# create a 2D array
npArrTwo = np.array([[1, 2], [3, 4]])
print(npArrTwo)

In Jax, the corresponding function is jnp.array:

import jax.numpy as jnp

# create a 1D array
jxArrOne = jnp.array([1, 2, 3])
print(jxArrOne)

# create a 2D array
jxArrTwo = jnp.array([[1, 2], [3, 4]])
print(jxArrTwo)

Note that the syntax is very similar in both libraries, but in Jax we import the numpy functions from the jax package rather than the numpy package.

Array indexing

Both Numpy and Jax provide powerful indexing capabilities for arrays. The basic syntax is the same in both libraries:

# indexing in numpy
npArr = np.array([1, 2, 3])
print(npArr[0])  
# output: 1

# indexing in jax
jxArr = jnp.array([1, 2, 3])
print(jxArr[0])  
# output: 1

However, Jax provides some additional indexing features that are particularly useful for machine learning. For example, Jax provides a function called jnp.index_update that allows you to update an element of an array by index:

# update an element in numpy
npArr = np.array([1, 2, 3])
npArr[0] = 4
print(npArr)  
# output: [4, 2, 3]

# update an element in jax
jxArr = jnp.array([1, 2, 3])
jxArr = jnp.index_update(b, 0, 4)
print(jxArr)  
# output: [4, 2, 3]

Note that in Jax, we cannot modify an array in-place, so we need to reassign the result of jnp.index_update to the original array.

Array broadcasting

One of the most powerful features of Numpy and Jax is their ability to broadcast arrays. Broadcasting allows you to perform operations on arrays with different shapes and sizes. Here is an example of broadcasting in Numpy:

# broadcasting in numpy
npArrOne = np.array([[1, 2], [3, 4]])
npArrTwo = np.array([10, 20])
print(npArrOne + npArrTwo)  
# output: [[11, 22], [13, 24]]

In Jax, broadcasting works in a similar way:

# broadcasting in jax
jxArrOne = jnp.array([[1, 2], [3, 4]])
jxArrTwo = jnp.array([10, 20])
print(jxArrOne + jxArrTwo) 
# output: [[11, 22], [13, 24]]

In the above example, Numpy and Jax are able to add the 2D array npArrOne and the 1D array npArrTwo by automatically broadcasting the dimensions of npArrTwo to match the dimensions of npArrOne.

For Jax, it provides some additional broadcasting features that are particularly useful for machine learning. For instance, Jax provides a function called jnp.vmap that allows you to apply a function to multiple inputs using broadcasting:

# vmap in Jax
def add(x, y):
    return x + y

jxArrOne = jnp.array([[1, 2], [3, 4]])
jxArrTwo = jnp.array([[10, 20], [30, 40]])

# apply add to multiple inputs using broadcasting
jxArrThree = jnp.vmap(add)(jxArrOne, jxArrTwo)
print(jxArrThree)  
# output: [[11, 22], [33, 44]]

In the code snippet above, we define a function add that adds two arrays element-wise. We then use jnp.vmap to apply this function to two arrays jxArrOne and jxArrTwo, which have the same shape. The result is a new array jxArrThree that has the same shape as jxArrOne and jxArrTwo, where each element of jxArrThree is the sum of the corresponding elements of jxArrOne and jxArrTwo.

Performance

One of the key advantages of Jax over Numpy is its ability to compile code using the XLA compiler. This allows Jax to run code on GPUs and TPUs, which can significantly improve performance for large-scale machine learning applications.

NumPyJAX
HardwareCPUCPU, GPU, TPU
ExecutionSynchronouslyAsynchronously
Parallel computationNoYes

In the example below, we show how Jax and NumPy can be used to accelerate a simple matrix multiplication:

NumPy:

# matrix multiplication in numpy
import numpy as np
import time

npArrOne = np.random.rand(1000, 1000)
npArrOne = np.random.rand(1000, 1000)

start = time.time()
npArrThree = np.dot(npArrOne, npArrOne)
end = time.time()

print("NumPy time:", end - start)

Output:

NumPy time: 0.5427792072296143

JAX:

# matrix multiplication in jax
import jax.numpy as jnp
from jax import jit

jxArrOne = jnp.random.rand(1000, 1000)
jxArrTwo = jnp.random.rand(1000, 1000)

@jit
def matmul(jxArrOne, jxArrTwo):
    return jnp.dot(jxArrOne, jxArrTwo)

start = time.time()
jxArrThree = matmul(jxArrOne, jxArrOne)
end = time.time()

print("JAX time:", end - start)

Output:

JAX time: 0.03486919403076172

In the code snippets above, we generate two random matrices NumPy and Jax of size 1000 x 1000, and we use the np.dot function to perform matrix multiplication in Numpy, and the jnp.dot function to perform matrix multiplication in Jax. We also use the @jit decorator to compile the matmul function using the XLA compiler.

From the output of time given, you can see that Jax is significantly faster than NumPy in terms of performance.

Conclusion

Summary table of the similarities and differences between NumPy and JAX

FeatureNumPyJAX
FunctionalityBasic array operations, linear algebra, statistics, image processing, Fourier transformAll of NumPy's functionality, plus automatic differentiation, just-in-time compilation, parallel execution, and stateful computations
PerformanceGood for simple tasksExcellent for complex tasks
ScalabilityGood for small to medium datasetsExcellent for large datasets

So, we have explored some of the differences between Numpy and Jax, two powerful Python libraries that are widely used in numerical computing and scientific computing.

While they have similar functionalities, Jax provides some additional features that are particularly useful for machine learning applications, including powerful indexing and broadcasting capabilities, as well as the ability to run code on GPUs and TPUs using the XLA compiler, which makes it particularly useful for machine learning applications.

By understanding the strengths and weaknesses of each library, you can choose the one that best suits your needs, and maximize the performance and functionality of your scientific computing projects.

Good luck!