Codecademy Logo

K-Means Clustering

K-Means: Inertia

Inertia measures how well a dataset was clustered by K-Means. It is calculated by measuring the distance between each data point and its centroid, squaring this distance, and summing these squares across one cluster.

A good model is one with low inertia AND a low number of clusters (K). However, this is a tradeoff because as K increases, inertia decreases.

To find the optimal K for a dataset, use the Elbow method; find the point where the decrease in inertia begins to slow. K=3 is the “elbow” of this graph.

A line graph plotting the optimal number of clusters.

 At the top of the graph there is a label 'Optimal Number of Clusters'. The x-axis of the graph is labeled 'Number of Clusters (k)'. The y-axis of the graph is labeled 'Inertia'. The x-axis ranges from 1 to 7 with increments of 1 marked with the numbers 1 through 7. The y-axis ranges from 100 to 700 and is marked at increments of 100.

The first plot point number of clusters is 1 and inertia is about 650. When the number of clusters is 2 the inertia is about 225. When the number of clusters is 3 the inertia is about 150. When the number of clusters is 4 the inertia is about 130. When the number of clusters is 5 the inertia is about 115. When the number of clusters is 6 the inertia is about 105. Finally when the number of clusters is 7 the inertia is about 100.

Unsupervised Learning Basics

Patterns and structure can be found in unlabeled data using unsupervised learning, an important branch of machine learning. Clustering is the most popular unsupervised learning algorithm; it groups data points into clusters based on their similarity. Because most datasets in the world are unlabeled, unsupervised learning algorithms are very applicable.

Possible applications of clustering include:

  • Search engines: grouping news topics and search results
  • Market segmentation: grouping customers based on geography, demographics, and behaviors

K-Means Algorithm: Intro

K-Means is the most popular clustering algorithm. It uses an iterative technique to group unlabeled data into K clusters based on cluster centers (centroids). The data in each cluster are chosen such that their average distance to their respective centroid is minimized.

  1. Randomly place K centroids for the initial clusters.
  2. Assign each data point to their nearest centroid.
  3. Update centroid locations based on the locations of the data points.

Repeat Steps 2 and 3 until points don’t move between clusters and centroids stabilize.

An animated gif of how the k- means algorithm works. 

To start there is a graph with the numbers 0 through 8 on the x-axis and the numbers 0 through 8 on the y-axis. There are many points plotted and each of the points starts out as blue.

Next centroids are placed. A green centroid in the shape of the letter 'X' is placed at an x value of 1 and a y value of 6. A red centroid in the shape of the letter 'X' is placed at an x value of about 5 and a y value of 3. The data points are then assigned to their nearest centroid and color coded as such. 

The centroids are updated based on the new data points that were assigned to the centroids. The green centroid moves to an x value of 2 and a y value of 5. The red centroid moves to an x value of 3.5 and a y value of about 2.5. Again, the data points are reassigned based on the closest centroid and then the centroids are updated based on their newly assigned data points. The centroids move for the third time, the green centroid  is at an x value of 3.5 and a y value of 5.5. the third position of the red centroid is an x value of 3 and a y value of 1.5. The data samples are assigned to the nearest centroid again and the centroids are updated based on the newly assigned data samples. The centroids move to their final positions. The green centroid is at an x value of 4.5 and a y value of 6.5. The red centroid is at an x value of 2 and a y value of 1.5.

K-Means Algorithm: 2nd Step

After randomly choosing centroid locations for K-Means, each data sample is allocated to its closest centroid to start creating more precise clusters.

The distance between each data sample and every centroid is calculated, the minimum distance is selected, and each data sample is assigned a label that indicates its closest cluster.

The distance formula is implemented as .distance() and used for each data point.

np.argmin() is used to find the minimum distance and find the cluster at that distance.

# distance formula
def distance(a, b):
one = (a[0] - b[0]) **2
two = (a[1] - b[1]) **2
distance = (one+two) ** 0.5
return distance

Scikit-Learn Datasets

