SciPy K-Means
SciPy K-Means : Package scipy.cluster.vp provides kmeans() function to perform k-means on a set of observation vectors forming k clusters. In this tutorial, we shall learn the syntax and the usage of kmeans() function with SciPy K-Means Examples.
Syntax
</>
Copy
centroids,distortion = scipy.cluster.vq.kmeans(obs, k_or_guess, iter=20, thresh=1e-05, check_finite=True)
Parameter | Optional/ Required | [datatype] Description |
obs | Required | [ndarray] Each row of the M by N array is an observation vector. The columns are the features seen during each observation. The features must be whitened first with the whiten function. |
k_or_guess | Required | [int or ndarray] The number of centroids to generate. A code is assigned to each centroid, which is also the row index of the centroid in the code_book matrix generated. |
iter | Optional | [int] The number of times to run k-means, returning the codebook with the lowest distortion. This argument is ignored if initial centroids are specified with an array for the k_or_guess parameter. This parameter does not represent the number of iterations of the k-means algorithm. |
thresh | Optional | [float] Terminates the k-means algorithm if the change in distortion since the last k-means iteration is less than or equal to thresh. |
check_finite | Optional | [boolean] Whether to check that the input matrices contain only finite numbers. Disabling may give a performance gain, but may result in problems (crashes, non-termination) if the inputs do contain infinities or NaNs. Default: True |
centroids | [Returned value] | [ndarray] A k by N array of k centroids. |
distortion | [Returned Value] | [float] The distortion between the observations passed and the centroids generated. |
Values provided for the optional arguments are default values.
SciPy K-Means Example
In this example, we shall generate a set of random 2-D points, centered around 3 centroids.
</>
Copy
# import numpy
from numpy import vstack,array
from numpy.random import rand
# matplotlib
import matplotlib.pyplot as plt
# scipy
from scipy.cluster.vq import kmeans,vq,whiten
data = vstack(((rand(20,2)+1),(rand(20,2)+3),(rand(20,2)+4.5)))
plt.plot(data[:,0],data[:,1],'go')
plt.show()
</>
Copy
# whiten the features
data = whiten(data)
# find 3 clusters in the data
centroids,distortion = kmeans(data,3)
print('centroids : ',centroids)
print('distortion :',distortion)
plt.plot(data[:,0],data[:,1],'go',centroids[:,0],centroids[:,1],'bs')
plt.show()
centroids : [[ 1.42125469 1.58213817]
[ 3.55399219 3.53655637]
[ 4.91171555 5.02202473]]
distortion : 0.35623898893