K-means是一種非監督式學習的分類方式,原理將向量空間中的點分成k群後,計算每個點到群組中心的距離,目的在於使點到群心的距離成為最小(參考資料:http://en.wikipedia.org/wiki/K-means_clustering,https://www.youtube.com/watch?v=aiJ8II94qck).K-means因為容易理解,所以業界中也很常使用,本篇主要不是講解演算法,而是利用python中的sikit-learn套件製作一個小程式,讓使用者可以輕鬆的使用k-means演算法.
原始碼如下:
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#! /usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# Above the run-comment and file encoding comment. | |
#----below is Bryan's code---- | |
# for caculate the minnimal distance between every bill and center of group | |
# | |
import sys | |
import numpy as np | |
import scipy | |
#import pylab as pl | |
#from mpl_toolkits.mplot3d import Axes3D | |
from sklearn.cluster import KMeans | |
from operator import itemgetter | |
from collections import Counter | |
def kmean_distance(filename, group): | |
k = KMeans(n_clusters = group, tol=0.000000001, init='random') | |
rowname = filename[:,0] | |
filename = filename[:,1:] | |
g = k.fit_predict(filename) ##group | |
distance = k.fit_transform(filename) ## caculate distance between point and "every" group center | |
g = np.column_stack((rowname, g, np.zeros((len(filename),)) )) ## combine raw data and | |
for nrow in range(len(g)): | |
id = int(g[nrow,1]) ##catch the group id | |
d = distance[nrow,id] ## get the distance with point's own group center | |
g[nrow,2] = d ## combine | |
##g_8 is the result | |
cnt = Counter(g[:,1]) | |
cnt = sorted(cnt.items(),key = itemgetter(0)) | |
print "total group: %s" % (group) | |
print "cnt of each group %s" % (cnt) | |
return g | |
def export(result,outFile): | |
fmt = "%i,%i,%1.3f" | |
with open(outFile,'wb') as f: | |
f.write("SerialNumber,Group,Distance\n") | |
np.savetxt(f, result,delimiter=',', fmt=fmt) | |
print "export result to %s" % (outFile) | |
def importFile(inFile): | |
inFile = scipy.loadtxt(inFile, delimiter = ",") | |
return inFile | |
def getResult(inFile,outFile,groupN): | |
inFile = importFile(inFile) | |
result = kmean_distance(inFile, groupN) | |
export(result,outFile) | |
if __name__=='__main__': | |
if len(sys.argv) != 4: | |
print 'usage : kmean.py <path_to_originFile, path_to_targetFile, group_number> ' | |
sys.exit(1) | |
getResult(sys.argv[1],sys.argv[2],int(sys.argv[3])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$python kmean.py [input_filename] [output_filename] [numbers_of_group] |
第二個參數:輸出的檔案名稱,輸出格式亦為.csv檔
第三個參數:要分組的組數
將參數設定好後執行,程式會自動列印出組別以及各組的個數,供使用者參考分組結果.輸入的檔案格式為要分析的變項[n_case,n_attributes],檔案輸出後會自動加上流水編號,分組組別,點與組中心的距離,如下圖(檔案原始碼也同步放在github上https://github.com/bryanyang0528/Kmeans_distance_center):
I'm a big fan of your blog. Thanks for sharing! Expecting more application examples that utilize R language.
回覆刪除Thanks your affirmation! Welcome to share your experience and thinking with me!
刪除