The scikit-learn library contains built-in datasets in its datasets module that are often used in machine learning problems like classification or regression.

Examples:

  • Iris dataset (classification)
  • Boston house-prices dataset (regression)

The format of these datasets are important to their use with algorithms. For example, each piece of data in the Iris dataset is a sample (flower type), and each element within a sample is a feature (i.e. petal width).

An example dataset from scikit-learn that includes illustrated iris flowers in three boxes. Each iris has three sepals and three petals. The sepals are shaped like a rectangle with two rounded corners on the end at the outside of the flower. The base of each sepal meets in the center of the flower. Each sepal has white tapered lines of varying lengths that meet at the center of the flower. The petals are rounded teardrops that have a pointed tip and a wider base, and are a gradient of saturated bright purple at the tip to yellow at the base. 

Each iris has differently positioned sepals and petals. In each illustration of the irises the sepals are in the same position but vary in size. The first sepal is in the center of the flower pointing towards the top of the box. The second and third sepals are positioned diagonally with the rounded ends pointing towards the bottom left and the bottom right corners of the box. The petals are positioned similarly to the sepals, but flipped horizontally. The central part of the flower has three wide rounded lines that are dark purple.

The left box is labeled 'Iris Versicolor' and has one labeled sepal and one labeled petal. The petals in this illustration have a similar height and width, where the height is about  ⅔ rds the height (length) of the sepals. 

The middle box is labeled 'Iris Setosa' and has thinner sepals and elongated petals, compared to the iris in the left box, that extend to about half the length of the sepals. 

The right box is labeled 'Iris Virginica' and has similarly sized sepals to 'Iris Versicolor' with wider and longer petals that extend further than the sepals.

K-Means Using Scikit-Learn

Scikit-Learn, or sklearn, is a machine learning library for Python that has a K-Means algorithm implementation that can be used instead of creating one from scratch.

To use it:

  • Import the KMeans() method from the sklearn.cluster library to build a model with n_clusters

  • Fit the model to the data samples using .fit()

  • Predict the cluster that each data sample belongs to using .predict() and store these as labels

from sklearn.cluster import KMeans
model = KMeans(n_clusters=3)
model.fit(data_samples)
labels = model.predict(data_samples)

Cross Tabulation Overview

Cross-tabulations involve grouping pieces of data together in order to examine their relationship in a different way. Sometimes correlations within data can be seen better when not just looking at total responses.

This technique is often performed in Python after running K-Means; the Pandas method .crosstab() allows for comparison between resulting cluster labels and user-defined labels for each data sample. In order to validate the results of a K-Means model with this technique, there must be user-defined labels for all data samples.

import pandas as pd
cross_tab = pd.crosstab(df['pred_labels'], df['user_labels'])

K-Means: Reaching Convergence

In K-Means, after placing K random centroids, the data samples are repeatedly assigned to the nearest centroid and then centroid locations are updated. This continues until each of the centroids’ coordinates converge, or stop changing.

This sequence of events can be implemented in Python using a while loop. The loop continues until the difference between each element of the updated centroids and each element of the past centroids_old is 0. This will mean the centroids have converged and the clusters are complete!

A scatter plot graph to illustrate how K random centroids converge. 

The graph's x-axis is labeled 'sepal length (cm)' and has marks (lengths) along the x-axis labeled starting with '4.5' for the first mark and continuing to '8.0' marked at intervals of 0.5. The graph's y-axis is labeled 'sepal width (cm)' and has marks (widths) across the y-axis labeled starting with '2.0' to '4.5' similarly marked at intervals of 0.5.

There are three centroids on the graph each marked with a blue diamond shape. There is a centroid at the x value of 5.0 and the y value of a little over 3.5. There is another centroid at the x value of about 5.25 and a y value of about 2.75. The last centroid is at an x value of a little over 6.5 and a y value of 3.0. Surrounding each of the centroids are multiple scatter plot data points. 

