Skip to content

Commit a35a3e5

Browse files
committed
K-Means聚类
1 parent 1f8e75d commit a35a3e5

File tree

1 file changed

+96
-19
lines changed

1 file changed

+96
-19
lines changed

K-Menas/K-Menas.py

+96-19
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,121 @@
1+
#-*- coding: utf-8 -*-
12
import numpy as np
23
from matplotlib import pyplot as plt
4+
from matplotlib import colors
35
from scipy import io as spio
6+
from scipy import misc # 图片操作
7+
from matplotlib.font_manager import FontProperties
8+
font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14) # 解决windows环境下画图汉字乱码问题
9+
410

511

612
def KMeans():
13+
'''二维数据聚类过程演示'''
14+
print u'聚类过程展示...\n'
715
data = spio.loadmat("data.mat")
816
X = data['X']
9-
K = 3 # 总类数
10-
initial_centroids = np.array([[3,3],[6,2],[8,5]]) # 初始化类中心
11-
idx = findClosestCentroids(X,initial_centroids) # 找到每条数据属于哪个类
17+
K = 3 # 总类数
18+
initial_centroids = np.array([[3,3],[6,2],[8,5]]) # 初始化类中心
19+
max_iters = 10
20+
runKMeans(X,initial_centroids,max_iters,True) # 执行K-Means聚类算法
21+
'''
22+
图片压缩
23+
'''
24+
print u'K-Means压缩图片\n'
25+
img_data = misc.imread("bird.png") # 读取图片像素数据
26+
img_data = img_data/255.0 # 像素值映射到0-1
27+
img_size = img_data.shape
28+
X = img_data.reshape(img_size[0]*img_size[1],3) # 调整为N*3的矩阵,N是所有像素点个数
29+
30+
K = 16
31+
max_iters = 5
32+
initial_centroids = kMeansInitCentroids(X,K)
33+
centroids,idx = runKMeans(X, initial_centroids, max_iters, False)
34+
print u'\nK-Means运行结束\n'
35+
print u'\n压缩图片...\n'
36+
idx = findClosestCentroids(X, centroids)
37+
X_recovered = centroids[idx,:]
38+
X_recovered = X_recovered.reshape(img_size[0],img_size[1],3)
39+
40+
print u'绘制图片...\n'
41+
plt.subplot(1,2,1)
42+
plt.imshow(img_data)
43+
plt.title(u"原先图片",fontproperties=font)
44+
plt.subplot(1,2,2)
45+
plt.imshow(X_recovered)
46+
plt.title(u"压缩图像",fontproperties=font)
47+
plt.show()
48+
print u'运行结束!'
1249

13-
centroids = computerCentroids(X,idx,K) # 重新计算类中心
14-
print centroids
1550

16-
# 找到每条数据距离哪个类中心最近
51+
# 找到每条数据距离哪个类中心最近
1752
def findClosestCentroids(X,initial_centroids):
18-
m = X.shape[0] # 数据条数
19-
K = initial_centroids.shape[0] # 类的总数
20-
dis = np.zeros((m,K)) # 存储计算每个点分别到K个类的距离
21-
idx = np.zeros((m,1)) # 要返回的每条数据属于哪个类
53+
m = X.shape[0] # 数据条数
54+
K = initial_centroids.shape[0] # 类的总数
55+
dis = np.zeros((m,K)) # 存储计算每个点分别到K个类的距离
56+
idx = np.zeros((m,1)) # 要返回的每条数据属于哪个类
2257

23-
'''计算每个点到每个类中心的距离'''
58+
'''计算每个点到每个类中心的距离'''
2459
for i in range(m):
2560
for j in range(K):
2661
dis[i,j] = np.dot((X[i,:]-initial_centroids[j,:]).reshape(1,-1),(X[i,:]-initial_centroids[j,:]).reshape(-1,1))
2762

