Aristarco - Fab Futures - Data Science
Home About

Machine Learning¶

What is¶

Machine Learning (ML) is a branch of Artificial Intelligence (AI) where systems learn from data, identify patterns, and make decisions or predictions with minimal human intervention, rather than being explicitly programmed for every task. It uses algorithms to analyze vast datasets, learn from them, and improve performance over time, enabling applications like fraud detection, recommendation engines, and medical diagnosis.

How it Works

  • Data Input: ML models are fed large amounts of historical data (e.g., images, text, numbers).
  • Training: Algorithms find patterns and build a model through repetitive analysis, adjusting internal parameters to optimize results.
  • Learning/Generalization: The model learns to generalize from these examples, enabling it to perform tasks on new, unseen data.
  • Prediction/Action: The trained model makes predictions or classifications, like identifying a cat in a photo or suggesting a movie.

Key Types of Machine Learning

  • Supervised Learning: Trained on labeled data (e.g., images with "cat" or "dog" tags) to predict outcomes.
  • Unsupervised Learning: Finds hidden patterns in unlabeled data, like grouping similar customers.
  • Reinforcement Learning: Learns through trial and error, receiving rewards or penalties for actions, like training an agent to play a game.
  • Deep Learning: A subset using complex neural networks to learn from raw, unstructured data, powering advanced AI.

Source: What is Machine Learning? , What is Machine Learning - IBM ,

DB used: Titanic¶

For better learning ML I used an exercise used for a competition in Kaggle: Titanic - Machine Learning from Disaster Dataset. It shows survivors of the Titanic disaster by gender, class, age, fare, etc.

You can find the dataset here

In [1]:
# This examplecomes from Kaggle Titanic example. I used it to better understanding.

# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load in 

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the "../input/" directory.
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
ruta = "/home/jovyan/work/aristarco-cortes/datasets/"

