diff options
Diffstat (limited to 'main.c')
-rw-r--r-- | main.c | 278 |
1 files changed, 219 insertions, 59 deletions
@@ -1,88 +1,248 @@ +#include <stdbool.h> // bool #include <stdio.h> // fprintf() #include <unistd.h> // EXIT_{FAILURE,SUCCESS} #include <stdlib.h> // exit() -#include "km.h" +#include <string.h> // memset() +#include <math.h> // fabsf() -#define UNUSED(a) ((void) (a)) - -// >> puts [[-1, -1], [-1, 1], [1, 1], [1, -1]].reduce([]) { |r, c| 5.times { r << [0.5 + 0.25 * c[0] + 0.4 * rand - 0.2, 0.5 + 0.25 * c[1] + 0.4 * rand - 0.2] }; r }.map { |row| '%1.3f, %1.3f,' % row } -static const float -DATA[] = { - 0.132, 0.190, - 0.313, 0.187, - 0.076, 0.276, - 0.443, 0.414, - 0.136, 0.060, - 0.344, 0.815, - 0.259, 0.760, - 0.211, 0.949, - 0.103, 0.903, - 0.173, 0.818, - 0.873, 0.735, - 0.783, 0.593, - 0.845, 0.674, - 0.808, 0.868, - 0.871, 0.947, - 0.917, 0.090, - 0.691, 0.058, - 0.840, 0.357, - 0.783, 0.275, - 0.807, 0.336, -}; +#define STB_IMAGE_WRITE_IMPLEMENTATION +#define STB_ONLY_PNG +#include "stb_image_write.h" -#define die(...) do { \ - fputs("FATAL: ", stderr); \ - fprintf(stderr, __VA_ARGS__); \ - fputs("\n", stderr); \ - exit(EXIT_FAILURE); \ -} while (0) +#include "util.h" +#include "km.h" #define MAX_CLUSTERS 10 -#define NUM_TESTS 10 +#define NUM_TESTS 100 -static const km_shape_t -SHAPE = { - .num_floats = 2, - .num_ints = 0, -}; +typedef struct { + km_rand_src_t rs; + + struct { + float distance, + variance, + cluster_size; + size_t num_empty_clusters; + } rows[MAX_CLUSTERS - 2]; +} find_t; + +static bool +load_on_shape( + const km_shape_t * const shape, + void * const cb_data +) { + km_set_t * const set = cb_data; + + fprintf(stderr, "DEBUG: shape = { %zu, %zu }\n", shape->num_floats, shape->num_ints); + + // init set + if (!km_set_init(set, shape, 100)) { + die("km_set_init() failed"); + } + + // return success + return true; +} + +static bool +load_on_row( + const float * const floats, + const int * const ints, + void * const cb_data +) { + km_set_t * const set = cb_data; + + // push row + if (!km_set_push(set, 1, floats, ints)) { + die("km_set_push_rows() failed"); + } + + // return success + return true; +} static void -on_search_row( - const km_search_row_t * const row, +load_on_error( + const char * const err, + void * const cb_data +) { + UNUSED(cb_data); + die(err); +} + +static const km_load_cbs_t +LOAD_CBS = { + .on_shape = load_on_shape, + .on_row = load_on_row, + .on_error = load_on_error, +}; + +static bool +find_on_init( + km_set_t * const cs, + const size_t num_floats, + const size_t num_clusters, + void *cb_data +) { + find_t *data = cb_data; + return km_set_init_rand_clusters(cs, num_floats, num_clusters, &(data->rs)); +} + +static bool +find_on_fini( + km_set_t * const cs, void *cb_data ) { UNUSED(cb_data); + km_set_fini(cs); + return true; +} + +static void +find_on_data( + const km_find_data_t * const data, + void *cb_data +) { + find_t * const find_data = cb_data; + const size_t ofs = data->num_clusters - 2; + + find_data->rows[ofs].distance += data->mean_distance; + find_data->rows[ofs].variance += data->mean_variance; + find_data->rows[ofs].cluster_size += data->mean_cluster_size; + find_data->rows[ofs].num_empty_clusters += data->num_empty_clusters; +} - printf("%zu,%0.5f,%zu\n", - row->num_clusters, - row->mean_distance, - row->num_empty_clusters +// init find config +static const km_find_cbs_t +FIND_CBS = { + .max_clusters = MAX_CLUSTERS, + .num_tests = NUM_TESTS, + .on_init = find_on_init, + .on_fini = find_on_fini, + .on_data = find_on_data, +}; + +static float +get_score( + const size_t ofs, + const find_t * const find_data +) { + if (!ofs || ofs == MAX_CLUSTERS - 3) { + return 0; + } + // const size_t num_clusters = ofs + 2; + const float mean_empty_clusters = 1.0 * find_data->rows[ofs].num_empty_clusters / NUM_TESTS; + + const float ds[3] = { + find_data->rows[ofs - 1].distance / NUM_TESTS, + find_data->rows[ofs + 0].distance / NUM_TESTS, + find_data->rows[ofs + 1].distance / NUM_TESTS, + }; + + return ( + (fabsf(ds[0] - ds[1]) / fabsf(ds[1] - ds[2])) + + -2.0 * mean_empty_clusters + ); +} + +static void +print_csv_row( + const size_t i, + const find_t * const find_data +) { + const size_t num_clusters = i + 2; + const float mean_distance = find_data->rows[i].distance / NUM_TESTS, + mean_variance = find_data->rows[i].variance / NUM_TESTS, + mean_cluster_size = find_data->rows[i].cluster_size / NUM_TESTS, + mean_empty_clusters = 1.0 * find_data->rows[i].num_empty_clusters / NUM_TESTS; + + // print result + printf("%zu,%0.3f,%0.3f,%0.3f,%0.3f,%0.3f\n", + num_clusters, + get_score(i, find_data), + mean_distance, + mean_variance, + mean_cluster_size, + mean_empty_clusters + ); +} + +static void +print_csv( + const find_t * const find_data +) { + // print headers + printf( + "#," + "score," + "distance," + "variance," + "cluster_size," + "empty_clusters\n" ); + + for (size_t i = 0; i < MAX_CLUSTERS - 2; i++) { + print_csv_row(i, find_data); + } +} + +#define IM_WIDTH 128 +#define IM_HEIGHT 128 +#define IM_STRIDE (3 * IM_WIDTH) + +static uint8_t im_data[3 * IM_WIDTH * IM_HEIGHT]; + +static void +save_png( + const char * const png_path, + const km_set_t * const set +) { + // clear image data to white + memset(im_data, 0xff, sizeof(im_data)); + + // draw red points + km_set_draw(set, im_data, IM_WIDTH, IM_HEIGHT, 0xff0000); + if (!stbi_write_png(png_path, IM_WIDTH, IM_HEIGHT, 3, im_data, IM_STRIDE)) { + die("stbi_write_png(\"%s\")", png_path); + } } int main(int argc, char *argv[]) { - km_set_t set; - UNUSED(argc); - UNUSED(argv); + // check command-line + if (argc < 2) { + fprintf(stderr, "Usage: %s <data>\n", argv[0]); + return EXIT_FAILURE; + } + + // init random seed + srand(getpid()); + + // init find data + find_t find_data; + memset(find_data.rows, 0, sizeof(find_data.rows)); + km_rand_src_system_init(&(find_data.rs)); // init data set - if (!km_set_init(&set, &SHAPE, 20)) { - die("km_set_init() failed"); + km_set_t set; + if (!km_load_path(argv[1], &LOAD_CBS, &set)) { + die("km_load_path() failed"); } - // push rows - if (!km_set_push_rows(&set, 20, DATA, NULL)) { - die("km_set_push_rows() failed"); + if (!km_set_normalize(&set)) { + die("km_set_normalize() failed"); } - // print headers - printf("num_clusters,mean_distance,num_empty_clusters\n"); - - // search for best solution - if (!km_search(&set, MAX_CLUSTERS, NUM_TESTS, on_search_row, NULL)) { - die("km_search()"); + // find best solution + if (!km_find(&set, &FIND_CBS, &find_data)) { + die("km_find()"); } + // print csv + print_csv(&find_data); + + // save png of data set + save_png("data.png", &set); + // finalize data set km_set_fini(&set); |