Sahil Gupta - Fab Futures - Data Science
Home About

< Home

Week 6: Density Estimation¶

Goal: Fit a probability distribution to our data

My understanding of this week's material was that distributions can have multiple clusters, which can be found iteratively, which helps produce a more nuanced interpretion of the data.

Here was the prompt to Copilot: Help me explore clustering in the data from chatgpt1.csv

....and its helpful explanation:

"To explore clustering in chatgpt1.csv, we'll use a Gaussian Mixture Model (GMM), which fits probability distributions with multiple clusters (as per your Week 6 understanding). GMM is a density estimation technique that assumes data comes from a mixture of Gaussian distributions, allowing iterative discovery of clusters for nuanced interpretations.

We'll cluster the tweets based on numerical features (e.g., ReplyCount, RetweetCount, etc.), as these represent engagement patterns. This can reveal groups like "low engagement" vs. "high engagement" tweets."

In [3]:
import pandas as pd
import numpy as np
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns

# Load and preprocess data
df = pd.read_csv('/Users/srgupta/Desktop/Academany/Data_Science/sahil-gupta/datasets/chatgpt1.csv')
numerical_cols = df.select_dtypes(include=[np.number]).columns
df_num = df[numerical_cols].fillna(0)

# Select features for clustering (engagement metrics)
features = ['ReplyCount', 'RetweetCount', 'QuoteCount', 'LikeCount', 'hastag_counts']
X = df_num[features]

# Standardize data (important for GMM)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Fit GMM (assume 3 clusters; adjust n_components based on data)
gmm = GaussianMixture(n_components=3, random_state=42)
gmm.fit(X_scaled)

# Predict cluster labels
labels = gmm.predict(X_scaled)
df_num['Cluster'] = labels

# Evaluate: Log-likelihood and BIC for model fit
print(f"Log-Likelihood: {gmm.score(X_scaled)}")
print(f"BIC: {gmm.bic(X_scaled)}")

# Visualize clusters (pairplot of key features, colored by cluster)
# Sample a subset for speed (full dataset is large)
sample_df = df_num.sample(n=5000, random_state=42)
sns.pairplot(sample_df, vars=['RetweetCount', 'LikeCount', 'ReplyCount'], hue='Cluster', palette='viridis')
plt.suptitle('GMM Clusters (Sampled Data)', y=1.02)
plt.show()

# Optional: Cluster sizes
print("Cluster Sizes:")
print(df_num['Cluster'].value_counts())
/Users/srgupta/Desktop/Academany/Data_Science/.venv/lib/python3.10/site-packages/sklearn/utils/extmath.py:203: RuntimeWarning: divide by zero encountered in matmul
  ret = a @ b
/Users/srgupta/Desktop/Academany/Data_Science/.venv/lib/python3.10/site-packages/sklearn/utils/extmath.py:203: RuntimeWarning: overflow encountered in matmul
  ret = a @ b
/Users/srgupta/Desktop/Academany/Data_Science/.venv/lib/python3.10/site-packages/sklearn/utils/extmath.py:203: RuntimeWarning: invalid value encountered in matmul
  ret = a @ b
/Users/srgupta/Desktop/Academany/Data_Science/.venv/lib/python3.10/site-packages/sklearn/cluster/_kmeans.py:237: RuntimeWarning: divide by zero encountered in matmul
  current_pot = closest_dist_sq @ sample_weight
/Users/srgupta/Desktop/Academany/Data_Science/.venv/lib/python3.10/site-packages/sklearn/cluster/_kmeans.py:237: RuntimeWarning: overflow encountered in matmul
  current_pot = closest_dist_sq @ sample_weight
/Users/srgupta/Desktop/Academany/Data_Science/.venv/lib/python3.10/site-packages/sklearn/cluster/_kmeans.py:237: RuntimeWarning: invalid value encountered in matmul
  current_pot = closest_dist_sq @ sample_weight
Log-Likelihood: 5.280790520421618
BIC: -527418.7861295816
No description has been provided for this image
Cluster Sizes:
Cluster
0    48924
1     1076
2        1
Name: count, dtype: int64

Copilot's Explanation:

  • Preprocessing: Selects numerical features and standardizes them (GMM is sensitive to scale).
  • GMM Fitting: n_components=3 assumes 3 clusters (e.g., low, medium, high engagement). The model iteratively refines clusters via expectation-maximization.
  • Evaluation: Log-likelihood measures fit; lower BIC is better (penalizes complexity).
  • Visualization: Pairplot shows clusters in 2D slices. Expect clusters to group similar engagement levels.
  • Insights: Clusters might reveal patterns (e.g., viral tweets in one cluster). Adjust n_components if needed (try 2-5 and compare BIC).

