From 6b2eb7da27a27b5d48caacfb928d9279799bcd6c Mon Sep 17 00:00:00 2001 From: Paul Duncan Date: Mon, 29 Apr 2024 12:08:54 -0400 Subject: sha3.c: permute{,12}_avx512(): optimize, update header comment --- sha3.c | 662 +++++++++++++++++++++++++++++------------------------------------ 1 file changed, 295 insertions(+), 367 deletions(-) (limited to 'sha3.c') diff --git a/sha3.c b/sha3.c index 887c5c2..83ef2b9 100644 --- a/sha3.c +++ b/sha3.c @@ -34,6 +34,9 @@ // number of rounds for permute() #define SHA3_NUM_ROUNDS 24 +// align memory to N bytes +#define ALIGN(N) __attribute__((aligned(N))) + // round constants (used by iota) static const uint64_t RCS[] = { 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, 0x8000000080008000ULL, @@ -211,66 +214,54 @@ static inline void permute12_scalar(uint64_t a[static 25]) { // // how it operates (roughly): // -// 1. load rows from state `s` into avx512 registers r0-r4, like so: -// -// r0 <- | s[ 0] | s[ 1] | s[ 2] | s[ 3] | s[ 4] | n/a | n/a | n/a | -// r1 <- | s[ 5] | s[ 6] | s[ 7] | s[ 8] | s[ 9] | n/a | n/a | n/a | -// r2 <- | s[10] | s[11] | s[12] | s[13] | s[14] | n/a | n/a | n/a | -// r3 <- | s[15] | s[16] | s[17] | s[18] | s[19] | n/a | n/a | n/a | -// r4 <- | s[20] | s[21] | s[22] | s[23] | s[24] | n/a | n/a | n/a | +// 1. load rows from state `s` into the first 5 64-bit lanes of AVX-512 +// registers r0-r4, like so: // -// 2. load the first 8 round constants for iota into an avx512 `ra` -// (round constants) register. +// ----------------------------------------------------------------- +// | | Lanes | +// |-----|---------------------------------------------------------| +// | Reg | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | +// |-----|-------|-------|-------|-------|-------|-----|-----|-----| +// | r0 | s[ 0] | s[ 1] | s[ 2] | s[ 3] | s[ 4] | n/a | n/a | n/a | +// | r0 | s[ 0] | s[ 1] | s[ 2] | s[ 3] | s[ 4] | n/a | n/a | n/a | +// | r1 | s[ 5] | s[ 6] | s[ 7] | s[ 8] | s[ 9] | n/a | n/a | n/a | +// | r2 | s[10] | s[11] | s[12] | s[13] | s[14] | n/a | n/a | n/a | +// | r3 | s[15] | s[16] | s[17] | s[18] | s[19] | n/a | n/a | n/a | +// | r4 | s[20] | s[21] | s[22] | s[23] | s[24] | n/a | n/a | n/a | +// ----------------------------------------------------------------- // -// 3. for each round: +// 2. For each round of 24 rounds: // a. Perform theta, rho, pi, and chi steps. pi, in particular, has // a large number of permutation registers (so it may spill). -// b. Perform iota with the first round constant in `rc`, then permute -// `rc`. If we have exhausted all 8 round constants and we are not -// at the final round, then load the next 8 round constants. +// b. Load round constant for current round and perform iota step. // -// 4. copy the rows from registers r0-r4 back to the state `s`. +// 3. store the rows first 5 64-bit lanes of registers r0-r4 back to the +// state `s`. // -// as noted above, this is not the most efficient avx512 implementation; -// the row registers have three empty slots and there are a lot of loads -// that could be removed with a little more work. static inline void permute_avx512(uint64_t s[static 25]) { - // unaligned load mask and permutation indices - uint8_t mask = 0x1f, - m0b = 0x01; - const __mmask8 m = _load_mask8(&mask), - m0 = _load_mask8(&m0b); - - // load round constant - // note: this will bomb if num_rounds < 8 or num_rounds > 24. - __m512i rc = _mm512_loadu_epi64((void*) RCS); - - // load rc permutation - static const uint64_t rc_ps[8] = { 1, 2, 3, 4, 5, 6, 7, 0 }; - const __m512i rc_p = _mm512_loadu_epi64((void*) rc_ps); - - // load rows - __m512i r0 = _mm512_maskz_loadu_epi64(m, (void*) (s)), - r1 = _mm512_maskz_loadu_epi64(m, (void*) (s + 5)), - r2 = _mm512_maskz_loadu_epi64(m, (void*) (s + 10)), - r3 = _mm512_maskz_loadu_epi64(m, (void*) (s + 15)), - r4 = _mm512_maskz_loadu_epi64(m, (void*) (s + 20)); - - for (int i = 0; i < SHA3_NUM_ROUNDS; i++) { + // load rows (r0-r4) + __m512i r0 = _mm512_maskz_loadu_epi64(0x1f, s + 0), // row 0 + r1 = _mm512_maskz_loadu_epi64(0x1f, s + 5), // row 1 + r2 = _mm512_maskz_loadu_epi64(0x1f, s + 10), // row 2 + r3 = _mm512_maskz_loadu_epi64(0x1f, s + 15), // row 3 + r4 = _mm512_maskz_loadu_epi64(0x1f, s + 20); // row 4 + + // 24 rounds + for (size_t i = 0; i < SHA3_NUM_ROUNDS; i++) { // theta { - const __m512i i0 = _mm512_setr_epi64(4, 0, 1, 2, 3, 5, 6, 7), - i1 = _mm512_setr_epi64(1, 2, 3, 4, 0, 5, 6, 7); + // permute ids + static const __m512i I0 = { 4, 0, 1, 2, 3 }, + I1 = { 1, 2, 3, 4, 0 }; // c = xor(r0, r1, r2, r3, r4) - const __m512i r01 = _mm512_xor_epi64(r0, r1), - r23 = _mm512_xor_epi64(r2, r3), - r03 = _mm512_xor_epi64(r01, r23), - c = _mm512_xor_epi64(r03, r4); + const __m512i r01 = _mm512_maskz_xor_epi64(0x1f, r0, r1), + r23 = _mm512_maskz_xor_epi64(0x1f, r2, r3), + c = _mm512_maskz_ternarylogic_epi64(0x1f, r01, r23, r4, 0x96); // d = xor(permute(i0, c), permute(i1, rol(c, 1))) - const __m512i d0 = _mm512_permutexvar_epi64(i0, c), - d1 = _mm512_permutexvar_epi64(i1, _mm512_rol_epi64(c, 1)), + const __m512i d0 = _mm512_permutexvar_epi64(I0, c), + d1 = _mm512_permutexvar_epi64(I1, _mm512_rol_epi64(c, 1)), d = _mm512_xor_epi64(d0, d1); // row = xor(row, d) @@ -283,110 +274,89 @@ static inline void permute_avx512(uint64_t s[static 25]) { // rho { - // unaligned load mask and rotate values - static const uint64_t vs0[8] = { 0, 1, 62, 28, 27, 0, 0, 0 }, - vs1[8] = { 36, 44, 6, 55, 20, 0, 0, 0 }, - vs2[8] = { 3, 10, 43, 25, 39, 0, 0, 0 }, - vs3[8] = { 41, 45, 15, 21, 8, 0, 0, 0 }, - vs4[8] = { 18, 2, 61, 56, 14, 0, 0, 0 }; - - // load rotate values - const __m512i v0 = _mm512_loadu_epi64((void*) vs0), - v1 = _mm512_loadu_epi64((void*) vs1), - v2 = _mm512_loadu_epi64((void*) vs2), - v3 = _mm512_loadu_epi64((void*) vs3), - v4 = _mm512_loadu_epi64((void*) vs4); - - // rotate - r0 = _mm512_rolv_epi64(r0, v0); - r1 = _mm512_rolv_epi64(r1, v1); - r2 = _mm512_rolv_epi64(r2, v2); - r3 = _mm512_rolv_epi64(r3, v3); - r4 = _mm512_rolv_epi64(r4, v4); + // rotate values + // + // note: switching from maskz_load_epi64()s to static const + // __m512i incurs a 500 cycle penalty; leaving them for now + static const uint64_t V0_VALS[5] ALIGN(64) = { 0, 1, 62, 28, 27 }, + V1_VALS[5] ALIGN(64) = { 36, 44, 6, 55, 20 }, + V2_VALS[5] ALIGN(64) = { 3, 10, 43, 25, 39 }, + V3_VALS[5] ALIGN(64) = { 41, 45, 15, 21, 8 }, + V4_VALS[5] ALIGN(64) = { 18, 2, 61, 56, 14 }; + + // rotate rows + r0 = _mm512_rolv_epi64(r0, _mm512_maskz_load_epi64(0x1f, V0_VALS)); + r1 = _mm512_rolv_epi64(r1, _mm512_maskz_load_epi64(0x1f, V1_VALS)); + r2 = _mm512_rolv_epi64(r2, _mm512_maskz_load_epi64(0x1f, V2_VALS)); + r3 = _mm512_rolv_epi64(r3, _mm512_maskz_load_epi64(0x1f, V3_VALS)); + r4 = _mm512_rolv_epi64(r4, _mm512_maskz_load_epi64(0x1f, V4_VALS)); } // pi + // + // The cells are permuted across all rows of the state array. each + // output row is the combination of three permutations: + // + // - e0: row 0 and row 1 + // - e2: row 2 and row 3 + // - e4: row 4 and row 0 + // + // the IDs for each permutation are merged into a single array + // (T*_IDS) to reduce register pressure, and the permute operations + // are masked so that each permutation only uses the relevant IDs. + // + // afterwards, the permutations are combined to form a temporary + // row: + // + // t0 = t0e0 | t0e2 | t0e4 + // + // once the permutations for all rows are complete, the temporary + // rows are saved to the actual row registers: + // + // r0 = t0 + // { - // mask bytes - uint8_t m01b = 0x03, - m23b = 0x0c, - m4b = 0x10; - - // load masks - const __mmask8 m01 = _load_mask8(&m01b), - m23 = _load_mask8(&m23b), - m4 = _load_mask8(&m4b); - - // permutation indices - // - // (note: these are masked so only the relevant indices for - // _mm512_maskz_permutex2var_epi64() in each array are filled in) - static const uint64_t t0_pis_01[8] = { 0, 8 + 1, 0, 0, 0, 0, 0, 0 }, - t0_pis_23[8] = { 0, 0, 2, 8 + 3, 0, 0, 0, 0 }, - t0_pis_4[8] = { 0, 0, 0, 0, 4, 0, 0, 0 }, - - t1_pis_01[8] = { 3, 8 + 4, 0, 0, 0, 0, 0, 0 }, - t1_pis_23[8] = { 0, 0, 0, 8 + 1, 0, 0, 0, 0 }, - t1_pis_4[8] = { 0, 0, 0, 0, 2, 0, 0, 0 }, - - t2_pis_01[8] = { 1, 8 + 2, 0, 0, 0, 0, 0, 0 }, - t2_pis_23[8] = { 0, 0, 3, 8 + 4, 0, 0, 0, 0 }, - t2_pis_4[8] = { 0, 0, 0, 0, 0, 0, 0, 0 }, - - t3_pis_01[8] = { 4, 8 + 0, 0, 0, 0, 0, 0, 0 }, - t3_pis_23[8] = { 0, 0, 1, 8 + 2, 0, 0, 0, 0 }, - t3_pis_4[8] = { 0, 0, 0, 0, 3, 0, 0, 0 }, - - t4_pis_01[8] = { 2, 8 + 3, 0, 0, 0, 0, 0, 0 }, - t4_pis_23[8] = { 0, 0, 4, 8 + 0, 0, 0, 0, 0 }, - t4_pis_4[8] = { 0, 0, 0, 0, 1, 0, 0, 0 }; - - // load permutation indices - const __m512i t0_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t0_pis_01), - t0_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t0_pis_23), - t0_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t0_pis_4), - - t1_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t1_pis_01), - t1_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t1_pis_23), - t1_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t1_pis_4), - - t2_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t2_pis_01), - t2_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t2_pis_23), - t2_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t2_pis_4), - - t3_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t3_pis_01), - t3_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t3_pis_23), - t3_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t3_pis_4), - - t4_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t4_pis_01), - t4_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t4_pis_23), - t4_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t4_pis_4); - - // permute rows - const __m512i t0e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t0_p01, r1), - t0e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t0_p23, r3), - t0e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t0_p4, r0), - t0 = _mm512_or_epi64(_mm512_or_epi64(t0e0, t0e2), t0e4), - - t1e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t1_p01, r1), - t1e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t1_p23, r3), - t1e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t1_p4, r0), - t1 = _mm512_or_epi64(_mm512_or_epi64(t1e0, t1e2), t1e4), - - t2e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t2_p01, r1), - t2e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t2_p23, r3), - t2e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t2_p4, r0), - t2 = _mm512_or_epi64(_mm512_or_epi64(t2e0, t2e2), t2e4), - - t3e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t3_p01, r1), - t3e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t3_p23, r3), - t3e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t3_p4, r0), - t3 = _mm512_or_epi64(_mm512_or_epi64(t3e0, t3e2), t3e4), - - t4e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t4_p01, r1), - t4e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t4_p23, r3), - t4e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t4_p4, r0), - t4 = _mm512_or_epi64(_mm512_or_epi64(t4e0, t4e2), t4e4); + // permute ids + static const __m512i T0_IDS = { 0, 8 + 1, 2, 8 + 3, 4 }, + T1_IDS = { 3, 8 + 4, 0, 8 + 1, 2 }, + T2_IDS = { 1, 8 + 2, 3, 8 + 4, 0 }, + T3_IDS = { 4, 8 + 0, 1, 8 + 2, 3 }, + T4_IDS = { 2, 8 + 3, 4, 8 + 0, 1 }; + + __m512i t0, t1, t2, t3, t4; + { + // permute r0 + const __m512i t0e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T0_IDS, r1), + t0e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T0_IDS, r3), + t0e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T0_IDS, r0); + + // permute r1 + const __m512i t1e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T1_IDS, r1), + t1e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T1_IDS, r3), + t1e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T1_IDS, r0); + + // permute r2 + const __m512i t2e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T2_IDS, r1), + t2e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T2_IDS, r3), + t2e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T2_IDS, r0); + + // permute r3 + const __m512i t3e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T3_IDS, r1), + t3e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T3_IDS, r3), + t3e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T3_IDS, r0); + + // permute r4 + const __m512i t4e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T4_IDS, r1), + t4e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T4_IDS, r3), + t4e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T4_IDS, r0); + + // combine permutes: tN = e0 | e2 | e4 + t0 = _mm512_maskz_ternarylogic_epi64(0x1f, t0e0, t0e2, t0e4, 0xfe); + t1 = _mm512_maskz_ternarylogic_epi64(0x1f, t1e0, t1e2, t1e4, 0xfe); + t2 = _mm512_maskz_ternarylogic_epi64(0x1f, t2e0, t2e2, t2e4, 0xfe); + t3 = _mm512_maskz_ternarylogic_epi64(0x1f, t3e0, t3e2, t3e4, 0xfe); + t4 = _mm512_maskz_ternarylogic_epi64(0x1f, t4e0, t4e2, t4e4, 0xfe); + } // store rows r0 = t0; @@ -398,103 +368,86 @@ static inline void permute_avx512(uint64_t s[static 25]) { // chi { - // permutation indices - static const uint64_t pis0[8] = { 1, 2, 3, 4, 0, 0, 0, 0 }, - pis1[8] = { 2, 3, 4, 0, 1, 0, 0, 0 }; - - // load permutation indices - const __m512i p0 = _mm512_maskz_loadu_epi64(m, (void*) pis0), - p1 = _mm512_maskz_loadu_epi64(m, (void*) pis1); - - // permute rows - const __m512i t0_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r0), - t0_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r0), - t0 = _mm512_xor_epi64(r0, _mm512_andnot_epi64(t0_e0, t0_e1)), - - t1_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r1), - t1_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r1), - t1 = _mm512_xor_epi64(r1, _mm512_andnot_epi64(t1_e0, t1_e1)), + // permute ids + static const __m512i P0 = { 1, 2, 3, 4, 0 }, + P1 = { 2, 3, 4, 0, 1 }; + + { + // r0 ^= ~e0 & e1 + const __m512i t0_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r0), + t0_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r0); + r0 = _mm512_maskz_ternarylogic_epi64(0x1f, r0, t0_e0, t0_e1, 0xd2); + } - t2_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r2), - t2_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r2), - t2 = _mm512_xor_epi64(r2, _mm512_andnot_epi64(t2_e0, t2_e1)), + { + // r1 ^= ~e0 & e1 + const __m512i t1_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r1), + t1_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r1); + r1 = _mm512_maskz_ternarylogic_epi64(0x1f, r1, t1_e0, t1_e1, 0xd2); + } - t3_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r3), - t3_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r3), - t3 = _mm512_xor_epi64(r3, _mm512_andnot_epi64(t3_e0, t3_e1)), + { + // r2 ^= ~e0 & e1 + const __m512i t2_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r2), + t2_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r2); + r2 = _mm512_maskz_ternarylogic_epi64(0x1f, r2, t2_e0, t2_e1, 0xd2); + } - t4_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r4), - t4_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r4), - t4 = _mm512_xor_epi64(r4, _mm512_andnot_epi64(t4_e0, t4_e1)); + { + // r3 ^= ~e0 & e1 + const __m512i t3_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r3), + t3_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r3); + r3 = _mm512_maskz_ternarylogic_epi64(0x1f, r3, t3_e0, t3_e1, 0xd2); + } - // store rows - r0 = t0; - r1 = t1; - r2 = t2; - r3 = t3; - r4 = t4; + { + // r4 ^= ~e0 & e1 + const __m512i t4_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r4), + t4_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r4); + r4 = _mm512_maskz_ternarylogic_epi64(0x1f, r4, t4_e0, t4_e1, 0xd2); + } } // iota { - // xor round constant, shuffle rc register - r0 = _mm512_mask_xor_epi64(r0, m0, r0, rc); - rc = _mm512_permutexvar_epi64(rc_p, rc); - - if (((i + 1) % 8) == 0 && i != 23) { - // load next set of round constants - // note: this will bomb if num_rounds < 8 or num_rounds > 24. - rc = _mm512_loadu_epi64((void*) (RCS + (i + 1))); - } + // xor round constant to first cell + r0 = _mm512_mask_xor_epi64(r0, 1, r0, _mm512_maskz_loadu_epi64(1, RCS + i)); } } // store rows - _mm512_mask_storeu_epi64((void*) (s), m, r0), - _mm512_mask_storeu_epi64((void*) (s + 5), m, r1), - _mm512_mask_storeu_epi64((void*) (s + 10), m, r2), - _mm512_mask_storeu_epi64((void*) (s + 15), m, r3), - _mm512_mask_storeu_epi64((void*) (s + 20), m, r4); + _mm512_mask_storeu_epi64(s + 5 * 0, 0x1f, r0); + _mm512_mask_storeu_epi64(s + 5 * 1, 0x1f, r1); + _mm512_mask_storeu_epi64(s + 5 * 2, 0x1f, r2); + _mm512_mask_storeu_epi64(s + 5 * 3, 0x1f, r3); + _mm512_mask_storeu_epi64(s + 5 * 4, 0x1f, r4); } // 12 round keccak permutation (avx512 implementation). static inline void permute12_avx512(uint64_t s[static 25]) { - // unaligned load mask and permutation indices - uint8_t mask = 0x1f, - m0b = 0x01; - const __mmask8 m = _load_mask8(&mask), - m0 = _load_mask8(&m0b); - - // load round constant - // note: this will bomb if num_rounds < 8 or num_rounds > 24. - __m512i rc = _mm512_loadu_epi64((void*) (RCS + 12)); - - // load rc permutation - static const uint64_t rc_ps[8] = { 1, 2, 3, 4, 5, 6, 7, 0 }; - const __m512i rc_p = _mm512_loadu_epi64((void*) rc_ps); - - // load rows - __m512i r0 = _mm512_maskz_loadu_epi64(m, (void*) (s)), - r1 = _mm512_maskz_loadu_epi64(m, (void*) (s + 5)), - r2 = _mm512_maskz_loadu_epi64(m, (void*) (s + 10)), - r3 = _mm512_maskz_loadu_epi64(m, (void*) (s + 15)), - r4 = _mm512_maskz_loadu_epi64(m, (void*) (s + 20)); - - for (int i = 0; i < SHA3_NUM_ROUNDS; i++) { + // load rows (r0-r4) + __m512i r0 = _mm512_maskz_loadu_epi64(0x1f, s + 0), // row 0 + r1 = _mm512_maskz_loadu_epi64(0x1f, s + 5), // row 1 + r2 = _mm512_maskz_loadu_epi64(0x1f, s + 10), // row 2 + r3 = _mm512_maskz_loadu_epi64(0x1f, s + 15), // row 3 + r4 = _mm512_maskz_loadu_epi64(0x1f, s + 20); // row 4 + + // 12 rounds + for (size_t i = 0; i < 12; i++) { // theta { - const __m512i i0 = _mm512_setr_epi64(4, 0, 1, 2, 3, 5, 6, 7), - i1 = _mm512_setr_epi64(1, 2, 3, 4, 0, 5, 6, 7); + // permute ids + static const __m512i I0 = { 4, 0, 1, 2, 3 }, + I1 = { 1, 2, 3, 4, 0 }; // c = xor(r0, r1, r2, r3, r4) - const __m512i r01 = _mm512_xor_epi64(r0, r1), - r23 = _mm512_xor_epi64(r2, r3), - r03 = _mm512_xor_epi64(r01, r23), - c = _mm512_xor_epi64(r03, r4); + const __m512i r01 = _mm512_maskz_xor_epi64(0x1f, r0, r1), + r23 = _mm512_maskz_xor_epi64(0x1f, r2, r3), + c = _mm512_maskz_ternarylogic_epi64(0x1f, r01, r23, r4, 0x96); // d = xor(permute(i0, c), permute(i1, rol(c, 1))) - const __m512i d0 = _mm512_permutexvar_epi64(i0, c), - d1 = _mm512_permutexvar_epi64(i1, _mm512_rol_epi64(c, 1)), + const __m512i d0 = _mm512_permutexvar_epi64(I0, c), + d1 = _mm512_permutexvar_epi64(I1, _mm512_rol_epi64(c, 1)), d = _mm512_xor_epi64(d0, d1); // row = xor(row, d) @@ -507,110 +460,89 @@ static inline void permute12_avx512(uint64_t s[static 25]) { // rho { - // unaligned load mask and rotate values - static const uint64_t vs0[8] = { 0, 1, 62, 28, 27, 0, 0, 0 }, - vs1[8] = { 36, 44, 6, 55, 20, 0, 0, 0 }, - vs2[8] = { 3, 10, 43, 25, 39, 0, 0, 0 }, - vs3[8] = { 41, 45, 15, 21, 8, 0, 0, 0 }, - vs4[8] = { 18, 2, 61, 56, 14, 0, 0, 0 }; - - // load rotate values - const __m512i v0 = _mm512_loadu_epi64((void*) vs0), - v1 = _mm512_loadu_epi64((void*) vs1), - v2 = _mm512_loadu_epi64((void*) vs2), - v3 = _mm512_loadu_epi64((void*) vs3), - v4 = _mm512_loadu_epi64((void*) vs4); - - // rotate - r0 = _mm512_rolv_epi64(r0, v0); - r1 = _mm512_rolv_epi64(r1, v1); - r2 = _mm512_rolv_epi64(r2, v2); - r3 = _mm512_rolv_epi64(r3, v3); - r4 = _mm512_rolv_epi64(r4, v4); + // rotate values + // + // note: switching from maskz_load_epi64()s to static const + // __m512i incurs a 500 cycle penalty; leaving them for now + static const uint64_t V0_VALS[5] ALIGN(64) = { 0, 1, 62, 28, 27 }, + V1_VALS[5] ALIGN(64) = { 36, 44, 6, 55, 20 }, + V2_VALS[5] ALIGN(64) = { 3, 10, 43, 25, 39 }, + V3_VALS[5] ALIGN(64) = { 41, 45, 15, 21, 8 }, + V4_VALS[5] ALIGN(64) = { 18, 2, 61, 56, 14 }; + + // rotate rows + r0 = _mm512_rolv_epi64(r0, _mm512_maskz_load_epi64(0x1f, V0_VALS)); + r1 = _mm512_rolv_epi64(r1, _mm512_maskz_load_epi64(0x1f, V1_VALS)); + r2 = _mm512_rolv_epi64(r2, _mm512_maskz_load_epi64(0x1f, V2_VALS)); + r3 = _mm512_rolv_epi64(r3, _mm512_maskz_load_epi64(0x1f, V3_VALS)); + r4 = _mm512_rolv_epi64(r4, _mm512_maskz_load_epi64(0x1f, V4_VALS)); } // pi + // + // The cells are permuted across all rows of the state array. each + // output row is the combination of three permutations: + // + // - e0: row 0 and row 1 + // - e2: row 2 and row 3 + // - e4: row 4 and row 0 + // + // the IDs for each permutation are merged into a single array + // (T*_IDS) to reduce register pressure, and the permute operations + // are masked so that each permutation only uses the relevant IDs. + // + // afterwards, the permutations are combined to form a temporary + // row: + // + // t0 = t0e0 | t0e2 | t0e4 + // + // once the permutations for all rows are complete, the temporary + // rows are saved to the actual row registers: + // + // r0 = t0 + // { - // mask bytes - uint8_t m01b = 0x03, - m23b = 0x0c, - m4b = 0x10; - - // load masks - const __mmask8 m01 = _load_mask8(&m01b), - m23 = _load_mask8(&m23b), - m4 = _load_mask8(&m4b); - - // permutation indices - // - // (note: these are masked so only the relevant indices for - // _mm512_maskz_permutex2var_epi64() in each array are filled in) - static const uint64_t t0_pis_01[8] = { 0, 8 + 1, 0, 0, 0, 0, 0, 0 }, - t0_pis_23[8] = { 0, 0, 2, 8 + 3, 0, 0, 0, 0 }, - t0_pis_4[8] = { 0, 0, 0, 0, 4, 0, 0, 0 }, - - t1_pis_01[8] = { 3, 8 + 4, 0, 0, 0, 0, 0, 0 }, - t1_pis_23[8] = { 0, 0, 0, 8 + 1, 0, 0, 0, 0 }, - t1_pis_4[8] = { 0, 0, 0, 0, 2, 0, 0, 0 }, - - t2_pis_01[8] = { 1, 8 + 2, 0, 0, 0, 0, 0, 0 }, - t2_pis_23[8] = { 0, 0, 3, 8 + 4, 0, 0, 0, 0 }, - t2_pis_4[8] = { 0, 0, 0, 0, 0, 0, 0, 0 }, - - t3_pis_01[8] = { 4, 8 + 0, 0, 0, 0, 0, 0, 0 }, - t3_pis_23[8] = { 0, 0, 1, 8 + 2, 0, 0, 0, 0 }, - t3_pis_4[8] = { 0, 0, 0, 0, 3, 0, 0, 0 }, - - t4_pis_01[8] = { 2, 8 + 3, 0, 0, 0, 0, 0, 0 }, - t4_pis_23[8] = { 0, 0, 4, 8 + 0, 0, 0, 0, 0 }, - t4_pis_4[8] = { 0, 0, 0, 0, 1, 0, 0, 0 }; - - // load permutation indices - const __m512i t0_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t0_pis_01), - t0_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t0_pis_23), - t0_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t0_pis_4), - - t1_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t1_pis_01), - t1_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t1_pis_23), - t1_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t1_pis_4), - - t2_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t2_pis_01), - t2_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t2_pis_23), - t2_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t2_pis_4), - - t3_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t3_pis_01), - t3_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t3_pis_23), - t3_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t3_pis_4), - - t4_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t4_pis_01), - t4_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t4_pis_23), - t4_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t4_pis_4); - - // permute rows - const __m512i t0e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t0_p01, r1), - t0e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t0_p23, r3), - t0e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t0_p4, r0), - t0 = _mm512_or_epi64(_mm512_or_epi64(t0e0, t0e2), t0e4), - - t1e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t1_p01, r1), - t1e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t1_p23, r3), - t1e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t1_p4, r0), - t1 = _mm512_or_epi64(_mm512_or_epi64(t1e0, t1e2), t1e4), - - t2e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t2_p01, r1), - t2e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t2_p23, r3), - t2e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t2_p4, r0), - t2 = _mm512_or_epi64(_mm512_or_epi64(t2e0, t2e2), t2e4), - - t3e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t3_p01, r1), - t3e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t3_p23, r3), - t3e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t3_p4, r0), - t3 = _mm512_or_epi64(_mm512_or_epi64(t3e0, t3e2), t3e4), - - t4e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t4_p01, r1), - t4e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t4_p23, r3), - t4e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t4_p4, r0), - t4 = _mm512_or_epi64(_mm512_or_epi64(t4e0, t4e2), t4e4); + // permute ids + static const __m512i T0_IDS = { 0, 8 + 1, 2, 8 + 3, 4 }, + T1_IDS = { 3, 8 + 4, 0, 8 + 1, 2 }, + T2_IDS = { 1, 8 + 2, 3, 8 + 4, 0 }, + T3_IDS = { 4, 8 + 0, 1, 8 + 2, 3 }, + T4_IDS = { 2, 8 + 3, 4, 8 + 0, 1 }; + + __m512i t0, t1, t2, t3, t4; + { + // permute r0 + const __m512i t0e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T0_IDS, r1), + t0e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T0_IDS, r3), + t0e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T0_IDS, r0); + + // permute r1 + const __m512i t1e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T1_IDS, r1), + t1e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T1_IDS, r3), + t1e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T1_IDS, r0); + + // permute r2 + const __m512i t2e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T2_IDS, r1), + t2e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T2_IDS, r3), + t2e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T2_IDS, r0); + + // permute r3 + const __m512i t3e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T3_IDS, r1), + t3e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T3_IDS, r3), + t3e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T3_IDS, r0); + + // permute r4 + const __m512i t4e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T4_IDS, r1), + t4e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T4_IDS, r3), + t4e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T4_IDS, r0); + + // combine permutes: tN = e0 | e2 | e4 + t0 = _mm512_maskz_ternarylogic_epi64(0x1f, t0e0, t0e2, t0e4, 0xfe); + t1 = _mm512_maskz_ternarylogic_epi64(0x1f, t1e0, t1e2, t1e4, 0xfe); + t2 = _mm512_maskz_ternarylogic_epi64(0x1f, t2e0, t2e2, t2e4, 0xfe); + t3 = _mm512_maskz_ternarylogic_epi64(0x1f, t3e0, t3e2, t3e4, 0xfe); + t4 = _mm512_maskz_ternarylogic_epi64(0x1f, t4e0, t4e2, t4e4, 0xfe); + } // store rows r0 = t0; @@ -622,63 +554,59 @@ static inline void permute12_avx512(uint64_t s[static 25]) { // chi { - // permutation indices - static const uint64_t pis0[8] = { 1, 2, 3, 4, 0, 0, 0, 0 }, - pis1[8] = { 2, 3, 4, 0, 1, 0, 0, 0 }; - - // load permutation indices - const __m512i p0 = _mm512_maskz_loadu_epi64(m, (void*) pis0), - p1 = _mm512_maskz_loadu_epi64(m, (void*) pis1); - - // permute rows - const __m512i t0_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r0), - t0_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r0), - t0 = _mm512_xor_epi64(r0, _mm512_andnot_epi64(t0_e0, t0_e1)), - - t1_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r1), - t1_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r1), - t1 = _mm512_xor_epi64(r1, _mm512_andnot_epi64(t1_e0, t1_e1)), + // permute ids + static const __m512i P0 = { 1, 2, 3, 4, 0 }, + P1 = { 2, 3, 4, 0, 1 }; + + { + // r0 ^= ~e0 & e1 + const __m512i t0_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r0), + t0_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r0); + r0 = _mm512_maskz_ternarylogic_epi64(0x1f, r0, t0_e0, t0_e1, 0xd2); + } - t2_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r2), - t2_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r2), - t2 = _mm512_xor_epi64(r2, _mm512_andnot_epi64(t2_e0, t2_e1)), + { + // r1 ^= ~e0 & e1 + const __m512i t1_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r1), + t1_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r1); + r1 = _mm512_maskz_ternarylogic_epi64(0x1f, r1, t1_e0, t1_e1, 0xd2); + } - t3_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r3), - t3_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r3), - t3 = _mm512_xor_epi64(r3, _mm512_andnot_epi64(t3_e0, t3_e1)), + { + // r2 ^= ~e0 & e1 + const __m512i t2_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r2), + t2_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r2); + r2 = _mm512_maskz_ternarylogic_epi64(0x1f, r2, t2_e0, t2_e1, 0xd2); + } - t4_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r4), - t4_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r4), - t4 = _mm512_xor_epi64(r4, _mm512_andnot_epi64(t4_e0, t4_e1)); + { + // r3 ^= ~e0 & e1 + const __m512i t3_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r3), + t3_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r3); + r3 = _mm512_maskz_ternarylogic_epi64(0x1f, r3, t3_e0, t3_e1, 0xd2); + } - // store rows - r0 = t0; - r1 = t1; - r2 = t2; - r3 = t3; - r4 = t4; + { + // r4 ^= ~e0 & e1 + const __m512i t4_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r4), + t4_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r4); + r4 = _mm512_maskz_ternarylogic_epi64(0x1f, r4, t4_e0, t4_e1, 0xd2); + } } // iota { - // xor round constant, shuffle rc register - r0 = _mm512_mask_xor_epi64(r0, m0, r0, rc); - rc = _mm512_permutexvar_epi64(rc_p, rc); - - if (((12 + i + 1) % 8) == 0 && (12 + i) != 23) { - // load next set of round constants - // note: this will bomb if num_rounds < 8 or num_rounds > 24. - rc = _mm512_loadu_epi64((void*) (RCS + (12 + i + 1))); - } + // xor round constant to first cell + r0 = _mm512_mask_xor_epi64(r0, 1, r0, _mm512_maskz_loadu_epi64(1, RCS + (12 + i))); } } // store rows - _mm512_mask_storeu_epi64((void*) (s), m, r0), - _mm512_mask_storeu_epi64((void*) (s + 5), m, r1), - _mm512_mask_storeu_epi64((void*) (s + 10), m, r2), - _mm512_mask_storeu_epi64((void*) (s + 15), m, r3), - _mm512_mask_storeu_epi64((void*) (s + 20), m, r4); + _mm512_mask_storeu_epi64(s + 5 * 0, 0x1f, r0); + _mm512_mask_storeu_epi64(s + 5 * 1, 0x1f, r1); + _mm512_mask_storeu_epi64(s + 5 * 2, 0x1f, r2); + _mm512_mask_storeu_epi64(s + 5 * 3, 0x1f, r3); + _mm512_mask_storeu_epi64(s + 5 * 4, 0x1f, r4); } #endif /* __AVX512F__ */ -- cgit v1.2.3