From 1a6d29a92333b2b0ab0e1bd363fb0243e2c030f0 Mon Sep 17 00:00:00 2001 From: Paul Duncan Date: Fri, 3 May 2024 22:20:56 -0400 Subject: sha3.c: refactor backends so they only implement permute_n() i verified that (gcc, at least) does constant propagation and inlines permute_n_ and that this change does not affect performance. bench results, before: pabs@flex:~/git/sha3/tests/bench> ./bench info: cpucycles: version=20240318 implementation=amd64-pmc persecond=4800000000 info: backend=avx512 num_trials=100000 src_lens=64,256,1024,4096,16384 dst_lens=32 function,dst_len,64,256,1024,4096,16384 sha3_224,28,15.4,7.8,7.8,7.1,7.0 sha3_256,32,15.5,7.8,7.8,7.6,7.4 sha3_384,48,15.5,11.7,9.8,9.8,9.7 sha3_512,64,15.6,15.5,14.6,13.9,13.9 shake128,32,15.5,7.8,6.9,6.2,6.1 shake256,32,15.5,7.8,7.9,7.6,7.4 bench results, after change: pabs@flex:~/git/sha3/tests/bench> ./bench info: cpucycles: version=20240318 implementation=amd64-pmc persecond=4800000000 info: backend=avx512 num_trials=100000 src_lens=64,256,1024,4096,16384 dst_lens=32 function,dst_len,64,256,1024,4096,16384 sha3_224,28,15.4,7.8,7.8,7.1,7.0 sha3_256,32,15.6,7.8,7.8,7.6,7.4 sha3_384,48,15.6,11.7,9.8,9.8,9.7 sha3_512,64,15.6,15.5,14.6,13.8,13.8 shake128,32,15.6,7.9,6.9,6.2,6.1 shake256,32,15.7,7.9,7.9,7.6,7.4 --- sha3.c | 61 +++++++++++++++++++++---------------------------------------- 1 file changed, 21 insertions(+), 40 deletions(-) (limited to 'sha3.c') diff --git a/sha3.c b/sha3.c index 6f936b9..8563d27 100644 --- a/sha3.c +++ b/sha3.c @@ -231,23 +231,6 @@ static inline void permute_n_scalar(uint64_t a[static 25], const size_t num_roun iota(a, (SHA3_NUM_ROUNDS - num_rounds + i)); } } - -/** - * @brief 24 round scalar Keccak permutation. - * @param[in,out] a Keccak state (array of 25 64-bit integers). - */ -static inline void permute_scalar(uint64_t a[static 25]) { - permute_n_scalar(a, 24); -} - -/** - * @brief 12 round scalar Keccak permutation. - * @note Only used by TurboSHAKE and KangarooTwelve. - * @param[in,out] a Keccak state (array of 25 64-bit integers). - */ -static inline void permute12_scalar(uint64_t a[static 25]) { - permute_n_scalar(a, 12); -} #endif /* (SHA3_BACKEND == BACKEND_SCALAR) || defined(SHA3_TEST) */ #if SHA3_BACKEND == BACKEND_AVX512 @@ -477,36 +460,34 @@ static inline void permute_n_avx512(uint64_t s[static 25], const size_t num_roun _mm512_mask_storeu_epi64(s + 5 * 3, 0x1f, r3); _mm512_mask_storeu_epi64(s + 5 * 4, 0x1f, r4); } +#endif /* SHA3_BACKEND == BACKEND_AVX512 */ + +#if SHA3_BACKEND == BACKEND_AVX512 +// use avx512 backend +#define permute_n permute_n_avx512 +#elif SHA3_BACKEND == BACKEND_SCALAR +// use scalar backend +#define permute_n permute_n_scalar +#else +#error "unknown sha3 backend" +#endif /* SHA3_BACKEND */ /** - * @brief 24 round AVX-512 Keccak permutation. + * @brief 24 round Keccak permutation. * @param[in,out] a Keccak state (array of 25 64-bit integers). */ -static inline void permute_avx512(uint64_t s[static 25]) { - permute_n_avx512(s, 24); +static inline void permute(uint64_t s[static 25]) { + permute_n(s, 24); } /** - * @brief 12 round AVX-512 Keccak permutation. + * @brief 12 round Keccak permutation. * @note Only used by TurboSHAKE and KangarooTwelve. * @param[in,out] a Keccak state (array of 25 64-bit integers). */ -static inline void permute12_avx512(uint64_t s[static 25]) { - permute_n_avx512(s, 12); +static inline void permute12(uint64_t s[static 25]) { + permute_n(s, 12); } -#endif /* SHA3_BACKEND == BACKEND_AVX512 */ - -#if SHA3_BACKEND == BACKEND_AVX512 -// use avx512 backend -#define permute permute_avx512 -#define permute12 permute12_avx512 -#elif SHA3_BACKEND == BACKEND_SCALAR -// use scalar backend -#define permute permute_scalar -#define permute12 permute12_scalar -#else -#error "unknown sha3 backend" -#endif /* SHA3_BACKEND */ // absorb message into state, return updated byte count // used by `hash_absorb()`, `hash_once()`, and `xof_absorb_raw()` @@ -2231,7 +2212,7 @@ static void test_permute_scalar(void) { uint64_t got[25] = { 0 }; memcpy(got, PERMUTE_TESTS[i].a, sizeof(got)); - permute_scalar(got); + permute_n_scalar(got, 24); // call permute_n_scalar() directly if (memcmp(got, PERMUTE_TESTS[i].exp, exp_len)) { fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE_TESTS[i].exp, exp_len); @@ -2246,7 +2227,7 @@ static void test_permute_avx512(void) { uint64_t got[25] = { 0 }; memcpy(got, PERMUTE_TESTS[i].a, sizeof(got)); - permute_avx512(got); + permute_n_avx512(got, 24); // call permute_n_avx512() directly if (memcmp(got, PERMUTE_TESTS[i].exp, exp_len)) { fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE_TESTS[i].exp, exp_len); @@ -2271,7 +2252,7 @@ static void test_permute12_scalar(void) { uint64_t got[25] = { 0 }; memcpy(got, PERMUTE12_TESTS[i].a, sizeof(got)); - permute12_scalar(got); + permute_n_scalar(got, 12); // call permute_n_scalar() directly if (memcmp(got, PERMUTE12_TESTS[i].exp, exp_len)) { fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE12_TESTS[i].exp, exp_len); @@ -2286,7 +2267,7 @@ static void test_permute12_avx512(void) { uint64_t got[25] = { 0 }; memcpy(got, PERMUTE12_TESTS[i].a, sizeof(got)); - permute12_avx512(got); + permute_n_avx512(got, 12); // call permute_n_avx512() directly if (memcmp(got, PERMUTE12_TESTS[i].exp, exp_len)) { fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE12_TESTS[i].exp, exp_len); -- cgit v1.2.3