From 006ed7e29912cacfe8a5aacdaab13e76aadb96a5 Mon Sep 17 00:00:00 2001 From: Paul Duncan Date: Mon, 18 Sep 2023 19:05:15 -0400 Subject: tests/permute/permute.c: add comments --- tests/permute/permute.c | 62 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 10 deletions(-) (limited to 'tests/permute') diff --git a/tests/permute/permute.c b/tests/permute/permute.c index 9ba5095..7c8afa6 100644 --- a/tests/permute/permute.c +++ b/tests/permute/permute.c @@ -12,6 +12,7 @@ // number of rounds for permute() #define SHA3_NUM_ROUNDS 24 +// scalar impl of theta step static void theta_scalar(uint64_t a[static 25]) { const uint64_t c[5] = { a[0] ^ a[5] ^ a[10] ^ a[15] ^ a[20], @@ -36,6 +37,7 @@ static void theta_scalar(uint64_t a[static 25]) { a[20] ^= d[0]; a[21] ^= d[1]; a[22] ^= d[2]; a[23] ^= d[3]; a[24] ^= d[4]; } +// scalar impl of rho step static void rho_scalar(uint64_t a[static 25]) { a[1] = ROL(a[1], 1); // 1 % 64 = 1 a[2] = ROL(a[2], 62); // 190 % 64 = 62 @@ -63,6 +65,7 @@ static void rho_scalar(uint64_t a[static 25]) { a[24] = ROL(a[24], 14); // 78 % 64 = 14 } +// scalar impl of pi step static void pi_scalar(uint64_t a[static 25]) { uint64_t t[25] = { 0 }; memcpy(t, a, sizeof(t)); @@ -92,6 +95,7 @@ static void pi_scalar(uint64_t a[static 25]) { a[24] = t[21]; } +// scalar impl of chi step static void chi_scalar(uint64_t a[static 25]) { uint64_t t[25] = { 0 }; memcpy(t, a, sizeof(t)); @@ -122,6 +126,7 @@ static void chi_scalar(uint64_t a[static 25]) { a[24] = t[24] ^ (~t[20] & t[21]); } +// scalar impl of iota step static void iota_scalar(uint64_t a[static 25], const int i) { // round constants (ambiguous in spec) static const uint64_t RCS[] = { @@ -136,7 +141,7 @@ static void iota_scalar(uint64_t a[static 25], const int i) { a[0] ^= RCS[i]; } -// keccak permutation (scalar implementation). +// scalar impl of keccak permutation. void permute_scalar(uint64_t a[static 25], const size_t num_rounds) { for (int i = 0; i < (int) num_rounds; i++) { theta_scalar(a); @@ -150,6 +155,7 @@ void permute_scalar(uint64_t a[static 25], const size_t num_rounds) { #ifdef __AVX512F__ #include +// avx512 impl of theta step static void theta_avx512(uint64_t s[static 25]) { // unaligned load mask and permutation indices uint8_t mask = 0x1f; @@ -190,6 +196,7 @@ static void theta_avx512(uint64_t s[static 25]) { _mm512_mask_storeu_epi64((void*) (s + 20), m, r4); } +// avx512 impl of rho step static void rho_avx512(uint64_t s[static 25]) { // unaligned load mask and rotate values uint8_t mask = 0x1f; @@ -229,6 +236,15 @@ static void rho_avx512(uint64_t s[static 25]) { _mm512_mask_storeu_epi64((void*) (s + 20), m, r4); } +// avx512 impl of pi step +// +// note: i originally tried a simpler implementation which just copied +// the state array out to a temporary buffer and then gathered it back +// in shuffled order. +// +// the "simpler" implementation did not work correctly, so i rewrote the +// shuffling as in-register permutations (which should actually be +// faster). static void pi_avx512(uint64_t s[static 25]) { // mask bytes uint8_t mask = 0x1f, @@ -241,14 +257,8 @@ static void pi_avx512(uint64_t s[static 25]) { m01 = _load_mask8(&m01b), m23 = _load_mask8(&m23b), m4 = _load_mask8(&m4b); - // permutation indices (offsets into state array) - // static uint64_t vs0[8] = { 0, 6, 12, 18, 24, 0, 0, 0 }, - // vs1[8] = { 3, 9, 10, 16, 22, 0, 0, 0 }, - // vs2[8] = { 1, 7, 13, 19, 20, 0, 0, 0 }, - // vs3[8] = { 4, 5, 11, 17, 23, 0, 0, 0 }, - // vs4[8] = { 2, 8, 14, 15, 21, 0, 0, 0 }; - // permutation indices + // permutation indices (offsets into state array) // // (note: these are masked so only the relevant indices for // _mm512_maskz_permutex2var_epi64() in each array are filled in) @@ -334,6 +344,7 @@ static void pi_avx512(uint64_t s[static 25]) { _mm512_mask_storeu_epi64((void*) (s + 20), m, t4); } +// avx512 impl of chi step static void chi_avx512(uint64_t s[static 25]) { // mask bytes uint8_t mask = 0x1f; @@ -384,6 +395,7 @@ static void chi_avx512(uint64_t s[static 25]) { _mm512_mask_storeu_epi64((void*) (s + 20), m, t4); } +// avx512 impl of iota step static void iota_avx512(uint64_t s[static 25], const int i) { // round constants (ambiguous in spec) static const uint64_t RCS[] = { @@ -407,7 +419,7 @@ static void iota_avx512(uint64_t s[static 25], const int i) { _mm512_mask_storeu_epi64((void*) s, m0, t0); } -// keccak permutation (slow avx512 implementation). +// slow avx512 impl of keccak permutation. void permute_avx512(uint64_t a[static 25], const size_t num_rounds) { for (int i = 0; i < (int) num_rounds; i++) { theta_avx512(a); @@ -418,7 +430,18 @@ void permute_avx512(uint64_t a[static 25], const size_t num_rounds) { } } -// keccak permutation (fast avx512 implementation). +// fast avx512 impl of keccak permutation +// +// this version is similar to permute_avx_512_slow(), except the +// function calls are inlined as blocks, duplicate definitions have +// been removed, and the state array loads and stores only happen at the +// beginning and end of the function. +// +// there are still several optimizations that can be done. for example: +// - iota does not need to reload the round constant every round +// (only every 8th round) +// - some spills could be addressed +// - probably more unnecessary register usage void permute_avx512_fast(uint64_t s[static 25], const size_t num_rounds) { // unaligned load mask and permutation indices uint8_t mask = 0x1f; @@ -654,6 +677,14 @@ void permute_avx512_fast(uint64_t s[static 25], const size_t num_rounds) { } #endif /* __AVX512F__ */ +// verify that both state arrays are identical. +// +// if the state arrays are equal, print "$TEST_NAME passed" to standard +// output. if the state arrays are NOT equal, then print the contents +// of both state arrays with the differing cells highlighted to standard +// error, then exit with an error. +// +// used by test_*() functions below. static void check(const char *name, uint64_t a_scalar[static 25], uint64_t a_avx512[static 25]) { // compare if (!memcmp(a_scalar, a_avx512, 25 * sizeof(uint64_t))) { @@ -689,6 +720,7 @@ static void check(const char *name, uint64_t a_scalar[static 25], uint64_t a_avx exit(-1); } +// test avx512 theta static void test_theta(void) { uint64_t a_scalar[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 }; uint64_t a_avx512[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 }; @@ -700,6 +732,7 @@ static void test_theta(void) { check("test_theta()", a_scalar, a_avx512); } +// test avx512 rho static void test_rho(void) { uint64_t a_scalar[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 }; uint64_t a_avx512[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 }; @@ -711,6 +744,7 @@ static void test_rho(void) { check("test_rho()", a_scalar, a_avx512); } +// test avx512 pi static void test_pi(void) { uint64_t a_scalar[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 }; uint64_t a_avx512[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 }; @@ -722,6 +756,7 @@ static void test_pi(void) { check("test_pi()", a_scalar, a_avx512); } +// test avx512 chi static void test_chi(void) { uint64_t a_scalar[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 }; uint64_t a_avx512[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 }; @@ -733,6 +768,7 @@ static void test_chi(void) { check("test_chi()", a_scalar, a_avx512); } +// test avx512 iota static void test_iota(void) { uint64_t a_scalar[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 }; uint64_t a_avx512[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 }; @@ -751,6 +787,7 @@ static void test_iota(void) { } } +// test avx512 permute_slow static void test_permute_slow(void) { uint64_t a_scalar[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 }; uint64_t a_avx512[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 }; @@ -763,6 +800,7 @@ static void test_permute_slow(void) { check("test_permute_slow()", a_scalar, a_avx512); } +// test avx512 permute_fast static void test_permute_fast(void) { #ifdef __AVX512F__ uint64_t a_scalar[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 }; @@ -777,8 +815,10 @@ static void test_permute_fast(void) { #endif /* __AVX512F__ */ } +// number of times to run permutation in timing tests below #define NUM_TIME_PERMUTES 10000000 +// time scalar keccak permutation static void time_permute_scalar(void) { uint64_t a[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 }; @@ -794,6 +834,7 @@ static void time_permute_scalar(void) { printf("time_permute_scalar(): %zu\n", t1 - t0); } +// time slow avx512 keccak permutation static void time_permute_avx512_slow(void) { #ifdef __AVX512F__ uint64_t a[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 }; @@ -813,6 +854,7 @@ static void time_permute_avx512_slow(void) { #endif /* __AVX512F__ */ } +// time fast avx512 keccak permutation static void time_permute_avx512_fast(void) { #ifdef __AVX512F__ uint64_t a[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 }; -- cgit v1.2.3