aboutsummaryrefslogtreecommitdiff
path: root/main.c
diff options
context:
space:
mode:
Diffstat (limited to 'main.c')
-rw-r--r--main.c110
1 files changed, 98 insertions, 12 deletions
diff --git a/main.c b/main.c
index ec62022..4d86b09 100644
--- a/main.c
+++ b/main.c
@@ -14,6 +14,11 @@
#define MAX_CLUSTERS 10
#define NUM_TESTS 100
+#define MAX_BEST 4
+
+#define IM_WIDTH 128
+#define IM_HEIGHT 128
+#define IM_STRIDE (3 * IM_WIDTH)
typedef struct {
km_rand_t rs;
@@ -24,8 +29,35 @@ typedef struct {
cluster_size;
size_t num_empty_clusters;
} rows[MAX_CLUSTERS - 2];
+
+ km_set_t best[MAX_BEST];
+ size_t num_best;
} find_t;
+static void
+find_each_best(
+ const find_t * const find_data,
+ void (*on_best)(const km_set_t * const, const size_t, void *),
+ void * const cb_data
+) {
+ // if the number of best sets is greater than MAX_BEST, then
+ // find_data->best is actually a ring buffer
+ const bool is_ring = find_data->num_best >= MAX_BEST;
+
+ if (!on_best) {
+ return;
+ }
+
+ // walk best sets
+ for (size_t i = 0; i < MAX_BEST; i++) {
+ // calculate set offset
+ const size_t ofs = i + (is_ring ? (find_data->num_best + 1) : 0);
+
+ // emit set
+ on_best(find_data->best + (ofs % MAX_BEST), i, cb_data);
+ }
+}
+
static bool
load_on_shape(
const km_shape_t * const shape,
@@ -112,14 +144,48 @@ find_on_data(
find_data->rows[ofs].num_empty_clusters += data->num_empty_clusters;
}
+static bool
+find_on_best(
+ const float score,
+ const km_set_t * const cs,
+ void *cb_data
+) {
+ find_t *find_data = cb_data;
+
+ D("best score = %0.3f, num_clusters = %zu", score, cs->num_rows);
+
+ // get pointer to destination set
+ // (note: data->best is a ring buffer)
+ km_set_t *dst = find_data->best + (find_data->num_best % MAX_BEST);
+
+ if (find_data->num_best >= MAX_BEST) {
+ // finalize old best data set
+ // D("finalizing old best %zu", find_data->num_best);
+ km_set_fini(dst);
+ }
+
+ // copy data set to best ring buffer
+ if (!km_set_copy(dst, cs)) {
+ die("km_set_copy()");
+ }
+
+ // increment best count
+ find_data->num_best++;
+
+ // return success
+ return true;
+}
+
// init find config
static const km_find_cbs_t
FIND_CBS = {
.max_clusters = MAX_CLUSTERS,
.num_tests = NUM_TESTS,
+
.on_init = find_on_init,
.on_fini = find_on_fini,
.on_data = find_on_data,
+ .on_best = find_on_best,
};
static float
@@ -127,9 +193,6 @@ get_score(
const size_t ofs,
const find_t * const find_data
) {
- if (!ofs || ofs == MAX_CLUSTERS - 3) {
- return 0;
- }
// const size_t num_clusters = ofs + 2;
const float mean_distance = find_data->rows[ofs].distance / NUM_TESTS,
mean_empty = 1.0 * find_data->rows[ofs].num_empty_clusters / NUM_TESTS;
@@ -178,22 +241,45 @@ print_csv(
}
}
-#define IM_WIDTH 128
-#define IM_HEIGHT 128
-#define IM_STRIDE (3 * IM_WIDTH)
-
static uint8_t im_data[3 * IM_WIDTH * IM_HEIGHT];
static void
+save_on_best(
+ const km_set_t * const set,
+ const size_t rank,
+ void * const cb_data
+) {
+ UNUSED(cb_data);
+
+ // convert rank to channel brightness
+ const uint8_t ch = 0x66 + (0xff - 0x66) * (1.0 * rank) / (MAX_BEST - 1);
+ const uint32_t color = (ch & 0xff) << 16;
+ // const uint32_t color = 0xff0000;
+ D("rank = %zu, color = %u", rank, color);
+
+ // draw clusters
+ km_set_draw(set, im_data, IM_WIDTH, IM_HEIGHT, 3, color);
+}
+
+static void
save_png(
const char * const png_path,
- const km_set_t * const set
+ const km_set_t * const set,
+ const find_t * const find_data
) {
// clear image data to white
memset(im_data, 0xff, sizeof(im_data));
// draw red points
- km_set_draw(set, im_data, IM_WIDTH, IM_HEIGHT, 0xff0000);
+ km_set_draw(set, im_data, IM_WIDTH, IM_HEIGHT, 1, 0x000000);
+ if (!stbi_write_png(png_path, IM_WIDTH, IM_HEIGHT, 3, im_data, IM_STRIDE)) {
+ die("stbi_write_png(\"%s\")", png_path);
+ }
+
+ // draw best cluster points
+ find_each_best(find_data, save_on_best, NULL);
+
+ // save png
if (!stbi_write_png(png_path, IM_WIDTH, IM_HEIGHT, 3, im_data, IM_STRIDE)) {
die("stbi_write_png(\"%s\")", png_path);
}
@@ -211,7 +297,7 @@ int main(int argc, char *argv[]) {
// init find data
find_t find_data;
- memset(find_data.rows, 0, sizeof(find_data.rows));
+ memset(&find_data, 0, sizeof(find_t));
km_rand_init_system(&(find_data.rs));
// init data set
@@ -232,8 +318,8 @@ int main(int argc, char *argv[]) {
// print csv
print_csv(&find_data);
- // save png of data set
- save_png("data.png", &set);
+ // save png of normalized data set and best clusters
+ save_png("data.png", &set, &find_data);
// finalize data set
km_set_fini(&set);