diff options
| -rw-r--r-- | sha3.c | 258 | 
1 files changed, 257 insertions, 1 deletions
| @@ -43,6 +43,13 @@  // number of rounds for permute()  #define SHA3_NUM_ROUNDS 24 +#if !defined(__AVX512F__) || defined(SHA3_TEST) +// If AVX512 is supported and we are not building the test suite, +// then do not compile the scalar step functions below. +// +// (because they aren't used by the AVX512 implementation). + +// theta step of keccak permutation (scalar implementation)  static inline void theta(uint64_t a[static 25]) {    const uint64_t c[5] = {      a[0] ^ a[5] ^ a[10] ^ a[15] ^ a[20], @@ -67,6 +74,7 @@ static inline void theta(uint64_t a[static 25]) {    a[20] ^= d[0]; a[21] ^= d[1]; a[22] ^= d[2]; a[23] ^= d[3]; a[24] ^= d[4];  } +// rho step of keccak permutation (scalar implementation)  static inline void rho(uint64_t a[static 25]) {    a[1] = ROL(a[1], 1); // 1 % 64 = 1    a[2] = ROL(a[2], 62); // 190 % 64 = 62 @@ -94,6 +102,7 @@ static inline void rho(uint64_t a[static 25]) {    a[24] = ROL(a[24], 14); // 78 % 64 = 14  } +// pi step of keccak permutation (scalar implementation)  static inline void pi(uint64_t a[static 25]) {    uint64_t t[25] = { 0 };    memcpy(t, a, sizeof(t)); @@ -123,6 +132,7 @@ static inline void pi(uint64_t a[static 25]) {    a[24] = t[21];  } +// chi step of keccak permutation (scalar implementation)  static inline void chi(uint64_t a[static 25]) {    uint64_t t[25] = { 0 };    memcpy(t, a, sizeof(t)); @@ -153,6 +163,7 @@ static inline void chi(uint64_t a[static 25]) {    a[24] = t[24] ^ (~t[20] & t[21]);  } +// iota step of keccak permutation (scalar implementation)  static inline void iota(uint64_t a[static 25], const int i) {    // round constants (ambiguous in spec)    static const uint64_t RCS[] = { @@ -166,8 +177,10 @@ static inline void iota(uint64_t a[static 25], const int i) {    a[0] ^= RCS[i];  } +#endif /* !defined(__AVX512F__) || defined(SHA3_TEST) */ -// keccak permutation. +#ifndef __AVX512F__ +// keccak permutation (scalar implementation)  //  // note: clang is better about inlining this than gcc with a  // configurable number of rounds.  the configurable number of rounds is @@ -182,6 +195,249 @@ static inline void permute(uint64_t a[static 25], const size_t num_rounds) {      iota(a, 24 - num_rounds + i);    }  } +#endif /* __AVX512F__ */ + +#ifdef __AVX512F__ +#include <immintrin.h> + +// keccak permutation (avx512 implementation). +// +// copied from `permute_avx512_fast()` in `tests/permute/permute.c`. all +// steps are inlined as blocks. ~3x faster than scalar implementation, +// but could be sped up more. +static inline void permute(uint64_t s[static 25], const size_t num_rounds) { +  // unaligned load mask and permutation indices +  uint8_t mask = 0x1f, +          m0b = 0x01; +  const __mmask8 m = _load_mask8(&mask), +                 m0 = _load_mask8(&m0b); + +  // round constants (used in iota) +  static const uint64_t RCS[] = { +    0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, 0x8000000080008000ULL, +    0x000000000000808bULL, 0x0000000080000001ULL, 0x8000000080008081ULL, 0x8000000000008009ULL, +    0x000000000000008aULL, 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL, +    0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, 0x8000000000008003ULL, +    0x8000000000008002ULL, 0x8000000000000080ULL, 0x000000000000800aULL, 0x800000008000000aULL, +    0x8000000080008081ULL, 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL, +  }; + +  // load round constant +  // note: this will bomb if num_rounds < 8 or num_rounds > 24. +  __m512i rc = _mm512_loadu_epi64((void*) (RCS + 24 - num_rounds)); + +  // load rc permutation +  static const uint64_t rc_ps[8] = { 1, 2, 3, 4, 5, 6, 7, 0 }; +  const __m512i rc_p = _mm512_loadu_epi64((void*) rc_ps); + +  // load rows +  __m512i r0 = _mm512_maskz_loadu_epi64(m, (void*) (s)), +          r1 = _mm512_maskz_loadu_epi64(m, (void*) (s + 5)), +          r2 = _mm512_maskz_loadu_epi64(m, (void*) (s + 10)), +          r3 = _mm512_maskz_loadu_epi64(m, (void*) (s + 15)), +          r4 = _mm512_maskz_loadu_epi64(m, (void*) (s + 20)); + +  for (int i = 0; i < (int) num_rounds; i++) { +    // theta +    { +      const __m512i i0 = _mm512_setr_epi64(4, 0, 1, 2, 3, 5, 6, 7), +                    i1 = _mm512_setr_epi64(1, 2, 3, 4, 0, 5, 6, 7); + +      // c = xor(r0, r1, r2, r3, r4) +      const __m512i r01 = _mm512_xor_epi64(r0, r1), +                    r23 = _mm512_xor_epi64(r2, r3), +                    r03 = _mm512_xor_epi64(r01, r23), +                    c = _mm512_xor_epi64(r03, r4); + +      // d = xor(permute(i0, c), permute(i1, rol(c, 1))) +      const __m512i d0 = _mm512_permutexvar_epi64(i0, c), +                    d1 = _mm512_permutexvar_epi64(i1, _mm512_rol_epi64(c, 1)), +                    d = _mm512_xor_epi64(d0, d1); + +      // row = xor(row, d) +      r0 = _mm512_xor_epi64(r0, d); +      r1 = _mm512_xor_epi64(r1, d); +      r2 = _mm512_xor_epi64(r2, d); +      r3 = _mm512_xor_epi64(r3, d); +      r4 = _mm512_xor_epi64(r4, d); +    } + +    // rho +    { +      // unaligned load mask and rotate values +      static const uint64_t vs0[8] = { 0, 1, 62, 28, 27, 0, 0, 0 }, +                            vs1[8] = { 36, 44, 6, 55, 20, 0, 0, 0 }, +                            vs2[8] = { 3, 10, 43, 25, 39, 0, 0, 0 }, +                            vs3[8] = { 41, 45, 15, 21, 8, 0, 0, 0 }, +                            vs4[8] = { 18, 2, 61, 56, 14, 0, 0, 0 }; + +      // load rotate values +      const __m512i v0 = _mm512_loadu_epi64((void*) vs0), +                    v1 = _mm512_loadu_epi64((void*) vs1), +                    v2 = _mm512_loadu_epi64((void*) vs2), +                    v3 = _mm512_loadu_epi64((void*) vs3), +                    v4 = _mm512_loadu_epi64((void*) vs4); + +      // rotate +      r0 = _mm512_rolv_epi64(r0, v0); +      r1 = _mm512_rolv_epi64(r1, v1); +      r2 = _mm512_rolv_epi64(r2, v2); +      r3 = _mm512_rolv_epi64(r3, v3); +      r4 = _mm512_rolv_epi64(r4, v4); +    } + +    // pi +    { +      // mask bytes +      uint8_t m01b = 0x03, +              m23b = 0x0c, +              m4b = 0x10; + +      // load masks +      const __mmask8 m01 = _load_mask8(&m01b), +                     m23 = _load_mask8(&m23b), +                     m4 = _load_mask8(&m4b); + +      // permutation indices +      // +      // (note: these are masked so only the relevant indices for +      // _mm512_maskz_permutex2var_epi64() in each array are filled in) +      static const uint64_t t0_pis_01[8] = { 0, 8 + 1, 0, 0, 0, 0, 0, 0 }, +                            t0_pis_23[8] = { 0, 0, 2, 8 + 3, 0, 0, 0, 0 }, +                            t0_pis_4[8] = { 0, 0, 0, 0, 4, 0, 0, 0 }, + +                            t1_pis_01[8] = { 3, 8 + 4, 0, 0, 0, 0, 0, 0 }, +                            t1_pis_23[8] = { 0, 0, 0, 8 + 1, 0, 0, 0, 0 }, +                            t1_pis_4[8] = { 0, 0, 0, 0, 2, 0, 0, 0 }, + +                            t2_pis_01[8] = { 1, 8 + 2, 0, 0, 0, 0, 0, 0 }, +                            t2_pis_23[8] = { 0, 0, 3, 8 + 4, 0, 0, 0, 0 }, +                            t2_pis_4[8] = { 0, 0, 0, 0, 0, 0, 0, 0 }, + +                            t3_pis_01[8] = { 4, 8 + 0, 0, 0, 0, 0, 0, 0 }, +                            t3_pis_23[8] = { 0, 0, 1, 8 + 2, 0, 0, 0, 0 }, +                            t3_pis_4[8] = { 0, 0, 0, 0, 3, 0, 0, 0 }, + +                            t4_pis_01[8] = { 2, 8 + 3, 0, 0, 0, 0, 0, 0 }, +                            t4_pis_23[8] = { 0, 0, 4, 8 + 0, 0, 0, 0, 0 }, +                            t4_pis_4[8] = { 0, 0, 0, 0, 1, 0, 0, 0 }; + +      // load permutation indices +      const __m512i t0_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t0_pis_01), +                    t0_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t0_pis_23), +                    t0_p4  = _mm512_maskz_loadu_epi64(m4, (void*) t0_pis_4), + +                    t1_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t1_pis_01), +                    t1_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t1_pis_23), +                    t1_p4  = _mm512_maskz_loadu_epi64(m4, (void*) t1_pis_4), + +                    t2_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t2_pis_01), +                    t2_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t2_pis_23), +                    t2_p4  = _mm512_maskz_loadu_epi64(m4, (void*) t2_pis_4), + +                    t3_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t3_pis_01), +                    t3_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t3_pis_23), +                    t3_p4  = _mm512_maskz_loadu_epi64(m4, (void*) t3_pis_4), + +                    t4_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t4_pis_01), +                    t4_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t4_pis_23), +                    t4_p4  = _mm512_maskz_loadu_epi64(m4, (void*) t4_pis_4); + +      // permute rows +      const __m512i t0e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t0_p01, r1), +                    t0e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t0_p23, r3), +                    t0e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t0_p4, r0), +                    t0 = _mm512_or_epi64(_mm512_or_epi64(t0e0, t0e2), t0e4), + +                    t1e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t1_p01, r1), +                    t1e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t1_p23, r3), +                    t1e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t1_p4, r0), +                    t1 = _mm512_or_epi64(_mm512_or_epi64(t1e0, t1e2), t1e4), + +                    t2e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t2_p01, r1), +                    t2e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t2_p23, r3), +                    t2e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t2_p4, r0), +                    t2 = _mm512_or_epi64(_mm512_or_epi64(t2e0, t2e2), t2e4), + +                    t3e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t3_p01, r1), +                    t3e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t3_p23, r3), +                    t3e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t3_p4, r0), +                    t3 = _mm512_or_epi64(_mm512_or_epi64(t3e0, t3e2), t3e4), + +                    t4e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t4_p01, r1), +                    t4e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t4_p23, r3), +                    t4e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t4_p4, r0), +                    t4 = _mm512_or_epi64(_mm512_or_epi64(t4e0, t4e2), t4e4); + +      // store rows +      r0 = t0; +      r1 = t1; +      r2 = t2; +      r3 = t3; +      r4 = t4; +    } + +    // chi +    { +      // permutation indices +      static const uint64_t pis0[8] = { 1, 2, 3, 4, 0, 0, 0, 0 }, +                            pis1[8] = { 2, 3, 4, 0, 1, 0, 0, 0 }; + +      // load permutation indices +      const __m512i p0 = _mm512_maskz_loadu_epi64(m, (void*) pis0), +                    p1 = _mm512_maskz_loadu_epi64(m, (void*) pis1); + +      // permute rows +      const __m512i t0_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r0), +                    t0_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r0), +                    t0 = _mm512_xor_epi64(r0, _mm512_andnot_epi64(t0_e0, t0_e1)), + +                    t1_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r1), +                    t1_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r1), +                    t1 = _mm512_xor_epi64(r1, _mm512_andnot_epi64(t1_e0, t1_e1)), + +                    t2_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r2), +                    t2_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r2), +                    t2 = _mm512_xor_epi64(r2, _mm512_andnot_epi64(t2_e0, t2_e1)), + +                    t3_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r3), +                    t3_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r3), +                    t3 = _mm512_xor_epi64(r3, _mm512_andnot_epi64(t3_e0, t3_e1)), + +                    t4_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r4), +                    t4_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r4), +                    t4 = _mm512_xor_epi64(r4, _mm512_andnot_epi64(t4_e0, t4_e1)); + +      // store rows +      r0 = t0; +      r1 = t1; +      r2 = t2; +      r3 = t3; +      r4 = t4; +    } + +    // iota +    { +      // xor round constant, shuffle rc register +      r0 = _mm512_mask_xor_epi64(r0, m0, r0, rc); +      rc = _mm512_permutexvar_epi64(rc_p, rc); + +      if (((24 - num_rounds + i + 1) % 8) == 0 && i != 23) { +        // load next set of round constants +        // note: this will bomb if num_rounds < 8 or num_rounds > 24. +        rc = _mm512_loadu_epi64((void*) (RCS + 24 - num_rounds + (i + 1))); +      } +    } +  } + +  // store rows +  _mm512_mask_storeu_epi64((void*) (s), m, r0), +  _mm512_mask_storeu_epi64((void*) (s + 5), m, r1), +  _mm512_mask_storeu_epi64((void*) (s + 10), m, r2), +  _mm512_mask_storeu_epi64((void*) (s + 15), m, r3), +  _mm512_mask_storeu_epi64((void*) (s + 20), m, r4); +} +#endif /* __AVX512F__ */  // one-shot keccak.  static inline size_t keccak(sha3_state_t * const a, const uint8_t *m, size_t m_len, const size_t rate) { | 
