Adrian Torres - Fab Lab León - Fab Futures - Data Science
Home About

< Home

Week 2: Machine Learning¶

November 28, 2025

After Neil's class, I was reviewing documentation related to JAX and comparing it to Numpy.

November 30, 2025

Reviewing Neil's code with LLM, I'm going to apply a very similar prompt, but for my data:

JAX¶

First example:¶

1️⃣ Import libraries

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
from jax import random

2️⃣ Upload and prepare our traveler data (Barcelona)

In [3]:
# Upload CSV
df = pd.read_csv(
    "datasets/barcelona_viajeros_por_franja_csv.csv",
    encoding="latin-1",
    sep=";"
)

# We make sure to work only with the Barcelona core (in case there are more in the file)
df = df[df["NUCLEO_CERCANIAS"] == "BARCELONA"]

# Add by time slot (we add all stations)
agg = df.groupby("TRAMO_HORARIO")[["VIAJEROS_SUBIDOS", "VIAJEROS_BAJADOS"]].sum().reset_index()

# Convert "HH:MM - HH:MM" → minutes since midnight (using the start of the segment)
def time_to_min(tramo):
    inicio = tramo.split("-")[0].strip()
    h, m = map(int, inicio.split(":"))
    return h * 60 + m

agg["minutos"] = agg["TRAMO_HORARIO"].apply(time_to_min)
agg = agg.sort_values("minutos").reset_index(drop=True)

agg.head()
Out[3]:
TRAMO_HORARIO VIAJEROS_SUBIDOS VIAJEROS_BAJADOS minutos
0 00:00 - 00:30 377 916 0
1 00:30 - 01:00 10 246 30
2 04:30 - 05:00 424 24 270
3 05:00 - 05:30 1120 584 300
4 05:30 - 06:00 2473 1049 330

3️⃣ Construct input/output tensors for JAX

X: minutes since midnight (normalized) Y: passengers boarded and alighted (normalized)

In [4]:
# Data in numpy
X_raw = agg["minutos"].values.astype(np.float32).reshape(-1, 1)
Y_raw = agg[["VIAJEROS_SUBIDOS", "VIAJEROS_BAJADOS"]].values.astype(np.float32)

# Standardization (very important for the network to learn well)
X_mean = X_raw.mean(axis=0)
X_std  = X_raw.std(axis=0) + 1e-8
X = (X_raw - X_mean) / X_std   # (N, 1)

Y_mean = Y_raw.mean(axis=0)
Y_std  = Y_raw.std(axis=0) + 1e-8
Y = (Y_raw - Y_mean) / Y_std   # (N, 2)

# Switch to jax.numpy
X_jax = jnp.array(X)
Y_jax = jnp.array(Y)

X_raw.shape, Y_raw.shape
Out[4]:
((41, 1), (41, 2))

4️⃣ Define a minimal neural network in JAX (MLP 1 hidden layer)

In [5]:
input_dim = 1     # just one minute of the day
hidden_dim = 32
output_dim = 2    # [uploaded, downloaded]

key = random.PRNGKey(0)

def init_params(key):
    k1, k2 = random.split(key)
    params = {
        "W1": random.normal(k1, (input_dim, hidden_dim)) * 0.1,
        "b1": jnp.zeros((hidden_dim,)),
        "W2": random.normal(k2, (hidden_dim, output_dim)) * 0.1,
        "b2": jnp.zeros((output_dim,)),
    }
    return params

def forward(params, x):
    # x: (N, 1)
    h = jnp.dot(x, params["W1"]) + params["b1"]   # (N, hidden)
    h = jnp.tanh(h)
    y_hat = jnp.dot(h, params["W2"]) + params["b2"]  # (N, 2)
    return y_hat

def loss_fn(params, x, y_true):
    y_pred = forward(params, x)
    return jnp.mean((y_pred - y_true) ** 2)

