前言
前面的學習中,已經詳細了解了K均值算法的相關原理。本篇文章,我們將使用python實現K均值算法并將其應用于圖像壓縮處理。這對K均值算法的直觀理解是非常有幫助的。在本次算法實現中,利用K均值算法減少一副圖像中圖像顏色數量出現最多的部分像素,來實現圖像壓縮。
算法實現
K均值算法的主要原理就是給定一組初始數據,并初始化聚類中心,根據初始化的聚類中心,將的數據分配給最近的聚類中心,并重新計算新的聚類中心,一直重復這個過程,直到沒有最新的數據分配給聚類中心或者是聚類中心不再發生新改變。
尋找聚類中心
K均值算法中,需要將每一個訓練樣本分配給最接近該樣本的聚類中心,對于每一個訓練樣本
,可以用如下公式求取其聚類中心:
其中,表示最接近
的聚類中心索引,而
則表示第
個聚類中心的值或者位置,在代碼中,用
idx[i]
表示。
尋找聚類中心的的算法可以用以下代碼實現:
- 初始化
首先,對一些參數進行初始化,如加載訓練樣本,聚類中心的數目和初始值設置,如下代碼所示:
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as scio
from skimage import io
from skimage import img_as_float
#加載訓練樣本
data = scio.loadmat('ex7data2.mat')
X = data['X']
# 選擇初始的聚類中心的數量和位置(value)
k = 3 #聚類中心的數目
initial_centroids = np.array([[3, 3], [6, 2], [8, 5]]) #聚類中心的初始值
- 尋找聚類中心
求取聚類中心的過程可以內外兩層循環完成,內循環表示求取每一個聚類中心與樣本的范數(距離),而外循環表示每一個樣本
減去聚類中心的所得的值。通過以上步驟,可以得到一個300X3的范數矩陣,最后,返回行方向上最小值索引(長度為300的一維數組)。
def find_closest_centroids(X, centroids):
K = centroids.shape[0] # K=3
m = X.shape[0] # m = 300
idx = np.zeros(m)
means = np.zeros((m, K))
for i in range(m):
x = X[i]
#外循環,每一個x的位置(二維矩陣)減去聚類中心的位置(二維矩陣)
diff = x - centroids
for k in range(K):
#內循環,x減去每一個聚類中心所得范數
means[i, k] = np.linalg.norm(diff[k])
#聚類中心的行方向上最小值所對應的索引
idx = np.argmin(means, axis=1)
return idx
注意:通過求取聚類中心,得到了由聚類中心索引所構成的一維數組。求取聚類中心之后,相當于給每個訓練樣本打上聚類中心
的標識,而這一維數組的索引與訓練樣本
的索引相對應,例如,300個訓練樣本,最后求得一個長度為300的一維數組,而其值表示訓練樣本所對應的聚類中心的標識。
計算聚類中心均值
以上,我們求取了行方向上最小聚類中心的索引,對于給定的個聚類中心,需要求得給第
個聚類中心的所有訓練樣本的均值,可以用如下公式表示:
假設,有兩個訓練樣本被分配給了聚類中心
,則
計算聚類中心的均值的算法實現,如下代碼所示
def compute_centroids(X, idx, K):
(m, n) = X.shape #m=300,n=2
centroids = np.zeros((K, n))
for k in range(K):
#每個聚類中心索引所對應的樣本x,表示分配給聚類中心索引k的訓練樣本x
x_for_centroid_k = X[np.where(idx == k)]
#分配給索引為k的聚類中心的樣本x在列方向上的和除以分配給聚類中心所對應的樣本數量
centroid_k = np.sum(x_for_centroid_k, axis=0) / x_for_centroid_k.shape[0]
centroids[k] = centroid_k
return centroids
K均值算法的可視化實現
通過以上步驟,已經得到了訓練樣本的聚類中心和其均值,通過以下代碼,通過十次迭代,可視化的實現K均值算法在訓練樣本中的運行方式,具體實現,如下代碼所示:
def run_kmeans(X, initial_centroids, max_iters, plot):
if plot:
plt.figure()
(m, n) = X.shape
K = initial_centroids.shape[0]
centroids = initial_centroids
previous_centroids = centroids
idx = np.zeros(m)
for i in range(max_iters):
print('K-Means iteration {}/{}'.format((i + 1), max_iters))
idx = find_closest_centroids(X, centroids)
if plot:
plot_progress(X, centroids, previous_centroids, idx, K, i)
previous_centroids = centroids
input('Press ENTER to continue')
centroids = compute_centroids(X, idx, K)
return centroids, idx
def plot_progress(X, centroids, previous, idx, K, i):
plt.scatter(X[:, 0], X[:, 1], c=idx, s=15)
plt.scatter(centroids[:, 0], centroids[:, 1], marker='x', c='black', s=25)
for j in range(centroids.shape[0]):
draw_line(centroids[j], previous[j])
plt.title('Iteration number {}'.format(i + 1))
def draw_line(p1, p2):
plt.plot(np.array([p1[0], p2[0]]), np.array([p1[1], p2[1]]), c='black', linewidth=1)
通過10次迭代后,運行圖像,如下圖所示:
利用K均值算法實現圖像壓縮
以上,已經詳細了解并實現了K均值算法,在這部分內容中,將使用K均值算法來實現圖像壓縮,所謂圖像壓縮指的是在圖像像素方面的處理。圖像常用的編碼方式為RGB編碼,即用三基色(RED,GREEN,BLUE)表示圖像顏色。每個像素由三個8位無符號二進制數(范圍從0到255)表示其像素顏色,例如,一個像素的顏色可以用(220,101,25)
表示。給定的圖像包含這數千種顏色,通過K均值算法,可以將其顏色的數量降至16種,從而實現圖像壓縮。
像素處理
圖像的每個像素顏色即代表訓練樣本,通過K均值算法尋找16種顏色代表圖像中的所有像素的顏色,即也就是尋找16個聚類中心,最后,將所有的像素顏色替換為16個聚類中心所對應的顏色。
- 圖像加載和預處理
對于每一個像素,可以用一個三維矩陣表示,其中,第一維和第二維表示其所在位置,第三維代表其是藍色,紅色,或者綠色。例如一個矩陣,表示53行,44列所在的像素其顏色為3.
在此過程中,需要將圖像轉換為的矩陣,其中,m=像素的行×列。其實現過程可以用如下代碼表示
image = io.imread('bird_small.png')
#將圖像轉換為浮點型數據
image = img_as_float(image)
img_shape = image.shape
X = image.reshape(img_shape[0] * img_shape[1], 3)
通過以上代碼,將圖像轉化為(128×128,3)的二維矩陣
運行K均值算法處理圖像
- 聚類中心的隨機初始化
在運行算法之前,還需要對聚類中心,進行隨機初始化的處理,其初始化過程也就是對位置進行初始化,具體實現方式如下代碼所示
def kmeans_init_centroids(X, K):
centroids = np.zeros((K, X.shape[1]))
indices = np.random.randint(X.shape[0], size=K)
centroids = X[indices]
return centroids
根據之前的算法實現,現在,可以直接運行K均值算法了,其實現方式如下所示
K = 16 #設置聚類中心數量
max_iters = 10 #最大迭代次數
initial_centroids = kmeans_init_centroids(X, K)
centroids, idx = run_kmeans(X, initial_centroids, max_iters, False)
- 實現圖像壓縮
經過以上處理,圖像壓縮的實現步驟如下代碼所示:
idx = find_closest_centroids(X, centroids)
X_recovered = centroids[idx]
# (128*128*3)
X_recovered = np.reshape(X_recovered, (img_shape[0], img_shape[1], 3))
plt.subplot(2, 1, 1)
plt.imshow(image)
plt.title('Original')
plt.subplot(2, 1, 2)
plt.imshow(X_recovered)
plt.title('Compressed, with {} colors'.format(K))
最后,經過處理過后的圖像對比如下圖所示,可以明顯的看出,第二幅圖像的質量有所降低。