K-Means Clustering
Implementing K-Means: Scikit-Learn

Awesome, you have implemented K-Means clustering from scratch!

Writing an algorithm whenever you need it can be very time-consuming and you might make mistakes and typos along the way. We will now show you how to implement K-Means more efficiently – using the scikit-learn library.

Instead of implementing K-Means from scratch, the sklearn.cluster module has many methods that can do this for you.

To import KMeans from sklearn.cluster:

from sklearn.cluster import KMeans

For Step 1, use the KMeans() method to build a model that finds k clusters. To specify the number of clusters (k), use the n_clusters keyword argument:

model = KMeans(n_clusters = k)

For Steps 2 and 3, use the .fit() method to compute K-Means clustering:


After K-Means, we can now predict the closest cluster each sample in X belongs to. Use the .predict() method to compute cluster centers and predict cluster index for each sample:




First, import KMeans from sklearn.cluster.


Somewhere after samples = iris.data, use KMeans() to create an instance called model to find 3 clusters.

To specify the number of clusters, use the n_clusters keyword argument.


Next, use the .fit() method of model to fit the model to the array of points samples.


After you have the “fitted” model, determine the cluster labels of samples.

Then, print the labels.

Folder Icon

Sign up to start coding

Already have an account?