diff options
Diffstat (limited to 'src/km-find.c')
-rw-r--r-- | src/km-find.c | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/src/km-find.c b/src/km-find.c new file mode 100644 index 0000000..657c645 --- /dev/null +++ b/src/km-find.c @@ -0,0 +1,115 @@ +#include <stdbool.h> // bool +#include <math.h> // fabsf() +#include "util.h" +#include "km.h" + +#define MIN_CLUSTER_DISTANCE 0.0001 + +static float +get_mean_cluster_size( + const km_set_t * const set +) { + float sum = 0; + size_t num_filled = 0; + + for (size_t i = 0; i < set->num_rows; i++) { + if (set->ints[i] > 0) { + sum += set->ints[i]; + num_filled++; + } + } + + return (num_filled > 0) ? (sum / num_filled) : 0; +} + +static void +find_solve_on_stats( + const km_set_t * const set, + const km_solve_stats_t * const stats, + void * const cb_data +) { + km_find_data_t * const find_data = cb_data; + UNUSED(set); + + // save total sum and silouette + find_data->distance_sum = stats->sum; + find_data->silouette = stats->silouette; +} + +static const km_solve_cbs_t +FIND_SOLVE_CBS = { + .on_stats = find_solve_on_stats, +}; + +bool +km_find( + const km_set_t * const set, + const km_find_cbs_t * const cbs, + void * const cb_data +) { + // check init callback + if (!cbs->on_init) { + return false; + } + + // check fini callback + if (!cbs->on_fini) { + return false; + } + + // check data callback + if (!cbs->on_data) { + return false; + } + + float best_silouette = -2.0; + for (size_t i = 2; i < cbs->max_clusters; i++) { + for (size_t j = 0; j < cbs->num_tests; j++) { + // init cluster set + km_set_t cs; + if (!cbs->on_init(&cs, i, set, cb_data)) { + // return failure + return false; + } + + // init find data + km_find_data_t find_data = { + .cluster_set = &cs, + .num_clusters = i, + }; + + // solve test + // (populates sum and silouette) + if (!km_solve(&cs, set, &FIND_SOLVE_CBS, &find_data)) { + // return failure + return false; + } + + // populate mean cluster size + find_data.mean_cluster_size = get_mean_cluster_size(&cs); + + // emit result + cbs->on_data(&find_data, cb_data); + + if (find_data.silouette > best_silouette) { + // emit new best result + if (cbs->on_best && !cbs->on_best(find_data.silouette, &cs, cb_data)) { + // return failure + return false; + } + + // update best silouette + best_silouette = find_data.silouette; + } + + // finalize cluster set + if (!cbs->on_fini(&cs, cb_data)) { + // return failure + return false; + } + } + } + + // return success + return true; +} |