aboutsummaryrefslogtreecommitdiff
path: root/main.c
diff options
context:
space:
mode:
Diffstat (limited to 'main.c')
-rw-r--r--main.c278
1 files changed, 219 insertions, 59 deletions
diff --git a/main.c b/main.c
index 143c14a..bd4192f 100644
--- a/main.c
+++ b/main.c
@@ -1,88 +1,248 @@
+#include <stdbool.h> // bool
#include <stdio.h> // fprintf()
#include <unistd.h> // EXIT_{FAILURE,SUCCESS}
#include <stdlib.h> // exit()
-#include "km.h"
+#include <string.h> // memset()
+#include <math.h> // fabsf()
-#define UNUSED(a) ((void) (a))
-
-// >> puts [[-1, -1], [-1, 1], [1, 1], [1, -1]].reduce([]) { |r, c| 5.times { r << [0.5 + 0.25 * c[0] + 0.4 * rand - 0.2, 0.5 + 0.25 * c[1] + 0.4 * rand - 0.2] }; r }.map { |row| '%1.3f, %1.3f,' % row }
-static const float
-DATA[] = {
- 0.132, 0.190,
- 0.313, 0.187,
- 0.076, 0.276,
- 0.443, 0.414,
- 0.136, 0.060,
- 0.344, 0.815,
- 0.259, 0.760,
- 0.211, 0.949,
- 0.103, 0.903,
- 0.173, 0.818,
- 0.873, 0.735,
- 0.783, 0.593,
- 0.845, 0.674,
- 0.808, 0.868,
- 0.871, 0.947,
- 0.917, 0.090,
- 0.691, 0.058,
- 0.840, 0.357,
- 0.783, 0.275,
- 0.807, 0.336,
-};
+#define STB_IMAGE_WRITE_IMPLEMENTATION
+#define STB_ONLY_PNG
+#include "stb_image_write.h"
-#define die(...) do { \
- fputs("FATAL: ", stderr); \
- fprintf(stderr, __VA_ARGS__); \
- fputs("\n", stderr); \
- exit(EXIT_FAILURE); \
-} while (0)
+#include "util.h"
+#include "km.h"
#define MAX_CLUSTERS 10
-#define NUM_TESTS 10
+#define NUM_TESTS 100
-static const km_shape_t
-SHAPE = {
- .num_floats = 2,
- .num_ints = 0,
-};
+typedef struct {
+ km_rand_src_t rs;
+
+ struct {
+ float distance,
+ variance,
+ cluster_size;
+ size_t num_empty_clusters;
+ } rows[MAX_CLUSTERS - 2];
+} find_t;
+
+static bool
+load_on_shape(
+ const km_shape_t * const shape,
+ void * const cb_data
+) {
+ km_set_t * const set = cb_data;
+
+ fprintf(stderr, "DEBUG: shape = { %zu, %zu }\n", shape->num_floats, shape->num_ints);
+
+ // init set
+ if (!km_set_init(set, shape, 100)) {
+ die("km_set_init() failed");
+ }
+
+ // return success
+ return true;
+}
+
+static bool
+load_on_row(
+ const float * const floats,
+ const int * const ints,
+ void * const cb_data
+) {
+ km_set_t * const set = cb_data;
+
+ // push row
+ if (!km_set_push(set, 1, floats, ints)) {
+ die("km_set_push_rows() failed");
+ }
+
+ // return success
+ return true;
+}
static void
-on_search_row(
- const km_search_row_t * const row,
+load_on_error(
+ const char * const err,
+ void * const cb_data
+) {
+ UNUSED(cb_data);
+ die(err);
+}
+
+static const km_load_cbs_t
+LOAD_CBS = {
+ .on_shape = load_on_shape,
+ .on_row = load_on_row,
+ .on_error = load_on_error,
+};
+
+static bool
+find_on_init(
+ km_set_t * const cs,
+ const size_t num_floats,
+ const size_t num_clusters,
+ void *cb_data
+) {
+ find_t *data = cb_data;
+ return km_set_init_rand_clusters(cs, num_floats, num_clusters, &(data->rs));
+}
+
+static bool
+find_on_fini(
+ km_set_t * const cs,
void *cb_data
) {
UNUSED(cb_data);
+ km_set_fini(cs);
+ return true;
+}
+
+static void
+find_on_data(
+ const km_find_data_t * const data,
+ void *cb_data
+) {
+ find_t * const find_data = cb_data;
+ const size_t ofs = data->num_clusters - 2;
+
+ find_data->rows[ofs].distance += data->mean_distance;
+ find_data->rows[ofs].variance += data->mean_variance;
+ find_data->rows[ofs].cluster_size += data->mean_cluster_size;
+ find_data->rows[ofs].num_empty_clusters += data->num_empty_clusters;
+}
- printf("%zu,%0.5f,%zu\n",
- row->num_clusters,
- row->mean_distance,
- row->num_empty_clusters
+// init find config
+static const km_find_cbs_t
+FIND_CBS = {
+ .max_clusters = MAX_CLUSTERS,
+ .num_tests = NUM_TESTS,
+ .on_init = find_on_init,
+ .on_fini = find_on_fini,
+ .on_data = find_on_data,
+};
+
+static float
+get_score(
+ const size_t ofs,
+ const find_t * const find_data
+) {
+ if (!ofs || ofs == MAX_CLUSTERS - 3) {
+ return 0;
+ }
+ // const size_t num_clusters = ofs + 2;
+ const float mean_empty_clusters = 1.0 * find_data->rows[ofs].num_empty_clusters / NUM_TESTS;
+
+ const float ds[3] = {
+ find_data->rows[ofs - 1].distance / NUM_TESTS,
+ find_data->rows[ofs + 0].distance / NUM_TESTS,
+ find_data->rows[ofs + 1].distance / NUM_TESTS,
+ };
+
+ return (
+ (fabsf(ds[0] - ds[1]) / fabsf(ds[1] - ds[2])) +
+ -2.0 * mean_empty_clusters
+ );
+}
+
+static void
+print_csv_row(
+ const size_t i,
+ const find_t * const find_data
+) {
+ const size_t num_clusters = i + 2;
+ const float mean_distance = find_data->rows[i].distance / NUM_TESTS,
+ mean_variance = find_data->rows[i].variance / NUM_TESTS,
+ mean_cluster_size = find_data->rows[i].cluster_size / NUM_TESTS,
+ mean_empty_clusters = 1.0 * find_data->rows[i].num_empty_clusters / NUM_TESTS;
+
+ // print result
+ printf("%zu,%0.3f,%0.3f,%0.3f,%0.3f,%0.3f\n",
+ num_clusters,
+ get_score(i, find_data),
+ mean_distance,
+ mean_variance,
+ mean_cluster_size,
+ mean_empty_clusters
+ );
+}
+
+static void
+print_csv(
+ const find_t * const find_data
+) {
+ // print headers
+ printf(
+ "#,"
+ "score,"
+ "distance,"
+ "variance,"
+ "cluster_size,"
+ "empty_clusters\n"
);
+
+ for (size_t i = 0; i < MAX_CLUSTERS - 2; i++) {
+ print_csv_row(i, find_data);
+ }
+}
+
+#define IM_WIDTH 128
+#define IM_HEIGHT 128
+#define IM_STRIDE (3 * IM_WIDTH)
+
+static uint8_t im_data[3 * IM_WIDTH * IM_HEIGHT];
+
+static void
+save_png(
+ const char * const png_path,
+ const km_set_t * const set
+) {
+ // clear image data to white
+ memset(im_data, 0xff, sizeof(im_data));
+
+ // draw red points
+ km_set_draw(set, im_data, IM_WIDTH, IM_HEIGHT, 0xff0000);
+ if (!stbi_write_png(png_path, IM_WIDTH, IM_HEIGHT, 3, im_data, IM_STRIDE)) {
+ die("stbi_write_png(\"%s\")", png_path);
+ }
}
int main(int argc, char *argv[]) {
- km_set_t set;
- UNUSED(argc);
- UNUSED(argv);
+ // check command-line
+ if (argc < 2) {
+ fprintf(stderr, "Usage: %s <data>\n", argv[0]);
+ return EXIT_FAILURE;
+ }
+
+ // init random seed
+ srand(getpid());
+
+ // init find data
+ find_t find_data;
+ memset(find_data.rows, 0, sizeof(find_data.rows));
+ km_rand_src_system_init(&(find_data.rs));
// init data set
- if (!km_set_init(&set, &SHAPE, 20)) {
- die("km_set_init() failed");
+ km_set_t set;
+ if (!km_load_path(argv[1], &LOAD_CBS, &set)) {
+ die("km_load_path() failed");
}
- // push rows
- if (!km_set_push_rows(&set, 20, DATA, NULL)) {
- die("km_set_push_rows() failed");
+ if (!km_set_normalize(&set)) {
+ die("km_set_normalize() failed");
}
- // print headers
- printf("num_clusters,mean_distance,num_empty_clusters\n");
-
- // search for best solution
- if (!km_search(&set, MAX_CLUSTERS, NUM_TESTS, on_search_row, NULL)) {
- die("km_search()");
+ // find best solution
+ if (!km_find(&set, &FIND_CBS, &find_data)) {
+ die("km_find()");
}
+ // print csv
+ print_csv(&find_data);
+
+ // save png of data set
+ save_png("data.png", &set);
+
// finalize data set
km_set_fini(&set);