aboutsummaryrefslogtreecommitdiff
path: root/km-solve.c
diff options
context:
space:
mode:
Diffstat (limited to 'km-solve.c')
-rw-r--r--km-solve.c193
1 files changed, 193 insertions, 0 deletions
diff --git a/km-solve.c b/km-solve.c
new file mode 100644
index 0000000..4560c67
--- /dev/null
+++ b/km-solve.c
@@ -0,0 +1,193 @@
+#include <stdbool.h> // bool
+#include <string.h> // memset()
+#include <float.h> // FLT_MAX
+#include <math.h> // sqrt()
+#include "util.h"
+#include "km.h"
+
+// calculate squared euclidean distance between two points
+static float
+distance_squared(
+ const size_t num_floats,
+ const float * const a,
+ const float * const b
+) {
+ float r = 0.0;
+
+ for (size_t i = 0; i < num_floats; i++) {
+ r += (b[i] - a[i]) * (b[i] - a[i]);
+ }
+
+ // return squared distance
+ return r;
+}
+
+// alloc and initialize row map
+static km_row_map_t *
+km_row_map_init(
+ const size_t num_rows
+) {
+ // alloc row map
+ km_row_map_t * const row_map = malloc(sizeof(km_row_map_t) * num_rows);
+
+ // check for error
+ if (!row_map) {
+ // return failure
+ return false;
+ }
+
+ // init row map by setting the maximum distance
+ for (size_t i = 0; i < num_rows; i++) {
+ row_map[i].d2 = FLT_MAX;
+ }
+
+ // return row map
+ return row_map;
+}
+
+static void
+km_row_map_fini(
+ km_row_map_t * const row_map
+) {
+ free(row_map);
+}
+
+// use k-means to iteratively update the cluster centroids until there
+// are no more changes to the centroids
+bool
+km_solve(
+ km_set_t * const cs,
+ const km_set_t * const set,
+ const km_solve_cbs_t * const cbs,
+ void * const cb_data
+) {
+ const size_t num_clusters = cs->num_rows,
+ num_floats = set->shape.num_floats;
+
+ // row map: row => distance, cluster ID
+ km_row_map_t *row_map = km_row_map_init(set->num_rows);
+ if (!row_map) {
+ // return failure
+ return false;
+ }
+
+ // calculate clusters by doing the following:
+ // * walk all clusters and all rows
+ // * if we find a closer cluster, move row to cluster
+ // * if there were changes to any cluster, then calculate a new
+ // centroid for each cluster by averaging the cluster rows
+ // * repeat until there are no more changes
+ bool changed = false;
+ do {
+ // no changes yet
+ changed = false;
+
+ for (size_t i = 0; i < num_clusters; i++) {
+ // get the floats for this cluster
+ const float * const floats = km_set_get_row(cs, i);
+
+ for (size_t j = 0; j < set->num_rows; j++) {
+ const float * const row_floats = km_set_get_row(set, j);
+
+ // calculate the distance squared between these clusters
+ const float d2 = distance_squared(num_floats, floats, row_floats);
+
+ if (d2 < row_map[j].d2) {
+ // row is closer to this cluster, update distance and cluster
+ row_map[j].d2 = d2;
+ row_map[j].cluster = i;
+
+ // flag change
+ changed = true;
+ }
+ }
+ }
+
+ if (changed) {
+ // if there were changes, then we need to calculate the new
+ // cluster centers
+
+ // calculate new center
+ for (size_t i = 0; i < num_clusters; i++) {
+ size_t num_rows = 0;
+ float * const floats = km_set_get_row(cs, i);
+ memset(floats, 0, sizeof(float) * num_floats);
+
+ for (size_t j = 0; j < set->num_rows; j++) {
+ const float * const row_floats = km_set_get_row(set, j);
+
+ if (row_map[j].cluster == i) {
+ // calculate numerator for average
+ for (size_t k = 0; k < num_floats; k++) {
+ floats[k] += row_floats[k];
+ }
+
+ // increment denominator for average
+ num_rows++;
+ }
+ }
+
+ // save number of rows in this cluster
+ cs->ints[i] = num_rows;
+
+ if (num_rows > 0) {
+ for (size_t k = 0; k < num_floats; k++) {
+ // divide by denominator to get average
+ floats[k] /= num_rows;
+ }
+ }
+ }
+ }
+
+ if (cbs && cbs->on_step) {
+ // pass clusters and row map to step callback
+ cbs->on_step(cs, row_map, cb_data);
+ }
+ } while (changed);
+
+ if (cbs && cbs->on_stats) {
+ float means[num_clusters],
+ variances[num_clusters];
+ memset(means, 0, sizeof(float) * num_clusters);
+ memset(variances, 0, sizeof(float) * num_clusters);
+
+ // calculate mean distances
+ for (size_t i = 0; i < set->num_rows; i++) {
+ means[row_map[i].cluster] += row_map[i].d2;
+ }
+
+ // finalize means
+ for (size_t i = 0; i < num_clusters; i++) {
+ means[i] = (cs->ints[i]) ? (sqrt(means[i]) / cs->ints[i]) : 0;
+ }
+
+ // calculate variances
+ for (size_t i = 0; i < set->num_rows; i++) {
+ const size_t cluster = row_map[i].cluster;
+ const float variance = (sqrt(row_map[i].d2) - means[cluster]) *
+ (sqrt(row_map[i].d2) - means[cluster]);
+ variances[cluster] += variance;
+ }
+
+ // finalize variances
+ for (size_t i = 0; i < num_clusters; i++) {
+ variances[i] = (cs->ints[i]) ? (variances[i] / cs->ints[i]) : 0;
+ }
+
+ // build stats
+ const km_solve_stats_t stats = {
+ .means = means,
+ .variances = variances,
+ .num_clusters = num_clusters,
+ };
+
+ // emit means
+ cbs->on_stats(cs, &stats, cb_data);
+ }
+
+ // free row_map
+ km_row_map_fini(row_map);
+
+ // return success
+ return true;
+}