1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
|
#include <stdbool.h> // bool
#include "util.h"
#include "km.h"
static float
get_mean_cluster_size(
const km_set_t * const set
) {
float sum = 0;
size_t num_filled = 0;
for (size_t i = 0; i < set->num_rows; i++) {
if (set->ints[i] > 0) {
sum += set->ints[i];
num_filled++;
}
}
return (num_filled > 0) ? (sum / num_filled) : 0;
}
static void
find_solve_on_stats(
const km_set_t * const set,
const km_solve_stats_t * const stats,
void * const cb_data
) {
km_find_data_t * const find_data = cb_data;
UNUSED(set);
// save total sum and silhouette
find_data->distance_sum = stats->sum;
find_data->silhouette = stats->silhouette;
}
static const km_solve_cbs_t
FIND_SOLVE_CBS = {
.on_stats = find_solve_on_stats,
};
bool
km_find(
const km_set_t * const set,
const km_find_cbs_t * const cbs,
void * const cb_data
) {
// check init callback
if (!cbs->on_init) {
return false;
}
// check fini callback
if (!cbs->on_fini) {
return false;
}
// check data callback
if (!cbs->on_data) {
return false;
}
// allocate row map: row => distance, cluster ID
km_row_map_t * const row_map = calloc(set->num_rows, sizeof(km_row_map_t));
if (!row_map) {
// return failure
return false;
}
float best_silhouette = -2.0;
for (size_t i = 2; i < cbs->max_clusters; i++) {
for (size_t j = 0; j < cbs->num_tests; j++) {
// init cluster set
km_set_t cs;
if (!cbs->on_init(&cs, i, set, cb_data)) {
// return failure
return false;
}
// init find data
km_find_data_t find_data = {
.cluster_set = &cs,
.num_clusters = i,
};
// solve test
// (populates sum and silhouette)
if (!km_solve(&cs, set, row_map, &FIND_SOLVE_CBS, &find_data)) {
// return failure
return false;
}
// populate mean cluster size
find_data.mean_cluster_size = get_mean_cluster_size(&cs);
// emit result
cbs->on_data(&find_data, cb_data);
if (find_data.silhouette > best_silhouette) {
// emit new best result
if (cbs->on_best && !cbs->on_best(find_data.silhouette, &cs, cb_data)) {
// return failure
return false;
}
// update best silhouette
best_silhouette = find_data.silhouette;
}
// finalize cluster set
if (!cbs->on_fini(&cs, cb_data)) {
// return failure
return false;
}
}
}
// free row map
free(row_map);
// return success
return true;
}
|