summaryrefslogtreecommitdiff
path: root/sha3.c
diff options
context:
space:
mode:
Diffstat (limited to 'sha3.c')
-rw-r--r--sha3.c625
1 files changed, 547 insertions, 78 deletions
diff --git a/sha3.c b/sha3.c
index 0f8f812..a7e2edf 100644
--- a/sha3.c
+++ b/sha3.c
@@ -34,6 +34,16 @@
// number of rounds for permute()
#define SHA3_NUM_ROUNDS 24
+// round constants (used by iota)
+static const uint64_t RCS[] = {
+ 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, 0x8000000080008000ULL,
+ 0x000000000000808bULL, 0x0000000080000001ULL, 0x8000000080008081ULL, 0x8000000000008009ULL,
+ 0x000000000000008aULL, 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL,
+ 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, 0x8000000000008003ULL,
+ 0x8000000000008002ULL, 0x8000000000000080ULL, 0x000000000000800aULL, 0x800000008000000aULL,
+ 0x8000000080008081ULL, 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL,
+};
+
#if !defined(__AVX512F__) || defined(SHA3_TEST)
// If AVX512 is supported and we are not building the test suite,
// then do not compile the scalar step functions below.
@@ -153,35 +163,33 @@ static inline void chi(uint64_t dst[static 25], const uint64_t src[static 25]) {
// iota step of keccak permutation (scalar implementation)
static inline void iota(uint64_t a[static 25], const int i) {
- // round constants (ambiguous in spec)
- static const uint64_t RCS[] = {
- 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, 0x8000000080008000ULL,
- 0x000000000000808bULL, 0x0000000080000001ULL, 0x8000000080008081ULL, 0x8000000000008009ULL,
- 0x000000000000008aULL, 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL,
- 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, 0x8000000000008003ULL,
- 0x8000000000008002ULL, 0x8000000000000080ULL, 0x000000000000800aULL, 0x800000008000000aULL,
- 0x8000000080008081ULL, 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL,
- };
-
a[0] ^= RCS[i];
}
#endif /* !defined(__AVX512F__) || defined(SHA3_TEST) */
#ifndef __AVX512F__
-// keccak permutation (scalar implementation)
-//
-// note: clang is better about inlining this than gcc with a
-// configurable number of rounds. the configurable number of rounds is
-// only used by turboshake, so it might be worth creating a specialized
-// `permute12()` to handle turboshake.
-static inline void permute(uint64_t a[static 25], const size_t num_rounds) {
+// 24-round keccak permutation (scalar implementation)
+static inline void permute_scalar(uint64_t a[static 25]) {
uint64_t tmp[25] = { 0 };
- for (int i = 0; i < (int) num_rounds; i++) {
+ for (size_t i = 0; i < SHA3_NUM_ROUNDS; i++) {
theta(a);
rho(a);
pi(tmp, a);
chi(a, tmp);
- iota(a, 24 - num_rounds + i);
+ iota(a, i);
+ }
+}
+
+// 12 round keccak permutation (scalar implementation)
+// (only used by turboshake)
+static inline void permute12_scalar(uint64_t a[static 25]) {
+ uint64_t tmp[25] = { 0 };
+ for (size_t i = 0; i < 12; i++) {
+ theta(a);
+ rho(a);
+ pi(tmp, a);
+ chi(a, tmp);
+ iota(a, 12 + i);
}
}
#endif /* !__AVX512F__ */
@@ -189,7 +197,7 @@ static inline void permute(uint64_t a[static 25], const size_t num_rounds) {
#ifdef __AVX512F__
#include <immintrin.h>
-// keccak permutation (avx512 implementation).
+// 24 round keccak permutation (avx512 implementation).
//
// copied from `permute_avx512_fast()` in `tests/permute/permute.c`. all
// steps are inlined as blocks. ~3x faster than scalar implementation,
@@ -220,26 +228,240 @@ static inline void permute(uint64_t a[static 25], const size_t num_rounds) {
// as noted above, this is not the most efficient avx512 implementation;
// the row registers have three empty slots and there are a lot of loads
// that could be removed with a little more work.
-static inline void permute(uint64_t s[static 25], const size_t num_rounds) {
+static inline void permute_avx512(uint64_t s[static 25]) {
// unaligned load mask and permutation indices
uint8_t mask = 0x1f,
m0b = 0x01;
const __mmask8 m = _load_mask8(&mask),
m0 = _load_mask8(&m0b);
- // round constants (used in iota)
- static const uint64_t RCS[] = {
- 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, 0x8000000080008000ULL,
- 0x000000000000808bULL, 0x0000000080000001ULL, 0x8000000080008081ULL, 0x8000000000008009ULL,
- 0x000000000000008aULL, 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL,
- 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, 0x8000000000008003ULL,
- 0x8000000000008002ULL, 0x8000000000000080ULL, 0x000000000000800aULL, 0x800000008000000aULL,
- 0x8000000080008081ULL, 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL,
- };
+ // load round constant
+ // note: this will bomb if num_rounds < 8 or num_rounds > 24.
+ __m512i rc = _mm512_loadu_epi64((void*) RCS);
+
+ // load rc permutation
+ static const uint64_t rc_ps[8] = { 1, 2, 3, 4, 5, 6, 7, 0 };
+ const __m512i rc_p = _mm512_loadu_epi64((void*) rc_ps);
+
+ // load rows
+ __m512i r0 = _mm512_maskz_loadu_epi64(m, (void*) (s)),
+ r1 = _mm512_maskz_loadu_epi64(m, (void*) (s + 5)),
+ r2 = _mm512_maskz_loadu_epi64(m, (void*) (s + 10)),
+ r3 = _mm512_maskz_loadu_epi64(m, (void*) (s + 15)),
+ r4 = _mm512_maskz_loadu_epi64(m, (void*) (s + 20));
+
+ for (int i = 0; i < SHA3_NUM_ROUNDS; i++) {
+ // theta
+ {
+ const __m512i i0 = _mm512_setr_epi64(4, 0, 1, 2, 3, 5, 6, 7),
+ i1 = _mm512_setr_epi64(1, 2, 3, 4, 0, 5, 6, 7);
+
+ // c = xor(r0, r1, r2, r3, r4)
+ const __m512i r01 = _mm512_xor_epi64(r0, r1),
+ r23 = _mm512_xor_epi64(r2, r3),
+ r03 = _mm512_xor_epi64(r01, r23),
+ c = _mm512_xor_epi64(r03, r4);
+
+ // d = xor(permute(i0, c), permute(i1, rol(c, 1)))
+ const __m512i d0 = _mm512_permutexvar_epi64(i0, c),
+ d1 = _mm512_permutexvar_epi64(i1, _mm512_rol_epi64(c, 1)),
+ d = _mm512_xor_epi64(d0, d1);
+
+ // row = xor(row, d)
+ r0 = _mm512_xor_epi64(r0, d);
+ r1 = _mm512_xor_epi64(r1, d);
+ r2 = _mm512_xor_epi64(r2, d);
+ r3 = _mm512_xor_epi64(r3, d);
+ r4 = _mm512_xor_epi64(r4, d);
+ }
+
+ // rho
+ {
+ // unaligned load mask and rotate values
+ static const uint64_t vs0[8] = { 0, 1, 62, 28, 27, 0, 0, 0 },
+ vs1[8] = { 36, 44, 6, 55, 20, 0, 0, 0 },
+ vs2[8] = { 3, 10, 43, 25, 39, 0, 0, 0 },
+ vs3[8] = { 41, 45, 15, 21, 8, 0, 0, 0 },
+ vs4[8] = { 18, 2, 61, 56, 14, 0, 0, 0 };
+
+ // load rotate values
+ const __m512i v0 = _mm512_loadu_epi64((void*) vs0),
+ v1 = _mm512_loadu_epi64((void*) vs1),
+ v2 = _mm512_loadu_epi64((void*) vs2),
+ v3 = _mm512_loadu_epi64((void*) vs3),
+ v4 = _mm512_loadu_epi64((void*) vs4);
+
+ // rotate
+ r0 = _mm512_rolv_epi64(r0, v0);
+ r1 = _mm512_rolv_epi64(r1, v1);
+ r2 = _mm512_rolv_epi64(r2, v2);
+ r3 = _mm512_rolv_epi64(r3, v3);
+ r4 = _mm512_rolv_epi64(r4, v4);
+ }
+
+ // pi
+ {
+ // mask bytes
+ uint8_t m01b = 0x03,
+ m23b = 0x0c,
+ m4b = 0x10;
+
+ // load masks
+ const __mmask8 m01 = _load_mask8(&m01b),
+ m23 = _load_mask8(&m23b),
+ m4 = _load_mask8(&m4b);
+
+ // permutation indices
+ //
+ // (note: these are masked so only the relevant indices for
+ // _mm512_maskz_permutex2var_epi64() in each array are filled in)
+ static const uint64_t t0_pis_01[8] = { 0, 8 + 1, 0, 0, 0, 0, 0, 0 },
+ t0_pis_23[8] = { 0, 0, 2, 8 + 3, 0, 0, 0, 0 },
+ t0_pis_4[8] = { 0, 0, 0, 0, 4, 0, 0, 0 },
+
+ t1_pis_01[8] = { 3, 8 + 4, 0, 0, 0, 0, 0, 0 },
+ t1_pis_23[8] = { 0, 0, 0, 8 + 1, 0, 0, 0, 0 },
+ t1_pis_4[8] = { 0, 0, 0, 0, 2, 0, 0, 0 },
+
+ t2_pis_01[8] = { 1, 8 + 2, 0, 0, 0, 0, 0, 0 },
+ t2_pis_23[8] = { 0, 0, 3, 8 + 4, 0, 0, 0, 0 },
+ t2_pis_4[8] = { 0, 0, 0, 0, 0, 0, 0, 0 },
+
+ t3_pis_01[8] = { 4, 8 + 0, 0, 0, 0, 0, 0, 0 },
+ t3_pis_23[8] = { 0, 0, 1, 8 + 2, 0, 0, 0, 0 },
+ t3_pis_4[8] = { 0, 0, 0, 0, 3, 0, 0, 0 },
+
+ t4_pis_01[8] = { 2, 8 + 3, 0, 0, 0, 0, 0, 0 },
+ t4_pis_23[8] = { 0, 0, 4, 8 + 0, 0, 0, 0, 0 },
+ t4_pis_4[8] = { 0, 0, 0, 0, 1, 0, 0, 0 };
+
+ // load permutation indices
+ const __m512i t0_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t0_pis_01),
+ t0_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t0_pis_23),
+ t0_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t0_pis_4),
+
+ t1_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t1_pis_01),
+ t1_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t1_pis_23),
+ t1_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t1_pis_4),
+
+ t2_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t2_pis_01),
+ t2_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t2_pis_23),
+ t2_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t2_pis_4),
+
+ t3_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t3_pis_01),
+ t3_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t3_pis_23),
+ t3_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t3_pis_4),
+
+ t4_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t4_pis_01),
+ t4_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t4_pis_23),
+ t4_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t4_pis_4);
+
+ // permute rows
+ const __m512i t0e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t0_p01, r1),
+ t0e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t0_p23, r3),
+ t0e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t0_p4, r0),
+ t0 = _mm512_or_epi64(_mm512_or_epi64(t0e0, t0e2), t0e4),
+
+ t1e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t1_p01, r1),
+ t1e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t1_p23, r3),
+ t1e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t1_p4, r0),
+ t1 = _mm512_or_epi64(_mm512_or_epi64(t1e0, t1e2), t1e4),
+
+ t2e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t2_p01, r1),
+ t2e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t2_p23, r3),
+ t2e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t2_p4, r0),
+ t2 = _mm512_or_epi64(_mm512_or_epi64(t2e0, t2e2), t2e4),
+
+ t3e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t3_p01, r1),
+ t3e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t3_p23, r3),
+ t3e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t3_p4, r0),
+ t3 = _mm512_or_epi64(_mm512_or_epi64(t3e0, t3e2), t3e4),
+
+ t4e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t4_p01, r1),
+ t4e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t4_p23, r3),
+ t4e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t4_p4, r0),
+ t4 = _mm512_or_epi64(_mm512_or_epi64(t4e0, t4e2), t4e4);
+
+ // store rows
+ r0 = t0;
+ r1 = t1;
+ r2 = t2;
+ r3 = t3;
+ r4 = t4;
+ }
+
+ // chi
+ {
+ // permutation indices
+ static const uint64_t pis0[8] = { 1, 2, 3, 4, 0, 0, 0, 0 },
+ pis1[8] = { 2, 3, 4, 0, 1, 0, 0, 0 };
+
+ // load permutation indices
+ const __m512i p0 = _mm512_maskz_loadu_epi64(m, (void*) pis0),
+ p1 = _mm512_maskz_loadu_epi64(m, (void*) pis1);
+
+ // permute rows
+ const __m512i t0_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r0),
+ t0_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r0),
+ t0 = _mm512_xor_epi64(r0, _mm512_andnot_epi64(t0_e0, t0_e1)),
+
+ t1_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r1),
+ t1_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r1),
+ t1 = _mm512_xor_epi64(r1, _mm512_andnot_epi64(t1_e0, t1_e1)),
+
+ t2_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r2),
+ t2_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r2),
+ t2 = _mm512_xor_epi64(r2, _mm512_andnot_epi64(t2_e0, t2_e1)),
+
+ t3_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r3),
+ t3_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r3),
+ t3 = _mm512_xor_epi64(r3, _mm512_andnot_epi64(t3_e0, t3_e1)),
+
+ t4_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r4),
+ t4_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r4),
+ t4 = _mm512_xor_epi64(r4, _mm512_andnot_epi64(t4_e0, t4_e1));
+
+ // store rows
+ r0 = t0;
+ r1 = t1;
+ r2 = t2;
+ r3 = t3;
+ r4 = t4;
+ }
+
+ // iota
+ {
+ // xor round constant, shuffle rc register
+ r0 = _mm512_mask_xor_epi64(r0, m0, r0, rc);
+ rc = _mm512_permutexvar_epi64(rc_p, rc);
+
+ if (((i + 1) % 8) == 0 && i != 23) {
+ // load next set of round constants
+ // note: this will bomb if num_rounds < 8 or num_rounds > 24.
+ rc = _mm512_loadu_epi64((void*) (RCS + (i + 1)));
+ }
+ }
+ }
+
+ // store rows
+ _mm512_mask_storeu_epi64((void*) (s), m, r0),
+ _mm512_mask_storeu_epi64((void*) (s + 5), m, r1),
+ _mm512_mask_storeu_epi64((void*) (s + 10), m, r2),
+ _mm512_mask_storeu_epi64((void*) (s + 15), m, r3),
+ _mm512_mask_storeu_epi64((void*) (s + 20), m, r4);
+}
+
+// 12 round keccak permutation (avx512 implementation).
+static inline void permute12_avx512(uint64_t s[static 25]) {
+ // unaligned load mask and permutation indices
+ uint8_t mask = 0x1f,
+ m0b = 0x01;
+ const __mmask8 m = _load_mask8(&mask),
+ m0 = _load_mask8(&m0b);
// load round constant
// note: this will bomb if num_rounds < 8 or num_rounds > 24.
- __m512i rc = _mm512_loadu_epi64((void*) (RCS + 24 - num_rounds));
+ __m512i rc = _mm512_loadu_epi64((void*) (RCS + 12));
// load rc permutation
static const uint64_t rc_ps[8] = { 1, 2, 3, 4, 5, 6, 7, 0 };
@@ -252,7 +474,7 @@ static inline void permute(uint64_t s[static 25], const size_t num_rounds) {
r3 = _mm512_maskz_loadu_epi64(m, (void*) (s + 15)),
r4 = _mm512_maskz_loadu_epi64(m, (void*) (s + 20));
- for (int i = 0; i < (int) num_rounds; i++) {
+ for (int i = 0; i < SHA3_NUM_ROUNDS; i++) {
// theta
{
const __m512i i0 = _mm512_setr_epi64(4, 0, 1, 2, 3, 5, 6, 7),
@@ -437,10 +659,10 @@ static inline void permute(uint64_t s[static 25], const size_t num_rounds) {
r0 = _mm512_mask_xor_epi64(r0, m0, r0, rc);
rc = _mm512_permutexvar_epi64(rc_p, rc);
- if (((24 - num_rounds + i + 1) % 8) == 0 && i != 23) {
+ if (((12 + i + 1) % 8) == 0 && (12 + i) != 23) {
// load next set of round constants
// note: this will bomb if num_rounds < 8 or num_rounds > 24.
- rc = _mm512_loadu_epi64((void*) (RCS + 24 - num_rounds + (i + 1)));
+ rc = _mm512_loadu_epi64((void*) (RCS + (12 + i + 1)));
}
}
}
@@ -454,9 +676,76 @@ static inline void permute(uint64_t s[static 25], const size_t num_rounds) {
}
#endif /* __AVX512F__ */
+#ifdef __AVX512F__
+#define permute permute_avx512
+#define permute12 permute12_avx512
+#else /* !__AVX512F__ */
+#define permute permute_scalar
+#define permute12 permute12_scalar
+#endif /* __AVX512F__ */
+
// absorb message into state, return updated byte count
// used by `hash_absorb()`, `hash_once()`, and `xof_absorb_raw()`
-static inline size_t absorb(sha3_state_t * const a, size_t num_bytes, const size_t rate, const size_t num_rounds, const uint8_t *m, size_t m_len) {
+static inline size_t absorb(sha3_state_t * const a, size_t num_bytes, const size_t rate, const uint8_t *m, size_t m_len) {
+ // absorb aligned chunks
+ if ((num_bytes & 7) == 0 && (((uintptr_t) m) & 7) == 0) {
+ // absorb 32 byte chunks (4 x uint64)
+ while (m_len >= 32 && num_bytes <= rate - 32) {
+ // xor chunk into state
+ // (FIXME: does not vectorize for some reason, even when unrolled)
+ for (size_t i = 0; i < 4; i++) {
+ a->u64[num_bytes/8 + i] ^= ((uint64_t*) m)[i];
+ }
+
+ // update counters
+ num_bytes += 32;
+ m += 32;
+ m_len -= 32;
+
+ if (num_bytes == rate) {
+ // permute state
+ permute(a->u64);
+ num_bytes = 0;
+ }
+ }
+
+ // absorb 8 byte chunks (1 x uint64)
+ while (m_len >= 8 && num_bytes <= rate - 8) {
+ // xor chunk into state
+ a->u64[num_bytes/8] ^= *((uint64_t*) m);
+
+ // update counters
+ num_bytes += 8;
+ m += 8;
+ m_len -= 8;
+
+ if (num_bytes == rate) {
+ // permute state
+ permute(a->u64);
+ num_bytes = 0;
+ }
+ }
+ }
+
+ // absorb remaining bytes
+ for (size_t i = 0; i < m_len; i++) {
+ // xor byte into state
+ a->u8[num_bytes++] ^= m[i];
+
+ if (num_bytes == rate) {
+ // permute state
+ permute(a->u64);
+ num_bytes = 0;
+ }
+ }
+
+ // return byte count
+ return num_bytes;
+}
+
+// absorb message into xof12 state, return updated byte count
+// used by `xof12_absorb_raw()`
+static inline size_t absorb12(sha3_state_t * const a, size_t num_bytes, const size_t rate, const uint8_t *m, size_t m_len) {
// absorb aligned chunks
if ((num_bytes & 7) == 0 && (((uintptr_t) m) & 7) == 0) {
// absorb 32 byte chunks (4 x uint64)
@@ -474,7 +763,7 @@ static inline size_t absorb(sha3_state_t * const a, size_t num_bytes, const size
if (num_bytes == rate) {
// permute state
- permute(a->u64, num_rounds);
+ permute12(a->u64);
num_bytes = 0;
}
}
@@ -491,7 +780,7 @@ static inline size_t absorb(sha3_state_t * const a, size_t num_bytes, const size
if (num_bytes == rate) {
// permute state
- permute(a->u64, num_rounds);
+ permute12(a->u64);
num_bytes = 0;
}
}
@@ -504,7 +793,7 @@ static inline size_t absorb(sha3_state_t * const a, size_t num_bytes, const size
if (num_bytes == rate) {
// permute state
- permute(a->u64, num_rounds);
+ permute12(a->u64);
num_bytes = 0;
}
}
@@ -513,6 +802,7 @@ static inline size_t absorb(sha3_state_t * const a, size_t num_bytes, const size
return num_bytes;
}
+
// Get rate (number of bytes that can be absorbed before the internal
// state is permuted).
//
@@ -546,7 +836,7 @@ static inline void hash_once(const uint8_t *m, size_t m_len, uint8_t * const dst
sha3_state_t a = { .u64 = { 0 } };
// absorb message, get new internal length
- const size_t len = absorb(&a, 0, RATE(dst_len), SHA3_NUM_ROUNDS, m, m_len);
+ const size_t len = absorb(&a, 0, RATE(dst_len), m, m_len);
// append suffix and padding
// (note: suffix and padding are ambiguous in spec)
@@ -554,7 +844,7 @@ static inline void hash_once(const uint8_t *m, size_t m_len, uint8_t * const dst
a.u8[RATE(dst_len)-1] ^= 0x80;
// final permutation
- permute(a.u64, SHA3_NUM_ROUNDS);
+ permute(a.u64);
// copy to destination
memcpy(dst, a.u8, dst_len);
@@ -575,7 +865,7 @@ static inline bool hash_absorb(sha3_t * const hash, const size_t rate, const uin
}
// absorb bytes, return success
- hash->num_bytes = absorb(&(hash->a), hash->num_bytes, rate, SHA3_NUM_ROUNDS, src, len);
+ hash->num_bytes = absorb(&(hash->a), hash->num_bytes, rate, src, len);
return true;
}
@@ -593,7 +883,7 @@ static inline void hash_final(sha3_t * const hash, const size_t rate, uint8_t *
hash->a.u8[rate - 1] ^= 0x80;
// permute
- permute(hash->a.u64, SHA3_NUM_ROUNDS);
+ permute(hash->a.u64);
}
// copy to destination
@@ -637,14 +927,14 @@ static inline void xof_init(sha3_xof_t * const xof) {
// context has already been squeezed.
//
// Called by `xof_absorb()` and `xof_once()`.
-static inline void xof_absorb_raw(sha3_xof_t * const xof, const size_t rate, const size_t num_rounds, const uint8_t *m, size_t m_len) {
- xof->num_bytes = absorb(&(xof->a), xof->num_bytes, rate, num_rounds, m, m_len);
+static inline void xof_absorb_raw(sha3_xof_t * const xof, const size_t rate, const uint8_t *m, size_t m_len) {
+ xof->num_bytes = absorb(&(xof->a), xof->num_bytes, rate, m, m_len);
}
// Absorb data into XOF context.
//
// Returns `false` if this context has already been squeezed.
-static inline _Bool xof_absorb(sha3_xof_t * const xof, const size_t rate, const size_t num_rounds, const uint8_t * const m, size_t m_len) {
+static inline _Bool xof_absorb(sha3_xof_t * const xof, const size_t rate, const uint8_t * const m, size_t m_len) {
// check context state
if (xof->squeezing) {
// xof has already been squeezed, return error
@@ -652,19 +942,19 @@ static inline _Bool xof_absorb(sha3_xof_t * const xof, const size_t rate, const
}
// absorb, return success
- xof_absorb_raw(xof, rate, num_rounds, m, m_len);
+ xof_absorb_raw(xof, rate, m, m_len);
return true;
}
// Finalize absorb, switch mode of XOF context to squeezing.
-static inline void xof_absorb_done(sha3_xof_t * const xof, const size_t rate, const size_t num_rounds, const uint8_t pad) {
+static inline void xof_absorb_done(sha3_xof_t * const xof, const size_t rate, const uint8_t pad) {
// append suffix (s6.2) and padding
// (note: suffix and padding are ambiguous in spec)
xof->a.u8[xof->num_bytes] ^= pad;
xof->a.u8[rate - 1] ^= 0x80;
// permute
- permute(xof->a.u64, num_rounds);
+ permute(xof->a.u64);
// switch to squeeze mode
xof->num_bytes = 0;
@@ -672,7 +962,7 @@ static inline void xof_absorb_done(sha3_xof_t * const xof, const size_t rate, co
}
// Squeeze data without checking mode (used by `xof_once()`).
-static inline void xof_squeeze_raw(sha3_xof_t * const xof, const size_t rate, const size_t num_rounds, uint8_t *dst, size_t dst_len) {
+static inline void xof_squeeze_raw(sha3_xof_t * const xof, const size_t rate, uint8_t *dst, size_t dst_len) {
if (!xof->num_bytes) {
// num_bytes is zero, so we are reading from the start of the
// internal state buffer. while `dst_len` is greater than rate,
@@ -681,7 +971,7 @@ static inline void xof_squeeze_raw(sha3_xof_t * const xof, const size_t rate, co
// rate-sized chunks to destination
while (dst_len >= rate) {
memcpy(dst, xof->a.u8, rate); // copy rate-sized chunk
- permute(xof->a.u64, num_rounds); // permute state
+ permute(xof->a.u64); // permute state
// update destination pointer and length
dst += rate;
@@ -705,7 +995,7 @@ static inline void xof_squeeze_raw(sha3_xof_t * const xof, const size_t rate, co
dst[i] = xof->a.u8[xof->num_bytes++]; // squeeze byte to destination
if (xof->num_bytes == rate) {
- permute(xof->a.u64, num_rounds); // permute state
+ permute(xof->a.u64); // permute state
xof->num_bytes = 0; // clear read bytes count
}
}
@@ -713,28 +1003,137 @@ static inline void xof_squeeze_raw(sha3_xof_t * const xof, const size_t rate, co
}
// squeeze data from xof
-static inline void xof_squeeze(sha3_xof_t * const xof, const size_t rate, const size_t num_rounds, const uint8_t pad, uint8_t * const dst, const size_t dst_len) {
+static inline void xof_squeeze(sha3_xof_t * const xof, const size_t rate, const uint8_t pad, uint8_t * const dst, const size_t dst_len) {
// check state
if (!xof->squeezing) {
// finalize absorb
- xof_absorb_done(xof, rate, num_rounds, pad);
+ xof_absorb_done(xof, rate, pad);
}
- xof_squeeze_raw(xof, rate, num_rounds, dst, dst_len);
+ xof_squeeze_raw(xof, rate, dst, dst_len);
}
// one-shot xof absorb and squeeze
-static inline void xof_once(const size_t rate, const size_t num_rounds, const uint8_t pad, const uint8_t * const src, const size_t src_len, uint8_t * const dst, const size_t dst_len) {
+static inline void xof_once(const size_t rate, const uint8_t pad, const uint8_t * const src, const size_t src_len, uint8_t * const dst, const size_t dst_len) {
// init
sha3_xof_t xof;
xof_init(&xof);
// absorb
- xof_absorb_raw(&xof, rate, num_rounds, src, src_len);
- xof_absorb_done(&xof, rate, num_rounds, pad);
+ xof_absorb_raw(&xof, rate, src, src_len);
+ xof_absorb_done(&xof, rate, pad);
+
+ // squeeze
+ xof_squeeze_raw(&xof, rate, dst, dst_len);
+}
+
+// initialize xof12 context
+static inline void xof12_init(sha3_xof_t * const xof) {
+ memset(xof, 0, sizeof(sha3_xof_t));
+}
+
+// Absorb data into XOF12 context without checking to see if the
+// context has already been squeezed.
+//
+// Called by `xof12_absorb()` and `xof12_once()`.
+static inline void xof12_absorb_raw(sha3_xof_t * const xof, const size_t rate, const uint8_t *m, size_t m_len) {
+ xof->num_bytes = absorb12(&(xof->a), xof->num_bytes, rate, m, m_len);
+}
+
+// Absorb data into XOF context.
+//
+// Returns `false` if this XOF12 context has already been squeezed.
+static inline _Bool xof12_absorb(sha3_xof_t * const xof, const size_t rate, const uint8_t * const m, size_t m_len) {
+ // check context state
+ if (xof->squeezing) {
+ // xof has already been squeezed, return error
+ return false;
+ }
+
+ // absorb, return success
+ xof12_absorb_raw(xof, rate, m, m_len);
+ return true;
+}
+
+// Finalize absorb, switch mode of XOF12 context to squeezing.
+static inline void xof12_absorb_done(sha3_xof_t * const xof, const size_t rate, const uint8_t pad) {
+ // append suffix (s6.2) and padding
+ // (note: suffix and padding are ambiguous in spec)
+ xof->a.u8[xof->num_bytes] ^= pad;
+ xof->a.u8[rate - 1] ^= 0x80;
+
+ // permute
+ permute12(xof->a.u64);
+
+ // switch to squeeze mode
+ xof->num_bytes = 0;
+ xof->squeezing = true;
+}
+
+// Squeeze data without checking mode (used by `xof12_once()`).
+static inline void xof12_squeeze_raw(sha3_xof_t * const xof, const size_t rate, uint8_t *dst, size_t dst_len) {
+ if (!xof->num_bytes) {
+ // num_bytes is zero, so we are reading from the start of the
+ // internal state buffer. while `dst_len` is greater than rate,
+ // copy `rate` sized chunks directly from the internal state buffer
+ // to the destination, then permute the internal state. squeeze
+ // rate-sized chunks to destination
+ while (dst_len >= rate) {
+ memcpy(dst, xof->a.u8, rate); // copy rate-sized chunk
+ permute12(xof->a.u64); // permute state
+
+ // update destination pointer and length
+ dst += rate;
+ dst_len -= rate;
+ }
+
+ if (dst_len > 0) {
+ // the remaining destination length is less than `rate`, so copy a
+ // `dst_len`-sized chunk from the internal state to the
+ // destination buffer, then update the read byte count.
+
+ // squeeze dst_len-sized block to destination
+ memcpy(dst, xof->a.u8, dst_len); // copy dst_len-sized chunk
+ xof->num_bytes = dst_len; // update read byte count
+ }
+ } else {
+ // fall back to squeezing one byte at a time
+
+ // squeeze bytes to destination
+ for (size_t i = 0; i < dst_len; i++) {
+ dst[i] = xof->a.u8[xof->num_bytes++]; // squeeze byte to destination
+
+ if (xof->num_bytes == rate) {
+ permute12(xof->a.u64); // permute state
+ xof->num_bytes = 0; // clear read bytes count
+ }
+ }
+ }
+}
+
+// squeeze data from xof12 context
+static inline void xof12_squeeze(sha3_xof_t * const xof, const size_t rate, const uint8_t pad, uint8_t * const dst, const size_t dst_len) {
+ // check state
+ if (!xof->squeezing) {
+ // finalize absorb
+ xof12_absorb_done(xof, rate, pad);
+ }
+
+ xof12_squeeze_raw(xof, rate, dst, dst_len);
+}
+
+// one-shot xof12 absorb and squeeze
+static inline void xof12_once(const size_t rate, const uint8_t pad, const uint8_t * const src, const size_t src_len, uint8_t * const dst, const size_t dst_len) {
+ // init
+ sha3_xof_t xof;
+ xof12_init(&xof);
+
+ // absorb
+ xof12_absorb_raw(&xof, rate, src, src_len);
+ xof12_absorb_done(&xof, rate, pad);
// squeeze
- xof_squeeze_raw(&xof, rate, num_rounds, dst, dst_len);
+ xof12_squeeze_raw(&xof, rate, dst, dst_len);
}
// define shake iterative context and one-shot functions
@@ -746,17 +1145,17 @@ static inline void xof_once(const size_t rate, const size_t num_rounds, const ui
\
/* absorb bytes into shake context */ \
_Bool shake ## BITS ## _absorb(sha3_xof_t * const xof, const uint8_t * const m, const size_t len) { \
- return xof_absorb(xof, SHAKE ## BITS ## _RATE, SHA3_NUM_ROUNDS, m, len); \
+ return xof_absorb(xof, SHAKE ## BITS ## _RATE, m, len); \
} \
\
/* squeeze bytes from shake context */ \
void shake ## BITS ## _squeeze(sha3_xof_t * const xof, uint8_t * const dst, const size_t dst_len) { \
- xof_squeeze(xof, SHAKE ## BITS ## _RATE, SHA3_NUM_ROUNDS, SHAKE_PAD, dst, dst_len); \
+ xof_squeeze(xof, SHAKE ## BITS ## _RATE, SHAKE_PAD, dst, dst_len); \
} \
\
/* one-shot shake absorb and squeeze */ \
void shake ## BITS(const uint8_t * const src, const size_t src_len, uint8_t * const dst, const size_t dst_len) { \
- xof_once(SHAKE ## BITS ## _RATE, SHA3_NUM_ROUNDS, SHAKE_PAD, src, src_len, dst, dst_len); \
+ xof_once(SHAKE ## BITS ## _RATE, SHAKE_PAD, src, src_len, dst, dst_len); \
}
// shake padding byte and rates
@@ -1070,12 +1469,12 @@ static inline bytepad_t bytepad(const size_t data_len, const size_t width) {
#define DEF_CSHAKE(BITS) \
/* absorb data into cshake context */ \
_Bool cshake ## BITS ## _xof_absorb(sha3_xof_t * const xof, const uint8_t * const msg, const size_t len) { \
- return xof_absorb(xof, SHAKE ## BITS ## _RATE, SHA3_NUM_ROUNDS, msg, len); \
+ return xof_absorb(xof, SHAKE ## BITS ## _RATE, msg, len); \
} \
\
/* squeeze data from cshake context */ \
void cshake ## BITS ## _xof_squeeze(sha3_xof_t * const xof, uint8_t * const dst, const size_t len) { \
- xof_squeeze(xof, SHAKE ## BITS ## _RATE, SHA3_NUM_ROUNDS, CSHAKE_PAD, dst, len); \
+ xof_squeeze(xof, SHAKE ## BITS ## _RATE, CSHAKE_PAD, dst, len); \
} \
\
/* initialize cshake context */ \
@@ -1522,7 +1921,7 @@ static inline _Bool turboshake_init(turboshake_t * const ts, const uint8_t pad)
}
// init xof
- xof_init(&(ts->xof));
+ xof12_init(&(ts->xof));
ts->pad = pad;
// return success
@@ -1544,22 +1943,22 @@ static inline _Bool turboshake_init(turboshake_t * const ts, const uint8_t pad)
\
/* absorb bytes into turboshake context. */ \
_Bool turboshake ## BITS ## _absorb(turboshake_t * const ts, const uint8_t * const m, const size_t len) { \
- return xof_absorb(&(ts->xof), SHAKE ## BITS ## _RATE, TURBOSHAKE_NUM_ROUNDS, m, len); \
+ return xof12_absorb(&(ts->xof), SHAKE ## BITS ## _RATE, m, len); \
} \
\
/* squeeze bytes from turboshake context */ \
void turboshake ## BITS ## _squeeze(turboshake_t * const ts, uint8_t * const dst, const size_t dst_len) { \
- xof_squeeze(&(ts->xof), SHAKE ## BITS ## _RATE, TURBOSHAKE_NUM_ROUNDS, ts->pad, dst, dst_len); \
+ xof12_squeeze(&(ts->xof), SHAKE ## BITS ## _RATE, ts->pad, dst, dst_len); \
} \
\
/* one-shot turboshake with default pad byte */ \
void turboshake ## BITS (const uint8_t * const src, const size_t src_len, uint8_t * const dst, const size_t dst_len) { \
- xof_once(SHAKE ## BITS ## _RATE, TURBOSHAKE_NUM_ROUNDS, TURBOSHAKE_PAD, src, src_len, dst, dst_len); \
+ xof12_once(SHAKE ## BITS ## _RATE, TURBOSHAKE_PAD, src, src_len, dst, dst_len); \
} \
\
/* one-shot turboshake with custom pad byte */ \
void turboshake ## BITS ## _custom(const uint8_t pad, const uint8_t * const src, const size_t src_len, uint8_t * const dst, const size_t dst_len) { \
- xof_once(SHAKE ## BITS ## _RATE, TURBOSHAKE_NUM_ROUNDS, pad, src, src_len, dst, dst_len); \
+ xof12_once(SHAKE ## BITS ## _RATE, pad, src, src_len, dst, dst_len); \
}
// declare turboshake functions
@@ -1978,17 +2377,84 @@ static void test_iota(void) {
}
}
-static void test_permute(void) {
- uint64_t a[25] = { [0] = 0x00000001997b5853ULL, [16] = 0x8000000000000000ULL };
- const uint64_t exp[] = {
- 0xE95A9E40EF2F24C8ULL, 0x24C64DAE57C8F1D1ULL,
- 0x8CAA629F80192BB9ULL, 0xD0B178A0541C4107ULL,
- };
+static const struct {
+ uint64_t a[25]; // input state
+ const uint64_t exp[25]; // expected value
+ const size_t exp_len; // length of exp, in bytes
+} PERMUTE_TESTS[] = {{
+ .a = { [0] = 0x00000001997b5853ULL, [16] = 0x8000000000000000ULL },
+ .exp = { 0xE95A9E40EF2F24C8ULL, 0x24C64DAE57C8F1D1ULL, 0x8CAA629F80192BB9ULL, 0xD0B178A0541C4107ULL },
+ .exp_len = 32,
+}};
+
+static void test_permute_scalar(void) {
+ for (size_t i = 0; i < sizeof(PERMUTE_TESTS) / sizeof(PERMUTE_TESTS[0]); i++) {
+ const size_t exp_len = PERMUTE_TESTS[i].exp_len;
+
+ uint64_t got[25] = { 0 };
+ memcpy(got, PERMUTE_TESTS[i].a, sizeof(got));
+ permute_scalar(got);
+
+ if (memcmp(got, PERMUTE_TESTS[i].exp, exp_len)) {
+ fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE_TESTS[i].exp, exp_len);
+ }
+ }
+}
- permute(a, SHA3_NUM_ROUNDS);
- if (memcmp(exp, a, sizeof(exp))) {
- fail_test(__func__, "", (uint8_t*) a, 32, (uint8_t*) exp, 32);
+static void test_permute_avx512(void) {
+#ifdef __AVX512F__
+ for (size_t i = 0; i < sizeof(PERMUTE_TESTS) / sizeof(PERMUTE_TESTS[0]); i++) {
+ const size_t exp_len = PERMUTE_TESTS[i].exp_len;
+
+ uint64_t got[25] = { 0 };
+ memcpy(got, PERMUTE_TESTS[i].a, sizeof(got));
+ permute_avx512(got);
+
+ if (memcmp(got, PERMUTE_TESTS[i].exp, exp_len)) {
+ fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE_TESTS[i].exp, exp_len);
+ }
+ }
+#endif /* __AVX512F__ */
+}
+
+static const struct {
+ uint64_t a[25]; // input state
+ const uint64_t exp[25]; // expected value
+ const size_t exp_len; // length of exp, in bytes
+} PERMUTE12_TESTS[] = {{
+ .a = { [0] = 0x00000001997b5853ULL, [16] = 0x8000000000000000ULL },
+ .exp = { 0X8B346BAFF5DA94C6ULL, 0XD7D37EC35E3B2EECULL, 0XBBF724EABFD84018ULL, 0X5E3C1AFA4EA7B3A1ULL },
+ .exp_len = 32,
+}};
+
+static void test_permute12_scalar(void) {
+ for (size_t i = 0; i < sizeof(PERMUTE12_TESTS) / sizeof(PERMUTE12_TESTS[0]); i++) {
+ const size_t exp_len = PERMUTE12_TESTS[i].exp_len;
+
+ uint64_t got[25] = { 0 };
+ memcpy(got, PERMUTE12_TESTS[i].a, sizeof(got));
+ permute12_scalar(got);
+
+ if (memcmp(got, PERMUTE12_TESTS[i].exp, exp_len)) {
+ fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE12_TESTS[i].exp, exp_len);
+ }
+ }
+}
+
+static void test_permute12_avx512(void) {
+#ifdef __AVX512F__
+ for (size_t i = 0; i < sizeof(PERMUTE12_TESTS) / sizeof(PERMUTE12_TESTS[0]); i++) {
+ const size_t exp_len = PERMUTE12_TESTS[i].exp_len;
+
+ uint64_t got[25] = { 0 };
+ memcpy(got, PERMUTE12_TESTS[i].a, sizeof(got));
+ permute12_avx512(got);
+
+ if (memcmp(got, PERMUTE12_TESTS[i].exp, exp_len)) {
+ fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE12_TESTS[i].exp, exp_len);
+ }
}
+#endif /* __AVX512F__ */
}
static void test_sha3_224(void) {
@@ -6143,7 +6609,10 @@ int main(void) {
test_pi();
test_chi();
test_iota();
- test_permute();
+ test_permute_scalar();
+ test_permute_avx512();
+ test_permute12_scalar();
+ test_permute12_avx512();
test_sha3_224();
test_sha3_256();
test_sha3_384();