[karma Tshomo] - Fab Futures - Data Science
Home About
In [22]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPRegressor
from sklearn.metrics import mean_squared_error, r2_score
import jax
import jax.numpy as jnp
from jax import random, grad, jit
In [23]:
df = pd.read_csv("datasets/Housing.csv")
df.head()
Out[23]:
price area bedrooms bathrooms stories mainroad guestroom basement hotwaterheating airconditioning parking prefarea furnishingstatus
0 13300000 7420 4 2 3 yes no no no yes 2 yes furnished
1 12250000 8960 4 4 4 yes no no no yes 3 no furnished
2 12250000 9960 3 2 2 yes no yes no no 2 yes semi-furnished
3 12215000 7500 4 2 2 yes no yes no yes 3 yes furnished
4 11410000 7420 4 1 2 yes yes yes no yes 2 no furnished
In [24]:
# Encode yes/no columns as 1/0
yes_no_cols = ["mainroad", "guestroom", "basement", "hotwaterheating", "airconditioning", "prefarea"]
for col in yes_no_cols:
    df[col] = df[col].map({'yes':1,'no':0})

# One-hot encode furnishingstatus
df = pd.get_dummies(df, columns=['furnishingstatus'], drop_first=True)

#Features and target
X = df.drop("price", axis=1)
y = df["price"].values

# Scale features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)

Train scikit-learn MLP Regressor¶

In [30]:
 

mlp_model = MLPRegressor(hidden_layer_sizes=(64,32), activation='relu',
                         solver='adam', learning_rate_init=0.001,
                         max_iter=100, random_state=1, verbose=True)

mlp_model.fit(X_train, y_train)

# Predict & evaluate
y_pred = mlp_model.predict(X_test)
print("Scikit-learn MLP Regressor MSE:", mean_squared_error(y_test, y_pred))
print("R² Score:", r2_score(y_test, y_pred))

