吳恩達 machine learning 編程作業 python實現 ex7

 

# -*- coding: utf-8 -*-
"""
Created on Fri Jul  3 18:23:49 2020

@author: cheetah023
"""

import numpy as np
import scipy.io as sci
import matplotlib.pyplot as plt
from skimage import io
#函數定義
def findClosestCentroids(X, centroids):
    #樣本數量
    m,n = X.shape
    #形心數量
    centroids = np.reshape(centroids,[len(centroids),n])
    K = centroids.shape[0]
    idx = np.zeros([m,1])
    for i in range(0,m):
        distance_t = np.sum((X[i,:] - centroids[0,:]) ** 2)
        j_t = 0
        for j in range(1,K):
            distance = np.sum((X[i,:] - centroids[j,:]) ** 2)
            if distance < distance_t:
                distance_t = distance
                j_t = j
        idx[i] = j_t
    return idx
def computeCentroids(X, idx, K):
    m,n = X.shape
    centroids = np.zeros([K,n])
    for i in range(0,K):
        count = 0
        for j in range(0,m):
            if idx[j] == i:
                centroids[i,:] += X[j,:]
                count += 1
        centroids[i,:] /= count
    return centroids
def plotDataPoints(X, idx, K):
    m = X.shape[0]
    color = ['b','r','y','g','m','c','k','w']
    for i in range(0,m):
        plt.scatter(X[i,0],X[i,1],c=color[int(idx[i,0])])
    
def plotProgresskMeans(X, centroids, idx, K):
    plotDataPoints(X, idx, K)
    for i in range(0,K):
        plt.plot(centroids[:,0+2*i],centroids[:,1+2*i],'kx--',markersize=8)
def runkMeans(X, initial_centroids, max_iters, plot_progress):
    m,n = X.shape
    idx = np.zeros([m,1])
    K = len(initial_centroids)
    initial_centroids = np.reshape(initial_centroids,[K,n])
    centroids = initial_centroids
    #將每次迭代的形心點放在一行
    #每次迭代centroids_history增加一行
    centroids_history = centroids.flatten()
    if plot_progress:
        plt.figure()
        
    for i in range(0,max_iters):
        idx = findClosestCentroids(X, centroids)
        centroids = computeCentroids(X, idx, K)
        centroids_history = np.row_stack([centroids_history,centroids.flatten()])
        
    if plot_progress:
        plotProgresskMeans(X, centroids_history, idx, K)
    return centroids,idx
def kMeansInitCentroids(X, K):
    idx = np.random.choice(X.shape[0],K)
    centroids = X[idx,:]
    return centroids
#Part 1: Find Closest Centroids
data = sci.loadmat('ex7data2.mat')
#print('data.keys',data.keys())
X = data['X']
print('X:',X.shape)
K = 3
initial_centroids = [[3, 3], [6, 2], [8, 5]]

idx = findClosestCentroids(X, initial_centroids)
print('Closest centroids for the first 3 examples:',idx[0:3] + 1)
print('(the closest centroids should be 1, 3, 2 respectively)')

#Part 2: Compute Means
centroids = computeCentroids(X, idx, K)
print('Centroids computed after initial finding of closest centroids:')
print(centroids)
print('(the centroids should be')
print('[ 2.428301 3.157924 ]')
print('[ 5.813503 2.633656 ]')
print('[ 7.119387 3.616684 ]')

#Part 3: K-Means Clustering
K = 3
max_iters = 10
initial_centroids = [[3, 3], [6, 2], [8, 5]]
[centroids, idx] = runkMeans(X, initial_centroids, max_iters, True)

#Part 4: K-Means Clustering on Pixels
data = io.imread('bird_small.png')
print('data:',data.shape)
m,n,k = data.shape
data = data / 255
X = np.reshape(data,[-1,3])
K = 16;
max_iters = 10
initial_centroids = kMeansInitCentroids(X, K)
[centroids, idx] = runkMeans(X, initial_centroids, max_iters, False)

#Part 5: Image Compression
idx = findClosestCentroids(X, centroids)
X_recovered = np.zeros_like(X)
for i in range(0,m*n):
    X_recovered[i] = centroids[int(idx[i,0])]
X_recovered = X_recovered.reshape([m,n,k])

fig,axes = plt.subplots(nrows=1,ncols=2)
axes[0].imshow(data)
axes[0].set_title('original')
axes[1].imshow(X_recovered)
axes[1].set_title('Compressed, with %d colors.'.format(K))


運行結果:

X: (300, 2)
Closest centroids for the first 3 examples: [[1.]
 [3.]
 [2.]]
(the closest centroids should be 1, 3, 2 respectively)
Centroids computed after initial finding of closest centroids:
[[2.42830111 3.15792418]
 [5.81350331 2.63365645]
 [7.11938687 3.6166844 ]]
(the centroids should be
[ 2.428301 3.157924 ]
[ 5.813503 2.633656 ]
[ 7.119387 3.616684 ]
data: (128, 128, 3)

 

總結:

1、畫centroids形心的移動軌跡時按照octave的流程來畫不太方便,所以增加了centroids_history 來保存每次迭代的形心座標,畫起來舒服些

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章