#include // bool #include // fprintf() #include // EXIT_{FAILURE,SUCCESS} #include // exit() #include // memset() #include // fabsf() #define STB_IMAGE_WRITE_IMPLEMENTATION #define STB_ONLY_PNG #include "stb_image_write.h" #include "util.h" #include "km.h" #define MAX_CLUSTERS 10 #define NUM_TESTS 100 #define MAX_BEST 4 #define IM_WIDTH 128 #define IM_HEIGHT 128 #define IM_STRIDE (3 * IM_WIDTH) typedef struct { km_rand_t rs; struct { float distance, variance, cluster_size; size_t num_empty_clusters; } rows[MAX_CLUSTERS - 2]; km_set_t best[MAX_BEST]; size_t num_best; } find_t; static void find_each_best( const find_t * const find_data, void (*on_best)(const km_set_t * const, const size_t, void *), void * const cb_data ) { // if the number of best sets is greater than MAX_BEST, then // find_data->best is actually a ring buffer const bool is_ring = find_data->num_best >= MAX_BEST; if (!on_best) { return; } // walk best sets for (size_t i = 0; i < MAX_BEST; i++) { // calculate set offset const size_t ofs = i + (is_ring ? (find_data->num_best + 1) : 0); // emit set on_best(find_data->best + (ofs % MAX_BEST), i, cb_data); } } static bool load_on_shape( const km_shape_t * const shape, void * const cb_data ) { km_set_t * const set = cb_data; fprintf(stderr, "DEBUG: shape = { %zu, %zu }\n", shape->num_floats, shape->num_ints); // init set if (!km_set_init(set, shape, 100)) { die("km_set_init() failed"); } // return success return true; } static bool load_on_row( const float * const floats, const int * const ints, void * const cb_data ) { km_set_t * const set = cb_data; // push row if (!km_set_push(set, 1, floats, ints)) { die("km_set_push_rows() failed"); } // return success return true; } static void load_on_error( const char * const err, void * const cb_data ) { UNUSED(cb_data); die("load failed: %s", err); } static const km_load_cbs_t LOAD_CBS = { .on_shape = load_on_shape, .on_row = load_on_row, .on_error = load_on_error, }; static bool find_on_init( km_set_t * const cs, const size_t num_floats, const size_t num_clusters, void *cb_data ) { find_t *data = cb_data; return km_set_init_rand_clusters(cs, num_floats, num_clusters, &(data->rs)); } static bool find_on_fini( km_set_t * const cs, void *cb_data ) { UNUSED(cb_data); km_set_fini(cs); return true; } static void find_on_data( const km_find_data_t * const data, void *cb_data ) { find_t * const find_data = cb_data; const size_t ofs = data->num_clusters - 2; find_data->rows[ofs].distance += data->mean_distance; find_data->rows[ofs].variance += data->mean_variance; find_data->rows[ofs].cluster_size += data->mean_cluster_size; find_data->rows[ofs].num_empty_clusters += data->num_empty_clusters; } static bool find_on_best( const float score, const km_set_t * const cs, void *cb_data ) { find_t *find_data = cb_data; D("best score = %0.3f, num_clusters = %zu", score, cs->num_rows); // get pointer to destination set // (note: data->best is a ring buffer) km_set_t *dst = find_data->best + (find_data->num_best % MAX_BEST); if (find_data->num_best >= MAX_BEST) { // finalize old best data set // D("finalizing old best %zu", find_data->num_best); km_set_fini(dst); } // copy data set to best ring buffer if (!km_set_copy(dst, cs)) { die("km_set_copy()"); } // increment best count find_data->num_best++; // return success return true; } // init find config static const km_find_cbs_t FIND_CBS = { .max_clusters = MAX_CLUSTERS, .num_tests = NUM_TESTS, .on_init = find_on_init, .on_fini = find_on_fini, .on_data = find_on_data, .on_best = find_on_best, }; static float get_score( const size_t ofs, const find_t * const find_data ) { // const size_t num_clusters = ofs + 2; const float mean_distance = find_data->rows[ofs].distance / NUM_TESTS, mean_empty = 1.0 * find_data->rows[ofs].num_empty_clusters / NUM_TESTS; return 1.0 / (mean_distance + mean_empty); } static void print_csv_row( const size_t i, const find_t * const find_data ) { const size_t num_clusters = i + 2; const float mean_distance = find_data->rows[i].distance / NUM_TESTS, mean_variance = find_data->rows[i].variance / NUM_TESTS, mean_cluster_size = find_data->rows[i].cluster_size / NUM_TESTS, mean_empty_clusters = 1.0 * find_data->rows[i].num_empty_clusters / NUM_TESTS; // print result printf("%zu,%0.3f,%0.3f,%0.3f,%0.3f,%0.3f\n", num_clusters, get_score(i, find_data), mean_distance, mean_variance, mean_cluster_size, mean_empty_clusters ); } static void print_csv( const find_t * const find_data ) { // print headers printf( "#," "score," "distance," "variance," "cluster_size," "empty_clusters\n" ); for (size_t i = 0; i < MAX_CLUSTERS - 2; i++) { print_csv_row(i, find_data); } } static uint8_t im_data[3 * IM_WIDTH * IM_HEIGHT]; static void save_on_best( const km_set_t * const set, const size_t rank, void * const cb_data ) { UNUSED(cb_data); // convert rank to channel brightness const uint8_t ch = 0x66 + (0xff - 0x66) * (1.0 * rank) / (MAX_BEST - 1); const uint32_t color = (ch & 0xff) << 16; // const uint32_t color = 0xff0000; D("rank = %zu, color = %u", rank, color); // draw clusters km_set_draw(set, im_data, IM_WIDTH, IM_HEIGHT, 3, color); } static void save_png( const char * const png_path, const km_set_t * const set, const find_t * const find_data ) { // clear image data to white memset(im_data, 0xff, sizeof(im_data)); // draw red points km_set_draw(set, im_data, IM_WIDTH, IM_HEIGHT, 1, 0x000000); if (!stbi_write_png(png_path, IM_WIDTH, IM_HEIGHT, 3, im_data, IM_STRIDE)) { die("stbi_write_png(\"%s\")", png_path); } // draw best cluster points find_each_best(find_data, save_on_best, NULL); // save png if (!stbi_write_png(png_path, IM_WIDTH, IM_HEIGHT, 3, im_data, IM_STRIDE)) { die("stbi_write_png(\"%s\")", png_path); } } int main(int argc, char *argv[]) { // check command-line if (argc < 2) { fprintf(stderr, "Usage: %s \n", argv[0]); return EXIT_FAILURE; } // init random seed srand(getpid()); // init find data find_t find_data; memset(&find_data, 0, sizeof(find_t)); km_rand_init_system(&(find_data.rs)); // init data set km_set_t set; if (!km_load_path(argv[1], &LOAD_CBS, &set)) { die("km_load_path() failed"); } if (!km_set_normalize(&set)) { die("km_set_normalize() failed"); } // find best solution if (!km_find(&set, &FIND_CBS, &find_data)) { die("km_find()"); } // print csv print_csv(&find_data); // save png of normalized data set and best clusters save_png("data.png", &set, &find_data); // finalize data set km_set_fini(&set); // return success return 0; }