diff options
Diffstat (limited to 'tests/bench/bench.c')
| -rw-r--r-- | tests/bench/bench.c | 222 | 
1 files changed, 222 insertions, 0 deletions
diff --git a/tests/bench/bench.c b/tests/bench/bench.c new file mode 100644 index 0000000..9144e91 --- /dev/null +++ b/tests/bench/bench.c @@ -0,0 +1,222 @@ +// +// Benchmark all three ML-KEM parameter sets and print summary +// statistics to standard output in CSV format. +// +// Requires libcpucycles (https://cpucycles.cr.yp.to/). +// +// Note: You may need to adjust your system configuration or run `bench` +// as root to grant libcpucycles access to the high-resolution cycle +// counter.  See the following URL for details: +// +// https://cpucycles.cr.yp.to/security.html +// + +#include <stdlib.h> // exit(), qsort() +#include <stdio.h> // printf() +#include <string.h> // memcmp() +#include <math.h> // sqrt(), pow() +#include <cpucycles.h> // cpucycles() +#include "sha3.h" // sha3_*(), shake*() +#include "rand-bytes.h" // rand_bytes() + +// default number of trials +#define NUM_TRIALS 100000 + +// Random data used for key generation and encapsulation. +typedef struct { +  uint8_t keygen[64], // random data for keygen() +          encaps[32]; // random data for encaps() +} seeds_t; + +// Aggregate statistics for a series of tests. +typedef struct { +  // min/max/median times +  long long lo, hi, median; + +  // mean/stddev +  double mean, stddev; +} stats_t; + +static void *checked_calloc(const char *name, const size_t nmemb, const size_t size) { +  // alloc keygen times +  void *mem = calloc(nmemb, size); +  if (!mem) { +    fprintf(stderr, "%s: calloc() failed\n", name); +    exit(-1); +  } +  return mem; +} + +// Callback for `qsort()` to sort observed times in ascending order. +static int sort_asc_cb(const void *ap, const void *bp) { +  const long long *a = ap, *b = bp; +  return *a - *b; +} + +// Get summary statistics of a series of test times. +static stats_t get_stats(long long * const vals, const size_t num_vals) { +  stats_t stats = { 0 }; + +  // sort values in ascending order (used for min, max, and median) +  qsort(vals, num_vals, sizeof(long long), sort_asc_cb); + +  // get low, high, and median +  stats.lo = vals[0]; +  stats.hi = vals[num_vals - 1]; +  stats.median = vals[num_vals / 2]; + +  // calculate mean +  for (size_t i = 0; i < num_vals; i++) { +    stats.mean += vals[i]; +  } +  stats.mean /= num_vals; + +  // calculate standard deviation +  for (size_t i = 0; i < num_vals; i++) { +    stats.stddev += pow(stats.mean - vals[i], 2); +  } +  stats.stddev = sqrt(stats.stddev / num_vals); + +  // return stats +  return stats; +} + +// define xof benchmark function +#define DEF_BENCH_XOF(FN) \ +  static stats_t bench_ ## FN (const size_t num_trials, const size_t src_len, const size_t dst_len) { \ +    /* allocate times, src, and dst buffers */ \ +    long long *times = checked_calloc(__func__, num_trials, sizeof(long long)); \ +    uint8_t *src = checked_calloc(__func__, num_trials, src_len); \ +    uint8_t *dst = checked_calloc(__func__, num_trials, dst_len); \ +    \ +    /* generate random source data */ \ +    rand_bytes(src, num_trials * src_len); \ +    \ +    /* run trials */ \ +    for (size_t i = 0; i < num_trials; i++) { \ +      /* call function */ \ +      const long long t0 = cpucycles(); \ +      FN (src + (i * src_len), src_len, dst + (i * dst_len), dst_len); \ +      const long long t1 = cpucycles() - t0; \ +      \ +      /* save time */ \ +      times[i] = t1; \ +    } \ +    \ +    /* generate summary stats */ \ +    const stats_t stats = get_stats(times, num_trials); \ +    \ +    /* free buffers */ \ +    free(src); \ +    free(times); \ +    \ +    /* return summary stats */ \ +    return stats; \ +  } + +// define hash benchmark function +#define DEF_BENCH_HASH(FN, OUT_LEN) \ +  static stats_t bench_ ## FN (const size_t num_trials, const size_t src_len) { \ +    /* allocate times and src buffers */ \ +    long long *times = checked_calloc(__func__, num_trials, sizeof(long long)); \ +    uint8_t *src = checked_calloc(__func__, src_len, 1); \ +    \ +    /* run trials */ \ +    for (size_t i = 0; i < num_trials; i++) { \ +      /* generate random source data */ \ +      rand_bytes(src, src_len); \ +      \ +      /* call function */ \ +      uint8_t dst[OUT_LEN] = { 0 }; \ +      const long long t0 = cpucycles(); \ +      FN (src, src_len, dst); \ +      const long long t1 = cpucycles() - t0; \ +      \ +      /* save time */ \ +      times[i] = t1; \ +    } \ +    \ +    /* generate summary stats */ \ +    const stats_t stats = get_stats(times, num_trials); \ +    \ +    /* free buffers */ \ +    free(src); \ +    free(times); \ +    \ +    /* return summary stats */ \ +    return stats; \ +  } + +// define xof benchmarks *() +DEF_BENCH_XOF(shake128) +DEF_BENCH_XOF(shake256) + +// define hash benchmarks +DEF_BENCH_HASH(sha3_224, 28) +DEF_BENCH_HASH(sha3_256, 32) +DEF_BENCH_HASH(sha3_384, 48) +DEF_BENCH_HASH(sha3_512, 64) + +// print function stats to standard output as CSV row. +static void print_row(const char *name, const size_t src_len, const size_t dst_len, stats_t fs) { +    const double median_cpb = 1.0 * fs.median / src_len, +                 mean_cpb = 1.0 * fs.mean / src_len; +  printf("%s,%zu,%zu,%.0f,%.0f,%lld,%.0f,%.0f,%lld,%lld\n", name, dst_len, src_len, median_cpb, mean_cpb, fs.median, fs.mean, fs.stddev, fs.lo, fs.hi); +} + +#define MIN_SRC_LEN 64 +#define MAX_SRC_LEN 2048 + +#define MIN_DST_LEN 32 +#define MAX_DST_LEN 128 + +int main(int argc, char *argv[]) { +  // get number of trials from first command-line argument, or fall back +  // to default if no argument was provided +  const size_t num_trials = (argc > 1) ? atoi(argv[1]) : NUM_TRIALS; +  if (num_trials < 2) { +    fprintf(stderr, "num_trials must be greater than 1\n"); +    return -1; +  } + +  // print metadata to stderr +  fprintf(stderr,"info: cpucycles: version=%s implementation=%s persecond=%lld\ninfo: num_trials=%zu\n", cpucycles_version(), cpucycles_implementation(), cpucycles_persecond(), num_trials); + +  // print column headers to stdout +  printf("function,dst,src,median_cpb,mean_cpb,median_cycles,mean_cycles,stddev_cycles,min_cycles,max_cycles\n"); + +  // sha3-224 +  for (size_t src_len = MIN_SRC_LEN; src_len < MAX_SRC_LEN; src_len <<= 1) { +    print_row("sha3_224", src_len, 28, bench_sha3_224(num_trials, src_len)); +  } + +  // sha3-256 +  for (size_t src_len = MIN_SRC_LEN; src_len < MAX_SRC_LEN; src_len <<= 1) { +    print_row("sha3_256", src_len, 32, bench_sha3_256(num_trials, src_len)); +  } + +  // sha3-384 +  for (size_t src_len = MIN_SRC_LEN; src_len < MAX_SRC_LEN; src_len <<= 1) { +    print_row("sha3_384", src_len, 48, bench_sha3_384(num_trials, src_len)); +  } + +  // sha3-512 +  for (size_t src_len = MIN_SRC_LEN; src_len < MAX_SRC_LEN; src_len <<= 1) { +    print_row("sha3_512", src_len, 64, bench_sha3_512(num_trials, src_len)); +  } + +  for (size_t dst_len = MIN_DST_LEN; dst_len < MAX_DST_LEN; dst_len <<= 1) { +    // shake128 +    for (size_t src_len = MIN_SRC_LEN; src_len < MAX_SRC_LEN; src_len <<= 1) { +      print_row("shake128", src_len, dst_len, bench_shake128(num_trials, src_len, dst_len)); +    } + +    // shake256 +    for (size_t src_len = MIN_SRC_LEN; src_len < MAX_SRC_LEN; src_len <<= 1) { +      print_row("shake256", src_len, dst_len, bench_shake256(num_trials, src_len, dst_len)); +    } +  } + +  // return success +  return 0; +}  | 
