diff options
-rw-r--r-- | sha3.c | 258 |
1 files changed, 257 insertions, 1 deletions
@@ -43,6 +43,13 @@ // number of rounds for permute() #define SHA3_NUM_ROUNDS 24 +#if !defined(__AVX512F__) || defined(SHA3_TEST) +// If AVX512 is supported and we are not building the test suite, +// then do not compile the scalar step functions below. +// +// (because they aren't used by the AVX512 implementation). + +// theta step of keccak permutation (scalar implementation) static inline void theta(uint64_t a[static 25]) { const uint64_t c[5] = { a[0] ^ a[5] ^ a[10] ^ a[15] ^ a[20], @@ -67,6 +74,7 @@ static inline void theta(uint64_t a[static 25]) { a[20] ^= d[0]; a[21] ^= d[1]; a[22] ^= d[2]; a[23] ^= d[3]; a[24] ^= d[4]; } +// rho step of keccak permutation (scalar implementation) static inline void rho(uint64_t a[static 25]) { a[1] = ROL(a[1], 1); // 1 % 64 = 1 a[2] = ROL(a[2], 62); // 190 % 64 = 62 @@ -94,6 +102,7 @@ static inline void rho(uint64_t a[static 25]) { a[24] = ROL(a[24], 14); // 78 % 64 = 14 } +// pi step of keccak permutation (scalar implementation) static inline void pi(uint64_t a[static 25]) { uint64_t t[25] = { 0 }; memcpy(t, a, sizeof(t)); @@ -123,6 +132,7 @@ static inline void pi(uint64_t a[static 25]) { a[24] = t[21]; } +// chi step of keccak permutation (scalar implementation) static inline void chi(uint64_t a[static 25]) { uint64_t t[25] = { 0 }; memcpy(t, a, sizeof(t)); @@ -153,6 +163,7 @@ static inline void chi(uint64_t a[static 25]) { a[24] = t[24] ^ (~t[20] & t[21]); } +// iota step of keccak permutation (scalar implementation) static inline void iota(uint64_t a[static 25], const int i) { // round constants (ambiguous in spec) static const uint64_t RCS[] = { @@ -166,8 +177,10 @@ static inline void iota(uint64_t a[static 25], const int i) { a[0] ^= RCS[i]; } +#endif /* !defined(__AVX512F__) || defined(SHA3_TEST) */ -// keccak permutation. +#ifndef __AVX512F__ +// keccak permutation (scalar implementation) // // note: clang is better about inlining this than gcc with a // configurable number of rounds. the configurable number of rounds is @@ -182,6 +195,249 @@ static inline void permute(uint64_t a[static 25], const size_t num_rounds) { iota(a, 24 - num_rounds + i); } } +#endif /* __AVX512F__ */ + +#ifdef __AVX512F__ +#include <immintrin.h> + +// keccak permutation (avx512 implementation). +// +// copied from `permute_avx512_fast()` in `tests/permute/permute.c`. all +// steps are inlined as blocks. ~3x faster than scalar implementation, +// but could be sped up more. +static inline void permute(uint64_t s[static 25], const size_t num_rounds) { + // unaligned load mask and permutation indices + uint8_t mask = 0x1f, + m0b = 0x01; + const __mmask8 m = _load_mask8(&mask), + m0 = _load_mask8(&m0b); + + // round constants (used in iota) + static const uint64_t RCS[] = { + 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, 0x8000000080008000ULL, + 0x000000000000808bULL, 0x0000000080000001ULL, 0x8000000080008081ULL, 0x8000000000008009ULL, + 0x000000000000008aULL, 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL, + 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, 0x8000000000008003ULL, + 0x8000000000008002ULL, 0x8000000000000080ULL, 0x000000000000800aULL, 0x800000008000000aULL, + 0x8000000080008081ULL, 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL, + }; + + // load round constant + // note: this will bomb if num_rounds < 8 or num_rounds > 24. + __m512i rc = _mm512_loadu_epi64((void*) (RCS + 24 - num_rounds)); + + // load rc permutation + static const uint64_t rc_ps[8] = { 1, 2, 3, 4, 5, 6, 7, 0 }; + const __m512i rc_p = _mm512_loadu_epi64((void*) rc_ps); + + // load rows + __m512i r0 = _mm512_maskz_loadu_epi64(m, (void*) (s)), + r1 = _mm512_maskz_loadu_epi64(m, (void*) (s + 5)), + r2 = _mm512_maskz_loadu_epi64(m, (void*) (s + 10)), + r3 = _mm512_maskz_loadu_epi64(m, (void*) (s + 15)), + r4 = _mm512_maskz_loadu_epi64(m, (void*) (s + 20)); + + for (int i = 0; i < (int) num_rounds; i++) { + // theta + { + const __m512i i0 = _mm512_setr_epi64(4, 0, 1, 2, 3, 5, 6, 7), + i1 = _mm512_setr_epi64(1, 2, 3, 4, 0, 5, 6, 7); + + // c = xor(r0, r1, r2, r3, r4) + const __m512i r01 = _mm512_xor_epi64(r0, r1), + r23 = _mm512_xor_epi64(r2, r3), + r03 = _mm512_xor_epi64(r01, r23), + c = _mm512_xor_epi64(r03, r4); + + // d = xor(permute(i0, c), permute(i1, rol(c, 1))) + const __m512i d0 = _mm512_permutexvar_epi64(i0, c), + d1 = _mm512_permutexvar_epi64(i1, _mm512_rol_epi64(c, 1)), + d = _mm512_xor_epi64(d0, d1); + + // row = xor(row, d) + r0 = _mm512_xor_epi64(r0, d); + r1 = _mm512_xor_epi64(r1, d); + r2 = _mm512_xor_epi64(r2, d); + r3 = _mm512_xor_epi64(r3, d); + r4 = _mm512_xor_epi64(r4, d); + } + + // rho + { + // unaligned load mask and rotate values + static const uint64_t vs0[8] = { 0, 1, 62, 28, 27, 0, 0, 0 }, + vs1[8] = { 36, 44, 6, 55, 20, 0, 0, 0 }, + vs2[8] = { 3, 10, 43, 25, 39, 0, 0, 0 }, + vs3[8] = { 41, 45, 15, 21, 8, 0, 0, 0 }, + vs4[8] = { 18, 2, 61, 56, 14, 0, 0, 0 }; + + // load rotate values + const __m512i v0 = _mm512_loadu_epi64((void*) vs0), + v1 = _mm512_loadu_epi64((void*) vs1), + v2 = _mm512_loadu_epi64((void*) vs2), + v3 = _mm512_loadu_epi64((void*) vs3), + v4 = _mm512_loadu_epi64((void*) vs4); + + // rotate + r0 = _mm512_rolv_epi64(r0, v0); + r1 = _mm512_rolv_epi64(r1, v1); + r2 = _mm512_rolv_epi64(r2, v2); + r3 = _mm512_rolv_epi64(r3, v3); + r4 = _mm512_rolv_epi64(r4, v4); + } + + // pi + { + // mask bytes + uint8_t m01b = 0x03, + m23b = 0x0c, + m4b = 0x10; + + // load masks + const __mmask8 m01 = _load_mask8(&m01b), + m23 = _load_mask8(&m23b), + m4 = _load_mask8(&m4b); + + // permutation indices + // + // (note: these are masked so only the relevant indices for + // _mm512_maskz_permutex2var_epi64() in each array are filled in) + static const uint64_t t0_pis_01[8] = { 0, 8 + 1, 0, 0, 0, 0, 0, 0 }, + t0_pis_23[8] = { 0, 0, 2, 8 + 3, 0, 0, 0, 0 }, + t0_pis_4[8] = { 0, 0, 0, 0, 4, 0, 0, 0 }, + + t1_pis_01[8] = { 3, 8 + 4, 0, 0, 0, 0, 0, 0 }, + t1_pis_23[8] = { 0, 0, 0, 8 + 1, 0, 0, 0, 0 }, + t1_pis_4[8] = { 0, 0, 0, 0, 2, 0, 0, 0 }, + + t2_pis_01[8] = { 1, 8 + 2, 0, 0, 0, 0, 0, 0 }, + t2_pis_23[8] = { 0, 0, 3, 8 + 4, 0, 0, 0, 0 }, + t2_pis_4[8] = { 0, 0, 0, 0, 0, 0, 0, 0 }, + + t3_pis_01[8] = { 4, 8 + 0, 0, 0, 0, 0, 0, 0 }, + t3_pis_23[8] = { 0, 0, 1, 8 + 2, 0, 0, 0, 0 }, + t3_pis_4[8] = { 0, 0, 0, 0, 3, 0, 0, 0 }, + + t4_pis_01[8] = { 2, 8 + 3, 0, 0, 0, 0, 0, 0 }, + t4_pis_23[8] = { 0, 0, 4, 8 + 0, 0, 0, 0, 0 }, + t4_pis_4[8] = { 0, 0, 0, 0, 1, 0, 0, 0 }; + + // load permutation indices + const __m512i t0_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t0_pis_01), + t0_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t0_pis_23), + t0_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t0_pis_4), + + t1_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t1_pis_01), + t1_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t1_pis_23), + t1_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t1_pis_4), + + t2_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t2_pis_01), + t2_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t2_pis_23), + t2_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t2_pis_4), + + t3_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t3_pis_01), + t3_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t3_pis_23), + t3_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t3_pis_4), + + t4_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t4_pis_01), + t4_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t4_pis_23), + t4_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t4_pis_4); + + // permute rows + const __m512i t0e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t0_p01, r1), + t0e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t0_p23, r3), + t0e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t0_p4, r0), + t0 = _mm512_or_epi64(_mm512_or_epi64(t0e0, t0e2), t0e4), + + t1e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t1_p01, r1), + t1e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t1_p23, r3), + t1e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t1_p4, r0), + t1 = _mm512_or_epi64(_mm512_or_epi64(t1e0, t1e2), t1e4), + + t2e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t2_p01, r1), + t2e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t2_p23, r3), + t2e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t2_p4, r0), + t2 = _mm512_or_epi64(_mm512_or_epi64(t2e0, t2e2), t2e4), + + t3e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t3_p01, r1), + t3e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t3_p23, r3), + t3e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t3_p4, r0), + t3 = _mm512_or_epi64(_mm512_or_epi64(t3e0, t3e2), t3e4), + + t4e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t4_p01, r1), + t4e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t4_p23, r3), + t4e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t4_p4, r0), + t4 = _mm512_or_epi64(_mm512_or_epi64(t4e0, t4e2), t4e4); + + // store rows + r0 = t0; + r1 = t1; + r2 = t2; + r3 = t3; + r4 = t4; + } + + // chi + { + // permutation indices + static const uint64_t pis0[8] = { 1, 2, 3, 4, 0, 0, 0, 0 }, + pis1[8] = { 2, 3, 4, 0, 1, 0, 0, 0 }; + + // load permutation indices + const __m512i p0 = _mm512_maskz_loadu_epi64(m, (void*) pis0), + p1 = _mm512_maskz_loadu_epi64(m, (void*) pis1); + + // permute rows + const __m512i t0_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r0), + t0_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r0), + t0 = _mm512_xor_epi64(r0, _mm512_andnot_epi64(t0_e0, t0_e1)), + + t1_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r1), + t1_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r1), + t1 = _mm512_xor_epi64(r1, _mm512_andnot_epi64(t1_e0, t1_e1)), + + t2_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r2), + t2_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r2), + t2 = _mm512_xor_epi64(r2, _mm512_andnot_epi64(t2_e0, t2_e1)), + + t3_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r3), + t3_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r3), + t3 = _mm512_xor_epi64(r3, _mm512_andnot_epi64(t3_e0, t3_e1)), + + t4_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r4), + t4_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r4), + t4 = _mm512_xor_epi64(r4, _mm512_andnot_epi64(t4_e0, t4_e1)); + + // store rows + r0 = t0; + r1 = t1; + r2 = t2; + r3 = t3; + r4 = t4; + } + + // iota + { + // xor round constant, shuffle rc register + r0 = _mm512_mask_xor_epi64(r0, m0, r0, rc); + rc = _mm512_permutexvar_epi64(rc_p, rc); + + if (((24 - num_rounds + i + 1) % 8) == 0 && i != 23) { + // load next set of round constants + // note: this will bomb if num_rounds < 8 or num_rounds > 24. + rc = _mm512_loadu_epi64((void*) (RCS + 24 - num_rounds + (i + 1))); + } + } + } + + // store rows + _mm512_mask_storeu_epi64((void*) (s), m, r0), + _mm512_mask_storeu_epi64((void*) (s + 5), m, r1), + _mm512_mask_storeu_epi64((void*) (s + 10), m, r2), + _mm512_mask_storeu_epi64((void*) (s + 15), m, r3), + _mm512_mask_storeu_epi64((void*) (s + 20), m, r4); +} +#endif /* __AVX512F__ */ // one-shot keccak. static inline size_t keccak(sha3_state_t * const a, const uint8_t *m, size_t m_len, const size_t rate) { |