aboutsummaryrefslogtreecommitdiff
path: root/src/km-solve.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/km-solve.c')
-rw-r--r--src/km-solve.c72
1 files changed, 26 insertions, 46 deletions
diff --git a/src/km-solve.c b/src/km-solve.c
index b6d98d6..271795b 100644
--- a/src/km-solve.c
+++ b/src/km-solve.c
@@ -6,36 +6,18 @@
#define MIN_DELTA 0.00001
-// alloc and initialize row map
-static km_row_map_t *
-km_row_map_init(
+static void
+km_row_map_clear(
+ km_row_map_t * const row_map,
const size_t num_rows
) {
- // alloc row map
- km_row_map_t * const row_map = malloc(sizeof(km_row_map_t) * num_rows);
-
- // check for error
- if (!row_map) {
- // return failure
- return false;
- }
-
// init row map
+ #pragma omp parallel for
for (size_t i = 0; i < num_rows; i++) {
// setting distances to maximum
row_map[i].d2 = FLT_MAX;
row_map[i].d2_near = FLT_MAX;
}
-
- // return row map
- return row_map;
-}
-
-static void
-km_row_map_fini(
- km_row_map_t * const row_map
-) {
- free(row_map);
}
// use k-means to iteratively update the cluster centroids until there
@@ -44,18 +26,15 @@ bool
km_solve(
km_set_t * const cs,
const km_set_t * const set,
+ km_row_map_t * const row_map,
const km_solve_cbs_t * const cbs,
void * const cb_data
) {
const size_t num_clusters = cs->num_rows,
num_floats = set->shape.num_floats;
- // row map: row => distance, cluster ID
- km_row_map_t *row_map = km_row_map_init(set->num_rows);
- if (!row_map) {
- // return failure
- return false;
- }
+ // clear row map distances
+ km_row_map_clear(row_map, set->num_rows);
// calculate clusters by doing the following:
// * walk all clusters and all rows
@@ -68,11 +47,13 @@ km_solve(
// no changes yet
changed = false;
+ #pragma omp parallel for collapse(2)
for (size_t i = 0; i < num_clusters; i++) {
- // get the floats for this cluster
- const float * const floats = km_set_get_row(cs, i);
-
for (size_t j = 0; j < set->num_rows; j++) {
+ // get the floats for this cluster
+ // NOTE: moved to inner loop for "collapse(2)" above
+ const float * const floats = km_set_get_row(cs, i);
+
// get row values
const float * const vals = km_set_get_row(set, j);
@@ -101,7 +82,8 @@ km_solve(
// if there were changes, then we need to calculate the new
// cluster centers
- // calculate new center
+ // calculate new cluster centers
+ #pragma omp parallel for
for (size_t i = 0; i < num_clusters; i++) {
size_t num_rows = 0;
float * const floats = km_set_get_row(cs, i);
@@ -141,30 +123,31 @@ km_solve(
if (cbs && cbs->on_stats) {
float sum = 0,
- silouette = 0;
+ silhouette = 0;
- // calculate mean numerators and silouette
+ // calculate mean numerators and silhouette
+ #pragma omp parallel for reduction(+:sum,silhouette)
for (size_t i = 0; i < set->num_rows; i++) {
- // sum distances
- sum += row_map[i].d2;
-
- // calculate silouette denominator
+ // calculate silhouette denominator
// (https://en.wikipedia.org/wiki/Silhouette_%28clustering%29)
const float delta = (row_map[i].d2_near - row_map[i].d2);
- // sum silouette
- silouette += (delta < -MIN_DELTA || delta > MIN_DELTA)
+ // sum silhouette
+ silhouette += (delta < -MIN_DELTA || delta > MIN_DELTA)
? (delta / MAX(row_map[i].d2, row_map[i].d2_near))
: 0.0;
+
+ // sum distances
+ sum += row_map[i].d2;
}
- // finalize silouette
- silouette /= set->num_rows;
+ // finalize silhouette
+ silhouette /= set->num_rows;
// build stats
const km_solve_stats_t stats = {
.sum = sum,
- .silouette = silouette,
+ .silhouette = silhouette,
.num_clusters = num_clusters,
};
@@ -172,9 +155,6 @@ km_solve(
cbs->on_stats(cs, &stats, cb_data);
}
- // free row_map
- km_row_map_fini(row_map);
-
// return success
return true;
}