diff options
Diffstat (limited to 'km-solve.c')
-rw-r--r-- | km-solve.c | 193 |
1 files changed, 193 insertions, 0 deletions
diff --git a/km-solve.c b/km-solve.c new file mode 100644 index 0000000..4560c67 --- /dev/null +++ b/km-solve.c @@ -0,0 +1,193 @@ +#include <stdbool.h> // bool +#include <string.h> // memset() +#include <float.h> // FLT_MAX +#include <math.h> // sqrt() +#include "util.h" +#include "km.h" + +// calculate squared euclidean distance between two points +static float +distance_squared( + const size_t num_floats, + const float * const a, + const float * const b +) { + float r = 0.0; + + for (size_t i = 0; i < num_floats; i++) { + r += (b[i] - a[i]) * (b[i] - a[i]); + } + + // return squared distance + return r; +} + +// alloc and initialize row map +static km_row_map_t * +km_row_map_init( + const size_t num_rows +) { + // alloc row map + km_row_map_t * const row_map = malloc(sizeof(km_row_map_t) * num_rows); + + // check for error + if (!row_map) { + // return failure + return false; + } + + // init row map by setting the maximum distance + for (size_t i = 0; i < num_rows; i++) { + row_map[i].d2 = FLT_MAX; + } + + // return row map + return row_map; +} + +static void +km_row_map_fini( + km_row_map_t * const row_map +) { + free(row_map); +} + +// use k-means to iteratively update the cluster centroids until there +// are no more changes to the centroids +bool +km_solve( + km_set_t * const cs, + const km_set_t * const set, + const km_solve_cbs_t * const cbs, + void * const cb_data +) { + const size_t num_clusters = cs->num_rows, + num_floats = set->shape.num_floats; + + // row map: row => distance, cluster ID + km_row_map_t *row_map = km_row_map_init(set->num_rows); + if (!row_map) { + // return failure + return false; + } + + // calculate clusters by doing the following: + // * walk all clusters and all rows + // * if we find a closer cluster, move row to cluster + // * if there were changes to any cluster, then calculate a new + // centroid for each cluster by averaging the cluster rows + // * repeat until there are no more changes + bool changed = false; + do { + // no changes yet + changed = false; + + for (size_t i = 0; i < num_clusters; i++) { + // get the floats for this cluster + const float * const floats = km_set_get_row(cs, i); + + for (size_t j = 0; j < set->num_rows; j++) { + const float * const row_floats = km_set_get_row(set, j); + + // calculate the distance squared between these clusters + const float d2 = distance_squared(num_floats, floats, row_floats); + + if (d2 < row_map[j].d2) { + // row is closer to this cluster, update distance and cluster + row_map[j].d2 = d2; + row_map[j].cluster = i; + + // flag change + changed = true; + } + } + } + + if (changed) { + // if there were changes, then we need to calculate the new + // cluster centers + + // calculate new center + for (size_t i = 0; i < num_clusters; i++) { + size_t num_rows = 0; + float * const floats = km_set_get_row(cs, i); + memset(floats, 0, sizeof(float) * num_floats); + + for (size_t j = 0; j < set->num_rows; j++) { + const float * const row_floats = km_set_get_row(set, j); + + if (row_map[j].cluster == i) { + // calculate numerator for average + for (size_t k = 0; k < num_floats; k++) { + floats[k] += row_floats[k]; + } + + // increment denominator for average + num_rows++; + } + } + + // save number of rows in this cluster + cs->ints[i] = num_rows; + + if (num_rows > 0) { + for (size_t k = 0; k < num_floats; k++) { + // divide by denominator to get average + floats[k] /= num_rows; + } + } + } + } + + if (cbs && cbs->on_step) { + // pass clusters and row map to step callback + cbs->on_step(cs, row_map, cb_data); + } + } while (changed); + + if (cbs && cbs->on_stats) { + float means[num_clusters], + variances[num_clusters]; + memset(means, 0, sizeof(float) * num_clusters); + memset(variances, 0, sizeof(float) * num_clusters); + + // calculate mean distances + for (size_t i = 0; i < set->num_rows; i++) { + means[row_map[i].cluster] += row_map[i].d2; + } + + // finalize means + for (size_t i = 0; i < num_clusters; i++) { + means[i] = (cs->ints[i]) ? (sqrt(means[i]) / cs->ints[i]) : 0; + } + + // calculate variances + for (size_t i = 0; i < set->num_rows; i++) { + const size_t cluster = row_map[i].cluster; + const float variance = (sqrt(row_map[i].d2) - means[cluster]) * + (sqrt(row_map[i].d2) - means[cluster]); + variances[cluster] += variance; + } + + // finalize variances + for (size_t i = 0; i < num_clusters; i++) { + variances[i] = (cs->ints[i]) ? (variances[i] / cs->ints[i]) : 0; + } + + // build stats + const km_solve_stats_t stats = { + .means = means, + .variances = variances, + .num_clusters = num_clusters, + }; + + // emit means + cbs->on_stats(cs, &stats, cb_data); + } + + // free row_map + km_row_map_fini(row_map); + + // return success + return true; +} |