aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sha3.c314
1 files changed, 305 insertions, 9 deletions
diff --git a/sha3.c b/sha3.c
index 8563d27..375ee3b 100644
--- a/sha3.c
+++ b/sha3.c
@@ -26,16 +26,19 @@
/** @cond INTERNAL */
// available backends
-#define BACKEND_AVX512 8 // avx512 backend
-#define BACKEND_SCALAR 0 // scalar (default) backend
+#define BACKEND_AVX512 8 // AVX-512 backend
+#define BACKEND_NEON 4 // A64 Neon backend
+#define BACKEND_SCALAR 0 // scalar (default) backend
// auto-detect backend
#ifndef SHA3_BACKEND
-#ifdef __AVX512F__
+#if defined(__AVX512F__)
#define SHA3_BACKEND BACKEND_AVX512
-#else /* !__AVX512F__ */
+#elif defined(__ARM_NEON)
+#define SHA3_BACKEND BACKEND_NEON
+#else
#define SHA3_BACKEND BACKEND_SCALAR
-#endif /* __AVX512F__ */
+#endif
#endif /* SHA3_BACKEND */
// 64-bit rotate left
@@ -462,12 +465,303 @@ static inline void permute_n_avx512(uint64_t s[static 25], const size_t num_roun
}
#endif /* SHA3_BACKEND == BACKEND_AVX512 */
+#if SHA3_BACKEND == BACKEND_NEON
+#include <arm_neon.h>
+
+// vrax1q_u64() not supported on pizza
+#define VROLQ(A, N) (vorrq_u64(vshlq_n_u64((A), (N)), vshrq_n_u64((A), 64-(N))))
+
+// keccak row, represented as 3 128-bit vector registers
+//
+// columns are stored in the low 5 64-bit lanes. this wastes one
+// 64-bit lane per row at the expense of making many of the instructions
+// simpler.
+typedef union {
+ // uint64_t u64[6];
+ uint64x2x3_t u64x2x3;
+ // uint8_t u8[48];
+ // uint8x16_t u8x16[3];
+ uint8x16x3_t u8x16x3;
+} row_t;
+
+// TODO
+// add row_load_fast which reads 6 elems and does this
+// r2 = { .u64x2x3 = vld1q_u64_x3(a + 10) },
+
+// load row from array
+static inline row_t row_load(const uint64_t p[static 5]) {
+ row_t a = { 0 };
+
+ a.u64x2x3.val[0] = vld1q_u64(p + 0);
+ a.u64x2x3.val[1] = vld1q_u64(p + 2);
+ a.u64x2x3.val[2] = vdupq_n_u64(p[4]);
+
+ return a;
+}
+
+// store row to array
+static inline void row_store(uint64_t p[static 5], const row_t a) {
+ // row_print(stderr, __func__, a);
+ vst1q_u64(p + 0, a.u64x2x3.val[0]);
+ vst1q_u64(p + 2, a.u64x2x3.val[1]);
+ vst1_u64(p + 4, vdup_laneq_u64(a.u64x2x3.val[2], 0));
+ // p[4] = vgetq_lane_u64(a.u64x2x3.val[2], 0);
+}
+
+// low lane ids for rol_rc{l,r}()
+static const uint8x16_t ROW_RL_LO_IDS = {
+ 8, 9, 10, 11, 12, 13, 14, 15, 99, 99, 99, 99, 99, 99, 99, 99,
+};
+
+// high lane ids for rol_rc{l,r}()
+static const uint8x16_t ROW_RL_HI_IDS = {
+ 99, 99, 99, 99, 99, 99, 99, 99, 0, 1, 2, 3, 4, 5, 6, 7,
+};
+
+// low lanes for last iteration of row_rlll() and first iteration of row_rlr()
+static const uint8x16_t ROW_RL_TAIL_IDS = {
+ 0, 1, 2, 3, 4, 5, 6, 7, 99, 99, 99, 99, 99, 99, 99, 99,
+};
+
+// rotate row lanes left
+//
+// --------------------------- ---------------------------
+// | 64-bit Lanes (Before) | | 64-bit Lanes (After) |
+// |-------------------------| |-------------------------|
+// | 0 | 1 | 2 | 3 | 4 | 5 | --> | 0 | 1 | 2 | 3 | 4 | 5 |
+// |---|---|---|---|---|-----| |---|---|---|---|---|-----|
+// | A | B | C | D | E | n/a | | E | A | B | C | D | n/a |
+// --------------------------- ---------------------------
+//
+static inline row_t row_rll(const row_t a) {
+ row_t b = { 0 };
+ for (size_t i = 0; i < 3; i++) {
+ const uint8x16_t lo_ids = i ? ROW_RL_LO_IDS : ROW_RL_TAIL_IDS,
+ hi = vqtbl1q_u8(a.u8x16x3.val[i], ROW_RL_HI_IDS),
+ lo = vqtbl1q_u8(a.u8x16x3.val[(i + 2) % 3], lo_ids);
+ b.u8x16x3.val[i] = vorrq_u8(lo, hi);
+ }
+ return b;
+}
+
+// rotate row lanes right
+//
+// --------------------------- ---------------------------
+// | 64-bit Lanes (Before) | | 64-bit Lanes (After) |
+// |-------------------------| |-------------------------|
+// | 0 | 1 | 2 | 3 | 4 | 5 | --> | 0 | 1 | 2 | 3 | 4 | 5 |
+// |---|---|---|---|---|-----| |---|---|---|---|---|-----|
+// | A | B | C | D | E | n/a | | B | C | D | E | A | n/a |
+// --------------------------- ---------------------------
+//
+static row_t row_rlr(const row_t a) {
+ row_t b = { 0 };
+ for (size_t i = 0; i < 2; i++) {
+ const uint8x16_t lo = vqtbl1q_u8(a.u8x16x3.val[i], ROW_RL_LO_IDS),
+ hi = vqtbl1q_u8(a.u8x16x3.val[(i + 1) % 3], ROW_RL_HI_IDS);
+ b.u8x16x3.val[i] = vorrq_u8(lo, hi);
+ }
+ b.u8x16x3.val[2] = vqtbl1q_u8(a.u8x16x3.val[0], ROW_RL_TAIL_IDS);
+ return b;
+}
+
+// c = a ^ b
+static inline row_t row_eor(const row_t a, const row_t b) {
+ row_t c = a;
+ for (size_t i = 0; i < 3; i++) {
+ c.u8x16x3.val[i] ^= b.u8x16x3.val[i];
+ }
+ return c;
+}
+
+// rotate bits in each lane left one bit
+static inline row_t row_rol1_u64(const row_t a) {
+ row_t b = { 0 };
+ for (size_t i = 0; i < 3; i++) {
+ b.u64x2x3.val[i] = VROLQ(a.u64x2x3.val[i], 1);
+ }
+ return b;
+}
+
+// rotate bits in each lane left by amounts in vector
+static inline row_t row_rotn_u64(const row_t a, const int64_t v[static 5]) {
+ row_t b = { 0 };
+ static const int64x2_t k64 = { 64, 64 };
+ for (size_t i = 0; i < 3; i++) {
+ const int64x2_t hi_ids = (i < 2) ? vld1q_s64(v + 2 * i) : vdupq_n_s64(v[4]),
+ lo_ids = vsubq_s64(hi_ids, k64);
+ b.u64x2x3.val[i] = vorrq_u64(vshlq_u64(a.u64x2x3.val[i], hi_ids), vshlq_u64(a.u64x2x3.val[i], lo_ids));
+ }
+ return b;
+}
+
+// return logical NOT of row
+static inline row_t row_not(const row_t a) {
+ row_t b;
+ for (size_t i = 0; i < 3; i++) {
+ b.u8x16x3.val[i] = vmvnq_u8(a.u8x16x3.val[i]);
+ }
+ return b;
+}
+
+// return logical OR NOT of rows
+static inline row_t row_orn(const row_t a, const row_t b) {
+ row_t c;
+ for (size_t i = 0; i < 3; i++) {
+ c.u8x16x3.val[i] = vornq_u8(a.u8x16x3.val[i], b.u8x16x3.val[i]);
+ }
+ return c;
+}
+
+// apply chi permutation to entire row
+// note: ~(a | ~b) = (~a & b) (demorgan's laws)
+static inline row_t row_chi(const row_t a) {
+ const row_t b = row_rlr(a),
+ c = row_rlr(b); // fixme, permute would be faster
+ return row_eor(a, row_not(row_orn(b, c)));
+}
+
+// permute IDS to take low lane of first pair and hi lane of second pair
+// a = [ a0, a1 ], b = [ b0, b1 ] => c = [ a0, b1 ]
+static const uint8x16_t PI_LO_HI_IDS = {
+ 0, 1, 2, 3, 4, 5, 6, 7, 24, 25, 26, 27, 28, 29, 30, 31,
+};
+
+// permute IDS to take high lane of first pair and low lane of second pair
+// a = [ a0, a1 ], b = [ b0, b1 ] => c = [ a1, b0 ]
+static const uint8x16_t PI_HI_LO_IDS = {
+ 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
+};
+
+static inline uint8x16_t pi_tbl(const uint8x16_t a, const uint8x16_t b, const uint8x16_t ids) {
+ uint8x16x2_t quad = { .val = { a, b } };
+ return vqtbl2q_u8(quad, ids);
+}
+
+// 24-round neon keccak permutation with inlined steps
+void permute_n_neon(uint64_t a[static 25], const size_t num_rounds) {
+ // load rows
+ row_t r0 = row_load(a + 0),
+ r1 = row_load(a + 5),
+ r2 = row_load(a + 10),
+ r3 = row_load(a + 15),
+ r4 = row_load(a + 20);
+
+ // loop for num rounds
+ for (size_t i = 0; i < num_rounds; i++) {
+ // theta
+ {
+ // c = r0 ^ r1 ^ r2 ^ r3 ^ r4
+ const row_t c = row_eor(row_eor(row_eor(r0, r1), row_eor(r2, r3)), r4);
+
+ // calculate d...
+ const row_t d = row_eor(row_rll(c), row_rol1_u64(row_rlr(c)));
+
+ r0 = row_eor(r0, d);
+ r1 = row_eor(r1, d);
+ r2 = row_eor(r2, d);
+ r3 = row_eor(r3, d);
+ r4 = row_eor(r4, d);
+ }
+
+ // rho
+ {
+ r0 = row_rotn_u64(r0, RHO_IDS + 0);
+ r1 = row_rotn_u64(r1, RHO_IDS + 5);
+ r2 = row_rotn_u64(r2, RHO_IDS + 10);
+ r3 = row_rotn_u64(r3, RHO_IDS + 15);
+ r4 = row_rotn_u64(r4, RHO_IDS + 20);
+ }
+
+ // pi
+ {
+ row_t t0 = { 0 };
+ {
+ // dst[ 0] = src[ 0]; dst[ 1] = src[ 6];
+ t0.u8x16x3.val[0] = pi_tbl(r0.u8x16x3.val[0], r1.u8x16x3.val[0], PI_LO_HI_IDS);
+ // dst[ 2] = src[12]; dst[ 3] = src[18];
+ t0.u8x16x3.val[1] = pi_tbl(r2.u8x16x3.val[1], r3.u8x16x3.val[1], PI_LO_HI_IDS);
+ // dst[ 4] = src[24];
+ t0.u8x16x3.val[2] = r4.u8x16x3.val[2];
+ }
+
+ row_t t1 = { 0 };
+ {
+
+ // dst[ 5] = src[ 3]; dst[ 6] = src[ 9];
+ t1.u8x16x3.val[0] = pi_tbl(r0.u8x16x3.val[1], r1.u8x16x3.val[2], PI_HI_LO_IDS);
+ // dst[ 7] = src[10]; dst[ 8] = src[16];
+ t1.u8x16x3.val[1] = pi_tbl(r2.u8x16x3.val[0], r3.u8x16x3.val[0], PI_LO_HI_IDS);
+ // dst[ 9] = src[22];
+ t1.u8x16x3.val[2] = r4.u8x16x3.val[1];
+ }
+
+ row_t t2 = { 0 };
+ {
+ // dst[10] = src[ 1]; dst[11] = src[ 7];
+ t2.u8x16x3.val[0] = pi_tbl(r0.u8x16x3.val[0], r1.u8x16x3.val[1], PI_HI_LO_IDS);
+ // dst[12] = src[13]; dst[13] = src[19];
+ t2.u8x16x3.val[1] = pi_tbl(r2.u8x16x3.val[1], r3.u8x16x3.val[2], PI_HI_LO_IDS);
+ // dst[14] = src[20];
+ t2.u8x16x3.val[2] = r4.u8x16x3.val[0];
+ }
+
+ row_t t3 = { 0 };
+ {
+ // dst[15] = src[ 4]; dst[16] = src[ 5];
+ // t3.u8x16x3.val[0] = pi_tbl(r0.u8x16x3.val[2], r1.u8x16x3.val[0], PI_LO_LO_IDS);
+ t3.u64x2x3.val[0] = vtrn1q_u64(r0.u64x2x3.val[2], r1.u64x2x3.val[0]);
+ // dst[17] = src[11]; dst[18] = src[17];
+ t3.u8x16x3.val[1] = pi_tbl(r2.u8x16x3.val[0], r3.u8x16x3.val[1], PI_HI_LO_IDS);
+ // dst[19] = src[23];
+ t3.u8x16x3.val[2] = pi_tbl(r4.u8x16x3.val[1], r4.u8x16x3.val[1], PI_HI_LO_IDS);
+ }
+
+ row_t t4 = { 0 };
+ {
+ // dst[20] = src[ 2]; dst[21] = src[ 8];
+ t4.u8x16x3.val[0] = pi_tbl(r0.u8x16x3.val[1], r1.u8x16x3.val[1], PI_LO_HI_IDS);
+ // dst[22] = src[14]; dst[23] = src[15];
+ // t4.u8x16x3.val[1] = pi_tbl(r2.u8x16x3.val[2], r3.u8x16x3.val[0], PI_LO_LO_IDS);
+ t4.u64x2x3.val[1] = vtrn1q_u64(r2.u64x2x3.val[2], r3.u64x2x3.val[0]);
+ // dst[24] = src[21];
+ t4.u8x16x3.val[2] = pi_tbl(r4.u8x16x3.val[0], r4.u8x16x3.val[0], PI_HI_LO_IDS);
+ }
+
+ r0 = t0;
+ r1 = t1;
+ r2 = t2;
+ r3 = t3;
+ r4 = t4;
+ }
+
+ // chi
+ r0 = row_chi(r0);
+ r1 = row_chi(r1);
+ r2 = row_chi(r2);
+ r3 = row_chi(r3);
+ r4 = row_chi(r4);
+
+ // iota
+ const uint64x2_t rc = { RCS[i], 0 };
+ r0.u64x2x3.val[0] ^= rc;
+ }
+
+ // store rows
+ row_store(a + 0, r0);
+ row_store(a + 5, r1);
+ row_store(a + 10, r2);
+ row_store(a + 15, r3);
+ row_store(a + 20, r4);
+}
+#endif /* SHA3_BACKEND == BACKEND_NEON */
+
#if SHA3_BACKEND == BACKEND_AVX512
-// use avx512 backend
-#define permute_n permute_n_avx512
+#define permute_n permute_n_avx512 // use avx512 backend
+#elif SHA3_BACKEND == BACKEND_NEON
+#define permute_n permute_n_neon // use neon backend
#elif SHA3_BACKEND == BACKEND_SCALAR
-// use scalar backend
-#define permute_n permute_n_scalar
+#define permute_n permute_n_scalar // use scalar backend
#else
#error "unknown sha3 backend"
#endif /* SHA3_BACKEND */
@@ -1959,6 +2253,8 @@ void k12_once(const uint8_t *src, const size_t src_len, uint8_t *dst, const size
const char *sha3_backend(void) {
#if SHA3_BACKEND == BACKEND_AVX512
return "avx512";
+#elif SHA3_BACKEND == BACKEND_NEON
+ return "neon";
#elif SHA3_BACKEND == BACKEND_SCALAR
return "scalar";
#endif /* SHA3_BACKEND */