|
| 1 | +#-*- coding: utf-8 -*- |
1 | 2 | import numpy as np
|
2 | 3 | from matplotlib import pyplot as plt
|
| 4 | +from matplotlib import colors |
3 | 5 | from scipy import io as spio
|
| 6 | +from scipy import misc # 图片操作 |
| 7 | +from matplotlib.font_manager import FontProperties |
| 8 | +font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14) # 解决windows环境下画图汉字乱码问题 |
| 9 | + |
4 | 10 |
|
5 | 11 |
|
6 | 12 | def KMeans():
|
| 13 | + '''二维数据聚类过程演示''' |
| 14 | + print u'聚类过程展示...\n' |
7 | 15 | data = spio.loadmat("data.mat")
|
8 | 16 | X = data['X']
|
9 |
| - K = 3 # 总类数 |
10 |
| - initial_centroids = np.array([[3,3],[6,2],[8,5]]) # 初始化类中心 |
11 |
| - idx = findClosestCentroids(X,initial_centroids) # 找到每条数据属于哪个类 |
| 17 | + K = 3 # 总类数 |
| 18 | + initial_centroids = np.array([[3,3],[6,2],[8,5]]) # 初始化类中心 |
| 19 | + max_iters = 10 |
| 20 | + runKMeans(X,initial_centroids,max_iters,True) # 执行K-Means聚类算法 |
| 21 | + ''' |
| 22 | + 图片压缩 |
| 23 | + ''' |
| 24 | + print u'K-Means压缩图片\n' |
| 25 | + img_data = misc.imread("bird.png") # 读取图片像素数据 |
| 26 | + img_data = img_data/255.0 # 像素值映射到0-1 |
| 27 | + img_size = img_data.shape |
| 28 | + X = img_data.reshape(img_size[0]*img_size[1],3) # 调整为N*3的矩阵,N是所有像素点个数 |
| 29 | + |
| 30 | + K = 16 |
| 31 | + max_iters = 5 |
| 32 | + initial_centroids = kMeansInitCentroids(X,K) |
| 33 | + centroids,idx = runKMeans(X, initial_centroids, max_iters, False) |
| 34 | + print u'\nK-Means运行结束\n' |
| 35 | + print u'\n压缩图片...\n' |
| 36 | + idx = findClosestCentroids(X, centroids) |
| 37 | + X_recovered = centroids[idx,:] |
| 38 | + X_recovered = X_recovered.reshape(img_size[0],img_size[1],3) |
| 39 | + |
| 40 | + print u'绘制图片...\n' |
| 41 | + plt.subplot(1,2,1) |
| 42 | + plt.imshow(img_data) |
| 43 | + plt.title(u"原先图片",fontproperties=font) |
| 44 | + plt.subplot(1,2,2) |
| 45 | + plt.imshow(X_recovered) |
| 46 | + plt.title(u"压缩图像",fontproperties=font) |
| 47 | + plt.show() |
| 48 | + print u'运行结束!' |
12 | 49 |
|
13 |
| - centroids = computerCentroids(X,idx,K) # 重新计算类中心 |
14 |
| - print centroids |
15 | 50 |
|
16 |
| -# 找到每条数据距离哪个类中心最近 |
| 51 | +# 找到每条数据距离哪个类中心最近 |
17 | 52 | def findClosestCentroids(X,initial_centroids):
|
18 |
| - m = X.shape[0] # 数据条数 |
19 |
| - K = initial_centroids.shape[0] # 类的总数 |
20 |
| - dis = np.zeros((m,K)) # 存储计算每个点分别到K个类的距离 |
21 |
| - idx = np.zeros((m,1)) # 要返回的每条数据属于哪个类 |
| 53 | + m = X.shape[0] # 数据条数 |
| 54 | + K = initial_centroids.shape[0] # 类的总数 |
| 55 | + dis = np.zeros((m,K)) # 存储计算每个点分别到K个类的距离 |
| 56 | + idx = np.zeros((m,1)) # 要返回的每条数据属于哪个类 |
22 | 57 |
|
23 |
| - '''计算每个点到每个类中心的距离''' |
| 58 | + '''计算每个点到每个类中心的距离''' |
24 | 59 | for i in range(m):
|
25 | 60 | for j in range(K):
|
26 | 61 | dis[i,j] = np.dot((X[i,:]-initial_centroids[j,:]).reshape(1,-1),(X[i,:]-initial_centroids[j,:]).reshape(-1,1))
|
27 | 62 |
|
28 |
| - '''返回dis每一行的最小值对应的列号,即为对应的类别''' |
29 |
| - idx = np.array(np.where(dis[0,:] == np.min(dis, axis=1)[0])) |
30 |
| - for i in np.arange(1, m): |
31 |
| - t = np.array(np.where(dis[i,:] == np.min(dis, axis=1)[i])) |
32 |
| - idx = np.vstack((idx,t)) |
33 |
| - return idx |
| 63 | + '''返回dis每一行的最小值对应的列号,即为对应的类别 |
| 64 | + - np.min(dis, axis=1)返回每一行的最小值 |
| 65 | + - np.where(dis == np.min(dis, axis=1).reshape(-1,1)) 返回对应最小值的坐标 |
| 66 | + - 注意:可能最小值对应的坐标有多个,where都会找出来,所以返回时返回前m个需要的即可(因为对于多个最小值,属于哪个类别都可以) |
| 67 | + ''' |
| 68 | + dummy,idx = np.where(dis == np.min(dis, axis=1).reshape(-1,1)) |
| 69 | + return idx[0:dis.shape[0]] # 注意截取一下 |
34 | 70 |
|
35 | 71 |
|
36 |
| -# 计算类中心 |
| 72 | +# 计算类中心 |
37 | 73 | def computerCentroids(X,idx,K):
|
38 | 74 | n = X.shape[1]
|
39 | 75 | centroids = np.zeros((K,n))
|
40 | 76 | for i in range(K):
|
41 |
| - centroids[i,:] = np.mean(X[np.array(np.where(idx==i)),:], axis=0).reshape(1,-1) # axis=0为每一列 |
| 77 | + centroids[i,:] = np.mean(X[np.ravel(idx==i),:], axis=0).reshape(1,-1) # 索引要是一维的,axis=0为每一列,idx==i一次找出属于哪一类的,然后计算均值 |
| 78 | + return centroids |
| 79 | + |
| 80 | +# 聚类算法 |
| 81 | +def runKMeans(X,initial_centroids,max_iters,plot_process): |
| 82 | + m,n = X.shape # 数据条数和维度 |
| 83 | + K = initial_centroids.shape[0] # 类数 |
| 84 | + centroids = initial_centroids # 记录当前类中心 |
| 85 | + previous_centroids = centroids # 记录上一次类中心 |
| 86 | + idx = np.zeros((m,1)) # 每条数据属于哪个类 |
| 87 | + |
| 88 | + for i in range(max_iters): # 迭代次数 |
| 89 | + print u'迭代计算次数:%d'%(i+1) |
| 90 | + idx = findClosestCentroids(X, centroids) |
| 91 | + if plot_process: # 如果绘制图像 |
| 92 | + plt = plotProcessKMeans(X,centroids,previous_centroids) # 画聚类中心的移动过程 |
| 93 | + previous_centroids = centroids # 重置 |
| 94 | + centroids = computerCentroids(X, idx, K) # 重新计算类中心 |
| 95 | + if plot_process: # 显示最终的绘制结果 |
| 96 | + plt.show() |
| 97 | + return centroids,idx # 返回聚类中心和数据属于哪个类 |
| 98 | + |
| 99 | +# 画图,聚类中心的移动过程 |
| 100 | +def plotProcessKMeans(X,centroids,previous_centroids): |
| 101 | + plt.scatter(X[:,0], X[:,1]) # 原数据的散点图 |
| 102 | + plt.plot(previous_centroids[:,0],previous_centroids[:,1],'rx',markersize=10,linewidth=5.0) # 上一次聚类中心 |
| 103 | + plt.plot(centroids[:,0],centroids[:,1],'rx',markersize=10,linewidth=5.0) # 当前聚类中心 |
| 104 | + for j in range(centroids.shape[0]): # 遍历每个类,画类中心的移动直线 |
| 105 | + p1 = centroids[j,:] |
| 106 | + p2 = previous_centroids[j,:] |
| 107 | + plt.plot([p1[0],p2[0]],[p1[1],p2[1]],"->",linewidth=2.0) |
| 108 | + return plt |
| 109 | + |
| 110 | + |
| 111 | +# 初始化类中心--随机取K个点作为聚类中心 |
| 112 | +def kMeansInitCentroids(X,K): |
| 113 | + m = X.shape[0] |
| 114 | + m_arr = np.arange(0,m) # 生成0-m-1 |
| 115 | + centroids = np.zeros((K,X.shape[1])) |
| 116 | + np.random.shuffle(m_arr) # 打乱m_arr顺序 |
| 117 | + rand_indices = m_arr[:K] # 取前K个 |
| 118 | + centroids = X[rand_indices,:] |
42 | 119 | return centroids
|
43 | 120 |
|
44 | 121 | if __name__ == "__main__":
|
|
0 commit comments