Trying a different method: K-Means clustering:

"To switch to K-Means clustering (a simpler, centroid-based method that partitions data into k groups by minimizing distances), we'll apply it to the same numerical features from chatgpt1.csv. K-Means isn't a density estimator like GMM, but it's great for exploratory clustering and can reveal engagement patterns (e.g., grouping tweets by similar counts)."

In [4]:
import pandas as pd
import numpy as np
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns

# Load and preprocess data
df = pd.read_csv('/Users/srgupta/Desktop/Academany/Data_Science/sahil-gupta/datasets/chatgpt1.csv')
numerical_cols = df.select_dtypes(include=[np.number]).columns
df_num = df[numerical_cols].fillna(0)

# Select features for clustering (engagement metrics)
features = ['ReplyCount', 'RetweetCount', 'QuoteCount', 'LikeCount', 'hastag_counts']
X = df_num[features]

# Standardize data (recommended for K-Means)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# Fit K-Means (assume 3 clusters; adjust n_clusters based on data)
kmeans = KMeans(n_clusters=3, random_state=42, n_init=10)
labels = kmeans.fit_predict(X_scaled)
df_num['Cluster'] = labels

# Evaluate: Inertia (within-cluster sum of squares; lower is better)
print(f"Inertia: {kmeans.inertia_}")

# Visualize clusters (pairplot of key features, colored by cluster)
# Sample a subset for speed
sample_df = df_num.sample(n=5000, random_state=42)
sns.pairplot(sample_df, vars=['RetweetCount', 'LikeCount', 'ReplyCount'], hue='Cluster', palette='viridis')
plt.suptitle('K-Means Clusters (Sampled Data)', y=1.02)
plt.show()

# Optional: Cluster sizes and centroids
print("Cluster Sizes:")
print(df_num['Cluster'].value_counts())
print("\nCluster Centroids (original scale):")
centroids = scaler.inverse_transform(kmeans.cluster_centers_)
centroid_df = pd.DataFrame(centroids, columns=features)
print(centroid_df)
/Users/srgupta/Desktop/Academany/Data_Science/.venv/lib/python3.10/site-packages/sklearn/utils/extmath.py:203: RuntimeWarning: divide by zero encountered in matmul
  ret = a @ b
/Users/srgupta/Desktop/Academany/Data_Science/.venv/lib/python3.10/site-packages/sklearn/utils/extmath.py:203: RuntimeWarning: overflow encountered in matmul
  ret = a @ b
/Users/srgupta/Desktop/Academany/Data_Science/.venv/lib/python3.10/site-packages/sklearn/utils/extmath.py:203: RuntimeWarning: invalid value encountered in matmul
  ret = a @ b
/Users/srgupta/Desktop/Academany/Data_Science/.venv/lib/python3.10/site-packages/sklearn/cluster/_kmeans.py:237: RuntimeWarning: divide by zero encountered in matmul
  current_pot = closest_dist_sq @ sample_weight
/Users/srgupta/Desktop/Academany/Data_Science/.venv/lib/python3.10/site-packages/sklearn/cluster/_kmeans.py:237: RuntimeWarning: overflow encountered in matmul
  current_pot = closest_dist_sq @ sample_weight
/Users/srgupta/Desktop/Academany/Data_Science/.venv/lib/python3.10/site-packages/sklearn/cluster/_kmeans.py:237: RuntimeWarning: invalid value encountered in matmul
  current_pot = closest_dist_sq @ sample_weight
Inertia: 97401.52927266332
No description has been provided for this image
Cluster Sizes:
Cluster
0    49988
1       12
2        1
Name: count, dtype: int64

Cluster Centroids (original scale):
    ReplyCount  RetweetCount   QuoteCount     LikeCount  hastag_counts
0     0.666460      0.908458     0.116108      6.383292   7.835080e-01
1   976.833333   1891.666667   268.833333   9138.750000   1.110223e-16
2  1421.000000   6815.000000  1947.000000  56073.000000   0.000000e+00

I guess the clusters don't add too much to what is already known. Most of tweets receive very little engagement, but the ones that do are almost outliers. There does tend to be a "middle" group visible in the plots, but it's not too far from the low engagement cluster.