def r2_score(params, x, y_true):
    y_pred = forward(params, x)
    ss_res = jnp.sum((y_true - y_pred)**2)
    ss_tot = jnp.sum((y_true - jnp.mean(y_true, axis=0))**2)
    return 1.0 - ss_res / ss_tot

5️⃣ Single gradient downhill training

In [6]:
learning_rate = 0.01
num_epochs = 5000

params = init_params(key)
loss_history = []

@jax.jit
def train_step(params, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    new_params = {k: params[k] - learning_rate * grads[k] for k in params}
    return new_params, loss

for epoch in range(num_epochs):
    params, loss = train_step(params, X_jax, Y_jax)
    loss_history.append(float(loss))

    if (epoch + 1) % 500 == 0:
        r2 = float(r2_score(params, X_jax, Y_jax))
        print(f"Época {epoch+1}/{num_epochs} - Loss: {loss:.5f}  R² (train): {r2:.4f}")
Época 500/5000 - Loss: 0.99239  R² (train): 0.0076
Época 1000/5000 - Loss: 0.99193  R² (train): 0.0081
Época 1500/5000 - Loss: 0.99133  R² (train): 0.0087
Época 2000/5000 - Loss: 0.99024  R² (train): 0.0098
Época 2500/5000 - Loss: 0.98766  R² (train): 0.0123
Época 3000/5000 - Loss: 0.97997  R² (train): 0.0201
Época 3500/5000 - Loss: 0.95585  R² (train): 0.0442
Época 4000/5000 - Loss: 0.90069  R² (train): 0.0994
Época 4500/5000 - Loss: 0.82314  R² (train): 0.1770
Época 5000/5000 - Loss: 0.75034  R² (train): 0.2498

6️⃣ View the loss curve (model fit)

In [7]:
plt.figure(figsize=(6,4))
plt.plot(loss_history)
plt.xlabel("Time")
plt.ylabel("Loss MSE (normalized)")
plt.title("Evolution of the loss – JAX Network (Barcelona travelers)")
plt.grid(True)
plt.tight_layout()
plt.show()
No description has been provided for this image

7️⃣ Compare prediction vs. actual data (up and down)

Here we generate a smooth curve over the entire day and denormalize it to recover actual travelers:

In [9]:
# Points to smoothly represent the curve (from 00:00 to 24:00)
x_plot_min = np.linspace(agg["minutos"].min(), agg["minutos"].max(), 300).astype(np.float32).reshape(-1,1)
x_plot_norm = (x_plot_min - X_mean) / X_std
x_plot_norm_jax = jnp.array(x_plot_norm)

# Normalized prediction
y_pred_norm = forward(params, x_plot_norm_jax)
y_pred_norm_np = np.array(y_pred_norm)

# Denormalize back to real travelers
y_pred = y_pred_norm_np * Y_std + Y_mean   # (300, 2)

# To overlay real data
minutos_real = agg["minutos"].values
subidos_real = agg["VIAJEROS_SUBIDOS"].values
bajados_real = agg["VIAJEROS_BAJADOS"].values

plt.figure(figsize=(12,6))

# Predicted curves
plt.plot(x_plot_min, y_pred[:,0], label="Prediction of UPLOADED (JAX model)")
plt.plot(x_plot_min, y_pred[:,1], label="Prediction of DOWNLOADED (JAX model)")

# Real data as points
plt.scatter(minutos_real, subidos_real, marker="o", s=40, label="Actual data UPLOADED")
plt.scatter(minutos_real, bajados_real, marker="x", s=40, label="Actual data DOWNLOADED")

# X-axis with time slot labels (as before)
plt.xticks(
    ticks=agg["minutos"],
    labels=agg["TRAMO_HORARIO"],
    rotation=90
)

plt.xlabel("Time slot")
plt.ylabel("Number of travelers")
plt.title("JAX Model – Prediction of passengers boarding and alighting per slot (Barcelona)")
plt.legend()
plt.tight_layout()
plt.show()
No description has been provided for this image

Second example:¶

Now I want to separate the model by station (one JAX model per "typical" station) using the same approach as before: minute of the day → [passengers_on, passengers_off]

The station type isn't included in the CSV file, so we'll create it using a mapping dictionary for type_per_station. You can then add/change stations and types according to your actual criteria (interchange, airport, suburbs, etc.).

The workflow will be:

  1. Load barcelona_viajeros_por_franja_csv.csv
  2. Create the TYPE_STATION column from a dictionary
  3. For each TYPE_STATION:
  • Add passengers per time slot
  • Train an MLP using JAX
  • Display the loss curve
  • Display the prediction curve vs. actual data
In [2]:
# ============================================
# 1. IMPORTS
# ============================================
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import jax
import jax.numpy as jnp
from jax import random


# ============================================
# 2. LOAD CSV AND CREATE STATION_TYPE
# ============================================
df = pd.read_csv(
    "datasets/barcelona_viajeros_por_franja_csv.csv",
    encoding="latin-1",
    sep=";"
)

# We make sure to work only with the core of Barcelona
df = df[df["NUCLEO_CERCANIAS"] == "BARCELONA"].copy()

# --- Mapping: Station -> Station Type ---
# ⚠️ EDIT THIS TABLE with your actual types
tipo_por_estacion = {
    # Interchanges / major stations
    "BARCELONA-SANTS": "INTERCAMBIADOR",
    "BARCELONA-ESTACIO DE FRANÇA": "INTERCAMBIADOR",
    "BARCELONA-PASSEIG DE GRACIA": "INTERCAMBIADOR",
    "BARCELONA-PLAÇA DE CATALUNYA": "INTERCAMBIADOR",

    # Airport
    "AEROPORT": "AEROPUERTO",

    # Periphery / regional node
    "SANT VICENÇ DE CALDERS": "PERIFERIA",
    "GRANOLLERS-CENTRE": "PERIFERIA",
    "MATARO": "PERIFERIA",

    # Examples of urban commuter rail (CHANGE according to your actual data)
    "HOSPITALET DE LLOBREGAT": "URBANA",
    "BADALONA": "URBANA",
    "SANT ANDREU COMTAL": "URBANA",
}

# Create STATION_TYPE column, using "OTHERS" if it is not in the dictionary
df["TIPO_ESTACION"] = df["NOMBRE_ESTACION"].map(tipo_por_estacion).fillna("OTRAS")

print("Tipos de estación disponibles:")
print(df["TIPO_ESTACION"].value_counts())


# ============================================
#3. DEFINE THE NEURAL NETWORK IN JAX (MLP)
# ============================================
input_dim = 1    # minute of the day
hidden_dim = 32
output_dim = 2   # [passengers_onboard, passengers_offboard]

learning_rate = 0.01
num_epochs = 3000

def init_params(key):
    k1, k2 = random.split(key)
    params = {
        "W1": random.normal(k1, (input_dim, hidden_dim)) * 0.1,
        "b1": jnp.zeros((hidden_dim,)),
        "W2": random.normal(k2, (hidden_dim, output_dim)) * 0.1,
        "b2": jnp.zeros((output_dim,)),
    }
    return params

def forward(params, x):
    # x: (N, 1)
    h = jnp.dot(x, params["W1"]) + params["b1"]
    h = jnp.tanh(h)
    y_hat = jnp.dot(h, params["W2"]) + params["b2"]
    return y_hat

def loss_fn(params, x, y_true):
    y_pred = forward(params, x)
    return jnp.mean((y_pred - y_true) ** 2)

def r2_score(params, x, y_true):
    y_pred = forward(params, x)
    ss_res = jnp.sum((y_true - y_pred)**2)
    ss_tot = jnp.sum((y_true - jnp.mean(y_true, axis=0))**2)
    return 1.0 - ss_res / ss_tot

@jax.jit
def train_step(params, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    new_params = {k: params[k] - learning_rate * grads[k] for k in params}
    return new_params, loss


# ============================================
# 4. MAIN LOOP: ONE MODEL PER STATION_TYPE
# ============================================
tipos = df["TIPO_ESTACION"].unique()
print("\nEntrenando modelos por tipo de estación:")
print(tipos)

for tipo in tipos:
    print(f"\n======================================")
    print(f"   TIPO_ESTACION: {tipo}")
    print(f"======================================")

    df_t = df[df["TIPO_ESTACION"] == tipo].copy()
    if df_t.empty:
        print("  (No hay datos para este tipo, se omite)")
        continue

  # --- Add by time slot (adding all stations of this type) ---
    agg = df_t.groupby("TRAMO_HORARIO")[["VIAJEROS_SUBIDOS", "VIAJEROS_BAJADOS"]].sum().reset_index()

   # Function to convert "HH:MM - HH:MM" to minutes from midnight (using the start of the segment)
    def time_to_min(tramo):
        inicio = tramo.split("-")[0].strip()
        h, m = map(int, inicio.split(":"))
        return h * 60 + m

    agg["minutos"] = agg["TRAMO_HORARIO"].apply(time_to_min)
    agg = agg.sort_values("minutos").reset_index(drop=True)

    # --------------------------------
    # Prepare X, Y (numpy + normalize)
    # --------------------------------
    X_raw = agg["minutos"].values.astype(np.float32).reshape(-1, 1)
    Y_raw = agg[["VIAJEROS_SUBIDOS", "VIAJEROS_BAJADOS"]].values.astype(np.float32)

   # Normalization by type
    X_mean = X_raw.mean(axis=0)
    X_std  = X_raw.std(axis=0) + 1e-8
    X = (X_raw - X_mean) / X_std

    Y_mean = Y_raw.mean(axis=0)
    Y_std  = Y_raw.std(axis=0) + 1e-8
    Y = (Y_raw - Y_mean) / Y_std

    X_jax = jnp.array(X)
    Y_jax = jnp.array(Y)

    # --------------------------------
    # Train a model for this type
    # --------------------------------
    key = random.PRNGKey(0)
    params = init_params(key)
    loss_history = []

    for epoch in range(num_epochs):
        params, loss = train_step(params, X_jax, Y_jax)
        loss_history.append(float(loss))

        if (epoch + 1) % 500 == 0:
            r2 = float(r2_score(params, X_jax, Y_jax))
            print(f"  Época {epoch+1}/{num_epochs} - Loss: {loss:.5f}  R²: {r2:.4f}")

    # --------------------------------
   # Loss graph for this type
    # --------------------------------
    plt.figure(figsize=(5,3))
    plt.plot(loss_history)
    plt.xlabel("Época")
    plt.ylabel("Loss MSE (normalizado)")
    plt.title(f"Pérdida – TIPO: {tipo}")
    plt.tight_layout()
    plt.show()

    # --------------------------------
   # Prediction vs. actual data
    # --------------------------------
   # Smooth curve throughout the day
    x_plot_min = np.linspace(agg["minutos"].min(), agg["minutos"].max(), 300).astype(np.float32).reshape(-1,1)
    x_plot_norm = (x_plot_min - X_mean) / X_std
    x_plot_norm_jax = jnp.array(x_plot_norm)

    y_pred_norm = forward(params, x_plot_norm_jax)
    y_pred_norm_np = np.array(y_pred_norm)
    y_pred = y_pred_norm_np * Y_std + Y_mean   # desnormalizar

    minutos_real = agg["minutos"].values
    subidos_real = agg["VIAJEROS_SUBIDOS"].values
    bajados_real = agg["VIAJEROS_BAJADOS"].values

    plt.figure(figsize=(10,5))

  # Predicted curves
    plt.plot(x_plot_min, y_pred[:,0], label="Prediction of UPLOADED (JAX model)")
    plt.plot(x_plot_min, y_pred[:,1], label="Prediction of DOWNLOADED (JAX model)")

  # Actual data (points)
    plt.scatter(minutos_real, subidos_real, marker="o", s=40, label="Actual data UPLOADED")
    plt.scatter(minutos_real, bajados_real, marker="x", s=40, label="Actual data DOWNLOADED")

   # X-axis: time slots
    plt.xticks(
        ticks=agg["minutos"],
        labels=agg["TRAMO_HORARIO"],
        rotation=90
    )

    plt.xlabel("Time slot")
    plt.ylabel("Number of travelers")
    plt.title(f"JAX Model – STATION_TYPE: {tipo}")
    plt.legend()
    plt.tight_layout()
    plt.show()
Tipos de estación disponibles:
TIPO_ESTACION
OTRAS             3990
PERIFERIA           80
INTERCAMBIADOR      79
URBANA              39
AEROPUERTO          38
Name: count, dtype: int64

Entrenando modelos por tipo de estación:
['OTRAS' 'INTERCAMBIADOR' 'AEROPUERTO' 'PERIFERIA' 'URBANA']

======================================
   TIPO_ESTACION: OTRAS
======================================
  Época 500/3000 - Loss: 0.97513  R²: 0.0249
  Época 1000/3000 - Loss: 0.97447  R²: 0.0255
  Época 1500/3000 - Loss: 0.97351  R²: 0.0265
  Época 2000/3000 - Loss: 0.97152  R²: 0.0285
  Época 2500/3000 - Loss: 0.96613  R²: 0.0339
  Época 3000/3000 - Loss: 0.94949  R²: 0.0506
No description has been provided for this image
No description has been provided for this image
======================================
   TIPO_ESTACION: INTERCAMBIADOR
======================================
  Época 500/3000 - Loss: 0.89424  R²: 0.1058
  Época 1000/3000 - Loss: 0.89308  R²: 0.1069
  Época 1500/3000 - Loss: 0.89144  R²: 0.1086
  Época 2000/3000 - Loss: 0.88883  R²: 0.1112
  Época 2500/3000 - Loss: 0.88407  R²: 0.1159
  Época 3000/3000 - Loss: 0.87395  R²: 0.1261
No description has been provided for this image
No description has been provided for this image
======================================
   TIPO_ESTACION: AEROPUERTO
======================================
  Época 500/3000 - Loss: 0.93024  R²: 0.0698
  Época 1000/3000 - Loss: 0.92979  R²: 0.0702
  Época 1500/3000 - Loss: 0.92920  R²: 0.0708
  Época 2000/3000 - Loss: 0.92826  R²: 0.0717
  Época 2500/3000 - Loss: 0.92652  R²: 0.0735
  Época 3000/3000 - Loss: 0.92265  R²: 0.0774
No description has been provided for this image
No description has been provided for this image
======================================
   TIPO_ESTACION: PERIFERIA
======================================
  Época 500/3000 - Loss: 0.93283  R²: 0.0672
  Época 1000/3000 - Loss: 0.93172  R²: 0.0683
  Época 1500/3000 - Loss: 0.92969  R²: 0.0703
  Época 2000/3000 - Loss: 0.92487  R²: 0.0751
  Época 2500/3000 - Loss: 0.91262  R²: 0.0874
  Época 3000/3000 - Loss: 0.88676  R²: 0.1133
No description has been provided for this image
No description has been provided for this image
======================================
   TIPO_ESTACION: URBANA
======================================
  Época 500/3000 - Loss: 0.95985  R²: 0.0402
  Época 1000/3000 - Loss: 0.95855  R²: 0.0415
  Época 1500/3000 - Loss: 0.95610  R²: 0.0439
  Época 2000/3000 - Loss: 0.95013  R²: 0.0499
  Época 2500/3000 - Loss: 0.93359  R²: 0.0665
  Época 3000/3000 - Loss: 0.89358  R²: 0.1065
No description has been provided for this image
No description has been provided for this image