Notice: This page requires JavaScript to function properly.
Please enable JavaScript in your browser settings or update your browser.
Mean Shift Clustering | Basic Clustering Algorithms
Cluster Analysis
course content

Зміст курсу

Cluster Analysis

Cluster Analysis

1. What is Clustering?
2. Basic Clustering Algorithms
3. How to choose the best model?

Mean Shift Clustering

Mean shift is the most simple density-based clustering algorithm. Simply speaking, "mean shift" equals "iteratively shifting to the mean". In the algorithm, every data point is shifted to the "regional mean" step by step, and the location of the final destination of each point represents the cluster it belongs to. Algorithm consists of the next steps:

Step 1. For each data point, you have to create a sliding window with a specified radius (bandwidth);

Step 2. Shift each of the sliding windows towards higher density regions by shifting its centroid to the data points' mean within the window. This step will be repeated until there will be no increase in the number of points in the sliding window or the centroid will stop moving;

Step 3. Selection of sliding windows by merging overlapping windows. When multiple windows overlap, the window containing the most points is preserved, and the others are merged with it;

Step 4. Assign the data points to the sliding window where they reside. If the data point is out of the window, assign it to the nearest window.

Mean shift shifts the windows to a higher density region by shifting their centroid (center of the sliding window) to the mean of the data points inside the sliding window.

So the Mean shift algorithm is very similar to the K-means algorithm: it also works on the mean of the points and can only work on isolated clusters. But there is one significant difference: the algorithm does not need to manually set the number of clusters.

Let's look at the example of using Mean shift clustering in Python:

123456789101112131415
from sklearn.cluster import MeanShift import numpy as np import matplotlib.pyplot as plt from sklearn.datasets import make_blobs, make_moons # Create dataset for clustering X, y = make_blobs(n_samples=500, cluster_std=1, centers=4 ) transformation = [[0.6, -0.6], [-0.4, 0.8]] X_aniso = np.matmul(X, transformation) # Train Mean Shift model on blobs dataset and visualize the results blobs_clustering = MeanShift(bandwidth=2).fit(X_aniso) fig, axes = plt.subplots(1, 2) axes[0].scatter(X[:, 0], X[:, 1], c=blobs_clustering.labels_, s=50, cmap='tab20b') axes[0].set_title('Clustered anizo blobs data') axes[1].scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='tab20b') axes[1].set_title('Real anizo blobs data')
copy

Let's check how Mean shift algorithm will deal with the moons dataset:

12345678910111213
from sklearn.datasets import make_moons import matplotlib.pyplot as plt from sklearn.cluster import MeanShift # Create moons dataset for clustering X, y = make_moons(n_samples=500) # Fit Mean Shift model on moons dataset and visualize the results moons_clustering = MeanShift(bandwidth=0.7).fit(X) fig,axes = plt.subplots(1,2) axes[0].scatter(X[:, 0], X[:, 1], c=moons_clustering.labels_, s=50, cmap='tab20b') axes[0].set_title('Clustered moons data') axes[1].scatter(X[:, 0], X[:, 1], c=y, s=50, cmap='tab20b') axes[1].set_title('Real moons data')
copy

In the code above, we use the MeanShift class to create the model: the bandwidth parameter defines the radius within which the average value is calculated.

Note

In MeanShift class you can use .predict() method to make predictions based on an already trained model.

What is the main difference between K-means and Mean shift clustering algorithms?

Виберіть правильну відповідь

Все було зрозуміло?

Секція 2. Розділ 5
We're sorry to hear that something went wrong. What happened?
some-alt