aboutsummaryrefslogtreecommitdiff
path: root/km-find.c
blob: 8e7f9ac243ad4d74b2238892765613b41981f091 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#include <stdbool.h> // bool
#include <math.h> // fabsf()
#include "util.h"
#include "km.h"

#define MIN_CLUSTER_DISTANCE 0.0001

typedef struct {
  float distance_sum,
        variance_sum;
  size_t num_empty_clusters;
} find_solve_data_t;

static float
find_get_mean_distance(
  const find_solve_data_t * const d,
  const size_t num_clusters
) {
  const bool num_filled = num_clusters - d->num_empty_clusters;
  return num_filled ? (d->distance_sum / num_filled) : 0;
}

static float
find_get_mean_variance(
  const find_solve_data_t * const d,
  const size_t num_clusters
) {
  const bool num_filled = num_clusters - d->num_empty_clusters;
  return num_filled ? (d->variance_sum / num_filled) : 0;
}

static float
find_get_mean_cluster_size(
  const km_set_t * const set
) {
  float sum = 0;
  size_t num_filled = 0;

  for (size_t i = 0; i < set->num_rows; i++) {
    if (set->ints[i] > 0) {
      sum += set->ints[i];
      num_filled++;
    }
  }

  return (num_filled > 0) ? sum / num_filled : 0;
}

static void
find_solve_on_stats(
  const km_set_t * const set,
  const km_solve_stats_t * const stats,
  void * const cb_data
) {
  find_solve_data_t * const solve_data = cb_data;
  UNUSED(set);

  // calculate numerator for the average distance across all clusters in
  // this test
  for (size_t i = 0; i < stats->num_clusters; i++) {
    if (fabsf(stats->means[i]) > MIN_CLUSTER_DISTANCE) {
      // increment mean count
      solve_data->distance_sum += stats->means[i];
      solve_data->variance_sum += stats->variances[i];
    } else {
      // increment empty cluster count
      solve_data->num_empty_clusters++;
    }
  }
}

static const km_solve_cbs_t
FIND_SOLVE_CBS = {
  .on_stats = find_solve_on_stats,
};

bool
km_find(
  const km_set_t * const set,
  const km_find_cbs_t * const cbs,
  void * const cb_data
) {
  // check init callback
  if (!cbs->on_init) {
    return false;
  }

  // check fini callback
  if (!cbs->on_fini) {
    return false;
  }

  // check data callback
  if (!cbs->on_data) {
    return false;
  }

  float best_score = 0.0;
  for (size_t i = 2; i < cbs->max_clusters; i++) {
    for (size_t j = 0; j < cbs->num_tests; j++) {
      // init cluster set
      km_set_t cs;
      if (!cbs->on_init(&cs, i, set, cb_data)) {
        // return failure
        return false;
      }

      // init solve data
      find_solve_data_t solve_data = {
        .distance_sum       = 0,
        .variance_sum       = 0,
        .num_empty_clusters = 0,
      };

      // solve test
      if (!km_solve(&cs, set, &FIND_SOLVE_CBS, &solve_data)) {
        // return failure
        return false;
      }

      // init result data
      const km_find_data_t result = {
        .cluster_set        = &cs,
        .num_clusters       = i,
        .mean_distance      = find_get_mean_distance(&solve_data, i),
        .mean_variance      = find_get_mean_variance(&solve_data, i),
        .mean_cluster_size  = find_get_mean_cluster_size(&cs),
        .num_empty_clusters = solve_data.num_empty_clusters,
      };

      // emit result
      cbs->on_data(&result, cb_data);

      // score result
      float score = km_score(result.mean_distance, result.num_empty_clusters);

      if (score > best_score) {
        // emit new best result
        if (cbs->on_best && !cbs->on_best(score, &cs, cb_data)) {
          // return failure
          return false;
        }

        // update best score
        best_score = score;
      }

      // finalize cluster set
      if (!cbs->on_fini(&cs, cb_data)) {
        // return failure
        return false;
      }
    }
  }

  // return success
  return true;
}