The K-Means algorithm:
- Place
k
random centroids for the initial clusters. - Assign data samples to the nearest centroid.
- Update centroids based on the above-assigned data samples.
Repeat Steps 2 and 3 until convergence.
In this exercise, we will implement Step 4.
This is the part of the algorithm where we repeatedly execute Step 2 and 3 until the centroids stabilize (convergence).
We can do this using a while
loop. And everything from Step 2 and 3 goes inside the loop.
For the condition of the while
loop, we need to create an array named errors
. In each error
index, we calculate the difference between the updated centroid (centroids
) and the old centroid (centroids_old
).
The loop ends when all three values in errors
are 0
.
Instructions
On line 40 of script.py, initialize error
:
error = np.zeros(3)
Then, use the distance()
function to calculate the distance between the updated centroid and the old centroid and put them in error
:
error[0] = distance(centroids[0], centroids_old[0]) # do the same for error[1] # do the same for error[2]
After that, add a while
loop:
while error.all() != 0:
And move everything below (from Step 2 and 3) inside.
And recalculate error
again at the end of each iteration of the while
loop:
error[0] = distance(centroids[0], centroids_old[0]) # do the same for error[1] # do the same for error[2]
Awesome, now you have everything, let’s visualize it.
After the while
loop finishes, let’s create an array of colors:
colors = ['r', 'g', 'b']
Then, create a for
loop that iterates k
times.
Inside the for
loop (similar to what we did in the last exercise), create an array named points
where we get all the data points that have the cluster label i
.
Then we are going to make a scatter plot of points[:, 0]
vs points[:, 1]
using the scatter()
function:
plt.scatter(points[:, 0], points[:, 1], c=colors[i], alpha=0.5)
Then, paste the following code at the very end. Here, we are visualizing all the points in each of the labels
a different color.
plt.scatter(centroids[:, 0], centroids[:, 1], marker='D', s=150) plt.xlabel('sepal length (cm)') plt.ylabel('sepal width (cm)') plt.show()