diff options
Diffstat (limited to 'sha3.c')
-rw-r--r-- | sha3.c | 329 |
1 files changed, 329 insertions, 0 deletions
@@ -30,6 +30,7 @@ #define BACKEND_SCALAR 1 // scalar backend #define BACKEND_AVX512 2 // AVX-512 backend #define BACKEND_NEON 3 // Neon backend +#define BACKEND_DIET_NEON 4 // Neon backend which uses fewer registers // if SHA3_BACKEND is defined and set to 0 (the default), then unset it // and auto-detect the appropriate backend @@ -758,10 +759,302 @@ static inline void permute_n_neon(uint64_t a[static 25], const size_t num_rounds } #endif /* SHA3_BACKEND == BACKEND_NEON */ +#if SHA3_BACKEND == BACKEND_DIET_NEON +#include <arm_neon.h> + +// rotate element in uint64x1_t left by N bits +#define VROL(A, N) (vorr_u64(vshl_n_u64((A), (N)), vshr_n_u64((A), 64-(N)))) + +// rotate elements in uint64x2_t left by N bits +// note: vrax1q_u64() not supported on pizza +#define VROLQ(A, N) (vorrq_u64(vshlq_n_u64((A), (N)), vshrq_n_u64((A), 64-(N)))) + +// keccak row, represented as 3 128-bit vector registers +// +// columns are stored in the low 5 64-bit lanes. this wastes one +// 64-bit lane per row at the expense of making many of the instructions +// simpler. +typedef struct { + uint64x2x2_t head; // first 4 columns + uint64x1_t tail; // last column +} row_t; + +// set contents of row +static inline row_t row_set(const uint64x2_t a, const uint64x2_t b, const uint64x1_t c) { + return (row_t) { .head = { { a, b } }, .tail = c }; +} + +// get Nth pair of u64s from row +static inline uint64x2_t row_get(const row_t a, const size_t n) { + return a.head.val[n]; +} + +// get tail column from row +static inline uint64x1_t row_get_tail(const row_t a) { + return a.tail; +} + +// rotate row lanes left +// +// --------------------------- --------------------------- +// | 64-bit Lanes (Before) | | 64-bit Lanes (After) | +// |-------------------------| |-------------------------| +// | 0 | 1 | 2 | 3 | 4 | 5 | --> | 0 | 1 | 2 | 3 | 4 | 5 | +// |---|---|---|---|---|-----| |---|---|---|---|---|-----| +// | A | B | C | D | E | n/a | | E | A | B | C | D | n/a | +// --------------------------- --------------------------- +// +static inline row_t row_rll(const row_t a) { + return row_set( + vcombine_u64(row_get_tail(a), vdup_laneq_u64(row_get(a, 0), 0)), // { a4, a0 } + vextq_u64(row_get(a, 0), row_get(a, 1), 1), // { a1, a2 } + vdup_laneq_u64(row_get(a, 1), 1) // { a3, n/a } + ); +} + +// rotate row lanes right +// +// --------------------------- --------------------------- +// | 64-bit Lanes (Before) | | 64-bit Lanes (After) | +// |-------------------------| |-------------------------| +// | 0 | 1 | 2 | 3 | 4 | 5 | --> | 0 | 1 | 2 | 3 | 4 | 5 | +// |---|---|---|---|---|-----| |---|---|---|---|---|-----| +// | A | B | C | D | E | n/a | | B | C | D | E | A | n/a | +// --------------------------- --------------------------- +// +static inline row_t row_rlr(const row_t a) { + return row_set( + vextq_u64(row_get(a, 0), row_get(a, 1), 1), // { a1, a2 } + vcombine_u64(vdup_laneq_u64(row_get(a, 1), 1), row_get_tail(a)), // { a3, a4 } + vdup_laneq_u64(row_get(a, 0), 0) // { a0, n/a } + ); +} + +// c = a ^ b +static inline row_t row_eor(const row_t a, const row_t b) { + return row_set( + row_get(a, 0) ^ row_get(b, 0), + row_get(a, 1) ^ row_get(b, 1), + row_get_tail(a) ^ row_get_tail(b) + ); +} + +// f = a ^ b ^ c ^ d ^ e +// FIXME want: veor3_u64(a, b, c); +static inline row_t row_eor5(const row_t a, const row_t b, const row_t c, const row_t d, const row_t e) { + return row_set( + row_get(a, 0) ^ row_get(b, 0) ^ row_get(c, 0) ^ row_get(d, 0) ^ row_get(e, 0), + row_get(a, 1) ^ row_get(b, 1) ^ row_get(c, 1) ^ row_get(d, 1) ^ row_get(e, 1), + row_get_tail(a) ^ row_get_tail(b) ^ row_get_tail(c) ^ row_get_tail(d) ^ row_get_tail(e) + ); +} + +// rotate bits in each lane left one bit +static inline row_t row_rol1_u64(const row_t a) { + return row_set( + VROLQ(row_get(a, 0), 1), + VROLQ(row_get(a, 1), 1), + VROL(row_get_tail(a), 1) + ); +} + +// rho lane rotate values +static const struct { + int64x2_t head[2]; + int64x1_t tail; +} RHO_IDS[] = { + { .head = { { 0, 1 }, { 62, 28 } }, .tail = { 27 } }, + { .head = { { 36, 44 }, { 6, 55 } }, .tail = { 20 } }, + { .head = { { 3, 10 }, { 43, 25 } }, .tail = { 39 } }, + { .head = { { 41, 45 }, { 15, 21 } }, .tail = { 8 } }, + { .head = { { 18, 2 }, { 61, 56 } }, .tail = { 14 } }, +}; + +// apply rho rotation to row +static inline row_t row_rho(const row_t a, const size_t id) { + const int64x2_t *vh = RHO_IDS[id].head; + const int64x1_t vt = RHO_IDS[id].tail; + return row_set( + vorrq_u64(vshlq_u64(row_get(a, 0), vh[0]), vshlq_u64(row_get(a, 0), vh[0] - 64)), + vorrq_u64(vshlq_u64(row_get(a, 1), vh[1]), vshlq_u64(row_get(a, 1), vh[1] - 64)), + vorr_u64(vshl_u64(row_get_tail(a), vt), vshl_u64(row_get_tail(a), vt - 64)) + ); +} + +// c = (~a & b) +// note: was using ~(a | ~b) = (~a & b) (demorgan's laws), but changed +// to BIC b, a instead (b & ~a) +static inline row_t row_andn(const row_t a, const row_t b) { + return row_set( + vbicq_u64(row_get(b, 0), row_get(a, 0)), + vbicq_u64(row_get(b, 1), row_get(a, 1)), + vbic_u64(row_get_tail(b), row_get_tail(a)) + ); +} + +// apply chi permutation to entire row +// note: ~(a | ~b) = (~a & b) (demorgan's laws) +static inline row_t row_chi(const row_t a) { + return row_eor(a, row_andn(row_rlr(a), row_set( + row_get(a, 1), // { a2, a3 } + vcombine_u64(row_get_tail(a), vdup_laneq_u64(row_get(a, 0), 0)), // { a4, a0 } + vdup_laneq_u64(row_get(a, 0), 1) // { a1 } + ))); +} + +// return new vector with low lane of first argument and high lane of +// second argument +static inline uint64x2_t pi_lo_hi(const uint64x2_t a, const uint64x2_t b) { + // was using vqtbl2q_u8() with tables, but this is faster + const uint64x2_t c = vextq_u64(b, a, 1); + return vextq_u64(c, c, 1); +} + +// neon keccak permutation with inlined steps +static inline void permute_n_neon(uint64_t a[static 25], const size_t num_rounds) { + // load rows + row_t r0, r1, r2, r3, r4; + { + // 3 loads of 8 and 1 load of 1 cell (3*8 + 1 = 25) + const uint64x2x4_t m0 = vld1q_u64_x4(a + 0), // r0 cols 0-4, r1 cols 0-2 + m1 = vld1q_u64_x4(a + 8), // r1 cols 3-4, r2 cols 0-4, r3 col 1 + m2 = vld1q_u64_x4(a + 16); // r3 cols 1-4, r4 cols 0-3 + const uint64x1_t m3 = vld1_u64(a + 24); // r4 col 4 + + // permute loaded data into rows + r0 = row_set(m0.val[0], m0.val[1], vdup_laneq_u64(m0.val[2], 0)); + r1 = row_set(vextq_u64(m0.val[2], m0.val[3], 1), vextq_u64(m0.val[3], m1.val[0], 1), vdup_laneq_u64(m1.val[0], 1)); + r2 = row_set(m1.val[1], m1.val[2], vdup_laneq_u64(m1.val[3], 0)); + r3 = row_set(vextq_u64(m1.val[3], m2.val[0], 1), vextq_u64(m2.val[0], m2.val[1], 1), vdup_laneq_u64(m2.val[1], 1)); + r4 = row_set(m2.val[2], m2.val[3], m3); + } + + // loop for num rounds + for (size_t i = 0; i < num_rounds; i++) { + // theta + { + // c = r0 ^ r1 ^ r2 ^ r3 ^ r4, d = rll(c) ^ (rlr(c) << 1) + const row_t c = row_eor5(r0, r1, r2, r3, r4), + d = row_eor(row_rll(c), row_rol1_u64(row_rlr(c))); + + r0 = row_eor(r0, d); // r0 ^= d + r1 = row_eor(r1, d); // r1 ^= d + r2 = row_eor(r2, d); // r2 ^= d + r3 = row_eor(r3, d); // r3 ^= d + r4 = row_eor(r4, d); // r4 ^= d + } + + // rho + r0 = row_rho(r0, 0); + r1 = row_rho(r1, 1); + r2 = row_rho(r2, 2); + r3 = row_rho(r3, 3); + r4 = row_rho(r4, 4); + + // pi + { + const row_t t0 = row_set( + pi_lo_hi(row_get(r0, 0), row_get(r1, 0)), + pi_lo_hi(row_get(r2, 1), row_get(r3, 1)), + row_get_tail(r4) + ); + + const row_t t1 = row_set( + vcombine_u64(vdup_laneq_u64(row_get(r0, 1), 1), row_get_tail(r1)), + pi_lo_hi(row_get(r2, 0), row_get(r3, 0)), + vdup_laneq_u64(row_get(r4, 1), 0) + ); + + const row_t t2 = row_set( + vextq_u64(row_get(r0, 0), row_get(r1, 1), 1), + vcombine_u64(vdup_laneq_u64(row_get(r2, 1), 1), row_get_tail(r3)), + vdup_laneq_u64(row_get(r4, 0), 0) + ); + + const row_t t3 = row_set( + vcombine_u64(row_get_tail(r0), vdup_laneq_u64(row_get(r1, 0), 0)), + vextq_u64(row_get(r2, 0), row_get(r3, 1), 1), + vdup_laneq_u64(row_get(r4, 1), 1) + ); + + const row_t t4 = row_set( + pi_lo_hi(row_get(r0, 1), row_get(r1, 1)), + vcombine_u64(row_get_tail(r2), vdup_laneq_u64(row_get(r3, 0), 0)), + vdup_laneq_u64(row_get(r4, 0), 1) + ); + + // store rows + r0 = t0; + r1 = t1; + r2 = t2; + r3 = t3; + r4 = t4; + } + + // chi + r0 = row_chi(r0); + r1 = row_chi(r1); + r2 = row_chi(r2); + r3 = row_chi(r3); + r4 = row_chi(r4); + + // iota + const uint64x2_t rc = { RCS[24 - num_rounds + i], 0 }; + r0.head.val[0] ^= rc; + } + + // store rows + { + // store columns 0-4 of r0 and columns 0-2 of r1 + const uint64x2x4_t m0 = { + .val = { + row_get(r0, 0), + row_get(r0, 1), + vcombine_u64(row_get_tail(r0), vdup_laneq_u64(row_get(r1, 0), 0)), + vextq_u64(row_get(r1, 0), row_get(r1, 1), 1) + }, + }; + vst1q_u64_x4(a + 0, m0); + } + + { + // store columns 3-4 of r1, columns 0-4 of r2, and column 0 of r3 + const uint64x2x4_t m1 = { + .val = { + vcombine_u64(vdup_laneq_u64(row_get(r1, 1), 1), row_get_tail(r1)), + row_get(r2, 0), + row_get(r2, 1), + vcombine_u64(row_get_tail(r2), vdup_laneq_u64(row_get(r3, 0), 0)), + }, + }; + vst1q_u64_x4(a + 8, m1); + } + + { + // store columns 1-4 of r3 and columns 03 of r4 + const uint64x2x4_t m2 = { + .val = { + vextq_u64(row_get(r3, 0), row_get(r3, 1), 1), + + vcombine_u64(vdup_laneq_u64(row_get(r3, 1), 1), row_get_tail(r3)), + row_get(r4, 0), + row_get(r4, 1), + }, + }; + vst1q_u64_x4(a + 16, m2); + } + + // store column 4 of r4 + vst1_u64(a + 24, row_get_tail(r4)); +} +#endif /* SHA3_BACKEND == BACKEND_DIET_NEON */ + #if SHA3_BACKEND == BACKEND_AVX512 #define permute_n permute_n_avx512 // use avx512 backend #elif SHA3_BACKEND == BACKEND_NEON #define permute_n permute_n_neon // use neon backend +#elif SHA3_BACKEND == BACKEND_DIET_NEON +#define permute_n permute_n_diet_neon // use diet-neon backend #elif SHA3_BACKEND == BACKEND_SCALAR #define permute_n permute_n_scalar // use scalar backend #else @@ -2257,6 +2550,8 @@ const char *sha3_backend(void) { return "avx512"; #elif SHA3_BACKEND == BACKEND_NEON return "neon"; +#elif SHA3_BACKEND == BACKEND_DIET_NEON + return "diet-neon"; #elif SHA3_BACKEND == BACKEND_SCALAR return "scalar"; #endif /* SHA3_BACKEND */ @@ -2550,6 +2845,22 @@ static void test_permute_neon(void) { #endif /* SHA3_BACKEND == BACKEND_NEON */ } +static void test_permute_diet_neon(void) { +#if SHA3_BACKEND == BACKEND_DIET_NEON + for (size_t i = 0; i < sizeof(PERMUTE_TESTS) / sizeof(PERMUTE_TESTS[0]); i++) { + const size_t exp_len = PERMUTE_TESTS[i].exp_len; + + uint64_t got[25] = { 0 }; + memcpy(got, PERMUTE_TESTS[i].a, sizeof(got)); + permute_n_diet_neon(got, 24); // call permute_n_diet_neon() directly + + if (memcmp(got, PERMUTE_TESTS[i].exp, exp_len)) { + fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE_TESTS[i].exp, exp_len); + } + } +#endif /* SHA3_BACKEND == BACKEND_DIET_NEON */ +} + static const struct { uint64_t a[25]; // input state const uint64_t exp[25]; // expected value @@ -2606,6 +2917,22 @@ static void test_permute12_neon(void) { #endif /* SHA3_BACKEND == BACKEND_NEON */ } +static void test_permute12_diet_neon(void) { +#if SHA3_BACKEND == BACKEND_DIET_NEON + for (size_t i = 0; i < sizeof(PERMUTE12_TESTS) / sizeof(PERMUTE12_TESTS[0]); i++) { + const size_t exp_len = PERMUTE12_TESTS[i].exp_len; + + uint64_t got[25] = { 0 }; + memcpy(got, PERMUTE12_TESTS[i].a, sizeof(got)); + permute_n_diet_neon(got, 12); // call permute_n_diet_neon() directly + + if (memcmp(got, PERMUTE12_TESTS[i].exp, exp_len)) { + fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE12_TESTS[i].exp, exp_len); + } + } +#endif /* SHA3_BACKEND == BACKEND_DIET_NEON */ +} + static void test_sha3_224(void) { static const struct { const char *name; // test name @@ -6761,9 +7088,11 @@ int main(void) { test_permute_scalar(); test_permute_avx512(); test_permute_neon(); + test_permute_diet_neon(); test_permute12_scalar(); test_permute12_avx512(); test_permute12_neon(); + test_permute12_diet_neon(); test_sha3_224(); test_sha3_256(); test_sha3_384(); |