aboutsummaryrefslogtreecommitdiff
path: root/src/km-find.c
blob: e9fed0f68403fb67ecf71d1758cc2e55e6c5c505 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#include <stdbool.h> // bool
#include "util.h"
#include "km.h"

static float
get_mean_cluster_size(
  const km_set_t * const set
) {
  float sum = 0;
  size_t num_filled = 0;

  for (size_t i = 0; i < set->num_rows; i++) {
    if (set->ints[i] > 0) {
      sum += set->ints[i];
      num_filled++;
    }
  }

  return (num_filled > 0) ? (sum / num_filled) : 0;
}

static void
find_solve_on_stats(
  const km_set_t * const set,
  const km_solve_stats_t * const stats,
  void * const cb_data
) {
  km_find_data_t * const find_data = cb_data;
  UNUSED(set);

  // save total sum and silhouette
  find_data->distance_sum = stats->sum;
  find_data->silhouette = stats->silhouette;
}

static const km_solve_cbs_t
FIND_SOLVE_CBS = {
  .on_stats = find_solve_on_stats,
};

bool
km_find(
  const km_set_t * const set,
  const km_find_cbs_t * const cbs,
  void * const cb_data
) {
  // check init callback
  if (!cbs->on_init) {
    return false;
  }

  // check fini callback
  if (!cbs->on_fini) {
    return false;
  }

  // check data callback
  if (!cbs->on_data) {
    return false;
  }

  // allocate row map: row => distance, cluster ID
  km_row_map_t * const row_map = calloc(set->num_rows, sizeof(km_row_map_t));
  if (!row_map) {
    // return failure
    return false;
  }

  float best_silhouette = -2.0;
  for (size_t i = 2; i < cbs->max_clusters; i++) {
    for (size_t j = 0; j < cbs->num_tests; j++) {
      // init cluster set
      km_set_t cs;
      if (!cbs->on_init(&cs, i, set, cb_data)) {
        // return failure
        return false;
      }

      // init find data
      km_find_data_t find_data = {
        .cluster_set        = &cs,
        .num_clusters       = i,
      };

      // solve test
      // (populates sum and silhouette)
      if (!km_solve(&cs, set, row_map, &FIND_SOLVE_CBS, &find_data)) {
        // return failure
        return false;
      }

      // populate mean cluster size
      find_data.mean_cluster_size = get_mean_cluster_size(&cs);

      // emit result
      cbs->on_data(&find_data, cb_data);

      if (find_data.silhouette > best_silhouette) {
        // emit new best result
        if (cbs->on_best && !cbs->on_best(find_data.silhouette, &cs, cb_data)) {
          // return failure
          return false;
        }

        // update best silhouette
        best_silhouette = find_data.silhouette;
      }

      // finalize cluster set
      if (!cbs->on_fini(&cs, cb_data)) {
        // return failure
        return false;
      }
    }
  }

  // free row map
  free(row_map);

  // return success
  return true;
}