diff options
-rw-r--r-- | sha3.c | 203 |
1 files changed, 202 insertions, 1 deletions
@@ -31,6 +31,7 @@ #define BACKEND_AVX512 2 // AVX-512 backend #define BACKEND_NEON 3 // Neon backend (experimental) #define BACKEND_DIET_NEON 4 // Neon backend which uses fewer registers +#define BACKEND_HYBRID_NEON 5 // Hybrid neon backend // if SHA3_BACKEND is defined and set to 0 (the default), then unset it // and auto-detect the appropriate backend @@ -1057,16 +1058,180 @@ static inline void permute_n_diet_neon(uint64_t a[static 25], const size_t num_r } // store column 4 of r4 - vst1_u64(a + 24, row_get_tail(r4)); + vst1_u64(a + 24, row_last(r4)); } #endif /* SHA3_BACKEND == BACKEND_DIET_NEON */ +#if (SHA3_BACKEND == BACKEND_HYBRID_NEON) +#include <arm_neon.h> + +/** + * @brief Scalar Keccak permutation. + * + * Apply `num_rounds` of Keccak permutation. This function is only + * called by: + * + * - `permute_scalar()`: 24 rounds + * - `permute12_scalar()`: 12 rounds. Used by TurboSHAKE and KangarooTwelve. + * + * @param[in,out] a Keccak state (array of 25 64-bit integers). + * @param[in] num_rounds Number of rounds (12 or 24). + */ +static inline void permute_n_hybrid_neon(uint64_t a[static 25], const size_t num_rounds) { + uint64_t tmp[25] = { 0 }; + for (size_t i = SHA3_NUM_ROUNDS - num_rounds; i < SHA3_NUM_ROUNDS; i++) { + // theta + const uint64_t c[5] = { + a[0] ^ a[5] ^ a[10] ^ a[15] ^ a[20], + a[1] ^ a[6] ^ a[11] ^ a[16] ^ a[21], + a[2] ^ a[7] ^ a[12] ^ a[17] ^ a[22], + a[3] ^ a[8] ^ a[13] ^ a[18] ^ a[23], + a[4] ^ a[9] ^ a[14] ^ a[19] ^ a[24], + }; + + const uint64_t d[5] = { + c[4] ^ ROL(c[1], 1), + c[0] ^ ROL(c[2], 1), + c[1] ^ ROL(c[3], 1), + c[2] ^ ROL(c[4], 1), + c[3] ^ ROL(c[0], 1), + }; + + const uint64x2_t d0 = { d[0], d[1] }, + d1 = { d[2], d[3] }; + + uint64x2_t r0a = { a[ 0], a[ 1] }, r0b = { a[ 2], a[ 3] }, + r1a = { a[ 5], a[ 6] }, r1b = { a[ 7], a[ 8] }, + r2a = { a[10], a[11] }, r2b = { a[12], a[13] }, + r3a = { a[15], a[16] }, r3b = { a[17], a[18] }, + r4a = { a[20], a[21] }, r4b = { a[22], a[23] }; + + r0a ^= d0; r0b ^= d1; a[ 4] ^= d[4]; + r1a ^= d0; r1b ^= d1; a[ 9] ^= d[4]; + r2a ^= d0; r2b ^= d1; a[14] ^= d[4]; + r3a ^= d0; r3b ^= d1; a[19] ^= d[4]; + r4a ^= d0; r4b ^= d1; a[24] ^= d[4]; + + + // rho + { + static const uint64x2_t r0a_ids = { 0, 1 }, + r0b_ids = { 62, 28 }, + r1a_ids = { 36, 44 }, + r1b_ids = { 6, 55 }, + r2a_ids = { 3, 10 }, + r2b_ids = { 43, 25 }, + r3a_ids = { 41, 45 }, + r3b_ids = { 15, 21 }, + r4a_ids = { 18, 2 }, + r4b_ids = { 61, 56 }; + + r0a = (r0a << r0a_ids) | (r0a >> (64 - r0a_ids)); + r0b = (r0b << r0b_ids) | (r0b >> (64 - r0b_ids)); + a[ 4] = ROL(a[ 4], 27); // 91 % 64 = 27 + + r1a = (r1a << r1a_ids) | (r1a >> (64 - r1a_ids)); + r1b = (r1b << r1b_ids) | (r1b >> (64 - r1b_ids)); + a[ 9] = ROL(a[ 9], 20); // 276 % 64 = 20 + + r2a = (r2a << r2a_ids) | (r2a >> (64 - r2a_ids)); + r2b = (r2b << r2b_ids) | (r2b >> (64 - r2b_ids)); + a[14] = ROL(a[14], 39); // 231 % 64 = 39 + + r3a = (r3a << r3a_ids) | (r3a >> (64 - r3a_ids)); + r3b = (r3b << r3b_ids) | (r3b >> (64 - r3b_ids)); + a[19] = ROL(a[19], 8); // 136 % 64 = 8 + + r4a = (r4a << r4a_ids) | (r4a >> (64 - r4a_ids)); + r4b = (r4b << r4b_ids) | (r4b >> (64 - r4b_ids)); + a[24] = ROL(a[24], 14); // 78 % 64 = 14 + + vst1q_u64(a + 0, r0a); vst1q_u64(a + 2, r0b); + vst1q_u64(a + 5, r1a); vst1q_u64(a + 7, r1b); + vst1q_u64(a + 10, r2a); vst1q_u64(a + 12, r2b); + vst1q_u64(a + 15, r3a); vst1q_u64(a + 17, r3b); + vst1q_u64(a + 20, r4a); vst1q_u64(a + 22, r4b); + } + + // pi + { + tmp[ 0] = a[ 0]; + tmp[ 1] = a[ 6]; + tmp[ 2] = a[12]; + tmp[ 3] = a[18]; + tmp[ 4] = a[24]; + + tmp[ 5] = a[ 3]; + tmp[ 6] = a[ 9]; + tmp[ 7] = a[10]; + tmp[ 8] = a[16]; + tmp[ 9] = a[22]; + + tmp[10] = a[ 1]; + tmp[11] = a[ 7]; + tmp[12] = a[13]; + tmp[13] = a[19]; + tmp[14] = a[20]; + + tmp[15] = a[ 4]; + tmp[16] = a[ 5]; + tmp[17] = a[11]; + tmp[18] = a[17]; + tmp[19] = a[23]; + + tmp[20] = a[ 2]; + tmp[21] = a[ 8]; + tmp[22] = a[14]; + tmp[23] = a[15]; + tmp[24] = a[21]; + } + + // chi + { + a[ 0] = tmp[ 0] ^ (~tmp[ 1] & tmp[ 2]); + a[ 1] = tmp[ 1] ^ (~tmp[ 2] & tmp[ 3]); + a[ 2] = tmp[ 2] ^ (~tmp[ 3] & tmp[ 4]); + a[ 3] = tmp[ 3] ^ (~tmp[ 4] & tmp[ 0]); + a[ 4] = tmp[ 4] ^ (~tmp[ 0] & tmp[ 1]); + + a[ 5] = tmp[ 5] ^ (~tmp[ 6] & tmp[ 7]); + a[ 6] = tmp[ 6] ^ (~tmp[ 7] & tmp[ 8]); + a[ 7] = tmp[ 7] ^ (~tmp[ 8] & tmp[ 9]); + a[ 8] = tmp[ 8] ^ (~tmp[ 9] & tmp[ 5]); + a[ 9] = tmp[ 9] ^ (~tmp[ 5] & tmp[ 6]); + + a[10] = tmp[10] ^ (~tmp[11] & tmp[12]); + a[11] = tmp[11] ^ (~tmp[12] & tmp[13]); + a[12] = tmp[12] ^ (~tmp[13] & tmp[14]); + a[13] = tmp[13] ^ (~tmp[14] & tmp[10]); + a[14] = tmp[14] ^ (~tmp[10] & tmp[11]); + + a[15] = tmp[15] ^ (~tmp[16] & tmp[17]); + a[16] = tmp[16] ^ (~tmp[17] & tmp[18]); + a[17] = tmp[17] ^ (~tmp[18] & tmp[19]); + a[18] = tmp[18] ^ (~tmp[19] & tmp[15]); + a[19] = tmp[19] ^ (~tmp[15] & tmp[16]); + + a[20] = tmp[20] ^ (~tmp[21] & tmp[22]); + a[21] = tmp[21] ^ (~tmp[22] & tmp[23]); + a[22] = tmp[22] ^ (~tmp[23] & tmp[24]); + a[23] = tmp[23] ^ (~tmp[24] & tmp[20]); + a[24] = tmp[24] ^ (~tmp[20] & tmp[21]); + } + + a[0] ^= RCS[i]; + } +} +#endif /* (SHA3_BACKEND == BACKEND_HYBRID_NEON) */ + #if SHA3_BACKEND == BACKEND_AVX512 #define permute_n permute_n_avx512 // use avx512 backend #elif SHA3_BACKEND == BACKEND_NEON #define permute_n permute_n_neon // use neon backend #elif SHA3_BACKEND == BACKEND_DIET_NEON #define permute_n permute_n_diet_neon // use diet-neon backend +#elif SHA3_BACKEND == BACKEND_HYBRID_NEON +#define permute_n permute_n_hybrid_neon // use hybrid-neon backend #elif SHA3_BACKEND == BACKEND_SCALAR #define permute_n permute_n_scalar // use scalar backend #else @@ -2564,6 +2729,8 @@ const char *sha3_backend(void) { return "neon"; #elif SHA3_BACKEND == BACKEND_DIET_NEON return "diet-neon"; +#elif SHA3_BACKEND == BACKEND_HYBRID_NEON + return "hybrid-neon"; #elif SHA3_BACKEND == BACKEND_SCALAR return "scalar"; #endif /* SHA3_BACKEND */ @@ -2873,6 +3040,22 @@ static void test_permute_diet_neon(void) { #endif /* SHA3_BACKEND == BACKEND_DIET_NEON */ } +static void test_permute_hybrid_neon(void) { +#if SHA3_BACKEND == BACKEND_HYBRID_NEON + for (size_t i = 0; i < sizeof(PERMUTE_TESTS) / sizeof(PERMUTE_TESTS[0]); i++) { + const size_t exp_len = PERMUTE_TESTS[i].exp_len; + + uint64_t got[25] = { 0 }; + memcpy(got, PERMUTE_TESTS[i].a, sizeof(got)); + permute_n_hybrid_neon(got, 24); // call permute_n_diet_neon() directly + + if (memcmp(got, PERMUTE_TESTS[i].exp, exp_len)) { + fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE_TESTS[i].exp, exp_len); + } + } +#endif /* SHA3_BACKEND == BACKEND_HYBRID_NEON */ +} + static const struct { uint64_t a[25]; // input state const uint64_t exp[25]; // expected value @@ -2945,6 +3128,22 @@ static void test_permute12_diet_neon(void) { #endif /* SHA3_BACKEND == BACKEND_DIET_NEON */ } +static void test_permute12_hybrid_neon(void) { +#if SHA3_BACKEND == BACKEND_HYBRID_NEON + for (size_t i = 0; i < sizeof(PERMUTE12_TESTS) / sizeof(PERMUTE12_TESTS[0]); i++) { + const size_t exp_len = PERMUTE12_TESTS[i].exp_len; + + uint64_t got[25] = { 0 }; + memcpy(got, PERMUTE12_TESTS[i].a, sizeof(got)); + permute_n_hybrid_neon(got, 12); // call permute_n() directly + + if (memcmp(got, PERMUTE12_TESTS[i].exp, exp_len)) { + fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE12_TESTS[i].exp, exp_len); + } + } +#endif /* SHA3_BACKEND == BACKEND_HYBRID_NEON */ +} + static void test_sha3_224(void) { static const struct { const char *name; // test name @@ -7101,10 +7300,12 @@ int main(void) { test_permute_avx512(); test_permute_neon(); test_permute_diet_neon(); + test_permute_hybrid_neon(); test_permute12_scalar(); test_permute12_avx512(); test_permute12_neon(); test_permute12_diet_neon(); + test_permute12_hybrid_neon(); test_sha3_224(); test_sha3_256(); test_sha3_384(); |