diff options
author | Paul Duncan <pabs@pablotron.org> | 2019-02-02 03:25:59 -0500 |
---|---|---|
committer | Paul Duncan <pabs@pablotron.org> | 2019-02-02 03:25:59 -0500 |
commit | a3792d8769d2dc8ee0abae758c6fae3a35b5dfbc (patch) | |
tree | 079d372fb69cabfe6dded99f231408c8b13022cd | |
parent | bf8f5126a6d621be4996df842c70d60297f87706 (diff) | |
download | kmeans-a3792d8769d2dc8ee0abae758c6fae3a35b5dfbc.tar.bz2 kmeans-a3792d8769d2dc8ee0abae758c6fae3a35b5dfbc.zip |
add km_search()
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | Makefile | 10 | ||||
-rw-r--r-- | km.c | 144 | ||||
-rw-r--r-- | km.h | 24 |
4 files changed, 131 insertions, 48 deletions
@@ -1 +1,2 @@ *.o +km-test @@ -1,7 +1,13 @@ +APP=km-test CFLAGS=-W -Wall -Wextra -pedantic -std=c11 -O2 -OBJS=km.o +OBJS=km.o main.o +LIBS=-lm -all: $(OBJS) +.PHONY=all clean +app: $(APP) + +$(APP): $(OBJS) + $(CC) -o $(APP) $(OBJS) $(LIBS) %.o: %.c $(CC) -c $(CFLAGS) $< @@ -9,6 +9,7 @@ #define MIN_ROWS (4096 / sizeof(float)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define UNUSED(a) ((void) (a)) +#define MIN_CLUSTER_DISTANCE 0.0001 // calculate squared euclidean distance between two points static float @@ -19,7 +20,7 @@ distance_squared( ) { float r = 0.0; - for (size_t i = 0; i <num_floats; i++) { + for (size_t i = 0; i < num_floats; i++) { r += (b[i] - a[i]) * (b[i] - a[i]); } @@ -88,14 +89,18 @@ km_set_grow( const size_t capacity ) { // alloc floats - float * const floats = malloc(sizeof(float) * set->shape.num_floats * capacity); + const size_t num_floats = set->shape.num_floats * capacity; + float * const floats = malloc(sizeof(float) * num_floats); if (!floats) { + // return failure return false; } // alloc ints - int * const ints = malloc(sizeof(int) * set->shape.num_ints * capacity); + const size_t num_ints = set->shape.num_ints * capacity; + int * const ints = malloc(sizeof(int) * num_ints); if (!ints) { + // return failure return false; } @@ -103,6 +108,7 @@ km_set_grow( set->ints = ints; set->capacity = capacity; + // return success return true; } @@ -141,47 +147,6 @@ km_set_fini(km_set_t * const set) { set->capacity = 0; } -/* - * // append row to data set - * bool - * km_set_push_row( - * km_set_t * const set, - * const float * const floats, - * const float * const ints - * ) { - * if (set->num_rows + 1 == set->capacity) { - * // resize buffers - * if (!km_set_grow(set, MAX(MIN_ROWS, 2 * set->capacity + 1))) { - * return false; - * } - * } - * - * // copy floats - * const size_t num_floats = set->shape.num_floats; - * if (num_floats > 0) { - * float * const dst = set->floats + num_floats * set->num_rows; - * const size_t stride = sizeof(float) * num_floats; - * - * memcpy(dst, floats, stride); - * } - * - * // copy ints - * const size_t num_ints = set->shape.num_ints; - * if (num_ints > 0) { - * int * const dst = set->ints + num_ints * set->num_rows; - * const size_t stride = sizeof(int) * num_ints; - * - * memcpy(dst, ints, stride); - * } - * - * // increment row count - * set->num_rows++; - * - * // return success - * return true; - * } - */ - // append rows to data set, growing set if necessary bool km_set_push_rows( @@ -325,7 +290,7 @@ km_clusters_solve( // calculate the distance squared between these clusters const float d2 = distance_squared(num_floats, floats, row_floats); - if (d2 < row_map[j].d2) { + if (d2 < row_map[j].d2) { // row is closer to this cluster, update distance and cluster row_map[j].d2 = d2; row_map[j].cluster = i; @@ -402,3 +367,92 @@ km_clusters_solve( // return success return true; } + +typedef struct { + float sum; + size_t num_empty_clusters; +} search_test_data_t; + +static void +search_test_on_means( + const km_set_t * const set, + const float * const means, + const size_t num_clusters, + void * const cb_data +) { + search_test_data_t *test_data = cb_data; + UNUSED(set); + + // calculate numerator for the average distance across all clusters in + // this test + for (size_t i = 0; i < num_clusters; i++) { + if (fabsf(means[i]) > MIN_CLUSTER_DISTANCE) { + test_data->sum += means[i]; + } else { + test_data->num_empty_clusters++; + } + } +} + +static const km_clusters_solve_cbs_t +SEARCH_TEST_CBS = { + .on_means = search_test_on_means, +}; + +bool +km_search( + const km_set_t * const set, + const size_t max_clusters, + const size_t num_tests, + const km_search_row_cb_t on_row, + void *cb_data +) { + // init random source + km_rand_src_t rs; + km_rand_src_system_init(&rs); + + for (size_t i = 2; i < max_clusters; i++) { + for (size_t j = 0; j < num_tests; j++) { + // init cluster set + km_set_t cs; + if (!km_clusters_rand_init(&cs, set->shape.num_floats, i, &rs)) { + // return failure + return false; + } + + // init test data + search_test_data_t data = { + .sum = 0, + .num_empty_clusters = 0, + }; + + // solve test + if (!km_clusters_solve(&cs, set, &SEARCH_TEST_CBS, &data)) { + // return failure + return false; + } + + if (on_row) { + // calculate mean + const float mean = (data.num_empty_clusters < i) ? (data.sum / (i - data.num_empty_clusters)) : 0; + + // init search result row + km_search_row_t row = { + .cluster_set = &cs, + .num_clusters = i, + .mean_distance = mean, + .num_empty_clusters = data.num_empty_clusters, + }; + + // emit row + on_row(&row, cb_data); + } + + // free cluster set + km_set_fini(&cs); + } + } + + // return success + return true; +} @@ -8,7 +8,7 @@ typedef struct km_rand_src_t_ km_rand_src_t; // random number source callbacks typedef struct { - bool (*fill)(km_rand_src_t * const, const size_t, float * const); + _Bool (*fill)(km_rand_src_t * const, const size_t, float * const); void (*fini)(km_rand_src_t * const); } km_rand_src_cbs_t; @@ -79,4 +79,26 @@ km_clusters_solve( void * const ); +typedef struct { + const km_set_t * const cluster_set; + const size_t num_clusters; + const float mean_distance; + const size_t num_empty_clusters; +} km_search_row_t; + +typedef void (*km_search_row_cb_t)( + const km_search_row_t * const, + void * +); + +// repeatedly test different cluster sizes and report results +_Bool +km_search( + const km_set_t * const set, + const size_t max_clusters, + const size_t num_tests, + const km_search_row_cb_t on_row, + void *cb_data +); + #endif /* KM_H */ |