aboutsummaryrefslogtreecommitdiff
path: root/sha3.c
diff options
context:
space:
mode:
authorPaul Duncan <pabs@pablotron.org>2024-04-29 12:08:54 -0400
committerPaul Duncan <pabs@pablotron.org>2024-04-29 12:08:54 -0400
commit6b2eb7da27a27b5d48caacfb928d9279799bcd6c (patch)
tree077cc64c9fe14eb33216ba6ac678d1fa20e7572a /sha3.c
parent4f2e3ab022e887e33aff5e2dccb8e6dc7074cbcf (diff)
downloadsha3-6b2eb7da27a27b5d48caacfb928d9279799bcd6c.tar.bz2
sha3-6b2eb7da27a27b5d48caacfb928d9279799bcd6c.zip
sha3.c: permute{,12}_avx512(): optimize, update header comment
Diffstat (limited to 'sha3.c')
-rw-r--r--sha3.c662
1 files changed, 295 insertions, 367 deletions
diff --git a/sha3.c b/sha3.c
index 887c5c2..83ef2b9 100644
--- a/sha3.c
+++ b/sha3.c
@@ -34,6 +34,9 @@
// number of rounds for permute()
#define SHA3_NUM_ROUNDS 24
+// align memory to N bytes
+#define ALIGN(N) __attribute__((aligned(N)))
+
// round constants (used by iota)
static const uint64_t RCS[] = {
0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, 0x8000000080008000ULL,
@@ -211,66 +214,54 @@ static inline void permute12_scalar(uint64_t a[static 25]) {
//
// how it operates (roughly):
//
-// 1. load rows from state `s` into avx512 registers r0-r4, like so:
-//
-// r0 <- | s[ 0] | s[ 1] | s[ 2] | s[ 3] | s[ 4] | n/a | n/a | n/a |
-// r1 <- | s[ 5] | s[ 6] | s[ 7] | s[ 8] | s[ 9] | n/a | n/a | n/a |
-// r2 <- | s[10] | s[11] | s[12] | s[13] | s[14] | n/a | n/a | n/a |
-// r3 <- | s[15] | s[16] | s[17] | s[18] | s[19] | n/a | n/a | n/a |
-// r4 <- | s[20] | s[21] | s[22] | s[23] | s[24] | n/a | n/a | n/a |
+// 1. load rows from state `s` into the first 5 64-bit lanes of AVX-512
+// registers r0-r4, like so:
//
-// 2. load the first 8 round constants for iota into an avx512 `ra`
-// (round constants) register.
+// -----------------------------------------------------------------
+// | | Lanes |
+// |-----|---------------------------------------------------------|
+// | Reg | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 |
+// |-----|-------|-------|-------|-------|-------|-----|-----|-----|
+// | r0 | s[ 0] | s[ 1] | s[ 2] | s[ 3] | s[ 4] | n/a | n/a | n/a |
+// | r0 | s[ 0] | s[ 1] | s[ 2] | s[ 3] | s[ 4] | n/a | n/a | n/a |
+// | r1 | s[ 5] | s[ 6] | s[ 7] | s[ 8] | s[ 9] | n/a | n/a | n/a |
+// | r2 | s[10] | s[11] | s[12] | s[13] | s[14] | n/a | n/a | n/a |
+// | r3 | s[15] | s[16] | s[17] | s[18] | s[19] | n/a | n/a | n/a |
+// | r4 | s[20] | s[21] | s[22] | s[23] | s[24] | n/a | n/a | n/a |
+// -----------------------------------------------------------------
//
-// 3. for each round:
+// 2. For each round of 24 rounds:
// a. Perform theta, rho, pi, and chi steps. pi, in particular, has
// a large number of permutation registers (so it may spill).
-// b. Perform iota with the first round constant in `rc`, then permute
-// `rc`. If we have exhausted all 8 round constants and we are not
-// at the final round, then load the next 8 round constants.
+// b. Load round constant for current round and perform iota step.
//
-// 4. copy the rows from registers r0-r4 back to the state `s`.
+// 3. store the rows first 5 64-bit lanes of registers r0-r4 back to the
+// state `s`.
//
-// as noted above, this is not the most efficient avx512 implementation;
-// the row registers have three empty slots and there are a lot of loads
-// that could be removed with a little more work.
static inline void permute_avx512(uint64_t s[static 25]) {
- // unaligned load mask and permutation indices
- uint8_t mask = 0x1f,
- m0b = 0x01;
- const __mmask8 m = _load_mask8(&mask),
- m0 = _load_mask8(&m0b);
-
- // load round constant
- // note: this will bomb if num_rounds < 8 or num_rounds > 24.
- __m512i rc = _mm512_loadu_epi64((void*) RCS);
-
- // load rc permutation
- static const uint64_t rc_ps[8] = { 1, 2, 3, 4, 5, 6, 7, 0 };
- const __m512i rc_p = _mm512_loadu_epi64((void*) rc_ps);
-
- // load rows
- __m512i r0 = _mm512_maskz_loadu_epi64(m, (void*) (s)),
- r1 = _mm512_maskz_loadu_epi64(m, (void*) (s + 5)),
- r2 = _mm512_maskz_loadu_epi64(m, (void*) (s + 10)),
- r3 = _mm512_maskz_loadu_epi64(m, (void*) (s + 15)),
- r4 = _mm512_maskz_loadu_epi64(m, (void*) (s + 20));
-
- for (int i = 0; i < SHA3_NUM_ROUNDS; i++) {
+ // load rows (r0-r4)
+ __m512i r0 = _mm512_maskz_loadu_epi64(0x1f, s + 0), // row 0
+ r1 = _mm512_maskz_loadu_epi64(0x1f, s + 5), // row 1
+ r2 = _mm512_maskz_loadu_epi64(0x1f, s + 10), // row 2
+ r3 = _mm512_maskz_loadu_epi64(0x1f, s + 15), // row 3
+ r4 = _mm512_maskz_loadu_epi64(0x1f, s + 20); // row 4
+
+ // 24 rounds
+ for (size_t i = 0; i < SHA3_NUM_ROUNDS; i++) {
// theta
{
- const __m512i i0 = _mm512_setr_epi64(4, 0, 1, 2, 3, 5, 6, 7),
- i1 = _mm512_setr_epi64(1, 2, 3, 4, 0, 5, 6, 7);
+ // permute ids
+ static const __m512i I0 = { 4, 0, 1, 2, 3 },
+ I1 = { 1, 2, 3, 4, 0 };
// c = xor(r0, r1, r2, r3, r4)
- const __m512i r01 = _mm512_xor_epi64(r0, r1),
- r23 = _mm512_xor_epi64(r2, r3),
- r03 = _mm512_xor_epi64(r01, r23),
- c = _mm512_xor_epi64(r03, r4);
+ const __m512i r01 = _mm512_maskz_xor_epi64(0x1f, r0, r1),
+ r23 = _mm512_maskz_xor_epi64(0x1f, r2, r3),
+ c = _mm512_maskz_ternarylogic_epi64(0x1f, r01, r23, r4, 0x96);
// d = xor(permute(i0, c), permute(i1, rol(c, 1)))
- const __m512i d0 = _mm512_permutexvar_epi64(i0, c),
- d1 = _mm512_permutexvar_epi64(i1, _mm512_rol_epi64(c, 1)),
+ const __m512i d0 = _mm512_permutexvar_epi64(I0, c),
+ d1 = _mm512_permutexvar_epi64(I1, _mm512_rol_epi64(c, 1)),
d = _mm512_xor_epi64(d0, d1);
// row = xor(row, d)
@@ -283,110 +274,89 @@ static inline void permute_avx512(uint64_t s[static 25]) {
// rho
{
- // unaligned load mask and rotate values
- static const uint64_t vs0[8] = { 0, 1, 62, 28, 27, 0, 0, 0 },
- vs1[8] = { 36, 44, 6, 55, 20, 0, 0, 0 },
- vs2[8] = { 3, 10, 43, 25, 39, 0, 0, 0 },
- vs3[8] = { 41, 45, 15, 21, 8, 0, 0, 0 },
- vs4[8] = { 18, 2, 61, 56, 14, 0, 0, 0 };
-
- // load rotate values
- const __m512i v0 = _mm512_loadu_epi64((void*) vs0),
- v1 = _mm512_loadu_epi64((void*) vs1),
- v2 = _mm512_loadu_epi64((void*) vs2),
- v3 = _mm512_loadu_epi64((void*) vs3),
- v4 = _mm512_loadu_epi64((void*) vs4);
-
- // rotate
- r0 = _mm512_rolv_epi64(r0, v0);
- r1 = _mm512_rolv_epi64(r1, v1);
- r2 = _mm512_rolv_epi64(r2, v2);
- r3 = _mm512_rolv_epi64(r3, v3);
- r4 = _mm512_rolv_epi64(r4, v4);
+ // rotate values
+ //
+ // note: switching from maskz_load_epi64()s to static const
+ // __m512i incurs a 500 cycle penalty; leaving them for now
+ static const uint64_t V0_VALS[5] ALIGN(64) = { 0, 1, 62, 28, 27 },
+ V1_VALS[5] ALIGN(64) = { 36, 44, 6, 55, 20 },
+ V2_VALS[5] ALIGN(64) = { 3, 10, 43, 25, 39 },
+ V3_VALS[5] ALIGN(64) = { 41, 45, 15, 21, 8 },
+ V4_VALS[5] ALIGN(64) = { 18, 2, 61, 56, 14 };
+
+ // rotate rows
+ r0 = _mm512_rolv_epi64(r0, _mm512_maskz_load_epi64(0x1f, V0_VALS));
+ r1 = _mm512_rolv_epi64(r1, _mm512_maskz_load_epi64(0x1f, V1_VALS));
+ r2 = _mm512_rolv_epi64(r2, _mm512_maskz_load_epi64(0x1f, V2_VALS));
+ r3 = _mm512_rolv_epi64(r3, _mm512_maskz_load_epi64(0x1f, V3_VALS));
+ r4 = _mm512_rolv_epi64(r4, _mm512_maskz_load_epi64(0x1f, V4_VALS));
}
// pi
+ //
+ // The cells are permuted across all rows of the state array. each
+ // output row is the combination of three permutations:
+ //
+ // - e0: row 0 and row 1
+ // - e2: row 2 and row 3
+ // - e4: row 4 and row 0
+ //
+ // the IDs for each permutation are merged into a single array
+ // (T*_IDS) to reduce register pressure, and the permute operations
+ // are masked so that each permutation only uses the relevant IDs.
+ //
+ // afterwards, the permutations are combined to form a temporary
+ // row:
+ //
+ // t0 = t0e0 | t0e2 | t0e4
+ //
+ // once the permutations for all rows are complete, the temporary
+ // rows are saved to the actual row registers:
+ //
+ // r0 = t0
+ //
{
- // mask bytes
- uint8_t m01b = 0x03,
- m23b = 0x0c,
- m4b = 0x10;
-
- // load masks
- const __mmask8 m01 = _load_mask8(&m01b),
- m23 = _load_mask8(&m23b),
- m4 = _load_mask8(&m4b);
-
- // permutation indices
- //
- // (note: these are masked so only the relevant indices for
- // _mm512_maskz_permutex2var_epi64() in each array are filled in)
- static const uint64_t t0_pis_01[8] = { 0, 8 + 1, 0, 0, 0, 0, 0, 0 },
- t0_pis_23[8] = { 0, 0, 2, 8 + 3, 0, 0, 0, 0 },
- t0_pis_4[8] = { 0, 0, 0, 0, 4, 0, 0, 0 },
-
- t1_pis_01[8] = { 3, 8 + 4, 0, 0, 0, 0, 0, 0 },
- t1_pis_23[8] = { 0, 0, 0, 8 + 1, 0, 0, 0, 0 },
- t1_pis_4[8] = { 0, 0, 0, 0, 2, 0, 0, 0 },
-
- t2_pis_01[8] = { 1, 8 + 2, 0, 0, 0, 0, 0, 0 },
- t2_pis_23[8] = { 0, 0, 3, 8 + 4, 0, 0, 0, 0 },
- t2_pis_4[8] = { 0, 0, 0, 0, 0, 0, 0, 0 },
-
- t3_pis_01[8] = { 4, 8 + 0, 0, 0, 0, 0, 0, 0 },
- t3_pis_23[8] = { 0, 0, 1, 8 + 2, 0, 0, 0, 0 },
- t3_pis_4[8] = { 0, 0, 0, 0, 3, 0, 0, 0 },
-
- t4_pis_01[8] = { 2, 8 + 3, 0, 0, 0, 0, 0, 0 },
- t4_pis_23[8] = { 0, 0, 4, 8 + 0, 0, 0, 0, 0 },
- t4_pis_4[8] = { 0, 0, 0, 0, 1, 0, 0, 0 };
-
- // load permutation indices
- const __m512i t0_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t0_pis_01),
- t0_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t0_pis_23),
- t0_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t0_pis_4),
-
- t1_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t1_pis_01),
- t1_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t1_pis_23),
- t1_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t1_pis_4),
-
- t2_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t2_pis_01),
- t2_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t2_pis_23),
- t2_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t2_pis_4),
-
- t3_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t3_pis_01),
- t3_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t3_pis_23),
- t3_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t3_pis_4),
-
- t4_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t4_pis_01),
- t4_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t4_pis_23),
- t4_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t4_pis_4);
-
- // permute rows
- const __m512i t0e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t0_p01, r1),
- t0e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t0_p23, r3),
- t0e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t0_p4, r0),
- t0 = _mm512_or_epi64(_mm512_or_epi64(t0e0, t0e2), t0e4),
-
- t1e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t1_p01, r1),
- t1e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t1_p23, r3),
- t1e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t1_p4, r0),
- t1 = _mm512_or_epi64(_mm512_or_epi64(t1e0, t1e2), t1e4),
-
- t2e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t2_p01, r1),
- t2e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t2_p23, r3),
- t2e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t2_p4, r0),
- t2 = _mm512_or_epi64(_mm512_or_epi64(t2e0, t2e2), t2e4),
-
- t3e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t3_p01, r1),
- t3e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t3_p23, r3),
- t3e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t3_p4, r0),
- t3 = _mm512_or_epi64(_mm512_or_epi64(t3e0, t3e2), t3e4),
-
- t4e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t4_p01, r1),
- t4e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t4_p23, r3),
- t4e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t4_p4, r0),
- t4 = _mm512_or_epi64(_mm512_or_epi64(t4e0, t4e2), t4e4);
+ // permute ids
+ static const __m512i T0_IDS = { 0, 8 + 1, 2, 8 + 3, 4 },
+ T1_IDS = { 3, 8 + 4, 0, 8 + 1, 2 },
+ T2_IDS = { 1, 8 + 2, 3, 8 + 4, 0 },
+ T3_IDS = { 4, 8 + 0, 1, 8 + 2, 3 },
+ T4_IDS = { 2, 8 + 3, 4, 8 + 0, 1 };
+
+ __m512i t0, t1, t2, t3, t4;
+ {
+ // permute r0
+ const __m512i t0e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T0_IDS, r1),
+ t0e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T0_IDS, r3),
+ t0e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T0_IDS, r0);
+
+ // permute r1
+ const __m512i t1e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T1_IDS, r1),
+ t1e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T1_IDS, r3),
+ t1e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T1_IDS, r0);
+
+ // permute r2
+ const __m512i t2e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T2_IDS, r1),
+ t2e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T2_IDS, r3),
+ t2e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T2_IDS, r0);
+
+ // permute r3
+ const __m512i t3e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T3_IDS, r1),
+ t3e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T3_IDS, r3),
+ t3e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T3_IDS, r0);
+
+ // permute r4
+ const __m512i t4e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T4_IDS, r1),
+ t4e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T4_IDS, r3),
+ t4e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T4_IDS, r0);
+
+ // combine permutes: tN = e0 | e2 | e4
+ t0 = _mm512_maskz_ternarylogic_epi64(0x1f, t0e0, t0e2, t0e4, 0xfe);
+ t1 = _mm512_maskz_ternarylogic_epi64(0x1f, t1e0, t1e2, t1e4, 0xfe);
+ t2 = _mm512_maskz_ternarylogic_epi64(0x1f, t2e0, t2e2, t2e4, 0xfe);
+ t3 = _mm512_maskz_ternarylogic_epi64(0x1f, t3e0, t3e2, t3e4, 0xfe);
+ t4 = _mm512_maskz_ternarylogic_epi64(0x1f, t4e0, t4e2, t4e4, 0xfe);
+ }
// store rows
r0 = t0;
@@ -398,103 +368,86 @@ static inline void permute_avx512(uint64_t s[static 25]) {
// chi
{
- // permutation indices
- static const uint64_t pis0[8] = { 1, 2, 3, 4, 0, 0, 0, 0 },
- pis1[8] = { 2, 3, 4, 0, 1, 0, 0, 0 };
-
- // load permutation indices
- const __m512i p0 = _mm512_maskz_loadu_epi64(m, (void*) pis0),
- p1 = _mm512_maskz_loadu_epi64(m, (void*) pis1);
-
- // permute rows
- const __m512i t0_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r0),
- t0_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r0),
- t0 = _mm512_xor_epi64(r0, _mm512_andnot_epi64(t0_e0, t0_e1)),
-
- t1_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r1),
- t1_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r1),
- t1 = _mm512_xor_epi64(r1, _mm512_andnot_epi64(t1_e0, t1_e1)),
+ // permute ids
+ static const __m512i P0 = { 1, 2, 3, 4, 0 },
+ P1 = { 2, 3, 4, 0, 1 };
+
+ {
+ // r0 ^= ~e0 & e1
+ const __m512i t0_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r0),
+ t0_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r0);
+ r0 = _mm512_maskz_ternarylogic_epi64(0x1f, r0, t0_e0, t0_e1, 0xd2);
+ }
- t2_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r2),
- t2_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r2),
- t2 = _mm512_xor_epi64(r2, _mm512_andnot_epi64(t2_e0, t2_e1)),
+ {
+ // r1 ^= ~e0 & e1
+ const __m512i t1_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r1),
+ t1_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r1);
+ r1 = _mm512_maskz_ternarylogic_epi64(0x1f, r1, t1_e0, t1_e1, 0xd2);
+ }
- t3_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r3),
- t3_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r3),
- t3 = _mm512_xor_epi64(r3, _mm512_andnot_epi64(t3_e0, t3_e1)),
+ {
+ // r2 ^= ~e0 & e1
+ const __m512i t2_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r2),
+ t2_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r2);
+ r2 = _mm512_maskz_ternarylogic_epi64(0x1f, r2, t2_e0, t2_e1, 0xd2);
+ }
- t4_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r4),
- t4_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r4),
- t4 = _mm512_xor_epi64(r4, _mm512_andnot_epi64(t4_e0, t4_e1));
+ {
+ // r3 ^= ~e0 & e1
+ const __m512i t3_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r3),
+ t3_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r3);
+ r3 = _mm512_maskz_ternarylogic_epi64(0x1f, r3, t3_e0, t3_e1, 0xd2);
+ }
- // store rows
- r0 = t0;
- r1 = t1;
- r2 = t2;
- r3 = t3;
- r4 = t4;
+ {
+ // r4 ^= ~e0 & e1
+ const __m512i t4_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r4),
+ t4_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r4);
+ r4 = _mm512_maskz_ternarylogic_epi64(0x1f, r4, t4_e0, t4_e1, 0xd2);
+ }
}
// iota
{
- // xor round constant, shuffle rc register
- r0 = _mm512_mask_xor_epi64(r0, m0, r0, rc);
- rc = _mm512_permutexvar_epi64(rc_p, rc);
-
- if (((i + 1) % 8) == 0 && i != 23) {
- // load next set of round constants
- // note: this will bomb if num_rounds < 8 or num_rounds > 24.
- rc = _mm512_loadu_epi64((void*) (RCS + (i + 1)));
- }
+ // xor round constant to first cell
+ r0 = _mm512_mask_xor_epi64(r0, 1, r0, _mm512_maskz_loadu_epi64(1, RCS + i));
}
}
// store rows
- _mm512_mask_storeu_epi64((void*) (s), m, r0),
- _mm512_mask_storeu_epi64((void*) (s + 5), m, r1),
- _mm512_mask_storeu_epi64((void*) (s + 10), m, r2),
- _mm512_mask_storeu_epi64((void*) (s + 15), m, r3),
- _mm512_mask_storeu_epi64((void*) (s + 20), m, r4);
+ _mm512_mask_storeu_epi64(s + 5 * 0, 0x1f, r0);
+ _mm512_mask_storeu_epi64(s + 5 * 1, 0x1f, r1);
+ _mm512_mask_storeu_epi64(s + 5 * 2, 0x1f, r2);
+ _mm512_mask_storeu_epi64(s + 5 * 3, 0x1f, r3);
+ _mm512_mask_storeu_epi64(s + 5 * 4, 0x1f, r4);
}
// 12 round keccak permutation (avx512 implementation).
static inline void permute12_avx512(uint64_t s[static 25]) {
- // unaligned load mask and permutation indices
- uint8_t mask = 0x1f,
- m0b = 0x01;
- const __mmask8 m = _load_mask8(&mask),
- m0 = _load_mask8(&m0b);
-
- // load round constant
- // note: this will bomb if num_rounds < 8 or num_rounds > 24.
- __m512i rc = _mm512_loadu_epi64((void*) (RCS + 12));
-
- // load rc permutation
- static const uint64_t rc_ps[8] = { 1, 2, 3, 4, 5, 6, 7, 0 };
- const __m512i rc_p = _mm512_loadu_epi64((void*) rc_ps);
-
- // load rows
- __m512i r0 = _mm512_maskz_loadu_epi64(m, (void*) (s)),
- r1 = _mm512_maskz_loadu_epi64(m, (void*) (s + 5)),
- r2 = _mm512_maskz_loadu_epi64(m, (void*) (s + 10)),
- r3 = _mm512_maskz_loadu_epi64(m, (void*) (s + 15)),
- r4 = _mm512_maskz_loadu_epi64(m, (void*) (s + 20));
-
- for (int i = 0; i < SHA3_NUM_ROUNDS; i++) {
+ // load rows (r0-r4)
+ __m512i r0 = _mm512_maskz_loadu_epi64(0x1f, s + 0), // row 0
+ r1 = _mm512_maskz_loadu_epi64(0x1f, s + 5), // row 1
+ r2 = _mm512_maskz_loadu_epi64(0x1f, s + 10), // row 2
+ r3 = _mm512_maskz_loadu_epi64(0x1f, s + 15), // row 3
+ r4 = _mm512_maskz_loadu_epi64(0x1f, s + 20); // row 4
+
+ // 12 rounds
+ for (size_t i = 0; i < 12; i++) {
// theta
{
- const __m512i i0 = _mm512_setr_epi64(4, 0, 1, 2, 3, 5, 6, 7),
- i1 = _mm512_setr_epi64(1, 2, 3, 4, 0, 5, 6, 7);
+ // permute ids
+ static const __m512i I0 = { 4, 0, 1, 2, 3 },
+ I1 = { 1, 2, 3, 4, 0 };
// c = xor(r0, r1, r2, r3, r4)
- const __m512i r01 = _mm512_xor_epi64(r0, r1),
- r23 = _mm512_xor_epi64(r2, r3),
- r03 = _mm512_xor_epi64(r01, r23),
- c = _mm512_xor_epi64(r03, r4);
+ const __m512i r01 = _mm512_maskz_xor_epi64(0x1f, r0, r1),
+ r23 = _mm512_maskz_xor_epi64(0x1f, r2, r3),
+ c = _mm512_maskz_ternarylogic_epi64(0x1f, r01, r23, r4, 0x96);
// d = xor(permute(i0, c), permute(i1, rol(c, 1)))
- const __m512i d0 = _mm512_permutexvar_epi64(i0, c),
- d1 = _mm512_permutexvar_epi64(i1, _mm512_rol_epi64(c, 1)),
+ const __m512i d0 = _mm512_permutexvar_epi64(I0, c),
+ d1 = _mm512_permutexvar_epi64(I1, _mm512_rol_epi64(c, 1)),
d = _mm512_xor_epi64(d0, d1);
// row = xor(row, d)
@@ -507,110 +460,89 @@ static inline void permute12_avx512(uint64_t s[static 25]) {
// rho
{
- // unaligned load mask and rotate values
- static const uint64_t vs0[8] = { 0, 1, 62, 28, 27, 0, 0, 0 },
- vs1[8] = { 36, 44, 6, 55, 20, 0, 0, 0 },
- vs2[8] = { 3, 10, 43, 25, 39, 0, 0, 0 },
- vs3[8] = { 41, 45, 15, 21, 8, 0, 0, 0 },
- vs4[8] = { 18, 2, 61, 56, 14, 0, 0, 0 };
-
- // load rotate values
- const __m512i v0 = _mm512_loadu_epi64((void*) vs0),
- v1 = _mm512_loadu_epi64((void*) vs1),
- v2 = _mm512_loadu_epi64((void*) vs2),
- v3 = _mm512_loadu_epi64((void*) vs3),
- v4 = _mm512_loadu_epi64((void*) vs4);
-
- // rotate
- r0 = _mm512_rolv_epi64(r0, v0);
- r1 = _mm512_rolv_epi64(r1, v1);
- r2 = _mm512_rolv_epi64(r2, v2);
- r3 = _mm512_rolv_epi64(r3, v3);
- r4 = _mm512_rolv_epi64(r4, v4);
+ // rotate values
+ //
+ // note: switching from maskz_load_epi64()s to static const
+ // __m512i incurs a 500 cycle penalty; leaving them for now
+ static const uint64_t V0_VALS[5] ALIGN(64) = { 0, 1, 62, 28, 27 },
+ V1_VALS[5] ALIGN(64) = { 36, 44, 6, 55, 20 },
+ V2_VALS[5] ALIGN(64) = { 3, 10, 43, 25, 39 },
+ V3_VALS[5] ALIGN(64) = { 41, 45, 15, 21, 8 },
+ V4_VALS[5] ALIGN(64) = { 18, 2, 61, 56, 14 };
+
+ // rotate rows
+ r0 = _mm512_rolv_epi64(r0, _mm512_maskz_load_epi64(0x1f, V0_VALS));
+ r1 = _mm512_rolv_epi64(r1, _mm512_maskz_load_epi64(0x1f, V1_VALS));
+ r2 = _mm512_rolv_epi64(r2, _mm512_maskz_load_epi64(0x1f, V2_VALS));
+ r3 = _mm512_rolv_epi64(r3, _mm512_maskz_load_epi64(0x1f, V3_VALS));
+ r4 = _mm512_rolv_epi64(r4, _mm512_maskz_load_epi64(0x1f, V4_VALS));
}
// pi
+ //
+ // The cells are permuted across all rows of the state array. each
+ // output row is the combination of three permutations:
+ //
+ // - e0: row 0 and row 1
+ // - e2: row 2 and row 3
+ // - e4: row 4 and row 0
+ //
+ // the IDs for each permutation are merged into a single array
+ // (T*_IDS) to reduce register pressure, and the permute operations
+ // are masked so that each permutation only uses the relevant IDs.
+ //
+ // afterwards, the permutations are combined to form a temporary
+ // row:
+ //
+ // t0 = t0e0 | t0e2 | t0e4
+ //
+ // once the permutations for all rows are complete, the temporary
+ // rows are saved to the actual row registers:
+ //
+ // r0 = t0
+ //
{
- // mask bytes
- uint8_t m01b = 0x03,
- m23b = 0x0c,
- m4b = 0x10;
-
- // load masks
- const __mmask8 m01 = _load_mask8(&m01b),
- m23 = _load_mask8(&m23b),
- m4 = _load_mask8(&m4b);
-
- // permutation indices
- //
- // (note: these are masked so only the relevant indices for
- // _mm512_maskz_permutex2var_epi64() in each array are filled in)
- static const uint64_t t0_pis_01[8] = { 0, 8 + 1, 0, 0, 0, 0, 0, 0 },
- t0_pis_23[8] = { 0, 0, 2, 8 + 3, 0, 0, 0, 0 },
- t0_pis_4[8] = { 0, 0, 0, 0, 4, 0, 0, 0 },
-
- t1_pis_01[8] = { 3, 8 + 4, 0, 0, 0, 0, 0, 0 },
- t1_pis_23[8] = { 0, 0, 0, 8 + 1, 0, 0, 0, 0 },
- t1_pis_4[8] = { 0, 0, 0, 0, 2, 0, 0, 0 },
-
- t2_pis_01[8] = { 1, 8 + 2, 0, 0, 0, 0, 0, 0 },
- t2_pis_23[8] = { 0, 0, 3, 8 + 4, 0, 0, 0, 0 },
- t2_pis_4[8] = { 0, 0, 0, 0, 0, 0, 0, 0 },
-
- t3_pis_01[8] = { 4, 8 + 0, 0, 0, 0, 0, 0, 0 },
- t3_pis_23[8] = { 0, 0, 1, 8 + 2, 0, 0, 0, 0 },
- t3_pis_4[8] = { 0, 0, 0, 0, 3, 0, 0, 0 },
-
- t4_pis_01[8] = { 2, 8 + 3, 0, 0, 0, 0, 0, 0 },
- t4_pis_23[8] = { 0, 0, 4, 8 + 0, 0, 0, 0, 0 },
- t4_pis_4[8] = { 0, 0, 0, 0, 1, 0, 0, 0 };
-
- // load permutation indices
- const __m512i t0_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t0_pis_01),
- t0_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t0_pis_23),
- t0_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t0_pis_4),
-
- t1_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t1_pis_01),
- t1_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t1_pis_23),
- t1_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t1_pis_4),
-
- t2_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t2_pis_01),
- t2_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t2_pis_23),
- t2_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t2_pis_4),
-
- t3_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t3_pis_01),
- t3_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t3_pis_23),
- t3_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t3_pis_4),
-
- t4_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t4_pis_01),
- t4_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t4_pis_23),
- t4_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t4_pis_4);
-
- // permute rows
- const __m512i t0e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t0_p01, r1),
- t0e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t0_p23, r3),
- t0e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t0_p4, r0),
- t0 = _mm512_or_epi64(_mm512_or_epi64(t0e0, t0e2), t0e4),
-
- t1e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t1_p01, r1),
- t1e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t1_p23, r3),
- t1e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t1_p4, r0),
- t1 = _mm512_or_epi64(_mm512_or_epi64(t1e0, t1e2), t1e4),
-
- t2e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t2_p01, r1),
- t2e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t2_p23, r3),
- t2e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t2_p4, r0),
- t2 = _mm512_or_epi64(_mm512_or_epi64(t2e0, t2e2), t2e4),
-
- t3e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t3_p01, r1),
- t3e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t3_p23, r3),
- t3e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t3_p4, r0),
- t3 = _mm512_or_epi64(_mm512_or_epi64(t3e0, t3e2), t3e4),
-
- t4e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t4_p01, r1),
- t4e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t4_p23, r3),
- t4e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t4_p4, r0),
- t4 = _mm512_or_epi64(_mm512_or_epi64(t4e0, t4e2), t4e4);
+ // permute ids
+ static const __m512i T0_IDS = { 0, 8 + 1, 2, 8 + 3, 4 },
+ T1_IDS = { 3, 8 + 4, 0, 8 + 1, 2 },
+ T2_IDS = { 1, 8 + 2, 3, 8 + 4, 0 },
+ T3_IDS = { 4, 8 + 0, 1, 8 + 2, 3 },
+ T4_IDS = { 2, 8 + 3, 4, 8 + 0, 1 };
+
+ __m512i t0, t1, t2, t3, t4;
+ {
+ // permute r0
+ const __m512i t0e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T0_IDS, r1),
+ t0e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T0_IDS, r3),
+ t0e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T0_IDS, r0);
+
+ // permute r1
+ const __m512i t1e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T1_IDS, r1),
+ t1e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T1_IDS, r3),
+ t1e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T1_IDS, r0);
+
+ // permute r2
+ const __m512i t2e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T2_IDS, r1),
+ t2e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T2_IDS, r3),
+ t2e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T2_IDS, r0);
+
+ // permute r3
+ const __m512i t3e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T3_IDS, r1),
+ t3e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T3_IDS, r3),
+ t3e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T3_IDS, r0);
+
+ // permute r4
+ const __m512i t4e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T4_IDS, r1),
+ t4e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T4_IDS, r3),
+ t4e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T4_IDS, r0);
+
+ // combine permutes: tN = e0 | e2 | e4
+ t0 = _mm512_maskz_ternarylogic_epi64(0x1f, t0e0, t0e2, t0e4, 0xfe);
+ t1 = _mm512_maskz_ternarylogic_epi64(0x1f, t1e0, t1e2, t1e4, 0xfe);
+ t2 = _mm512_maskz_ternarylogic_epi64(0x1f, t2e0, t2e2, t2e4, 0xfe);
+ t3 = _mm512_maskz_ternarylogic_epi64(0x1f, t3e0, t3e2, t3e4, 0xfe);
+ t4 = _mm512_maskz_ternarylogic_epi64(0x1f, t4e0, t4e2, t4e4, 0xfe);
+ }
// store rows
r0 = t0;
@@ -622,63 +554,59 @@ static inline void permute12_avx512(uint64_t s[static 25]) {
// chi
{
- // permutation indices
- static const uint64_t pis0[8] = { 1, 2, 3, 4, 0, 0, 0, 0 },
- pis1[8] = { 2, 3, 4, 0, 1, 0, 0, 0 };
-
- // load permutation indices
- const __m512i p0 = _mm512_maskz_loadu_epi64(m, (void*) pis0),
- p1 = _mm512_maskz_loadu_epi64(m, (void*) pis1);
-
- // permute rows
- const __m512i t0_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r0),
- t0_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r0),
- t0 = _mm512_xor_epi64(r0, _mm512_andnot_epi64(t0_e0, t0_e1)),
-
- t1_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r1),
- t1_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r1),
- t1 = _mm512_xor_epi64(r1, _mm512_andnot_epi64(t1_e0, t1_e1)),
+ // permute ids
+ static const __m512i P0 = { 1, 2, 3, 4, 0 },
+ P1 = { 2, 3, 4, 0, 1 };
+
+ {
+ // r0 ^= ~e0 & e1
+ const __m512i t0_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r0),
+ t0_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r0);
+ r0 = _mm512_maskz_ternarylogic_epi64(0x1f, r0, t0_e0, t0_e1, 0xd2);
+ }
- t2_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r2),
- t2_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r2),
- t2 = _mm512_xor_epi64(r2, _mm512_andnot_epi64(t2_e0, t2_e1)),
+ {
+ // r1 ^= ~e0 & e1
+ const __m512i t1_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r1),
+ t1_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r1);
+ r1 = _mm512_maskz_ternarylogic_epi64(0x1f, r1, t1_e0, t1_e1, 0xd2);
+ }
- t3_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r3),
- t3_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r3),
- t3 = _mm512_xor_epi64(r3, _mm512_andnot_epi64(t3_e0, t3_e1)),
+ {
+ // r2 ^= ~e0 & e1
+ const __m512i t2_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r2),
+ t2_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r2);
+ r2 = _mm512_maskz_ternarylogic_epi64(0x1f, r2, t2_e0, t2_e1, 0xd2);
+ }
- t4_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r4),
- t4_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r4),
- t4 = _mm512_xor_epi64(r4, _mm512_andnot_epi64(t4_e0, t4_e1));
+ {
+ // r3 ^= ~e0 & e1
+ const __m512i t3_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r3),
+ t3_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r3);
+ r3 = _mm512_maskz_ternarylogic_epi64(0x1f, r3, t3_e0, t3_e1, 0xd2);
+ }
- // store rows
- r0 = t0;
- r1 = t1;
- r2 = t2;
- r3 = t3;
- r4 = t4;
+ {
+ // r4 ^= ~e0 & e1
+ const __m512i t4_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r4),
+ t4_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r4);
+ r4 = _mm512_maskz_ternarylogic_epi64(0x1f, r4, t4_e0, t4_e1, 0xd2);
+ }
}
// iota
{
- // xor round constant, shuffle rc register
- r0 = _mm512_mask_xor_epi64(r0, m0, r0, rc);
- rc = _mm512_permutexvar_epi64(rc_p, rc);
-
- if (((12 + i + 1) % 8) == 0 && (12 + i) != 23) {
- // load next set of round constants
- // note: this will bomb if num_rounds < 8 or num_rounds > 24.
- rc = _mm512_loadu_epi64((void*) (RCS + (12 + i + 1)));
- }
+ // xor round constant to first cell
+ r0 = _mm512_mask_xor_epi64(r0, 1, r0, _mm512_maskz_loadu_epi64(1, RCS + (12 + i)));
}
}
// store rows
- _mm512_mask_storeu_epi64((void*) (s), m, r0),
- _mm512_mask_storeu_epi64((void*) (s + 5), m, r1),
- _mm512_mask_storeu_epi64((void*) (s + 10), m, r2),
- _mm512_mask_storeu_epi64((void*) (s + 15), m, r3),
- _mm512_mask_storeu_epi64((void*) (s + 20), m, r4);
+ _mm512_mask_storeu_epi64(s + 5 * 0, 0x1f, r0);
+ _mm512_mask_storeu_epi64(s + 5 * 1, 0x1f, r1);
+ _mm512_mask_storeu_epi64(s + 5 * 2, 0x1f, r2);
+ _mm512_mask_storeu_epi64(s + 5 * 3, 0x1f, r3);
+ _mm512_mask_storeu_epi64(s + 5 * 4, 0x1f, r4);
}
#endif /* __AVX512F__ */