for dirname, _, filenames in os.walk('/home/jovyan/work/aristarco-cortes/datasets/'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# Any results you write to the current directory are saved as output.
/home/jovyan/work/aristarco-cortes/datasets/Titanic_test.csv
/home/jovyan/work/aristarco-cortes/datasets/TENBIARE.csv
/home/jovyan/work/aristarco-cortes/datasets/enbiare_2021_fd.xlsx
/home/jovyan/work/aristarco-cortes/datasets/historico_accidentes.csv
/home/jovyan/work/aristarco-cortes/datasets/Titanic_train.csv
/home/jovyan/work/aristarco-cortes/datasets/denue_inegi_21_.csv
/home/jovyan/work/aristarco-cortes/datasets/.gitignore
/home/jovyan/work/aristarco-cortes/datasets/.ipynb_checkpoints/TENBIARE-checkpoint.csv
/home/jovyan/work/aristarco-cortes/datasets/.ipynb_checkpoints/Titanic_train-checkpoint.csv
/home/jovyan/work/aristarco-cortes/datasets/.ipynb_checkpoints/Titanic_test-checkpoint.csv
In [2]:
# Upload the train data

train_data = pd.read_csv("/home/jovyan/work/aristarco-cortes/datasets/Titanic_train.csv")
train_data.head()
Out[2]:
PassengerId Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
0 1 0 3 Braund, Mr. Owen Harris male 22.0 1 0 A/5 21171 7.2500 NaN S
1 2 1 1 Cumings, Mrs. John Bradley (Florence Briggs Th... female 38.0 1 0 PC 17599 71.2833 C85 C
2 3 1 3 Heikkinen, Miss. Laina female 26.0 0 0 STON/O2. 3101282 7.9250 NaN S
3 4 1 1 Futrelle, Mrs. Jacques Heath (Lily May Peel) female 35.0 1 0 113803 53.1000 C123 S
4 5 0 3 Allen, Mr. William Henry male 35.0 0 0 373450 8.0500 NaN S
In [3]:
test_data = pd.read_csv("/home/jovyan/work/aristarco-cortes/datasets/Titanic_test.csv")
test_data.head()
Out[3]:
PassengerId Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
0 892 3 Kelly, Mr. James male 34.5 0 0 330911 7.8292 NaN Q
1 893 3 Wilkes, Mrs. James (Ellen Needs) female 47.0 1 0 363272 7.0000 NaN S
2 894 2 Myles, Mr. Thomas Francis male 62.0 0 0 240276 9.6875 NaN Q
3 895 3 Wirz, Mr. Albert male 27.0 0 0 315154 8.6625 NaN S
4 896 3 Hirvonen, Mrs. Alexander (Helga E Lindqvist) female 22.0 1 1 3101298 12.2875 NaN S
In [24]:
# Calculating the % of women who survived
women = train_data.loc[train_data.Sex == 'female']["Survived"]
rate_women = sum(women)/len(women)

print("Real % of women who survived:", rate_women)
Real % of women who survived: 0.7420382165605095
In [5]:
# Calculating the % of men who survived

men = train_data.loc[train_data.Sex == 'male']["Survived"]
rate_men = sum(men)/len(men)

print("% of men who survived:", rate_men)
% of men who survived: 0.18890814558058924

With this numbers I have the real data of passengers that survived and now I can compare the results of the ML process vs real data¶

In [22]:
# For the ML practice I used as example the Random forest Classifier suggested in Kaggle. 

from sklearn.ensemble import RandomForestClassifier

y = train_data["Survived"]

features = ["Pclass", "Sex", "SibSp", "Parch"]
X = pd.get_dummies(train_data[features])
X_test = pd.get_dummies(test_data[features])

model = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=1)
model.fit(X, y)
predictions = model.predict(X_test)

output = pd.DataFrame({'PassengerId': test_data.PassengerId, 'Survived': predictions, 'Sex': test_data.Sex})
output.to_csv('submission.csv', index=False)
print("Your submission was successfully saved!")

ruta = "/home/jovyan/work/aristarco-cortes/datasets/submission.csv"

df = pd.read_csv(ruta, )
women = df.loc[df.Sex == 'female']["Survived"]
rate_women = sum(women)/len(women)

print("trained data % of women who survived:", rate_women)
Your submission was successfully saved!
trained data % of women who survived: 0.9473684210526315
In [7]:
import pandas as pd

# Create a series containing feature importances
feature_importances = pd.Series(model.feature_importances_, index=X.columns)

# Sort them to see the most influential features
print(feature_importances.sort_values(ascending=False))
Sex_female    0.342297
Sex_male      0.284667
Pclass        0.215508
Parch         0.080406
SibSp         0.077122
dtype: float64

JAX Solution.¶

Gemini prompt: Write the code to make the same exercise with JAX and explain line by line¶

In [8]:
# JAX Solution. Gemini prompt: Write the code to make the same exercise with JAX and explain line by line

import jax
import jax.numpy as jnp
from jax import grad, jit, random
import pandas as pd
import numpy as np

# 1. LOAD AND PREPROCESS DATA
# ---------------------------
# Load data (Assuming standard Kaggle paths or local files)
train_data = pd.read_csv('/home/jovyan/work/aristarco-cortes/datasets/Titanic_train.csv')
test_data = pd.read_csv('/home/jovyan/work/aristarco-cortes/datasets/Titanic_test.csv')

features = ["Pclass", "Sex", "SibSp", "Parch"]

# Convert categorical data to numbers (One-Hot Encoding)
X = pd.get_dummies(train_data[features]).astype(float)
X_test = pd.get_dummies(test_data[features]).astype(float)
y = train_data["Survived"].values

# Convert to JAX arrays (DeviceArrays)
X_jax = jnp.array(X.values, dtype=jnp.float32)
y_jax = jnp.array(y, dtype=jnp.float32)
X_test_jax = jnp.array(X_test.values, dtype=jnp.float32)

# Normalize data (Neural Networks perform better with small numbers around 0)
mean = X_jax.mean(axis=0)
std = X_jax.std(axis=0)
X_jax = (X_jax - mean) / std
X_test_jax = (X_test_jax - mean) / std

# 2. DEFINE THE MODEL
# -------------------
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

# A simple 2-layer Neural Network: Input -> Hidden (10 units) -> Output (1 unit)
def predict(params, inputs):
    # Layer 1: Matrix multiplication + Bias
    hidden = jnp.dot(inputs, params['w1']) + params['b1']
    # Activation function (Relu)
    hidden = jnp.maximum(0, hidden)
    # Layer 2: Output
    logits = jnp.dot(hidden, params['w2']) + params['b2']
    return sigmoid(logits) # Squash to 0-1 probability

# 3. DEFINE LOSS AND UPDATE FUNCTIONS
# -----------------------------------
def loss_fn(params, inputs, targets):
    preds = predict(params, inputs)
    # Binary Cross Entropy Loss (Standard for Yes/No classification)
    # Adding 1e-7 prevents log(0) errors
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.mean(jnp.log(label_probs + 1e-7))

# JIT (Just-In-Time) compile the update step for massive speed
@jit
def update(params, x, y, lr=0.1):
    # Calculate gradients automatically
    grads = grad(loss_fn)(params, x, y)
    
    # Update parameters: W = W - learning_rate * gradient
    new_params = {
        'w1': params['w1'] - lr * grads['w1'],
        'b1': params['b1'] - lr * grads['b1'],
        'w2': params['w2'] - lr * grads['w2'],
        'b2': params['b2'] - lr * grads['b2'],
    }
    return new_params

# 4. INITIALIZATION AND TRAINING
# ------------------------------
# JAX handles randomness explicitly with keys
key = random.PRNGKey(42)
input_dim = X_jax.shape[1]
hidden_dim = 10

# Randomly initialize weights
w1 = random.normal(key, (input_dim, hidden_dim)) * 0.1
b1 = jnp.zeros(hidden_dim)
w2 = random.normal(key, (hidden_dim, 1)) * 0.1
b2 = jnp.zeros(1)

params = {'w1': w1, 'b1': b1, 'w2': w2, 'b2': b2}

print("Training Neural Network...")
# Training loop
for i in range(2000):
    params = update(params, X_jax, y_jax.reshape(-1, 1), lr=0.1)
    if i % 500 == 0:
        current_loss = loss_fn(params, X_jax, y_jax.reshape(-1, 1))
        print(f"Epoch {i}, Loss: {current_loss:.4f}")

# 5. PREDICTION AND SUBMISSION
# ----------------------------
final_probs = predict(params, X_test_jax)
# Convert probabilities to 0 or 1 (if > 0.5, then 1)
predictions = (final_probs > 0.5).astype(int).flatten()

# Save to CSV (back to Pandas for saving)
output = pd.DataFrame({'PassengerId': test_data.PassengerId, 'Survived': predictions})
output.to_csv('submission_jax.csv', index=False)
print("JAX submission saved successfully!")
Training Neural Network...
Epoch 0, Loss: 0.6980
Epoch 500, Loss: 0.4339
Epoch 1000, Loss: 0.4285
Epoch 1500, Loss: 0.4262
JAX submission saved successfully!

Explanation of the code line by line¶

1. Setup and Preprocessing¶

import jax.numpy as jnp from jax import grad, jit, random

  • import jax.numpy as jnp: This is the JAX version of NumPy. It looks and feels like NumPy, but the arrays it creates can run on GPUs/TPUs and can be automatically differentiated.
  • grad: The magic function. It takes any Python function and mathematically calculates its derivative (gradient).
  • jit: "Just-In-Time" compilation. It translates your Python code into XLA (Accelerated Linear Algebra), making it run incredibly fast.
X_jax = jnp.array(X.values, dtype=jnp.float32)
# ... Normalization steps ...
* **Normalization**: Unlike Random Forests, Neural Networks are sensitive to the scale of numbers. We subtract the mean and divide by the standard deviation so all inputs are roughly between -1 and 1. If we don't do this, the math inside the network might "explode" (numbers get too big).

## 2. Defining the "Brain" (The Model)

### Part 1: The Helper Function
Python
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))
def sigmoid(x):: Defines a function that takes a number (or a matrix of numbers)
x.return 1 / (1 + jnp.exp(-x)): This is the mathematical formula for the Sigmoid Curve.
    jnp.exp(-x): Calculates $e^{-x}$ (exponential decay).
