aboutsummaryrefslogtreecommitdiff
path: root/km-solve.c
blob: f579a417fae71bf33b626254cdeef55dcd5dd1ea (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
#include <stdbool.h> // bool
#include <string.h> // memset()
#include <float.h> // FLT_MAX
#include <math.h> // sqrt()
#include "util.h"
#include "km.h"

#define MIN_CLUSTER_DISTANCE 0.00001

// alloc and initialize row map
static km_row_map_t *
km_row_map_init(
  const size_t num_rows
) {
  // alloc row map
  km_row_map_t * const row_map = malloc(sizeof(km_row_map_t) * num_rows);

  // check for error
  if (!row_map) {
    // return failure
    return false;
  }

  // init row map
  for (size_t i = 0; i < num_rows; i++) {
    // setting distances to maximum
    row_map[i].d2 = FLT_MAX;
    row_map[i].d2_near = FLT_MAX;
  }

  // return row map
  return row_map;
}

static void
km_row_map_fini(
  km_row_map_t * const row_map
) {
  free(row_map);
}

// use k-means to iteratively update the cluster centroids until there
// are no more changes to the centroids
bool
km_solve(
  km_set_t * const cs,
  const km_set_t * const set,
  const km_solve_cbs_t * const cbs,
  void * const cb_data
) {
  const size_t num_clusters = cs->num_rows,
               num_floats = set->shape.num_floats;

  // row map: row => distance, cluster ID
  km_row_map_t *row_map = km_row_map_init(set->num_rows);
  if (!row_map) {
    // return failure
    return false;
  }

  // calculate clusters by doing the following:
  //   * walk all clusters and all rows
  //   * if we find a closer cluster, move row to cluster
  //   * if there were changes to any cluster, then calculate a new
  //     centroid for each cluster by averaging the cluster rows
  //   * repeat until there are no more changes
  bool changed = false;
  do {
    // no changes yet
    changed = false;

    for (size_t i = 0; i < num_clusters; i++) {
      // get the floats for this cluster
      const float * const floats = km_set_get_row(cs, i);

      for (size_t j = 0; j < set->num_rows; j++) {
        // get row values
        const float * const vals = km_set_get_row(set, j);

        // calculate the distance squared between row and cluster
        const float d2 = distance_squared(num_floats, floats, vals);

        if (d2 < row_map[j].d2) {
          // row is closer to this cluster, update row map
          row_map[j].d2 = d2;
          row_map[j].cluster = i;

          // flag change
          changed = true;
        }

        if ((row_map[j].cluster != i) && (d2 < row_map[j].d2_near)) {
          row_map[j].d2_near = d2;

          // flag change
          changed = true;
        }
      }
    }

    if (changed) {
      // if there were changes, then we need to calculate the new
      // cluster centers

      // calculate new center
      for (size_t i = 0; i < num_clusters; i++) {
        size_t num_rows = 0;
        float * const floats = km_set_get_row(cs, i);
        memset(floats, 0, sizeof(float) * num_floats);

        for (size_t j = 0; j < set->num_rows; j++) {
          const float * const row_floats = km_set_get_row(set, j);

          if (row_map[j].cluster == i) {
            // calculate numerator for average
            for (size_t k = 0; k < num_floats; k++) {
              floats[k] += row_floats[k];
            }

            // increment denominator for average
            num_rows++;
          }
        }

        // save number of rows in this cluster
        cs->ints[i] = num_rows;

        if (num_rows > 0) {
          for (size_t k = 0; k < num_floats; k++) {
            // divide by denominator to get average
            floats[k] /= num_rows;
          }
        }
      }
    }

    if (cbs && cbs->on_step) {
      // pass clusters and row map to step callback
      cbs->on_step(cs, row_map, cb_data);
    }
  } while (changed);

  if (cbs && cbs->on_stats) {
    float sum = 0,
          silouette = 0,
          mean_dists[num_clusters],
          mean_nears[num_clusters];

    memset(mean_dists, 0, sizeof(mean_dists));
    memset(mean_nears, 0, sizeof(mean_nears));

    // calculate sum of distances across all clusters
    for (size_t i = 0; i < set->num_rows; i++) {
      sum += row_map[i].d2;
    }

    // calculate mean numerators and silouette
    for (size_t i = 0; i < set->num_rows; i++) {
      // distance squared (d2) to center of this cluster
      mean_dists[row_map[i].cluster] += row_map[i].d2;

      // distance squared (d2) to center of nearest cluster
      mean_nears[row_map[i].cluster] += row_map[i].d2_near;

      // calculate silouette denominator
      // (https://en.wikipedia.org/wiki/Silhouette_%28clustering%29)
      const float delta = row_map[i].d2_near - row_map[i].d2;
      if (fabsf(delta) > MIN_CLUSTER_DISTANCE) {
        silouette += delta / MAX(row_map[i].d2, row_map[i].d2_near);
      }
    }

    // finalize means (divide by row count)
    for (size_t i = 0; i < num_clusters; i++) {
      mean_dists[i] = (cs->ints[i]) ? (sqrt(mean_dists[i]) / cs->ints[i]) : 0;
      mean_nears[i] = (cs->ints[i]) ? (sqrt(mean_nears[i]) / cs->ints[i]) : 0;
    }

    // finalize silouette
    silouette /= set->num_rows;

    // build stats
    const km_solve_stats_t stats = {
      .sum          = sum,
      .silouette    = silouette,
      .mean_dists   = mean_dists,
      .mean_nears   = mean_nears,
      .num_clusters = num_clusters,
    };

    // emit means
    cbs->on_stats(cs, &stats, cb_data);
  }

  // free row_map
  km_row_map_fini(row_map);

  // return success
  return true;
}