K-Means in R
If you look up the k-means algorithm online or in a reference book, you will be met with a flurry a mathematical symbols and formal explanations. The basic principal (informally stated) is rather simple... given set of observations (picture a scatter plot of points), and a number of groups or clusters that you wish to group them in, the k-means algorithm finds the center of each group and associates observations with the groups with the "closest" center.
To use k-means in R, call the kmeans function with a matrix of values and the number of centers. The function seeks to partition the points into k groups (the number of centers) such that the sum of squares from points to the assigned cluster centers is minimized. Each observation (point) belongs to the cluster with the nearest mean.
How-To
To start, we will copy the iris data set to a separate data frame. Not strictly speaking necessary, but makes it easier me to reflexively enter df whenever the data frame is in view. Next we create a matrix object containing only the Petal Length and Width.
df=iris
m=as.matrix(cbind(df$Petal.Length, df$Petal.Width),ncol=2)
Now we will do the actual clustering.
cl=(kmeans(m,3))
Simple eh? The cl object contains a number of interesting attributes associated with the model.
cl$size
cl$withinss
Next we do a bit of data formatting and preparation for subsequent calls to graph the data. Notice that we add the cluster information back to our original data frame. This is a good organization of the data and also a requirement for working with ggplot2 which is designed to use data frames.
df$cluster=factor(cl$cluster)
centers=as.data.frame(cl$centers)
The following graph color codes the points by cluster. We also add the centers and a semi transparent halo around the center to emphasize the place of the center... and its role in classifying the observations into clusters.
library(ggplot2)
ggplot(data=df, aes(x=Petal.Length, y=Petal.Width, color=cluster )) +
geom_point() +
geom_point(data=centers, aes(x=V1,y=V2, color='Center')) +
geom_point(data=centers, aes(x=V1,y=V2, color='Center'), size=52, alpha=.3, legend=FALSE)
This plot is an interesting example of how several different sets of data (in this case the actual observations as well as the centers) in separate data frames can be included in a single ggplot2 chart.
Misclassified Observations
The models is pretty accurate, but not perfect. The following SQL statement highlights the few misclassified observations:
sqldf('select Species, cluster, count(*) from df group by Species, Cluster')
Species cluster count(*)
1 setosa 2 50
2 versicolor 1 48
3 versicolor 3 2
4 virginica 1 6
5 virginica 3 44
So we grab the outliers into their own data frame....
df2 = sqldf('select * from df where (Species || cluster) in
(select Species || cluster from df group by Species, Cluster having count(*) < 10)')
Now we can enhance the previous graph to put a diamond around misclassified points.
last_plot() + geom_point(data=df2, aes(x=Petal_Length, y=Petal_Width, shape=5, alpha=.7, size=4.5), legend=FALSE)
And so with a bit of Data Mining knowledge and the R programming language, even our machines can stop and smell the roses...


1 comments: