aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sha3.c75
1 files changed, 46 insertions, 29 deletions
diff --git a/sha3.c b/sha3.c
index 5353a26..0e0ebf3 100644
--- a/sha3.c
+++ b/sha3.c
@@ -573,10 +573,16 @@ static const __m256i M0 = { ~0, 0, 0, 0 }, // mask, first lane only
#define THETA_I1_HI 0x00 // 0, 0, 0, 0 -> 0b00000000 -> 0x00
// pi permute IDs
-#define PI_I0_LO 0x90 // 0, 0, 1, 2 -> 0b10010000 -> 0x90
-#define PI_I0_HI 0x03 // 3, 0, 0, 0 -> 0b00000011 -> 0x03
-#define PI_I1_LO 0x39 // 1, 2, 3, 0 -> 0b00111001 -> 0x39
-#define PI_I1_HI 0x00 // 0, 0, 0, 0 -> 0b00000000 -> 0x00
+#define PI_T0_LO 0xe4 // 0b11100100 -> 0xe4
+#define PI_T0_HI 0x00
+#define PI_T1_LO 0x43 // 0b01000011 -> 0x43
+#define PI_T1_HI 0x02
+#define PI_T2_LO 0x39 // 0b00111001 -> 0x39
+#define PI_T2_HI 0x00
+#define PI_T3_LO 0x90 // 0b10010000 -> 0x90
+#define PI_T3_HI 0x03
+#define PI_T4_LO 0x0e // 0b00001110 -> 0x0e
+#define PI_T4_HI 0x01
// chi permute IDs
#define CHI_I0_LO 0x39 // 1, 2, 3, 0 -> 0b00111001 -> 0x39
@@ -621,10 +627,6 @@ static const __m256i M0 = { ~0, 0, 0, 0 }, // mask, first lane only
* `num_rounds` is either 12 for TurboSHAKE and KangarooTwelve or 24
* otherwise.
*
- * (Note: for the Pi step the registers are stored back to the state
- * array and then gathered to permute the state. This is different than
- * the AVX-512 implementation because of register pressure).
- *
* 4. The permuted Keccak state is copied back to `s`.
*/
static inline void permute_n_avx2(uint64_t s[static 25], const size_t num_rounds) {
@@ -679,28 +681,43 @@ static inline void permute_n_avx2(uint64_t s[static 25], const size_t num_rounds
}
// pi
- //
- // store state array, then gather to permute the state. note: with
- // some work we could probably do in-register permutes, but
- // benchmark first to see if this is worth the trouble.
{
- static const __m256i V0_LO = { 0, 6, 12, 18 },
- V1_LO = { 3, 9, 10, 16 },
- V2_LO = { 1, 7, 13, 19 },
- V3_LO = { 4, 5, 11, 17 },
- V4_LO = { 2, 8, 14, 15 };
- static const size_t V0_HI = 24, V1_HI = 22, V2_HI = 20, V3_HI = 23, V4_HI = 21;
-
- // store rows to state, then gather to permute
- AVX2_STORE(s);
-
- // re-load using gather to permute
- union { long long int *i64; uint64_t *u64; } p = { .u64 = s };
- r0_lo = _mm256_i64gather_epi64(p.i64, V0_LO, 8); r0_hi = ((__m256i) { s[V0_HI] });
- r1_lo = _mm256_i64gather_epi64(p.i64, V1_LO, 8); r1_hi = ((__m256i) { s[V1_HI] });
- r2_lo = _mm256_i64gather_epi64(p.i64, V2_LO, 8); r2_hi = ((__m256i) { s[V2_HI] });
- r3_lo = _mm256_i64gather_epi64(p.i64, V3_LO, 8); r3_hi = ((__m256i) { s[V3_HI] });
- r4_lo = _mm256_i64gather_epi64(p.i64, V4_LO, 8); r4_hi = ((__m256i) { s[V4_HI] });
+ static const __m256i LM0 = { ~0, 0, 0, 0 },
+ LM1 = { 0, ~0, 0, 0 },
+ LM2 = { 0, 0, ~0, 0 },
+ LM3 = { 0, 0, 0, ~0 };
+
+ const __m256i t0_lo = (_mm256_permute4x64_epi64(r0_lo, PI_T0_LO) & LM0) |
+ (_mm256_permute4x64_epi64(r1_lo, PI_T0_LO) & LM1) |
+ (_mm256_permute4x64_epi64(r2_lo, PI_T0_LO) & LM2) |
+ (_mm256_permute4x64_epi64(r3_lo, PI_T0_LO) & LM3),
+ t0_hi = (_mm256_permute4x64_epi64(r4_hi, PI_T0_HI) & LM0),
+ t1_lo = (_mm256_permute4x64_epi64(r0_lo, PI_T1_LO) & LM0) |
+ (_mm256_permute4x64_epi64(r1_hi, PI_T1_LO) & LM1) |
+ (_mm256_permute4x64_epi64(r2_lo, PI_T1_LO) & LM2) |
+ (_mm256_permute4x64_epi64(r3_lo, PI_T1_LO) & LM3),
+ t1_hi = (_mm256_permute4x64_epi64(r4_lo, PI_T1_HI) & LM0),
+ t2_lo = (_mm256_permute4x64_epi64(r0_lo, PI_T2_LO) & LM0) |
+ (_mm256_permute4x64_epi64(r1_lo, PI_T2_LO) & LM1) |
+ (_mm256_permute4x64_epi64(r2_lo, PI_T2_LO) & LM2) |
+ (_mm256_permute4x64_epi64(r3_hi, PI_T2_LO) & LM3),
+ t2_hi = (_mm256_permute4x64_epi64(r4_lo, PI_T2_HI) & LM0),
+ t3_lo = (_mm256_permute4x64_epi64(r0_hi, PI_T3_LO) & LM0) |
+ (_mm256_permute4x64_epi64(r1_lo, PI_T3_LO) & LM1) |
+ (_mm256_permute4x64_epi64(r2_lo, PI_T3_LO) & LM2) |
+ (_mm256_permute4x64_epi64(r3_lo, PI_T3_LO) & LM3),
+ t3_hi = (_mm256_permute4x64_epi64(r4_lo, PI_T3_HI) & LM0),
+ t4_lo = (_mm256_permute4x64_epi64(r0_lo, PI_T4_LO) & LM0) |
+ (_mm256_permute4x64_epi64(r1_lo, PI_T4_LO) & LM1) |
+ (_mm256_permute4x64_epi64(r2_hi, PI_T4_LO) & LM2) |
+ (_mm256_permute4x64_epi64(r3_lo, PI_T4_LO) & LM3),
+ t4_hi = (_mm256_permute4x64_epi64(r4_lo, PI_T4_HI) & LM0);
+
+ r0_lo = t0_lo; r0_hi = t0_hi;
+ r1_lo = t1_lo; r1_hi = t1_hi;
+ r2_lo = t2_lo; r2_hi = t2_hi;
+ r3_lo = t3_lo; r3_hi = t3_hi;
+ r4_lo = t4_lo; r4_hi = t4_hi;
}
// chi