Class 4: Machine LearningΒΆ
I tried the MNIST example from Neil's page:
InΒ [6]:
import jax
import jax.numpy as jnp
from jax import random,grad,jit
from sklearn.neural_network import MLPClassifier
import numpy as np
xtrain = np.load('datasets/MNIST/xtrain.npy')
ytrain = np.load('datasets/MNIST/ytrain.npy')
xtest = np.load('datasets/MNIST/xtest.npy')
ytest = np.load('datasets/MNIST/ytest.npy')
print(f"read {xtrain.shape[1]} byte data records, {xtrain.shape[0]} training examples, {xtest.shape[0]} testing examples\n")
classifier = MLPClassifier(solver='adam',hidden_layer_sizes=(100),activation='relu',random_state=1,verbose=True,tol=0.05)
classifier.fit(xtrain,ytrain)
print(f"\ntest score: {classifier.score(xtest,ytest)}\n")
predictions = classifier.predict(xtest)
read 784 byte data records, 60000 training examples, 10000 testing examples Iteration 1, loss = 3.36992820 Iteration 2, loss = 1.13264743 Iteration 3, loss = 0.67881655 Iteration 4, loss = 0.44722907 Iteration 5, loss = 0.31658618 Iteration 6, loss = 0.23506685 Iteration 7, loss = 0.19331921 Iteration 8, loss = 0.15768276 Iteration 9, loss = 0.13673548 Iteration 10, loss = 0.12379790 Iteration 11, loss = 0.10733766 Iteration 12, loss = 0.11199584 Iteration 13, loss = 0.09769195 Iteration 14, loss = 0.09220702 Iteration 15, loss = 0.09282348 Iteration 16, loss = 0.08964422 Iteration 17, loss = 0.08613192 Training loss did not improve more than tol=0.050000 for 10 consecutive epochs. Stopping. test score: 0.9548
InΒ [7]:
import jax
import jax.numpy as jnp
from jax import random,grad,jit
import matplotlib.pyplot as plt
fig,axs = plt.subplots(1,5)
for i in range(5):
axs[i].imshow(jnp.reshape(xtest[i],(28,28)))
axs[i].axis('off')
axs[i].set_title(f"predict: {predictions[i]}")
plt.tight_layout()
plt.show()
Then I attempted a neural network fit of the simple drop test data from last week. A falling object is trivially easy to model using a neural network, but it's nice to be able to apply machine learning to any sensor data that I have. Let's load the drop test data:
InΒ [4]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
df = pd.read_csv('datasets/sensors/drop02.csv')
height = -df['0'][600:-250]
time = np.linspace(0,len(height)-1,len(height))
scale_factor = (0.5-0)/(520-22)
height_scaled = scale_factor*height
time_scaled = time/1000
And set up TensorFlow (as shown here). Before running the code below I have to open a terminal in JupyterLab and pip install tensorflow.
InΒ [5]:
from tensorflow import keras
import tensorflow as tf
import math
# Create the model
model = keras.Sequential()
model.add(keras.layers.Dense(units = 1, activation = 'linear', input_shape=[1]))
model.add(keras.layers.Dense(units = 64, activation = 'relu'))
model.add(keras.layers.Dense(units = 64, activation = 'relu'))
model.add(keras.layers.Dense(units = 1, activation = 'linear'))
model.compile(loss='mse', optimizer="adam")
# Display the model
model.summary()
/opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/framework/attr_value.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/framework/tensor.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/framework/resource_handle.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/framework/tensor_shape.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/framework/types.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/framework/full_type.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/framework/function.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/framework/node_def.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/framework/op_def.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/framework/graph.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/framework/graph_debug_info.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/framework/versions.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/protobuf/config.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at xla/tsl/protobuf/coordination_config.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/framework/cost_graph.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/framework/step_stats.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/framework/allocation_description.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/framework/tensor_description.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/protobuf/cluster.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at tensorflow/core/protobuf/debug.proto. Please update the gencode to avoid compatibility violations in the next runtime release. warnings.warn( /opt/conda/lib/python3.13/site-packages/keras/src/export/tf2onnx_lib.py:8: FutureWarning: In the future `np.object` will be defined as the corresponding NumPy scalar. if not hasattr(np, "object"): /opt/conda/lib/python3.13/site-packages/keras/src/layers/core/dense.py:106: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead. super().__init__(activity_regularizer=activity_regularizer, **kwargs)
Model: "sequential"
βββββββββββββββββββββββββββββββββββ³βββββββββββββββββββββββββ³ββββββββββββββββ β Layer (type) β Output Shape β Param # β β‘βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ© β dense (Dense) β (None, 1) β 2 β βββββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββΌββββββββββββββββ€ β dense_1 (Dense) β (None, 64) β 128 β βββββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββΌββββββββββββββββ€ β dense_2 (Dense) β (None, 64) β 4,160 β βββββββββββββββββββββββββββββββββββΌβββββββββββββββββββββββββΌββββββββββββββββ€ β dense_3 (Dense) β (None, 1) β 65 β βββββββββββββββββββββββββββββββββββ΄βββββββββββββββββββββββββ΄ββββββββββββββββ
Total params: 4,355 (17.01 KB)
Trainable params: 4,355 (17.01 KB)
Non-trainable params: 0 (0.00 B)
InΒ [7]:
# Training
model.fit(time_scaled, height_scaled, epochs=100, verbose=1)
Epoch 1/100 8/8 ββββββββββββββββββββ 1s 4ms/step - loss: 0.1930 Epoch 2/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.1363 Epoch 3/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0895 Epoch 4/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0566 Epoch 5/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0417 Epoch 6/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0396 Epoch 7/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0356 Epoch 8/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0311 Epoch 9/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0270 Epoch 10/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0226 Epoch 11/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0184 Epoch 12/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0150 Epoch 13/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0121 Epoch 14/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0099 Epoch 15/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0083 Epoch 16/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0074 Epoch 17/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0068 Epoch 18/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0062 Epoch 19/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0060 Epoch 20/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0057 Epoch 21/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0056 Epoch 22/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0054 Epoch 23/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0053 Epoch 24/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0052 Epoch 25/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0052 Epoch 26/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0051 Epoch 27/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0050 Epoch 28/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0048 Epoch 29/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0048 Epoch 30/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0047 Epoch 31/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0046 Epoch 32/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0046 Epoch 33/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0045 Epoch 34/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0043 Epoch 35/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0043 Epoch 36/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0042 Epoch 37/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0041 Epoch 38/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0040 Epoch 39/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0040 Epoch 40/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0039 Epoch 41/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0038 Epoch 42/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0038 Epoch 43/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0036 Epoch 44/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0035 Epoch 45/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0035 Epoch 46/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0034 Epoch 47/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0033 Epoch 48/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0033 Epoch 49/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0031 Epoch 50/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0031 Epoch 51/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0030 Epoch 52/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0029 Epoch 53/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0028 Epoch 54/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0028 Epoch 55/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0027 Epoch 56/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0026 Epoch 57/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0025 Epoch 58/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0025 Epoch 59/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0024 Epoch 60/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0024 Epoch 61/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0024 Epoch 62/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0022 Epoch 63/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0022 Epoch 64/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0021 Epoch 65/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0020 Epoch 66/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0020 Epoch 67/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0019 Epoch 68/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0018 Epoch 69/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0018 Epoch 70/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0017 Epoch 71/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0018 Epoch 72/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0018 Epoch 73/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0016 Epoch 74/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0016 Epoch 75/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0016 Epoch 76/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0015 Epoch 77/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0014 Epoch 78/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0014 Epoch 79/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0013 Epoch 80/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0013 Epoch 81/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0012 Epoch 82/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0012 Epoch 83/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0011 Epoch 84/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0011 Epoch 85/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0011 Epoch 86/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0011 Epoch 87/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 0.0010 Epoch 88/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 9.6904e-04 Epoch 89/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 9.3927e-04 Epoch 90/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 8.9222e-04 Epoch 91/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 8.6871e-04 Epoch 92/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 8.4325e-04 Epoch 93/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 8.0810e-04 Epoch 94/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 7.7178e-04 Epoch 95/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 7.5864e-04 Epoch 96/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 7.2359e-04 Epoch 97/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 6.9347e-04 Epoch 98/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 6.7897e-04 Epoch 99/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 6.4054e-04 Epoch 100/100 8/8 ββββββββββββββββββββ 0s 4ms/step - loss: 6.1879e-04
Out[7]:
<keras.src.callbacks.history.History at 0xe8ad0c624050>
InΒ [14]:
# Compute the output
height_predicted = model.predict(time_scaled)
# Display the result
plt.scatter(time_scaled, height_scaled)
plt.plot(time_scaled, height_predicted, 'r', linewidth=4)
plt.title("Drop test with neural network fit")
plt.legend(['Drop test data', '100 epoch neural network fit with TensorFlow'])
# plt.grid()
plt.show()
8/8 ββββββββββββββββββββ 0s 3ms/step