Surrounding the first centroid (using the same order as the prior paragraph) there are points signified by twenty-seven red circles starting at an x value of about 4.5 and a y value of about 3.25 and ending at an x value of about 5.75 and a y value of almost 4.5. There are a few of the points that are darker in color and the points are most dense near the centroid.

Surrounding the second centroid there are thirty-five data points signified by green circles and some of the circles are darker in color. These points start at an x value of about 4.25 and a y value of about 2.0. The points end at an x value of about 5.75 and a y value of almost 3.25. The points are most dense to the right of the centroid.

The final centroid fifty-three data points signified by purple circles, again some circles are darker in color. The points start at an x value of almost 6.0 and a y value of almost 2.25. The points end at an x value of about 8.0 and a y value of about 3.75. The data points are most dense to the left of the centroid.

K-Means Algorithm: 3rd Step

The third step of K-Means updates centroid locations. After the data are assigned to their respectively closest centroid in step 2, each cluster center location is adjusted to be the average of its assigned data points.

The NumPy .mean() function is used to find the average x and y-coordinates of all data points for each cluster and store these as the new centroid locations.

An animated gif of how the k-means algorithm works. 

To start there is a graph with the numbers 0 through 8 on the x-axis and the numbers 0 through 8 on the y-axis. There are many points plotted and each of the points starts out as blue. Under the graph there is a numbered list of what happens at each step of the algorithm and has the following text:
1. Place k random centeroids for the initial clusters.
2. Assign data samples to the nearest centroid.
3. Update cetroids based on the newly assigned samples.

To start, none of the text is highlighted.

Next the first step of the algorithm text is highlighted and centroids are placed. A green centroid in the shape of the letter 'X' is placed at an x value of 1 and a y value of 6. A red centroid in the shape of the letter 'X' is placed at an x value of about 5 and a y value of 3. 

Next, the data points are assigned to their nearest centroid and color coded as such. The second step of the algorithm text is highlighted as well and the first step text is no longer highlighted. 

The centroids are updated based on the new data points that were assigned to the centroids. The green centroid moves to an x value of 2 and a y value of 5. The red centroid moves to an x value of 3.5 and a y value of about 2.5. Step three of the algorithm text is highlighted while the other steps are not.
 
Again, the data points are reassigned based on the closest centroid and the second step of the algorithm text is highlighted. 

Then the centroids are updated again based on their newly assigned data points. Step three of the algorithm text is highlighted. The green centroid is at an x value of 3.5 and a y value of 5.5. The red centroid is an x value of 3 and a y value of 1.5. 

Next, the data samples are assigned to the nearest centroid and the second step of the algorithm text is highlighted. 

The centroids move to their final positions. The green centroid is at an x value of 4.5 and a y value of 6.5. The red centroid is at an x value of 2 and a y value of 1.5. Step three of the algorithm text is highlighted.

K-Means Algorithm: 1st Step

The first step of the K-Means clustering algorithm requires placing K random centroids which will become the centers of the K initial clusters. This step can be implemented in Python using the Numpy random.uniform() function; the x and y-coordinates are randomly chosen within the x and y ranges of the data points.

A scatter plot graph to illustrate the K means algorithm. 

The graph's x-axis is labeled 'sepal length (cm)' and has marks (lengths) along the x-axis labeled starting with '4.5' for the first mark and continuing to '8.0' marked at intervals of 0.5. The graph's y-axis is labeled 'sepal width (cm)' and has marks (widths) across the y-axis labeled starting with '2.0' to '4.5' similarly marked at intervals of 0.5.

There are three centroids on the graph each marked with an orange dot. There is a centroid at the x value of about 4.75 and the y value of a little over 2.0. There is another centroid at the x value of about 7.25 and a y value of about 3.0. The last centroid is at an x value of about 8.0 and a y value of 3.75. 

There are dark and light blue dots representing data points on the graph. These dots are most dense in the center of the graph along the x-axis and near the 3.0 mark on the y-axis.

Learn More on Codecademy