#include // bool #include "util.h" #include "km.h" static float get_mean_cluster_size( const km_set_t * const set ) { float sum = 0; size_t num_filled = 0; for (size_t i = 0; i < set->num_rows; i++) { if (set->ints[i] > 0) { sum += set->ints[i]; num_filled++; } } return (num_filled > 0) ? (sum / num_filled) : 0; } static void find_solve_on_stats( const km_set_t * const set, const km_solve_stats_t * const stats, void * const cb_data ) { km_find_data_t * const find_data = cb_data; UNUSED(set); // save total sum and silouette find_data->distance_sum = stats->sum; find_data->silouette = stats->silouette; } static const km_solve_cbs_t FIND_SOLVE_CBS = { .on_stats = find_solve_on_stats, }; bool km_find( const km_set_t * const set, const km_find_cbs_t * const cbs, void * const cb_data ) { // check init callback if (!cbs->on_init) { return false; } // check fini callback if (!cbs->on_fini) { return false; } // check data callback if (!cbs->on_data) { return false; } float best_silouette = -2.0; for (size_t i = 2; i < cbs->max_clusters; i++) { for (size_t j = 0; j < cbs->num_tests; j++) { // init cluster set km_set_t cs; if (!cbs->on_init(&cs, i, set, cb_data)) { // return failure return false; } // init find data km_find_data_t find_data = { .cluster_set = &cs, .num_clusters = i, }; // solve test // (populates sum and silouette) if (!km_solve(&cs, set, &FIND_SOLVE_CBS, &find_data)) { // return failure return false; } // populate mean cluster size find_data.mean_cluster_size = get_mean_cluster_size(&cs); // emit result cbs->on_data(&find_data, cb_data); if (find_data.silouette > best_silouette) { // emit new best result if (cbs->on_best && !cbs->on_best(find_data.silouette, &cs, cb_data)) { // return failure return false; } // update best silouette best_silouette = find_data.silouette; } // finalize cluster set if (!cbs->on_fini(&cs, cb_data)) { // return failure return false; } } } // return success return true; }