aboutsummaryrefslogtreecommitdiff
path: root/src/km-solve.c
blob: 271795be093d17b4d301c44231764d379f6d86df (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
#include <stdbool.h> // bool
#include <string.h> // memset()
#include <float.h> // FLT_MAX
#include "util.h"
#include "km.h"

#define MIN_DELTA 0.00001

static void
km_row_map_clear(
  km_row_map_t * const row_map,
  const size_t num_rows
) {
  // init row map
  #pragma omp parallel for
  for (size_t i = 0; i < num_rows; i++) {
    // setting distances to maximum
    row_map[i].d2 = FLT_MAX;
    row_map[i].d2_near = FLT_MAX;
  }
}

// use k-means to iteratively update the cluster centroids until there
// are no more changes to the centroids
bool
km_solve(
  km_set_t * const cs,
  const km_set_t * const set,
  km_row_map_t * const row_map,
  const km_solve_cbs_t * const cbs,
  void * const cb_data
) {
  const size_t num_clusters = cs->num_rows,
               num_floats = set->shape.num_floats;

  // clear row map distances
  km_row_map_clear(row_map, set->num_rows);

  // calculate clusters by doing the following:
  //   * walk all clusters and all rows
  //   * if we find a closer cluster, move row to cluster
  //   * if there were changes to any cluster, then calculate a new
  //     centroid for each cluster by averaging the cluster rows
  //   * repeat until there are no more changes
  bool changed = false;
  do {
    // no changes yet
    changed = false;

    #pragma omp parallel for collapse(2)
    for (size_t i = 0; i < num_clusters; i++) {
      for (size_t j = 0; j < set->num_rows; j++) {
        // get the floats for this cluster
        // NOTE: moved to inner loop for "collapse(2)" above
        const float * const floats = km_set_get_row(cs, i);

        // get row values
        const float * const vals = km_set_get_row(set, j);

        // calculate the distance squared between row and cluster
        const float d2 = distance_squared(num_floats, floats, vals);

        if (d2 < row_map[j].d2) {
          // row is closer to this cluster, update row map
          row_map[j].d2 = d2;
          row_map[j].cluster = i;

          // flag change
          changed = true;
        }

        if ((row_map[j].cluster != i) && (d2 < row_map[j].d2_near)) {
          row_map[j].d2_near = d2;

          // flag change
          changed = true;
        }
      }
    }

    if (changed) {
      // if there were changes, then we need to calculate the new
      // cluster centers

      // calculate new cluster centers
      #pragma omp parallel for
      for (size_t i = 0; i < num_clusters; i++) {
        size_t num_rows = 0;
        float * const floats = km_set_get_row(cs, i);
        memset(floats, 0, sizeof(float) * num_floats);

        for (size_t j = 0; j < set->num_rows; j++) {
          const float * const row_floats = km_set_get_row(set, j);

          if (row_map[j].cluster == i) {
            // calculate numerator for average
            for (size_t k = 0; k < num_floats; k++) {
              floats[k] += row_floats[k];
            }

            // increment denominator for average
            num_rows++;
          }
        }

        // save number of rows in this cluster
        cs->ints[i] = num_rows;

        if (num_rows > 0) {
          for (size_t k = 0; k < num_floats; k++) {
            // divide by denominator to get average
            floats[k] /= num_rows;
          }
        }
      }
    }

    if (cbs && cbs->on_step) {
      // pass clusters and row map to step callback
      cbs->on_step(cs, row_map, cb_data);
    }
  } while (changed);

  if (cbs && cbs->on_stats) {
    float sum = 0,
          silhouette = 0;

    // calculate mean numerators and silhouette
    #pragma omp parallel for reduction(+:sum,silhouette)
    for (size_t i = 0; i < set->num_rows; i++) {
      // calculate silhouette denominator
      // (https://en.wikipedia.org/wiki/Silhouette_%28clustering%29)
      const float delta = (row_map[i].d2_near - row_map[i].d2);

      // sum silhouette
      silhouette += (delta < -MIN_DELTA || delta > MIN_DELTA)
        ? (delta / MAX(row_map[i].d2, row_map[i].d2_near))
        : 0.0;

      // sum distances
      sum += row_map[i].d2;
    }

    // finalize silhouette
    silhouette /= set->num_rows;

    // build stats
    const km_solve_stats_t stats = {
      .sum          = sum,
      .silhouette   = silhouette,
      .num_clusters = num_clusters,
    };

    // emit means
    cbs->on_stats(cs, &stats, cb_data);
  }

  // return success
  return true;
}