What it does: It takes any number, no matter how large (e.g., 1,000,000) or how small (e.g., -500), and "squashes" it to be between 0 and 1.
Why use it? We need our final output to represent a probability (0% to 100% chance of survival).
    
### Part 2: The Prediction Function
This function represents the entire structure of your Neural Network.

    def predict(params, inputs):
def predict: This is our "forward pass." It calculates the prediction based on current data.

params: A dictionary containing the trained weights (w1, w2) and biases (b1, b2). In JAX, we pass these explicitly because JAX is functional (variables aren't stored "inside" the function).

inputs: The data for the passengers (e.g., Pclass, Sex, Age).
                                                                                          hidden = jnp.dot(inputs, params['w1']) + params['b1']
jnp.dot(inputs, params['w1']): This is the Dot Product (Matrix Multiplication). It multiplies every input feature by its corresponding weight.
    
Example: If "Sex_female" is 1 and the weight for females is high, this result becomes a large number.
    
+ params['b1']: This adds the Bias. The bias acts like an intercept in a linear equation ($y = mx + b$). It allows the node to activate even if the inputs are zero.
    
hidden: The result is a set of "raw scores" for the 10 hidden neurons.

### Activation function (Relu)
    hidden = jnp.maximum(0, hidden)

jnp.maximum(0, hidden): This is the ReLU (Rectified Linear Unit) activation function.

What it does: It looks at the numbers calculated in the previous step.

If a number is positive, it keeps it.

If a number is negative, it turns it into 0.

Why? This introduces Non-Linearity. Without this step, a neural network is just a giant linear regression model. This allows the model to learn complex, "curvy" boundaries between survivors and non-survivors.

# Layer 2: Output
    logits = jnp.dot(hidden, params['w2']) + params['b2']

jnp.dot(hidden, params['w2']): We take the 10 processed signals from the hidden layer and combine them into one final score.

+ params['b2']: Add the final bias.

logits: In Machine Learning, the raw, unscaled output of the last layer (before probability conversion) is called a logit. This number could be anything (e.g., 5.4 or -12.1).

return sigmoid(logits) # Squash to 0-1 probability

Part 1.

If the logit was 5.4, sigmoid returns ~0.99 (99% chance of survival).

If the logit was -12.1, sigmoid returns ~0.00 (0% chance of survival).

Summary of the Flow
Input (Passenger Data) ->

Layer 1 (Weigh features & Add Bias) ->

ReLU (Filter out negatives) ->

Layer 2 (Combine hidden signals) ->

Sigmoid (Convert to Probability) -> Output

## 3. DEFINE LOSS AND UPDATE FUNCTIONS
Part 1: The Loss Function (The Scorecard)
This function calculates a single number representing how "badly" the model performed.

Python

def loss_fn(params, inputs, targets):
def loss_fn: Defines the loss function.

Arguments: It takes the current model weights (params), the passenger data (inputs), and the actual answers (targets: 0 for died, 1 for survived).

Python

    preds = predict(params, inputs)
preds = ...: It runs the predict function (explained in the previous step) to get the model's current guess. These are probabilities between 0 and 1.

Python

    # Binary Cross Entropy Loss (Standard for Yes/No classification)
    # Adding 1e-7 prevents log(0) errors
    label_probs = preds * targets + (1 - preds) * (1 - targets)
The Logic: This is a clever mathematical trick to select the probability corresponding to the correct answer.

If the passenger survived (targets = 1): The formula becomes preds * 1 + (1-preds) * 0. We keep preds. We want this to be high.

If the passenger died (targets = 0): The formula becomes preds * 0 + (1-preds) * 1. We keep 1 - preds. If the model predicted 0.1 (10% survival), 1-preds is 0.9 (90% death), which is high.

Result: label_probs is "How high of a probability did the model assign to the correct event happening?"

Python

    return -jnp.mean(jnp.log(label_probs + 1e-7))
label_probs + 1e-7: We add a tiny number (0.0000001) called "epsilon."

Why? Mathematically, log(0) is negative infinity (which crashes computers). If the model is 100% wrong, this tiny number saves the calculation.

jnp.log(...): We take the logarithm.

Log(1.0) is 0 (Perfect score).

Log(0.1) is -2.3 (Bad score).

Log(0.01) is -4.6 (Terrible score).

-jnp.mean(...):

Since logs of probabilities are negative numbers, we flip the sign with - so the error is positive.

mean: We average the error across all passengers in the batch.

Part 2: The Update Function (The Learning)
This is where the magic of JAX happens. It calculates how to adjust the weights to lower the error.

Shutterstock

Python

@jit
def update(params, x, y, lr=0.1):
@jit: "Just-In-Time" compilation. It tells JAX to analyze the update function and compile it into optimized XLA code that runs extremely fast on the hardware.

lr=0.1: The Learning Rate. This controls how big of a "step" we take during learning. Too big, and we overshoot; too small, and learning takes forever.

Python

    # Calculate gradients automatically
    grads = grad(loss_fn)(params, x, y)
grad(loss_fn): This is JAX's superpower. It takes your python function loss_fn and mathematically calculates its derivative.

(params, x, y): We immediately call this new gradient function with our data.

grads: This variable now holds the Slope. It tells us: "If you increase weight w1, the error goes UP. If you decrease bias b1, the error goes DOWN."

Python

    # Update parameters: W = W - learning_rate * gradient
    new_params = {
        'w1': params['w1'] - lr * grads['w1'],
        'b1': params['b1'] - lr * grads['b1'],
        'w2': params['w2'] - lr * grads['w2'],
        'b2': params['b2'] - lr * grads['b2'],
    }
The Descent: This implements Gradient Descent.

The Formula: New Weight = Old Weight - (Learning Rate * Slope)

If the slope is positive (increasing weight increases error), we subtract (make weight smaller).

If the slope is negative (increasing weight decreases error), we subtract a negative (make weight bigger).

Why a new dictionary? JAX is designed to be "Functional." We generally do not modify variables in place; we create new updated versions of them.

Python

    return new_params
return: Passes the improved "brain" (weights) back to the main loop to be used for the next round of training.

## 4. INITIALIZATION AND TRAINING

Here is the line-by-line explanation of the Initialization and Training phase. This is the setup where we create the model's "brain" and then force it to learn.

Part 1: Setting up Randomness & Architecture
Python

# JAX handles randomness explicitly with keys
key = random.PRNGKey(42)
random.PRNGKey(42): JAX handles random numbers differently than standard Python or NumPy.

The Concept: JAX is "stateless." It doesn't remember what random number was generated last. You must provide a "Key" (a seed state) every time you want a random number.

42: The seed. Using the same seed ensures that if you run this code tomorrow, you get the exact same random weights (reproducibility).

Python

input_dim = X_jax.shape[1]
hidden_dim = 10
input_dim: We check our data (X_jax) to see how many features we have (e.g., Pclass, Sex_female, Sex_male, etc.).

hidden_dim = 10: We decide arbitrarily that our hidden layer will have 10 neurons. More neurons = more complex thinking, but higher risk of memorizing data (overfitting).

Part 2: Initializing Weights (The "Blank Slate")
Before the model learns, it needs starting values. We can't start with all zeros (or the model won't learn), so we start with small random numbers.

