aboutsummaryrefslogtreecommitdiff
path: root/km-init-kmeans.c
diff options
context:
space:
mode:
Diffstat (limited to 'km-init-kmeans.c')
-rw-r--r--km-init-kmeans.c34
1 files changed, 32 insertions, 2 deletions
diff --git a/km-init-kmeans.c b/km-init-kmeans.c
index 61d67a7..f0c1b5f 100644
--- a/km-init-kmeans.c
+++ b/km-init-kmeans.c
@@ -4,6 +4,24 @@
#include "util.h"
#include "km.h"
+// sum the squared distance of every row in set from this point
+static inline float
+sum_distance_squared(
+ const float * const floats,
+ const km_set_t * const set
+) {
+ float r = 0;
+
+ // sum squared distances
+ for (size_t i = 0; i < set->num_rows; i++) {
+ const float * const vals = km_set_get_row(set, i);
+ r += distance_squared(set->shape.num_floats, floats, vals);
+ }
+
+ // return result
+ return r;
+}
+
static const float *
get_random_row(
const km_set_t * const set,
@@ -32,6 +50,7 @@ km_init_kmeans(
stride = sizeof(float) * num_floats;
// row values (filled below)
+ float sums[num_clusters];
float floats[num_floats * num_clusters];
// pick first row randomly
@@ -41,6 +60,9 @@ km_init_kmeans(
// copy row values to floats buffer
memcpy(floats, vals, stride);
+
+ // copy row values to floats buffer
+ sums[0] = sum_distance_squared(floats, set);
}
for (size_t i = 1; i < num_clusters;) {
@@ -48,10 +70,12 @@ km_init_kmeans(
const float * const vals = get_random_row(set, rs);
// calculate squared distance to nearest cluster
+ size_t ofs = 0;
float min_d2 = FLT_MAX;
for (size_t j = 0; j < i; j++) {
const float d2 = distance_squared(num_floats, vals, floats + j * stride);
if (d2 < min_d2) {
+ ofs = j;
min_d2 = d2;
}
}
@@ -64,9 +88,15 @@ km_init_kmeans(
}
// check random value
- if (rand_val < min_d2) {
+ if (rand_val / sums[ofs] < min_d2) {
+ // get destination
+ float * const dst = floats + i * num_floats;
+
// copy row values
- memcpy(floats + i * num_floats, vals, stride);
+ memcpy(dst, vals, stride);
+
+ // calculate total cluster distance
+ sums[i] = sum_distance_squared(dst, set);
// increment cluster count
i++;