1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
|
#include <stdbool.h> // bool
#include <string.h> // memset()
#include <float.h> // FLT_MAX
#include "util.h"
#include "km.h"
#define MIN_DELTA 0.00001
static void
km_row_map_clear(
km_row_map_t * const row_map,
const size_t num_rows
) {
// init row map
#pragma omp parallel for
for (size_t i = 0; i < num_rows; i++) {
// setting distances to maximum
row_map[i].d2 = FLT_MAX;
row_map[i].d2_near = FLT_MAX;
}
}
// use k-means to iteratively update the cluster centroids until there
// are no more changes to the centroids
bool
km_solve(
km_set_t * const cs,
const km_set_t * const set,
km_row_map_t * const row_map,
const km_solve_cbs_t * const cbs,
void * const cb_data
) {
const size_t num_clusters = cs->num_rows,
num_floats = set->shape.num_floats;
// clear row map distances
km_row_map_clear(row_map, set->num_rows);
// calculate clusters by doing the following:
// * walk all clusters and all rows
// * if we find a closer cluster, move row to cluster
// * if there were changes to any cluster, then calculate a new
// centroid for each cluster by averaging the cluster rows
// * repeat until there are no more changes
bool changed = false;
do {
// no changes yet
changed = false;
#pragma omp parallel for collapse(2)
for (size_t i = 0; i < num_clusters; i++) {
for (size_t j = 0; j < set->num_rows; j++) {
// get the floats for this cluster
// NOTE: moved to inner loop for "collapse(2)" above
const float * const floats = km_set_get_row(cs, i);
// get row values
const float * const vals = km_set_get_row(set, j);
// calculate the distance squared between row and cluster
const float d2 = distance_squared(num_floats, floats, vals);
if (d2 < row_map[j].d2) {
// row is closer to this cluster, update row map
row_map[j].d2 = d2;
row_map[j].cluster = i;
// flag change
changed = true;
}
if ((row_map[j].cluster != i) && (d2 < row_map[j].d2_near)) {
row_map[j].d2_near = d2;
// flag change
changed = true;
}
}
}
if (changed) {
// if there were changes, then we need to calculate the new
// cluster centers
// calculate new cluster centers
#pragma omp parallel for
for (size_t i = 0; i < num_clusters; i++) {
size_t num_rows = 0;
float * const floats = km_set_get_row(cs, i);
memset(floats, 0, sizeof(float) * num_floats);
for (size_t j = 0; j < set->num_rows; j++) {
const float * const row_floats = km_set_get_row(set, j);
if (row_map[j].cluster == i) {
// calculate numerator for average
for (size_t k = 0; k < num_floats; k++) {
floats[k] += row_floats[k];
}
// increment denominator for average
num_rows++;
}
}
// save number of rows in this cluster
cs->ints[i] = num_rows;
if (num_rows > 0) {
for (size_t k = 0; k < num_floats; k++) {
// divide by denominator to get average
floats[k] /= num_rows;
}
}
}
}
if (cbs && cbs->on_step) {
// pass clusters and row map to step callback
cbs->on_step(cs, row_map, cb_data);
}
} while (changed);
if (cbs && cbs->on_stats) {
float sum = 0,
silhouette = 0;
// calculate mean numerators and silhouette
#pragma omp parallel for reduction(+:sum,silhouette)
for (size_t i = 0; i < set->num_rows; i++) {
// calculate silhouette denominator
// (https://en.wikipedia.org/wiki/Silhouette_%28clustering%29)
const float delta = (row_map[i].d2_near - row_map[i].d2);
// sum silhouette
silhouette += (delta < -MIN_DELTA || delta > MIN_DELTA)
? (delta / MAX(row_map[i].d2, row_map[i].d2_near))
: 0.0;
// sum distances
sum += row_map[i].d2;
}
// finalize silhouette
silhouette /= set->num_rows;
// build stats
const km_solve_stats_t stats = {
.sum = sum,
.silhouette = silhouette,
.num_clusters = num_clusters,
};
// emit means
cbs->on_stats(cs, &stats, cb_data);
}
// return success
return true;
}
|