Python

# Randomly initialize weights
w1 = random.normal(key, (input_dim, hidden_dim)) * 0.1
random.normal: Generates numbers fitting a "Bell Curve" (Gaussian distribution) centered at 0.

(input_dim, hidden_dim): The shape of the matrix. It connects every input feature to every hidden neuron.

* 0.1: Scaling. Neural networks hate big numbers. If weights start too large, the math "explodes" or gets stuck immediately. We shrink the random numbers to be small.

Python

b1 = jnp.zeros(hidden_dim)
jnp.zeros: Biases are safe to initialize at 0. This is the "intercept" for each of the 10 hidden neurons.

Python

w2 = random.normal(key, (hidden_dim, 1)) * 0.1
b2 = jnp.zeros(1)
w2: The weights connecting the 10 hidden neurons to the 1 final output neuron.

b2: The final bias for the output.

Python

params = {'w1': w1, 'b1': b1, 'w2': w2, 'b2': b2}
Packing: We wrap all these individual variables into a single Python dictionary called params. This makes it easy to pass the entire "brain" into functions like update and loss_fn.

Part 3: The Training Loop (The "Gym")
This is where the actual learning occurs.

Python

print("Training Neural Network...")
# Training loop
for i in range(2000):
range(2000): We will show the model the data and correct its mistakes 2,000 times. Each time is called an Epoch.

