Jigme Tenzin - Fab Futures - Data Science
Home About

Jax¶

A high-performance numerical computing library that feels like NumPy but adds:¶

automatic differentiation with grad just-in-time compilation via jit using XLA Single-function vectorization with vmap Seamless execution on CPU, GPU, and TPU Ideal for data science and ML tasks that benefit from gradients and accelerators.

JAX works like NumPy¶

In [21]:
import jax.numpy as jnp

# Create an array
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])

# Basic operation
c = a + b
print(c)
[5 7 9]

Automatic Differentiation¶

In [22]:
import jax.numpy as jnp
from jax import grad

# Define a simple function
def f(x):
    return x**2

# Compute derivative
df = grad(f)

print(df(3.0))
6.0

XOR¶

What is XOR?¶

XOR stands for Exclusive OR.¶

It is a logical operation used in data science, machine learning, and computer science.

Rule of XOR:

Output is 1 (True) only when the inputs are different

Output is 0 (False) when the inputs are the same

image.png¶
In [23]:
import jax.numpy as jnp

# XOR input data
X = jnp.array([
    [0.0, 0.0],
    [0.0, 1.0],
    [1.0, 0.0],
    [1.0, 1.0]
])

# XOR output
y = jnp.array([
    [0.0],
    [1.0],
    [1.0],
    [0.0]
])

print("Inputs:\n", X)
print("Outputs:\n", y)
Inputs:
 [[0. 0.]
 [0. 1.]
 [1. 0.]
 [1. 1.]]
Outputs:
 [[0.]
 [1.]
 [1.]
 [0.]]
In [ ]: