diff options
Diffstat (limited to 'km-solve.c')
-rw-r--r-- | km-solve.c | 59 |
1 files changed, 37 insertions, 22 deletions
@@ -5,6 +5,8 @@ #include "util.h" #include "km.h" +#define MIN_CLUSTER_DISTANCE 0.00001 + // alloc and initialize row map static km_row_map_t * km_row_map_init( @@ -19,9 +21,11 @@ km_row_map_init( return false; } - // init row map by setting the maximum distance + // init row map for (size_t i = 0; i < num_rows; i++) { + // setting distances to maximum row_map[i].d2 = FLT_MAX; + row_map[i].d2_near = FLT_MAX; } // return row map @@ -84,6 +88,13 @@ km_solve( // flag change changed = true; } + + if ((row_map[j].cluster != i) && (d2 < row_map[j].d2_near)) { + row_map[j].d2_near = d2; + + // flag change + changed = true; + } } } @@ -131,44 +142,48 @@ km_solve( if (cbs && cbs->on_stats) { float sum = 0, - means[num_clusters], - variances[num_clusters]; + silouette = 0, + mean_dists[num_clusters], + mean_nears[num_clusters]; - memset(means, 0, sizeof(float) * num_clusters); - memset(variances, 0, sizeof(float) * num_clusters); + memset(mean_dists, 0, sizeof(mean_dists)); + memset(mean_nears, 0, sizeof(mean_nears)); // calculate sum of distances across all clusters for (size_t i = 0; i < set->num_rows; i++) { sum += row_map[i].d2; } - // calculate mean distances + // calculate mean numerators and silouette for (size_t i = 0; i < set->num_rows; i++) { - means[row_map[i].cluster] += row_map[i].d2; - } - // finalize means - for (size_t i = 0; i < num_clusters; i++) { - means[i] = (cs->ints[i]) ? (sqrt(means[i]) / cs->ints[i]) : 0; - } + // distance squared (d2) to center of this cluster + mean_dists[row_map[i].cluster] += row_map[i].d2; - // calculate variances - for (size_t i = 0; i < set->num_rows; i++) { - const size_t cluster = row_map[i].cluster; - const float variance = (sqrt(row_map[i].d2) - means[cluster]) * - (sqrt(row_map[i].d2) - means[cluster]); - variances[cluster] += variance; + // distance squared (d2) to center of nearest cluster + mean_nears[row_map[i].cluster] += row_map[i].d2_near; + + // calculate silouette denominator + const float delta = row_map[i].d2_near - row_map[i].d2; + if (fabsf(delta) > MIN_CLUSTER_DISTANCE) { + silouette += delta / MAX(row_map[i].d2, row_map[i].d2_near); + } } - // finalize variances + // finalize means (divide by row count) for (size_t i = 0; i < num_clusters; i++) { - variances[i] = (cs->ints[i]) ? (variances[i] / cs->ints[i]) : 0; + mean_dists[i] = (cs->ints[i]) ? (sqrt(mean_dists[i]) / cs->ints[i]) : 0; + mean_nears[i] = (cs->ints[i]) ? (sqrt(mean_nears[i]) / cs->ints[i]) : 0; } + // finalize silouette + silouette /= set->num_rows; + // build stats const km_solve_stats_t stats = { .sum = sum, - .means = means, - .variances = variances, + .silouette = silouette, + .mean_dists = mean_dists, + .mean_nears = mean_nears, .num_clusters = num_clusters, }; |