aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--km-find.c16
-rw-r--r--km-solve.c20
-rw-r--r--km.h4
3 files changed, 28 insertions, 12 deletions
diff --git a/km-find.c b/km-find.c
index 8e7f9ac..7cd24f6 100644
--- a/km-find.c
+++ b/km-find.c
@@ -6,7 +6,8 @@
#define MIN_CLUSTER_DISTANCE 0.0001
typedef struct {
- float distance_sum,
+ float sum,
+ mean_sum,
variance_sum;
size_t num_empty_clusters;
} find_solve_data_t;
@@ -17,7 +18,7 @@ find_get_mean_distance(
const size_t num_clusters
) {
const bool num_filled = num_clusters - d->num_empty_clusters;
- return num_filled ? (d->distance_sum / num_filled) : 0;
+ return num_filled ? (d->mean_sum / num_filled) : 0;
}
static float
@@ -55,12 +56,15 @@ find_solve_on_stats(
find_solve_data_t * const solve_data = cb_data;
UNUSED(set);
+ // save total sum
+ solve_data->sum = stats->sum;
+
// calculate numerator for the average distance across all clusters in
// this test
for (size_t i = 0; i < stats->num_clusters; i++) {
if (fabsf(stats->means[i]) > MIN_CLUSTER_DISTANCE) {
// increment mean count
- solve_data->distance_sum += stats->means[i];
+ solve_data->mean_sum += stats->means[i];
solve_data->variance_sum += stats->variances[i];
} else {
// increment empty cluster count
@@ -107,7 +111,8 @@ km_find(
// init solve data
find_solve_data_t solve_data = {
- .distance_sum = 0,
+ .sum = 0,
+ .mean_sum = 0,
.variance_sum = 0,
.num_empty_clusters = 0,
};
@@ -122,6 +127,7 @@ km_find(
const km_find_data_t result = {
.cluster_set = &cs,
.num_clusters = i,
+ .distance_sum = solve_data.sum,
.mean_distance = find_get_mean_distance(&solve_data, i),
.mean_variance = find_get_mean_variance(&solve_data, i),
.mean_cluster_size = find_get_mean_cluster_size(&cs),
@@ -132,7 +138,7 @@ km_find(
cbs->on_data(&result, cb_data);
// score result
- float score = km_score(result.mean_distance, result.num_empty_clusters);
+ float score = km_score(result.distance_sum, result.num_empty_clusters);
if (score > best_score) {
// emit new best result
diff --git a/km-solve.c b/km-solve.c
index 9eac13c..2a4de16 100644
--- a/km-solve.c
+++ b/km-solve.c
@@ -70,13 +70,14 @@ km_solve(
const float * const floats = km_set_get_row(cs, i);
for (size_t j = 0; j < set->num_rows; j++) {
- const float * const row_floats = km_set_get_row(set, j);
+ // get row values
+ const float * const vals = km_set_get_row(set, j);
- // calculate the distance squared between these clusters
- const float d2 = distance_squared(num_floats, floats, row_floats);
+ // calculate the distance squared between row and cluster
+ const float d2 = distance_squared(num_floats, floats, vals);
if (d2 < row_map[j].d2) {
- // row is closer to this cluster, update distance and cluster
+ // row is closer to this cluster, update row map
row_map[j].d2 = d2;
row_map[j].cluster = i;
@@ -129,16 +130,22 @@ km_solve(
} while (changed);
if (cbs && cbs->on_stats) {
- float means[num_clusters],
+ float sum = 0,
+ means[num_clusters],
variances[num_clusters];
+
memset(means, 0, sizeof(float) * num_clusters);
memset(variances, 0, sizeof(float) * num_clusters);
+ // calculate sum of distances across all clusters
+ for (size_t i = 0; i < set->num_rows; i++) {
+ sum += row_map[i].d2;
+ }
+
// calculate mean distances
for (size_t i = 0; i < set->num_rows; i++) {
means[row_map[i].cluster] += row_map[i].d2;
}
-
// finalize means
for (size_t i = 0; i < num_clusters; i++) {
means[i] = (cs->ints[i]) ? (sqrt(means[i]) / cs->ints[i]) : 0;
@@ -159,6 +166,7 @@ km_solve(
// build stats
const km_solve_stats_t stats = {
+ .sum = sum,
.means = means,
.variances = variances,
.num_clusters = num_clusters,
diff --git a/km.h b/km.h
index 097026e..bc2e3a6 100644
--- a/km.h
+++ b/km.h
@@ -89,6 +89,7 @@ typedef struct {
} km_row_map_t;
typedef struct {
+ const float sum;
const float *means;
const float *variances;
const size_t num_clusters;
@@ -112,7 +113,8 @@ km_solve(
typedef struct {
const km_set_t * const cluster_set;
- const float mean_distance,
+ const float distance_sum,
+ mean_distance,
mean_variance,
mean_cluster_size;