diff options
Diffstat (limited to 'km-solve.c')
-rw-r--r-- | km-solve.c | 20 |
1 files changed, 14 insertions, 6 deletions
@@ -70,13 +70,14 @@ km_solve( const float * const floats = km_set_get_row(cs, i); for (size_t j = 0; j < set->num_rows; j++) { - const float * const row_floats = km_set_get_row(set, j); + // get row values + const float * const vals = km_set_get_row(set, j); - // calculate the distance squared between these clusters - const float d2 = distance_squared(num_floats, floats, row_floats); + // calculate the distance squared between row and cluster + const float d2 = distance_squared(num_floats, floats, vals); if (d2 < row_map[j].d2) { - // row is closer to this cluster, update distance and cluster + // row is closer to this cluster, update row map row_map[j].d2 = d2; row_map[j].cluster = i; @@ -129,16 +130,22 @@ km_solve( } while (changed); if (cbs && cbs->on_stats) { - float means[num_clusters], + float sum = 0, + means[num_clusters], variances[num_clusters]; + memset(means, 0, sizeof(float) * num_clusters); memset(variances, 0, sizeof(float) * num_clusters); + // calculate sum of distances across all clusters + for (size_t i = 0; i < set->num_rows; i++) { + sum += row_map[i].d2; + } + // calculate mean distances for (size_t i = 0; i < set->num_rows; i++) { means[row_map[i].cluster] += row_map[i].d2; } - // finalize means for (size_t i = 0; i < num_clusters; i++) { means[i] = (cs->ints[i]) ? (sqrt(means[i]) / cs->ints[i]) : 0; @@ -159,6 +166,7 @@ km_solve( // build stats const km_solve_stats_t stats = { + .sum = sum, .means = means, .variances = variances, .num_clusters = num_clusters, |