diff options
| -rw-r--r-- | sha3.c | 314 | 
1 files changed, 305 insertions, 9 deletions
@@ -26,16 +26,19 @@  /** @cond INTERNAL */  // available backends -#define BACKEND_AVX512 8 // avx512 backend -#define BACKEND_SCALAR 0 // scalar (default) backend +#define BACKEND_AVX512 8    // AVX-512 backend +#define BACKEND_NEON 4  // A64 Neon backend +#define BACKEND_SCALAR 0    // scalar (default) backend  // auto-detect backend  #ifndef SHA3_BACKEND -#ifdef __AVX512F__ +#if defined(__AVX512F__)  #define SHA3_BACKEND BACKEND_AVX512 -#else /* !__AVX512F__ */ +#elif defined(__ARM_NEON) +#define SHA3_BACKEND BACKEND_NEON +#else  #define SHA3_BACKEND BACKEND_SCALAR -#endif /* __AVX512F__ */ +#endif  #endif /* SHA3_BACKEND */  // 64-bit rotate left @@ -462,12 +465,303 @@ static inline void permute_n_avx512(uint64_t s[static 25], const size_t num_roun  }  #endif /* SHA3_BACKEND == BACKEND_AVX512 */ +#if SHA3_BACKEND == BACKEND_NEON +#include <arm_neon.h> + +// vrax1q_u64() not supported on pizza +#define VROLQ(A, N) (vorrq_u64(vshlq_n_u64((A), (N)), vshrq_n_u64((A), 64-(N)))) + +// keccak row, represented as 3 128-bit vector registers +// +// columns are stored in the low 5 64-bit lanes.  this wastes one +// 64-bit lane per row at the expense of making many of the instructions +// simpler. +typedef union { +  // uint64_t u64[6]; +  uint64x2x3_t u64x2x3; +  // uint8_t u8[48]; +  // uint8x16_t u8x16[3]; +  uint8x16x3_t u8x16x3; +} row_t; + +// TODO +// add row_load_fast which reads 6 elems and does this +// r2 = { .u64x2x3 = vld1q_u64_x3(a + 10) }, + +// load row from array +static inline row_t row_load(const uint64_t p[static 5]) { +  row_t a = { 0 }; + +  a.u64x2x3.val[0] = vld1q_u64(p + 0); +  a.u64x2x3.val[1] = vld1q_u64(p + 2); +  a.u64x2x3.val[2] = vdupq_n_u64(p[4]); + +  return a; +} + +// store row to array +static inline void row_store(uint64_t p[static 5], const row_t a) { +  // row_print(stderr, __func__, a); +  vst1q_u64(p + 0, a.u64x2x3.val[0]); +  vst1q_u64(p + 2, a.u64x2x3.val[1]); +  vst1_u64(p + 4, vdup_laneq_u64(a.u64x2x3.val[2], 0)); +  // p[4] = vgetq_lane_u64(a.u64x2x3.val[2], 0); +} + +// low lane ids for rol_rc{l,r}() +static const uint8x16_t ROW_RL_LO_IDS = { +   8,  9, 10, 11, 12, 13, 14, 15,  99, 99, 99, 99, 99, 99, 99, 99, +}; + +// high lane ids for rol_rc{l,r}() +static const uint8x16_t ROW_RL_HI_IDS = { +  99, 99, 99, 99, 99, 99, 99, 99,   0,  1,  2,  3,  4,  5,  6,  7, +}; + +// low lanes for last iteration of row_rlll() and first iteration of row_rlr() +static const uint8x16_t ROW_RL_TAIL_IDS = { +   0,  1,  2,  3,  4,  5,  6,  7,  99, 99, 99, 99, 99, 99, 99, 99, +}; + +// rotate row lanes left +// +// ---------------------------     --------------------------- +// |  64-bit Lanes (Before)  |     |  64-bit Lanes (After)   | +// |-------------------------|     |-------------------------| +// | 0 | 1 | 2 | 3 | 4 |  5  | --> | 0 | 1 | 2 | 3 | 4 |  5  | +// |---|---|---|---|---|-----|     |---|---|---|---|---|-----| +// | A | B | C | D | E | n/a |     | E | A | B | C | D | n/a | +// ---------------------------     --------------------------- +// +static inline row_t row_rll(const row_t a) { +  row_t b = { 0 }; +  for (size_t i = 0; i < 3; i++) { +    const uint8x16_t lo_ids = i ? ROW_RL_LO_IDS : ROW_RL_TAIL_IDS, +                     hi = vqtbl1q_u8(a.u8x16x3.val[i], ROW_RL_HI_IDS), +                     lo = vqtbl1q_u8(a.u8x16x3.val[(i + 2) % 3], lo_ids); +    b.u8x16x3.val[i] = vorrq_u8(lo, hi); +  } +  return b; +} + +// rotate row lanes right +// +// ---------------------------     --------------------------- +// |  64-bit Lanes (Before)  |     |  64-bit Lanes (After)   | +// |-------------------------|     |-------------------------| +// | 0 | 1 | 2 | 3 | 4 |  5  | --> | 0 | 1 | 2 | 3 | 4 |  5  | +// |---|---|---|---|---|-----|     |---|---|---|---|---|-----| +// | A | B | C | D | E | n/a |     | B | C | D | E | A | n/a | +// ---------------------------     --------------------------- +// +static row_t row_rlr(const row_t a) { +  row_t b = { 0 }; +  for (size_t i = 0; i < 2; i++) { +    const uint8x16_t lo = vqtbl1q_u8(a.u8x16x3.val[i], ROW_RL_LO_IDS), +                     hi = vqtbl1q_u8(a.u8x16x3.val[(i + 1) % 3], ROW_RL_HI_IDS); +    b.u8x16x3.val[i] = vorrq_u8(lo, hi); +  } +  b.u8x16x3.val[2] = vqtbl1q_u8(a.u8x16x3.val[0], ROW_RL_TAIL_IDS); +  return b; +} + +// c = a ^ b +static inline row_t row_eor(const row_t a, const row_t b) { +  row_t c = a; +  for (size_t i = 0; i < 3; i++) { +    c.u8x16x3.val[i] ^= b.u8x16x3.val[i]; +  } +  return c; +} + +// rotate bits in each lane left one bit +static inline row_t row_rol1_u64(const row_t a) { +  row_t b = { 0 }; +  for (size_t i = 0; i < 3; i++) { +    b.u64x2x3.val[i] = VROLQ(a.u64x2x3.val[i], 1); +  } +  return b; +} + +// rotate bits in each lane left by amounts in vector +static inline row_t row_rotn_u64(const row_t a, const int64_t v[static 5]) { +  row_t b = { 0 }; +  static const int64x2_t k64 = { 64, 64 }; +  for (size_t i = 0; i < 3; i++) { +    const int64x2_t hi_ids = (i < 2) ? vld1q_s64(v + 2 * i) : vdupq_n_s64(v[4]), +                    lo_ids = vsubq_s64(hi_ids, k64); +    b.u64x2x3.val[i] = vorrq_u64(vshlq_u64(a.u64x2x3.val[i], hi_ids), vshlq_u64(a.u64x2x3.val[i], lo_ids)); +  } +  return b; +} + +// return logical NOT of row +static inline row_t row_not(const row_t a) { +  row_t b; +  for (size_t i = 0; i < 3; i++) { +    b.u8x16x3.val[i] = vmvnq_u8(a.u8x16x3.val[i]); +  } +  return b; +} + +// return logical OR NOT of rows +static inline row_t row_orn(const row_t a, const row_t b) { +  row_t c; +  for (size_t i = 0; i < 3; i++) { +    c.u8x16x3.val[i] = vornq_u8(a.u8x16x3.val[i], b.u8x16x3.val[i]); +  } +  return c; +} + +// apply chi permutation to entire row +// note: ~(a | ~b) = (~a & b) (demorgan's laws) +static inline row_t row_chi(const row_t a) { +  const row_t b = row_rlr(a), +              c = row_rlr(b); // fixme, permute would be faster +  return row_eor(a, row_not(row_orn(b, c))); +} + +// permute IDS to take low lane of first pair and hi lane of second pair +// a = [ a0, a1 ], b = [ b0, b1 ] => c = [ a0, b1 ] +static const uint8x16_t PI_LO_HI_IDS = { +   0,  1,  2,  3,  4,  5,  6,  7,  24, 25, 26, 27, 28, 29, 30, 31, +}; + +// permute IDS to take high lane of first pair and low lane of second pair +// a = [ a0, a1 ], b = [ b0, b1 ] => c = [ a1, b0 ] +static const uint8x16_t PI_HI_LO_IDS = { +   8,  9, 10, 11, 12, 13, 14, 15,  16, 17, 18, 19, 20, 21, 22, 23, +}; + +static inline uint8x16_t pi_tbl(const uint8x16_t a, const uint8x16_t b, const uint8x16_t ids) { +  uint8x16x2_t quad = { .val = { a, b } }; +  return vqtbl2q_u8(quad, ids); +} + +// 24-round neon keccak permutation with inlined steps +void permute_n_neon(uint64_t a[static 25], const size_t num_rounds) { +  // load rows +  row_t r0 = row_load(a +  0), +        r1 = row_load(a +  5), +        r2 = row_load(a + 10), +        r3 = row_load(a + 15), +        r4 = row_load(a + 20); + +  // loop for num rounds +  for (size_t i = 0; i < num_rounds; i++) { +    // theta +    { +      // c = r0 ^ r1 ^ r2 ^ r3 ^ r4 +      const row_t c = row_eor(row_eor(row_eor(r0, r1), row_eor(r2, r3)), r4); + +      // calculate d... +      const row_t d = row_eor(row_rll(c), row_rol1_u64(row_rlr(c))); + +      r0 = row_eor(r0, d); +      r1 = row_eor(r1, d); +      r2 = row_eor(r2, d); +      r3 = row_eor(r3, d); +      r4 = row_eor(r4, d); +    } + +    // rho +    { +      r0 = row_rotn_u64(r0, RHO_IDS +  0); +      r1 = row_rotn_u64(r1, RHO_IDS +  5); +      r2 = row_rotn_u64(r2, RHO_IDS + 10); +      r3 = row_rotn_u64(r3, RHO_IDS + 15); +      r4 = row_rotn_u64(r4, RHO_IDS + 20); +    } + +    // pi +    { +      row_t t0 = { 0 }; +      { +        // dst[ 0] = src[ 0]; dst[ 1] = src[ 6]; +        t0.u8x16x3.val[0] = pi_tbl(r0.u8x16x3.val[0], r1.u8x16x3.val[0], PI_LO_HI_IDS); +        // dst[ 2] = src[12]; dst[ 3] = src[18]; +        t0.u8x16x3.val[1] = pi_tbl(r2.u8x16x3.val[1], r3.u8x16x3.val[1], PI_LO_HI_IDS); +        // dst[ 4] = src[24]; +        t0.u8x16x3.val[2] = r4.u8x16x3.val[2]; +      } + +      row_t t1 = { 0 }; +      { + +        // dst[ 5] = src[ 3]; dst[ 6] = src[ 9]; +        t1.u8x16x3.val[0] = pi_tbl(r0.u8x16x3.val[1], r1.u8x16x3.val[2], PI_HI_LO_IDS); +        // dst[ 7] = src[10]; dst[ 8] = src[16]; +        t1.u8x16x3.val[1] = pi_tbl(r2.u8x16x3.val[0], r3.u8x16x3.val[0], PI_LO_HI_IDS); +        // dst[ 9] = src[22]; +        t1.u8x16x3.val[2] = r4.u8x16x3.val[1]; +      } + +      row_t t2 = { 0 }; +      { +        // dst[10] = src[ 1]; dst[11] = src[ 7]; +        t2.u8x16x3.val[0] = pi_tbl(r0.u8x16x3.val[0], r1.u8x16x3.val[1], PI_HI_LO_IDS); +        // dst[12] = src[13]; dst[13] = src[19]; +        t2.u8x16x3.val[1] = pi_tbl(r2.u8x16x3.val[1], r3.u8x16x3.val[2], PI_HI_LO_IDS); +        // dst[14] = src[20]; +        t2.u8x16x3.val[2] = r4.u8x16x3.val[0]; +      } + +      row_t t3 = { 0 }; +      { +        // dst[15] = src[ 4]; dst[16] = src[ 5]; +        // t3.u8x16x3.val[0] = pi_tbl(r0.u8x16x3.val[2], r1.u8x16x3.val[0], PI_LO_LO_IDS); +        t3.u64x2x3.val[0] = vtrn1q_u64(r0.u64x2x3.val[2], r1.u64x2x3.val[0]); +        // dst[17] = src[11]; dst[18] = src[17]; +        t3.u8x16x3.val[1] = pi_tbl(r2.u8x16x3.val[0], r3.u8x16x3.val[1], PI_HI_LO_IDS); +        // dst[19] = src[23]; +        t3.u8x16x3.val[2] = pi_tbl(r4.u8x16x3.val[1], r4.u8x16x3.val[1], PI_HI_LO_IDS); +      } + +      row_t t4 = { 0 }; +      { +        // dst[20] = src[ 2]; dst[21] = src[ 8]; +        t4.u8x16x3.val[0] = pi_tbl(r0.u8x16x3.val[1], r1.u8x16x3.val[1], PI_LO_HI_IDS); +        // dst[22] = src[14]; dst[23] = src[15]; +        // t4.u8x16x3.val[1] = pi_tbl(r2.u8x16x3.val[2], r3.u8x16x3.val[0], PI_LO_LO_IDS); +        t4.u64x2x3.val[1] = vtrn1q_u64(r2.u64x2x3.val[2], r3.u64x2x3.val[0]); +        // dst[24] = src[21]; +        t4.u8x16x3.val[2] = pi_tbl(r4.u8x16x3.val[0], r4.u8x16x3.val[0], PI_HI_LO_IDS); +      } + +      r0 = t0; +      r1 = t1; +      r2 = t2; +      r3 = t3; +      r4 = t4; +    } + +    // chi +    r0 = row_chi(r0); +    r1 = row_chi(r1); +    r2 = row_chi(r2); +    r3 = row_chi(r3); +    r4 = row_chi(r4); + +    // iota +    const uint64x2_t rc = { RCS[i], 0 }; +    r0.u64x2x3.val[0] ^= rc; +  } + +  // store rows +  row_store(a +  0, r0); +  row_store(a +  5, r1); +  row_store(a + 10, r2); +  row_store(a + 15, r3); +  row_store(a + 20, r4); +} +#endif /* SHA3_BACKEND == BACKEND_NEON */ +  #if SHA3_BACKEND == BACKEND_AVX512 -// use avx512 backend -#define permute_n permute_n_avx512 +#define permute_n permute_n_avx512 // use avx512 backend +#elif SHA3_BACKEND == BACKEND_NEON +#define permute_n permute_n_neon // use neon backend  #elif SHA3_BACKEND == BACKEND_SCALAR -// use scalar backend -#define permute_n permute_n_scalar +#define permute_n permute_n_scalar // use scalar backend  #else  #error "unknown sha3 backend"  #endif /* SHA3_BACKEND */ @@ -1959,6 +2253,8 @@ void k12_once(const uint8_t *src, const size_t src_len, uint8_t *dst, const size  const char *sha3_backend(void) {  #if SHA3_BACKEND == BACKEND_AVX512    return "avx512"; +#elif SHA3_BACKEND == BACKEND_NEON +  return "neon";  #elif SHA3_BACKEND == BACKEND_SCALAR    return "scalar";  #endif /* SHA3_BACKEND */  | 
