summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Duncan <pabs@pablotron.org>2023-09-18 19:48:10 -0400
committerPaul Duncan <pabs@pablotron.org>2023-09-18 19:48:10 -0400
commit75acc4ed20683c7d70f44b5b4b2122094a3d66d6 (patch)
tree675cf496edbec41959804b37caed5728bff5d4a3
parent0d39d865eec8f6def6068c3bc08b1c55e7358f3e (diff)
downloadsha3-75acc4ed20683c7d70f44b5b4b2122094a3d66d6.tar.bz2
sha3-75acc4ed20683c7d70f44b5b4b2122094a3d66d6.zip
sha3.c: add avx512 permute(), add step comments, do not build scalar steps unless necessary
-rw-r--r--sha3.c258
1 files changed, 257 insertions, 1 deletions
diff --git a/sha3.c b/sha3.c
index 3c3108a..0e9a63d 100644
--- a/sha3.c
+++ b/sha3.c
@@ -43,6 +43,13 @@
// number of rounds for permute()
#define SHA3_NUM_ROUNDS 24
+#if !defined(__AVX512F__) || defined(SHA3_TEST)
+// If AVX512 is supported and we are not building the test suite,
+// then do not compile the scalar step functions below.
+//
+// (because they aren't used by the AVX512 implementation).
+
+// theta step of keccak permutation (scalar implementation)
static inline void theta(uint64_t a[static 25]) {
const uint64_t c[5] = {
a[0] ^ a[5] ^ a[10] ^ a[15] ^ a[20],
@@ -67,6 +74,7 @@ static inline void theta(uint64_t a[static 25]) {
a[20] ^= d[0]; a[21] ^= d[1]; a[22] ^= d[2]; a[23] ^= d[3]; a[24] ^= d[4];
}
+// rho step of keccak permutation (scalar implementation)
static inline void rho(uint64_t a[static 25]) {
a[1] = ROL(a[1], 1); // 1 % 64 = 1
a[2] = ROL(a[2], 62); // 190 % 64 = 62
@@ -94,6 +102,7 @@ static inline void rho(uint64_t a[static 25]) {
a[24] = ROL(a[24], 14); // 78 % 64 = 14
}
+// pi step of keccak permutation (scalar implementation)
static inline void pi(uint64_t a[static 25]) {
uint64_t t[25] = { 0 };
memcpy(t, a, sizeof(t));
@@ -123,6 +132,7 @@ static inline void pi(uint64_t a[static 25]) {
a[24] = t[21];
}
+// chi step of keccak permutation (scalar implementation)
static inline void chi(uint64_t a[static 25]) {
uint64_t t[25] = { 0 };
memcpy(t, a, sizeof(t));
@@ -153,6 +163,7 @@ static inline void chi(uint64_t a[static 25]) {
a[24] = t[24] ^ (~t[20] & t[21]);
}
+// iota step of keccak permutation (scalar implementation)
static inline void iota(uint64_t a[static 25], const int i) {
// round constants (ambiguous in spec)
static const uint64_t RCS[] = {
@@ -166,8 +177,10 @@ static inline void iota(uint64_t a[static 25], const int i) {
a[0] ^= RCS[i];
}
+#endif /* !defined(__AVX512F__) || defined(SHA3_TEST) */
-// keccak permutation.
+#ifndef __AVX512F__
+// keccak permutation (scalar implementation)
//
// note: clang is better about inlining this than gcc with a
// configurable number of rounds. the configurable number of rounds is
@@ -182,6 +195,249 @@ static inline void permute(uint64_t a[static 25], const size_t num_rounds) {
iota(a, 24 - num_rounds + i);
}
}
+#endif /* __AVX512F__ */
+
+#ifdef __AVX512F__
+#include <immintrin.h>
+
+// keccak permutation (avx512 implementation).
+//
+// copied from `permute_avx512_fast()` in `tests/permute/permute.c`. all
+// steps are inlined as blocks. ~3x faster than scalar implementation,
+// but could be sped up more.
+static inline void permute(uint64_t s[static 25], const size_t num_rounds) {
+ // unaligned load mask and permutation indices
+ uint8_t mask = 0x1f,
+ m0b = 0x01;
+ const __mmask8 m = _load_mask8(&mask),
+ m0 = _load_mask8(&m0b);
+
+ // round constants (used in iota)
+ static const uint64_t RCS[] = {
+ 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, 0x8000000080008000ULL,
+ 0x000000000000808bULL, 0x0000000080000001ULL, 0x8000000080008081ULL, 0x8000000000008009ULL,
+ 0x000000000000008aULL, 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL,
+ 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, 0x8000000000008003ULL,
+ 0x8000000000008002ULL, 0x8000000000000080ULL, 0x000000000000800aULL, 0x800000008000000aULL,
+ 0x8000000080008081ULL, 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL,
+ };
+
+ // load round constant
+ // note: this will bomb if num_rounds < 8 or num_rounds > 24.
+ __m512i rc = _mm512_loadu_epi64((void*) (RCS + 24 - num_rounds));
+
+ // load rc permutation
+ static const uint64_t rc_ps[8] = { 1, 2, 3, 4, 5, 6, 7, 0 };
+ const __m512i rc_p = _mm512_loadu_epi64((void*) rc_ps);
+
+ // load rows
+ __m512i r0 = _mm512_maskz_loadu_epi64(m, (void*) (s)),
+ r1 = _mm512_maskz_loadu_epi64(m, (void*) (s + 5)),
+ r2 = _mm512_maskz_loadu_epi64(m, (void*) (s + 10)),
+ r3 = _mm512_maskz_loadu_epi64(m, (void*) (s + 15)),
+ r4 = _mm512_maskz_loadu_epi64(m, (void*) (s + 20));
+
+ for (int i = 0; i < (int) num_rounds; i++) {
+ // theta
+ {
+ const __m512i i0 = _mm512_setr_epi64(4, 0, 1, 2, 3, 5, 6, 7),
+ i1 = _mm512_setr_epi64(1, 2, 3, 4, 0, 5, 6, 7);
+
+ // c = xor(r0, r1, r2, r3, r4)
+ const __m512i r01 = _mm512_xor_epi64(r0, r1),
+ r23 = _mm512_xor_epi64(r2, r3),
+ r03 = _mm512_xor_epi64(r01, r23),
+ c = _mm512_xor_epi64(r03, r4);
+
+ // d = xor(permute(i0, c), permute(i1, rol(c, 1)))
+ const __m512i d0 = _mm512_permutexvar_epi64(i0, c),
+ d1 = _mm512_permutexvar_epi64(i1, _mm512_rol_epi64(c, 1)),
+ d = _mm512_xor_epi64(d0, d1);
+
+ // row = xor(row, d)
+ r0 = _mm512_xor_epi64(r0, d);
+ r1 = _mm512_xor_epi64(r1, d);
+ r2 = _mm512_xor_epi64(r2, d);
+ r3 = _mm512_xor_epi64(r3, d);
+ r4 = _mm512_xor_epi64(r4, d);
+ }
+
+ // rho
+ {
+ // unaligned load mask and rotate values
+ static const uint64_t vs0[8] = { 0, 1, 62, 28, 27, 0, 0, 0 },
+ vs1[8] = { 36, 44, 6, 55, 20, 0, 0, 0 },
+ vs2[8] = { 3, 10, 43, 25, 39, 0, 0, 0 },
+ vs3[8] = { 41, 45, 15, 21, 8, 0, 0, 0 },
+ vs4[8] = { 18, 2, 61, 56, 14, 0, 0, 0 };
+
+ // load rotate values
+ const __m512i v0 = _mm512_loadu_epi64((void*) vs0),
+ v1 = _mm512_loadu_epi64((void*) vs1),
+ v2 = _mm512_loadu_epi64((void*) vs2),
+ v3 = _mm512_loadu_epi64((void*) vs3),
+ v4 = _mm512_loadu_epi64((void*) vs4);
+
+ // rotate
+ r0 = _mm512_rolv_epi64(r0, v0);
+ r1 = _mm512_rolv_epi64(r1, v1);
+ r2 = _mm512_rolv_epi64(r2, v2);
+ r3 = _mm512_rolv_epi64(r3, v3);
+ r4 = _mm512_rolv_epi64(r4, v4);
+ }
+
+ // pi
+ {
+ // mask bytes
+ uint8_t m01b = 0x03,
+ m23b = 0x0c,
+ m4b = 0x10;
+
+ // load masks
+ const __mmask8 m01 = _load_mask8(&m01b),
+ m23 = _load_mask8(&m23b),
+ m4 = _load_mask8(&m4b);
+
+ // permutation indices
+ //
+ // (note: these are masked so only the relevant indices for
+ // _mm512_maskz_permutex2var_epi64() in each array are filled in)
+ static const uint64_t t0_pis_01[8] = { 0, 8 + 1, 0, 0, 0, 0, 0, 0 },
+ t0_pis_23[8] = { 0, 0, 2, 8 + 3, 0, 0, 0, 0 },
+ t0_pis_4[8] = { 0, 0, 0, 0, 4, 0, 0, 0 },
+
+ t1_pis_01[8] = { 3, 8 + 4, 0, 0, 0, 0, 0, 0 },
+ t1_pis_23[8] = { 0, 0, 0, 8 + 1, 0, 0, 0, 0 },
+ t1_pis_4[8] = { 0, 0, 0, 0, 2, 0, 0, 0 },
+
+ t2_pis_01[8] = { 1, 8 + 2, 0, 0, 0, 0, 0, 0 },
+ t2_pis_23[8] = { 0, 0, 3, 8 + 4, 0, 0, 0, 0 },
+ t2_pis_4[8] = { 0, 0, 0, 0, 0, 0, 0, 0 },
+
+ t3_pis_01[8] = { 4, 8 + 0, 0, 0, 0, 0, 0, 0 },
+ t3_pis_23[8] = { 0, 0, 1, 8 + 2, 0, 0, 0, 0 },
+ t3_pis_4[8] = { 0, 0, 0, 0, 3, 0, 0, 0 },
+
+ t4_pis_01[8] = { 2, 8 + 3, 0, 0, 0, 0, 0, 0 },
+ t4_pis_23[8] = { 0, 0, 4, 8 + 0, 0, 0, 0, 0 },
+ t4_pis_4[8] = { 0, 0, 0, 0, 1, 0, 0, 0 };
+
+ // load permutation indices
+ const __m512i t0_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t0_pis_01),
+ t0_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t0_pis_23),
+ t0_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t0_pis_4),
+
+ t1_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t1_pis_01),
+ t1_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t1_pis_23),
+ t1_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t1_pis_4),
+
+ t2_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t2_pis_01),
+ t2_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t2_pis_23),
+ t2_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t2_pis_4),
+
+ t3_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t3_pis_01),
+ t3_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t3_pis_23),
+ t3_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t3_pis_4),
+
+ t4_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t4_pis_01),
+ t4_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t4_pis_23),
+ t4_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t4_pis_4);
+
+ // permute rows
+ const __m512i t0e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t0_p01, r1),
+ t0e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t0_p23, r3),
+ t0e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t0_p4, r0),
+ t0 = _mm512_or_epi64(_mm512_or_epi64(t0e0, t0e2), t0e4),
+
+ t1e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t1_p01, r1),
+ t1e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t1_p23, r3),
+ t1e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t1_p4, r0),
+ t1 = _mm512_or_epi64(_mm512_or_epi64(t1e0, t1e2), t1e4),
+
+ t2e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t2_p01, r1),
+ t2e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t2_p23, r3),
+ t2e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t2_p4, r0),
+ t2 = _mm512_or_epi64(_mm512_or_epi64(t2e0, t2e2), t2e4),
+
+ t3e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t3_p01, r1),
+ t3e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t3_p23, r3),
+ t3e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t3_p4, r0),
+ t3 = _mm512_or_epi64(_mm512_or_epi64(t3e0, t3e2), t3e4),
+
+ t4e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t4_p01, r1),
+ t4e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t4_p23, r3),
+ t4e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t4_p4, r0),
+ t4 = _mm512_or_epi64(_mm512_or_epi64(t4e0, t4e2), t4e4);
+
+ // store rows
+ r0 = t0;
+ r1 = t1;
+ r2 = t2;
+ r3 = t3;
+ r4 = t4;
+ }
+
+ // chi
+ {
+ // permutation indices
+ static const uint64_t pis0[8] = { 1, 2, 3, 4, 0, 0, 0, 0 },
+ pis1[8] = { 2, 3, 4, 0, 1, 0, 0, 0 };
+
+ // load permutation indices
+ const __m512i p0 = _mm512_maskz_loadu_epi64(m, (void*) pis0),
+ p1 = _mm512_maskz_loadu_epi64(m, (void*) pis1);
+
+ // permute rows
+ const __m512i t0_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r0),
+ t0_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r0),
+ t0 = _mm512_xor_epi64(r0, _mm512_andnot_epi64(t0_e0, t0_e1)),
+
+ t1_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r1),
+ t1_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r1),
+ t1 = _mm512_xor_epi64(r1, _mm512_andnot_epi64(t1_e0, t1_e1)),
+
+ t2_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r2),
+ t2_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r2),
+ t2 = _mm512_xor_epi64(r2, _mm512_andnot_epi64(t2_e0, t2_e1)),
+
+ t3_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r3),
+ t3_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r3),
+ t3 = _mm512_xor_epi64(r3, _mm512_andnot_epi64(t3_e0, t3_e1)),
+
+ t4_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r4),
+ t4_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r4),
+ t4 = _mm512_xor_epi64(r4, _mm512_andnot_epi64(t4_e0, t4_e1));
+
+ // store rows
+ r0 = t0;
+ r1 = t1;
+ r2 = t2;
+ r3 = t3;
+ r4 = t4;
+ }
+
+ // iota
+ {
+ // xor round constant, shuffle rc register
+ r0 = _mm512_mask_xor_epi64(r0, m0, r0, rc);
+ rc = _mm512_permutexvar_epi64(rc_p, rc);
+
+ if (((24 - num_rounds + i + 1) % 8) == 0 && i != 23) {
+ // load next set of round constants
+ // note: this will bomb if num_rounds < 8 or num_rounds > 24.
+ rc = _mm512_loadu_epi64((void*) (RCS + 24 - num_rounds + (i + 1)));
+ }
+ }
+ }
+
+ // store rows
+ _mm512_mask_storeu_epi64((void*) (s), m, r0),
+ _mm512_mask_storeu_epi64((void*) (s + 5), m, r1),
+ _mm512_mask_storeu_epi64((void*) (s + 10), m, r2),
+ _mm512_mask_storeu_epi64((void*) (s + 15), m, r3),
+ _mm512_mask_storeu_epi64((void*) (s + 20), m, r4);
+}
+#endif /* __AVX512F__ */
// one-shot keccak.
static inline size_t keccak(sha3_state_t * const a, const uint8_t *m, size_t m_len, const size_t rate) {