Python

    params = update(params, X_jax, y_jax.reshape(-1, 1), lr=0.1)
params = ...: This is crucial in JAX. Since JAX variables are immutable (cannot be changed), the update function doesn't change params inside itself. Instead, it returns a new version of params with slightly better weights, and we overwrite our variable with it.

y_jax.reshape(-1, 1):

Our y_jax might look like a flat list: [1, 0, 1].

Our model outputs a column: [[1], [0], [1]].

Reshape forces them to match shapes so the math works. -1 means "whatever length is needed."

Python

    if i % 500 == 0:
        current_loss = loss_fn(params, X_jax, y_jax.reshape(-1, 1))
        print(f"Epoch {i}, Loss: {current_loss:.4f}")
Monitoring: We don't want to print every single step (it's too fast).

i % 500 == 0: Every 500th step (0, 500, 1000, 1500), we pause to calculate the current error (loss_fn) and print it.

Goal: You should see the "Loss" number go down (e.g., 0.6 -> 0.4 -> 0.3) as the loop runs. This proves the machine is learning.

## 5. PREDICTION AND SUBMISSION

Part 1: Generating ProbabilitiesPythonfinal_probs = predict(params, X_test_jax)
    
* predict(...): We call the same function we used during training.
* params: We pass the final version of params. These contain the weights ($W$) and biases ($b$) that were optimized over 2000 loops.
* X_test_jax: This is the data for the passengers we need to submit to Kaggle. The model has never seen this data before.
* Result: final_probs is a list of numbers between 0 and 1 (e.g., [0.1, 0.9, 0.45, ...]).
    
