diff options
Diffstat (limited to 'src/km-solve.c')
-rw-r--r-- | src/km-solve.c | 72 |
1 files changed, 26 insertions, 46 deletions
diff --git a/src/km-solve.c b/src/km-solve.c index b6d98d6..271795b 100644 --- a/src/km-solve.c +++ b/src/km-solve.c @@ -6,36 +6,18 @@ #define MIN_DELTA 0.00001 -// alloc and initialize row map -static km_row_map_t * -km_row_map_init( +static void +km_row_map_clear( + km_row_map_t * const row_map, const size_t num_rows ) { - // alloc row map - km_row_map_t * const row_map = malloc(sizeof(km_row_map_t) * num_rows); - - // check for error - if (!row_map) { - // return failure - return false; - } - // init row map + #pragma omp parallel for for (size_t i = 0; i < num_rows; i++) { // setting distances to maximum row_map[i].d2 = FLT_MAX; row_map[i].d2_near = FLT_MAX; } - - // return row map - return row_map; -} - -static void -km_row_map_fini( - km_row_map_t * const row_map -) { - free(row_map); } // use k-means to iteratively update the cluster centroids until there @@ -44,18 +26,15 @@ bool km_solve( km_set_t * const cs, const km_set_t * const set, + km_row_map_t * const row_map, const km_solve_cbs_t * const cbs, void * const cb_data ) { const size_t num_clusters = cs->num_rows, num_floats = set->shape.num_floats; - // row map: row => distance, cluster ID - km_row_map_t *row_map = km_row_map_init(set->num_rows); - if (!row_map) { - // return failure - return false; - } + // clear row map distances + km_row_map_clear(row_map, set->num_rows); // calculate clusters by doing the following: // * walk all clusters and all rows @@ -68,11 +47,13 @@ km_solve( // no changes yet changed = false; + #pragma omp parallel for collapse(2) for (size_t i = 0; i < num_clusters; i++) { - // get the floats for this cluster - const float * const floats = km_set_get_row(cs, i); - for (size_t j = 0; j < set->num_rows; j++) { + // get the floats for this cluster + // NOTE: moved to inner loop for "collapse(2)" above + const float * const floats = km_set_get_row(cs, i); + // get row values const float * const vals = km_set_get_row(set, j); @@ -101,7 +82,8 @@ km_solve( // if there were changes, then we need to calculate the new // cluster centers - // calculate new center + // calculate new cluster centers + #pragma omp parallel for for (size_t i = 0; i < num_clusters; i++) { size_t num_rows = 0; float * const floats = km_set_get_row(cs, i); @@ -141,30 +123,31 @@ km_solve( if (cbs && cbs->on_stats) { float sum = 0, - silouette = 0; + silhouette = 0; - // calculate mean numerators and silouette + // calculate mean numerators and silhouette + #pragma omp parallel for reduction(+:sum,silhouette) for (size_t i = 0; i < set->num_rows; i++) { - // sum distances - sum += row_map[i].d2; - - // calculate silouette denominator + // calculate silhouette denominator // (https://en.wikipedia.org/wiki/Silhouette_%28clustering%29) const float delta = (row_map[i].d2_near - row_map[i].d2); - // sum silouette - silouette += (delta < -MIN_DELTA || delta > MIN_DELTA) + // sum silhouette + silhouette += (delta < -MIN_DELTA || delta > MIN_DELTA) ? (delta / MAX(row_map[i].d2, row_map[i].d2_near)) : 0.0; + + // sum distances + sum += row_map[i].d2; } - // finalize silouette - silouette /= set->num_rows; + // finalize silhouette + silhouette /= set->num_rows; // build stats const km_solve_stats_t stats = { .sum = sum, - .silouette = silouette, + .silhouette = silhouette, .num_clusters = num_clusters, }; @@ -172,9 +155,6 @@ km_solve( cbs->on_stats(cs, &stats, cb_data); } - // free row_map - km_row_map_fini(row_map); - // return success return true; } |