k-means法をpythonで実装してみる

概要

  • 実装してみた
  • k-means法を実装してみたくなる一番の理由は「実装できそうな気がするから
  • fit()、fit_predict()という2メソッドの名前だけsklearnを参考にした

実装

import numpy as np


class MyKMeans(object):
    """クラスタリングを行うためのクラス"""
    
    def __init__(self, nc, n_iter=10, tol=1e-4, 
        save_cents=True):
        self.nc = nc    # クラスタ数
        self.uniq_labels = np.arange(self.nc)    # 一意なラベル
        self.n_iter = n_iter    # 中心点の更新回数
        self.tol = tol    # 中心点の移動距離がtol以下で更新打ち切り
        self.centroids = None    # 現在の中心点
        self.save_cents = save_cents    # 中心点の移動履歴を記憶するか
        self.hist_centroids = None    # 中心点の移動履歴

    def _choice_rand_cents(self, X):
        """サンプルからランダムに中心点を選択する"""
        n = X.shape[0]    # サンプル数
        inds = np.random.choice(n, self.nc, replace=False)    # ダブりなし
        self.centroids = X[inds, :]

        if self.save_cents:
            self.hist_centroids = [self.centroids]

    def _measure_distances(self, X):
        """現在の中心点からの距離を算出する"""
        distances = []
        for centroid in self.centroids:
            c_distances = [np.sqrt(np.sum((x - centroid) ** 2)) for x in X]
            distances.append(c_distances)
        return np.array(distances).T

    def _predict(self, X):
        """各サンプルのラベルを予測する"""

        # 中心点からの距離を算出
        distances = self._measure_distances(X)

        # より距離が近いほうのラベルを更新する
        return np.argmin(distances, axis=1)

    def _update_cents(self, X, y):
        """中心点を更新する"""
        new_centroids = []
        for label in self.uniq_labels:
            mask = y == label

            # そのクラスタに所属すると予測されるサンプル
            X_new = X[mask, :]

            new_centroids.append(np.mean(X_new, axis=0))

        # 履歴として保存
        if self.save_cents:
            self.hist_centroids.append(np.array(new_centroids))

        # 前回の中心点
        prev_centroids = self.centroids

        self.centroids = np.array(new_centroids)

        # 中心点の移動距離を返す
        return np.sqrt(np.sum(
            (self.centroids - prev_centroids) ** 2
        ))

    def _iter(self, X):
        """ループの繰り返し部分"""
        
        # ラベルを予測する
        y = self._predict(X)

        # 中心点を更新する
        move_distance = self._update_cents(X, y)

        return move_distance

    def fit(self, X, debug=False):
        """学習させる"""

        # ランダムに中心点を選択する
        self._choice_rand_cents(X)

        for i in range(self.n_iter):
            move_dist = self._iter(X)
            if debug:
                print(self.centroids)
                print("move %f" % (move_dist))
            if move_dist < self.tol or i == (self.n_iter - 1):
                self.hist_centroids = np.array(self.hist_centroids)    # 後始末
                break

    def fit_predict(self, X, debug=False):
        """サンプルを分類してラベルを予測する"""
        self.fit(X, debug=debug)
        return self._predict(X)

使用例

import numpy as np
from matplotlib import pyplot as plt
from my_kmeans import MyKMeans


def make_two_clusters(n_samples, mu1, std1, mu2, std2):
    """
    指定された平均、標準偏差に基づく正規分布から
    クラスターを2つ生成する(次元数=2)
    """
    ndim = 2    # 次元数
    a = np.random.randn(n_samples, ndim) * std1 + mu1
    b = np.random.randn(n_samples, ndim) * std2 + mu2
    X = np.vstack((a, b))
    return X, a, b


def main():
    
    # 再現性を確保
    np.random.seed(123)

    # (1)ランダムなクラスタを2つ生成する
    n_samples = 3000
    X, a, b = make_two_clusters(n_samples,
        mu1=3.0, std1=0.75,
        mu2=6.0, std2=1.05)

    # k-means法でクラスタリングを行う(クラスタ数=2)
    n_clusters = 2    # クラスタ数
    mkm = MyKMeans(nc=n_clusters)
    y = mkm.fit_predict(X, debug=False)

    # デバッグ
    plot_for_debug = True
    if plot_for_debug:
        plt.figure(figsize=(10, 5))

        # プロット(1) - 分類前
        plt.subplot(1, 2, 1)
        plt.title("Before")

        # サンプルをプロット
        plt.scatter(X[:, 0], X[:, 1], 
            color="springgreen",
            edgecolor="green",
            marker="o",
            label="random sample")

        # 中心点をプロット
        initial_cents = mkm.hist_centroids[0]
        plt.scatter(initial_cents[:, 0], initial_cents[:, 1],
            color="red",
            marker="*",
            label="centroid")

        plt.legend()

        # プロット(2) - 分類後
        plt.subplot(1, 2, 2)
        plt.title("After")

        # クラスタごとにサンプルをプロット
        colors = ("springgreen", "lightblue")
        edgecolors = ("green", "blue")
        labels = tuple(
            ["cluster %s" % str(i + 1) for i in range(len(mkm.uniq_labels))]
        )
        for i, label in enumerate(mkm.uniq_labels):
            mask = y == label
            X_c = X[mask, :]
            plt.scatter(X_c[:, 0], X_c[:, 1],
                color=colors[i],
                edgecolor=edgecolors[i],
                label=labels[i])

        # 中心点の移動履歴をプロット
        for i in range(mkm.uniq_labels.shape[0]):
            hist_c = mkm.hist_centroids[:, i, :]    # i+1回目に更新された中心点
            x_c = hist_c[:, 0]
            y_c = hist_c[:, 1]
            plt.plot(x_c, y_c, "-", color="red",
                label="centroids %s" % str(i + 1))

        plt.legend()
        plt.tight_layout()
        plt.show()


if __name__ == '__main__':
    main()

実行結果

※赤い線は中心点の移動履歴(それぞれ、より中心点らしき方向に移動している)
f:id:zdassen:20170327150934p:plain