2014年10月25日 星期六

[Python] K-means 分組 Script


        K-means是一種非監督式學習的分類方式,原理將向量空間中的點分成k群後,計算每個點到群組中心的距離,目的在於使點到群心的距離成為最小(參考資料:http://en.wikipedia.org/wiki/K-means_clusteringhttps://www.youtube.com/watch?v=aiJ8II94qck).K-means因為容易理解,所以業界中也很常使用,本篇主要不是講解演算法,而是利用python中的sikit-learn套件製作一個小程式,讓使用者可以輕鬆的使用k-means演算法.


原始碼如下:

#! /usr/bin/env python
# -*- coding: utf-8 -*-
# Above the run-comment and file encoding comment.
#----below is Bryan's code----
# for caculate the minnimal distance between every bill and center of group
#
import sys
import numpy as np
import scipy
#import pylab as pl
#from mpl_toolkits.mplot3d import Axes3D
from sklearn.cluster import KMeans
from operator import itemgetter
from collections import Counter
def kmean_distance(filename, group):
k = KMeans(n_clusters = group, tol=0.000000001, init='random')
rowname = filename[:,0]
filename = filename[:,1:]
g = k.fit_predict(filename) ##group
distance = k.fit_transform(filename) ## caculate distance between point and "every" group center
g = np.column_stack((rowname, g, np.zeros((len(filename),)) )) ## combine raw data and
for nrow in range(len(g)):
id = int(g[nrow,1]) ##catch the group id
d = distance[nrow,id] ## get the distance with point's own group center
g[nrow,2] = d ## combine
##g_8 is the result
cnt = Counter(g[:,1])
cnt = sorted(cnt.items(),key = itemgetter(0))
print "total group: %s" % (group)
print "cnt of each group %s" % (cnt)
return g
def export(result,outFile):
fmt = "%i,%i,%1.3f"
with open(outFile,'wb') as f:
f.write("SerialNumber,Group,Distance\n")
np.savetxt(f, result,delimiter=',', fmt=fmt)
print "export result to %s" % (outFile)
def importFile(inFile):
inFile = scipy.loadtxt(inFile, delimiter = ",")
return inFile
def getResult(inFile,outFile,groupN):
inFile = importFile(inFile)
result = kmean_distance(inFile, groupN)
export(result,outFile)
if __name__=='__main__':
if len(sys.argv) != 4:
print 'usage : kmean.py <path_to_originFile, path_to_targetFile, group_number> '
sys.exit(1)
getResult(sys.argv[1],sys.argv[2],int(sys.argv[3]))
view raw kmeans.py hosted with ❤ by GitHub
這支程式其實很單純,就是利用sikit-learn的kmeans套件將input的資料分組,比較特別的是,結合了另外一個函數計算每個點到群中心的距離後一起輸出給使用者.因為考量讓不會python的使用者也可以方便使用程式,所以包成一個script檔案,使用方式如下:


$python kmean.py [input_filename] [output_filename] [numbers_of_group]
view raw gistfile1.txt hosted with ❤ by GitHub
第一個參數:要分析的檔案名稱,支援的格式為.csv逗號分格檔案.(第一個欄位為)
第二個參數:輸出的檔案名稱,輸出格式亦為.csv檔
第三個參數:要分組的組數
將參數設定好後執行,程式會自動列印出組別以及各組的個數,供使用者參考分組結果.輸入的檔案格式為要分析的變項[n_case,n_attributes],檔案輸出後會自動加上流水編號,分組組別,點與組中心的距離,如下圖(檔案原始碼也同步放在github上https://github.com/bryanyang0528/Kmeans_distance_center):






2 則留言:

  1. I'm a big fan of your blog. Thanks for sharing! Expecting more application examples that utilize R language.

    回覆刪除
    回覆
    1. Thanks your affirmation! Welcome to share your experience and thinking with me!

      刪除