aboutsummaryrefslogtreecommitdiff
path: root/main.c
diff options
context:
space:
mode:
Diffstat (limited to 'main.c')
-rw-r--r--main.c154
1 files changed, 87 insertions, 67 deletions
diff --git a/main.c b/main.c
index 73d5655..f071804 100644
--- a/main.c
+++ b/main.c
@@ -14,7 +14,7 @@
#define MAX_CLUSTERS 10
#define NUM_TESTS 100
-#define MAX_BEST 4
+#define MAX_BEST 10
#define IM_WIDTH 128
#define IM_HEIGHT 128
@@ -26,6 +26,10 @@ typedef struct {
} best_item_t;
typedef struct {
+ // cluster initialization method
+ km_init_type_t init_type;
+
+ // random number source
km_rand_t rs;
struct {
@@ -35,9 +39,10 @@ typedef struct {
size_t num_empty_clusters;
} rows[MAX_CLUSTERS - 2];
+ // best clusters
best_item_t best[MAX_BEST];
size_t num_best;
-} find_t;
+} ctx_t;
static int
best_score_cmp(
@@ -47,32 +52,32 @@ best_score_cmp(
const best_item_t * const a = ap;
const best_item_t * const b = bp;
- return (a->score > b->score) ? -1 : 1;
+ return (a->score > b->score) ? 1 : -1;
}
static void
-find_sort_best(
- find_t * const find_data
+ctx_best_sort(
+ ctx_t * const ctx
) {
// sort best sets by ascending score (worst to best)
qsort(
- find_data->best,
- (find_data->num_best % MAX_BEST),
+ ctx->best,
+ (ctx->num_best % MAX_BEST),
sizeof(best_item_t),
best_score_cmp
);
}
static void
-find_each_best(
- const find_t * const find_data,
+ctx_best_each(
+ const ctx_t * const ctx,
void (*on_best)(const km_set_t * const, const size_t, const float, void *),
void * const cb_data
) {
if (on_best) {
// walk best sets and emit each one
- for (size_t i = 0; i < MIN(find_data->num_best, MAX_BEST); i++) {
- on_best(&(find_data->best[i].set), i, find_data->best[i].score, cb_data);
+ for (size_t i = 0; i < MIN(ctx->num_best, MAX_BEST); i++) {
+ on_best(&(ctx->best[i].set), i, ctx->best[i].score, cb_data);
}
}
}
@@ -84,7 +89,7 @@ load_on_shape(
) {
km_set_t * const set = cb_data;
- D("shape: %zu floats, %zu ints", shape->num_floats, shape->num_ints);
+ // D("shape: %zu floats, %zu ints", shape->num_floats, shape->num_ints);
// init set
if (!km_set_init(set, shape, 100)) {
@@ -131,12 +136,22 @@ LOAD_CBS = {
static bool
find_on_init(
km_set_t * const cs,
- const size_t num_floats,
const size_t num_clusters,
+ const km_set_t * const set,
void *cb_data
) {
- find_t *data = cb_data;
- return km_set_init_rand_clusters(cs, num_floats, num_clusters, &(data->rs));
+ ctx_t * const ctx = cb_data;
+ km_rand_t * const rs = &(ctx->rs);
+
+ switch(ctx->init_type) {
+ case KM_INIT_TYPE_RAND:
+ return km_init_rand(cs, num_clusters, set->shape.num_floats, rs);
+ case KM_INIT_TYPE_FORGY:
+ return km_init_forgy(cs, num_clusters, set, rs);
+ default:
+ die("unknown cluster init method");
+ return false;
+ }
}
static bool
@@ -154,13 +169,13 @@ find_on_data(
const km_find_data_t * const data,
void *cb_data
) {
- find_t * const find_data = cb_data;
+ ctx_t * const ctx = cb_data;
const size_t ofs = data->num_clusters - 2;
- find_data->rows[ofs].distance += data->mean_distance;
- find_data->rows[ofs].variance += data->mean_variance;
- find_data->rows[ofs].cluster_size += data->mean_cluster_size;
- find_data->rows[ofs].num_empty_clusters += data->num_empty_clusters;
+ ctx->rows[ofs].distance += data->mean_distance;
+ ctx->rows[ofs].variance += data->mean_variance;
+ ctx->rows[ofs].cluster_size += data->mean_cluster_size;
+ ctx->rows[ofs].num_empty_clusters += data->num_empty_clusters;
}
static bool
@@ -169,17 +184,17 @@ find_on_best(
const km_set_t * const cs,
void *cb_data
) {
- find_t * const find_data = cb_data;
+ ctx_t * const ctx = cb_data;
D("best score = %0.3f, num_clusters = %zu", score, cs->num_rows);
// get pointer to destination set
// (note: data->best is a ring buffer)
- km_set_t *dst = &(find_data->best[find_data->num_best % MAX_BEST].set);
+ km_set_t * const dst = &(ctx->best[ctx->num_best % MAX_BEST].set);
- if (find_data->num_best >= MAX_BEST) {
+ if (ctx->num_best >= MAX_BEST) {
// finalize old best data set
- // D("finalizing old best %zu", find_data->num_best);
+ // D("finalizing old best %zu", ctx->num_best);
km_set_fini(dst);
}
@@ -189,7 +204,7 @@ find_on_best(
}
// increment best count
- find_data->num_best++;
+ ctx->num_best++;
// return success
return true;
@@ -207,33 +222,22 @@ FIND_CBS = {
.on_best = find_on_best,
};
-static float
-get_score(
- const size_t ofs,
- const find_t * const find_data
-) {
- // const size_t num_clusters = ofs + 2;
- const float mean_distance = find_data->rows[ofs].distance / NUM_TESTS,
- mean_empty = 1.0 * find_data->rows[ofs].num_empty_clusters / NUM_TESTS;
-
- return 1.0 / (mean_distance + mean_empty);
-}
-
static void
print_csv_row(
- const size_t i,
- const find_t * const find_data
+ const ctx_t * const ctx,
+ const size_t i
) {
const size_t num_clusters = i + 2;
- const float mean_distance = find_data->rows[i].distance / NUM_TESTS,
- mean_variance = find_data->rows[i].variance / NUM_TESTS,
- mean_cluster_size = find_data->rows[i].cluster_size / NUM_TESTS,
- mean_empty_clusters = 1.0 * find_data->rows[i].num_empty_clusters / NUM_TESTS;
+ const float mean_distance = ctx->rows[i].distance / NUM_TESTS,
+ mean_variance = ctx->rows[i].variance / NUM_TESTS,
+ mean_cluster_size = ctx->rows[i].cluster_size / NUM_TESTS,
+ mean_empty_clusters = 1.0 * ctx->rows[i].num_empty_clusters / NUM_TESTS,
+ score = km_score(mean_distance, mean_empty_clusters);
// print result
printf("%zu,%0.3f,%0.3f,%0.3f,%0.3f,%0.3f\n",
num_clusters,
- get_score(i, find_data),
+ score,
mean_distance,
mean_variance,
mean_cluster_size,
@@ -243,7 +247,7 @@ print_csv_row(
static void
print_csv(
- const find_t * const find_data
+ const ctx_t * const ctx
) {
// print headers
printf(
@@ -256,7 +260,7 @@ print_csv(
);
for (size_t i = 0; i < MAX_CLUSTERS - 2; i++) {
- print_csv_row(i, find_data);
+ print_csv_row(ctx, i);
}
}
@@ -269,36 +273,36 @@ save_on_best(
const float score,
void * const cb_data
) {
+ UNUSED(score);
UNUSED(cb_data);
-
// convert rank to channel brightness
const uint8_t ch = 0x66 + (0xff - 0x66) * (1.0 * rank) / (MAX_BEST - 1);
const uint32_t color = (ch & 0xff) << 16;
// const uint32_t color = 0xff0000;
- D("rank = %zu, score = %0.3f, size = %zu, color = %06x", rank, score, set->num_rows, color);
+ // D("rank = %zu, score = %0.3f, size = %zu, color = %06x", rank, score, set->num_rows, color);
// draw clusters
km_set_draw(set, im_data, IM_WIDTH, IM_HEIGHT, 3, color);
}
static void
-save_png(
+ctx_save_png(
+ const ctx_t * const ctx,
const char * const png_path,
- const km_set_t * const set,
- const find_t * const find_data
+ const km_set_t * const set
) {
// clear image data to white
memset(im_data, 0xff, sizeof(im_data));
- // draw red points
+ // draw data points
km_set_draw(set, im_data, IM_WIDTH, IM_HEIGHT, 1, 0x000000);
if (!stbi_write_png(png_path, IM_WIDTH, IM_HEIGHT, 3, im_data, IM_STRIDE)) {
die("stbi_write_png(\"%s\")", png_path);
}
// draw best cluster points
- find_each_best(find_data, save_on_best, NULL);
+ ctx_best_each(ctx, save_on_best, NULL);
// save png
if (!stbi_write_png(png_path, IM_WIDTH, IM_HEIGHT, 3, im_data, IM_STRIDE)) {
@@ -306,45 +310,61 @@ save_png(
}
}
+static const char USAGE_FORMAT[] =
+ "Usage: %s [init] [data_path] <png_path>\n"
+ "\n"
+ "Arguments:\n"
+ "* init: Cluster init method (one of \"rand\" or \"set\").\n"
+ "* data_path: Path to input data file.\n"
+ "* png_path: Path to output file (optional).\n"
+ "";
+
int main(int argc, char *argv[]) {
- // check command-line
- if (argc < 2) {
- fprintf(stderr, "Usage: %s <data_path> <png_path>\n", argv[0]);
+ // check command-line arguments
+ if (argc < 3) {
+ fprintf(stderr, USAGE_FORMAT, argv[0]);
return EXIT_FAILURE;
}
+ // get command-line arguments
+ const char * const init_type_name = argv[1];
+ const char * const data_path = argv[2];
+ const char * const png_path = (argc > 3) ? argv[3] : NULL;
+
// init random seed
srand(getpid());
- // init find data
- find_t find_data;
- memset(&find_data, 0, sizeof(find_t));
- km_rand_init_system(&(find_data.rs));
+ // init context
+ ctx_t ctx;
+ memset(&ctx, 0, sizeof(ctx_t));
+ km_rand_init_system(&(ctx.rs));
+ ctx.init_type = km_get_init_type(init_type_name);
// init data set
km_set_t set;
- if (!km_load_path(argv[1], &LOAD_CBS, &set)) {
- die("km_load_path() failed");
+ if (!km_load_path(data_path, &LOAD_CBS, &set)) {
+ die("km_load_path(\"%s\") failed", data_path);
}
+ // init data set
if (!km_set_normalize(&set)) {
die("km_set_normalize() failed");
}
// find best solutions
- if (!km_find(&set, &FIND_CBS, &find_data)) {
+ if (!km_find(&set, &FIND_CBS, &ctx)) {
die("km_find()");
}
// print csv
- print_csv(&find_data);
+ print_csv(&ctx);
// sort best results from lowest to highest
- find_sort_best(&find_data);
+ ctx_best_sort(&ctx);
- if (argc > 2) {
+ if (png_path) {
// save png of normalized data set and best clusters
- save_png(argv[2], &set, &find_data);
+ ctx_save_png(&ctx, png_path, &set);
}
// finalize data set