diff options
-rw-r--r-- | sha3.c | 249 |
1 files changed, 249 insertions, 0 deletions
@@ -68,6 +68,7 @@ #define BACKEND_NEON 3 // Neon backend. Slower than scalar. #define BACKEND_DIET_NEON 4 // Neon backend, fewer registers. Slower than scalar. #define BACKEND_HYBRID 5 // Hybrid scalar/neon backend. Slower than scalar. +#define BACKEND_AVX2 6 // AVX2 backend // if BACKEND is defined and set to 0 (the default), then unset it // and auto-detect the appropriate backend below @@ -79,6 +80,8 @@ #ifndef BACKEND #if defined(__AVX512F__) #define BACKEND BACKEND_AVX512 +#elif defined(__AVX2__) +#define BACKEND BACKEND_AVX2 #elif 0 && defined(__ARM_NEON) #define BACKEND BACKEND_NEON #else @@ -520,6 +523,248 @@ static inline void permute_n_avx512(uint64_t s[static 25], const size_t num_roun } #endif /* BACKEND == BACKEND_AVX512 */ +#if BACKEND == BACKEND_AVX2 +#include <immintrin.h> + +static const __m256i M0 = { ~0, 0, 0, 0 }, // mask, first lane only + K64 = { 64, 64, 64, 64 }; // 64, all lanes (used by ROLV) + +// load state array to avx2 registers +// FIXME: remove macro, not needed +#define AVX2_LOAD(s) __m256i \ + r0_lo = _mm256_loadu_epi64(s + 0), /* row 0, cols 0-3 */ \ + r1_lo = _mm256_loadu_epi64(s + 5), /* row 1, cols 0-3 */ \ + r2_lo = _mm256_loadu_epi64(s + 10), /* row 2, cols 0-3 */ \ + r3_lo = _mm256_loadu_epi64(s + 15), /* row 3, cols 0-3 */ \ + r4_lo = _mm256_loadu_epi64(s + 20), /* row 4, cols 0-3 */ \ + r0_hi = { s[ 4] }, /* row 0, col 4 */ \ + r1_hi = { s[ 9] }, /* row 1, col 4 */ \ + r2_hi = { s[14] }, /* row 2, col 4 */ \ + r3_hi = { s[19] }, /* row 3, col 4 */ \ + r4_hi = { s[24] }; /* row 4, col 4 */ + +// store avx2 registers to state array +#define AVX2_STORE(s) do { \ + union { long long int *i64; uint64_t *u64; } p = { .u64 = s }; \ + \ + /* store rows */ \ + _mm256_storeu_epi64(p.i64 + 0, r0_lo); /* row 0, cols 0-3 */ \ + _mm256_storeu_epi64(p.i64 + 5, r1_lo); /* row 1, cols 0-3 */ \ + _mm256_storeu_epi64(p.i64 + 10, r2_lo); /* row 2, cols 0-3 */ \ + _mm256_storeu_epi64(p.i64 + 15, r3_lo); /* row 3, cols 0-3 */ \ + _mm256_storeu_epi64(p.i64 + 20, r4_lo); /* row 4, cols 0-3 */ \ + _mm256_maskstore_epi64(p.i64 + 4, M0, r0_hi); /* row 0, col 4 */ \ + _mm256_maskstore_epi64(p.i64 + 9, M0, r1_hi); /* row 1, col 4 */ \ + _mm256_maskstore_epi64(p.i64 + 14, M0, r2_hi); /* row 2, col 4 */ \ + _mm256_maskstore_epi64(p.i64 + 19, M0, r3_hi); /* row 3, col 4 */ \ + _mm256_maskstore_epi64(p.i64 + 24, M0, r4_hi); /* row 4, col 4 */ \ +} while (0) + +// rotate left immediate +#define AVX2_ROLI(v, n) (_mm256_slli_epi64((v), (n)) | _mm256_srli_epi64((v), (64-(n)))) + +// rotate left by vector +#define AVX2_ROLV(v, n) (_mm256_sllv_epi64((v), (n)) | _mm256_srlv_epi64((v), (K64-(n)))) + +// pi permute IDs +#define PI_I0_LO 0x90 // 0, 0, 1, 2 -> 0b10010000 -> 0x90 +#define PI_I0_HI 0x03 // 3, 0, 0, 0 -> 0b00000011 -> 0x03 +#define PI_I1_LO 0x39 // 1, 2, 3, 0 -> 0b00111001 -> 0x39 +#define PI_I1_HI 0x00 // 0, 0, 0, 0 -> 0b00000000 -> 0x00 + +// chi permute IDs +#define CHI_I0_LO 0x39 // 1, 2, 3, 0 -> 0b00111001 -> 0x39 +#define CHI_I1_LO 0x0e // 2, 3, 0, 0 -> 0b00001110 -> 0x0e +#define CHI_I1_HI 0x01 // 1, 0, 0, 0 -> 0b00000001 -> 0x01 + +/** + * @brief AVX2 Keccak permutation. + * + * @param[in,out] s Keccak state (array of 25 64-bit integers). + * @param[in] num_rounds Number of rounds (12 or 24). + * + * How it works: + * + * 1. The Keccak state for cells in columns 0-3 is loaded from `s` (an + * array of 25 64-bit unsigned integers) into four 64-bit lanes of 5 + * 256-bit registers r0_lo-r4_lo. + * + * 2. The Keccak state for cells in column 4 is loaded from `s` into + * the first 64-bit lanes of 5 256-bit registers r0_hi-r4_hi. + * + * When steps #1 and #2 are done, the registers look like this: + * + * ------------------------------------------------ + * | | 64-bit Lanes of 256-bit Registers | + * |----------|-----------------------------------| + * | Register | Lane 0 | Lane 1 | Lane 2 | Lane 3 | + * |----------|--------|--------|--------|--------| + * | r0_lo | s[ 0] | s[ 1] | s[ 2] | s[ 3] | + * | r1_lo | s[ 5] | s[ 6] | s[ 7] | s[ 8] | + * | r2_lo | s[10] | s[11] | s[12] | s[13] | + * | r3_lo | s[15] | s[16] | s[17] | s[18] | + * | r4_lo | s[20] | s[21] | s[22] | s[23] | + * | r0_hi | s[ 4] | n/a | n/a | n/a | + * | r1_hi | s[ 9] | n/a | n/a | n/a | + * | r2_hi | s[14] | n/a | n/a | n/a | + * | r3_hi | s[19] | n/a | n/a | n/a | + * | r4_hi | s[24] | n/a | n/a | n/a | + * ------------------------------------------------ + * + * 3. The Keccak permutation is applied `num_rounds` times, where + * `num_rounds` is either 12 for TurboSHAKE and KangarooTwelve or 24 + * otherwise. + * + * (Note: for the Pi step the registers are stored back to the state + * array and then gathered to permute the state. This is different than + * the AVX-512 implementation because of register pressure). + * + * 4. The permuted Keccak state is copied back to `s`. + */ +static inline void permute_n_avx2(uint64_t s[static 25], const size_t num_rounds) { + // load state + AVX2_LOAD(s); + + // loop over rounds + for (size_t i = (SHA3_NUM_ROUNDS - num_rounds); __builtin_expect(i < SHA3_NUM_ROUNDS, 1); i++) { + // theta + { + // c = xor(r0, r1, r2, r3, r4) + const __m256i c_lo = r0_lo ^ r1_lo ^ r2_lo ^ r3_lo ^ r4_lo, + c_hi = r0_hi ^ r1_hi ^ r2_hi ^ r3_hi ^ r4_hi; + + // avx512 permute ids (for reference) + // static const __m512i I0 = { 4, 0, 1, 2, 3 }, + // I1 = { 1, 2, 3, 4, 0 }; + + // masks + static const __m256i M0 = { ~0, 0, 0, 0 }, // { 1, 0, 0, 0 } + M1 = { ~0, ~0, ~0, 0 }; // { 1, 1, 1, 0 } + + // d = xor(permute(i0, c), permute(i1, rol(c, 1))) + const __m256i d0_lo = (_mm256_permute4x64_epi64(c_lo, PI_I0_LO) & ~M0) | (c_hi & M0), + d0_hi = _mm256_permute4x64_epi64(c_lo, PI_I0_HI) & M0, + d1_lo = (_mm256_permute4x64_epi64(c_lo, PI_I1_LO) & M1) | (_mm256_permute4x64_epi64(c_hi, PI_I1_HI) & ~M1), + d1_hi = (c_lo & M0), + d_lo = d0_lo ^ AVX2_ROLI(d1_lo, 1), + d_hi = d0_hi ^ AVX2_ROLI(d1_hi, 1); + + // row = xor(row, d) + r0_lo ^= d_lo; r1_lo ^= d_lo; r2_lo ^= d_lo; r3_lo ^= d_lo; r4_lo ^= d_lo; + r0_hi ^= d_hi; r1_hi ^= d_hi; r2_hi ^= d_hi; r3_hi ^= d_hi; r4_hi ^= d_hi; + } + + // rho + { + // rotate values + static const __m256i V0_LO = { 0, 1, 62, 28 }, V0_HI = { 27 }, + V1_LO = { 36, 44, 6, 55 }, V1_HI = { 20 }, + V2_LO = { 3, 10, 43, 25 }, V2_HI = { 39 }, + V3_LO = { 41, 45, 15, 21 }, V3_HI = { 8 }, + V4_LO = { 18, 2, 61, 56 }, V4_HI = { 14 }; + + // rotate rows + // FIXME: could reduce rotates by permuting + r0_lo = AVX2_ROLV(r0_lo, V0_LO); r0_hi = AVX2_ROLV(r0_hi, V0_HI); + r1_lo = AVX2_ROLV(r1_lo, V1_LO); r1_hi = AVX2_ROLV(r1_hi, V1_HI); + r2_lo = AVX2_ROLV(r2_lo, V2_LO); r2_hi = AVX2_ROLV(r2_hi, V2_HI); + r3_lo = AVX2_ROLV(r3_lo, V3_LO); r3_hi = AVX2_ROLV(r3_hi, V3_HI); + r4_lo = AVX2_ROLV(r4_lo, V4_LO); r4_hi = AVX2_ROLV(r4_hi, V4_HI); + } + + // pi + // + // store state array, then gather to permute the state. note: with + // some work we could probably do in-register permutes, but + // benchmark first to see if this is worth the trouble. + { + static const __m256i V0_LO = { 0, 6, 12, 18 }, + V1_LO = { 3, 9, 10, 16 }, + V2_LO = { 1, 7, 13, 19 }, + V3_LO = { 4, 5, 11, 17 }, + V4_LO = { 2, 8, 14, 15 }; + static const size_t V0_HI = 24, V1_HI = 22, V2_HI = 20, V3_HI = 23, V4_HI = 21; + + // store rows to state, then gather to permute + AVX2_STORE(s); + + // re-load using gather to permute + union { long long int *i64; uint64_t *u64; } p = { .u64 = s }; + r0_lo = _mm256_i64gather_epi64(p.i64, V0_LO, 8); r0_hi = ((__m256i) { s[V0_HI] }); + r1_lo = _mm256_i64gather_epi64(p.i64, V1_LO, 8); r1_hi = ((__m256i) { s[V1_HI] }); + r2_lo = _mm256_i64gather_epi64(p.i64, V2_LO, 8); r2_hi = ((__m256i) { s[V2_HI] }); + r3_lo = _mm256_i64gather_epi64(p.i64, V3_LO, 8); r3_hi = ((__m256i) { s[V3_HI] }); + r4_lo = _mm256_i64gather_epi64(p.i64, V4_LO, 8); r4_hi = ((__m256i) { s[V4_HI] }); + } + + // chi + { + // masks + static const __m256i M0 = { ~0, 0, 0, 0 }, // { 1, 0, 0, 0 } + M1 = { ~0, ~0, ~0, 0 }, // { 1, 1, 1, 0 } + M2 = { ~0, ~0, 0, ~0 }; // { 1, 1, 0, 1 } + + // r0 + { + const __m256i a_lo = (_mm256_permute4x64_epi64(r0_lo, CHI_I0_LO) & M1) | (_mm256_permute4x64_epi64(r0_hi, CHI_I0_LO) & ~M1), + a_hi = r0_lo & M0, + b_lo = (_mm256_permute4x64_epi64(r0_lo, CHI_I1_LO) & M2) | (_mm256_permute4x64_epi64(r0_hi, CHI_I1_LO) & ~M0), + b_hi = _mm256_permute4x64_epi64(r0_lo, CHI_I1_HI) & M0; + + r0_lo ^= ~a_lo & b_lo; r0_hi ^= ~a_hi & b_hi; // r0 ^= ~a & b + } + + // r1 + { + const __m256i a_lo = (_mm256_permute4x64_epi64(r1_lo, CHI_I0_LO) & M1) | (_mm256_permute4x64_epi64(r1_hi, CHI_I0_LO) & ~M1), + a_hi = r1_lo & M0, + b_lo = (_mm256_permute4x64_epi64(r1_lo, CHI_I1_LO) & M2) | (_mm256_permute4x64_epi64(r1_hi, CHI_I1_LO) & ~M0), + b_hi = _mm256_permute4x64_epi64(r1_lo, CHI_I1_HI) & M0; + + r1_lo ^= ~a_lo & b_lo; r1_hi ^= ~a_hi & b_hi; // r1 ^= ~a & b + } + + // r2 + { + const __m256i a_lo = (_mm256_permute4x64_epi64(r2_lo, CHI_I0_LO) & M1) | (_mm256_permute4x64_epi64(r2_hi, CHI_I0_LO) & ~M1), + a_hi = r2_lo & M0, + b_lo = (_mm256_permute4x64_epi64(r2_lo, CHI_I1_LO) & M2) | (_mm256_permute4x64_epi64(r2_hi, CHI_I1_LO) & ~M0), + b_hi = _mm256_permute4x64_epi64(r2_lo, CHI_I1_HI) & M0; + + r2_lo ^= ~a_lo & b_lo; r2_hi ^= ~a_hi & b_hi; // r2 ^= ~a & b + } + + // r3 + { + const __m256i a_lo = (_mm256_permute4x64_epi64(r3_lo, CHI_I0_LO) & M1) | (_mm256_permute4x64_epi64(r3_hi, CHI_I0_LO) & ~M1), + a_hi = r3_lo & M0, + b_lo = (_mm256_permute4x64_epi64(r3_lo, CHI_I1_LO) & M2) | (_mm256_permute4x64_epi64(r3_hi, CHI_I1_LO) & ~M0), + b_hi = _mm256_permute4x64_epi64(r3_lo, CHI_I1_HI) & M0; + + r3_lo ^= ~a_lo & b_lo; r3_hi ^= ~a_hi & b_hi; // r3 ^= ~a & b + } + + // r4 + { + const __m256i a_lo = (_mm256_permute4x64_epi64(r4_lo, CHI_I0_LO) & M1) | (_mm256_permute4x64_epi64(r4_hi, CHI_I0_LO) & ~M1), + a_hi = r4_lo & M0, + b_lo = (_mm256_permute4x64_epi64(r4_lo, CHI_I1_LO) & M2) | (_mm256_permute4x64_epi64(r4_hi, CHI_I1_LO) & ~M0), + b_hi = _mm256_permute4x64_epi64(r4_lo, CHI_I1_HI) & M0; + + r4_lo ^= ~a_lo & b_lo; r4_hi ^= ~a_hi & b_hi; // r4 ^= ~a & b + } + } + + // iota + const __m256i rc = { RCS[i], 0, 0, 0 }; + r0_lo ^= rc; + } + + // store rows to state + AVX2_STORE(s); +} +#endif /* BACKEND == BACKEND_AVX2 */ + #if BACKEND == BACKEND_NEON #include <arm_neon.h> @@ -1310,6 +1555,8 @@ static inline void permute_n_hybrid(uint64_t a[static 25], const size_t num_roun // map permute_n() to active backend #if BACKEND == BACKEND_AVX512 #define permute_n permute_n_avx512 // use avx512 backend +#elif BACKEND == BACKEND_AVX2 +#define permute_n permute_n_avx2 // use AVX2 backend #elif BACKEND == BACKEND_NEON #define permute_n permute_n_neon // use neon backend #elif BACKEND == BACKEND_DIET_NEON @@ -2968,6 +3215,8 @@ void k12_once(const uint8_t *src, const size_t src_len, uint8_t *dst, const size const char *sha3_backend(void) { #if BACKEND == BACKEND_AVX512 return "avx512"; +#elif BACKEND == BACKEND_AVX2 + return "avx2"; #elif BACKEND == BACKEND_NEON return "neon"; #elif BACKEND == BACKEND_DIET_NEON |