diff options
Diffstat (limited to 'km-find.c')
-rw-r--r-- | km-find.c | 143 |
1 files changed, 143 insertions, 0 deletions
diff --git a/km-find.c b/km-find.c new file mode 100644 index 0000000..2dd1ba1 --- /dev/null +++ b/km-find.c @@ -0,0 +1,143 @@ +#include <stdbool.h> // bool +#include <math.h> // fabsf() +#include "util.h" +#include "km.h" + +#define MIN_CLUSTER_DISTANCE 0.0001 + +typedef struct { + float distance_sum, + variance_sum; + size_t num_empty_clusters; +} find_solve_data_t; + +static float +find_get_mean_distance( + const find_solve_data_t * const d, + const size_t num_clusters +) { + const bool num_filled = num_clusters - d->num_empty_clusters; + return num_filled ? (d->distance_sum / num_filled) : 0; +} + +static float +find_get_mean_variance( + const find_solve_data_t * const d, + const size_t num_clusters +) { + const bool num_filled = num_clusters - d->num_empty_clusters; + return num_filled ? (d->variance_sum / num_filled) : 0; +} + +static float +find_get_mean_cluster_size( + const km_set_t * const set +) { + float sum = 0; + size_t num_filled = 0; + + for (size_t i = 0; i < set->num_rows; i++) { + if (set->ints[i] > 0) { + sum += set->ints[i]; + num_filled++; + } + } + + return (num_filled > 0) ? sum / num_filled : 0; +} + +static void +find_solve_on_stats( + const km_set_t * const set, + const km_solve_stats_t * const stats, + void * const cb_data +) { + find_solve_data_t * const solve_data = cb_data; + UNUSED(set); + + // calculate numerator for the average distance across all clusters in + // this test + for (size_t i = 0; i < stats->num_clusters; i++) { + if (fabsf(stats->means[i]) > MIN_CLUSTER_DISTANCE) { + // increment mean count + solve_data->distance_sum += stats->means[i]; + solve_data->variance_sum += stats->variances[i]; + } else { + // increment empty cluster count + solve_data->num_empty_clusters++; + } + } +} + +static const km_solve_cbs_t +FIND_SOLVE_CBS = { + .on_stats = find_solve_on_stats, +}; + +bool +km_find( + const km_set_t * const set, + const km_find_cbs_t * const cbs, + void * const cb_data +) { + // check init callback + if (!cbs->on_init) { + return false; + } + + // check fini callback + if (!cbs->on_fini) { + return false; + } + + // check data callback + if (!cbs->on_data) { + return false; + } + + for (size_t i = 2; i < cbs->max_clusters; i++) { + for (size_t j = 0; j < cbs->num_tests; j++) { + // init cluster set + km_set_t cs; + if (!cbs->on_init(&cs, set->shape.num_floats, i, cb_data)) { + // return failure + return false; + } + + // init solve data + find_solve_data_t solve_data = { + .distance_sum = 0, + .variance_sum = 0, + .num_empty_clusters = 0, + }; + + // solve test + if (!km_solve(&cs, set, &FIND_SOLVE_CBS, &solve_data)) { + // return failure + return false; + } + + // init result data + const km_find_data_t result = { + .cluster_set = &cs, + .num_clusters = i, + .mean_distance = find_get_mean_distance(&solve_data, i), + .mean_variance = find_get_mean_variance(&solve_data, i), + .mean_cluster_size = find_get_mean_cluster_size(&cs), + .num_empty_clusters = solve_data.num_empty_clusters, + }; + + // emit result + cbs->on_data(&result, cb_data); + + // finalize cluster set + if (!cbs->on_fini(&cs, cb_data)) { + // return failure + return false; + } + } + } + + // return success + return true; +} |