28-
'''返回dis每一行的最小值对应的列号,即为对应的类别'''
29-
idx = np.array(np.where(dis[0,:] == np.min(dis, axis=1)[0]))
30-
for i in np.arange(1, m):
31-
t = np.array(np.where(dis[i,:] == np.min(dis, axis=1)[i]))
32-
idx = np.vstack((idx,t))
33-
return idx
63+
'''返回dis每一行的最小值对应的列号,即为对应的类别
64+
- np.min(dis, axis=1)返回每一行的最小值
65+
- np.where(dis == np.min(dis, axis=1).reshape(-1,1)) 返回对应最小值的坐标
66+
- 注意:可能最小值对应的坐标有多个,where都会找出来,所以返回时返回前m个需要的即可(因为对于多个最小值,属于哪个类别都可以)
67+
'''
68+
dummy,idx = np.where(dis == np.min(dis, axis=1).reshape(-1,1))
69+
return idx[0:dis.shape[0]] # 注意截取一下
3470

3571

36-
# 计算类中心
72+
# 计算类中心
3773
def computerCentroids(X,idx,K):
3874
n = X.shape[1]
3975
centroids = np.zeros((K,n))
4076
for i in range(K):
41-
centroids[i,:] = np.mean(X[np.array(np.where(idx==i)),:], axis=0).reshape(1,-1) # axis=0为每一列
77+
centroids[i,:] = np.mean(X[np.ravel(idx==i),:], axis=0).reshape(1,-1) # 索引要是一维的,axis=0为每一列,idx==i一次找出属于哪一类的,然后计算均值
78+
return centroids
79+
80+
# 聚类算法
81+
def runKMeans(X,initial_centroids,max_iters,plot_process):
82+
m,n = X.shape # 数据条数和维度
83+
K = initial_centroids.shape[0] # 类数
84+
centroids = initial_centroids # 记录当前类中心
85+
previous_centroids = centroids # 记录上一次类中心
86+
idx = np.zeros((m,1)) # 每条数据属于哪个类
87+
88+
for i in range(max_iters): # 迭代次数
89+
print u'迭代计算次数:%d'%(i+1)
90+
idx = findClosestCentroids(X, centroids)
91+
if plot_process: # 如果绘制图像
92+
plt = plotProcessKMeans(X,centroids,previous_centroids) # 画聚类中心的移动过程
93+
previous_centroids = centroids # 重置
94+
centroids = computerCentroids(X, idx, K) # 重新计算类中心
95+
if plot_process: # 显示最终的绘制结果
96+
plt.show()
97+
return centroids,idx # 返回聚类中心和数据属于哪个类
98+
99+
# 画图,聚类中心的移动过程
100+
def plotProcessKMeans(X,centroids,previous_centroids):
101+
plt.scatter(X[:,0], X[:,1]) # 原数据的散点图
102+
plt.plot(previous_centroids[:,0],previous_centroids[:,1],'rx',markersize=10,linewidth=5.0) # 上一次聚类中心
103+
plt.plot(centroids[:,0],centroids[:,1],'rx',markersize=10,linewidth=5.0) # 当前聚类中心
104+
for j in range(centroids.shape[0]): # 遍历每个类,画类中心的移动直线
105+
p1 = centroids[j,:]
106+
p2 = previous_centroids[j,:]
107+
plt.plot([p1[0],p2[0]],[p1[1],p2[1]],"->",linewidth=2.0)
108+
return plt
109+
110+
111+
# 初始化类中心--随机取K个点作为聚类中心
112+
def kMeansInitCentroids(X,K):
113+
m = X.shape[0]
114+
m_arr = np.arange(0,m) # 生成0-m-1
115+
centroids = np.zeros((K,X.shape[1]))
116+
np.random.shuffle(m_arr) # 打乱m_arr顺序
117+
rand_indices = m_arr[:K] # 取前K个
118+
centroids = X[rand_indices,:]
42119
return centroids
43120

44121
if __name__ == "__main__":

0 commit comments

Comments
 (0)