summaryrefslogtreecommitdiff
path: root/sha3.c
diff options
context:
space:
mode:
authorPaul Duncan <pabs@pablotron.org>2024-04-29 19:48:16 -0400
committerPaul Duncan <pabs@pablotron.org>2024-04-29 19:48:16 -0400
commitd898f431491846c5742874d81f834e45c984e535 (patch)
tree552e9765832d1af8a69a7f821daa1831b7693cb6 /sha3.c
parentd9f691ab935591725110c74554e3fdc614e9b775 (diff)
downloadsha3-d898f431491846c5742874d81f834e45c984e535.tar.bz2
sha3-d898f431491846c5742874d81f834e45c984e535.zip
sha3.c: add permute_n_{scalar,avx512}() and refactor permute{,12}_{scalar,avx512}() to use them
Diffstat (limited to 'sha3.c')
-rw-r--r--sha3.c221
1 files changed, 23 insertions, 198 deletions
diff --git a/sha3.c b/sha3.c
index b364ae2..3925f25 100644
--- a/sha3.c
+++ b/sha3.c
@@ -194,36 +194,34 @@ static inline void iota(uint64_t a[static 25], const int i) {
a[0] ^= RCS[i];
}
-// 24-round keccak permutation (scalar implementation)
-static inline void permute_scalar(uint64_t a[static 25]) {
+// keccak permutation (scalar implementation)
+static inline void permute_n_scalar(uint64_t a[static 25], const size_t num_rounds) {
uint64_t tmp[25] = { 0 };
- for (size_t i = 0; i < SHA3_NUM_ROUNDS; i++) {
+ for (size_t i = 0; i < num_rounds; i++) {
theta(a);
rho(a);
pi(tmp, a);
chi(a, tmp);
- iota(a, i);
+ iota(a, (SHA3_NUM_ROUNDS - num_rounds + i));
}
}
+// 24 round keccak permutation (scalar implementation)
+static inline void permute_scalar(uint64_t a[static 25]) {
+ permute_n_scalar(a, 24);
+}
+
// 12 round keccak permutation (scalar implementation)
// (only used by turboshake)
static inline void permute12_scalar(uint64_t a[static 25]) {
- uint64_t tmp[25] = { 0 };
- for (size_t i = 0; i < 12; i++) {
- theta(a);
- rho(a);
- pi(tmp, a);
- chi(a, tmp);
- iota(a, 12 + i);
- }
+ permute_n_scalar(a, 12);
}
#endif /* (SHA3_BACKEND == SHA3_BACKEND_SCALAR) || defined(SHA3_TEST) */
#if SHA3_BACKEND == SHA3_BACKEND_AVX512
#include <immintrin.h>
-// 24 round keccak permutation (avx512 implementation).
+// keccak permutation (avx512 implementation).
//
// how it operates (roughly):
//
@@ -251,7 +249,7 @@ static inline void permute12_scalar(uint64_t a[static 25]) {
// 3. store the rows first 5 64-bit lanes of registers r0-r4 back to the
// state `s`.
//
-static inline void permute_avx512(uint64_t s[static 25]) {
+static inline void permute_n_avx512(uint64_t s[static 25], const size_t num_rounds) {
// load rows (r0-r4)
__m512i r0 = _mm512_maskz_loadu_epi64(0x1f, s + 0), // row 0
r1 = _mm512_maskz_loadu_epi64(0x1f, s + 5), // row 1
@@ -260,7 +258,7 @@ static inline void permute_avx512(uint64_t s[static 25]) {
r4 = _mm512_maskz_loadu_epi64(0x1f, s + 20); // row 4
// 24 rounds
- for (size_t i = 0; i < SHA3_NUM_ROUNDS; i++) {
+ for (size_t i = 0; i < num_rounds; i++) {
// theta
{
// permute ids
@@ -423,8 +421,11 @@ static inline void permute_avx512(uint64_t s[static 25]) {
// iota
{
+ // calculate RCS offset
+ const size_t ofs = SHA3_NUM_ROUNDS - num_rounds + i;
+
// xor round constant to first cell
- r0 = _mm512_mask_xor_epi64(r0, 1, r0, _mm512_maskz_loadu_epi64(1, RCS + i));
+ r0 = _mm512_mask_xor_epi64(r0, 1, r0, _mm512_maskz_loadu_epi64(1, RCS + ofs));
}
}
@@ -436,190 +437,14 @@ static inline void permute_avx512(uint64_t s[static 25]) {
_mm512_mask_storeu_epi64(s + 5 * 4, 0x1f, r4);
}
+// 24 round keccak permutation (avx512 implementation).
+static inline void permute_avx512(uint64_t s[static 25]) {
+ permute_n_avx512(s, 24);
+}
+
// 12 round keccak permutation (avx512 implementation).
static inline void permute12_avx512(uint64_t s[static 25]) {
- // load rows (r0-r4)
- __m512i r0 = _mm512_maskz_loadu_epi64(0x1f, s + 0), // row 0
- r1 = _mm512_maskz_loadu_epi64(0x1f, s + 5), // row 1
- r2 = _mm512_maskz_loadu_epi64(0x1f, s + 10), // row 2
- r3 = _mm512_maskz_loadu_epi64(0x1f, s + 15), // row 3
- r4 = _mm512_maskz_loadu_epi64(0x1f, s + 20); // row 4
-
- // 12 rounds
- for (size_t i = 0; i < 12; i++) {
- // theta
- {
- // permute ids
- static const __m512i I0 = { 4, 0, 1, 2, 3 },
- I1 = { 1, 2, 3, 4, 0 };
-
- // c = xor(r0, r1, r2, r3, r4)
- const __m512i r01 = _mm512_maskz_xor_epi64(0x1f, r0, r1),
- r23 = _mm512_maskz_xor_epi64(0x1f, r2, r3),
- c = _mm512_maskz_ternarylogic_epi64(0x1f, r01, r23, r4, 0x96);
-
- // d = xor(permute(i0, c), permute(i1, rol(c, 1)))
- const __m512i d0 = _mm512_permutexvar_epi64(I0, c),
- d1 = _mm512_permutexvar_epi64(I1, _mm512_rol_epi64(c, 1)),
- d = _mm512_xor_epi64(d0, d1);
-
- // row = xor(row, d)
- r0 = _mm512_xor_epi64(r0, d);
- r1 = _mm512_xor_epi64(r1, d);
- r2 = _mm512_xor_epi64(r2, d);
- r3 = _mm512_xor_epi64(r3, d);
- r4 = _mm512_xor_epi64(r4, d);
- }
-
- // rho
- {
- // rotate values
- //
- // note: switching from maskz_load_epi64()s to static const
- // __m512i incurs a 500 cycle penalty; leaving them for now
- static const uint64_t V0_VALS[5] ALIGN(64) = { 0, 1, 62, 28, 27 },
- V1_VALS[5] ALIGN(64) = { 36, 44, 6, 55, 20 },
- V2_VALS[5] ALIGN(64) = { 3, 10, 43, 25, 39 },
- V3_VALS[5] ALIGN(64) = { 41, 45, 15, 21, 8 },
- V4_VALS[5] ALIGN(64) = { 18, 2, 61, 56, 14 };
-
- // rotate rows
- r0 = _mm512_rolv_epi64(r0, _mm512_maskz_load_epi64(0x1f, V0_VALS));
- r1 = _mm512_rolv_epi64(r1, _mm512_maskz_load_epi64(0x1f, V1_VALS));
- r2 = _mm512_rolv_epi64(r2, _mm512_maskz_load_epi64(0x1f, V2_VALS));
- r3 = _mm512_rolv_epi64(r3, _mm512_maskz_load_epi64(0x1f, V3_VALS));
- r4 = _mm512_rolv_epi64(r4, _mm512_maskz_load_epi64(0x1f, V4_VALS));
- }
-
- // pi
- //
- // The cells are permuted across all rows of the state array. each
- // output row is the combination of three permutations:
- //
- // - e0: row 0 and row 1
- // - e2: row 2 and row 3
- // - e4: row 4 and row 0
- //
- // the IDs for each permutation are merged into a single array
- // (T*_IDS) to reduce register pressure, and the permute operations
- // are masked so that each permutation only uses the relevant IDs.
- //
- // afterwards, the permutations are combined to form a temporary
- // row:
- //
- // t0 = t0e0 | t0e2 | t0e4
- //
- // once the permutations for all rows are complete, the temporary
- // rows are saved to the actual row registers:
- //
- // r0 = t0
- //
- {
- // permute ids
- static const __m512i T0_IDS = { 0, 8 + 1, 2, 8 + 3, 4 },
- T1_IDS = { 3, 8 + 4, 0, 8 + 1, 2 },
- T2_IDS = { 1, 8 + 2, 3, 8 + 4, 0 },
- T3_IDS = { 4, 8 + 0, 1, 8 + 2, 3 },
- T4_IDS = { 2, 8 + 3, 4, 8 + 0, 1 };
-
- __m512i t0, t1, t2, t3, t4;
- {
- // permute r0
- const __m512i t0e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T0_IDS, r1),
- t0e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T0_IDS, r3),
- t0e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T0_IDS, r0);
-
- // permute r1
- const __m512i t1e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T1_IDS, r1),
- t1e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T1_IDS, r3),
- t1e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T1_IDS, r0);
-
- // permute r2
- const __m512i t2e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T2_IDS, r1),
- t2e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T2_IDS, r3),
- t2e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T2_IDS, r0);
-
- // permute r3
- const __m512i t3e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T3_IDS, r1),
- t3e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T3_IDS, r3),
- t3e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T3_IDS, r0);
-
- // permute r4
- const __m512i t4e0 = _mm512_maskz_permutex2var_epi64(0x03, r0, T4_IDS, r1),
- t4e2 = _mm512_maskz_permutex2var_epi64(0x0c, r2, T4_IDS, r3),
- t4e4 = _mm512_maskz_permutex2var_epi64(0x10, r4, T4_IDS, r0);
-
- // combine permutes: tN = e0 | e2 | e4
- t0 = _mm512_maskz_ternarylogic_epi64(0x1f, t0e0, t0e2, t0e4, 0xfe);
- t1 = _mm512_maskz_ternarylogic_epi64(0x1f, t1e0, t1e2, t1e4, 0xfe);
- t2 = _mm512_maskz_ternarylogic_epi64(0x1f, t2e0, t2e2, t2e4, 0xfe);
- t3 = _mm512_maskz_ternarylogic_epi64(0x1f, t3e0, t3e2, t3e4, 0xfe);
- t4 = _mm512_maskz_ternarylogic_epi64(0x1f, t4e0, t4e2, t4e4, 0xfe);
- }
-
- // store rows
- r0 = t0;
- r1 = t1;
- r2 = t2;
- r3 = t3;
- r4 = t4;
- }
-
- // chi
- {
- // permute ids
- static const __m512i P0 = { 1, 2, 3, 4, 0 },
- P1 = { 2, 3, 4, 0, 1 };
-
- {
- // r0 ^= ~e0 & e1
- const __m512i t0_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r0),
- t0_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r0);
- r0 = _mm512_maskz_ternarylogic_epi64(0x1f, r0, t0_e0, t0_e1, 0xd2);
- }
-
- {
- // r1 ^= ~e0 & e1
- const __m512i t1_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r1),
- t1_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r1);
- r1 = _mm512_maskz_ternarylogic_epi64(0x1f, r1, t1_e0, t1_e1, 0xd2);
- }
-
- {
- // r2 ^= ~e0 & e1
- const __m512i t2_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r2),
- t2_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r2);
- r2 = _mm512_maskz_ternarylogic_epi64(0x1f, r2, t2_e0, t2_e1, 0xd2);
- }
-
- {
- // r3 ^= ~e0 & e1
- const __m512i t3_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r3),
- t3_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r3);
- r3 = _mm512_maskz_ternarylogic_epi64(0x1f, r3, t3_e0, t3_e1, 0xd2);
- }
-
- {
- // r4 ^= ~e0 & e1
- const __m512i t4_e0 = _mm512_maskz_permutexvar_epi64(0x1f, P0, r4),
- t4_e1 = _mm512_maskz_permutexvar_epi64(0x1f, P1, r4);
- r4 = _mm512_maskz_ternarylogic_epi64(0x1f, r4, t4_e0, t4_e1, 0xd2);
- }
- }
-
- // iota
- {
- // xor round constant to first cell
- r0 = _mm512_mask_xor_epi64(r0, 1, r0, _mm512_maskz_loadu_epi64(1, RCS + (12 + i)));
- }
- }
-
- // store rows
- _mm512_mask_storeu_epi64(s + 5 * 0, 0x1f, r0);
- _mm512_mask_storeu_epi64(s + 5 * 1, 0x1f, r1);
- _mm512_mask_storeu_epi64(s + 5 * 2, 0x1f, r2);
- _mm512_mask_storeu_epi64(s + 5 * 3, 0x1f, r3);
- _mm512_mask_storeu_epi64(s + 5 * 4, 0x1f, r4);
+ permute_n_avx512(s, 12);
}
#endif /* SHA3_BACKEND == SHA3_BACKEND_AVX512 */