From d898f431491846c5742874d81f834e45c984e535 Mon Sep 17 00:00:00 2001 From: Paul Duncan Date: Mon, 29 Apr 2024 19:48:16 -0400 Subject: sha3.c: add permute_n_{scalar,avx512}() and refactor permute{,12}_{scalar,avx512}() to use them --- sha3.c | 221 +++++++---------------------------------------------------------- 1 file changed, 23 insertions(+), 198 deletions(-) diff --git a/sha3.c b/sha3.c index b364ae2..3925f25 100644 --- a/sha3.c +++ b/sha3.c @@ -194,36 +194,34 @@ static inline void iota(uint64_t a[static 25], const int i) { a[0] ^= RCS[i]; } -// 24-round keccak permutation (scalar implementation) -static inline void permute_scalar(uint64_t a[static 25]) { +// keccak permutation (scalar implementation) +static inline void permute_n_scalar(uint64_t a[static 25], const size_t num_rounds) { uint64_t tmp[25] = { 0 }; - for (size_t i = 0; i < SHA3_NUM_ROUNDS; i++) { + for (size_t i = 0; i < num_rounds; i++) { theta(a); rho(a); pi(tmp, a); chi(a, tmp); - iota(a, i); + iota(a, (SHA3_NUM_ROUNDS - num_rounds + i)); } } +// 24 round keccak permutation (scalar implementation) +static inline void permute_scalar(uint64_t a[static 25]) { + permute_n_scalar(a, 24); +} + // 12 round keccak permutation (scalar implementation) // (only used by turboshake) static inline void permute12_scalar(uint64_t a[static 25]) { - uint64_t tmp[25] = { 0 }; - for (size_t i = 0; i < 12; i++) { - theta(a); - rho(a); - pi(tmp, a); - chi(a, tmp); - iota(a, 12 + i); - } + permute_n_scalar(a, 12); } #endif /* (SHA3_BACKEND == SHA3_BACKEND_SCALAR) || defined(SHA3_TEST) */ #if SHA3_BACKEND == SHA3_BACKEND_AVX512 #include -// 24 round keccak permutation (avx512 implementation). +// keccak permutation (avx512 implementation). // // how it operates (roughly): // @@ -251,7 +249,7 @@ static inline void permute12_scalar(uint64_t a[static 25]) { // 3. store the rows first 5 64-bit lanes of registers r0-r4 back to the // state `s`. // -static inline void permute_avx512(uint64_t s[static 25]) { +static inline void permute_n_avx512(uint64_t s[static 25], const size_t num_rounds) { // load rows (r0-r4) __m512i r0 = _mm512_maskz_loadu_epi64(0x1f, s + 0), // row 0 r1 = _mm512_maskz_loadu_epi64(0x1f, s + 5), // row 1 @@ -260,7 +258,7 @@ static inline void permute_avx512(uint64_t s[static 25]) { r4 = _mm512_maskz_loadu_epi64(0x1f, s + 20); // row 4 // 24 rounds - for (size_t i = 0; i < SHA3_NUM_ROUNDS; i++) { + for (size_t i = 0; i < num_rounds; i++) { // theta { // permute ids @@ -423,8 +421,11 @@ static inline void permute_avx512(uint64_t s[static 25]) { // iota { + // calculate RCS offset + const size_t ofs = SHA3_NUM_ROUNDS - num_rounds + i; + // xor round constant to first cell - r0 = _mm512_mask_xor_epi64(r0, 1, r0, _mm512_maskz_loadu_epi64(1, RCS + i)); + r0 = _mm512_mask_xor_epi64(r0, 1, r0, _mm512_maskz_loadu_epi64(1, RCS + ofs)); } } @@ -436,190 +437,14 @@ static inline void permute_avx512(uint64_t s[static 25]) { _mm512_mask_storeu_epi64(s + 5 * 4, 0x1f, r4); } +// 24 round keccak permutation (avx512 implementation). +static inline void permute_avx512(uint64_t s[static 25]) { + permute_n_avx512(s, 24); +} + // 12 round keccak permutation (avx512 implementation). static inline void permute12_avx512(uint64_t s[static 25]) { - // load rows (r0-r4) - __m512i r0 = _mm512_maskz_loadu_epi64(0x1f, s + 0), // row 0 - r1 = _mm512_maskz_loadu_epi64(0x1f, s + 5), // row 1 - r2 = _mm512_maskz_loadu_epi64(0x1f, s + 10), // row 2 - r3 = _mm512_maskz_loadu_epi64(0x1f, s + 15), // row 3 - r4 = _mm512_maskz_loadu_epi64(0x1f, s + 20); // row 4 - - // 12 rounds - for (size_t i = 0; i < 12; i++) { - // theta - { - // permute ids - static const __m512i I0 = { 4, 0, 1, 2, 3 }, - I1 = { 1, 2, 3, 4, 0 }; - - // c = xor(r0, r1, r2, r3, r4) - const __m512i r01 = _mm512_maskz_xor_epi64(0x1f, r0, r1), - r23 = _mm512_maskz_xor_epi64(0x1f, r2, r3), - c = _mm512_maskz_ternarylogic_epi64(0x1f, r01, r23, r4, 0x96); - - // d = xor(permute(i0, c), permute(i1, rol(c, 1))) - const __m512i d0 = _mm512_permutexvar_epi64(I0, c), - d1 = _mm512_permutexvar_epi64(I1, _mm512_rol_epi64(c, 1)), - d = _mm512_xor_epi64(d0, d1); - - // row = xor(row, d) - r0 = _mm512_xor_epi64(r0, d); - r1 = _mm512_xor_epi64(r1, d); - r2 = _mm512_xor_epi64(r2, d); - r3 = _mm512_xor_epi64(r3, d); - r4 = _mm512_xor_epi64(r4, d); - } - - // rho - { - // rotate values - // - // note: switching from maskz_load_epi64()s to static const - // __m512i incurs a 500 cycle penalty; leaving them for now - static const uint64_t V0_VALS[5] ALIGN(64) = { 0, 1, 62, 28, 27 }, - V1_VALS[5] ALIGN(64) = { 36, 44, 6, 55, 20 }, - V2_VALS[5] ALIGN(64) = { 3, 10, 43, 25, 39 }, - V3_VALS[5] ALIGN(64) = { 41, 45, 15, 21, 8 }, - V4_VALS[5] ALIGN(64) = { 18, 2, 61, 56, 14 }; - - // rotate rows - r0 = _mm512_rolv_epi64(r0, _mm512_maskz_load_epi64(0x1f, V0_VALS)); - r1 = _mm512_rolv_epi64(r1, _mm512_maskz_load_epi64(0x1f, V1_VALS)); - r2 = _mm512_rolv_epi64(r2, _mm512_maskz_load_epi64(0x1f, V2_VALS)); - r3 = _mm512_rolv_epi64(r3, _mm512_maskz_load_epi64(0x1f, V3_VALS)); - r4 = _mm512_rolv_epi64(r4, _mm512_maskz_load_epi64(0x1f, V4_VALS)); - } - - // pi - // - // The cells are permuted across all rows of the state array. each - // output row is the combination of three permutations: - // - // - e0: row 0 and row 1 - // - e2: row 2 and row 3 - // - e4: row 4 and row 0 - // - // the IDs for each permutation are merged into a single array - // (T*_IDS) to reduce register pressure, and the permute operations - // are masked so that each permutation only uses the relevant IDs. - // - // afterwards, the permutations are combined to form a temporary - // row: - // - // t0 = t0e0 | t0e2 | t0e4 - // - // once the permutations for all rows are complete, the temporary - // rows are saved to the actual row registers: - // - // r0 = t0 - // - { - // permute ids - static const __m512i T0_IDS = { 0, 8 + 1, 2, 8 + 3, 4 }, - T1_IDS = { 3, 8 + 4, 0, 8 + 1, 2 }, - T2_IDS = { 1, 8 + 2, 3, 8 + 4, 0 }, - T3_IDS = { 4, 8 + 0, 1, 8 + 2, 3 }, - T4_IDS = { 2, 8 + 3, 4, 8 + 0, 1 }; - - __m512i t0, t1, t2, t3, t4; - { - // permute r0 - const __m512i t0e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T0_IDS, r1), - t0e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T0_IDS, r3), - t0e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T0_IDS, r0); - - // permute r1 - const __m512i t1e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T1_IDS, r1), - t1e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T1_IDS, r3), - t1e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T1_IDS, r0); - - // permute r2 - const __m512i t2e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T2_IDS, r1), - t2e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T2_IDS, r3), - t2e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T2_IDS, r0); - - // permute r3 - const __m512i t3e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T3_IDS, r1), - t3e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T3_IDS, r3), - t3e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T3_IDS, r0); - - // permute r4 - const __m512i t4e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T4_IDS, r1), - t4e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T4_IDS, r3), - t4e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T4_IDS, r0); - - // combine permutes: tN = e0 | e2 | e4 - t0 = _mm512_maskz_ternarylogic_epi64(0x1f, t0e0, t0e2, t0e4, 0xfe); - t1 = _mm512_maskz_ternarylogic_epi64(0x1f, t1e0, t1e2, t1e4, 0xfe); - t2 = _mm512_maskz_ternarylogic_epi64(0x1f, t2e0, t2e2, t2e4, 0xfe); - t3 = _mm512_maskz_ternarylogic_epi64(0x1f, t3e0, t3e2, t3e4, 0xfe); - t4 = _mm512_maskz_ternarylogic_epi64(0x1f, t4e0, t4e2, t4e4, 0xfe); - } - - // store rows - r0 = t0; - r1 = t1; - r2 = t2; - r3 = t3; - r4 = t4; - } - - // chi - { - // permute ids - static const __m512i P0 = { 1, 2, 3, 4, 0 }, - P1 = { 2, 3, 4, 0, 1 }; - - { - // r0 ^= ~e0 & e1 - const __m512i t0_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r0), - t0_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r0); - r0 = _mm512_maskz_ternarylogic_epi64(0x1f, r0, t0_e0, t0_e1, 0xd2); - } - - { - // r1 ^= ~e0 & e1 - const __m512i t1_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r1), - t1_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r1); - r1 = _mm512_maskz_ternarylogic_epi64(0x1f, r1, t1_e0, t1_e1, 0xd2); - } - - { - // r2 ^= ~e0 & e1 - const __m512i t2_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r2), - t2_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r2); - r2 = _mm512_maskz_ternarylogic_epi64(0x1f, r2, t2_e0, t2_e1, 0xd2); - } - - { - // r3 ^= ~e0 & e1 - const __m512i t3_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r3), - t3_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r3); - r3 = _mm512_maskz_ternarylogic_epi64(0x1f, r3, t3_e0, t3_e1, 0xd2); - } - - { - // r4 ^= ~e0 & e1 - const __m512i t4_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r4), - t4_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r4); - r4 = _mm512_maskz_ternarylogic_epi64(0x1f, r4, t4_e0, t4_e1, 0xd2); - } - } - - // iota - { - // xor round constant to first cell - r0 = _mm512_mask_xor_epi64(r0, 1, r0, _mm512_maskz_loadu_epi64(1, RCS + (12 + i))); - } - } - - // store rows - _mm512_mask_storeu_epi64(s + 5 * 0, 0x1f, r0); - _mm512_mask_storeu_epi64(s + 5 * 1, 0x1f, r1); - _mm512_mask_storeu_epi64(s + 5 * 2, 0x1f, r2); - _mm512_mask_storeu_epi64(s + 5 * 3, 0x1f, r3); - _mm512_mask_storeu_epi64(s + 5 * 4, 0x1f, r4); + permute_n_avx512(s, 12); } #endif /* SHA3_BACKEND == SHA3_BACKEND_AVX512 */ -- cgit v1.2.3