aboutsummaryrefslogtreecommitdiff
path: root/src/km-solve.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/km-solve.c')
-rw-r--r--src/km-solve.c200
1 files changed, 200 insertions, 0 deletions
diff --git a/src/km-solve.c b/src/km-solve.c
new file mode 100644
index 0000000..f579a41
--- /dev/null
+++ b/src/km-solve.c
@@ -0,0 +1,200 @@
+#include <stdbool.h> // bool
+#include <string.h> // memset()
+#include <float.h> // FLT_MAX
+#include <math.h> // sqrt()
+#include "util.h"
+#include "km.h"
+
+#define MIN_CLUSTER_DISTANCE 0.00001
+
+// alloc and initialize row map
+static km_row_map_t *
+km_row_map_init(
+ const size_t num_rows
+) {
+ // alloc row map
+ km_row_map_t * const row_map = malloc(sizeof(km_row_map_t) * num_rows);
+
+ // check for error
+ if (!row_map) {
+ // return failure
+ return false;
+ }
+
+ // init row map
+ for (size_t i = 0; i < num_rows; i++) {
+ // setting distances to maximum
+ row_map[i].d2 = FLT_MAX;
+ row_map[i].d2_near = FLT_MAX;
+ }
+
+ // return row map
+ return row_map;
+}
+
+static void
+km_row_map_fini(
+ km_row_map_t * const row_map
+) {
+ free(row_map);
+}
+
+// use k-means to iteratively update the cluster centroids until there
+// are no more changes to the centroids
+bool
+km_solve(
+ km_set_t * const cs,
+ const km_set_t * const set,
+ const km_solve_cbs_t * const cbs,
+ void * const cb_data
+) {
+ const size_t num_clusters = cs->num_rows,
+ num_floats = set->shape.num_floats;
+
+ // row map: row => distance, cluster ID
+ km_row_map_t *row_map = km_row_map_init(set->num_rows);
+ if (!row_map) {
+ // return failure
+ return false;
+ }
+
+ // calculate clusters by doing the following:
+ // * walk all clusters and all rows
+ // * if we find a closer cluster, move row to cluster
+ // * if there were changes to any cluster, then calculate a new
+ // centroid for each cluster by averaging the cluster rows
+ // * repeat until there are no more changes
+ bool changed = false;
+ do {
+ // no changes yet
+ changed = false;
+
+ for (size_t i = 0; i < num_clusters; i++) {
+ // get the floats for this cluster
+ const float * const floats = km_set_get_row(cs, i);
+
+ for (size_t j = 0; j < set->num_rows; j++) {
+ // get row values
+ const float * const vals = km_set_get_row(set, j);
+
+ // calculate the distance squared between row and cluster
+ const float d2 = distance_squared(num_floats, floats, vals);
+
+ if (d2 < row_map[j].d2) {
+ // row is closer to this cluster, update row map
+ row_map[j].d2 = d2;
+ row_map[j].cluster = i;
+
+ // flag change
+ changed = true;
+ }
+
+ if ((row_map[j].cluster != i) && (d2 < row_map[j].d2_near)) {
+ row_map[j].d2_near = d2;
+
+ // flag change
+ changed = true;
+ }
+ }
+ }
+
+ if (changed) {
+ // if there were changes, then we need to calculate the new
+ // cluster centers
+
+ // calculate new center
+ for (size_t i = 0; i < num_clusters; i++) {
+ size_t num_rows = 0;
+ float * const floats = km_set_get_row(cs, i);
+ memset(floats, 0, sizeof(float) * num_floats);
+
+ for (size_t j = 0; j < set->num_rows; j++) {
+ const float * const row_floats = km_set_get_row(set, j);
+
+ if (row_map[j].cluster == i) {
+ // calculate numerator for average
+ for (size_t k = 0; k < num_floats; k++) {
+ floats[k] += row_floats[k];
+ }
+
+ // increment denominator for average
+ num_rows++;
+ }
+ }
+
+ // save number of rows in this cluster
+ cs->ints[i] = num_rows;
+
+ if (num_rows > 0) {
+ for (size_t k = 0; k < num_floats; k++) {
+ // divide by denominator to get average
+ floats[k] /= num_rows;
+ }
+ }
+ }
+ }
+
+ if (cbs && cbs->on_step) {
+ // pass clusters and row map to step callback
+ cbs->on_step(cs, row_map, cb_data);
+ }
+ } while (changed);
+
+ if (cbs && cbs->on_stats) {
+ float sum = 0,
+ silouette = 0,
+ mean_dists[num_clusters],
+ mean_nears[num_clusters];
+
+ memset(mean_dists, 0, sizeof(mean_dists));
+ memset(mean_nears, 0, sizeof(mean_nears));
+
+ // calculate sum of distances across all clusters
+ for (size_t i = 0; i < set->num_rows; i++) {
+ sum += row_map[i].d2;
+ }
+
+ // calculate mean numerators and silouette
+ for (size_t i = 0; i < set->num_rows; i++) {
+ // distance squared (d2) to center of this cluster
+ mean_dists[row_map[i].cluster] += row_map[i].d2;
+
+ // distance squared (d2) to center of nearest cluster
+ mean_nears[row_map[i].cluster] += row_map[i].d2_near;
+
+ // calculate silouette denominator
+ // (https://en.wikipedia.org/wiki/Silhouette_%28clustering%29)
+ const float delta = row_map[i].d2_near - row_map[i].d2;
+ if (fabsf(delta) > MIN_CLUSTER_DISTANCE) {
+ silouette += delta / MAX(row_map[i].d2, row_map[i].d2_near);
+ }
+ }
+
+ // finalize means (divide by row count)
+ for (size_t i = 0; i < num_clusters; i++) {
+ mean_dists[i] = (cs->ints[i]) ? (sqrt(mean_dists[i]) / cs->ints[i]) : 0;
+ mean_nears[i] = (cs->ints[i]) ? (sqrt(mean_nears[i]) / cs->ints[i]) : 0;
+ }
+
+ // finalize silouette
+ silouette /= set->num_rows;
+
+ // build stats
+ const km_solve_stats_t stats = {
+ .sum = sum,
+ .silouette = silouette,
+ .mean_dists = mean_dists,
+ .mean_nears = mean_nears,
+ .num_clusters = num_clusters,
+ };
+
+ // emit means
+ cbs->on_stats(cs, &stats, cb_data);
+ }
+
+ // free row_map
+ km_row_map_fini(row_map);
+
+ // return success
+ return true;
+}