aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sha3.c249
1 files changed, 249 insertions, 0 deletions
diff --git a/sha3.c b/sha3.c
index 9b213d0..7979e72 100644
--- a/sha3.c
+++ b/sha3.c
@@ -68,6 +68,7 @@
#define BACKEND_NEON 3 // Neon backend. Slower than scalar.
#define BACKEND_DIET_NEON 4 // Neon backend, fewer registers. Slower than scalar.
#define BACKEND_HYBRID 5 // Hybrid scalar/neon backend. Slower than scalar.
+#define BACKEND_AVX2 6 // AVX2 backend
// if BACKEND is defined and set to 0 (the default), then unset it
// and auto-detect the appropriate backend below
@@ -79,6 +80,8 @@
#ifndef BACKEND
#if defined(__AVX512F__)
#define BACKEND BACKEND_AVX512
+#elif defined(__AVX2__)
+#define BACKEND BACKEND_AVX2
#elif 0 && defined(__ARM_NEON)
#define BACKEND BACKEND_NEON
#else
@@ -520,6 +523,248 @@ static inline void permute_n_avx512(uint64_t s[static 25], const size_t num_roun
}
#endif /* BACKEND == BACKEND_AVX512 */
+#if BACKEND == BACKEND_AVX2
+#include <immintrin.h>
+
+static const __m256i M0 = { ~0, 0, 0, 0 }, // mask, first lane only
+ K64 = { 64, 64, 64, 64 }; // 64, all lanes (used by ROLV)
+
+// load state array to avx2 registers
+// FIXME: remove macro, not needed
+#define AVX2_LOAD(s) __m256i \
+ r0_lo = _mm256_loadu_epi64(s + 0), /* row 0, cols 0-3 */ \
+ r1_lo = _mm256_loadu_epi64(s + 5), /* row 1, cols 0-3 */ \
+ r2_lo = _mm256_loadu_epi64(s + 10), /* row 2, cols 0-3 */ \
+ r3_lo = _mm256_loadu_epi64(s + 15), /* row 3, cols 0-3 */ \
+ r4_lo = _mm256_loadu_epi64(s + 20), /* row 4, cols 0-3 */ \
+ r0_hi = { s[ 4] }, /* row 0, col 4 */ \
+ r1_hi = { s[ 9] }, /* row 1, col 4 */ \
+ r2_hi = { s[14] }, /* row 2, col 4 */ \
+ r3_hi = { s[19] }, /* row 3, col 4 */ \
+ r4_hi = { s[24] }; /* row 4, col 4 */
+
+// store avx2 registers to state array
+#define AVX2_STORE(s) do { \
+ union { long long int *i64; uint64_t *u64; } p = { .u64 = s }; \
+ \
+ /* store rows */ \
+ _mm256_storeu_epi64(p.i64 + 0, r0_lo); /* row 0, cols 0-3 */ \
+ _mm256_storeu_epi64(p.i64 + 5, r1_lo); /* row 1, cols 0-3 */ \
+ _mm256_storeu_epi64(p.i64 + 10, r2_lo); /* row 2, cols 0-3 */ \
+ _mm256_storeu_epi64(p.i64 + 15, r3_lo); /* row 3, cols 0-3 */ \
+ _mm256_storeu_epi64(p.i64 + 20, r4_lo); /* row 4, cols 0-3 */ \
+ _mm256_maskstore_epi64(p.i64 + 4, M0, r0_hi); /* row 0, col 4 */ \
+ _mm256_maskstore_epi64(p.i64 + 9, M0, r1_hi); /* row 1, col 4 */ \
+ _mm256_maskstore_epi64(p.i64 + 14, M0, r2_hi); /* row 2, col 4 */ \
+ _mm256_maskstore_epi64(p.i64 + 19, M0, r3_hi); /* row 3, col 4 */ \
+ _mm256_maskstore_epi64(p.i64 + 24, M0, r4_hi); /* row 4, col 4 */ \
+} while (0)
+
+// rotate left immediate
+#define AVX2_ROLI(v, n) (_mm256_slli_epi64((v), (n)) | _mm256_srli_epi64((v), (64-(n))))
+
+// rotate left by vector
+#define AVX2_ROLV(v, n) (_mm256_sllv_epi64((v), (n)) | _mm256_srlv_epi64((v), (K64-(n))))
+
+// pi permute IDs
+#define PI_I0_LO 0x90 // 0, 0, 1, 2 -> 0b10010000 -> 0x90
+#define PI_I0_HI 0x03 // 3, 0, 0, 0 -> 0b00000011 -> 0x03
+#define PI_I1_LO 0x39 // 1, 2, 3, 0 -> 0b00111001 -> 0x39
+#define PI_I1_HI 0x00 // 0, 0, 0, 0 -> 0b00000000 -> 0x00
+
+// chi permute IDs
+#define CHI_I0_LO 0x39 // 1, 2, 3, 0 -> 0b00111001 -> 0x39
+#define CHI_I1_LO 0x0e // 2, 3, 0, 0 -> 0b00001110 -> 0x0e
+#define CHI_I1_HI 0x01 // 1, 0, 0, 0 -> 0b00000001 -> 0x01
+
+/**
+ * @brief AVX2 Keccak permutation.
+ *
+ * @param[in,out] s Keccak state (array of 25 64-bit integers).
+ * @param[in] num_rounds Number of rounds (12 or 24).
+ *
+ * How it works:
+ *
+ * 1. The Keccak state for cells in columns 0-3 is loaded from `s` (an
+ * array of 25 64-bit unsigned integers) into four 64-bit lanes of 5
+ * 256-bit registers r0_lo-r4_lo.
+ *
+ * 2. The Keccak state for cells in column 4 is loaded from `s` into
+ * the first 64-bit lanes of 5 256-bit registers r0_hi-r4_hi.
+ *
+ * When steps #1 and #2 are done, the registers look like this:
+ *
+ * ------------------------------------------------
+ * | | 64-bit Lanes of 256-bit Registers |
+ * |----------|-----------------------------------|
+ * | Register | Lane 0 | Lane 1 | Lane 2 | Lane 3 |
+ * |----------|--------|--------|--------|--------|
+ * | r0_lo | s[ 0] | s[ 1] | s[ 2] | s[ 3] |
+ * | r1_lo | s[ 5] | s[ 6] | s[ 7] | s[ 8] |
+ * | r2_lo | s[10] | s[11] | s[12] | s[13] |
+ * | r3_lo | s[15] | s[16] | s[17] | s[18] |
+ * | r4_lo | s[20] | s[21] | s[22] | s[23] |
+ * | r0_hi | s[ 4] | n/a | n/a | n/a |
+ * | r1_hi | s[ 9] | n/a | n/a | n/a |
+ * | r2_hi | s[14] | n/a | n/a | n/a |
+ * | r3_hi | s[19] | n/a | n/a | n/a |
+ * | r4_hi | s[24] | n/a | n/a | n/a |
+ * ------------------------------------------------
+ *
+ * 3. The Keccak permutation is applied `num_rounds` times, where
+ * `num_rounds` is either 12 for TurboSHAKE and KangarooTwelve or 24
+ * otherwise.
+ *
+ * (Note: for the Pi step the registers are stored back to the state
+ * array and then gathered to permute the state. This is different than
+ * the AVX-512 implementation because of register pressure).
+ *
+ * 4. The permuted Keccak state is copied back to `s`.
+ */
+static inline void permute_n_avx2(uint64_t s[static 25], const size_t num_rounds) {
+ // load state
+ AVX2_LOAD(s);
+
+ // loop over rounds
+ for (size_t i = (SHA3_NUM_ROUNDS - num_rounds); __builtin_expect(i < SHA3_NUM_ROUNDS, 1); i++) {
+ // theta
+ {
+ // c = xor(r0, r1, r2, r3, r4)
+ const __m256i c_lo = r0_lo ^ r1_lo ^ r2_lo ^ r3_lo ^ r4_lo,
+ c_hi = r0_hi ^ r1_hi ^ r2_hi ^ r3_hi ^ r4_hi;
+
+ // avx512 permute ids (for reference)
+ // static const __m512i I0 = { 4, 0, 1, 2, 3 },
+ // I1 = { 1, 2, 3, 4, 0 };
+
+ // masks
+ static const __m256i M0 = { ~0, 0, 0, 0 }, // { 1, 0, 0, 0 }
+ M1 = { ~0, ~0, ~0, 0 }; // { 1, 1, 1, 0 }
+
+ // d = xor(permute(i0, c), permute(i1, rol(c, 1)))
+ const __m256i d0_lo = (_mm256_permute4x64_epi64(c_lo, PI_I0_LO) & ~M0) | (c_hi & M0),
+ d0_hi = _mm256_permute4x64_epi64(c_lo, PI_I0_HI) & M0,
+ d1_lo = (_mm256_permute4x64_epi64(c_lo, PI_I1_LO) & M1) | (_mm256_permute4x64_epi64(c_hi, PI_I1_HI) & ~M1),
+ d1_hi = (c_lo & M0),
+ d_lo = d0_lo ^ AVX2_ROLI(d1_lo, 1),
+ d_hi = d0_hi ^ AVX2_ROLI(d1_hi, 1);
+
+ // row = xor(row, d)
+ r0_lo ^= d_lo; r1_lo ^= d_lo; r2_lo ^= d_lo; r3_lo ^= d_lo; r4_lo ^= d_lo;
+ r0_hi ^= d_hi; r1_hi ^= d_hi; r2_hi ^= d_hi; r3_hi ^= d_hi; r4_hi ^= d_hi;
+ }
+
+ // rho
+ {
+ // rotate values
+ static const __m256i V0_LO = { 0, 1, 62, 28 }, V0_HI = { 27 },
+ V1_LO = { 36, 44, 6, 55 }, V1_HI = { 20 },
+ V2_LO = { 3, 10, 43, 25 }, V2_HI = { 39 },
+ V3_LO = { 41, 45, 15, 21 }, V3_HI = { 8 },
+ V4_LO = { 18, 2, 61, 56 }, V4_HI = { 14 };
+
+ // rotate rows
+ // FIXME: could reduce rotates by permuting
+ r0_lo = AVX2_ROLV(r0_lo, V0_LO); r0_hi = AVX2_ROLV(r0_hi, V0_HI);
+ r1_lo = AVX2_ROLV(r1_lo, V1_LO); r1_hi = AVX2_ROLV(r1_hi, V1_HI);
+ r2_lo = AVX2_ROLV(r2_lo, V2_LO); r2_hi = AVX2_ROLV(r2_hi, V2_HI);
+ r3_lo = AVX2_ROLV(r3_lo, V3_LO); r3_hi = AVX2_ROLV(r3_hi, V3_HI);
+ r4_lo = AVX2_ROLV(r4_lo, V4_LO); r4_hi = AVX2_ROLV(r4_hi, V4_HI);
+ }
+
+ // pi
+ //
+ // store state array, then gather to permute the state. note: with
+ // some work we could probably do in-register permutes, but
+ // benchmark first to see if this is worth the trouble.
+ {
+ static const __m256i V0_LO = { 0, 6, 12, 18 },
+ V1_LO = { 3, 9, 10, 16 },
+ V2_LO = { 1, 7, 13, 19 },
+ V3_LO = { 4, 5, 11, 17 },
+ V4_LO = { 2, 8, 14, 15 };
+ static const size_t V0_HI = 24, V1_HI = 22, V2_HI = 20, V3_HI = 23, V4_HI = 21;
+
+ // store rows to state, then gather to permute
+ AVX2_STORE(s);
+
+ // re-load using gather to permute
+ union { long long int *i64; uint64_t *u64; } p = { .u64 = s };
+ r0_lo = _mm256_i64gather_epi64(p.i64, V0_LO, 8); r0_hi = ((__m256i) { s[V0_HI] });
+ r1_lo = _mm256_i64gather_epi64(p.i64, V1_LO, 8); r1_hi = ((__m256i) { s[V1_HI] });
+ r2_lo = _mm256_i64gather_epi64(p.i64, V2_LO, 8); r2_hi = ((__m256i) { s[V2_HI] });
+ r3_lo = _mm256_i64gather_epi64(p.i64, V3_LO, 8); r3_hi = ((__m256i) { s[V3_HI] });
+ r4_lo = _mm256_i64gather_epi64(p.i64, V4_LO, 8); r4_hi = ((__m256i) { s[V4_HI] });
+ }
+
+ // chi
+ {
+ // masks
+ static const __m256i M0 = { ~0, 0, 0, 0 }, // { 1, 0, 0, 0 }
+ M1 = { ~0, ~0, ~0, 0 }, // { 1, 1, 1, 0 }
+ M2 = { ~0, ~0, 0, ~0 }; // { 1, 1, 0, 1 }
+
+ // r0
+ {
+ const __m256i a_lo = (_mm256_permute4x64_epi64(r0_lo, CHI_I0_LO) & M1) | (_mm256_permute4x64_epi64(r0_hi, CHI_I0_LO) & ~M1),
+ a_hi = r0_lo & M0,
+ b_lo = (_mm256_permute4x64_epi64(r0_lo, CHI_I1_LO) & M2) | (_mm256_permute4x64_epi64(r0_hi, CHI_I1_LO) & ~M0),
+ b_hi = _mm256_permute4x64_epi64(r0_lo, CHI_I1_HI) & M0;
+
+ r0_lo ^= ~a_lo & b_lo; r0_hi ^= ~a_hi & b_hi; // r0 ^= ~a & b
+ }
+
+ // r1
+ {
+ const __m256i a_lo = (_mm256_permute4x64_epi64(r1_lo, CHI_I0_LO) & M1) | (_mm256_permute4x64_epi64(r1_hi, CHI_I0_LO) & ~M1),
+ a_hi = r1_lo & M0,
+ b_lo = (_mm256_permute4x64_epi64(r1_lo, CHI_I1_LO) & M2) | (_mm256_permute4x64_epi64(r1_hi, CHI_I1_LO) & ~M0),
+ b_hi = _mm256_permute4x64_epi64(r1_lo, CHI_I1_HI) & M0;
+
+ r1_lo ^= ~a_lo & b_lo; r1_hi ^= ~a_hi & b_hi; // r1 ^= ~a & b
+ }
+
+ // r2
+ {
+ const __m256i a_lo = (_mm256_permute4x64_epi64(r2_lo, CHI_I0_LO) & M1) | (_mm256_permute4x64_epi64(r2_hi, CHI_I0_LO) & ~M1),
+ a_hi = r2_lo & M0,
+ b_lo = (_mm256_permute4x64_epi64(r2_lo, CHI_I1_LO) & M2) | (_mm256_permute4x64_epi64(r2_hi, CHI_I1_LO) & ~M0),
+ b_hi = _mm256_permute4x64_epi64(r2_lo, CHI_I1_HI) & M0;
+
+ r2_lo ^= ~a_lo & b_lo; r2_hi ^= ~a_hi & b_hi; // r2 ^= ~a & b
+ }
+
+ // r3
+ {
+ const __m256i a_lo = (_mm256_permute4x64_epi64(r3_lo, CHI_I0_LO) & M1) | (_mm256_permute4x64_epi64(r3_hi, CHI_I0_LO) & ~M1),
+ a_hi = r3_lo & M0,
+ b_lo = (_mm256_permute4x64_epi64(r3_lo, CHI_I1_LO) & M2) | (_mm256_permute4x64_epi64(r3_hi, CHI_I1_LO) & ~M0),
+ b_hi = _mm256_permute4x64_epi64(r3_lo, CHI_I1_HI) & M0;
+
+ r3_lo ^= ~a_lo & b_lo; r3_hi ^= ~a_hi & b_hi; // r3 ^= ~a & b
+ }
+
+ // r4
+ {
+ const __m256i a_lo = (_mm256_permute4x64_epi64(r4_lo, CHI_I0_LO) & M1) | (_mm256_permute4x64_epi64(r4_hi, CHI_I0_LO) & ~M1),
+ a_hi = r4_lo & M0,
+ b_lo = (_mm256_permute4x64_epi64(r4_lo, CHI_I1_LO) & M2) | (_mm256_permute4x64_epi64(r4_hi, CHI_I1_LO) & ~M0),
+ b_hi = _mm256_permute4x64_epi64(r4_lo, CHI_I1_HI) & M0;
+
+ r4_lo ^= ~a_lo & b_lo; r4_hi ^= ~a_hi & b_hi; // r4 ^= ~a & b
+ }
+ }
+
+ // iota
+ const __m256i rc = { RCS[i], 0, 0, 0 };
+ r0_lo ^= rc;
+ }
+
+ // store rows to state
+ AVX2_STORE(s);
+}
+#endif /* BACKEND == BACKEND_AVX2 */
+
#if BACKEND == BACKEND_NEON
#include <arm_neon.h>
@@ -1310,6 +1555,8 @@ static inline void permute_n_hybrid(uint64_t a[static 25], const size_t num_roun
// map permute_n() to active backend
#if BACKEND == BACKEND_AVX512
#define permute_n permute_n_avx512 // use avx512 backend
+#elif BACKEND == BACKEND_AVX2
+#define permute_n permute_n_avx2 // use AVX2 backend
#elif BACKEND == BACKEND_NEON
#define permute_n permute_n_neon // use neon backend
#elif BACKEND == BACKEND_DIET_NEON
@@ -2968,6 +3215,8 @@ void k12_once(const uint8_t *src, const size_t src_len, uint8_t *dst, const size
const char *sha3_backend(void) {
#if BACKEND == BACKEND_AVX512
return "avx512";
+#elif BACKEND == BACKEND_AVX2
+ return "avx2";
#elif BACKEND == BACKEND_NEON
return "neon";
#elif BACKEND == BACKEND_DIET_NEON