# Plot actual vs predicted
plt.figure(figsize=(8,6))
plt.scatter(y_test, y_pred, alpha=0.5)
plt.xlabel("Actual Price")
plt.ylabel("Predicted Price")
plt.title("Actual vs Predicted House Prices (MLP Regressor)")
plt.show()
Iteration 1, loss = 12617395412237.01367188
Iteration 2, loss = 12617394416759.10351562
Iteration 3, loss = 12617393472502.44921875
Iteration 4, loss = 12617392531119.09179688
Iteration 5, loss = 12617391604663.75390625
Iteration 6, loss = 12617390674047.72851562
Iteration 7, loss = 12617389722099.65820312
Iteration 8, loss = 12617388741018.42187500
Iteration 9, loss = 12617387690483.00976562
Iteration 10, loss = 12617386599193.39062500
Iteration 11, loss = 12617385431479.14062500
Iteration 12, loss = 12617384148094.23632812
Iteration 13, loss = 12617382789235.07617188
Iteration 14, loss = 12617381296951.57617188
Iteration 15, loss = 12617379662809.83984375
Iteration 16, loss = 12617377906499.59570312
Iteration 17, loss = 12617375993697.35351562
Iteration 18, loss = 12617373895995.53710938
Iteration 19, loss = 12617371607059.81250000
Iteration 20, loss = 12617369158984.76562500
Iteration 21, loss = 12617366485934.88671875
Iteration 22, loss = 12617363609008.75195312
Iteration 23, loss = 12617360559305.82421875
Iteration 24, loss = 12617357273120.06835938
Iteration 25, loss = 12617353725948.61523438
Iteration 26, loss = 12617349953833.50195312
Iteration 27, loss = 12617345850622.58398438
Iteration 28, loss = 12617341543995.25195312
Iteration 29, loss = 12617336924558.91992188
Iteration 30, loss = 12617332035943.63671875
Iteration 31, loss = 12617326871470.97851562
Iteration 32, loss = 12617321312193.64062500
Iteration 33, loss = 12617315423067.38085938
Iteration 34, loss = 12617309228359.27343750
Iteration 35, loss = 12617302533884.84179688
Iteration 36, loss = 12617295655840.44726562
Iteration 37, loss = 12617288064181.66015625
Iteration 38, loss = 12617280294106.18750000
Iteration 39, loss = 12617271859208.51367188
Iteration 40, loss = 12617263183767.55078125
Iteration 41, loss = 12617254069259.33789062
Iteration 42, loss = 12617244121617.87304688
Iteration 43, loss = 12617233862730.45507812
Iteration 44, loss = 12617223107923.23437500
Iteration 45, loss = 12617211664649.01562500
Iteration 46, loss = 12617199666200.10546875
Iteration 47, loss = 12617186789168.54296875
Iteration 48, loss = 12617173262987.20117188
Iteration 49, loss = 12617158758841.26757812
Iteration 50, loss = 12617143865462.75000000
Iteration 51, loss = 12617127254316.85351562
Iteration 52, loss = 12617110753280.50390625
Iteration 53, loss = 12617092763710.79101562
Iteration 54, loss = 12617073986271.40429688
Iteration 55, loss = 12617054710725.01367188
Iteration 56, loss = 12617034011283.27929688
Iteration 57, loss = 12617012588638.78710938
Iteration 58, loss = 12616990881567.46484375
Iteration 59, loss = 12616967493938.21289062
Iteration 60, loss = 12616943111488.75195312
Iteration 61, loss = 12616918695659.51562500
Iteration 62, loss = 12616893093700.67968750
Iteration 63, loss = 12616866351310.09765625
Iteration 64, loss = 12616838573899.48242188
Iteration 65, loss = 12616810119722.33398438
Iteration 66, loss = 12616780350607.40625000
Iteration 67, loss = 12616749866788.15820312
Iteration 68, loss = 12616717792651.42968750
Iteration 69, loss = 12616685030384.77929688
Iteration 70, loss = 12616650974232.16601562
Iteration 71, loss = 12616615257122.50976562
Iteration 72, loss = 12616577863723.77148438
Iteration 73, loss = 12616539794626.67968750
Iteration 74, loss = 12616500772086.98242188
Iteration 75, loss = 12616459417649.61523438
Iteration 76, loss = 12616417899265.57812500
Iteration 77, loss = 12616375174957.74804688
Iteration 78, loss = 12616329472201.04296875
Iteration 79, loss = 12616282914916.66992188
Iteration 80, loss = 12616233866473.27148438
Iteration 81, loss = 12616183890409.84765625
Iteration 82, loss = 12616132699212.46875000
Iteration 83, loss = 12616079118824.40429688
Iteration 84, loss = 12616022183785.92968750
Iteration 85, loss = 12615966815205.11718750
Iteration 86, loss = 12615906575812.15234375
Iteration 87, loss = 12615844459320.86914062
Iteration 88, loss = 12615780800369.15625000
Iteration 89, loss = 12615716980326.33789062
Iteration 90, loss = 12615650340522.09570312
Iteration 91, loss = 12615582789871.71093750
Iteration 92, loss = 12615513599969.33007812
Iteration 93, loss = 12615442140751.60937500
Iteration 94, loss = 12615369961571.83984375
Iteration 95, loss = 12615295270860.76757812
Iteration 96, loss = 12615219543194.57031250
Iteration 97, loss = 12615141861002.65820312
Iteration 98, loss = 12615058433093.70898438
Iteration 99, loss = 12614976667262.54296875
Iteration 100, loss = 12614889307139.66992188
Scikit-learn MLP Regressor MSE: 30124185687573.863
R² Score: -4.959791677496043
/opt/conda/lib/python3.13/site-packages/sklearn/neural_network/_multilayer_perceptron.py:781: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (100) reached and the optimization hasn't converged yet.
  warnings.warn(
No description has been provided for this image

The scatter plot compares the actual house prices with the prices predicted by the MLP Regressor model. Ideally, the points should lie close to a straight diagonal line, which would indicate accurate predictions.

However, in this graph, the points are widely scattered and do not follow a clear pattern. This means the model is unable to capture the relationship between the input features and the house prices.

The high Mean Squared Error (MSE) and negative R² score confirm that the model’s predictions are very poor. Overall, the Scikit-learn neural network fails to predict the prices accurately and shows strong signs of underfitting.

JAX Neural Network Regression¶

In [26]:
# Convert data to jax arrays
X_train_jax = jnp.array(X_train)
y_train_jax = jnp.array(y_train).reshape(-1,1)
X_test_jax = jnp.array(X_test)
y_test_jax = jnp.array(y_test).reshape(-1,1)

# Hyperparameters
input_size = X_train.shape[1]
hidden_size = 32
output_size = 1
learning_rate = 0.01
train_steps = 500

# Initialize random key
key = random.PRNGKey(0)

# Initialize parameters
def init_params(key, input_size, hidden_size, output_size):
    key1, key2 = random.split(key)
    W1 = 0.01*random.normal(key1, (input_size, hidden_size))
    b1 = jnp.zeros(hidden_size)
    W2 = 0.01*random.normal(key2, (hidden_size, output_size))
    b2 = jnp.zeros(output_size)
    return (W1,b1,W2,b2)

params = init_params(key, input_size, hidden_size, output_size)

# Forward pass
@jit
def forward(params, X):
    W1,b1,W2,b2 = params
    hidden = jnp.tanh(X @ W1 + b1)
    output = hidden @ W2 + b2
    return output

# Loss function (MSE)
@jit
def loss(params, X, y):
    y_pred = forward(params, X)
    return jnp.mean((y_pred - y)**2)

# Corrected Gradient update
@jit
def update(params, X, y, lr):
    grads = grad(loss)(params, X, y)
    return jax.tree_util.tree_map(lambda p, g: p - lr * g, params, grads)

# Training loop
for step in range(train_steps):
    params = update(params, X_train_jax, y_train_jax, learning_rate)
    if step % 100 == 0:
        print(f"Step {step}, Loss: {loss(params, X_train_jax, y_train_jax):.4f}")

# Evaluate on test set
y_pred_jax = forward(params, X_test_jax)
test_loss = jnp.mean((y_pred_jax - y_test_jax)**2)
print("JAX Neural Network Test MSE:", test_loss)

# Plot predicted vs actual
plt.figure(figsize=(8,6))
plt.scatter(y_test, np.array(y_pred_jax), alpha=0.5)
plt.xlabel("Actual Price")
plt.ylabel("Predicted Price")
plt.title("Actual vs Predicted House Prices (JAX NN)")
plt.show()
Step 0, Loss: 24327185498112.0000
Step 100, Loss: 3261978640384.0000
Step 200, Loss: 2168709971968.0000
Step 300, Loss: 1725492232192.0000
Step 400, Loss: 1531514716160.0000
JAX Neural Network Test MSE: 2.4613424e+12
No description has been provided for this image

Similar to the first graph, this plot compares the actual values with the predicted house prices generated by the JAX neural network. The points again fail to align around the diagonal line and remain scattered across the graph.

This indicates that the JAX model also struggles to learn meaningful patterns from the dataset. Although the loss decreases during training, the final test MSE is still extremely high, and the predictions remain far from the actual values.

This suggests that the model architecture, training duration, or learning rate may not be sufficient for the complexity of the price data. In summary, the JAX neural network also performs poorly in predicting house prices.