[Tenzin Namgyal] - Fab Futures - Data Science
Home About

< Home

Session4: Machine Learning 28/11/2025¶

Jax¶

To my understanding, Jax is the python library package which will help to accelerate machine learning anaysis.

jax

  • Understand¶

Basic NumPy-like operations¶

In [3]:
import jax.numpy as jnp

# create arrays
x = jnp.array([1, 2, 3])
y = jnp.array([4, 5, 6])

print("x + y =", x + y)
print("x * y =", x * y)
print("sin(x) =", jnp.sin(x))
x + y = [5 7 9]
x * y = [ 4 10 18]
sin(x) = [0.84147096 0.9092974  0.14112   ]

Automatic Differentiation (grad)¶

In [4]:
from jax import grad

# function: f(x) = x^2 + 3x
def f(x):
    return x**2 + 3*x

df_dx = grad(f)    # derivative of f

print(df_dx(5.0))
13.0

Gradient of multi-parameter function¶

In [5]:
import jax.numpy as jnp
from jax import grad

def g(x, y):
    return x**2 + y**3

dg_dx = grad(g, argnums=0)
dg_dy = grad(g, argnums=1)

print("dg/dx =", dg_dx(2.0, 3.0))
print("dg/dy =", dg_dy(2.0, 3.0))
dg/dx = 4.0
dg/dy = 27.0

Vectorization (vmap) – apply function to array automatically¶

In [6]:
from jax import vmap

def square(x):
    return x * x

# apply across array without writing loops
xs = jnp.arange(5)
v_square = vmap(square)

print(v_square(xs))
[ 0  1  4  9 16]

Assignment - Fitting a machine learning model to the data¶

  • In my dataset I have youtube id which is string, in order to read, used pandas as (pandas.read_csv()).
In [7]:
import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp
from jax import grad, jit
import matplotlib.pyplot as plt

# ===============================
# 1. Load the CSV correctly
# ===============================
df = pd.read_csv("datasets/youtube_video.csv")

print("Columns:", df.columns)
print(df.head())

# Identify numeric columns (ignore video id / text)
numeric_df = df.select_dtypes(include=[np.number])

print("\nNumeric columns detected:", numeric_df.columns)

# Convert to JAX arrays
data = jnp.array(numeric_df.values)

# If the file has only 2 numeric columns:
# x = first column, y = second column
x = data[:, 0]
y = data[:, 1]

print("\nLoaded numeric data successfully!")
print("x shape:", x.shape)
print("y shape:", y.shape)

# ===============================
# 2. Fit a simple linear model using JAX
# y = a*x + b
# ===============================

def predict(params, x):
    a, b = params
    return a * x + b

def loss(params, x, y):
    y_pred = predict(params, x)
    return jnp.mean((y - y_pred)**2)

grad_loss = grad(loss)

# Initialize parameters
params = jnp.array([0.1, 0.1])
learning_rate = 0.01

# Gradient descent
for step in range(500):
    grads = grad_loss(params, x, y)
    params = params - learning_rate * grads

a, b = params
print("\nLearned parameters:")
print("a =", float(a))
print("b =", float(b))

# ===============================
# 3. Plot the data + fitted line
# ===============================
plt.figure(figsize=(8,5))
plt.scatter(np.array(x), np.array(y), label="Data")
plt.plot(np.array(x), np.array(a*x+b), label="Fitted Line", linewidth=3)
plt.xlabel("X")
plt.ylabel("Y")
plt.title("Fit Using JAX")
plt.legend()
plt.show()

# ===============================
# 4. Print sample prediction
# ===============================
print("\nSample prediction:")
print("For x = 10 → predicted y =", float(predict(params, 10)))
Columns: Index(['video_id', 'title', 'channel_name', 'channel_id', 'view_count',
       'like_count', 'comment_count', 'published_date', 'thumbnail'],
      dtype='object')
      video_id                                              title  \
0  QwtOeDR-N3g  Why do Human Feet Wash up on This Beach? | Fas...   
1  K0fw1uiSGE0  Range Rover SV (2025) - The King of the City, ...   
2  StuKQpzfyjI  Bruce Wayne Becomes Batman - NEW Gotham City U...   
3  C91854vbs8I            How the music of Mario Kart has evolved   
4  1DaHVSmx160     My bulldog loves mess... you can see that ))))   

           channel_name                channel_id  view_count  like_count  \
0    Fascinating Horror  UCFXad0mx4WxY1fXdbvtg0CQ       33955        1817   
1                CAR TV  UCs_gcVRMHEx5mVXfQ7icQTg       59659        1130   
2           Just4fun290  UCdPG86Ahrf9A8YzBnFDAaEg        8658          89   
3      Thomas Game Docs  UCyy7dZhgfeMMctSoo3wDXlQ      161045        5391   
4  Enzo English Bulldog  UCkk_2IfQ52OJcCUXOrJBtnw        7620          94   

   comment_count        published_date  \
0            119  2025-10-03T09:15:02Z   
1             22  2025-10-08T09:10:46Z   
2              0  2025-10-04T15:00:06Z   
3            239  2025-05-08T15:32:15Z   
4             24  2021-08-12T15:33:31Z   

                                          thumbnail  
0  https://i.ytimg.com/vi/QwtOeDR-N3g/hqdefault.jpg  
1  https://i.ytimg.com/vi/K0fw1uiSGE0/hqdefault.jpg  
2  https://i.ytimg.com/vi/StuKQpzfyjI/hqdefault.jpg  
3  https://i.ytimg.com/vi/C91854vbs8I/hqdefault.jpg  
4  https://i.ytimg.com/vi/1DaHVSmx160/hqdefault.jpg  

Numeric columns detected: Index(['view_count', 'like_count', 'comment_count'], dtype='object')

Loaded numeric data successfully!
x shape: (155669,)
y shape: (155669,)

Learned parameters:
a = nan
b = nan
No description has been provided for this image
Sample prediction:
For x = 10 → predicted y = nan