import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neural_network import MLPRegressor
from sklearn.metrics import mean_squared_error, r2_score
import jax
import jax.numpy as jnp
from jax import random, grad, jit
df = pd.read_csv("datasets/Housing.csv")
df.head()
| price | area | bedrooms | bathrooms | stories | mainroad | guestroom | basement | hotwaterheating | airconditioning | parking | prefarea | furnishingstatus | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 13300000 | 7420 | 4 | 2 | 3 | yes | no | no | no | yes | 2 | yes | furnished |
| 1 | 12250000 | 8960 | 4 | 4 | 4 | yes | no | no | no | yes | 3 | no | furnished |
| 2 | 12250000 | 9960 | 3 | 2 | 2 | yes | no | yes | no | no | 2 | yes | semi-furnished |
| 3 | 12215000 | 7500 | 4 | 2 | 2 | yes | no | yes | no | yes | 3 | yes | furnished |
| 4 | 11410000 | 7420 | 4 | 1 | 2 | yes | yes | yes | no | yes | 2 | no | furnished |
# Encode yes/no columns as 1/0
yes_no_cols = ["mainroad", "guestroom", "basement", "hotwaterheating", "airconditioning", "prefarea"]
for col in yes_no_cols:
df[col] = df[col].map({'yes':1,'no':0})
# One-hot encode furnishingstatus
df = pd.get_dummies(df, columns=['furnishingstatus'], drop_first=True)
#Features and target
X = df.drop("price", axis=1)
y = df["price"].values
# Scale features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# Train-test split
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
Train scikit-learn MLP Regressor¶
mlp_model = MLPRegressor(hidden_layer_sizes=(64,32), activation='relu',
solver='adam', learning_rate_init=0.001,
max_iter=100, random_state=1, verbose=True)
mlp_model.fit(X_train, y_train)
# Predict & evaluate
y_pred = mlp_model.predict(X_test)
print("Scikit-learn MLP Regressor MSE:", mean_squared_error(y_test, y_pred))
print("R² Score:", r2_score(y_test, y_pred))
# Plot actual vs predicted
plt.figure(figsize=(8,6))
plt.scatter(y_test, y_pred, alpha=0.5)
plt.xlabel("Actual Price")
plt.ylabel("Predicted Price")
plt.title("Actual vs Predicted House Prices (MLP Regressor)")
plt.show()
Iteration 1, loss = 12617395412237.01367188 Iteration 2, loss = 12617394416759.10351562 Iteration 3, loss = 12617393472502.44921875 Iteration 4, loss = 12617392531119.09179688 Iteration 5, loss = 12617391604663.75390625 Iteration 6, loss = 12617390674047.72851562 Iteration 7, loss = 12617389722099.65820312 Iteration 8, loss = 12617388741018.42187500 Iteration 9, loss = 12617387690483.00976562 Iteration 10, loss = 12617386599193.39062500 Iteration 11, loss = 12617385431479.14062500 Iteration 12, loss = 12617384148094.23632812 Iteration 13, loss = 12617382789235.07617188 Iteration 14, loss = 12617381296951.57617188 Iteration 15, loss = 12617379662809.83984375 Iteration 16, loss = 12617377906499.59570312 Iteration 17, loss = 12617375993697.35351562 Iteration 18, loss = 12617373895995.53710938 Iteration 19, loss = 12617371607059.81250000 Iteration 20, loss = 12617369158984.76562500 Iteration 21, loss = 12617366485934.88671875 Iteration 22, loss = 12617363609008.75195312 Iteration 23, loss = 12617360559305.82421875 Iteration 24, loss = 12617357273120.06835938 Iteration 25, loss = 12617353725948.61523438 Iteration 26, loss = 12617349953833.50195312 Iteration 27, loss = 12617345850622.58398438 Iteration 28, loss = 12617341543995.25195312 Iteration 29, loss = 12617336924558.91992188 Iteration 30, loss = 12617332035943.63671875 Iteration 31, loss = 12617326871470.97851562 Iteration 32, loss = 12617321312193.64062500 Iteration 33, loss = 12617315423067.38085938 Iteration 34, loss = 12617309228359.27343750 Iteration 35, loss = 12617302533884.84179688 Iteration 36, loss = 12617295655840.44726562 Iteration 37, loss = 12617288064181.66015625 Iteration 38, loss = 12617280294106.18750000 Iteration 39, loss = 12617271859208.51367188 Iteration 40, loss = 12617263183767.55078125 Iteration 41, loss = 12617254069259.33789062 Iteration 42, loss = 12617244121617.87304688 Iteration 43, loss = 12617233862730.45507812 Iteration 44, loss = 12617223107923.23437500 Iteration 45, loss = 12617211664649.01562500 Iteration 46, loss = 12617199666200.10546875 Iteration 47, loss = 12617186789168.54296875 Iteration 48, loss = 12617173262987.20117188 Iteration 49, loss = 12617158758841.26757812 Iteration 50, loss = 12617143865462.75000000 Iteration 51, loss = 12617127254316.85351562 Iteration 52, loss = 12617110753280.50390625 Iteration 53, loss = 12617092763710.79101562 Iteration 54, loss = 12617073986271.40429688 Iteration 55, loss = 12617054710725.01367188 Iteration 56, loss = 12617034011283.27929688 Iteration 57, loss = 12617012588638.78710938 Iteration 58, loss = 12616990881567.46484375 Iteration 59, loss = 12616967493938.21289062 Iteration 60, loss = 12616943111488.75195312 Iteration 61, loss = 12616918695659.51562500 Iteration 62, loss = 12616893093700.67968750 Iteration 63, loss = 12616866351310.09765625 Iteration 64, loss = 12616838573899.48242188 Iteration 65, loss = 12616810119722.33398438 Iteration 66, loss = 12616780350607.40625000 Iteration 67, loss = 12616749866788.15820312 Iteration 68, loss = 12616717792651.42968750 Iteration 69, loss = 12616685030384.77929688 Iteration 70, loss = 12616650974232.16601562 Iteration 71, loss = 12616615257122.50976562 Iteration 72, loss = 12616577863723.77148438 Iteration 73, loss = 12616539794626.67968750 Iteration 74, loss = 12616500772086.98242188 Iteration 75, loss = 12616459417649.61523438 Iteration 76, loss = 12616417899265.57812500 Iteration 77, loss = 12616375174957.74804688 Iteration 78, loss = 12616329472201.04296875 Iteration 79, loss = 12616282914916.66992188 Iteration 80, loss = 12616233866473.27148438 Iteration 81, loss = 12616183890409.84765625 Iteration 82, loss = 12616132699212.46875000 Iteration 83, loss = 12616079118824.40429688 Iteration 84, loss = 12616022183785.92968750 Iteration 85, loss = 12615966815205.11718750 Iteration 86, loss = 12615906575812.15234375 Iteration 87, loss = 12615844459320.86914062 Iteration 88, loss = 12615780800369.15625000 Iteration 89, loss = 12615716980326.33789062 Iteration 90, loss = 12615650340522.09570312 Iteration 91, loss = 12615582789871.71093750 Iteration 92, loss = 12615513599969.33007812 Iteration 93, loss = 12615442140751.60937500 Iteration 94, loss = 12615369961571.83984375 Iteration 95, loss = 12615295270860.76757812 Iteration 96, loss = 12615219543194.57031250 Iteration 97, loss = 12615141861002.65820312 Iteration 98, loss = 12615058433093.70898438 Iteration 99, loss = 12614976667262.54296875 Iteration 100, loss = 12614889307139.66992188 Scikit-learn MLP Regressor MSE: 30124185687573.863 R² Score: -4.959791677496043
/opt/conda/lib/python3.13/site-packages/sklearn/neural_network/_multilayer_perceptron.py:781: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (100) reached and the optimization hasn't converged yet. warnings.warn(
The scatter plot compares the actual house prices with the prices predicted by the MLP Regressor model. Ideally, the points should lie close to a straight diagonal line, which would indicate accurate predictions.
However, in this graph, the points are widely scattered and do not follow a clear pattern. This means the model is unable to capture the relationship between the input features and the house prices.
The high Mean Squared Error (MSE) and negative R² score confirm that the model’s predictions are very poor. Overall, the Scikit-learn neural network fails to predict the prices accurately and shows strong signs of underfitting.
JAX Neural Network Regression¶
# Convert data to jax arrays
X_train_jax = jnp.array(X_train)
y_train_jax = jnp.array(y_train).reshape(-1,1)
X_test_jax = jnp.array(X_test)
y_test_jax = jnp.array(y_test).reshape(-1,1)
# Hyperparameters
input_size = X_train.shape[1]
hidden_size = 32
output_size = 1
learning_rate = 0.01
train_steps = 500
# Initialize random key
key = random.PRNGKey(0)
# Initialize parameters
def init_params(key, input_size, hidden_size, output_size):
key1, key2 = random.split(key)
W1 = 0.01*random.normal(key1, (input_size, hidden_size))
b1 = jnp.zeros(hidden_size)
W2 = 0.01*random.normal(key2, (hidden_size, output_size))
b2 = jnp.zeros(output_size)
return (W1,b1,W2,b2)
params = init_params(key, input_size, hidden_size, output_size)
# Forward pass
@jit
def forward(params, X):
W1,b1,W2,b2 = params
hidden = jnp.tanh(X @ W1 + b1)
output = hidden @ W2 + b2
return output
# Loss function (MSE)
@jit
def loss(params, X, y):
y_pred = forward(params, X)
return jnp.mean((y_pred - y)**2)
# Corrected Gradient update
@jit
def update(params, X, y, lr):
grads = grad(loss)(params, X, y)
return jax.tree_util.tree_map(lambda p, g: p - lr * g, params, grads)
# Training loop
for step in range(train_steps):
params = update(params, X_train_jax, y_train_jax, learning_rate)
if step % 100 == 0:
print(f"Step {step}, Loss: {loss(params, X_train_jax, y_train_jax):.4f}")
# Evaluate on test set
y_pred_jax = forward(params, X_test_jax)
test_loss = jnp.mean((y_pred_jax - y_test_jax)**2)
print("JAX Neural Network Test MSE:", test_loss)
# Plot predicted vs actual
plt.figure(figsize=(8,6))
plt.scatter(y_test, np.array(y_pred_jax), alpha=0.5)
plt.xlabel("Actual Price")
plt.ylabel("Predicted Price")
plt.title("Actual vs Predicted House Prices (JAX NN)")
plt.show()
Step 0, Loss: 24327185498112.0000 Step 100, Loss: 3261978640384.0000 Step 200, Loss: 2168709971968.0000 Step 300, Loss: 1725492232192.0000 Step 400, Loss: 1531514716160.0000 JAX Neural Network Test MSE: 2.4613424e+12
Similar to the first graph, this plot compares the actual values with the predicted house prices generated by the JAX neural network. The points again fail to align around the diagonal line and remain scattered across the graph.
This indicates that the JAX model also struggles to learn meaningful patterns from the dataset. Although the loss decreases during training, the final test MSE is still extremely high, and the predictions remain far from the actual values.
This suggests that the model architecture, training duration, or learning rate may not be sufficient for the complexity of the price data. In summary, the JAX neural network also performs poorly in predicting house prices.