Part 2: Making the Decision (Thresholding)Kaggle doesn't want probabilities (e.g., "I'm 70% sure"); they want a binary answer (0 or 1).

predictions = (final_probs > 0.5).astype(int).flatten()
(final_probs > 0.5): This creates a boolean array.If probability is 0.8: 0.8 > 0.5 is True.If probability is 0.2: 0.2 > 0.5 is False..astype(int): Converts the Booleans into numbers.True becomes 1 (Survived).False becomes 0 (Died)..flatten(): JAX often works with 2D column vectors (e.g., a shape of (418, 1)). Pandas prefers a simple 1D list (shape (418,)). flatten() squashes the column into a flat list.Part 3: Formatting and SavingPythonoutput = pd.DataFrame({'PassengerId': test_data.PassengerId, 'Survived': predictions})
pd.DataFrame(...): We reconstruct the spreadsheet format Kaggle requires.'PassengerId': test_data.PassengerId: We pull the original IDs (892, 893, etc.) from the test file we loaded at the very beginning.'Survived': predictions: We attach our list of 0s and 1s next to the IDs.Pythonoutput.to_csv('submission_jax.csv', index=False)
to_csv: Writes the dataframe to a physical text file.index=False: Critical! If you don't say False, Pandas will add a row number (0, 1, 2...) as the very first column. Kaggle's automated grader will reject the file because it expects exactly two columns, not three.Pythonprint("JAX submission saved successfully!")
Confirmation: A simple message to let you know the script finished without errors and the file is ready to be downloaded or submitted.
                                             

Comparison between Random Forest and same model using JAX predictions¶

Gemini prompt: como puedo hacer una comparación entre dos tablas donde el campo PassengerId es el mismo y cambia la siguiente columna de survived¶

In [9]:
# 1. Carga los archivos de predicciones

df_rf = pd.read_csv('submission.csv')
df_jax = pd.read_csv('submission_jax.csv')

df_rf.head()
Out[9]:
PassengerId Survived
0 892 0
1 893 1
2 894 0
3 895 0
4 896 1
In [10]:
df_jax.head()
Out[10]:
PassengerId Survived
0 892 0
1 893 1
2 894 0
3 895 0
4 896 0
In [11]:
# 2. FUSIONAR (MERGE) LAS TABLAS
# Usamos 'on="PassengerId"' para asegurar que alineamos al mismo pasajero
# Usamos 'suffixes' para renombrar las columnas automáticamente

comparacion = pd.merge(df_rf, df_jax, on='PassengerId', suffixes=('_rf', '_jax'))

print("--- Tabla Fusionada (Vista Previa) ---")
print(comparacion)
print("\n")
--- Tabla Fusionada (Vista Previa) ---
     PassengerId  Survived_rf  Survived_jax
0            892            0             0
1            893            1             1
2            894            0             0
3            895            0             0
4            896            1             0
..           ...          ...           ...
413         1305            0             0
414         1306            1             1
415         1307            0             0
416         1308            0             0
417         1309            0             0

[418 rows x 3 columns]


In [12]:
# 3. FILTRAR DIFERENCIAS
# Buscamos filas donde la predicción RF NO SEA IGUAL (!=) a la predicción JAX
diferencias = comparacion[comparacion['Survived_rf'] != comparacion['Survived_jax']]

print(f"--- Se encontraron {len(diferencias)} diferencias ---")
print(diferencias)

# Opcional: Calcular el porcentaje de desacuerdo
total = len(comparacion)
diff_count = len(diferencias)
print(f"\nPorcentaje de desacuerdo: {(diff_count/total)*100}%")
--- Se encontraron 20 diferencias ---
     PassengerId  Survived_rf  Survived_jax
4            896            1             0
104          996            1             0
117         1009            1             0
127         1019            1             0
142         1034            1             0
165         1057            1             0
181         1073            0             1
194         1086            1             0
196         1088            0             1
217         1109            0             1
242         1134            0             1
263         1155            1             0
283         1175            1             0
284         1176            1             0
293         1185            0             1
308         1200            0             1
333         1225            1             0
376         1268            1             0
407         1299            0             1
409         1301            1             0

Porcentaje de desacuerdo: 4.784688995215311%

Using Random Forest alone and Random Forest with JAX has 4.78% difference in data.¶

I asked Gemini to use the same Random forest model and run it with JAX and between the two datasets there is a difference in 20 data points. I don´t have enough knoledge to understand why there is this difference. I think may be there is a random generator that changes the result but I can´t tell.

