From f557d1f49a2914c6084dd18efc783395228d8ce0 Mon Sep 17 00:00:00 2001 From: Paul Duncan Date: Tue, 5 Feb 2019 00:22:15 -0500 Subject: mv *.[hc] src/ --- src/km-solve.c | 200 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 src/km-solve.c (limited to 'src/km-solve.c') diff --git a/src/km-solve.c b/src/km-solve.c new file mode 100644 index 0000000..f579a41 --- /dev/null +++ b/src/km-solve.c @@ -0,0 +1,200 @@ +#include // bool +#include // memset() +#include // FLT_MAX +#include // sqrt() +#include "util.h" +#include "km.h" + +#define MIN_CLUSTER_DISTANCE 0.00001 + +// alloc and initialize row map +static km_row_map_t * +km_row_map_init( + const size_t num_rows +) { + // alloc row map + km_row_map_t * const row_map = malloc(sizeof(km_row_map_t) * num_rows); + + // check for error + if (!row_map) { + // return failure + return false; + } + + // init row map + for (size_t i = 0; i < num_rows; i++) { + // setting distances to maximum + row_map[i].d2 = FLT_MAX; + row_map[i].d2_near = FLT_MAX; + } + + // return row map + return row_map; +} + +static void +km_row_map_fini( + km_row_map_t * const row_map +) { + free(row_map); +} + +// use k-means to iteratively update the cluster centroids until there +// are no more changes to the centroids +bool +km_solve( + km_set_t * const cs, + const km_set_t * const set, + const km_solve_cbs_t * const cbs, + void * const cb_data +) { + const size_t num_clusters = cs->num_rows, + num_floats = set->shape.num_floats; + + // row map: row => distance, cluster ID + km_row_map_t *row_map = km_row_map_init(set->num_rows); + if (!row_map) { + // return failure + return false; + } + + // calculate clusters by doing the following: + // * walk all clusters and all rows + // * if we find a closer cluster, move row to cluster + // * if there were changes to any cluster, then calculate a new + // centroid for each cluster by averaging the cluster rows + // * repeat until there are no more changes + bool changed = false; + do { + // no changes yet + changed = false; + + for (size_t i = 0; i < num_clusters; i++) { + // get the floats for this cluster + const float * const floats = km_set_get_row(cs, i); + + for (size_t j = 0; j < set->num_rows; j++) { + // get row values + const float * const vals = km_set_get_row(set, j); + + // calculate the distance squared between row and cluster + const float d2 = distance_squared(num_floats, floats, vals); + + if (d2 < row_map[j].d2) { + // row is closer to this cluster, update row map + row_map[j].d2 = d2; + row_map[j].cluster = i; + + // flag change + changed = true; + } + + if ((row_map[j].cluster != i) && (d2 < row_map[j].d2_near)) { + row_map[j].d2_near = d2; + + // flag change + changed = true; + } + } + } + + if (changed) { + // if there were changes, then we need to calculate the new + // cluster centers + + // calculate new center + for (size_t i = 0; i < num_clusters; i++) { + size_t num_rows = 0; + float * const floats = km_set_get_row(cs, i); + memset(floats, 0, sizeof(float) * num_floats); + + for (size_t j = 0; j < set->num_rows; j++) { + const float * const row_floats = km_set_get_row(set, j); + + if (row_map[j].cluster == i) { + // calculate numerator for average + for (size_t k = 0; k < num_floats; k++) { + floats[k] += row_floats[k]; + } + + // increment denominator for average + num_rows++; + } + } + + // save number of rows in this cluster + cs->ints[i] = num_rows; + + if (num_rows > 0) { + for (size_t k = 0; k < num_floats; k++) { + // divide by denominator to get average + floats[k] /= num_rows; + } + } + } + } + + if (cbs && cbs->on_step) { + // pass clusters and row map to step callback + cbs->on_step(cs, row_map, cb_data); + } + } while (changed); + + if (cbs && cbs->on_stats) { + float sum = 0, + silouette = 0, + mean_dists[num_clusters], + mean_nears[num_clusters]; + + memset(mean_dists, 0, sizeof(mean_dists)); + memset(mean_nears, 0, sizeof(mean_nears)); + + // calculate sum of distances across all clusters + for (size_t i = 0; i < set->num_rows; i++) { + sum += row_map[i].d2; + } + + // calculate mean numerators and silouette + for (size_t i = 0; i < set->num_rows; i++) { + // distance squared (d2) to center of this cluster + mean_dists[row_map[i].cluster] += row_map[i].d2; + + // distance squared (d2) to center of nearest cluster + mean_nears[row_map[i].cluster] += row_map[i].d2_near; + + // calculate silouette denominator + // (https://en.wikipedia.org/wiki/Silhouette_%28clustering%29) + const float delta = row_map[i].d2_near - row_map[i].d2; + if (fabsf(delta) > MIN_CLUSTER_DISTANCE) { + silouette += delta / MAX(row_map[i].d2, row_map[i].d2_near); + } + } + + // finalize means (divide by row count) + for (size_t i = 0; i < num_clusters; i++) { + mean_dists[i] = (cs->ints[i]) ? (sqrt(mean_dists[i]) / cs->ints[i]) : 0; + mean_nears[i] = (cs->ints[i]) ? (sqrt(mean_nears[i]) / cs->ints[i]) : 0; + } + + // finalize silouette + silouette /= set->num_rows; + + // build stats + const km_solve_stats_t stats = { + .sum = sum, + .silouette = silouette, + .mean_dists = mean_dists, + .mean_nears = mean_nears, + .num_clusters = num_clusters, + }; + + // emit means + cbs->on_stats(cs, &stats, cb_data); + } + + // free row_map + km_row_map_fini(row_map); + + // return success + return true; +} -- cgit v1.2.3