Week 2: Model fitting¶
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import math
df = pd.read_csv("datasets/data_oo23_181a.csv", sep=",") # import as data frame
df = df.iloc[::10].reset_index(drop=True) # Keep every 10th row
df[['pitch','roll','head']] = df[['pitch','roll','head']]/math.pi*180 # Convert prh from radians to degrees
# Create categorical variable with deep vs shallow
df = df.assign(deep=df[['depth']]>25)
df["deep"] = df["deep"].map({True: ">25m", False: "<=25m"}).astype("category")
Fit polynomials¶
X-axis acceleration vs pitch data looked pretty nice for this.
x = df["pitch"].values # pulling data from df not array, .values to ensure ndim=1
y = df["ax"].values
xmin = min(x)
xmax = max(x)
npts = 100
coeff1 = np.polyfit(x,y,1) # fit 1st-order polynomial
coeff2 = np.polyfit(x,y,2) # fit 2nd-order polynomial
coeff3 = np.polyfit(x,y,3) # fit 3rd-order polynomial
xfit = np.arange(xmin,xmax,(xmax-xmin)/npts)
pfit1 = np.poly1d(coeff1)
yfit1 = pfit1(xfit) # evaluate fit
pfit2 = np.poly1d(coeff2)
yfit2 = pfit2(xfit) # evaluate fit
pfit3 = np.poly1d(coeff3)
yfit3 = pfit3(xfit) # evaluate fit
print(f"fit coefficients: {coeff3}")
plt.plot(x,y,'o')
plt.plot(xfit,yfit1,'g-',label='1st order')
plt.plot(xfit,yfit2,'r-',label='2nd order')
plt.plot(xfit,yfit3,'y-',label='3rd order')
plt.xlabel("Pitch (degrees)")
plt.ylabel("Ax (g)")
plt.legend()
plt.show()
fit coefficients: [-7.69397245e-07 9.77174679e-07 1.72519368e-02 -9.37386999e-05]
Humpback whale hydrophone data¶
I don't know if it makes sense to analyse the highly autocorrelated killer whale tag data using machine learning (it probably does!) so I'm going to use acoustic data from a bottom-moored recorder, called the Soundtrap. The recorder is on a small subsurface mooring in Vestmannaeyjar and is redeployed twice a year, so there is literally years of data to go through!
The 15-min sound clip (downsampled to fs=12 kHz and noise-reduced) contains a partial humpback whale song, which is a sequence of discrete and stereotyped sounds or song units, so perfect for relatively basic pattern recognition stuff I hope! The same recording on our YouTube Channel
Load and visualise data set¶
I asked ChatGPT:
- Can I read a wav file and generate and plot a spectrogram in Python. Which library?
It recommended scipy + matplotlib for basic stuff and librosa for audio analysis & machine learning. I went with the first option and adapted the example code based on prior knowledge and the standard documentation.
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import wavfile
from scipy.signal import spectrogram
# Read WAV file
fs, x = wavfile.read("datasets/humpback-whale-song_noise_reduced.wav")
# Create spectrogram
freqs, t, Sxx = spectrogram(x, fs, window='hann', nperseg=512)
# Plot three 30-s segments as an example
tstart = np.array([0, 120, 600])
tend = np.array([60, 180, 660])
for i in range(3):
idx = (t>=tstart[i]) & (t<tend[i])
plt.figure(figsize=(8, 6))
plt.subplot(3,1,i+1)
plt.pcolormesh(t[idx], freqs/1000, 10 * np.log10(Sxx[:,idx]), shading='gouraud')
plt.ylabel('freq. (kHz)')
plt.xlabel('t (s)')
#plt.colorbar(label='level (dBFS)')
plt.tight_layout()
plt.show()
The three segments of 1 min each shown here show some of the variability in song unit types.
Fit unsupervised machine learning model¶
Then I asked ChatGPT:
-I have a spectrogram Sxx in Python. It's shape is (257 freqs, 24107 time steps). Total duration is about 15 min. Please write a minimal JAX program using no other libraries, to train a simple unsupervised machine learning algoritm to classify sounds in the spectrogram data, and use matplotlib to show examples of the output. Use segments of about 2 seconds each, with 50% overlap.
And made some minor adjustments from there, such as the 10*log10 scaling.
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
# -----------------------------
# Parameters
# -----------------------------
segment_len = 54 # ~2 seconds
hop_len = segment_len // 2 # 50% overlap
k = 7 # number of clusters
examples_per_cluster = 3 # for plotting
# -----------------------------
# 1. Create overlapping segments
# -----------------------------
num_segments = (Sxx.shape[1] - segment_len) // hop_len + 1
segments = []
for i in range(num_segments):
start = i * hop_len
end = start + segment_len
segments.append(Sxx[:, start:end])
segments = jnp.stack(segments, axis=0) # (num_segments, freqs, segment_len)
X = segments.reshape(num_segments, -1) # flatten for k-means
# -----------------------------
# 2. Simple K-Means in JAX
# -----------------------------
key = jax.random.PRNGKey(0)
centroids = X[jax.random.choice(key, X.shape[0], (k,), replace=False)]
def kmeans_step(centroids, X):
dists = jnp.linalg.norm(X[:, None, :] - centroids[None, :, :], axis=2)
labels = jnp.argmin(dists, axis=1)
def update_centroid(i):
pts = X[labels == i]
return jnp.mean(pts, axis=0)
new_centroids = jnp.stack([update_centroid(i) for i in range(k)], axis=0)
return new_centroids, labels
# Run 10 iterations
for _ in range(10):
centroids, labels = kmeans_step(centroids, X)
# -----------------------------
# 3. Visualization of clusters
# -----------------------------
fig, axs = plt.subplots(k, examples_per_cluster, figsize=(8, 12))
for cluster in range(k):
cluster_indices = jnp.where(labels == cluster)[0]
cluster_indices = cluster_indices[:examples_per_cluster]
for j, idx in enumerate(cluster_indices):
axs[cluster, j].imshow(10*jnp.log10(segments[idx]), aspect='auto', origin='lower')
axs[cluster, j].axis('off')
axs[cluster, j].set_title(f"C{cluster}-{j}", fontsize=8)
plt.tight_layout()
plt.show()
Not bad for a first cut!