aboutsummaryrefslogtreecommitdiff
path: root/src/km-init-kmeans.c
blob: f0c1b5fbb9d7974fd00ee71ef0e4bb6434432549 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include <stdbool.h> // bool
#include <string.h> // memset()
#include <float.h> // FLT_MAX
#include "util.h"
#include "km.h"

// sum the squared distance of every row in set from this point
static inline float
sum_distance_squared(
  const float * const floats,
  const km_set_t * const set
) {
  float r = 0;

  // sum squared distances
  for (size_t i = 0; i < set->num_rows; i++) {
    const float * const vals = km_set_get_row(set, i);
    r += distance_squared(set->shape.num_floats, floats, vals);
  }

  // return result
  return r;
}

static const float *
get_random_row(
  const km_set_t * const set,
  km_rand_t * const rs
) {
  // get random offset
  size_t ofs = 0;
  if (!km_rand_get_sizes(rs, 1, &ofs)) {
    die("km_rand_fill_sizes()");
  }

  return km_set_get_row(set, ofs % set->num_rows);
}

// init a cluster set of num_clusters by picking the first cluster from
// the set of points at random
// random initial points from the set
bool
km_init_kmeans(
  km_set_t * const cs,
  const size_t num_clusters,
  const km_set_t * const set,
  km_rand_t * const rs
) {
  const size_t num_floats = set->shape.num_floats,
               stride = sizeof(float) * num_floats;
  
  // row values (filled below)
  float sums[num_clusters];
  float floats[num_floats * num_clusters];

  // pick first row randomly
  {
    // get random row
    const float * const vals = get_random_row(set, rs);

    // copy row values to floats buffer
    memcpy(floats, vals, stride);

    // copy row values to floats buffer
    sums[0] = sum_distance_squared(floats, set);
  }

  for (size_t i = 1; i < num_clusters;) {
    // get random row
    const float * const vals = get_random_row(set, rs);

    // calculate squared distance to nearest cluster
    size_t ofs = 0;
    float min_d2 = FLT_MAX;
    for (size_t j = 0; j < i; j++) {
      const float d2 = distance_squared(num_floats, vals, floats + j * stride);
      if (d2 < min_d2) {
        ofs = j;
        min_d2 = d2;
      }
    }

    // get a random floating point value
    float rand_val = 0;
    if (!km_rand_get_floats(rs, 1, &rand_val)) {
      // return failure
      return false;
    }

    // check random value
    if (rand_val / sums[ofs] < min_d2) {
      // get destination
      float * const dst = floats + i * num_floats;

      // copy row values
      memcpy(dst, vals, stride);

      // calculate total cluster distance
      sums[i] = sum_distance_squared(dst, set);

      // increment cluster count
      i++;
    }
  }

  // FIXME: should probably be heap-allocated
  int ints[num_clusters];
  memset(ints, 0, sizeof(ints));

  // init cluster shape
  const km_shape_t shape = {
    .num_floats = num_floats,
    .num_ints = 1,
  };

  // init cluster set
  if (!km_set_init(cs, &shape, num_clusters)) {
    // return failure
    return false;
  }

  // add data, return result
  return km_set_push(cs, num_clusters, floats, ints);
}