aboutsummaryrefslogtreecommitdiff
path: root/km-solve.c
blob: 4560c67c652e8322882eb0e7915ea79e86a135ac (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#include <stdbool.h> // bool
#include <string.h> // memset()
#include <float.h> // FLT_MAX
#include <math.h> // sqrt()
#include "util.h"
#include "km.h"

// calculate squared euclidean distance between two points
static float
distance_squared(
  const size_t num_floats,
  const float * const a,
  const float * const b
) {
  float r = 0.0;

  for (size_t i = 0; i < num_floats; i++) {
    r += (b[i] - a[i]) * (b[i] - a[i]);
  }

  // return squared distance
  return r;
}

// alloc and initialize row map
static km_row_map_t *
km_row_map_init(
  const size_t num_rows
) {
  // alloc row map
  km_row_map_t * const row_map = malloc(sizeof(km_row_map_t) * num_rows);

  // check for error
  if (!row_map) {
    // return failure
    return false;
  }

  // init row map by setting the maximum distance
  for (size_t i = 0; i < num_rows; i++) {
    row_map[i].d2 = FLT_MAX;
  }

  // return row map
  return row_map;
}

static void
km_row_map_fini(
  km_row_map_t * const row_map
) {
  free(row_map);
}

// use k-means to iteratively update the cluster centroids until there
// are no more changes to the centroids
bool
km_solve(
  km_set_t * const cs,
  const km_set_t * const set,
  const km_solve_cbs_t * const cbs,
  void * const cb_data
) {
  const size_t num_clusters = cs->num_rows,
               num_floats = set->shape.num_floats;

  // row map: row => distance, cluster ID
  km_row_map_t *row_map = km_row_map_init(set->num_rows);
  if (!row_map) {
    // return failure
    return false;
  }

  // calculate clusters by doing the following:
  //   * walk all clusters and all rows
  //   * if we find a closer cluster, move row to cluster
  //   * if there were changes to any cluster, then calculate a new
  //     centroid for each cluster by averaging the cluster rows
  //   * repeat until there are no more changes
  bool changed = false;
  do {
    // no changes yet
    changed = false;

    for (size_t i = 0; i < num_clusters; i++) {
      // get the floats for this cluster
      const float * const floats = km_set_get_row(cs, i);

      for (size_t j = 0; j < set->num_rows; j++) {
        const float * const row_floats = km_set_get_row(set, j);

        // calculate the distance squared between these clusters
        const float d2 = distance_squared(num_floats, floats, row_floats);

        if (d2 < row_map[j].d2) {
          // row is closer to this cluster, update distance and cluster
          row_map[j].d2 = d2;
          row_map[j].cluster = i;

          // flag change
          changed = true;
        }
      }
    }

    if (changed) {
      // if there were changes, then we need to calculate the new
      // cluster centers

      // calculate new center
      for (size_t i = 0; i < num_clusters; i++) {
        size_t num_rows = 0;
        float * const floats = km_set_get_row(cs, i);
        memset(floats, 0, sizeof(float) * num_floats);

        for (size_t j = 0; j < set->num_rows; j++) {
          const float * const row_floats = km_set_get_row(set, j);

          if (row_map[j].cluster == i) {
            // calculate numerator for average
            for (size_t k = 0; k < num_floats; k++) {
              floats[k] += row_floats[k];
            }

            // increment denominator for average
            num_rows++;
          }
        }

        // save number of rows in this cluster
        cs->ints[i] = num_rows;

        if (num_rows > 0) {
          for (size_t k = 0; k < num_floats; k++) {
            // divide by denominator to get average
            floats[k] /= num_rows;
          }
        }
      }
    }

    if (cbs && cbs->on_step) {
      // pass clusters and row map to step callback
      cbs->on_step(cs, row_map, cb_data);
    }
  } while (changed);

  if (cbs && cbs->on_stats) {
    float means[num_clusters],
          variances[num_clusters];
    memset(means, 0, sizeof(float) * num_clusters);
    memset(variances, 0, sizeof(float) * num_clusters);

    // calculate mean distances
    for (size_t i = 0; i < set->num_rows; i++) {
      means[row_map[i].cluster] += row_map[i].d2;
    }

    // finalize means
    for (size_t i = 0; i < num_clusters; i++) {
      means[i] = (cs->ints[i]) ? (sqrt(means[i]) / cs->ints[i]) : 0;
    }

    // calculate variances
    for (size_t i = 0; i < set->num_rows; i++) {
      const size_t cluster = row_map[i].cluster;
      const float variance = (sqrt(row_map[i].d2) - means[cluster]) *
                             (sqrt(row_map[i].d2) - means[cluster]);
      variances[cluster] += variance;
    }

    // finalize variances
    for (size_t i = 0; i < num_clusters; i++) {
      variances[i] = (cs->ints[i]) ? (variances[i] / cs->ints[i]) : 0;
    }

    // build stats
    const km_solve_stats_t stats = {
      .means        = means,
      .variances    = variances,
      .num_clusters = num_clusters,
    };

    // emit means
    cbs->on_stats(cs, &stats, cb_data);
  }

  // free row_map
  km_row_map_fini(row_map);

  // return success
  return true;
}