diff options
Diffstat (limited to 'km.c')
-rw-r--r-- | km.c | 144 |
1 files changed, 99 insertions, 45 deletions
@@ -9,6 +9,7 @@ #define MIN_ROWS (4096 / sizeof(float)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define UNUSED(a) ((void) (a)) +#define MIN_CLUSTER_DISTANCE 0.0001 // calculate squared euclidean distance between two points static float @@ -19,7 +20,7 @@ distance_squared( ) { float r = 0.0; - for (size_t i = 0; i <num_floats; i++) { + for (size_t i = 0; i < num_floats; i++) { r += (b[i] - a[i]) * (b[i] - a[i]); } @@ -88,14 +89,18 @@ km_set_grow( const size_t capacity ) { // alloc floats - float * const floats = malloc(sizeof(float) * set->shape.num_floats * capacity); + const size_t num_floats = set->shape.num_floats * capacity; + float * const floats = malloc(sizeof(float) * num_floats); if (!floats) { + // return failure return false; } // alloc ints - int * const ints = malloc(sizeof(int) * set->shape.num_ints * capacity); + const size_t num_ints = set->shape.num_ints * capacity; + int * const ints = malloc(sizeof(int) * num_ints); if (!ints) { + // return failure return false; } @@ -103,6 +108,7 @@ km_set_grow( set->ints = ints; set->capacity = capacity; + // return success return true; } @@ -141,47 +147,6 @@ km_set_fini(km_set_t * const set) { set->capacity = 0; } -/* - * // append row to data set - * bool - * km_set_push_row( - * km_set_t * const set, - * const float * const floats, - * const float * const ints - * ) { - * if (set->num_rows + 1 == set->capacity) { - * // resize buffers - * if (!km_set_grow(set, MAX(MIN_ROWS, 2 * set->capacity + 1))) { - * return false; - * } - * } - * - * // copy floats - * const size_t num_floats = set->shape.num_floats; - * if (num_floats > 0) { - * float * const dst = set->floats + num_floats * set->num_rows; - * const size_t stride = sizeof(float) * num_floats; - * - * memcpy(dst, floats, stride); - * } - * - * // copy ints - * const size_t num_ints = set->shape.num_ints; - * if (num_ints > 0) { - * int * const dst = set->ints + num_ints * set->num_rows; - * const size_t stride = sizeof(int) * num_ints; - * - * memcpy(dst, ints, stride); - * } - * - * // increment row count - * set->num_rows++; - * - * // return success - * return true; - * } - */ - // append rows to data set, growing set if necessary bool km_set_push_rows( @@ -325,7 +290,7 @@ km_clusters_solve( // calculate the distance squared between these clusters const float d2 = distance_squared(num_floats, floats, row_floats); - if (d2 < row_map[j].d2) { + if (d2 < row_map[j].d2) { // row is closer to this cluster, update distance and cluster row_map[j].d2 = d2; row_map[j].cluster = i; @@ -402,3 +367,92 @@ km_clusters_solve( // return success return true; } + +typedef struct { + float sum; + size_t num_empty_clusters; +} search_test_data_t; + +static void +search_test_on_means( + const km_set_t * const set, + const float * const means, + const size_t num_clusters, + void * const cb_data +) { + search_test_data_t *test_data = cb_data; + UNUSED(set); + + // calculate numerator for the average distance across all clusters in + // this test + for (size_t i = 0; i < num_clusters; i++) { + if (fabsf(means[i]) > MIN_CLUSTER_DISTANCE) { + test_data->sum += means[i]; + } else { + test_data->num_empty_clusters++; + } + } +} + +static const km_clusters_solve_cbs_t +SEARCH_TEST_CBS = { + .on_means = search_test_on_means, +}; + +bool +km_search( + const km_set_t * const set, + const size_t max_clusters, + const size_t num_tests, + const km_search_row_cb_t on_row, + void *cb_data +) { + // init random source + km_rand_src_t rs; + km_rand_src_system_init(&rs); + + for (size_t i = 2; i < max_clusters; i++) { + for (size_t j = 0; j < num_tests; j++) { + // init cluster set + km_set_t cs; + if (!km_clusters_rand_init(&cs, set->shape.num_floats, i, &rs)) { + // return failure + return false; + } + + // init test data + search_test_data_t data = { + .sum = 0, + .num_empty_clusters = 0, + }; + + // solve test + if (!km_clusters_solve(&cs, set, &SEARCH_TEST_CBS, &data)) { + // return failure + return false; + } + + if (on_row) { + // calculate mean + const float mean = (data.num_empty_clusters < i) ? (data.sum / (i - data.num_empty_clusters)) : 0; + + // init search result row + km_search_row_t row = { + .cluster_set = &cs, + .num_clusters = i, + .mean_distance = mean, + .num_empty_clusters = data.num_empty_clusters, + }; + + // emit row + on_row(&row, cb_data); + } + + // free cluster set + km_set_fini(&cs); + } + } + + // return success + return true; +} |