Validation Set¶

Since you cannot verify your accuracy on the official Test data, you must simulate a test locally. You do this by taking your Training data (which does have the answers) and slicing off a small piece (e.g., 20%) to pretend it is test data. This is called a Validation Set.¶

Gemini prompt: Generate a script to show you how to create a "local test" so I can verify if JAX or Random Forest models are actually working¶

In [13]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

# 1. Load the Training Data (This is the only file with "Answers")
train_data = pd.read_csv('/home/jovyan/work/aristarco-cortes/datasets/Titanic_train.csv')

# 2. Prepare features (Simple version)
features = ["Pclass", "Sex", "SibSp", "Parch"]
X = pd.get_dummies(train_data[features])
y = train_data["Survived"]

# 3. SPLIT the data
# We hide 20% of the data from the model. This 20% becomes our "Local Test Set"
# random_state=42 ensures the split is the same every time
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

print(f"Training on {len(X_train)} passengers.")
print(f"Validating on {len(X_val)} passengers (Hidden from model).")

# 4. Train the model ONLY on the 80%
model = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=1)
model.fit(X_train, y_train)

# 5. Predict on the hidden 20%
predictions = model.predict(X_val)

# 6. COMPARE with the Real Results (y_val)
# This is possible because y_val comes from the original train.csv
accuracy = accuracy_score(y_val, predictions)

print("--------------------------")
print(f"Model Accuracy: {accuracy:.2%}")
print("--------------------------")

# Show a comparison table of the first 10 rows
comparison = pd.DataFrame({'Actual': y_val, 'Predicted': predictions})
print(comparison.head(10))
Training on 712 passengers.
Validating on 179 passengers (Hidden from model).
--------------------------
Model Accuracy: 79.89%
--------------------------
     Actual  Predicted
709       1          0
439       0          0
840       0          0
720       1          1
39        1          1
290       1          1
300       1          1
333       0          0
208       1          1
136       1          1

Validation for JAX Model¶

Gemini prompt: write the validation strategy code for JAX model¶

In [14]:
#Validation code for JAX Model
import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp
from jax import grad, jit, random
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# 1. Load the Training Data
train_data = pd.read_csv('/home/jovyan/work/aristarco-cortes/datasets/Titanic_train.csv')

# 2. Prepare features
features = ["Pclass", "Sex", "SibSp", "Parch"]
X = pd.get_dummies(train_data[features]).astype(float)
y = train_data["Survived"].values

# 3. SPLIT the data (80% Train, 20% Validation)
X_train_pd, X_val_pd, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# --- JAX SPECIFIC PREPROCESSING ---

# Convert to JAX arrays
X_train_jax = jnp.array(X_train_pd.values, dtype=jnp.float32)
X_val_jax = jnp.array(X_val_pd.values, dtype=jnp.float32)
y_train_jax = jnp.array(y_train, dtype=jnp.float32).reshape(-1, 1)

# NORMALIZE the data (Crucial for Neural Networks)
# We calculate mean/std on TRAINING data only to avoid "data leakage"
mean = X_train_jax.mean(axis=0)
std = X_train_jax.std(axis=0)

X_train_jax = (X_train_jax - mean) / std
X_val_jax = (X_val_jax - mean) / std # Apply same scaling to validation set

print(f"Training on {len(X_train_jax)} passengers.")
print(f"Validating on {len(X_val_jax)} passengers.")

# 4. DEFINE JAX MODEL FUNCTIONS
def sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

def predict(params, inputs):
    # Layer 1: Matrix mul + Bias -> ReLU
    hidden = jnp.dot(inputs, params['w1']) + params['b1']
    hidden = jnp.maximum(0, hidden)
    # Layer 2: Output -> Sigmoid
    logits = jnp.dot(hidden, params['w2']) + params['b2']
    return sigmoid(logits)

