Jax¶
A high-performance numerical computing library that feels like NumPy but adds:¶
automatic differentiation with grad just-in-time compilation via jit using XLA Single-function vectorization with vmap Seamless execution on CPU, GPU, and TPU Ideal for data science and ML tasks that benefit from gradients and accelerators.
JAX works like NumPy¶
In [21]:
import jax.numpy as jnp
# Create an array
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])
# Basic operation
c = a + b
print(c)
[5 7 9]
Automatic Differentiation¶
In [22]:
import jax.numpy as jnp
from jax import grad
# Define a simple function
def f(x):
return x**2
# Compute derivative
df = grad(f)
print(df(3.0))
6.0
XOR¶
What is XOR?¶
XOR stands for Exclusive OR.¶
It is a logical operation used in data science, machine learning, and computer science.
Rule of XOR:
Output is 1 (True) only when the inputs are different
Output is 0 (False) when the inputs are the same
In [23]:
import jax.numpy as jnp
# XOR input data
X = jnp.array([
[0.0, 0.0],
[0.0, 1.0],
[1.0, 0.0],
[1.0, 1.0]
])
# XOR output
y = jnp.array([
[0.0],
[1.0],
[1.0],
[0.0]
])
print("Inputs:\n", X)
print("Outputs:\n", y)
Inputs: [[0. 0.] [0. 1.] [1. 0.] [1. 1.]] Outputs: [[0.] [1.] [1.] [0.]]
In [ ]: