From cb7823eaa631c43ed2f9620c30e9fbbbe574bd41 Mon Sep 17 00:00:00 2001 From: Paul Duncan Date: Mon, 4 Feb 2019 21:31:19 -0500 Subject: add distance sum --- km-find.c | 16 +++++++++++----- km-solve.c | 20 ++++++++++++++------ km.h | 4 +++- 3 files changed, 28 insertions(+), 12 deletions(-) diff --git a/km-find.c b/km-find.c index 8e7f9ac..7cd24f6 100644 --- a/km-find.c +++ b/km-find.c @@ -6,7 +6,8 @@ #define MIN_CLUSTER_DISTANCE 0.0001 typedef struct { - float distance_sum, + float sum, + mean_sum, variance_sum; size_t num_empty_clusters; } find_solve_data_t; @@ -17,7 +18,7 @@ find_get_mean_distance( const size_t num_clusters ) { const bool num_filled = num_clusters - d->num_empty_clusters; - return num_filled ? (d->distance_sum / num_filled) : 0; + return num_filled ? (d->mean_sum / num_filled) : 0; } static float @@ -55,12 +56,15 @@ find_solve_on_stats( find_solve_data_t * const solve_data = cb_data; UNUSED(set); + // save total sum + solve_data->sum = stats->sum; + // calculate numerator for the average distance across all clusters in // this test for (size_t i = 0; i < stats->num_clusters; i++) { if (fabsf(stats->means[i]) > MIN_CLUSTER_DISTANCE) { // increment mean count - solve_data->distance_sum += stats->means[i]; + solve_data->mean_sum += stats->means[i]; solve_data->variance_sum += stats->variances[i]; } else { // increment empty cluster count @@ -107,7 +111,8 @@ km_find( // init solve data find_solve_data_t solve_data = { - .distance_sum = 0, + .sum = 0, + .mean_sum = 0, .variance_sum = 0, .num_empty_clusters = 0, }; @@ -122,6 +127,7 @@ km_find( const km_find_data_t result = { .cluster_set = &cs, .num_clusters = i, + .distance_sum = solve_data.sum, .mean_distance = find_get_mean_distance(&solve_data, i), .mean_variance = find_get_mean_variance(&solve_data, i), .mean_cluster_size = find_get_mean_cluster_size(&cs), @@ -132,7 +138,7 @@ km_find( cbs->on_data(&result, cb_data); // score result - float score = km_score(result.mean_distance, result.num_empty_clusters); + float score = km_score(result.distance_sum, result.num_empty_clusters); if (score > best_score) { // emit new best result diff --git a/km-solve.c b/km-solve.c index 9eac13c..2a4de16 100644 --- a/km-solve.c +++ b/km-solve.c @@ -70,13 +70,14 @@ km_solve( const float * const floats = km_set_get_row(cs, i); for (size_t j = 0; j < set->num_rows; j++) { - const float * const row_floats = km_set_get_row(set, j); + // get row values + const float * const vals = km_set_get_row(set, j); - // calculate the distance squared between these clusters - const float d2 = distance_squared(num_floats, floats, row_floats); + // calculate the distance squared between row and cluster + const float d2 = distance_squared(num_floats, floats, vals); if (d2 < row_map[j].d2) { - // row is closer to this cluster, update distance and cluster + // row is closer to this cluster, update row map row_map[j].d2 = d2; row_map[j].cluster = i; @@ -129,16 +130,22 @@ km_solve( } while (changed); if (cbs && cbs->on_stats) { - float means[num_clusters], + float sum = 0, + means[num_clusters], variances[num_clusters]; + memset(means, 0, sizeof(float) * num_clusters); memset(variances, 0, sizeof(float) * num_clusters); + // calculate sum of distances across all clusters + for (size_t i = 0; i < set->num_rows; i++) { + sum += row_map[i].d2; + } + // calculate mean distances for (size_t i = 0; i < set->num_rows; i++) { means[row_map[i].cluster] += row_map[i].d2; } - // finalize means for (size_t i = 0; i < num_clusters; i++) { means[i] = (cs->ints[i]) ? (sqrt(means[i]) / cs->ints[i]) : 0; @@ -159,6 +166,7 @@ km_solve( // build stats const km_solve_stats_t stats = { + .sum = sum, .means = means, .variances = variances, .num_clusters = num_clusters, diff --git a/km.h b/km.h index 097026e..bc2e3a6 100644 --- a/km.h +++ b/km.h @@ -89,6 +89,7 @@ typedef struct { } km_row_map_t; typedef struct { + const float sum; const float *means; const float *variances; const size_t num_clusters; @@ -112,7 +113,8 @@ km_solve( typedef struct { const km_set_t * const cluster_set; - const float mean_distance, + const float distance_sum, + mean_distance, mean_variance, mean_cluster_size; -- cgit v1.2.3