aboutsummaryrefslogtreecommitdiff
path: root/km-solve.c
diff options
context:
space:
mode:
Diffstat (limited to 'km-solve.c')
-rw-r--r--km-solve.c20
1 files changed, 14 insertions, 6 deletions
diff --git a/km-solve.c b/km-solve.c
index 9eac13c..2a4de16 100644
--- a/km-solve.c
+++ b/km-solve.c
@@ -70,13 +70,14 @@ km_solve(
const float * const floats = km_set_get_row(cs, i);
for (size_t j = 0; j < set->num_rows; j++) {
- const float * const row_floats = km_set_get_row(set, j);
+ // get row values
+ const float * const vals = km_set_get_row(set, j);
- // calculate the distance squared between these clusters
- const float d2 = distance_squared(num_floats, floats, row_floats);
+ // calculate the distance squared between row and cluster
+ const float d2 = distance_squared(num_floats, floats, vals);
if (d2 < row_map[j].d2) {
- // row is closer to this cluster, update distance and cluster
+ // row is closer to this cluster, update row map
row_map[j].d2 = d2;
row_map[j].cluster = i;
@@ -129,16 +130,22 @@ km_solve(
} while (changed);
if (cbs && cbs->on_stats) {
- float means[num_clusters],
+ float sum = 0,
+ means[num_clusters],
variances[num_clusters];
+
memset(means, 0, sizeof(float) * num_clusters);
memset(variances, 0, sizeof(float) * num_clusters);
+ // calculate sum of distances across all clusters
+ for (size_t i = 0; i < set->num_rows; i++) {
+ sum += row_map[i].d2;
+ }
+
// calculate mean distances
for (size_t i = 0; i < set->num_rows; i++) {
means[row_map[i].cluster] += row_map[i].d2;
}
-
// finalize means
for (size_t i = 0; i < num_clusters; i++) {
means[i] = (cs->ints[i]) ? (sqrt(means[i]) / cs->ints[i]) : 0;
@@ -159,6 +166,7 @@ km_solve(
// build stats
const km_solve_stats_t stats = {
+ .sum = sum,
.means = means,
.variances = variances,
.num_clusters = num_clusters,