Chandra B. Pradhan - Fab Futures - Data Science
Home About

Week 6: Density Estimation¶

Assignments- Probability Clusters¶

In [2]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from sklearn.mixture import GaussianMixture

# Set random seed for reproducibility
np.random.seed(42)

# ----------------------------------------------------
# 1. Generate Synthetic Data for Clustering
# ----------------------------------------------------
# We create 3 distinct clusters (blobs)
X, y_true = make_blobs(n_samples=500, centers=3,
                       cluster_std=1.0, random_state=42)

# ----------------------------------------------------
# 2. Fit Gaussian Mixture Model (Probabilistic Clustering)
# ----------------------------------------------------
# n_components=3 is chosen based on the knowledge of the generated data.
gmm = GaussianMixture(n_components=3, random_state=42)
gmm.fit(X)

# ----------------------------------------------------
# 3. Predict Cluster Assignments and Probabilities
# ----------------------------------------------------
probabilities = gmm.predict_proba(X)
cluster_labels = gmm.predict(X)

# ----------------------------------------------------
# 4. Visualization (Plotting Probabilistic Clusters)
# ----------------------------------------------------
plt.figure(figsize=(12, 6))

# --- Subplot 1: Scatter plot colored by assigned cluster ---
plt.subplot(1, 2, 1)
plt.scatter(X[:, 0], X[:, 1], c=cluster_labels, s=40, cmap='viridis')
plt.title('Hard Cluster Assignments (Max Probability)')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')

# --- Subplot 2: Scatter plot colored by probability of membership ---
plt.subplot(1, 2, 2)

# Select the highest probability for each point
max_prob = np.max(probabilities, axis=1)

# Plot, using alpha proportional to probability
plt.scatter(X[:, 0], X[:, 1], c=cluster_labels, s=100 * max_prob, alpha=max_prob, cmap='viridis')

# Add colorbar for reference
cbar = plt.colorbar(orientation='vertical', label='Max Cluster Probability')
cbar.set_alpha(1) 
# Removed the unnecessary cbar.draw_all() which can sometimes cause issues

plt.title(r'Probabilistic Clusters (Alpha/Size $\propto$ Probability)')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')

plt.tight_layout()
plt.savefig('gmm_probability_clusters.png') # File is saved
plt.show() # Display plot in the notebook

# ----------------------------------------------------
# 5. Print a summary
# ----------------------------------------------------
print("GMM Model Fitted Successfully.")
print("\nSample Probability Matrix (First 5 points):")
print(probabilities[:5].round(3))
print("\nInterpretation: Each row sums to 1.0, showing the probability of the point belonging to Cluster 0, 1, or 2.")
No description has been provided for this image
GMM Model Fitted Successfully.

Sample Probability Matrix (First 5 points):
[[0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 0. 1.]
 [1. 0. 0.]]

Interpretation: Each row sums to 1.0, showing the probability of the point belonging to Cluster 0, 1, or 2.
In [ ]: