diff options
author | Paul Duncan <pabs@pablotron.org> | 2019-02-04 15:33:35 -0500 |
---|---|---|
committer | Paul Duncan <pabs@pablotron.org> | 2019-02-04 15:33:35 -0500 |
commit | 159c42498365913f6ed400e13c77798d041a7d43 (patch) | |
tree | 96cc3b14bbbbd44dc5173e89b85cd2c8228e86e4 /km-init-kmeans.c | |
parent | f4a38b43d43f9395d6042d234a5e0ada7455ace1 (diff) | |
download | kmeans-159c42498365913f6ed400e13c77798d041a7d43.tar.bz2 kmeans-159c42498365913f6ed400e13c77798d041a7d43.zip |
add rand-{path,erand48}, minor fixes
Diffstat (limited to 'km-init-kmeans.c')
-rw-r--r-- | km-init-kmeans.c | 34 |
1 files changed, 32 insertions, 2 deletions
diff --git a/km-init-kmeans.c b/km-init-kmeans.c index 61d67a7..f0c1b5f 100644 --- a/km-init-kmeans.c +++ b/km-init-kmeans.c @@ -4,6 +4,24 @@ #include "util.h" #include "km.h" +// sum the squared distance of every row in set from this point +static inline float +sum_distance_squared( + const float * const floats, + const km_set_t * const set +) { + float r = 0; + + // sum squared distances + for (size_t i = 0; i < set->num_rows; i++) { + const float * const vals = km_set_get_row(set, i); + r += distance_squared(set->shape.num_floats, floats, vals); + } + + // return result + return r; +} + static const float * get_random_row( const km_set_t * const set, @@ -32,6 +50,7 @@ km_init_kmeans( stride = sizeof(float) * num_floats; // row values (filled below) + float sums[num_clusters]; float floats[num_floats * num_clusters]; // pick first row randomly @@ -41,6 +60,9 @@ km_init_kmeans( // copy row values to floats buffer memcpy(floats, vals, stride); + + // copy row values to floats buffer + sums[0] = sum_distance_squared(floats, set); } for (size_t i = 1; i < num_clusters;) { @@ -48,10 +70,12 @@ km_init_kmeans( const float * const vals = get_random_row(set, rs); // calculate squared distance to nearest cluster + size_t ofs = 0; float min_d2 = FLT_MAX; for (size_t j = 0; j < i; j++) { const float d2 = distance_squared(num_floats, vals, floats + j * stride); if (d2 < min_d2) { + ofs = j; min_d2 = d2; } } @@ -64,9 +88,15 @@ km_init_kmeans( } // check random value - if (rand_val < min_d2) { + if (rand_val / sums[ofs] < min_d2) { + // get destination + float * const dst = floats + i * num_floats; + // copy row values - memcpy(floats + i * num_floats, vals, stride); + memcpy(dst, vals, stride); + + // calculate total cluster distance + sums[i] = sum_distance_squared(dst, set); // increment cluster count i++; |