Machine Learning¶
What is¶
Machine Learning (ML) is a branch of Artificial Intelligence (AI) where systems learn from data, identify patterns, and make decisions or predictions with minimal human intervention, rather than being explicitly programmed for every task. It uses algorithms to analyze vast datasets, learn from them, and improve performance over time, enabling applications like fraud detection, recommendation engines, and medical diagnosis.
How it Works
- Data Input: ML models are fed large amounts of historical data (e.g., images, text, numbers).
- Training: Algorithms find patterns and build a model through repetitive analysis, adjusting internal parameters to optimize results.
- Learning/Generalization: The model learns to generalize from these examples, enabling it to perform tasks on new, unseen data.
- Prediction/Action: The trained model makes predictions or classifications, like identifying a cat in a photo or suggesting a movie.
Key Types of Machine Learning
- Supervised Learning: Trained on labeled data (e.g., images with "cat" or "dog" tags) to predict outcomes.
- Unsupervised Learning: Finds hidden patterns in unlabeled data, like grouping similar customers.
- Reinforcement Learning: Learns through trial and error, receiving rewards or penalties for actions, like training an agent to play a game.
- Deep Learning: A subset using complex neural networks to learn from raw, unstructured data, powering advanced AI.
Source: What is Machine Learning? , What is Machine Learning - IBM ,
DB used: Titanic¶
For better learning ML I used an exercise used for a competition in Kaggle: Titanic - Machine Learning from Disaster Dataset. It shows survivors of the Titanic disaster by gender, class, age, fare, etc.
# This examplecomes from Kaggle Titanic example. I used it to better understanding.
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load in
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
# Input data files are available in the "../input/" directory.
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
import os
ruta = "/home/jovyan/work/aristarco-cortes/datasets/"
for dirname, _, filenames in os.walk('/home/jovyan/work/aristarco-cortes/datasets/'):
for filename in filenames:
print(os.path.join(dirname, filename))
# Any results you write to the current directory are saved as output.
/home/jovyan/work/aristarco-cortes/datasets/Titanic_test.csv /home/jovyan/work/aristarco-cortes/datasets/TENBIARE.csv /home/jovyan/work/aristarco-cortes/datasets/enbiare_2021_fd.xlsx /home/jovyan/work/aristarco-cortes/datasets/historico_accidentes.csv /home/jovyan/work/aristarco-cortes/datasets/Titanic_train.csv /home/jovyan/work/aristarco-cortes/datasets/denue_inegi_21_.csv /home/jovyan/work/aristarco-cortes/datasets/.gitignore /home/jovyan/work/aristarco-cortes/datasets/.ipynb_checkpoints/TENBIARE-checkpoint.csv /home/jovyan/work/aristarco-cortes/datasets/.ipynb_checkpoints/Titanic_train-checkpoint.csv /home/jovyan/work/aristarco-cortes/datasets/.ipynb_checkpoints/Titanic_test-checkpoint.csv
# Upload the train data
train_data = pd.read_csv("/home/jovyan/work/aristarco-cortes/datasets/Titanic_train.csv")
train_data.head()
| PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 0 | 3 | Braund, Mr. Owen Harris | male | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S |
| 1 | 2 | 1 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | female | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C85 | C |
| 2 | 3 | 1 | 3 | Heikkinen, Miss. Laina | female | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S |
| 3 | 4 | 1 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | female | 35.0 | 1 | 0 | 113803 | 53.1000 | C123 | S |
| 4 | 5 | 0 | 3 | Allen, Mr. William Henry | male | 35.0 | 0 | 0 | 373450 | 8.0500 | NaN | S |
test_data = pd.read_csv("/home/jovyan/work/aristarco-cortes/datasets/Titanic_test.csv")
test_data.head()
| PassengerId | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | |
|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 892 | 3 | Kelly, Mr. James | male | 34.5 | 0 | 0 | 330911 | 7.8292 | NaN | Q |
| 1 | 893 | 3 | Wilkes, Mrs. James (Ellen Needs) | female | 47.0 | 1 | 0 | 363272 | 7.0000 | NaN | S |
| 2 | 894 | 2 | Myles, Mr. Thomas Francis | male | 62.0 | 0 | 0 | 240276 | 9.6875 | NaN | Q |
| 3 | 895 | 3 | Wirz, Mr. Albert | male | 27.0 | 0 | 0 | 315154 | 8.6625 | NaN | S |
| 4 | 896 | 3 | Hirvonen, Mrs. Alexander (Helga E Lindqvist) | female | 22.0 | 1 | 1 | 3101298 | 12.2875 | NaN | S |
# Calculating the % of women who survived
women = train_data.loc[train_data.Sex == 'female']["Survived"]
rate_women = sum(women)/len(women)
print("Real % of women who survived:", rate_women)
Real % of women who survived: 0.7420382165605095
# Calculating the % of men who survived
men = train_data.loc[train_data.Sex == 'male']["Survived"]
rate_men = sum(men)/len(men)
print("% of men who survived:", rate_men)
% of men who survived: 0.18890814558058924
With this numbers I have the real data of passengers that survived and now I can compare the results of the ML process vs real data¶
# For the ML practice I used as example the Random forest Classifier suggested in Kaggle.
from sklearn.ensemble import RandomForestClassifier
y = train_data["Survived"]
features = ["Pclass", "Sex", "SibSp", "Parch"]
X = pd.get_dummies(train_data[features])
X_test = pd.get_dummies(test_data[features])
model = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=1)
model.fit(X, y)
predictions = model.predict(X_test)
output = pd.DataFrame({'PassengerId': test_data.PassengerId, 'Survived': predictions, 'Sex': test_data.Sex})
output.to_csv('submission.csv', index=False)
print("Your submission was successfully saved!")
ruta = "/home/jovyan/work/aristarco-cortes/datasets/submission.csv"
df = pd.read_csv(ruta, )
women = df.loc[df.Sex == 'female']["Survived"]
rate_women = sum(women)/len(women)
print("trained data % of women who survived:", rate_women)
Your submission was successfully saved! trained data % of women who survived: 0.9473684210526315
import pandas as pd
# Create a series containing feature importances
feature_importances = pd.Series(model.feature_importances_, index=X.columns)
# Sort them to see the most influential features
print(feature_importances.sort_values(ascending=False))
Sex_female 0.342297 Sex_male 0.284667 Pclass 0.215508 Parch 0.080406 SibSp 0.077122 dtype: float64
# JAX Solution. Gemini prompt: Write the code to make the same exercise with JAX and explain line by line
import jax
import jax.numpy as jnp
from jax import grad, jit, random
import pandas as pd
import numpy as np
# 1. LOAD AND PREPROCESS DATA
# ---------------------------
# Load data (Assuming standard Kaggle paths or local files)
train_data = pd.read_csv('/home/jovyan/work/aristarco-cortes/datasets/Titanic_train.csv')
test_data = pd.read_csv('/home/jovyan/work/aristarco-cortes/datasets/Titanic_test.csv')
features = ["Pclass", "Sex", "SibSp", "Parch"]
# Convert categorical data to numbers (One-Hot Encoding)
X = pd.get_dummies(train_data[features]).astype(float)
X_test = pd.get_dummies(test_data[features]).astype(float)
y = train_data["Survived"].values
# Convert to JAX arrays (DeviceArrays)
X_jax = jnp.array(X.values, dtype=jnp.float32)
y_jax = jnp.array(y, dtype=jnp.float32)
X_test_jax = jnp.array(X_test.values, dtype=jnp.float32)
# Normalize data (Neural Networks perform better with small numbers around 0)
mean = X_jax.mean(axis=0)
std = X_jax.std(axis=0)
X_jax = (X_jax - mean) / std
X_test_jax = (X_test_jax - mean) / std
# 2. DEFINE THE MODEL
# -------------------
def sigmoid(x):
return 1 / (1 + jnp.exp(-x))
# A simple 2-layer Neural Network: Input -> Hidden (10 units) -> Output (1 unit)
def predict(params, inputs):
# Layer 1: Matrix multiplication + Bias
hidden = jnp.dot(inputs, params['w1']) + params['b1']
# Activation function (Relu)
hidden = jnp.maximum(0, hidden)
# Layer 2: Output
logits = jnp.dot(hidden, params['w2']) + params['b2']
return sigmoid(logits) # Squash to 0-1 probability
# 3. DEFINE LOSS AND UPDATE FUNCTIONS
# -----------------------------------
def loss_fn(params, inputs, targets):
preds = predict(params, inputs)
# Binary Cross Entropy Loss (Standard for Yes/No classification)
# Adding 1e-7 prevents log(0) errors
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.mean(jnp.log(label_probs + 1e-7))
# JIT (Just-In-Time) compile the update step for massive speed
@jit
def update(params, x, y, lr=0.1):
# Calculate gradients automatically
grads = grad(loss_fn)(params, x, y)
# Update parameters: W = W - learning_rate * gradient
new_params = {
'w1': params['w1'] - lr * grads['w1'],
'b1': params['b1'] - lr * grads['b1'],
'w2': params['w2'] - lr * grads['w2'],
'b2': params['b2'] - lr * grads['b2'],
}
return new_params
# 4. INITIALIZATION AND TRAINING
# ------------------------------
# JAX handles randomness explicitly with keys
key = random.PRNGKey(42)
input_dim = X_jax.shape[1]
hidden_dim = 10
# Randomly initialize weights
w1 = random.normal(key, (input_dim, hidden_dim)) * 0.1
b1 = jnp.zeros(hidden_dim)
w2 = random.normal(key, (hidden_dim, 1)) * 0.1
b2 = jnp.zeros(1)
params = {'w1': w1, 'b1': b1, 'w2': w2, 'b2': b2}
print("Training Neural Network...")
# Training loop
for i in range(2000):
params = update(params, X_jax, y_jax.reshape(-1, 1), lr=0.1)
if i % 500 == 0:
current_loss = loss_fn(params, X_jax, y_jax.reshape(-1, 1))
print(f"Epoch {i}, Loss: {current_loss:.4f}")
# 5. PREDICTION AND SUBMISSION
# ----------------------------
final_probs = predict(params, X_test_jax)
# Convert probabilities to 0 or 1 (if > 0.5, then 1)
predictions = (final_probs > 0.5).astype(int).flatten()
# Save to CSV (back to Pandas for saving)
output = pd.DataFrame({'PassengerId': test_data.PassengerId, 'Survived': predictions})
output.to_csv('submission_jax.csv', index=False)
print("JAX submission saved successfully!")
Training Neural Network... Epoch 0, Loss: 0.6980 Epoch 500, Loss: 0.4339 Epoch 1000, Loss: 0.4285 Epoch 1500, Loss: 0.4262 JAX submission saved successfully!
Explanation of the code line by line¶
1. Setup and Preprocessing¶
import jax.numpy as jnp from jax import grad, jit, random
import jax.numpy as jnp: This is the JAX version of NumPy. It looks and feels like NumPy, but the arrays it creates can run on GPUs/TPUs and can be automatically differentiated.grad: The magic function. It takes any Python function and mathematically calculates its derivative (gradient).jit: "Just-In-Time" compilation. It translates your Python code into XLA (Accelerated Linear Algebra), making it run incredibly fast.
X_jax = jnp.array(X.values, dtype=jnp.float32)
# ... Normalization steps ...
* **Normalization**: Unlike Random Forests, Neural Networks are sensitive to the scale of numbers. We subtract the mean and divide by the standard deviation so all inputs are roughly between -1 and 1. If we don't do this, the math inside the network might "explode" (numbers get too big).
## 2. Defining the "Brain" (The Model)
### Part 1: The Helper Function
Python
def sigmoid(x):
return 1 / (1 + jnp.exp(-x))
def sigmoid(x):: Defines a function that takes a number (or a matrix of numbers)
x.return 1 / (1 + jnp.exp(-x)): This is the mathematical formula for the Sigmoid Curve.
jnp.exp(-x): Calculates $e^{-x}$ (exponential decay).
What it does: It takes any number, no matter how large (e.g., 1,000,000) or how small (e.g., -500), and "squashes" it to be between 0 and 1.
Why use it? We need our final output to represent a probability (0% to 100% chance of survival).
### Part 2: The Prediction Function
This function represents the entire structure of your Neural Network.
def predict(params, inputs):
def predict: This is our "forward pass." It calculates the prediction based on current data.
params: A dictionary containing the trained weights (w1, w2) and biases (b1, b2). In JAX, we pass these explicitly because JAX is functional (variables aren't stored "inside" the function).
inputs: The data for the passengers (e.g., Pclass, Sex, Age).
hidden = jnp.dot(inputs, params['w1']) + params['b1']
jnp.dot(inputs, params['w1']): This is the Dot Product (Matrix Multiplication). It multiplies every input feature by its corresponding weight.
Example: If "Sex_female" is 1 and the weight for females is high, this result becomes a large number.
+ params['b1']: This adds the Bias. The bias acts like an intercept in a linear equation ($y = mx + b$). It allows the node to activate even if the inputs are zero.
hidden: The result is a set of "raw scores" for the 10 hidden neurons.
### Activation function (Relu)
hidden = jnp.maximum(0, hidden)
jnp.maximum(0, hidden): This is the ReLU (Rectified Linear Unit) activation function.
What it does: It looks at the numbers calculated in the previous step.
If a number is positive, it keeps it.
If a number is negative, it turns it into 0.
Why? This introduces Non-Linearity. Without this step, a neural network is just a giant linear regression model. This allows the model to learn complex, "curvy" boundaries between survivors and non-survivors.
# Layer 2: Output
logits = jnp.dot(hidden, params['w2']) + params['b2']
jnp.dot(hidden, params['w2']): We take the 10 processed signals from the hidden layer and combine them into one final score.
+ params['b2']: Add the final bias.
logits: In Machine Learning, the raw, unscaled output of the last layer (before probability conversion) is called a logit. This number could be anything (e.g., 5.4 or -12.1).
return sigmoid(logits) # Squash to 0-1 probability
Part 1.
If the logit was 5.4, sigmoid returns ~0.99 (99% chance of survival).
If the logit was -12.1, sigmoid returns ~0.00 (0% chance of survival).
Summary of the Flow
Input (Passenger Data) ->
Layer 1 (Weigh features & Add Bias) ->
ReLU (Filter out negatives) ->
Layer 2 (Combine hidden signals) ->
Sigmoid (Convert to Probability) -> Output
## 3. DEFINE LOSS AND UPDATE FUNCTIONS
Part 1: The Loss Function (The Scorecard)
This function calculates a single number representing how "badly" the model performed.
Python
def loss_fn(params, inputs, targets):
def loss_fn: Defines the loss function.
Arguments: It takes the current model weights (params), the passenger data (inputs), and the actual answers (targets: 0 for died, 1 for survived).
Python
preds = predict(params, inputs)
preds = ...: It runs the predict function (explained in the previous step) to get the model's current guess. These are probabilities between 0 and 1.
Python
# Binary Cross Entropy Loss (Standard for Yes/No classification)
# Adding 1e-7 prevents log(0) errors
label_probs = preds * targets + (1 - preds) * (1 - targets)
The Logic: This is a clever mathematical trick to select the probability corresponding to the correct answer.
If the passenger survived (targets = 1): The formula becomes preds * 1 + (1-preds) * 0. We keep preds. We want this to be high.
If the passenger died (targets = 0): The formula becomes preds * 0 + (1-preds) * 1. We keep 1 - preds. If the model predicted 0.1 (10% survival), 1-preds is 0.9 (90% death), which is high.
Result: label_probs is "How high of a probability did the model assign to the correct event happening?"
Python
return -jnp.mean(jnp.log(label_probs + 1e-7))
label_probs + 1e-7: We add a tiny number (0.0000001) called "epsilon."
Why? Mathematically, log(0) is negative infinity (which crashes computers). If the model is 100% wrong, this tiny number saves the calculation.
jnp.log(...): We take the logarithm.
Log(1.0) is 0 (Perfect score).
Log(0.1) is -2.3 (Bad score).
Log(0.01) is -4.6 (Terrible score).
-jnp.mean(...):
Since logs of probabilities are negative numbers, we flip the sign with - so the error is positive.
mean: We average the error across all passengers in the batch.
Part 2: The Update Function (The Learning)
This is where the magic of JAX happens. It calculates how to adjust the weights to lower the error.
Shutterstock
Python
@jit
def update(params, x, y, lr=0.1):
@jit: "Just-In-Time" compilation. It tells JAX to analyze the update function and compile it into optimized XLA code that runs extremely fast on the hardware.
lr=0.1: The Learning Rate. This controls how big of a "step" we take during learning. Too big, and we overshoot; too small, and learning takes forever.
Python
# Calculate gradients automatically
grads = grad(loss_fn)(params, x, y)
grad(loss_fn): This is JAX's superpower. It takes your python function loss_fn and mathematically calculates its derivative.
(params, x, y): We immediately call this new gradient function with our data.
grads: This variable now holds the Slope. It tells us: "If you increase weight w1, the error goes UP. If you decrease bias b1, the error goes DOWN."
Python
# Update parameters: W = W - learning_rate * gradient
new_params = {
'w1': params['w1'] - lr * grads['w1'],
'b1': params['b1'] - lr * grads['b1'],
'w2': params['w2'] - lr * grads['w2'],
'b2': params['b2'] - lr * grads['b2'],
}
The Descent: This implements Gradient Descent.
The Formula: New Weight = Old Weight - (Learning Rate * Slope)
If the slope is positive (increasing weight increases error), we subtract (make weight smaller).
If the slope is negative (increasing weight decreases error), we subtract a negative (make weight bigger).
Why a new dictionary? JAX is designed to be "Functional." We generally do not modify variables in place; we create new updated versions of them.
Python
return new_params
return: Passes the improved "brain" (weights) back to the main loop to be used for the next round of training.
## 4. INITIALIZATION AND TRAINING
Here is the line-by-line explanation of the Initialization and Training phase. This is the setup where we create the model's "brain" and then force it to learn.
Part 1: Setting up Randomness & Architecture
Python
# JAX handles randomness explicitly with keys
key = random.PRNGKey(42)
random.PRNGKey(42): JAX handles random numbers differently than standard Python or NumPy.
The Concept: JAX is "stateless." It doesn't remember what random number was generated last. You must provide a "Key" (a seed state) every time you want a random number.
42: The seed. Using the same seed ensures that if you run this code tomorrow, you get the exact same random weights (reproducibility).
Python
input_dim = X_jax.shape[1]
hidden_dim = 10
input_dim: We check our data (X_jax) to see how many features we have (e.g., Pclass, Sex_female, Sex_male, etc.).
hidden_dim = 10: We decide arbitrarily that our hidden layer will have 10 neurons. More neurons = more complex thinking, but higher risk of memorizing data (overfitting).
Part 2: Initializing Weights (The "Blank Slate")
Before the model learns, it needs starting values. We can't start with all zeros (or the model won't learn), so we start with small random numbers.
Python
# Randomly initialize weights
w1 = random.normal(key, (input_dim, hidden_dim)) * 0.1
random.normal: Generates numbers fitting a "Bell Curve" (Gaussian distribution) centered at 0.
(input_dim, hidden_dim): The shape of the matrix. It connects every input feature to every hidden neuron.
* 0.1: Scaling. Neural networks hate big numbers. If weights start too large, the math "explodes" or gets stuck immediately. We shrink the random numbers to be small.
Python
b1 = jnp.zeros(hidden_dim)
jnp.zeros: Biases are safe to initialize at 0. This is the "intercept" for each of the 10 hidden neurons.
Python
w2 = random.normal(key, (hidden_dim, 1)) * 0.1
b2 = jnp.zeros(1)
w2: The weights connecting the 10 hidden neurons to the 1 final output neuron.
b2: The final bias for the output.
Python
params = {'w1': w1, 'b1': b1, 'w2': w2, 'b2': b2}
Packing: We wrap all these individual variables into a single Python dictionary called params. This makes it easy to pass the entire "brain" into functions like update and loss_fn.
Part 3: The Training Loop (The "Gym")
This is where the actual learning occurs.
Python
print("Training Neural Network...")
# Training loop
for i in range(2000):
range(2000): We will show the model the data and correct its mistakes 2,000 times. Each time is called an Epoch.
Python
params = update(params, X_jax, y_jax.reshape(-1, 1), lr=0.1)
params = ...: This is crucial in JAX. Since JAX variables are immutable (cannot be changed), the update function doesn't change params inside itself. Instead, it returns a new version of params with slightly better weights, and we overwrite our variable with it.
y_jax.reshape(-1, 1):
Our y_jax might look like a flat list: [1, 0, 1].
Our model outputs a column: [[1], [0], [1]].
Reshape forces them to match shapes so the math works. -1 means "whatever length is needed."
Python
if i % 500 == 0:
current_loss = loss_fn(params, X_jax, y_jax.reshape(-1, 1))
print(f"Epoch {i}, Loss: {current_loss:.4f}")
Monitoring: We don't want to print every single step (it's too fast).
i % 500 == 0: Every 500th step (0, 500, 1000, 1500), we pause to calculate the current error (loss_fn) and print it.
Goal: You should see the "Loss" number go down (e.g., 0.6 -> 0.4 -> 0.3) as the loop runs. This proves the machine is learning.
## 5. PREDICTION AND SUBMISSION
Part 1: Generating ProbabilitiesPythonfinal_probs = predict(params, X_test_jax)
* predict(...): We call the same function we used during training.
* params: We pass the final version of params. These contain the weights ($W$) and biases ($b$) that were optimized over 2000 loops.
* X_test_jax: This is the data for the passengers we need to submit to Kaggle. The model has never seen this data before.
* Result: final_probs is a list of numbers between 0 and 1 (e.g., [0.1, 0.9, 0.45, ...]).
Part 2: Making the Decision (Thresholding)Kaggle doesn't want probabilities (e.g., "I'm 70% sure"); they want a binary answer (0 or 1).
predictions = (final_probs > 0.5).astype(int).flatten()
(final_probs > 0.5): This creates a boolean array.If probability is 0.8: 0.8 > 0.5 is True.If probability is 0.2: 0.2 > 0.5 is False..astype(int): Converts the Booleans into numbers.True becomes 1 (Survived).False becomes 0 (Died)..flatten(): JAX often works with 2D column vectors (e.g., a shape of (418, 1)). Pandas prefers a simple 1D list (shape (418,)). flatten() squashes the column into a flat list.Part 3: Formatting and SavingPythonoutput = pd.DataFrame({'PassengerId': test_data.PassengerId, 'Survived': predictions})
pd.DataFrame(...): We reconstruct the spreadsheet format Kaggle requires.'PassengerId': test_data.PassengerId: We pull the original IDs (892, 893, etc.) from the test file we loaded at the very beginning.'Survived': predictions: We attach our list of 0s and 1s next to the IDs.Pythonoutput.to_csv('submission_jax.csv', index=False)
to_csv: Writes the dataframe to a physical text file.index=False: Critical! If you don't say False, Pandas will add a row number (0, 1, 2...) as the very first column. Kaggle's automated grader will reject the file because it expects exactly two columns, not three.Pythonprint("JAX submission saved successfully!")
Confirmation: A simple message to let you know the script finished without errors and the file is ready to be downloaded or submitted.
# 1. Carga los archivos de predicciones
df_rf = pd.read_csv('submission.csv')
df_jax = pd.read_csv('submission_jax.csv')
df_rf.head()
| PassengerId | Survived | |
|---|---|---|
| 0 | 892 | 0 |
| 1 | 893 | 1 |
| 2 | 894 | 0 |
| 3 | 895 | 0 |
| 4 | 896 | 1 |
df_jax.head()
| PassengerId | Survived | |
|---|---|---|
| 0 | 892 | 0 |
| 1 | 893 | 1 |
| 2 | 894 | 0 |
| 3 | 895 | 0 |
| 4 | 896 | 0 |
# 2. FUSIONAR (MERGE) LAS TABLAS
# Usamos 'on="PassengerId"' para asegurar que alineamos al mismo pasajero
# Usamos 'suffixes' para renombrar las columnas automáticamente
comparacion = pd.merge(df_rf, df_jax, on='PassengerId', suffixes=('_rf', '_jax'))
print("--- Tabla Fusionada (Vista Previa) ---")
print(comparacion)
print("\n")
--- Tabla Fusionada (Vista Previa) ---
PassengerId Survived_rf Survived_jax
0 892 0 0
1 893 1 1
2 894 0 0
3 895 0 0
4 896 1 0
.. ... ... ...
413 1305 0 0
414 1306 1 1
415 1307 0 0
416 1308 0 0
417 1309 0 0
[418 rows x 3 columns]
# 3. FILTRAR DIFERENCIAS
# Buscamos filas donde la predicción RF NO SEA IGUAL (!=) a la predicción JAX
diferencias = comparacion[comparacion['Survived_rf'] != comparacion['Survived_jax']]
print(f"--- Se encontraron {len(diferencias)} diferencias ---")
print(diferencias)
# Opcional: Calcular el porcentaje de desacuerdo
total = len(comparacion)
diff_count = len(diferencias)
print(f"\nPorcentaje de desacuerdo: {(diff_count/total)*100}%")
--- Se encontraron 20 diferencias ---
PassengerId Survived_rf Survived_jax
4 896 1 0
104 996 1 0
117 1009 1 0
127 1019 1 0
142 1034 1 0
165 1057 1 0
181 1073 0 1
194 1086 1 0
196 1088 0 1
217 1109 0 1
242 1134 0 1
263 1155 1 0
283 1175 1 0
284 1176 1 0
293 1185 0 1
308 1200 0 1
333 1225 1 0
376 1268 1 0
407 1299 0 1
409 1301 1 0
Porcentaje de desacuerdo: 4.784688995215311%
Using Random Forest alone and Random Forest with JAX has 4.78% difference in data.¶
I asked Gemini to use the same Random forest model and run it with JAX and between the two datasets there is a difference in 20 data points. I don´t have enough knoledge to understand why there is this difference. I think may be there is a random generator that changes the result but I can´t tell.
Validation Set¶
Since you cannot verify your accuracy on the official Test data, you must simulate a test locally. You do this by taking your Training data (which does have the answers) and slicing off a small piece (e.g., 20%) to pretend it is test data. This is called a Validation Set.¶
Gemini prompt: Generate a script to show you how to create a "local test" so I can verify if JAX or Random Forest models are actually working¶
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
# 1. Load the Training Data (This is the only file with "Answers")
train_data = pd.read_csv('/home/jovyan/work/aristarco-cortes/datasets/Titanic_train.csv')
# 2. Prepare features (Simple version)
features = ["Pclass", "Sex", "SibSp", "Parch"]
X = pd.get_dummies(train_data[features])
y = train_data["Survived"]
# 3. SPLIT the data
# We hide 20% of the data from the model. This 20% becomes our "Local Test Set"
# random_state=42 ensures the split is the same every time
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
print(f"Training on {len(X_train)} passengers.")
print(f"Validating on {len(X_val)} passengers (Hidden from model).")
# 4. Train the model ONLY on the 80%
model = RandomForestClassifier(n_estimators=100, max_depth=5, random_state=1)
model.fit(X_train, y_train)
# 5. Predict on the hidden 20%
predictions = model.predict(X_val)
# 6. COMPARE with the Real Results (y_val)
# This is possible because y_val comes from the original train.csv
accuracy = accuracy_score(y_val, predictions)
print("--------------------------")
print(f"Model Accuracy: {accuracy:.2%}")
print("--------------------------")
# Show a comparison table of the first 10 rows
comparison = pd.DataFrame({'Actual': y_val, 'Predicted': predictions})
print(comparison.head(10))
Training on 712 passengers.
Validating on 179 passengers (Hidden from model).
--------------------------
Model Accuracy: 79.89%
--------------------------
Actual Predicted
709 1 0
439 0 0
840 0 0
720 1 1
39 1 1
290 1 1
300 1 1
333 0 0
208 1 1
136 1 1
#Validation code for JAX Model
import pandas as pd
import numpy as np
import jax
import jax.numpy as jnp
from jax import grad, jit, random
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# 1. Load the Training Data
train_data = pd.read_csv('/home/jovyan/work/aristarco-cortes/datasets/Titanic_train.csv')
# 2. Prepare features
features = ["Pclass", "Sex", "SibSp", "Parch"]
X = pd.get_dummies(train_data[features]).astype(float)
y = train_data["Survived"].values
# 3. SPLIT the data (80% Train, 20% Validation)
X_train_pd, X_val_pd, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
# --- JAX SPECIFIC PREPROCESSING ---
# Convert to JAX arrays
X_train_jax = jnp.array(X_train_pd.values, dtype=jnp.float32)
X_val_jax = jnp.array(X_val_pd.values, dtype=jnp.float32)
y_train_jax = jnp.array(y_train, dtype=jnp.float32).reshape(-1, 1)
# NORMALIZE the data (Crucial for Neural Networks)
# We calculate mean/std on TRAINING data only to avoid "data leakage"
mean = X_train_jax.mean(axis=0)
std = X_train_jax.std(axis=0)
X_train_jax = (X_train_jax - mean) / std
X_val_jax = (X_val_jax - mean) / std # Apply same scaling to validation set
print(f"Training on {len(X_train_jax)} passengers.")
print(f"Validating on {len(X_val_jax)} passengers.")
# 4. DEFINE JAX MODEL FUNCTIONS
def sigmoid(x):
return 1 / (1 + jnp.exp(-x))
def predict(params, inputs):
# Layer 1: Matrix mul + Bias -> ReLU
hidden = jnp.dot(inputs, params['w1']) + params['b1']
hidden = jnp.maximum(0, hidden)
# Layer 2: Output -> Sigmoid
logits = jnp.dot(hidden, params['w2']) + params['b2']
return sigmoid(logits)
def loss_fn(params, inputs, targets):
preds = predict(params, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.mean(jnp.log(label_probs + 1e-7))
@jit
def update(params, x, y, lr=0.1):
grads = grad(loss_fn)(params, x, y)
return {
'w1': params['w1'] - lr * grads['w1'],
'b1': params['b1'] - lr * grads['b1'],
'w2': params['w2'] - lr * grads['w2'],
'b2': params['b2'] - lr * grads['b2'],
}
# 5. INITIALIZE AND TRAIN
key = random.PRNGKey(42)
input_dim = X_train_jax.shape[1]
hidden_dim = 10
# Initialize weights slightly randomly
w1 = random.normal(key, (input_dim, hidden_dim)) * 0.1
b1 = jnp.zeros(hidden_dim)
w2 = random.normal(key, (hidden_dim, 1)) * 0.1
b2 = jnp.zeros(1)
params = {'w1': w1, 'b1': b1, 'w2': w2, 'b2': b2}
print("Starting training loop...")
for i in range(2000):
params = update(params, X_train_jax, y_train_jax, lr=0.1)
if i % 500 == 0:
loss = loss_fn(params, X_train_jax, y_train_jax)
print(f"Epoch {i}, Loss: {loss:.4f}")
# 6. PREDICT on Validation Set
print("Predicting on validation set...")
final_probs = predict(params, X_val_jax)
# Flatten JAX array and convert to standard Numpy array to avoid Pandas issues
predictions = np.array((final_probs > 0.5).astype(int).flatten())
# 7. COMPARE with Real Results
accuracy = accuracy_score(y_val, predictions)
print("--------------------------")
print(f"JAX Model Accuracy: {accuracy:.2%}")
print("--------------------------")
# Show comparison
# We use y_val which is already a numpy array (values) and our new predictions array
comparison = pd.DataFrame({'Actual': y_val, 'Predicted': predictions})
print(comparison.head(10))
Training on 712 passengers. Validating on 179 passengers. Starting training loop... Epoch 0, Loss: 0.6977 Epoch 500, Loss: 0.4352 Epoch 1000, Loss: 0.4288 Epoch 1500, Loss: 0.4262 Predicting on validation set... -------------------------- JAX Model Accuracy: 79.89% -------------------------- Actual Predicted 0 1 0 1 0 0 2 0 0 3 1 1 4 1 1 5 1 1 6 1 1 7 0 0 8 1 1 9 1 1
With different results i found suspicious that both models have the same accuracy having differences so I asked Gemini for this differences since they run the same Random Forest Model.¶
Prompt: Actua como experto en data science y Machine learning. Corrí el modelo Random forest para un dataset del Titanic y luego pedí a Gemini que corriera el mismo modelo utilizando JAX para mejorar el rendimiento. Al obtener el resultado encuentro en los datos 20 diferencias. ¿por qué sucede esto si se supone que es el mismo modelo?¶
The answer
As a Data Science expert, I can tell you that encountering differences between a Random Forest implementation in Scikit-Learn and one in JAX is completely normal—and even expected.Even if the underlying logic is the same, the technical execution differs significantly.
Here is why you are seeing those 20 discrepancies:
1. Algorithmic ImplementationWhile both are called "Random Forest," the libraries often use different "engines":Scikit-Learn: Uses a highly optimized version of the CART (Classification and Regression Trees) algorithm.JAX: JAX does not have a native "Random Forest" module. The code Gemini provided likely uses jax.numpy to manually calculate splits or a library like Haiku or Equinox. Any tiny variation in how Gini Impurity or Entropy is calculated due to floating-point precision can change an entire branch of a tree.
2. Random Number Generation (PRNG) This is the most likely culprit. Random Forest relies on randomness at two levels:Bootstrapping: Randomly selecting samples for each tree.Feature Selection: Randomly selecting a subset of columns at each node.Scikit-Learn uses the NumPy random generator (state-based). JAX, by functional design, uses an explicit PRNG (Pseudo-Random Number Generation) system based on "keys" that you must manually split (jax.random.split). It is virtually impossible for these two different generators to produce the exact same sequence of numbers, meaning the trees will be structurally different.
3. Numerical Precision and HardwareFloat32 vs. Float64: JAX often defaults to float32 to maximize speed (especially on GPUs/TPUs), whereas Scikit-Learn typically uses float64 on CPU.In the Titanic dataset, where many variables are categorical or small scales (like Fare), a difference of 10^{-7} in a decision threshold can push a passenger to the left branch instead of the right.
4. Tie-breaking Criteria If two features provide the exact same information gain when trying to split a node, each library has its own internal rule to decide which one to pick. In small datasets like the Titanic, these "ties" happen more often than you'd think.
Learning Points¶
Go slow, Think thrice I learned that I still have many blindspots in Data Science. Asking AI my doubts can become a very powerful tool to learn and better understand
Key takeaway: In Machine Learning, two models giving slightly different results doesn't mean one is "wrong." 20 differences in the Titanic dataset (which has ~891 training or 418 test rows) represents a ~2-4% variation. This is well within the expected variance range when changing random seeds or library implementations.