From d65046e34693c21067dc580e71485a8e703fc284 Mon Sep 17 00:00:00 2001 From: Paul Duncan Date: Mon, 18 Sep 2023 19:27:12 -0400 Subject: tests/permute: increase test permutation count, add timing note, faster avx512 iota --- tests/permute/permute.c | 66 +++++++++++++++++++++++++------------------------ 1 file changed, 34 insertions(+), 32 deletions(-) diff --git a/tests/permute/permute.c b/tests/permute/permute.c index 7c8afa6..3a236a0 100644 --- a/tests/permute/permute.c +++ b/tests/permute/permute.c @@ -444,12 +444,28 @@ void permute_avx512(uint64_t a[static 25], const size_t num_rounds) { // - probably more unnecessary register usage void permute_avx512_fast(uint64_t s[static 25], const size_t num_rounds) { // unaligned load mask and permutation indices - uint8_t mask = 0x1f; - const __mmask8 m = _load_mask8(&mask); + uint8_t mask = 0x1f, + m0b = 0x01; + const __mmask8 m = _load_mask8(&mask), + m0 = _load_mask8(&m0b); + + // round constants (used in iota) + static const uint64_t RCS[] = { + 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, 0x8000000080008000ULL, + 0x000000000000808bULL, 0x0000000080000001ULL, 0x8000000080008081ULL, 0x8000000000008009ULL, + 0x000000000000008aULL, 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL, + 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, 0x8000000000008003ULL, + 0x8000000000008002ULL, 0x8000000000000080ULL, 0x000000000000800aULL, 0x800000008000000aULL, + 0x8000000080008081ULL, 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL, + }; + + // load round constant + // note: this will bomb if num_rounds < 8 or num_rounds > 24. + __m512i rc = _mm512_loadu_epi64((void*) (RCS + 24 - num_rounds)); - // load rc permutation (TODO) - // static const uint64_t rc_ps[8] = { 1, 2, 3, 4, 5, 6, 7, 0 }, - // const __m512i rc_p = _mm512_loadu_epi64(m, (void*) rc_ps); + // load rc permutation + static const uint64_t rc_ps[8] = { 1, 2, 3, 4, 5, 6, 7, 0 }; + const __m512i rc_p = _mm512_loadu_epi64((void*) rc_ps); // load rows __m512i r0 = _mm512_maskz_loadu_epi64(m, (void*) (s)), @@ -639,32 +655,15 @@ void permute_avx512_fast(uint64_t s[static 25], const size_t num_rounds) { // iota { - // round constants (ambiguous in spec) - static const uint64_t RCS[] = { - 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, 0x8000000080008000ULL, - 0x000000000000808bULL, 0x0000000080000001ULL, 0x8000000080008081ULL, 0x8000000000008009ULL, - 0x000000000000008aULL, 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL, - 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, 0x8000000000008003ULL, - 0x8000000000008002ULL, 0x8000000000000080ULL, 0x000000000000800aULL, 0x800000008000000aULL, - 0x8000000080008081ULL, 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL, - }; - - // TODO - // if ((24 - num_rounds + i + 1) % 8) { - // rc = _mm512_permutexvar_epi64(p0, r4), - // } else { - // } - - // get rc address - const uint64_t *rc = RCS + (24 - num_rounds + i); - - // load mask - uint8_t m0b = 0x01; - const __mmask8 m0 = _load_mask8(&m0b); - - // mask/store row - const __m512i c0 = _mm512_maskz_loadu_epi64(m0, (void*) rc); - r0 = _mm512_xor_epi64(r0, c0); + // xor round constant, shuffle rc register + r0 = _mm512_mask_xor_epi64(r0, m0, r0, rc); + rc = _mm512_permutexvar_epi64(rc_p, rc); + + if (((24 - num_rounds + i + 1) % 8) == 0 && i != 23) { + // load next set of round constants + // note: this will bomb if num_rounds < 8 or num_rounds > 24. + rc = _mm512_loadu_epi64((void*) (RCS + 24 - num_rounds + (i + 1))); + } } } @@ -816,7 +815,7 @@ static void test_permute_fast(void) { } // number of times to run permutation in timing tests below -#define NUM_TIME_PERMUTES 10000000 +#define NUM_TIME_PERMUTES 20000000 // time scalar keccak permutation static void time_permute_scalar(void) { @@ -882,8 +881,11 @@ int main() { test_iota(); test_permute_slow(); test_permute_fast(); + + printf("timing permute, please wait...\n"); time_permute_scalar(); time_permute_avx512_slow(); time_permute_avx512_fast(); + return 0; } -- cgit v1.2.3