From f557d1f49a2914c6084dd18efc783395228d8ce0 Mon Sep 17 00:00:00 2001 From: Paul Duncan Date: Tue, 5 Feb 2019 00:22:15 -0500 Subject: mv *.[hc] src/ --- src/main.c | 398 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 398 insertions(+) create mode 100644 src/main.c (limited to 'src/main.c') diff --git a/src/main.c b/src/main.c new file mode 100644 index 0000000..55056b7 --- /dev/null +++ b/src/main.c @@ -0,0 +1,398 @@ +#include // bool +#include // fprintf() +#include // EXIT_{FAILURE,SUCCESS} +#include // exit() +#include // memset() +#include // fabsf() + +#define STB_IMAGE_WRITE_IMPLEMENTATION +#define STB_ONLY_PNG +#include "stb_image_write.h" + +#include "util.h" +#include "km.h" + +#define MAX_CLUSTERS 10 +#define NUM_TESTS 300 +#define MAX_BEST 10 + +#define IM_WIDTH 128 +#define IM_HEIGHT 128 +#define IM_STRIDE (3 * IM_WIDTH) + +typedef struct { + float score; + km_set_t set; +} best_item_t; + +typedef struct { + // cluster initialization method + km_init_type_t init_type; + + // random number source + km_rand_t rs; + + struct { + float distance, + silouette, + cluster_size; + size_t num_empty; + } rows[MAX_CLUSTERS - 2]; + + // best clusters + best_item_t best[MAX_BEST]; + size_t num_best; +} ctx_t; + +static int +best_score_cmp( + const void * const ap, + const void * const bp +) { + const best_item_t * const a = ap; + const best_item_t * const b = bp; + + return (a->score > b->score) ? 1 : -1; +} + +static void +ctx_best_sort( + ctx_t * const ctx +) { + // sort best sets by ascending score (worst to best) + qsort( + ctx->best, + MIN(ctx->num_best, MAX_BEST), + sizeof(best_item_t), + best_score_cmp + ); +} + +static void +ctx_best_walk( + const ctx_t * const ctx, + void (*on_best)(const km_set_t * const, const size_t, const float, void *), + void * const cb_data +) { + if (on_best) { + // walk best sets and emit each one + for (size_t i = 0; i < MIN(ctx->num_best, MAX_BEST); i++) { + on_best(&(ctx->best[i].set), i, ctx->best[i].score, cb_data); + } + } +} + +static bool +load_on_shape( + const km_shape_t * const shape, + void * const cb_data +) { + km_set_t * const set = cb_data; + + // D("shape: %zu floats, %zu ints", shape->num_floats, shape->num_ints); + + // init set + if (!km_set_init(set, shape, 100)) { + die("km_set_init() failed"); + } + + // return success + return true; +} + +static bool +load_on_row( + const float * const floats, + const int * const ints, + void * const cb_data +) { + km_set_t * const set = cb_data; + + // push row + if (!km_set_push(set, 1, floats, ints)) { + die("km_set_push_rows() failed"); + } + + // return success + return true; +} + +static void +load_on_error( + const char * const err, + void * const cb_data +) { + UNUSED(cb_data); + die("load failed: %s", err); +} + +static const km_load_cbs_t +LOAD_CBS = { + .on_shape = load_on_shape, + .on_row = load_on_row, + .on_error = load_on_error, +}; + +static bool +find_on_init( + km_set_t * const cs, + const size_t num_clusters, + const km_set_t * const set, + void *cb_data +) { + ctx_t * const ctx = cb_data; + return km_init(cs, ctx->init_type, num_clusters, set, &(ctx->rs)); +} + +static bool +find_on_fini( + km_set_t * const cs, + void *cb_data +) { + UNUSED(cb_data); + km_set_fini(cs); + return true; +} + +static void +find_on_data( + const km_find_data_t * const data, + void *cb_data +) { + ctx_t * const ctx = cb_data; + const size_t ofs = data->num_clusters - 2; + + ctx->rows[ofs].distance += data->distance_sum; + ctx->rows[ofs].silouette += data->silouette; + ctx->rows[ofs].cluster_size += data->mean_cluster_size; + ctx->rows[ofs].num_empty += data->num_empty_clusters; +} + +static bool +find_on_best( + const float score, + const km_set_t * const cs, + void *cb_data +) { + ctx_t * const ctx = cb_data; + + D("new best: score = %0.3f, num_clusters = %zu", score, cs->num_rows); + + // get pointer to destination set + // (note: data->best is a ring buffer) + const size_t ofs = ctx->num_best % MAX_BEST; + ctx->best[ofs].score = score; + km_set_t * const dst = &(ctx->best[ofs].set); + + if (ctx->num_best >= MAX_BEST) { + // finalize old best data set + km_set_fini(dst); + } + + // copy data set to best ring buffer + if (!km_set_copy(dst, cs)) { + die("km_set_copy()"); + } + + // increment best count + ctx->num_best++; + + // return success + return true; +} + +// init find config +static const km_find_cbs_t +FIND_CBS = { + .max_clusters = MAX_CLUSTERS, + .num_tests = NUM_TESTS, + + .on_init = find_on_init, + .on_fini = find_on_fini, + .on_data = find_on_data, + .on_best = find_on_best, +}; + +static void +ctx_csv_print_row( + const ctx_t * const ctx, + FILE * const fh, + const size_t i +) { + const size_t num_clusters = i + 2; + const float mean_distance = ctx->rows[i].distance / NUM_TESTS, + mean_cluster_size = ctx->rows[i].cluster_size / NUM_TESTS, + mean_empty = 1.0 * ctx->rows[i].num_empty / NUM_TESTS, + score = ctx->rows[i].silouette / NUM_TESTS; + + // print result + fprintf(fh, "%zu,%0.3f,%0.3f,%0.3f,%0.3f\n", + num_clusters, + score, + mean_distance, + mean_cluster_size, + mean_empty + ); +} + +static void +ctx_csv_print( + const ctx_t * const ctx, + FILE * const fh +) { + // print headers + fprintf(fh, "#,score,distance,size,empty\n"); + + // print rows + for (size_t i = 0; i < MAX_CLUSTERS - 2; i++) { + ctx_csv_print_row(ctx, fh, i); + } +} + +// static image data buffer +static uint8_t im_data[3 * IM_WIDTH * IM_HEIGHT]; + +static void +save_on_best( + const km_set_t * const set, + const size_t rank, + const float score, + void * const cb_data +) { + const ctx_t * const ctx = cb_data; + UNUSED(score); + + // convert rank to channel brightness + const uint8_t ch = 0x33 + (0xff - 0x33) * (1.0 * rank + 1) / (MAX_BEST); + const uint8_t shift = (rank == MIN(ctx->num_best, MAX_BEST) - 1) ? 8 : 16; + const uint32_t color = (ch & 0xff) << shift; + // const uint32_t color = 0xff0000; + // D("rank = %zu, score = %0.3f, size = %zu, color = %06x", rank, score, set->num_rows, color); + + // draw clusters + km_set_draw(set, im_data, IM_WIDTH, IM_HEIGHT, 3, color); +} + +static void +ctx_save_png( + const ctx_t * const ctx, + const char * const png_path, + const km_set_t * const set +) { + // clear image data to white + memset(im_data, 0xff, sizeof(im_data)); + + // draw data points + km_set_draw(set, im_data, IM_WIDTH, IM_HEIGHT, 1, 0x000000); + if (!stbi_write_png(png_path, IM_WIDTH, IM_HEIGHT, 3, im_data, IM_STRIDE)) { + die("stbi_write_png(\"%s\")", png_path); + } + + // draw best cluster points + ctx_best_walk(ctx, save_on_best, (void*) ctx); + + // save png + if (!stbi_write_png(png_path, IM_WIDTH, IM_HEIGHT, 3, im_data, IM_STRIDE)) { + die("stbi_write_png(\"%s\")", png_path); + } +} + +static void +ctx_best_print_on_best( + const km_set_t * const set, + const size_t rank, + const float score, + void * const cb_data +) { + FILE * const fh = cb_data; + + fprintf(fh, "rank = %zu, score = %0.3f, num_clusters = %zu: [\n", rank, score, set->num_rows); + for (size_t i = 0; i < set->num_rows; i++) { + const float * const vals = km_set_get_row(set, i); + fprintf(fh, " ["); + + for (size_t j = 0; j < set->shape.num_floats; j++) { + fprintf(fh, "%s%0.3f", (j > 0) ? ", " : "", vals[j]); + } + + fprintf(fh, "], (%d rows)\n", set->ints[i]); + } + fprintf(fh, "]\n"); +} + +static void +ctx_best_print( + const ctx_t * const ctx, + FILE * const fh +) { + ctx_best_walk(ctx, ctx_best_print_on_best, fh); +} + +static const char USAGE_FORMAT[] = + "Usage: %s [init] [data_path] \n" + "\n" + "Arguments:\n" + "* init: Cluster init method (one of \"rand\" or \"set\").\n" + "* data_path: Path to input data file.\n" + "* png_path: Path to output file (optional).\n" + ""; + +int main(int argc, char *argv[]) { + // check command-line arguments + if (argc < 3) { + fprintf(stderr, USAGE_FORMAT, argv[0]); + return EXIT_FAILURE; + } + + // get command-line arguments + const char * const init_type_name = argv[1]; + const char * const data_path = argv[2]; + const char * const png_path = (argc > 3) ? argv[3] : NULL; + + // init random seed + srand(getpid()); + + // init context + ctx_t ctx; + memset(&ctx, 0, sizeof(ctx_t)); + ctx.init_type = km_init_get_type(init_type_name); + + // init ctx rng + km_rand_init_erand48(&(ctx.rs), rand()); + + // init data set + km_set_t set; + if (!km_load_path(data_path, &LOAD_CBS, &set)) { + die("km_load_path(\"%s\") failed", data_path); + } + + // normalize data set + if (!km_set_normalize(&set)) { + die("km_set_normalize() failed"); + } + + // find best solutions + if (!km_find(&set, &FIND_CBS, &ctx)) { + die("km_find()"); + } + + // print csv + ctx_csv_print(&ctx, stdout); + + // sort best results from lowest to highest + ctx_best_sort(&ctx); + + // print best + ctx_best_print(&ctx, stdout); + + if (png_path) { + // save png of normalized data set and best clusters + ctx_save_png(&ctx, png_path, &set); + } + + // finalize data set + km_set_fini(&set); + + // return success + return 0; +} -- cgit v1.2.3