aboutsummaryrefslogtreecommitdiff
path: root/km.c
diff options
context:
space:
mode:
Diffstat (limited to 'km.c')
-rw-r--r--km.c144
1 files changed, 99 insertions, 45 deletions
diff --git a/km.c b/km.c
index a6c9301..135132b 100644
--- a/km.c
+++ b/km.c
@@ -9,6 +9,7 @@
#define MIN_ROWS (4096 / sizeof(float))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define UNUSED(a) ((void) (a))
+#define MIN_CLUSTER_DISTANCE 0.0001
// calculate squared euclidean distance between two points
static float
@@ -19,7 +20,7 @@ distance_squared(
) {
float r = 0.0;
- for (size_t i = 0; i <num_floats; i++) {
+ for (size_t i = 0; i < num_floats; i++) {
r += (b[i] - a[i]) * (b[i] - a[i]);
}
@@ -88,14 +89,18 @@ km_set_grow(
const size_t capacity
) {
// alloc floats
- float * const floats = malloc(sizeof(float) * set->shape.num_floats * capacity);
+ const size_t num_floats = set->shape.num_floats * capacity;
+ float * const floats = malloc(sizeof(float) * num_floats);
if (!floats) {
+ // return failure
return false;
}
// alloc ints
- int * const ints = malloc(sizeof(int) * set->shape.num_ints * capacity);
+ const size_t num_ints = set->shape.num_ints * capacity;
+ int * const ints = malloc(sizeof(int) * num_ints);
if (!ints) {
+ // return failure
return false;
}
@@ -103,6 +108,7 @@ km_set_grow(
set->ints = ints;
set->capacity = capacity;
+ // return success
return true;
}
@@ -141,47 +147,6 @@ km_set_fini(km_set_t * const set) {
set->capacity = 0;
}
-/*
- * // append row to data set
- * bool
- * km_set_push_row(
- * km_set_t * const set,
- * const float * const floats,
- * const float * const ints
- * ) {
- * if (set->num_rows + 1 == set->capacity) {
- * // resize buffers
- * if (!km_set_grow(set, MAX(MIN_ROWS, 2 * set->capacity + 1))) {
- * return false;
- * }
- * }
- *
- * // copy floats
- * const size_t num_floats = set->shape.num_floats;
- * if (num_floats > 0) {
- * float * const dst = set->floats + num_floats * set->num_rows;
- * const size_t stride = sizeof(float) * num_floats;
- *
- * memcpy(dst, floats, stride);
- * }
- *
- * // copy ints
- * const size_t num_ints = set->shape.num_ints;
- * if (num_ints > 0) {
- * int * const dst = set->ints + num_ints * set->num_rows;
- * const size_t stride = sizeof(int) * num_ints;
- *
- * memcpy(dst, ints, stride);
- * }
- *
- * // increment row count
- * set->num_rows++;
- *
- * // return success
- * return true;
- * }
- */
-
// append rows to data set, growing set if necessary
bool
km_set_push_rows(
@@ -325,7 +290,7 @@ km_clusters_solve(
// calculate the distance squared between these clusters
const float d2 = distance_squared(num_floats, floats, row_floats);
- if (d2 < row_map[j].d2) {
+ if (d2 < row_map[j].d2) {
// row is closer to this cluster, update distance and cluster
row_map[j].d2 = d2;
row_map[j].cluster = i;
@@ -402,3 +367,92 @@ km_clusters_solve(
// return success
return true;
}
+
+typedef struct {
+ float sum;
+ size_t num_empty_clusters;
+} search_test_data_t;
+
+static void
+search_test_on_means(
+ const km_set_t * const set,
+ const float * const means,
+ const size_t num_clusters,
+ void * const cb_data
+) {
+ search_test_data_t *test_data = cb_data;
+ UNUSED(set);
+
+ // calculate numerator for the average distance across all clusters in
+ // this test
+ for (size_t i = 0; i < num_clusters; i++) {
+ if (fabsf(means[i]) > MIN_CLUSTER_DISTANCE) {
+ test_data->sum += means[i];
+ } else {
+ test_data->num_empty_clusters++;
+ }
+ }
+}
+
+static const km_clusters_solve_cbs_t
+SEARCH_TEST_CBS = {
+ .on_means = search_test_on_means,
+};
+
+bool
+km_search(
+ const km_set_t * const set,
+ const size_t max_clusters,
+ const size_t num_tests,
+ const km_search_row_cb_t on_row,
+ void *cb_data
+) {
+ // init random source
+ km_rand_src_t rs;
+ km_rand_src_system_init(&rs);
+
+ for (size_t i = 2; i < max_clusters; i++) {
+ for (size_t j = 0; j < num_tests; j++) {
+ // init cluster set
+ km_set_t cs;
+ if (!km_clusters_rand_init(&cs, set->shape.num_floats, i, &rs)) {
+ // return failure
+ return false;
+ }
+
+ // init test data
+ search_test_data_t data = {
+ .sum = 0,
+ .num_empty_clusters = 0,
+ };
+
+ // solve test
+ if (!km_clusters_solve(&cs, set, &SEARCH_TEST_CBS, &data)) {
+ // return failure
+ return false;
+ }
+
+ if (on_row) {
+ // calculate mean
+ const float mean = (data.num_empty_clusters < i) ? (data.sum / (i - data.num_empty_clusters)) : 0;
+
+ // init search result row
+ km_search_row_t row = {
+ .cluster_set = &cs,
+ .num_clusters = i,
+ .mean_distance = mean,
+ .num_empty_clusters = data.num_empty_clusters,
+ };
+
+ // emit row
+ on_row(&row, cb_data);
+ }
+
+ // free cluster set
+ km_set_fini(&cs);
+ }
+ }
+
+ // return success
+ return true;
+}