#include // bool #include // memset() #include // FLT_MAX #include // sqrt() #include "util.h" #include "km.h" #define MIN_CLUSTER_DISTANCE 0.00001 // alloc and initialize row map static km_row_map_t * km_row_map_init( const size_t num_rows ) { // alloc row map km_row_map_t * const row_map = malloc(sizeof(km_row_map_t) * num_rows); // check for error if (!row_map) { // return failure return false; } // init row map for (size_t i = 0; i < num_rows; i++) { // setting distances to maximum row_map[i].d2 = FLT_MAX; row_map[i].d2_near = FLT_MAX; } // return row map return row_map; } static void km_row_map_fini( km_row_map_t * const row_map ) { free(row_map); } // use k-means to iteratively update the cluster centroids until there // are no more changes to the centroids bool km_solve( km_set_t * const cs, const km_set_t * const set, const km_solve_cbs_t * const cbs, void * const cb_data ) { const size_t num_clusters = cs->num_rows, num_floats = set->shape.num_floats; // row map: row => distance, cluster ID km_row_map_t *row_map = km_row_map_init(set->num_rows); if (!row_map) { // return failure return false; } // calculate clusters by doing the following: // * walk all clusters and all rows // * if we find a closer cluster, move row to cluster // * if there were changes to any cluster, then calculate a new // centroid for each cluster by averaging the cluster rows // * repeat until there are no more changes bool changed = false; do { // no changes yet changed = false; for (size_t i = 0; i < num_clusters; i++) { // get the floats for this cluster const float * const floats = km_set_get_row(cs, i); for (size_t j = 0; j < set->num_rows; j++) { // get row values const float * const vals = km_set_get_row(set, j); // calculate the distance squared between row and cluster const float d2 = distance_squared(num_floats, floats, vals); if (d2 < row_map[j].d2) { // row is closer to this cluster, update row map row_map[j].d2 = d2; row_map[j].cluster = i; // flag change changed = true; } if ((row_map[j].cluster != i) && (d2 < row_map[j].d2_near)) { row_map[j].d2_near = d2; // flag change changed = true; } } } if (changed) { // if there were changes, then we need to calculate the new // cluster centers // calculate new center for (size_t i = 0; i < num_clusters; i++) { size_t num_rows = 0; float * const floats = km_set_get_row(cs, i); memset(floats, 0, sizeof(float) * num_floats); for (size_t j = 0; j < set->num_rows; j++) { const float * const row_floats = km_set_get_row(set, j); if (row_map[j].cluster == i) { // calculate numerator for average for (size_t k = 0; k < num_floats; k++) { floats[k] += row_floats[k]; } // increment denominator for average num_rows++; } } // save number of rows in this cluster cs->ints[i] = num_rows; if (num_rows > 0) { for (size_t k = 0; k < num_floats; k++) { // divide by denominator to get average floats[k] /= num_rows; } } } } if (cbs && cbs->on_step) { // pass clusters and row map to step callback cbs->on_step(cs, row_map, cb_data); } } while (changed); if (cbs && cbs->on_stats) { float sum = 0, silouette = 0, mean_dists[num_clusters], mean_nears[num_clusters]; memset(mean_dists, 0, sizeof(mean_dists)); memset(mean_nears, 0, sizeof(mean_nears)); // calculate sum of distances across all clusters for (size_t i = 0; i < set->num_rows; i++) { sum += row_map[i].d2; } // calculate mean numerators and silouette for (size_t i = 0; i < set->num_rows; i++) { // distance squared (d2) to center of this cluster mean_dists[row_map[i].cluster] += row_map[i].d2; // distance squared (d2) to center of nearest cluster mean_nears[row_map[i].cluster] += row_map[i].d2_near; // calculate silouette denominator // (https://en.wikipedia.org/wiki/Silhouette_%28clustering%29) const float delta = row_map[i].d2_near - row_map[i].d2; if (fabsf(delta) > MIN_CLUSTER_DISTANCE) { silouette += delta / MAX(row_map[i].d2, row_map[i].d2_near); } } // finalize means (divide by row count) for (size_t i = 0; i < num_clusters; i++) { mean_dists[i] = (cs->ints[i]) ? (sqrt(mean_dists[i]) / cs->ints[i]) : 0; mean_nears[i] = (cs->ints[i]) ? (sqrt(mean_nears[i]) / cs->ints[i]) : 0; } // finalize silouette silouette /= set->num_rows; // build stats const km_solve_stats_t stats = { .sum = sum, .silouette = silouette, .mean_dists = mean_dists, .mean_nears = mean_nears, .num_clusters = num_clusters, }; // emit means cbs->on_stats(cs, &stats, cb_data); } // free row_map km_row_map_fini(row_map); // return success return true; }