aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sha3.c203
1 files changed, 202 insertions, 1 deletions
diff --git a/sha3.c b/sha3.c
index 21b2c78..b12e20d 100644
--- a/sha3.c
+++ b/sha3.c
@@ -31,6 +31,7 @@
#define BACKEND_AVX512 2 // AVX-512 backend
#define BACKEND_NEON 3 // Neon backend (experimental)
#define BACKEND_DIET_NEON 4 // Neon backend which uses fewer registers
+#define BACKEND_HYBRID_NEON 5 // Hybrid neon backend
// if SHA3_BACKEND is defined and set to 0 (the default), then unset it
// and auto-detect the appropriate backend
@@ -1057,16 +1058,180 @@ static inline void permute_n_diet_neon(uint64_t a[static 25], const size_t num_r
}
// store column 4 of r4
- vst1_u64(a + 24, row_get_tail(r4));
+ vst1_u64(a + 24, row_last(r4));
}
#endif /* SHA3_BACKEND == BACKEND_DIET_NEON */
+#if (SHA3_BACKEND == BACKEND_HYBRID_NEON)
+#include <arm_neon.h>
+
+/**
+ * @brief Scalar Keccak permutation.
+ *
+ * Apply `num_rounds` of Keccak permutation. This function is only
+ * called by:
+ *
+ * - `permute_scalar()`: 24 rounds
+ * - `permute12_scalar()`: 12 rounds. Used by TurboSHAKE and KangarooTwelve.
+ *
+ * @param[in,out] a Keccak state (array of 25 64-bit integers).
+ * @param[in] num_rounds Number of rounds (12 or 24).
+ */
+static inline void permute_n_hybrid_neon(uint64_t a[static 25], const size_t num_rounds) {
+ uint64_t tmp[25] = { 0 };
+ for (size_t i = SHA3_NUM_ROUNDS - num_rounds; i < SHA3_NUM_ROUNDS; i++) {
+ // theta
+ const uint64_t c[5] = {
+ a[0] ^ a[5] ^ a[10] ^ a[15] ^ a[20],
+ a[1] ^ a[6] ^ a[11] ^ a[16] ^ a[21],
+ a[2] ^ a[7] ^ a[12] ^ a[17] ^ a[22],
+ a[3] ^ a[8] ^ a[13] ^ a[18] ^ a[23],
+ a[4] ^ a[9] ^ a[14] ^ a[19] ^ a[24],
+ };
+
+ const uint64_t d[5] = {
+ c[4] ^ ROL(c[1], 1),
+ c[0] ^ ROL(c[2], 1),
+ c[1] ^ ROL(c[3], 1),
+ c[2] ^ ROL(c[4], 1),
+ c[3] ^ ROL(c[0], 1),
+ };
+
+ const uint64x2_t d0 = { d[0], d[1] },
+ d1 = { d[2], d[3] };
+
+ uint64x2_t r0a = { a[ 0], a[ 1] }, r0b = { a[ 2], a[ 3] },
+ r1a = { a[ 5], a[ 6] }, r1b = { a[ 7], a[ 8] },
+ r2a = { a[10], a[11] }, r2b = { a[12], a[13] },
+ r3a = { a[15], a[16] }, r3b = { a[17], a[18] },
+ r4a = { a[20], a[21] }, r4b = { a[22], a[23] };
+
+ r0a ^= d0; r0b ^= d1; a[ 4] ^= d[4];
+ r1a ^= d0; r1b ^= d1; a[ 9] ^= d[4];
+ r2a ^= d0; r2b ^= d1; a[14] ^= d[4];
+ r3a ^= d0; r3b ^= d1; a[19] ^= d[4];
+ r4a ^= d0; r4b ^= d1; a[24] ^= d[4];
+
+
+ // rho
+ {
+ static const uint64x2_t r0a_ids = { 0, 1 },
+ r0b_ids = { 62, 28 },
+ r1a_ids = { 36, 44 },
+ r1b_ids = { 6, 55 },
+ r2a_ids = { 3, 10 },
+ r2b_ids = { 43, 25 },
+ r3a_ids = { 41, 45 },
+ r3b_ids = { 15, 21 },
+ r4a_ids = { 18, 2 },
+ r4b_ids = { 61, 56 };
+
+ r0a = (r0a << r0a_ids) | (r0a >> (64 - r0a_ids));
+ r0b = (r0b << r0b_ids) | (r0b >> (64 - r0b_ids));
+ a[ 4] = ROL(a[ 4], 27); // 91 % 64 = 27
+
+ r1a = (r1a << r1a_ids) | (r1a >> (64 - r1a_ids));
+ r1b = (r1b << r1b_ids) | (r1b >> (64 - r1b_ids));
+ a[ 9] = ROL(a[ 9], 20); // 276 % 64 = 20
+
+ r2a = (r2a << r2a_ids) | (r2a >> (64 - r2a_ids));
+ r2b = (r2b << r2b_ids) | (r2b >> (64 - r2b_ids));
+ a[14] = ROL(a[14], 39); // 231 % 64 = 39
+
+ r3a = (r3a << r3a_ids) | (r3a >> (64 - r3a_ids));
+ r3b = (r3b << r3b_ids) | (r3b >> (64 - r3b_ids));
+ a[19] = ROL(a[19], 8); // 136 % 64 = 8
+
+ r4a = (r4a << r4a_ids) | (r4a >> (64 - r4a_ids));
+ r4b = (r4b << r4b_ids) | (r4b >> (64 - r4b_ids));
+ a[24] = ROL(a[24], 14); // 78 % 64 = 14
+
+ vst1q_u64(a + 0, r0a); vst1q_u64(a + 2, r0b);
+ vst1q_u64(a + 5, r1a); vst1q_u64(a + 7, r1b);
+ vst1q_u64(a + 10, r2a); vst1q_u64(a + 12, r2b);
+ vst1q_u64(a + 15, r3a); vst1q_u64(a + 17, r3b);
+ vst1q_u64(a + 20, r4a); vst1q_u64(a + 22, r4b);
+ }
+
+ // pi
+ {
+ tmp[ 0] = a[ 0];
+ tmp[ 1] = a[ 6];
+ tmp[ 2] = a[12];
+ tmp[ 3] = a[18];
+ tmp[ 4] = a[24];
+
+ tmp[ 5] = a[ 3];
+ tmp[ 6] = a[ 9];
+ tmp[ 7] = a[10];
+ tmp[ 8] = a[16];
+ tmp[ 9] = a[22];
+
+ tmp[10] = a[ 1];
+ tmp[11] = a[ 7];
+ tmp[12] = a[13];
+ tmp[13] = a[19];
+ tmp[14] = a[20];
+
+ tmp[15] = a[ 4];
+ tmp[16] = a[ 5];
+ tmp[17] = a[11];
+ tmp[18] = a[17];
+ tmp[19] = a[23];
+
+ tmp[20] = a[ 2];
+ tmp[21] = a[ 8];
+ tmp[22] = a[14];
+ tmp[23] = a[15];
+ tmp[24] = a[21];
+ }
+
+ // chi
+ {
+ a[ 0] = tmp[ 0] ^ (~tmp[ 1] & tmp[ 2]);
+ a[ 1] = tmp[ 1] ^ (~tmp[ 2] & tmp[ 3]);
+ a[ 2] = tmp[ 2] ^ (~tmp[ 3] & tmp[ 4]);
+ a[ 3] = tmp[ 3] ^ (~tmp[ 4] & tmp[ 0]);
+ a[ 4] = tmp[ 4] ^ (~tmp[ 0] & tmp[ 1]);
+
+ a[ 5] = tmp[ 5] ^ (~tmp[ 6] & tmp[ 7]);
+ a[ 6] = tmp[ 6] ^ (~tmp[ 7] & tmp[ 8]);
+ a[ 7] = tmp[ 7] ^ (~tmp[ 8] & tmp[ 9]);
+ a[ 8] = tmp[ 8] ^ (~tmp[ 9] & tmp[ 5]);
+ a[ 9] = tmp[ 9] ^ (~tmp[ 5] & tmp[ 6]);
+
+ a[10] = tmp[10] ^ (~tmp[11] & tmp[12]);
+ a[11] = tmp[11] ^ (~tmp[12] & tmp[13]);
+ a[12] = tmp[12] ^ (~tmp[13] & tmp[14]);
+ a[13] = tmp[13] ^ (~tmp[14] & tmp[10]);
+ a[14] = tmp[14] ^ (~tmp[10] & tmp[11]);
+
+ a[15] = tmp[15] ^ (~tmp[16] & tmp[17]);
+ a[16] = tmp[16] ^ (~tmp[17] & tmp[18]);
+ a[17] = tmp[17] ^ (~tmp[18] & tmp[19]);
+ a[18] = tmp[18] ^ (~tmp[19] & tmp[15]);
+ a[19] = tmp[19] ^ (~tmp[15] & tmp[16]);
+
+ a[20] = tmp[20] ^ (~tmp[21] & tmp[22]);
+ a[21] = tmp[21] ^ (~tmp[22] & tmp[23]);
+ a[22] = tmp[22] ^ (~tmp[23] & tmp[24]);
+ a[23] = tmp[23] ^ (~tmp[24] & tmp[20]);
+ a[24] = tmp[24] ^ (~tmp[20] & tmp[21]);
+ }
+
+ a[0] ^= RCS[i];
+ }
+}
+#endif /* (SHA3_BACKEND == BACKEND_HYBRID_NEON) */
+
#if SHA3_BACKEND == BACKEND_AVX512
#define permute_n permute_n_avx512 // use avx512 backend
#elif SHA3_BACKEND == BACKEND_NEON
#define permute_n permute_n_neon // use neon backend
#elif SHA3_BACKEND == BACKEND_DIET_NEON
#define permute_n permute_n_diet_neon // use diet-neon backend
+#elif SHA3_BACKEND == BACKEND_HYBRID_NEON
+#define permute_n permute_n_hybrid_neon // use hybrid-neon backend
#elif SHA3_BACKEND == BACKEND_SCALAR
#define permute_n permute_n_scalar // use scalar backend
#else
@@ -2564,6 +2729,8 @@ const char *sha3_backend(void) {
return "neon";
#elif SHA3_BACKEND == BACKEND_DIET_NEON
return "diet-neon";
+#elif SHA3_BACKEND == BACKEND_HYBRID_NEON
+ return "hybrid-neon";
#elif SHA3_BACKEND == BACKEND_SCALAR
return "scalar";
#endif /* SHA3_BACKEND */
@@ -2873,6 +3040,22 @@ static void test_permute_diet_neon(void) {
#endif /* SHA3_BACKEND == BACKEND_DIET_NEON */
}
+static void test_permute_hybrid_neon(void) {
+#if SHA3_BACKEND == BACKEND_HYBRID_NEON
+ for (size_t i = 0; i < sizeof(PERMUTE_TESTS) / sizeof(PERMUTE_TESTS[0]); i++) {
+ const size_t exp_len = PERMUTE_TESTS[i].exp_len;
+
+ uint64_t got[25] = { 0 };
+ memcpy(got, PERMUTE_TESTS[i].a, sizeof(got));
+ permute_n_hybrid_neon(got, 24); // call permute_n_diet_neon() directly
+
+ if (memcmp(got, PERMUTE_TESTS[i].exp, exp_len)) {
+ fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE_TESTS[i].exp, exp_len);
+ }
+ }
+#endif /* SHA3_BACKEND == BACKEND_HYBRID_NEON */
+}
+
static const struct {
uint64_t a[25]; // input state
const uint64_t exp[25]; // expected value
@@ -2945,6 +3128,22 @@ static void test_permute12_diet_neon(void) {
#endif /* SHA3_BACKEND == BACKEND_DIET_NEON */
}
+static void test_permute12_hybrid_neon(void) {
+#if SHA3_BACKEND == BACKEND_HYBRID_NEON
+ for (size_t i = 0; i < sizeof(PERMUTE12_TESTS) / sizeof(PERMUTE12_TESTS[0]); i++) {
+ const size_t exp_len = PERMUTE12_TESTS[i].exp_len;
+
+ uint64_t got[25] = { 0 };
+ memcpy(got, PERMUTE12_TESTS[i].a, sizeof(got));
+ permute_n_hybrid_neon(got, 12); // call permute_n() directly
+
+ if (memcmp(got, PERMUTE12_TESTS[i].exp, exp_len)) {
+ fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE12_TESTS[i].exp, exp_len);
+ }
+ }
+#endif /* SHA3_BACKEND == BACKEND_HYBRID_NEON */
+}
+
static void test_sha3_224(void) {
static const struct {
const char *name; // test name
@@ -7101,10 +7300,12 @@ int main(void) {
test_permute_avx512();
test_permute_neon();
test_permute_diet_neon();
+ test_permute_hybrid_neon();
test_permute12_scalar();
test_permute12_avx512();
test_permute12_neon();
test_permute12_diet_neon();
+ test_permute12_hybrid_neon();
test_sha3_224();
test_sha3_256();
test_sha3_384();