diff options
-rw-r--r-- | sha3.c | 625 |
1 files changed, 547 insertions, 78 deletions
@@ -34,6 +34,16 @@ // number of rounds for permute() #define SHA3_NUM_ROUNDS 24 +// round constants (used by iota) +static const uint64_t RCS[] = { + 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, 0x8000000080008000ULL, + 0x000000000000808bULL, 0x0000000080000001ULL, 0x8000000080008081ULL, 0x8000000000008009ULL, + 0x000000000000008aULL, 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL, + 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, 0x8000000000008003ULL, + 0x8000000000008002ULL, 0x8000000000000080ULL, 0x000000000000800aULL, 0x800000008000000aULL, + 0x8000000080008081ULL, 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL, +}; + #if !defined(__AVX512F__) || defined(SHA3_TEST) // If AVX512 is supported and we are not building the test suite, // then do not compile the scalar step functions below. @@ -153,35 +163,33 @@ static inline void chi(uint64_t dst[static 25], const uint64_t src[static 25]) { // iota step of keccak permutation (scalar implementation) static inline void iota(uint64_t a[static 25], const int i) { - // round constants (ambiguous in spec) - static const uint64_t RCS[] = { - 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, 0x8000000080008000ULL, - 0x000000000000808bULL, 0x0000000080000001ULL, 0x8000000080008081ULL, 0x8000000000008009ULL, - 0x000000000000008aULL, 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL, - 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, 0x8000000000008003ULL, - 0x8000000000008002ULL, 0x8000000000000080ULL, 0x000000000000800aULL, 0x800000008000000aULL, - 0x8000000080008081ULL, 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL, - }; - a[0] ^= RCS[i]; } #endif /* !defined(__AVX512F__) || defined(SHA3_TEST) */ #ifndef __AVX512F__ -// keccak permutation (scalar implementation) -// -// note: clang is better about inlining this than gcc with a -// configurable number of rounds. the configurable number of rounds is -// only used by turboshake, so it might be worth creating a specialized -// `permute12()` to handle turboshake. -static inline void permute(uint64_t a[static 25], const size_t num_rounds) { +// 24-round keccak permutation (scalar implementation) +static inline void permute_scalar(uint64_t a[static 25]) { uint64_t tmp[25] = { 0 }; - for (int i = 0; i < (int) num_rounds; i++) { + for (size_t i = 0; i < SHA3_NUM_ROUNDS; i++) { theta(a); rho(a); pi(tmp, a); chi(a, tmp); - iota(a, 24 - num_rounds + i); + iota(a, i); + } +} + +// 12 round keccak permutation (scalar implementation) +// (only used by turboshake) +static inline void permute12_scalar(uint64_t a[static 25]) { + uint64_t tmp[25] = { 0 }; + for (size_t i = 0; i < 12; i++) { + theta(a); + rho(a); + pi(tmp, a); + chi(a, tmp); + iota(a, 12 + i); } } #endif /* !__AVX512F__ */ @@ -189,7 +197,7 @@ static inline void permute(uint64_t a[static 25], const size_t num_rounds) { #ifdef __AVX512F__ #include <immintrin.h> -// keccak permutation (avx512 implementation). +// 24 round keccak permutation (avx512 implementation). // // copied from `permute_avx512_fast()` in `tests/permute/permute.c`. all // steps are inlined as blocks. ~3x faster than scalar implementation, @@ -220,26 +228,240 @@ static inline void permute(uint64_t a[static 25], const size_t num_rounds) { // as noted above, this is not the most efficient avx512 implementation; // the row registers have three empty slots and there are a lot of loads // that could be removed with a little more work. -static inline void permute(uint64_t s[static 25], const size_t num_rounds) { +static inline void permute_avx512(uint64_t s[static 25]) { // unaligned load mask and permutation indices uint8_t mask = 0x1f, m0b = 0x01; const __mmask8 m = _load_mask8(&mask), m0 = _load_mask8(&m0b); - // round constants (used in iota) - static const uint64_t RCS[] = { - 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, 0x8000000080008000ULL, - 0x000000000000808bULL, 0x0000000080000001ULL, 0x8000000080008081ULL, 0x8000000000008009ULL, - 0x000000000000008aULL, 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL, - 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, 0x8000000000008003ULL, - 0x8000000000008002ULL, 0x8000000000000080ULL, 0x000000000000800aULL, 0x800000008000000aULL, - 0x8000000080008081ULL, 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL, - }; + // load round constant + // note: this will bomb if num_rounds < 8 or num_rounds > 24. + __m512i rc = _mm512_loadu_epi64((void*) RCS); + + // load rc permutation + static const uint64_t rc_ps[8] = { 1, 2, 3, 4, 5, 6, 7, 0 }; + const __m512i rc_p = _mm512_loadu_epi64((void*) rc_ps); + + // load rows + __m512i r0 = _mm512_maskz_loadu_epi64(m, (void*) (s)), + r1 = _mm512_maskz_loadu_epi64(m, (void*) (s + 5)), + r2 = _mm512_maskz_loadu_epi64(m, (void*) (s + 10)), + r3 = _mm512_maskz_loadu_epi64(m, (void*) (s + 15)), + r4 = _mm512_maskz_loadu_epi64(m, (void*) (s + 20)); + + for (int i = 0; i < SHA3_NUM_ROUNDS; i++) { + // theta + { + const __m512i i0 = _mm512_setr_epi64(4, 0, 1, 2, 3, 5, 6, 7), + i1 = _mm512_setr_epi64(1, 2, 3, 4, 0, 5, 6, 7); + + // c = xor(r0, r1, r2, r3, r4) + const __m512i r01 = _mm512_xor_epi64(r0, r1), + r23 = _mm512_xor_epi64(r2, r3), + r03 = _mm512_xor_epi64(r01, r23), + c = _mm512_xor_epi64(r03, r4); + + // d = xor(permute(i0, c), permute(i1, rol(c, 1))) + const __m512i d0 = _mm512_permutexvar_epi64(i0, c), + d1 = _mm512_permutexvar_epi64(i1, _mm512_rol_epi64(c, 1)), + d = _mm512_xor_epi64(d0, d1); + + // row = xor(row, d) + r0 = _mm512_xor_epi64(r0, d); + r1 = _mm512_xor_epi64(r1, d); + r2 = _mm512_xor_epi64(r2, d); + r3 = _mm512_xor_epi64(r3, d); + r4 = _mm512_xor_epi64(r4, d); + } + + // rho + { + // unaligned load mask and rotate values + static const uint64_t vs0[8] = { 0, 1, 62, 28, 27, 0, 0, 0 }, + vs1[8] = { 36, 44, 6, 55, 20, 0, 0, 0 }, + vs2[8] = { 3, 10, 43, 25, 39, 0, 0, 0 }, + vs3[8] = { 41, 45, 15, 21, 8, 0, 0, 0 }, + vs4[8] = { 18, 2, 61, 56, 14, 0, 0, 0 }; + + // load rotate values + const __m512i v0 = _mm512_loadu_epi64((void*) vs0), + v1 = _mm512_loadu_epi64((void*) vs1), + v2 = _mm512_loadu_epi64((void*) vs2), + v3 = _mm512_loadu_epi64((void*) vs3), + v4 = _mm512_loadu_epi64((void*) vs4); + + // rotate + r0 = _mm512_rolv_epi64(r0, v0); + r1 = _mm512_rolv_epi64(r1, v1); + r2 = _mm512_rolv_epi64(r2, v2); + r3 = _mm512_rolv_epi64(r3, v3); + r4 = _mm512_rolv_epi64(r4, v4); + } + + // pi + { + // mask bytes + uint8_t m01b = 0x03, + m23b = 0x0c, + m4b = 0x10; + + // load masks + const __mmask8 m01 = _load_mask8(&m01b), + m23 = _load_mask8(&m23b), + m4 = _load_mask8(&m4b); + + // permutation indices + // + // (note: these are masked so only the relevant indices for + // _mm512_maskz_permutex2var_epi64() in each array are filled in) + static const uint64_t t0_pis_01[8] = { 0, 8 + 1, 0, 0, 0, 0, 0, 0 }, + t0_pis_23[8] = { 0, 0, 2, 8 + 3, 0, 0, 0, 0 }, + t0_pis_4[8] = { 0, 0, 0, 0, 4, 0, 0, 0 }, + + t1_pis_01[8] = { 3, 8 + 4, 0, 0, 0, 0, 0, 0 }, + t1_pis_23[8] = { 0, 0, 0, 8 + 1, 0, 0, 0, 0 }, + t1_pis_4[8] = { 0, 0, 0, 0, 2, 0, 0, 0 }, + + t2_pis_01[8] = { 1, 8 + 2, 0, 0, 0, 0, 0, 0 }, + t2_pis_23[8] = { 0, 0, 3, 8 + 4, 0, 0, 0, 0 }, + t2_pis_4[8] = { 0, 0, 0, 0, 0, 0, 0, 0 }, + + t3_pis_01[8] = { 4, 8 + 0, 0, 0, 0, 0, 0, 0 }, + t3_pis_23[8] = { 0, 0, 1, 8 + 2, 0, 0, 0, 0 }, + t3_pis_4[8] = { 0, 0, 0, 0, 3, 0, 0, 0 }, + + t4_pis_01[8] = { 2, 8 + 3, 0, 0, 0, 0, 0, 0 }, + t4_pis_23[8] = { 0, 0, 4, 8 + 0, 0, 0, 0, 0 }, + t4_pis_4[8] = { 0, 0, 0, 0, 1, 0, 0, 0 }; + + // load permutation indices + const __m512i t0_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t0_pis_01), + t0_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t0_pis_23), + t0_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t0_pis_4), + + t1_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t1_pis_01), + t1_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t1_pis_23), + t1_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t1_pis_4), + + t2_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t2_pis_01), + t2_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t2_pis_23), + t2_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t2_pis_4), + + t3_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t3_pis_01), + t3_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t3_pis_23), + t3_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t3_pis_4), + + t4_p01 = _mm512_maskz_loadu_epi64(m01, (void*) t4_pis_01), + t4_p23 = _mm512_maskz_loadu_epi64(m23, (void*) t4_pis_23), + t4_p4 = _mm512_maskz_loadu_epi64(m4, (void*) t4_pis_4); + + // permute rows + const __m512i t0e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t0_p01, r1), + t0e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t0_p23, r3), + t0e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t0_p4, r0), + t0 = _mm512_or_epi64(_mm512_or_epi64(t0e0, t0e2), t0e4), + + t1e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t1_p01, r1), + t1e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t1_p23, r3), + t1e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t1_p4, r0), + t1 = _mm512_or_epi64(_mm512_or_epi64(t1e0, t1e2), t1e4), + + t2e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t2_p01, r1), + t2e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t2_p23, r3), + t2e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t2_p4, r0), + t2 = _mm512_or_epi64(_mm512_or_epi64(t2e0, t2e2), t2e4), + + t3e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t3_p01, r1), + t3e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t3_p23, r3), + t3e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t3_p4, r0), + t3 = _mm512_or_epi64(_mm512_or_epi64(t3e0, t3e2), t3e4), + + t4e0 = _mm512_maskz_permutex2var_epi64(m01, r0, t4_p01, r1), + t4e2 = _mm512_maskz_permutex2var_epi64(m23, r2, t4_p23, r3), + t4e4 = _mm512_maskz_permutex2var_epi64(m4, r4, t4_p4, r0), + t4 = _mm512_or_epi64(_mm512_or_epi64(t4e0, t4e2), t4e4); + + // store rows + r0 = t0; + r1 = t1; + r2 = t2; + r3 = t3; + r4 = t4; + } + + // chi + { + // permutation indices + static const uint64_t pis0[8] = { 1, 2, 3, 4, 0, 0, 0, 0 }, + pis1[8] = { 2, 3, 4, 0, 1, 0, 0, 0 }; + + // load permutation indices + const __m512i p0 = _mm512_maskz_loadu_epi64(m, (void*) pis0), + p1 = _mm512_maskz_loadu_epi64(m, (void*) pis1); + + // permute rows + const __m512i t0_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r0), + t0_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r0), + t0 = _mm512_xor_epi64(r0, _mm512_andnot_epi64(t0_e0, t0_e1)), + + t1_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r1), + t1_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r1), + t1 = _mm512_xor_epi64(r1, _mm512_andnot_epi64(t1_e0, t1_e1)), + + t2_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r2), + t2_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r2), + t2 = _mm512_xor_epi64(r2, _mm512_andnot_epi64(t2_e0, t2_e1)), + + t3_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r3), + t3_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r3), + t3 = _mm512_xor_epi64(r3, _mm512_andnot_epi64(t3_e0, t3_e1)), + + t4_e0 = _mm512_maskz_permutexvar_epi64(m, p0, r4), + t4_e1 = _mm512_maskz_permutexvar_epi64(m, p1, r4), + t4 = _mm512_xor_epi64(r4, _mm512_andnot_epi64(t4_e0, t4_e1)); + + // store rows + r0 = t0; + r1 = t1; + r2 = t2; + r3 = t3; + r4 = t4; + } + + // iota + { + // xor round constant, shuffle rc register + r0 = _mm512_mask_xor_epi64(r0, m0, r0, rc); + rc = _mm512_permutexvar_epi64(rc_p, rc); + + if (((i + 1) % 8) == 0 && i != 23) { + // load next set of round constants + // note: this will bomb if num_rounds < 8 or num_rounds > 24. + rc = _mm512_loadu_epi64((void*) (RCS + (i + 1))); + } + } + } + + // store rows + _mm512_mask_storeu_epi64((void*) (s), m, r0), + _mm512_mask_storeu_epi64((void*) (s + 5), m, r1), + _mm512_mask_storeu_epi64((void*) (s + 10), m, r2), + _mm512_mask_storeu_epi64((void*) (s + 15), m, r3), + _mm512_mask_storeu_epi64((void*) (s + 20), m, r4); +} + +// 12 round keccak permutation (avx512 implementation). +static inline void permute12_avx512(uint64_t s[static 25]) { + // unaligned load mask and permutation indices + uint8_t mask = 0x1f, + m0b = 0x01; + const __mmask8 m = _load_mask8(&mask), + m0 = _load_mask8(&m0b); // load round constant // note: this will bomb if num_rounds < 8 or num_rounds > 24. - __m512i rc = _mm512_loadu_epi64((void*) (RCS + 24 - num_rounds)); + __m512i rc = _mm512_loadu_epi64((void*) (RCS + 12)); // load rc permutation static const uint64_t rc_ps[8] = { 1, 2, 3, 4, 5, 6, 7, 0 }; @@ -252,7 +474,7 @@ static inline void permute(uint64_t s[static 25], const size_t num_rounds) { r3 = _mm512_maskz_loadu_epi64(m, (void*) (s + 15)), r4 = _mm512_maskz_loadu_epi64(m, (void*) (s + 20)); - for (int i = 0; i < (int) num_rounds; i++) { + for (int i = 0; i < SHA3_NUM_ROUNDS; i++) { // theta { const __m512i i0 = _mm512_setr_epi64(4, 0, 1, 2, 3, 5, 6, 7), @@ -437,10 +659,10 @@ static inline void permute(uint64_t s[static 25], const size_t num_rounds) { r0 = _mm512_mask_xor_epi64(r0, m0, r0, rc); rc = _mm512_permutexvar_epi64(rc_p, rc); - if (((24 - num_rounds + i + 1) % 8) == 0 && i != 23) { + if (((12 + i + 1) % 8) == 0 && (12 + i) != 23) { // load next set of round constants // note: this will bomb if num_rounds < 8 or num_rounds > 24. - rc = _mm512_loadu_epi64((void*) (RCS + 24 - num_rounds + (i + 1))); + rc = _mm512_loadu_epi64((void*) (RCS + (12 + i + 1))); } } } @@ -454,9 +676,76 @@ static inline void permute(uint64_t s[static 25], const size_t num_rounds) { } #endif /* __AVX512F__ */ +#ifdef __AVX512F__ +#define permute permute_avx512 +#define permute12 permute12_avx512 +#else /* !__AVX512F__ */ +#define permute permute_scalar +#define permute12 permute12_scalar +#endif /* __AVX512F__ */ + // absorb message into state, return updated byte count // used by `hash_absorb()`, `hash_once()`, and `xof_absorb_raw()` -static inline size_t absorb(sha3_state_t * const a, size_t num_bytes, const size_t rate, const size_t num_rounds, const uint8_t *m, size_t m_len) { +static inline size_t absorb(sha3_state_t * const a, size_t num_bytes, const size_t rate, const uint8_t *m, size_t m_len) { + // absorb aligned chunks + if ((num_bytes & 7) == 0 && (((uintptr_t) m) & 7) == 0) { + // absorb 32 byte chunks (4 x uint64) + while (m_len >= 32 && num_bytes <= rate - 32) { + // xor chunk into state + // (FIXME: does not vectorize for some reason, even when unrolled) + for (size_t i = 0; i < 4; i++) { + a->u64[num_bytes/8 + i] ^= ((uint64_t*) m)[i]; + } + + // update counters + num_bytes += 32; + m += 32; + m_len -= 32; + + if (num_bytes == rate) { + // permute state + permute(a->u64); + num_bytes = 0; + } + } + + // absorb 8 byte chunks (1 x uint64) + while (m_len >= 8 && num_bytes <= rate - 8) { + // xor chunk into state + a->u64[num_bytes/8] ^= *((uint64_t*) m); + + // update counters + num_bytes += 8; + m += 8; + m_len -= 8; + + if (num_bytes == rate) { + // permute state + permute(a->u64); + num_bytes = 0; + } + } + } + + // absorb remaining bytes + for (size_t i = 0; i < m_len; i++) { + // xor byte into state + a->u8[num_bytes++] ^= m[i]; + + if (num_bytes == rate) { + // permute state + permute(a->u64); + num_bytes = 0; + } + } + + // return byte count + return num_bytes; +} + +// absorb message into xof12 state, return updated byte count +// used by `xof12_absorb_raw()` +static inline size_t absorb12(sha3_state_t * const a, size_t num_bytes, const size_t rate, const uint8_t *m, size_t m_len) { // absorb aligned chunks if ((num_bytes & 7) == 0 && (((uintptr_t) m) & 7) == 0) { // absorb 32 byte chunks (4 x uint64) @@ -474,7 +763,7 @@ static inline size_t absorb(sha3_state_t * const a, size_t num_bytes, const size if (num_bytes == rate) { // permute state - permute(a->u64, num_rounds); + permute12(a->u64); num_bytes = 0; } } @@ -491,7 +780,7 @@ static inline size_t absorb(sha3_state_t * const a, size_t num_bytes, const size if (num_bytes == rate) { // permute state - permute(a->u64, num_rounds); + permute12(a->u64); num_bytes = 0; } } @@ -504,7 +793,7 @@ static inline size_t absorb(sha3_state_t * const a, size_t num_bytes, const size if (num_bytes == rate) { // permute state - permute(a->u64, num_rounds); + permute12(a->u64); num_bytes = 0; } } @@ -513,6 +802,7 @@ static inline size_t absorb(sha3_state_t * const a, size_t num_bytes, const size return num_bytes; } + // Get rate (number of bytes that can be absorbed before the internal // state is permuted). // @@ -546,7 +836,7 @@ static inline void hash_once(const uint8_t *m, size_t m_len, uint8_t * const dst sha3_state_t a = { .u64 = { 0 } }; // absorb message, get new internal length - const size_t len = absorb(&a, 0, RATE(dst_len), SHA3_NUM_ROUNDS, m, m_len); + const size_t len = absorb(&a, 0, RATE(dst_len), m, m_len); // append suffix and padding // (note: suffix and padding are ambiguous in spec) @@ -554,7 +844,7 @@ static inline void hash_once(const uint8_t *m, size_t m_len, uint8_t * const dst a.u8[RATE(dst_len)-1] ^= 0x80; // final permutation - permute(a.u64, SHA3_NUM_ROUNDS); + permute(a.u64); // copy to destination memcpy(dst, a.u8, dst_len); @@ -575,7 +865,7 @@ static inline bool hash_absorb(sha3_t * const hash, const size_t rate, const uin } // absorb bytes, return success - hash->num_bytes = absorb(&(hash->a), hash->num_bytes, rate, SHA3_NUM_ROUNDS, src, len); + hash->num_bytes = absorb(&(hash->a), hash->num_bytes, rate, src, len); return true; } @@ -593,7 +883,7 @@ static inline void hash_final(sha3_t * const hash, const size_t rate, uint8_t * hash->a.u8[rate - 1] ^= 0x80; // permute - permute(hash->a.u64, SHA3_NUM_ROUNDS); + permute(hash->a.u64); } // copy to destination @@ -637,14 +927,14 @@ static inline void xof_init(sha3_xof_t * const xof) { // context has already been squeezed. // // Called by `xof_absorb()` and `xof_once()`. -static inline void xof_absorb_raw(sha3_xof_t * const xof, const size_t rate, const size_t num_rounds, const uint8_t *m, size_t m_len) { - xof->num_bytes = absorb(&(xof->a), xof->num_bytes, rate, num_rounds, m, m_len); +static inline void xof_absorb_raw(sha3_xof_t * const xof, const size_t rate, const uint8_t *m, size_t m_len) { + xof->num_bytes = absorb(&(xof->a), xof->num_bytes, rate, m, m_len); } // Absorb data into XOF context. // // Returns `false` if this context has already been squeezed. -static inline _Bool xof_absorb(sha3_xof_t * const xof, const size_t rate, const size_t num_rounds, const uint8_t * const m, size_t m_len) { +static inline _Bool xof_absorb(sha3_xof_t * const xof, const size_t rate, const uint8_t * const m, size_t m_len) { // check context state if (xof->squeezing) { // xof has already been squeezed, return error @@ -652,19 +942,19 @@ static inline _Bool xof_absorb(sha3_xof_t * const xof, const size_t rate, const } // absorb, return success - xof_absorb_raw(xof, rate, num_rounds, m, m_len); + xof_absorb_raw(xof, rate, m, m_len); return true; } // Finalize absorb, switch mode of XOF context to squeezing. -static inline void xof_absorb_done(sha3_xof_t * const xof, const size_t rate, const size_t num_rounds, const uint8_t pad) { +static inline void xof_absorb_done(sha3_xof_t * const xof, const size_t rate, const uint8_t pad) { // append suffix (s6.2) and padding // (note: suffix and padding are ambiguous in spec) xof->a.u8[xof->num_bytes] ^= pad; xof->a.u8[rate - 1] ^= 0x80; // permute - permute(xof->a.u64, num_rounds); + permute(xof->a.u64); // switch to squeeze mode xof->num_bytes = 0; @@ -672,7 +962,7 @@ static inline void xof_absorb_done(sha3_xof_t * const xof, const size_t rate, co } // Squeeze data without checking mode (used by `xof_once()`). -static inline void xof_squeeze_raw(sha3_xof_t * const xof, const size_t rate, const size_t num_rounds, uint8_t *dst, size_t dst_len) { +static inline void xof_squeeze_raw(sha3_xof_t * const xof, const size_t rate, uint8_t *dst, size_t dst_len) { if (!xof->num_bytes) { // num_bytes is zero, so we are reading from the start of the // internal state buffer. while `dst_len` is greater than rate, @@ -681,7 +971,7 @@ static inline void xof_squeeze_raw(sha3_xof_t * const xof, const size_t rate, co // rate-sized chunks to destination while (dst_len >= rate) { memcpy(dst, xof->a.u8, rate); // copy rate-sized chunk - permute(xof->a.u64, num_rounds); // permute state + permute(xof->a.u64); // permute state // update destination pointer and length dst += rate; @@ -705,7 +995,7 @@ static inline void xof_squeeze_raw(sha3_xof_t * const xof, const size_t rate, co dst[i] = xof->a.u8[xof->num_bytes++]; // squeeze byte to destination if (xof->num_bytes == rate) { - permute(xof->a.u64, num_rounds); // permute state + permute(xof->a.u64); // permute state xof->num_bytes = 0; // clear read bytes count } } @@ -713,28 +1003,137 @@ static inline void xof_squeeze_raw(sha3_xof_t * const xof, const size_t rate, co } // squeeze data from xof -static inline void xof_squeeze(sha3_xof_t * const xof, const size_t rate, const size_t num_rounds, const uint8_t pad, uint8_t * const dst, const size_t dst_len) { +static inline void xof_squeeze(sha3_xof_t * const xof, const size_t rate, const uint8_t pad, uint8_t * const dst, const size_t dst_len) { // check state if (!xof->squeezing) { // finalize absorb - xof_absorb_done(xof, rate, num_rounds, pad); + xof_absorb_done(xof, rate, pad); } - xof_squeeze_raw(xof, rate, num_rounds, dst, dst_len); + xof_squeeze_raw(xof, rate, dst, dst_len); } // one-shot xof absorb and squeeze -static inline void xof_once(const size_t rate, const size_t num_rounds, const uint8_t pad, const uint8_t * const src, const size_t src_len, uint8_t * const dst, const size_t dst_len) { +static inline void xof_once(const size_t rate, const uint8_t pad, const uint8_t * const src, const size_t src_len, uint8_t * const dst, const size_t dst_len) { // init sha3_xof_t xof; xof_init(&xof); // absorb - xof_absorb_raw(&xof, rate, num_rounds, src, src_len); - xof_absorb_done(&xof, rate, num_rounds, pad); + xof_absorb_raw(&xof, rate, src, src_len); + xof_absorb_done(&xof, rate, pad); + + // squeeze + xof_squeeze_raw(&xof, rate, dst, dst_len); +} + +// initialize xof12 context +static inline void xof12_init(sha3_xof_t * const xof) { + memset(xof, 0, sizeof(sha3_xof_t)); +} + +// Absorb data into XOF12 context without checking to see if the +// context has already been squeezed. +// +// Called by `xof12_absorb()` and `xof12_once()`. +static inline void xof12_absorb_raw(sha3_xof_t * const xof, const size_t rate, const uint8_t *m, size_t m_len) { + xof->num_bytes = absorb12(&(xof->a), xof->num_bytes, rate, m, m_len); +} + +// Absorb data into XOF context. +// +// Returns `false` if this XOF12 context has already been squeezed. +static inline _Bool xof12_absorb(sha3_xof_t * const xof, const size_t rate, const uint8_t * const m, size_t m_len) { + // check context state + if (xof->squeezing) { + // xof has already been squeezed, return error + return false; + } + + // absorb, return success + xof12_absorb_raw(xof, rate, m, m_len); + return true; +} + +// Finalize absorb, switch mode of XOF12 context to squeezing. +static inline void xof12_absorb_done(sha3_xof_t * const xof, const size_t rate, const uint8_t pad) { + // append suffix (s6.2) and padding + // (note: suffix and padding are ambiguous in spec) + xof->a.u8[xof->num_bytes] ^= pad; + xof->a.u8[rate - 1] ^= 0x80; + + // permute + permute12(xof->a.u64); + + // switch to squeeze mode + xof->num_bytes = 0; + xof->squeezing = true; +} + +// Squeeze data without checking mode (used by `xof12_once()`). +static inline void xof12_squeeze_raw(sha3_xof_t * const xof, const size_t rate, uint8_t *dst, size_t dst_len) { + if (!xof->num_bytes) { + // num_bytes is zero, so we are reading from the start of the + // internal state buffer. while `dst_len` is greater than rate, + // copy `rate` sized chunks directly from the internal state buffer + // to the destination, then permute the internal state. squeeze + // rate-sized chunks to destination + while (dst_len >= rate) { + memcpy(dst, xof->a.u8, rate); // copy rate-sized chunk + permute12(xof->a.u64); // permute state + + // update destination pointer and length + dst += rate; + dst_len -= rate; + } + + if (dst_len > 0) { + // the remaining destination length is less than `rate`, so copy a + // `dst_len`-sized chunk from the internal state to the + // destination buffer, then update the read byte count. + + // squeeze dst_len-sized block to destination + memcpy(dst, xof->a.u8, dst_len); // copy dst_len-sized chunk + xof->num_bytes = dst_len; // update read byte count + } + } else { + // fall back to squeezing one byte at a time + + // squeeze bytes to destination + for (size_t i = 0; i < dst_len; i++) { + dst[i] = xof->a.u8[xof->num_bytes++]; // squeeze byte to destination + + if (xof->num_bytes == rate) { + permute12(xof->a.u64); // permute state + xof->num_bytes = 0; // clear read bytes count + } + } + } +} + +// squeeze data from xof12 context +static inline void xof12_squeeze(sha3_xof_t * const xof, const size_t rate, const uint8_t pad, uint8_t * const dst, const size_t dst_len) { + // check state + if (!xof->squeezing) { + // finalize absorb + xof12_absorb_done(xof, rate, pad); + } + + xof12_squeeze_raw(xof, rate, dst, dst_len); +} + +// one-shot xof12 absorb and squeeze +static inline void xof12_once(const size_t rate, const uint8_t pad, const uint8_t * const src, const size_t src_len, uint8_t * const dst, const size_t dst_len) { + // init + sha3_xof_t xof; + xof12_init(&xof); + + // absorb + xof12_absorb_raw(&xof, rate, src, src_len); + xof12_absorb_done(&xof, rate, pad); // squeeze - xof_squeeze_raw(&xof, rate, num_rounds, dst, dst_len); + xof12_squeeze_raw(&xof, rate, dst, dst_len); } // define shake iterative context and one-shot functions @@ -746,17 +1145,17 @@ static inline void xof_once(const size_t rate, const size_t num_rounds, const ui \ /* absorb bytes into shake context */ \ _Bool shake ## BITS ## _absorb(sha3_xof_t * const xof, const uint8_t * const m, const size_t len) { \ - return xof_absorb(xof, SHAKE ## BITS ## _RATE, SHA3_NUM_ROUNDS, m, len); \ + return xof_absorb(xof, SHAKE ## BITS ## _RATE, m, len); \ } \ \ /* squeeze bytes from shake context */ \ void shake ## BITS ## _squeeze(sha3_xof_t * const xof, uint8_t * const dst, const size_t dst_len) { \ - xof_squeeze(xof, SHAKE ## BITS ## _RATE, SHA3_NUM_ROUNDS, SHAKE_PAD, dst, dst_len); \ + xof_squeeze(xof, SHAKE ## BITS ## _RATE, SHAKE_PAD, dst, dst_len); \ } \ \ /* one-shot shake absorb and squeeze */ \ void shake ## BITS(const uint8_t * const src, const size_t src_len, uint8_t * const dst, const size_t dst_len) { \ - xof_once(SHAKE ## BITS ## _RATE, SHA3_NUM_ROUNDS, SHAKE_PAD, src, src_len, dst, dst_len); \ + xof_once(SHAKE ## BITS ## _RATE, SHAKE_PAD, src, src_len, dst, dst_len); \ } // shake padding byte and rates @@ -1070,12 +1469,12 @@ static inline bytepad_t bytepad(const size_t data_len, const size_t width) { #define DEF_CSHAKE(BITS) \ /* absorb data into cshake context */ \ _Bool cshake ## BITS ## _xof_absorb(sha3_xof_t * const xof, const uint8_t * const msg, const size_t len) { \ - return xof_absorb(xof, SHAKE ## BITS ## _RATE, SHA3_NUM_ROUNDS, msg, len); \ + return xof_absorb(xof, SHAKE ## BITS ## _RATE, msg, len); \ } \ \ /* squeeze data from cshake context */ \ void cshake ## BITS ## _xof_squeeze(sha3_xof_t * const xof, uint8_t * const dst, const size_t len) { \ - xof_squeeze(xof, SHAKE ## BITS ## _RATE, SHA3_NUM_ROUNDS, CSHAKE_PAD, dst, len); \ + xof_squeeze(xof, SHAKE ## BITS ## _RATE, CSHAKE_PAD, dst, len); \ } \ \ /* initialize cshake context */ \ @@ -1522,7 +1921,7 @@ static inline _Bool turboshake_init(turboshake_t * const ts, const uint8_t pad) } // init xof - xof_init(&(ts->xof)); + xof12_init(&(ts->xof)); ts->pad = pad; // return success @@ -1544,22 +1943,22 @@ static inline _Bool turboshake_init(turboshake_t * const ts, const uint8_t pad) \ /* absorb bytes into turboshake context. */ \ _Bool turboshake ## BITS ## _absorb(turboshake_t * const ts, const uint8_t * const m, const size_t len) { \ - return xof_absorb(&(ts->xof), SHAKE ## BITS ## _RATE, TURBOSHAKE_NUM_ROUNDS, m, len); \ + return xof12_absorb(&(ts->xof), SHAKE ## BITS ## _RATE, m, len); \ } \ \ /* squeeze bytes from turboshake context */ \ void turboshake ## BITS ## _squeeze(turboshake_t * const ts, uint8_t * const dst, const size_t dst_len) { \ - xof_squeeze(&(ts->xof), SHAKE ## BITS ## _RATE, TURBOSHAKE_NUM_ROUNDS, ts->pad, dst, dst_len); \ + xof12_squeeze(&(ts->xof), SHAKE ## BITS ## _RATE, ts->pad, dst, dst_len); \ } \ \ /* one-shot turboshake with default pad byte */ \ void turboshake ## BITS (const uint8_t * const src, const size_t src_len, uint8_t * const dst, const size_t dst_len) { \ - xof_once(SHAKE ## BITS ## _RATE, TURBOSHAKE_NUM_ROUNDS, TURBOSHAKE_PAD, src, src_len, dst, dst_len); \ + xof12_once(SHAKE ## BITS ## _RATE, TURBOSHAKE_PAD, src, src_len, dst, dst_len); \ } \ \ /* one-shot turboshake with custom pad byte */ \ void turboshake ## BITS ## _custom(const uint8_t pad, const uint8_t * const src, const size_t src_len, uint8_t * const dst, const size_t dst_len) { \ - xof_once(SHAKE ## BITS ## _RATE, TURBOSHAKE_NUM_ROUNDS, pad, src, src_len, dst, dst_len); \ + xof12_once(SHAKE ## BITS ## _RATE, pad, src, src_len, dst, dst_len); \ } // declare turboshake functions @@ -1978,17 +2377,84 @@ static void test_iota(void) { } } -static void test_permute(void) { - uint64_t a[25] = { [0] = 0x00000001997b5853ULL, [16] = 0x8000000000000000ULL }; - const uint64_t exp[] = { - 0xE95A9E40EF2F24C8ULL, 0x24C64DAE57C8F1D1ULL, - 0x8CAA629F80192BB9ULL, 0xD0B178A0541C4107ULL, - }; +static const struct { + uint64_t a[25]; // input state + const uint64_t exp[25]; // expected value + const size_t exp_len; // length of exp, in bytes +} PERMUTE_TESTS[] = {{ + .a = { [0] = 0x00000001997b5853ULL, [16] = 0x8000000000000000ULL }, + .exp = { 0xE95A9E40EF2F24C8ULL, 0x24C64DAE57C8F1D1ULL, 0x8CAA629F80192BB9ULL, 0xD0B178A0541C4107ULL }, + .exp_len = 32, +}}; + +static void test_permute_scalar(void) { + for (size_t i = 0; i < sizeof(PERMUTE_TESTS) / sizeof(PERMUTE_TESTS[0]); i++) { + const size_t exp_len = PERMUTE_TESTS[i].exp_len; + + uint64_t got[25] = { 0 }; + memcpy(got, PERMUTE_TESTS[i].a, sizeof(got)); + permute_scalar(got); + + if (memcmp(got, PERMUTE_TESTS[i].exp, exp_len)) { + fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE_TESTS[i].exp, exp_len); + } + } +} - permute(a, SHA3_NUM_ROUNDS); - if (memcmp(exp, a, sizeof(exp))) { - fail_test(__func__, "", (uint8_t*) a, 32, (uint8_t*) exp, 32); +static void test_permute_avx512(void) { +#ifdef __AVX512F__ + for (size_t i = 0; i < sizeof(PERMUTE_TESTS) / sizeof(PERMUTE_TESTS[0]); i++) { + const size_t exp_len = PERMUTE_TESTS[i].exp_len; + + uint64_t got[25] = { 0 }; + memcpy(got, PERMUTE_TESTS[i].a, sizeof(got)); + permute_avx512(got); + + if (memcmp(got, PERMUTE_TESTS[i].exp, exp_len)) { + fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE_TESTS[i].exp, exp_len); + } + } +#endif /* __AVX512F__ */ +} + +static const struct { + uint64_t a[25]; // input state + const uint64_t exp[25]; // expected value + const size_t exp_len; // length of exp, in bytes +} PERMUTE12_TESTS[] = {{ + .a = { [0] = 0x00000001997b5853ULL, [16] = 0x8000000000000000ULL }, + .exp = { 0X8B346BAFF5DA94C6ULL, 0XD7D37EC35E3B2EECULL, 0XBBF724EABFD84018ULL, 0X5E3C1AFA4EA7B3A1ULL }, + .exp_len = 32, +}}; + +static void test_permute12_scalar(void) { + for (size_t i = 0; i < sizeof(PERMUTE12_TESTS) / sizeof(PERMUTE12_TESTS[0]); i++) { + const size_t exp_len = PERMUTE12_TESTS[i].exp_len; + + uint64_t got[25] = { 0 }; + memcpy(got, PERMUTE12_TESTS[i].a, sizeof(got)); + permute12_scalar(got); + + if (memcmp(got, PERMUTE12_TESTS[i].exp, exp_len)) { + fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE12_TESTS[i].exp, exp_len); + } + } +} + +static void test_permute12_avx512(void) { +#ifdef __AVX512F__ + for (size_t i = 0; i < sizeof(PERMUTE12_TESTS) / sizeof(PERMUTE12_TESTS[0]); i++) { + const size_t exp_len = PERMUTE12_TESTS[i].exp_len; + + uint64_t got[25] = { 0 }; + memcpy(got, PERMUTE12_TESTS[i].a, sizeof(got)); + permute12_avx512(got); + + if (memcmp(got, PERMUTE12_TESTS[i].exp, exp_len)) { + fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE12_TESTS[i].exp, exp_len); + } } +#endif /* __AVX512F__ */ } static void test_sha3_224(void) { @@ -6143,7 +6609,10 @@ int main(void) { test_pi(); test_chi(); test_iota(); - test_permute(); + test_permute_scalar(); + test_permute_avx512(); + test_permute12_scalar(); + test_permute12_avx512(); test_sha3_224(); test_sha3_256(); test_sha3_384(); |