#include // bool #include // memset() #include // FLT_MAX #include // sqrt() #include "util.h" #include "km.h" // alloc and initialize row map static km_row_map_t * km_row_map_init( const size_t num_rows ) { // alloc row map km_row_map_t * const row_map = malloc(sizeof(km_row_map_t) * num_rows); // check for error if (!row_map) { // return failure return false; } // init row map by setting the maximum distance for (size_t i = 0; i < num_rows; i++) { row_map[i].d2 = FLT_MAX; } // return row map return row_map; } static void km_row_map_fini( km_row_map_t * const row_map ) { free(row_map); } // use k-means to iteratively update the cluster centroids until there // are no more changes to the centroids bool km_solve( km_set_t * const cs, const km_set_t * const set, const km_solve_cbs_t * const cbs, void * const cb_data ) { const size_t num_clusters = cs->num_rows, num_floats = set->shape.num_floats; // row map: row => distance, cluster ID km_row_map_t *row_map = km_row_map_init(set->num_rows); if (!row_map) { // return failure return false; } // calculate clusters by doing the following: // * walk all clusters and all rows // * if we find a closer cluster, move row to cluster // * if there were changes to any cluster, then calculate a new // centroid for each cluster by averaging the cluster rows // * repeat until there are no more changes bool changed = false; do { // no changes yet changed = false; for (size_t i = 0; i < num_clusters; i++) { // get the floats for this cluster const float * const floats = km_set_get_row(cs, i); for (size_t j = 0; j < set->num_rows; j++) { const float * const row_floats = km_set_get_row(set, j); // calculate the distance squared between these clusters const float d2 = distance_squared(num_floats, floats, row_floats); if (d2 < row_map[j].d2) { // row is closer to this cluster, update distance and cluster row_map[j].d2 = d2; row_map[j].cluster = i; // flag change changed = true; } } } if (changed) { // if there were changes, then we need to calculate the new // cluster centers // calculate new center for (size_t i = 0; i < num_clusters; i++) { size_t num_rows = 0; float * const floats = km_set_get_row(cs, i); memset(floats, 0, sizeof(float) * num_floats); for (size_t j = 0; j < set->num_rows; j++) { const float * const row_floats = km_set_get_row(set, j); if (row_map[j].cluster == i) { // calculate numerator for average for (size_t k = 0; k < num_floats; k++) { floats[k] += row_floats[k]; } // increment denominator for average num_rows++; } } // save number of rows in this cluster cs->ints[i] = num_rows; if (num_rows > 0) { for (size_t k = 0; k < num_floats; k++) { // divide by denominator to get average floats[k] /= num_rows; } } } } if (cbs && cbs->on_step) { // pass clusters and row map to step callback cbs->on_step(cs, row_map, cb_data); } } while (changed); if (cbs && cbs->on_stats) { float means[num_clusters], variances[num_clusters]; memset(means, 0, sizeof(float) * num_clusters); memset(variances, 0, sizeof(float) * num_clusters); // calculate mean distances for (size_t i = 0; i < set->num_rows; i++) { means[row_map[i].cluster] += row_map[i].d2; } // finalize means for (size_t i = 0; i < num_clusters; i++) { means[i] = (cs->ints[i]) ? (sqrt(means[i]) / cs->ints[i]) : 0; } // calculate variances for (size_t i = 0; i < set->num_rows; i++) { const size_t cluster = row_map[i].cluster; const float variance = (sqrt(row_map[i].d2) - means[cluster]) * (sqrt(row_map[i].d2) - means[cluster]); variances[cluster] += variance; } // finalize variances for (size_t i = 0; i < num_clusters; i++) { variances[i] = (cs->ints[i]) ? (variances[i] / cs->ints[i]) : 0; } // build stats const km_solve_stats_t stats = { .means = means, .variances = variances, .num_clusters = num_clusters, }; // emit means cbs->on_stats(cs, &stats, cb_data); } // free row_map km_row_map_fini(row_map); // return success return true; }