< Home
Week 2: Machine Learning¶
November 28, 2025
After Neil's class, I was reviewing documentation related to JAX and comparing it to Numpy.
November 30, 2025
Reviewing Neil's code with LLM, I'm going to apply a very similar prompt, but for my data:
JAX¶
First example:¶
1️⃣ Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import random
2️⃣ Upload and prepare our traveler data (Barcelona)
# Upload CSV
df = pd.read_csv(
"datasets/barcelona_viajeros_por_franja_csv.csv",
encoding="latin-1",
sep=";"
)
# We make sure to work only with the Barcelona core (in case there are more in the file)
df = df[df["NUCLEO_CERCANIAS"] == "BARCELONA"]
# Add by time slot (we add all stations)
agg = df.groupby("TRAMO_HORARIO")[["VIAJEROS_SUBIDOS", "VIAJEROS_BAJADOS"]].sum().reset_index()
# Convert "HH:MM - HH:MM" → minutes since midnight (using the start of the segment)
def time_to_min(tramo):
inicio = tramo.split("-")[0].strip()
h, m = map(int, inicio.split(":"))
return h * 60 + m
agg["minutos"] = agg["TRAMO_HORARIO"].apply(time_to_min)
agg = agg.sort_values("minutos").reset_index(drop=True)
agg.head()
| TRAMO_HORARIO | VIAJEROS_SUBIDOS | VIAJEROS_BAJADOS | minutos | |
|---|---|---|---|---|
| 0 | 00:00 - 00:30 | 377 | 916 | 0 |
| 1 | 00:30 - 01:00 | 10 | 246 | 30 |
| 2 | 04:30 - 05:00 | 424 | 24 | 270 |
| 3 | 05:00 - 05:30 | 1120 | 584 | 300 |
| 4 | 05:30 - 06:00 | 2473 | 1049 | 330 |
3️⃣ Construct input/output tensors for JAX
X: minutes since midnight (normalized) Y: passengers boarded and alighted (normalized)
# Data in numpy
X_raw = agg["minutos"].values.astype(np.float32).reshape(-1, 1)
Y_raw = agg[["VIAJEROS_SUBIDOS", "VIAJEROS_BAJADOS"]].values.astype(np.float32)
# Standardization (very important for the network to learn well)
X_mean = X_raw.mean(axis=0)
X_std = X_raw.std(axis=0) + 1e-8
X = (X_raw - X_mean) / X_std # (N, 1)
Y_mean = Y_raw.mean(axis=0)
Y_std = Y_raw.std(axis=0) + 1e-8
Y = (Y_raw - Y_mean) / Y_std # (N, 2)
# Switch to jax.numpy
X_jax = jnp.array(X)
Y_jax = jnp.array(Y)
X_raw.shape, Y_raw.shape
((41, 1), (41, 2))
4️⃣ Define a minimal neural network in JAX (MLP 1 hidden layer)
input_dim = 1 # just one minute of the day
hidden_dim = 32
output_dim = 2 # [uploaded, downloaded]
key = random.PRNGKey(0)
def init_params(key):
k1, k2 = random.split(key)
params = {
"W1": random.normal(k1, (input_dim, hidden_dim)) * 0.1,
"b1": jnp.zeros((hidden_dim,)),
"W2": random.normal(k2, (hidden_dim, output_dim)) * 0.1,
"b2": jnp.zeros((output_dim,)),
}
return params
def forward(params, x):
# x: (N, 1)
h = jnp.dot(x, params["W1"]) + params["b1"] # (N, hidden)
h = jnp.tanh(h)
y_hat = jnp.dot(h, params["W2"]) + params["b2"] # (N, 2)
return y_hat
def loss_fn(params, x, y_true):
y_pred = forward(params, x)
return jnp.mean((y_pred - y_true) ** 2)
def r2_score(params, x, y_true):
y_pred = forward(params, x)
ss_res = jnp.sum((y_true - y_pred)**2)
ss_tot = jnp.sum((y_true - jnp.mean(y_true, axis=0))**2)
return 1.0 - ss_res / ss_tot
5️⃣ Single gradient downhill training
learning_rate = 0.01
num_epochs = 5000
params = init_params(key)
loss_history = []
@jax.jit
def train_step(params, x, y):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
new_params = {k: params[k] - learning_rate * grads[k] for k in params}
return new_params, loss
for epoch in range(num_epochs):
params, loss = train_step(params, X_jax, Y_jax)
loss_history.append(float(loss))
if (epoch + 1) % 500 == 0:
r2 = float(r2_score(params, X_jax, Y_jax))
print(f"Época {epoch+1}/{num_epochs} - Loss: {loss:.5f} R² (train): {r2:.4f}")
Época 500/5000 - Loss: 0.99239 R² (train): 0.0076 Época 1000/5000 - Loss: 0.99193 R² (train): 0.0081 Época 1500/5000 - Loss: 0.99133 R² (train): 0.0087 Época 2000/5000 - Loss: 0.99024 R² (train): 0.0098 Época 2500/5000 - Loss: 0.98766 R² (train): 0.0123 Época 3000/5000 - Loss: 0.97997 R² (train): 0.0201 Época 3500/5000 - Loss: 0.95585 R² (train): 0.0442 Época 4000/5000 - Loss: 0.90069 R² (train): 0.0994 Época 4500/5000 - Loss: 0.82314 R² (train): 0.1770 Época 5000/5000 - Loss: 0.75034 R² (train): 0.2498
6️⃣ View the loss curve (model fit)
plt.figure(figsize=(6,4))
plt.plot(loss_history)
plt.xlabel("Time")
plt.ylabel("Loss MSE (normalized)")
plt.title("Evolution of the loss – JAX Network (Barcelona travelers)")
plt.grid(True)
plt.tight_layout()
plt.show()
7️⃣ Compare prediction vs. actual data (up and down)
Here we generate a smooth curve over the entire day and denormalize it to recover actual travelers:
# Points to smoothly represent the curve (from 00:00 to 24:00)
x_plot_min = np.linspace(agg["minutos"].min(), agg["minutos"].max(), 300).astype(np.float32).reshape(-1,1)
x_plot_norm = (x_plot_min - X_mean) / X_std
x_plot_norm_jax = jnp.array(x_plot_norm)
# Normalized prediction
y_pred_norm = forward(params, x_plot_norm_jax)
y_pred_norm_np = np.array(y_pred_norm)
# Denormalize back to real travelers
y_pred = y_pred_norm_np * Y_std + Y_mean # (300, 2)
# To overlay real data
minutos_real = agg["minutos"].values
subidos_real = agg["VIAJEROS_SUBIDOS"].values
bajados_real = agg["VIAJEROS_BAJADOS"].values
plt.figure(figsize=(12,6))
# Predicted curves
plt.plot(x_plot_min, y_pred[:,0], label="Prediction of UPLOADED (JAX model)")
plt.plot(x_plot_min, y_pred[:,1], label="Prediction of DOWNLOADED (JAX model)")
# Real data as points
plt.scatter(minutos_real, subidos_real, marker="o", s=40, label="Actual data UPLOADED")
plt.scatter(minutos_real, bajados_real, marker="x", s=40, label="Actual data DOWNLOADED")
# X-axis with time slot labels (as before)
plt.xticks(
ticks=agg["minutos"],
labels=agg["TRAMO_HORARIO"],
rotation=90
)
plt.xlabel("Time slot")
plt.ylabel("Number of travelers")
plt.title("JAX Model – Prediction of passengers boarding and alighting per slot (Barcelona)")
plt.legend()
plt.tight_layout()
plt.show()
Second example:¶
Now I want to separate the model by station (one JAX model per "typical" station) using the same approach as before: minute of the day → [passengers_on, passengers_off]
The station type isn't included in the CSV file, so we'll create it using a mapping dictionary for type_per_station. You can then add/change stations and types according to your actual criteria (interchange, airport, suburbs, etc.).
The workflow will be:
- Load barcelona_viajeros_por_franja_csv.csv
- Create the
TYPE_STATIONcolumn from a dictionary - For each
TYPE_STATION:
- Add passengers per time slot
- Train an MLP using JAX
- Display the loss curve
- Display the prediction curve vs. actual data
# ============================================
# 1. IMPORTS
# ============================================
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from jax import random
# ============================================
# 2. LOAD CSV AND CREATE STATION_TYPE
# ============================================
df = pd.read_csv(
"datasets/barcelona_viajeros_por_franja_csv.csv",
encoding="latin-1",
sep=";"
)
# We make sure to work only with the core of Barcelona
df = df[df["NUCLEO_CERCANIAS"] == "BARCELONA"].copy()
# --- Mapping: Station -> Station Type ---
# ⚠️ EDIT THIS TABLE with your actual types
tipo_por_estacion = {
# Interchanges / major stations
"BARCELONA-SANTS": "INTERCAMBIADOR",
"BARCELONA-ESTACIO DE FRANÇA": "INTERCAMBIADOR",
"BARCELONA-PASSEIG DE GRACIA": "INTERCAMBIADOR",
"BARCELONA-PLAÇA DE CATALUNYA": "INTERCAMBIADOR",
# Airport
"AEROPORT": "AEROPUERTO",
# Periphery / regional node
"SANT VICENÇ DE CALDERS": "PERIFERIA",
"GRANOLLERS-CENTRE": "PERIFERIA",
"MATARO": "PERIFERIA",
# Examples of urban commuter rail (CHANGE according to your actual data)
"HOSPITALET DE LLOBREGAT": "URBANA",
"BADALONA": "URBANA",
"SANT ANDREU COMTAL": "URBANA",
}
# Create STATION_TYPE column, using "OTHERS" if it is not in the dictionary
df["TIPO_ESTACION"] = df["NOMBRE_ESTACION"].map(tipo_por_estacion).fillna("OTRAS")
print("Tipos de estación disponibles:")
print(df["TIPO_ESTACION"].value_counts())
# ============================================
#3. DEFINE THE NEURAL NETWORK IN JAX (MLP)
# ============================================
input_dim = 1 # minute of the day
hidden_dim = 32
output_dim = 2 # [passengers_onboard, passengers_offboard]
learning_rate = 0.01
num_epochs = 3000
def init_params(key):
k1, k2 = random.split(key)
params = {
"W1": random.normal(k1, (input_dim, hidden_dim)) * 0.1,
"b1": jnp.zeros((hidden_dim,)),
"W2": random.normal(k2, (hidden_dim, output_dim)) * 0.1,
"b2": jnp.zeros((output_dim,)),
}
return params
def forward(params, x):
# x: (N, 1)
h = jnp.dot(x, params["W1"]) + params["b1"]
h = jnp.tanh(h)
y_hat = jnp.dot(h, params["W2"]) + params["b2"]
return y_hat
def loss_fn(params, x, y_true):
y_pred = forward(params, x)
return jnp.mean((y_pred - y_true) ** 2)
def r2_score(params, x, y_true):
y_pred = forward(params, x)
ss_res = jnp.sum((y_true - y_pred)**2)
ss_tot = jnp.sum((y_true - jnp.mean(y_true, axis=0))**2)
return 1.0 - ss_res / ss_tot
@jax.jit
def train_step(params, x, y):
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
new_params = {k: params[k] - learning_rate * grads[k] for k in params}
return new_params, loss
# ============================================
# 4. MAIN LOOP: ONE MODEL PER STATION_TYPE
# ============================================
tipos = df["TIPO_ESTACION"].unique()
print("\nEntrenando modelos por tipo de estación:")
print(tipos)
for tipo in tipos:
print(f"\n======================================")
print(f" TIPO_ESTACION: {tipo}")
print(f"======================================")
df_t = df[df["TIPO_ESTACION"] == tipo].copy()
if df_t.empty:
print(" (No hay datos para este tipo, se omite)")
continue
# --- Add by time slot (adding all stations of this type) ---
agg = df_t.groupby("TRAMO_HORARIO")[["VIAJEROS_SUBIDOS", "VIAJEROS_BAJADOS"]].sum().reset_index()
# Function to convert "HH:MM - HH:MM" to minutes from midnight (using the start of the segment)
def time_to_min(tramo):
inicio = tramo.split("-")[0].strip()
h, m = map(int, inicio.split(":"))
return h * 60 + m
agg["minutos"] = agg["TRAMO_HORARIO"].apply(time_to_min)
agg = agg.sort_values("minutos").reset_index(drop=True)
# --------------------------------
# Prepare X, Y (numpy + normalize)
# --------------------------------
X_raw = agg["minutos"].values.astype(np.float32).reshape(-1, 1)
Y_raw = agg[["VIAJEROS_SUBIDOS", "VIAJEROS_BAJADOS"]].values.astype(np.float32)
# Normalization by type
X_mean = X_raw.mean(axis=0)
X_std = X_raw.std(axis=0) + 1e-8
X = (X_raw - X_mean) / X_std
Y_mean = Y_raw.mean(axis=0)
Y_std = Y_raw.std(axis=0) + 1e-8
Y = (Y_raw - Y_mean) / Y_std
X_jax = jnp.array(X)
Y_jax = jnp.array(Y)
# --------------------------------
# Train a model for this type
# --------------------------------
key = random.PRNGKey(0)
params = init_params(key)
loss_history = []
for epoch in range(num_epochs):
params, loss = train_step(params, X_jax, Y_jax)
loss_history.append(float(loss))
if (epoch + 1) % 500 == 0:
r2 = float(r2_score(params, X_jax, Y_jax))
print(f" Época {epoch+1}/{num_epochs} - Loss: {loss:.5f} R²: {r2:.4f}")
# --------------------------------
# Loss graph for this type
# --------------------------------
plt.figure(figsize=(5,3))
plt.plot(loss_history)
plt.xlabel("Época")
plt.ylabel("Loss MSE (normalizado)")
plt.title(f"Pérdida – TIPO: {tipo}")
plt.tight_layout()
plt.show()
# --------------------------------
# Prediction vs. actual data
# --------------------------------
# Smooth curve throughout the day
x_plot_min = np.linspace(agg["minutos"].min(), agg["minutos"].max(), 300).astype(np.float32).reshape(-1,1)
x_plot_norm = (x_plot_min - X_mean) / X_std
x_plot_norm_jax = jnp.array(x_plot_norm)
y_pred_norm = forward(params, x_plot_norm_jax)
y_pred_norm_np = np.array(y_pred_norm)
y_pred = y_pred_norm_np * Y_std + Y_mean # desnormalizar
minutos_real = agg["minutos"].values
subidos_real = agg["VIAJEROS_SUBIDOS"].values
bajados_real = agg["VIAJEROS_BAJADOS"].values
plt.figure(figsize=(10,5))
# Predicted curves
plt.plot(x_plot_min, y_pred[:,0], label="Prediction of UPLOADED (JAX model)")
plt.plot(x_plot_min, y_pred[:,1], label="Prediction of DOWNLOADED (JAX model)")
# Actual data (points)
plt.scatter(minutos_real, subidos_real, marker="o", s=40, label="Actual data UPLOADED")
plt.scatter(minutos_real, bajados_real, marker="x", s=40, label="Actual data DOWNLOADED")
# X-axis: time slots
plt.xticks(
ticks=agg["minutos"],
labels=agg["TRAMO_HORARIO"],
rotation=90
)
plt.xlabel("Time slot")
plt.ylabel("Number of travelers")
plt.title(f"JAX Model – STATION_TYPE: {tipo}")
plt.legend()
plt.tight_layout()
plt.show()
Tipos de estación disponibles: TIPO_ESTACION OTRAS 3990 PERIFERIA 80 INTERCAMBIADOR 79 URBANA 39 AEROPUERTO 38 Name: count, dtype: int64 Entrenando modelos por tipo de estación: ['OTRAS' 'INTERCAMBIADOR' 'AEROPUERTO' 'PERIFERIA' 'URBANA'] ====================================== TIPO_ESTACION: OTRAS ====================================== Época 500/3000 - Loss: 0.97513 R²: 0.0249 Época 1000/3000 - Loss: 0.97447 R²: 0.0255 Época 1500/3000 - Loss: 0.97351 R²: 0.0265 Época 2000/3000 - Loss: 0.97152 R²: 0.0285 Época 2500/3000 - Loss: 0.96613 R²: 0.0339 Época 3000/3000 - Loss: 0.94949 R²: 0.0506
====================================== TIPO_ESTACION: INTERCAMBIADOR ====================================== Época 500/3000 - Loss: 0.89424 R²: 0.1058 Época 1000/3000 - Loss: 0.89308 R²: 0.1069 Época 1500/3000 - Loss: 0.89144 R²: 0.1086 Época 2000/3000 - Loss: 0.88883 R²: 0.1112 Época 2500/3000 - Loss: 0.88407 R²: 0.1159 Época 3000/3000 - Loss: 0.87395 R²: 0.1261
====================================== TIPO_ESTACION: AEROPUERTO ====================================== Época 500/3000 - Loss: 0.93024 R²: 0.0698 Época 1000/3000 - Loss: 0.92979 R²: 0.0702 Época 1500/3000 - Loss: 0.92920 R²: 0.0708 Época 2000/3000 - Loss: 0.92826 R²: 0.0717 Época 2500/3000 - Loss: 0.92652 R²: 0.0735 Época 3000/3000 - Loss: 0.92265 R²: 0.0774
====================================== TIPO_ESTACION: PERIFERIA ====================================== Época 500/3000 - Loss: 0.93283 R²: 0.0672 Época 1000/3000 - Loss: 0.93172 R²: 0.0683 Época 1500/3000 - Loss: 0.92969 R²: 0.0703 Época 2000/3000 - Loss: 0.92487 R²: 0.0751 Época 2500/3000 - Loss: 0.91262 R²: 0.0874 Época 3000/3000 - Loss: 0.88676 R²: 0.1133
====================================== TIPO_ESTACION: URBANA ====================================== Época 500/3000 - Loss: 0.95985 R²: 0.0402 Época 1000/3000 - Loss: 0.95855 R²: 0.0415 Época 1500/3000 - Loss: 0.95610 R²: 0.0439 Época 2000/3000 - Loss: 0.95013 R²: 0.0499 Época 2500/3000 - Loss: 0.93359 R²: 0.0665 Época 3000/3000 - Loss: 0.89358 R²: 0.1065