from typing import Tuple, Iterable, Sequence, List, Dict, DefaultDict from random import sample from math import fsum, sqrt from collections import defaultdict def partial(func, *args): "Rewrite functools.partial() in a way that doesn't confuse mypy" def inner(*moreargs): return func(*args, *moreargs) return inner def mean(data: Iterable[float]) -> float: 'Accurate arithmetic mean' data = list(data) return fsum(data) / len(data) def transpose(matrix: Iterable[Iterable]) -> Iterable[tuple]: 'Swap rows with columns for a 2-D array' return zip(*matrix) Point = Tuple[float, ...] Centroid = Point def dist(p: Point, q: Point, sqrt=sqrt, fsum=fsum, zip=zip) -> float: 'Multi-dimensional euclidean distance' return sqrt(fsum((x1 - x2) ** 2.0 for x1, x2 in zip(p, q))) def assign_data(centroids: Sequence[Centroid], data: Iterable[Point]) -> Dict[Centroid, Sequence[Point]]: 'Assign data the closest centroid' d : DefaultDict[Point, List[Point]] = defaultdict(list) for point in data: centroid: Point = min(centroids, key=partial(dist, point)) d[centroid].append(point) return dict(d) def compute_centroids(groups: Iterable[Sequence[Point]]) -> List[Centroid]: 'Compute the centroid of each group' return [tuple(map(mean, transpose(group))) for group in groups] def k_means(data: Iterable[Point], k:int=2, iterations:int=10) -> List[Point]: 'Return k-centroids for the data' data = list(data) centroids = sample(data, k) for i in range(iterations): labeled = assign_data(centroids, data) centroids = compute_centroids(labeled.values()) return centroids def quality(labeled: Dict[Centroid, Sequence[Point]]) -> float: 'Mean value of squared distances from data to its assigned centroid' return mean(dist(c, p) ** 2 for c, pts in labeled.items() for p in pts) if __name__ == '__main__': from pprint import pprint print('Simple example with six 3-D points clustered into two groups') points = [ (10, 41, 23), (22, 30, 29), (11, 42, 5), (20, 32, 4), (12, 40, 12), (21, 36, 23), ] centroids = k_means(points, k=2) pprint(assign_data(centroids, points)) print('\nExample with a richer dataset.') print('See: https://www.datascience.com/blog/introduction-to-k-means-clustering-algorithm-learn-data-science-tutorials') data = [ (10, 30), (12, 50), (14, 70), (9, 150), (20, 175), (8, 200), (14, 240), (50, 35), (40, 50), (45, 60), (55, 45), (60, 130), (60, 220), (70, 150), (60, 190), (90, 160), ] print('k quality') print('- -------') for k in range(1, 8): centroids = k_means(data, k, iterations=20) d = assign_data(centroids, data) print(f'{k} {quality(d) :8,.1f}')