From 4fdd7953fde90c0a68a3679bcdca1001b96188d6 Mon Sep 17 00:00:00 2001 From: Paul Duncan Date: Fri, 3 May 2024 22:47:23 -0400 Subject: sha3.c: add neon backend --- sha3.c | 314 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 305 insertions(+), 9 deletions(-) diff --git a/sha3.c b/sha3.c index 8563d27..375ee3b 100644 --- a/sha3.c +++ b/sha3.c @@ -26,16 +26,19 @@ /** @cond INTERNAL */ // available backends -#define BACKEND_AVX512 8 // avx512 backend -#define BACKEND_SCALAR 0 // scalar (default) backend +#define BACKEND_AVX512 8 // AVX-512 backend +#define BACKEND_NEON 4 // A64 Neon backend +#define BACKEND_SCALAR 0 // scalar (default) backend // auto-detect backend #ifndef SHA3_BACKEND -#ifdef __AVX512F__ +#if defined(__AVX512F__) #define SHA3_BACKEND BACKEND_AVX512 -#else /* !__AVX512F__ */ +#elif defined(__ARM_NEON) +#define SHA3_BACKEND BACKEND_NEON +#else #define SHA3_BACKEND BACKEND_SCALAR -#endif /* __AVX512F__ */ +#endif #endif /* SHA3_BACKEND */ // 64-bit rotate left @@ -462,12 +465,303 @@ static inline void permute_n_avx512(uint64_t s[static 25], const size_t num_roun } #endif /* SHA3_BACKEND == BACKEND_AVX512 */ +#if SHA3_BACKEND == BACKEND_NEON +#include + +// vrax1q_u64() not supported on pizza +#define VROLQ(A, N) (vorrq_u64(vshlq_n_u64((A), (N)), vshrq_n_u64((A), 64-(N)))) + +// keccak row, represented as 3 128-bit vector registers +// +// columns are stored in the low 5 64-bit lanes. this wastes one +// 64-bit lane per row at the expense of making many of the instructions +// simpler. +typedef union { + // uint64_t u64[6]; + uint64x2x3_t u64x2x3; + // uint8_t u8[48]; + // uint8x16_t u8x16[3]; + uint8x16x3_t u8x16x3; +} row_t; + +// TODO +// add row_load_fast which reads 6 elems and does this +// r2 = { .u64x2x3 = vld1q_u64_x3(a + 10) }, + +// load row from array +static inline row_t row_load(const uint64_t p[static 5]) { + row_t a = { 0 }; + + a.u64x2x3.val[0] = vld1q_u64(p + 0); + a.u64x2x3.val[1] = vld1q_u64(p + 2); + a.u64x2x3.val[2] = vdupq_n_u64(p[4]); + + return a; +} + +// store row to array +static inline void row_store(uint64_t p[static 5], const row_t a) { + // row_print(stderr, __func__, a); + vst1q_u64(p + 0, a.u64x2x3.val[0]); + vst1q_u64(p + 2, a.u64x2x3.val[1]); + vst1_u64(p + 4, vdup_laneq_u64(a.u64x2x3.val[2], 0)); + // p[4] = vgetq_lane_u64(a.u64x2x3.val[2], 0); +} + +// low lane ids for rol_rc{l,r}() +static const uint8x16_t ROW_RL_LO_IDS = { + 8, 9, 10, 11, 12, 13, 14, 15, 99, 99, 99, 99, 99, 99, 99, 99, +}; + +// high lane ids for rol_rc{l,r}() +static const uint8x16_t ROW_RL_HI_IDS = { + 99, 99, 99, 99, 99, 99, 99, 99, 0, 1, 2, 3, 4, 5, 6, 7, +}; + +// low lanes for last iteration of row_rlll() and first iteration of row_rlr() +static const uint8x16_t ROW_RL_TAIL_IDS = { + 0, 1, 2, 3, 4, 5, 6, 7, 99, 99, 99, 99, 99, 99, 99, 99, +}; + +// rotate row lanes left +// +// --------------------------- --------------------------- +// | 64-bit Lanes (Before) | | 64-bit Lanes (After) | +// |-------------------------| |-------------------------| +// | 0 | 1 | 2 | 3 | 4 | 5 | --> | 0 | 1 | 2 | 3 | 4 | 5 | +// |---|---|---|---|---|-----| |---|---|---|---|---|-----| +// | A | B | C | D | E | n/a | | E | A | B | C | D | n/a | +// --------------------------- --------------------------- +// +static inline row_t row_rll(const row_t a) { + row_t b = { 0 }; + for (size_t i = 0; i < 3; i++) { + const uint8x16_t lo_ids = i ? ROW_RL_LO_IDS : ROW_RL_TAIL_IDS, + hi = vqtbl1q_u8(a.u8x16x3.val[i], ROW_RL_HI_IDS), + lo = vqtbl1q_u8(a.u8x16x3.val[(i + 2) % 3], lo_ids); + b.u8x16x3.val[i] = vorrq_u8(lo, hi); + } + return b; +} + +// rotate row lanes right +// +// --------------------------- --------------------------- +// | 64-bit Lanes (Before) | | 64-bit Lanes (After) | +// |-------------------------| |-------------------------| +// | 0 | 1 | 2 | 3 | 4 | 5 | --> | 0 | 1 | 2 | 3 | 4 | 5 | +// |---|---|---|---|---|-----| |---|---|---|---|---|-----| +// | A | B | C | D | E | n/a | | B | C | D | E | A | n/a | +// --------------------------- --------------------------- +// +static row_t row_rlr(const row_t a) { + row_t b = { 0 }; + for (size_t i = 0; i < 2; i++) { + const uint8x16_t lo = vqtbl1q_u8(a.u8x16x3.val[i], ROW_RL_LO_IDS), + hi = vqtbl1q_u8(a.u8x16x3.val[(i + 1) % 3], ROW_RL_HI_IDS); + b.u8x16x3.val[i] = vorrq_u8(lo, hi); + } + b.u8x16x3.val[2] = vqtbl1q_u8(a.u8x16x3.val[0], ROW_RL_TAIL_IDS); + return b; +} + +// c = a ^ b +static inline row_t row_eor(const row_t a, const row_t b) { + row_t c = a; + for (size_t i = 0; i < 3; i++) { + c.u8x16x3.val[i] ^= b.u8x16x3.val[i]; + } + return c; +} + +// rotate bits in each lane left one bit +static inline row_t row_rol1_u64(const row_t a) { + row_t b = { 0 }; + for (size_t i = 0; i < 3; i++) { + b.u64x2x3.val[i] = VROLQ(a.u64x2x3.val[i], 1); + } + return b; +} + +// rotate bits in each lane left by amounts in vector +static inline row_t row_rotn_u64(const row_t a, const int64_t v[static 5]) { + row_t b = { 0 }; + static const int64x2_t k64 = { 64, 64 }; + for (size_t i = 0; i < 3; i++) { + const int64x2_t hi_ids = (i < 2) ? vld1q_s64(v + 2 * i) : vdupq_n_s64(v[4]), + lo_ids = vsubq_s64(hi_ids, k64); + b.u64x2x3.val[i] = vorrq_u64(vshlq_u64(a.u64x2x3.val[i], hi_ids), vshlq_u64(a.u64x2x3.val[i], lo_ids)); + } + return b; +} + +// return logical NOT of row +static inline row_t row_not(const row_t a) { + row_t b; + for (size_t i = 0; i < 3; i++) { + b.u8x16x3.val[i] = vmvnq_u8(a.u8x16x3.val[i]); + } + return b; +} + +// return logical OR NOT of rows +static inline row_t row_orn(const row_t a, const row_t b) { + row_t c; + for (size_t i = 0; i < 3; i++) { + c.u8x16x3.val[i] = vornq_u8(a.u8x16x3.val[i], b.u8x16x3.val[i]); + } + return c; +} + +// apply chi permutation to entire row +// note: ~(a | ~b) = (~a & b) (demorgan's laws) +static inline row_t row_chi(const row_t a) { + const row_t b = row_rlr(a), + c = row_rlr(b); // fixme, permute would be faster + return row_eor(a, row_not(row_orn(b, c))); +} + +// permute IDS to take low lane of first pair and hi lane of second pair +// a = [ a0, a1 ], b = [ b0, b1 ] => c = [ a0, b1 ] +static const uint8x16_t PI_LO_HI_IDS = { + 0, 1, 2, 3, 4, 5, 6, 7, 24, 25, 26, 27, 28, 29, 30, 31, +}; + +// permute IDS to take high lane of first pair and low lane of second pair +// a = [ a0, a1 ], b = [ b0, b1 ] => c = [ a1, b0 ] +static const uint8x16_t PI_HI_LO_IDS = { + 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, +}; + +static inline uint8x16_t pi_tbl(const uint8x16_t a, const uint8x16_t b, const uint8x16_t ids) { + uint8x16x2_t quad = { .val = { a, b } }; + return vqtbl2q_u8(quad, ids); +} + +// 24-round neon keccak permutation with inlined steps +void permute_n_neon(uint64_t a[static 25], const size_t num_rounds) { + // load rows + row_t r0 = row_load(a + 0), + r1 = row_load(a + 5), + r2 = row_load(a + 10), + r3 = row_load(a + 15), + r4 = row_load(a + 20); + + // loop for num rounds + for (size_t i = 0; i < num_rounds; i++) { + // theta + { + // c = r0 ^ r1 ^ r2 ^ r3 ^ r4 + const row_t c = row_eor(row_eor(row_eor(r0, r1), row_eor(r2, r3)), r4); + + // calculate d... + const row_t d = row_eor(row_rll(c), row_rol1_u64(row_rlr(c))); + + r0 = row_eor(r0, d); + r1 = row_eor(r1, d); + r2 = row_eor(r2, d); + r3 = row_eor(r3, d); + r4 = row_eor(r4, d); + } + + // rho + { + r0 = row_rotn_u64(r0, RHO_IDS + 0); + r1 = row_rotn_u64(r1, RHO_IDS + 5); + r2 = row_rotn_u64(r2, RHO_IDS + 10); + r3 = row_rotn_u64(r3, RHO_IDS + 15); + r4 = row_rotn_u64(r4, RHO_IDS + 20); + } + + // pi + { + row_t t0 = { 0 }; + { + // dst[ 0] = src[ 0]; dst[ 1] = src[ 6]; + t0.u8x16x3.val[0] = pi_tbl(r0.u8x16x3.val[0], r1.u8x16x3.val[0], PI_LO_HI_IDS); + // dst[ 2] = src[12]; dst[ 3] = src[18]; + t0.u8x16x3.val[1] = pi_tbl(r2.u8x16x3.val[1], r3.u8x16x3.val[1], PI_LO_HI_IDS); + // dst[ 4] = src[24]; + t0.u8x16x3.val[2] = r4.u8x16x3.val[2]; + } + + row_t t1 = { 0 }; + { + + // dst[ 5] = src[ 3]; dst[ 6] = src[ 9]; + t1.u8x16x3.val[0] = pi_tbl(r0.u8x16x3.val[1], r1.u8x16x3.val[2], PI_HI_LO_IDS); + // dst[ 7] = src[10]; dst[ 8] = src[16]; + t1.u8x16x3.val[1] = pi_tbl(r2.u8x16x3.val[0], r3.u8x16x3.val[0], PI_LO_HI_IDS); + // dst[ 9] = src[22]; + t1.u8x16x3.val[2] = r4.u8x16x3.val[1]; + } + + row_t t2 = { 0 }; + { + // dst[10] = src[ 1]; dst[11] = src[ 7]; + t2.u8x16x3.val[0] = pi_tbl(r0.u8x16x3.val[0], r1.u8x16x3.val[1], PI_HI_LO_IDS); + // dst[12] = src[13]; dst[13] = src[19]; + t2.u8x16x3.val[1] = pi_tbl(r2.u8x16x3.val[1], r3.u8x16x3.val[2], PI_HI_LO_IDS); + // dst[14] = src[20]; + t2.u8x16x3.val[2] = r4.u8x16x3.val[0]; + } + + row_t t3 = { 0 }; + { + // dst[15] = src[ 4]; dst[16] = src[ 5]; + // t3.u8x16x3.val[0] = pi_tbl(r0.u8x16x3.val[2], r1.u8x16x3.val[0], PI_LO_LO_IDS); + t3.u64x2x3.val[0] = vtrn1q_u64(r0.u64x2x3.val[2], r1.u64x2x3.val[0]); + // dst[17] = src[11]; dst[18] = src[17]; + t3.u8x16x3.val[1] = pi_tbl(r2.u8x16x3.val[0], r3.u8x16x3.val[1], PI_HI_LO_IDS); + // dst[19] = src[23]; + t3.u8x16x3.val[2] = pi_tbl(r4.u8x16x3.val[1], r4.u8x16x3.val[1], PI_HI_LO_IDS); + } + + row_t t4 = { 0 }; + { + // dst[20] = src[ 2]; dst[21] = src[ 8]; + t4.u8x16x3.val[0] = pi_tbl(r0.u8x16x3.val[1], r1.u8x16x3.val[1], PI_LO_HI_IDS); + // dst[22] = src[14]; dst[23] = src[15]; + // t4.u8x16x3.val[1] = pi_tbl(r2.u8x16x3.val[2], r3.u8x16x3.val[0], PI_LO_LO_IDS); + t4.u64x2x3.val[1] = vtrn1q_u64(r2.u64x2x3.val[2], r3.u64x2x3.val[0]); + // dst[24] = src[21]; + t4.u8x16x3.val[2] = pi_tbl(r4.u8x16x3.val[0], r4.u8x16x3.val[0], PI_HI_LO_IDS); + } + + r0 = t0; + r1 = t1; + r2 = t2; + r3 = t3; + r4 = t4; + } + + // chi + r0 = row_chi(r0); + r1 = row_chi(r1); + r2 = row_chi(r2); + r3 = row_chi(r3); + r4 = row_chi(r4); + + // iota + const uint64x2_t rc = { RCS[i], 0 }; + r0.u64x2x3.val[0] ^= rc; + } + + // store rows + row_store(a + 0, r0); + row_store(a + 5, r1); + row_store(a + 10, r2); + row_store(a + 15, r3); + row_store(a + 20, r4); +} +#endif /* SHA3_BACKEND == BACKEND_NEON */ + #if SHA3_BACKEND == BACKEND_AVX512 -// use avx512 backend -#define permute_n permute_n_avx512 +#define permute_n permute_n_avx512 // use avx512 backend +#elif SHA3_BACKEND == BACKEND_NEON +#define permute_n permute_n_neon // use neon backend #elif SHA3_BACKEND == BACKEND_SCALAR -// use scalar backend -#define permute_n permute_n_scalar +#define permute_n permute_n_scalar // use scalar backend #else #error "unknown sha3 backend" #endif /* SHA3_BACKEND */ @@ -1959,6 +2253,8 @@ void k12_once(const uint8_t *src, const size_t src_len, uint8_t *dst, const size const char *sha3_backend(void) { #if SHA3_BACKEND == BACKEND_AVX512 return "avx512"; +#elif SHA3_BACKEND == BACKEND_NEON + return "neon"; #elif SHA3_BACKEND == BACKEND_SCALAR return "scalar"; #endif /* SHA3_BACKEND */ -- cgit v1.2.3