Paul Wensveen - Fab Futures - Data Science
Home About

< Home - Next Week >

Week 2: Model fitting¶

Killer whale tag data¶

Load data set & pre-processing¶

Reloading the killer whale data set from week 1 here and a few steps that may come in handy.

In [1]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import math

df = pd.read_csv("datasets/data_oo23_181a.csv", sep=",") # import as data frame
df = df.iloc[::10].reset_index(drop=True) # Keep every 10th row
df[['pitch','roll','head']] = df[['pitch','roll','head']]/math.pi*180 # Convert prh from radians to degrees

# Create categorical variable with deep vs shallow
df = df.assign(deep=df[['depth']]>25)
df["deep"] = df["deep"].map({True: ">25m", False: "<=25m"}).astype("category")

Fit polynomials¶

X-axis acceleration vs pitch data looked pretty nice for this.

In [2]:
x = df["pitch"].values # pulling data from df not array, .values to ensure ndim=1
y = df["ax"].values
xmin = min(x)
xmax = max(x)
npts = 100

coeff1 = np.polyfit(x,y,1) # fit 1st-order polynomial
coeff2 = np.polyfit(x,y,2) # fit 2nd-order polynomial
coeff3 = np.polyfit(x,y,3) # fit 3rd-order polynomial
xfit = np.arange(xmin,xmax,(xmax-xmin)/npts)
pfit1 = np.poly1d(coeff1)
yfit1 = pfit1(xfit) # evaluate fit
pfit2 = np.poly1d(coeff2)
yfit2 = pfit2(xfit) # evaluate fit
pfit3 = np.poly1d(coeff3)
yfit3 = pfit3(xfit) # evaluate fit
print(f"fit coefficients: {coeff3}")
plt.plot(x,y,'o')
plt.plot(xfit,yfit1,'g-',label='1st order')
plt.plot(xfit,yfit2,'r-',label='2nd order')
plt.plot(xfit,yfit3,'y-',label='3rd order')
plt.xlabel("Pitch (degrees)")
plt.ylabel("Ax (g)")
plt.legend()
plt.show()
fit coefficients: [-7.69397245e-07  9.77174679e-07  1.72519368e-02 -9.37386999e-05]
No description has been provided for this image

Humpback whale hydrophone data¶

I don't know if it makes sense to analyse the highly autocorrelated killer whale tag data using machine learning (it probably does!) so I'm going to use acoustic data from a bottom-moored recorder, called the Soundtrap. The recorder is on a small subsurface mooring in Vestmannaeyjar and is redeployed twice a year, so there is literally years of data to go through!

The 15-min sound clip (downsampled to fs=12 kHz and noise-reduced) contains a partial humpback whale song, which is a sequence of discrete and stereotyped sounds or song units, so perfect for relatively basic pattern recognition stuff I hope! The same recording on our YouTube Channel

A useful reference: Chicco C et al. (2024). Using acoustic monitoring to reveal nearly year-round presence of humpback whales (Megaptera novaeangliae) in the waters of southern Iceland. Marine Mammal Science

Load and visualise data set¶

I asked ChatGPT:

  • Can I read a wav file and generate and plot a spectrogram in Python. Which library?

It recommended scipy + matplotlib for basic stuff and librosa for audio analysis & machine learning. I went with the first option and adapted the example code based on prior knowledge and the standard documentation.

In [3]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import wavfile
from scipy.signal import spectrogram

# Read WAV file
fs, x = wavfile.read("datasets/humpback-whale-song_noise_reduced.wav")

# Create spectrogram
freqs, t, Sxx = spectrogram(x, fs, window='hann', nperseg=512)

# Plot three 30-s segments as an example 
tstart = np.array([0, 120, 600]) 
tend = np.array([60, 180, 660]) 

for i in range(3):

    idx = (t>=tstart[i]) & (t<tend[i])
    
    plt.figure(figsize=(8, 6))
    plt.subplot(3,1,i+1)
    plt.pcolormesh(t[idx], freqs/1000, 10 * np.log10(Sxx[:,idx]), shading='gouraud')
    plt.ylabel('freq. (kHz)')
    plt.xlabel('t (s)')
    #plt.colorbar(label='level (dBFS)')
    plt.tight_layout()
    plt.show()
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image

The three segments of 1 min each shown here show some of the variability in song unit types.

Fit unsupervised machine learning model¶

Then I asked ChatGPT:

-I have a spectrogram Sxx in Python. It's shape is (257 freqs, 24107 time steps). Total duration is about 15 min. Please write a minimal JAX program using no other libraries, to train a simple unsupervised machine learning algoritm to classify sounds in the spectrogram data, and use matplotlib to show examples of the output. Use segments of about 2 seconds each, with 50% overlap.

And made some minor adjustments from there, such as the 10*log10 scaling.

In [5]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

# -----------------------------
# Parameters
# -----------------------------
segment_len = 54            # ~2 seconds
hop_len = segment_len // 2  # 50% overlap
k = 7                       # number of clusters
examples_per_cluster = 3    # for plotting

# -----------------------------
# 1. Create overlapping segments
# -----------------------------
num_segments = (Sxx.shape[1] - segment_len) // hop_len + 1
segments = []

for i in range(num_segments):
    start = i * hop_len
    end = start + segment_len
    segments.append(Sxx[:, start:end])

segments = jnp.stack(segments, axis=0)  # (num_segments, freqs, segment_len)
X = segments.reshape(num_segments, -1)   # flatten for k-means

# -----------------------------
# 2. Simple K-Means in JAX
# -----------------------------
key = jax.random.PRNGKey(0)
centroids = X[jax.random.choice(key, X.shape[0], (k,), replace=False)]

def kmeans_step(centroids, X):
    dists = jnp.linalg.norm(X[:, None, :] - centroids[None, :, :], axis=2)
    labels = jnp.argmin(dists, axis=1)

    def update_centroid(i):
        pts = X[labels == i]
        return jnp.mean(pts, axis=0)
    
    new_centroids = jnp.stack([update_centroid(i) for i in range(k)], axis=0)
    return new_centroids, labels

# Run 10 iterations
for _ in range(10):
    centroids, labels = kmeans_step(centroids, X)

# -----------------------------
# 3. Visualization of clusters
# -----------------------------
fig, axs = plt.subplots(k, examples_per_cluster, figsize=(8, 12))

for cluster in range(k):
    cluster_indices = jnp.where(labels == cluster)[0]
    cluster_indices = cluster_indices[:examples_per_cluster]

    for j, idx in enumerate(cluster_indices):
        axs[cluster, j].imshow(10*jnp.log10(segments[idx]), aspect='auto', origin='lower')
        axs[cluster, j].axis('off')
        axs[cluster, j].set_title(f"C{cluster}-{j}", fontsize=8)

plt.tight_layout()
plt.show()
No description has been provided for this image

Not bad for a first cut!