aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/km-solve.c37
-rw-r--r--src/km.h5
2 files changed, 9 insertions, 33 deletions
diff --git a/src/km-solve.c b/src/km-solve.c
index f579a41..4730eb4 100644
--- a/src/km-solve.c
+++ b/src/km-solve.c
@@ -5,7 +5,7 @@
#include "util.h"
#include "km.h"
-#define MIN_CLUSTER_DISTANCE 0.00001
+#define MIN_DELTA 0.00001
// alloc and initialize row map
static km_row_map_t *
@@ -142,38 +142,21 @@ km_solve(
if (cbs && cbs->on_stats) {
float sum = 0,
- silouette = 0,
- mean_dists[num_clusters],
- mean_nears[num_clusters];
-
- memset(mean_dists, 0, sizeof(mean_dists));
- memset(mean_nears, 0, sizeof(mean_nears));
-
- // calculate sum of distances across all clusters
- for (size_t i = 0; i < set->num_rows; i++) {
- sum += row_map[i].d2;
- }
+ silouette = 0;
// calculate mean numerators and silouette
for (size_t i = 0; i < set->num_rows; i++) {
- // distance squared (d2) to center of this cluster
- mean_dists[row_map[i].cluster] += row_map[i].d2;
-
- // distance squared (d2) to center of nearest cluster
- mean_nears[row_map[i].cluster] += row_map[i].d2_near;
+ // sum distances
+ sum += row_map[i].d2;
// calculate silouette denominator
// (https://en.wikipedia.org/wiki/Silhouette_%28clustering%29)
- const float delta = row_map[i].d2_near - row_map[i].d2;
- if (fabsf(delta) > MIN_CLUSTER_DISTANCE) {
- silouette += delta / MAX(row_map[i].d2, row_map[i].d2_near);
- }
- }
+ const float delta = (row_map[i].d2_near - row_map[i].d2);
- // finalize means (divide by row count)
- for (size_t i = 0; i < num_clusters; i++) {
- mean_dists[i] = (cs->ints[i]) ? (sqrt(mean_dists[i]) / cs->ints[i]) : 0;
- mean_nears[i] = (cs->ints[i]) ? (sqrt(mean_nears[i]) / cs->ints[i]) : 0;
+ // sum silouette
+ silouette += (delta < -MIN_DELTA || delta > MIN_DELTA)
+ ? (delta / MAX(row_map[i].d2, row_map[i].d2_near))
+ : 0.0;
}
// finalize silouette
@@ -183,8 +166,6 @@ km_solve(
const km_solve_stats_t stats = {
.sum = sum,
.silouette = silouette,
- .mean_dists = mean_dists,
- .mean_nears = mean_nears,
.num_clusters = num_clusters,
};
diff --git a/src/km.h b/src/km.h
index 3091685..7906f00 100644
--- a/src/km.h
+++ b/src/km.h
@@ -94,12 +94,7 @@ typedef struct {
typedef struct {
const float sum;
-
const float silouette;
-
- const float *mean_dists;
- const float *mean_nears;
-
const size_t num_clusters;
} km_solve_stats_t;