diff options
-rw-r--r-- | sha3.c | 75 |
1 files changed, 46 insertions, 29 deletions
@@ -573,10 +573,16 @@ static const __m256i M0 = { ~0, 0, 0, 0 }, // mask, first lane only #define THETA_I1_HI 0x00 // 0, 0, 0, 0 -> 0b00000000 -> 0x00 // pi permute IDs -#define PI_I0_LO 0x90 // 0, 0, 1, 2 -> 0b10010000 -> 0x90 -#define PI_I0_HI 0x03 // 3, 0, 0, 0 -> 0b00000011 -> 0x03 -#define PI_I1_LO 0x39 // 1, 2, 3, 0 -> 0b00111001 -> 0x39 -#define PI_I1_HI 0x00 // 0, 0, 0, 0 -> 0b00000000 -> 0x00 +#define PI_T0_LO 0xe4 // 0b11100100 -> 0xe4 +#define PI_T0_HI 0x00 +#define PI_T1_LO 0x43 // 0b01000011 -> 0x43 +#define PI_T1_HI 0x02 +#define PI_T2_LO 0x39 // 0b00111001 -> 0x39 +#define PI_T2_HI 0x00 +#define PI_T3_LO 0x90 // 0b10010000 -> 0x90 +#define PI_T3_HI 0x03 +#define PI_T4_LO 0x0e // 0b00001110 -> 0x0e +#define PI_T4_HI 0x01 // chi permute IDs #define CHI_I0_LO 0x39 // 1, 2, 3, 0 -> 0b00111001 -> 0x39 @@ -621,10 +627,6 @@ static const __m256i M0 = { ~0, 0, 0, 0 }, // mask, first lane only * `num_rounds` is either 12 for TurboSHAKE and KangarooTwelve or 24 * otherwise. * - * (Note: for the Pi step the registers are stored back to the state - * array and then gathered to permute the state. This is different than - * the AVX-512 implementation because of register pressure). - * * 4. The permuted Keccak state is copied back to `s`. */ static inline void permute_n_avx2(uint64_t s[static 25], const size_t num_rounds) { @@ -679,28 +681,43 @@ static inline void permute_n_avx2(uint64_t s[static 25], const size_t num_rounds } // pi - // - // store state array, then gather to permute the state. note: with - // some work we could probably do in-register permutes, but - // benchmark first to see if this is worth the trouble. { - static const __m256i V0_LO = { 0, 6, 12, 18 }, - V1_LO = { 3, 9, 10, 16 }, - V2_LO = { 1, 7, 13, 19 }, - V3_LO = { 4, 5, 11, 17 }, - V4_LO = { 2, 8, 14, 15 }; - static const size_t V0_HI = 24, V1_HI = 22, V2_HI = 20, V3_HI = 23, V4_HI = 21; - - // store rows to state, then gather to permute - AVX2_STORE(s); - - // re-load using gather to permute - union { long long int *i64; uint64_t *u64; } p = { .u64 = s }; - r0_lo = _mm256_i64gather_epi64(p.i64, V0_LO, 8); r0_hi = ((__m256i) { s[V0_HI] }); - r1_lo = _mm256_i64gather_epi64(p.i64, V1_LO, 8); r1_hi = ((__m256i) { s[V1_HI] }); - r2_lo = _mm256_i64gather_epi64(p.i64, V2_LO, 8); r2_hi = ((__m256i) { s[V2_HI] }); - r3_lo = _mm256_i64gather_epi64(p.i64, V3_LO, 8); r3_hi = ((__m256i) { s[V3_HI] }); - r4_lo = _mm256_i64gather_epi64(p.i64, V4_LO, 8); r4_hi = ((__m256i) { s[V4_HI] }); + static const __m256i LM0 = { ~0, 0, 0, 0 }, + LM1 = { 0, ~0, 0, 0 }, + LM2 = { 0, 0, ~0, 0 }, + LM3 = { 0, 0, 0, ~0 }; + + const __m256i t0_lo = (_mm256_permute4x64_epi64(r0_lo, PI_T0_LO) & LM0) | + (_mm256_permute4x64_epi64(r1_lo, PI_T0_LO) & LM1) | + (_mm256_permute4x64_epi64(r2_lo, PI_T0_LO) & LM2) | + (_mm256_permute4x64_epi64(r3_lo, PI_T0_LO) & LM3), + t0_hi = (_mm256_permute4x64_epi64(r4_hi, PI_T0_HI) & LM0), + t1_lo = (_mm256_permute4x64_epi64(r0_lo, PI_T1_LO) & LM0) | + (_mm256_permute4x64_epi64(r1_hi, PI_T1_LO) & LM1) | + (_mm256_permute4x64_epi64(r2_lo, PI_T1_LO) & LM2) | + (_mm256_permute4x64_epi64(r3_lo, PI_T1_LO) & LM3), + t1_hi = (_mm256_permute4x64_epi64(r4_lo, PI_T1_HI) & LM0), + t2_lo = (_mm256_permute4x64_epi64(r0_lo, PI_T2_LO) & LM0) | + (_mm256_permute4x64_epi64(r1_lo, PI_T2_LO) & LM1) | + (_mm256_permute4x64_epi64(r2_lo, PI_T2_LO) & LM2) | + (_mm256_permute4x64_epi64(r3_hi, PI_T2_LO) & LM3), + t2_hi = (_mm256_permute4x64_epi64(r4_lo, PI_T2_HI) & LM0), + t3_lo = (_mm256_permute4x64_epi64(r0_hi, PI_T3_LO) & LM0) | + (_mm256_permute4x64_epi64(r1_lo, PI_T3_LO) & LM1) | + (_mm256_permute4x64_epi64(r2_lo, PI_T3_LO) & LM2) | + (_mm256_permute4x64_epi64(r3_lo, PI_T3_LO) & LM3), + t3_hi = (_mm256_permute4x64_epi64(r4_lo, PI_T3_HI) & LM0), + t4_lo = (_mm256_permute4x64_epi64(r0_lo, PI_T4_LO) & LM0) | + (_mm256_permute4x64_epi64(r1_lo, PI_T4_LO) & LM1) | + (_mm256_permute4x64_epi64(r2_hi, PI_T4_LO) & LM2) | + (_mm256_permute4x64_epi64(r3_lo, PI_T4_LO) & LM3), + t4_hi = (_mm256_permute4x64_epi64(r4_lo, PI_T4_HI) & LM0); + + r0_lo = t0_lo; r0_hi = t0_hi; + r1_lo = t1_lo; r1_hi = t1_hi; + r2_lo = t2_lo; r2_hi = t2_hi; + r3_lo = t3_lo; r3_hi = t3_hi; + r4_lo = t4_lo; r4_hi = t4_hi; } // chi |