def loss_fn(params, inputs, targets):
    preds = predict(params, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.mean(jnp.log(label_probs + 1e-7))

@jit
def update(params, x, y, lr=0.1):
    grads = grad(loss_fn)(params, x, y)
    return {
        'w1': params['w1'] - lr * grads['w1'],
        'b1': params['b1'] - lr * grads['b1'],
        'w2': params['w2'] - lr * grads['w2'],
        'b2': params['b2'] - lr * grads['b2'],
    }

# 5. INITIALIZE AND TRAIN
key = random.PRNGKey(42)
input_dim = X_train_jax.shape[1]
hidden_dim = 10

# Initialize weights slightly randomly
w1 = random.normal(key, (input_dim, hidden_dim)) * 0.1
b1 = jnp.zeros(hidden_dim)
w2 = random.normal(key, (hidden_dim, 1)) * 0.1
b2 = jnp.zeros(1)

params = {'w1': w1, 'b1': b1, 'w2': w2, 'b2': b2}

print("Starting training loop...")
for i in range(2000):
    params = update(params, X_train_jax, y_train_jax, lr=0.1)
    if i % 500 == 0:
        loss = loss_fn(params, X_train_jax, y_train_jax)
        print(f"Epoch {i}, Loss: {loss:.4f}")

# 6. PREDICT on Validation Set
print("Predicting on validation set...")
final_probs = predict(params, X_val_jax)
# Flatten JAX array and convert to standard Numpy array to avoid Pandas issues
predictions = np.array((final_probs > 0.5).astype(int).flatten())

# 7. COMPARE with Real Results
accuracy = accuracy_score(y_val, predictions)

print("--------------------------")
print(f"JAX Model Accuracy: {accuracy:.2%}")
print("--------------------------")

# Show comparison
# We use y_val which is already a numpy array (values) and our new predictions array
comparison = pd.DataFrame({'Actual': y_val, 'Predicted': predictions})
print(comparison.head(10))
Training on 712 passengers.
Validating on 179 passengers.
Starting training loop...
Epoch 0, Loss: 0.6977
Epoch 500, Loss: 0.4352
Epoch 1000, Loss: 0.4288
Epoch 1500, Loss: 0.4262
Predicting on validation set...
--------------------------
JAX Model Accuracy: 79.89%
--------------------------
   Actual  Predicted
0       1          0
1       0          0
2       0          0
3       1          1
4       1          1
5       1          1
6       1          1
7       0          0
8       1          1
9       1          1

With different results i found suspicious that both models have the same accuracy having differences so I asked Gemini for this differences since they run the same Random Forest Model.¶

Prompt: Actua como experto en data science y Machine learning. Corrí el modelo Random forest para un dataset del Titanic y luego pedí a Gemini que corriera el mismo modelo utilizando JAX para mejorar el rendimiento. Al obtener el resultado encuentro en los datos 20 diferencias. ¿por qué sucede esto si se supone que es el mismo modelo?¶

The answer

As a Data Science expert, I can tell you that encountering differences between a Random Forest implementation in Scikit-Learn and one in JAX is completely normal—and even expected.Even if the underlying logic is the same, the technical execution differs significantly.

Here is why you are seeing those 20 discrepancies:

1. Algorithmic ImplementationWhile both are called "Random Forest," the libraries often use different "engines":Scikit-Learn: Uses a highly optimized version of the CART (Classification and Regression Trees) algorithm.JAX: JAX does not have a native "Random Forest" module. The code Gemini provided likely uses jax.numpy to manually calculate splits or a library like Haiku or Equinox. Any tiny variation in how Gini Impurity or Entropy is calculated due to floating-point precision can change an entire branch of a tree.

2. Random Number Generation (PRNG) This is the most likely culprit. Random Forest relies on randomness at two levels:Bootstrapping: Randomly selecting samples for each tree.Feature Selection: Randomly selecting a subset of columns at each node.Scikit-Learn uses the NumPy random generator (state-based). JAX, by functional design, uses an explicit PRNG (Pseudo-Random Number Generation) system based on "keys" that you must manually split (jax.random.split). It is virtually impossible for these two different generators to produce the exact same sequence of numbers, meaning the trees will be structurally different.

3. Numerical Precision and HardwareFloat32 vs. Float64: JAX often defaults to float32 to maximize speed (especially on GPUs/TPUs), whereas Scikit-Learn typically uses float64 on CPU.In the Titanic dataset, where many variables are categorical or small scales (like Fare), a difference of 10^{-7} in a decision threshold can push a passenger to the left branch instead of the right.

4. Tie-breaking Criteria If two features provide the exact same information gain when trying to split a node, each library has its own internal rule to decide which one to pick. In small datasets like the Titanic, these "ties" happen more often than you'd think.

Learning Points¶

Go slow, Think thrice I learned that I still have many blindspots in Data Science. Asking AI my doubts can become a very powerful tool to learn and better understand

Key takeaway: In Machine Learning, two models giving slightly different results doesn't mean one is "wrong." 20 differences in the Titanic dataset (which has ~891 training or 418 test rows) represents a ~2-4% variation. This is well within the expected variance range when changing random seeds or library implementations.

In [ ]: