 
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import time


def k_medoids(X, k, max_iters=100, tol=1e-4, random_state=42):
    rng = np.random.RandomState(random_state)
    n_samples, n_features = X.shape

    # 随机初始化 k 个 medoids
    init_idx = rng.choice(n_samples, k, replace=False)
    medoids = X[init_idx]

    for _ in range(max_iters):
        dists = np.linalg.norm(X[:, None, :] - medoids[None, :, :], axis=2)
        labels = np.argmin(dists, axis=1)

        new_medoids = medoids.copy()
        for j in range(k):
            cluster_idx = np.where(labels == j)[0]
            if cluster_idx.size == 0:
                continue
            cluster_pts = X[cluster_idx]

            # 计算簇内点到簇内所有点的总距离
            pairwise_d = np.linalg.norm(
                cluster_pts[:, None, :] - cluster_pts[None, :, :],
                axis=2
            )
            total_d = pairwise_d.sum(axis=1)

            # 选距离和最小的那个点
            best = cluster_idx[np.argmin(total_d)]
            new_medoids[j] = X[best]

        shift = np.linalg.norm(new_medoids - medoids, axis=1).max()
        medoids = new_medoids
        if shift < tol:
            
            break
    print("the iterating number is",  _)
    return labels, medoids

if __name__ == "__main__":

    df = pd.read_csv('./data_file/student_habits_performance.csv')

    X = df[['study_hours_per_day', 'exam_score']].values

    start = time.perf_counter()
    labels, centers = k_medoids(X, k=5)
    end   = time.perf_counter()

    print(f"K-medoids 聚类用时: {end - start:.4f} 秒")

    print(np.unique(labels,    return_counts=True))

    plt.figure(figsize=(8, 6))
    plt.scatter(
        X[:, 0], X[:, 1],
        c=labels,
        cmap='rainbow',
        alpha=0.7,
        s=60
    )
    # 画 medoids
    plt.scatter(
        centers[:, 0], centers[:, 1],
        marker='X', c='black',
        s=200, linewidths=2,
        label='Medoids'
    )

    plt.xlabel('Study Hours per Day')
    plt.ylabel('Exam Score')
    plt.title('K-medoids Clustering (k=5)')
    plt.legend()
    plt.grid(True)
    plt.show()
