From 7785d7419bf82205f43f935464df5d455965afa8 Mon Sep 17 00:00:00 2001
From: Paul Duncan <pabs@pablotron.org>
Date: Sun, 5 May 2024 11:46:39 -0400
Subject: sha3.c: add diet-neon backend (even slower, see commit message)

scalar bench results:
  info: cpucycles: version=20240318 implementation=arm64-vct persecond=2400000000
  info: backend=scalar num_trials=50000 src_lens=64,256,1024,4096,16384 dst_lens=32
  function,dst_len,64,256,1024,4096,16384
  sha3_224,28,20.2,10.3,10.3,9.3,9.2
  sha3_256,32,20.2,10.1,10.3,9.9,9.7
  sha3_384,48,20.9,15.1,12.8,12.7,12.5
  sha3_512,64,20.2,20.2,18.9,18.0,18.0
  shake128,32,20.2,10.1,9.0,8.1,7.9
  shake256,32,20.2,10.1,10.3,9.9,9.7

neon bench results:
  info: cpucycles: version=20240318 implementation=arm64-vct persecond=2400000000
  info: backend=neon num_trials=50000 src_lens=64,256,1024,4096,16384 dst_lens=32
  function,dst_len,64,256,1024,4096,16384
  sha3_224,28,32.7,16.2,16.3,14.8,14.5
  sha3_256,32,32.7,16.2,16.3,15.8,15.4
  sha3_384,48,32.7,24.2,20.3,20.3,20.0
  sha3_512,64,32.0,32.3,30.2,28.7,29.3
  shake128,32,34.8,16.9,14.9,13.3,13.4
  shake256,32,35.5,18.1,17.4,17.2,16.4

diet-neon bench results:
  info: cpucycles: version=20240318 implementation=arm64-vct persecond=2400000000
  info: backend=diet-neon num_trials=50000 src_lens=64,256,1024,4096,16384 dst_lens=32
  function,dst_len,64,256,1024,4096,16384
  sha3_224,28,33.4,16.5,16.6,15.1,15.0
  sha3_256,32,33.4,16.5,16.6,16.1,15.9
  sha3_384,48,33.4,25.0,21.0,20.7,21.4
  sha3_512,64,33.4,34.9,33.5,31.1,32.0
  shake128,32,36.8,18.4,16.3,14.3,14.0
  shake256,32,34.1,17.7,18.2,17.6,17.3
---
 sha3.c | 329 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 329 insertions(+)

diff --git a/sha3.c b/sha3.c
index 4073287..b2158f5 100644
--- a/sha3.c
+++ b/sha3.c
@@ -30,6 +30,7 @@
 #define BACKEND_SCALAR 1      // scalar backend
 #define BACKEND_AVX512 2      // AVX-512 backend
 #define BACKEND_NEON 3        // Neon backend
+#define BACKEND_DIET_NEON 4   // Neon backend which uses fewer registers
 
 // if SHA3_BACKEND is defined and set to 0 (the default), then unset it
 // and auto-detect the appropriate backend
@@ -758,10 +759,302 @@ static inline void permute_n_neon(uint64_t a[static 25], const size_t num_rounds
 }
 #endif /* SHA3_BACKEND == BACKEND_NEON */
 
+#if SHA3_BACKEND == BACKEND_DIET_NEON
+#include <arm_neon.h>
+
+// rotate element in uint64x1_t left by N bits
+#define VROL(A, N) (vorr_u64(vshl_n_u64((A), (N)), vshr_n_u64((A), 64-(N))))
+
+// rotate elements in uint64x2_t left by N bits
+// note: vrax1q_u64() not supported on pizza
+#define VROLQ(A, N) (vorrq_u64(vshlq_n_u64((A), (N)), vshrq_n_u64((A), 64-(N))))
+
+// keccak row, represented as 3 128-bit vector registers
+//
+// columns are stored in the low 5 64-bit lanes.  this wastes one
+// 64-bit lane per row at the expense of making many of the instructions
+// simpler.
+typedef struct {
+  uint64x2x2_t head; // first 4 columns
+  uint64x1_t tail; // last column
+} row_t;
+
+// set contents of row
+static inline row_t row_set(const uint64x2_t a, const uint64x2_t b, const uint64x1_t c) {
+  return (row_t) { .head = { { a, b } }, .tail = c };
+}
+
+// get Nth pair of u64s from row
+static inline uint64x2_t row_get(const row_t a, const size_t n) {
+  return a.head.val[n];
+}
+
+// get tail column from row
+static inline uint64x1_t row_get_tail(const row_t a) {
+  return a.tail;
+}
+
+// rotate row lanes left
+//
+// ---------------------------     ---------------------------
+// |  64-bit Lanes (Before)  |     |  64-bit Lanes (After)   |
+// |-------------------------|     |-------------------------|
+// | 0 | 1 | 2 | 3 | 4 |  5  | --> | 0 | 1 | 2 | 3 | 4 |  5  |
+// |---|---|---|---|---|-----|     |---|---|---|---|---|-----|
+// | A | B | C | D | E | n/a |     | E | A | B | C | D | n/a |
+// ---------------------------     ---------------------------
+//
+static inline row_t row_rll(const row_t a) {
+  return row_set(
+    vcombine_u64(row_get_tail(a), vdup_laneq_u64(row_get(a, 0), 0)), // { a4, a0 }
+    vextq_u64(row_get(a, 0), row_get(a, 1), 1), // { a1, a2 }
+    vdup_laneq_u64(row_get(a, 1), 1) // { a3, n/a }
+  );
+}
+
+// rotate row lanes right
+//
+// ---------------------------     ---------------------------
+// |  64-bit Lanes (Before)  |     |  64-bit Lanes (After)   |
+// |-------------------------|     |-------------------------|
+// | 0 | 1 | 2 | 3 | 4 |  5  | --> | 0 | 1 | 2 | 3 | 4 |  5  |
+// |---|---|---|---|---|-----|     |---|---|---|---|---|-----|
+// | A | B | C | D | E | n/a |     | B | C | D | E | A | n/a |
+// ---------------------------     ---------------------------
+//
+static inline row_t row_rlr(const row_t a) {
+  return row_set(
+    vextq_u64(row_get(a, 0), row_get(a, 1), 1), // { a1, a2 }
+    vcombine_u64(vdup_laneq_u64(row_get(a, 1), 1), row_get_tail(a)), // { a3, a4 }
+    vdup_laneq_u64(row_get(a, 0), 0) // { a0, n/a }
+  );
+}
+
+// c = a ^ b
+static inline row_t row_eor(const row_t a, const row_t b) {
+  return row_set(
+    row_get(a, 0) ^ row_get(b, 0),
+    row_get(a, 1) ^ row_get(b, 1),
+    row_get_tail(a) ^ row_get_tail(b)
+  );
+}
+
+// f = a ^ b ^ c ^ d ^ e
+// FIXME want: veor3_u64(a, b, c);
+static inline row_t row_eor5(const row_t a, const row_t b, const row_t c, const row_t d, const row_t e) {
+  return row_set(
+    row_get(a, 0) ^ row_get(b, 0) ^ row_get(c, 0) ^ row_get(d, 0) ^ row_get(e, 0),
+    row_get(a, 1) ^ row_get(b, 1) ^ row_get(c, 1) ^ row_get(d, 1) ^ row_get(e, 1),
+    row_get_tail(a) ^ row_get_tail(b) ^ row_get_tail(c) ^ row_get_tail(d) ^ row_get_tail(e)
+  );
+}
+
+// rotate bits in each lane left one bit
+static inline row_t row_rol1_u64(const row_t a) {
+  return row_set(
+    VROLQ(row_get(a, 0), 1),
+    VROLQ(row_get(a, 1), 1),
+    VROL(row_get_tail(a), 1)
+  );
+}
+
+// rho lane rotate values
+static const struct {
+  int64x2_t head[2];
+  int64x1_t tail;
+} RHO_IDS[] = {
+  { .head = { {  0,  1 }, { 62, 28 } }, .tail = { 27 } },
+  { .head = { { 36, 44 }, {  6, 55 } }, .tail = { 20 } },
+  { .head = { {  3, 10 }, { 43, 25 } }, .tail = { 39 } },
+  { .head = { { 41, 45 }, { 15, 21 } }, .tail = {  8 } },
+  { .head = { { 18,  2 }, { 61, 56 } }, .tail = { 14 } },
+};
+
+// apply rho rotation to row
+static inline row_t row_rho(const row_t a, const size_t id) {
+  const int64x2_t *vh = RHO_IDS[id].head;
+  const int64x1_t vt = RHO_IDS[id].tail;
+  return row_set(
+    vorrq_u64(vshlq_u64(row_get(a, 0), vh[0]), vshlq_u64(row_get(a, 0), vh[0] - 64)),
+    vorrq_u64(vshlq_u64(row_get(a, 1), vh[1]), vshlq_u64(row_get(a, 1), vh[1] - 64)),
+    vorr_u64(vshl_u64(row_get_tail(a), vt), vshl_u64(row_get_tail(a), vt - 64))
+  );
+}
+
+// c = (~a & b)
+// note: was using ~(a | ~b) = (~a & b) (demorgan's laws), but changed
+// to BIC b, a instead (b & ~a)
+static inline row_t row_andn(const row_t a, const row_t b) {
+  return row_set(
+    vbicq_u64(row_get(b, 0), row_get(a, 0)),
+    vbicq_u64(row_get(b, 1), row_get(a, 1)),
+    vbic_u64(row_get_tail(b), row_get_tail(a))
+  );
+}
+
+// apply chi permutation to entire row
+// note: ~(a | ~b) = (~a & b) (demorgan's laws)
+static inline row_t row_chi(const row_t a) {
+  return row_eor(a, row_andn(row_rlr(a), row_set(
+    row_get(a, 1),                            // { a2, a3 }
+    vcombine_u64(row_get_tail(a), vdup_laneq_u64(row_get(a, 0), 0)), // { a4, a0 }
+    vdup_laneq_u64(row_get(a, 0), 1)          // { a1 }
+  )));
+}
+
+// return new vector with low lane of first argument and high lane of
+// second argument
+static inline uint64x2_t pi_lo_hi(const uint64x2_t a, const uint64x2_t b) {
+  // was using vqtbl2q_u8() with tables, but this is faster
+  const uint64x2_t c = vextq_u64(b, a, 1);
+  return vextq_u64(c, c, 1);
+}
+
+// neon keccak permutation with inlined steps
+static inline void permute_n_neon(uint64_t a[static 25], const size_t num_rounds) {
+  // load rows
+  row_t r0, r1, r2, r3, r4;
+  {
+    // 3 loads of 8 and 1 load of 1 cell (3*8 + 1 = 25)
+    const uint64x2x4_t m0 = vld1q_u64_x4(a +  0), // r0 cols 0-4, r1 cols 0-2
+                       m1 = vld1q_u64_x4(a +  8), // r1 cols 3-4, r2 cols 0-4, r3 col 1
+                       m2 = vld1q_u64_x4(a + 16); // r3 cols 1-4, r4 cols 0-3
+    const uint64x1_t m3 = vld1_u64(a + 24);  // r4 col 4
+
+    // permute loaded data into rows
+    r0 = row_set(m0.val[0], m0.val[1], vdup_laneq_u64(m0.val[2], 0));
+    r1 = row_set(vextq_u64(m0.val[2], m0.val[3], 1), vextq_u64(m0.val[3], m1.val[0], 1), vdup_laneq_u64(m1.val[0], 1));
+    r2 = row_set(m1.val[1], m1.val[2], vdup_laneq_u64(m1.val[3], 0));
+    r3 = row_set(vextq_u64(m1.val[3], m2.val[0], 1), vextq_u64(m2.val[0], m2.val[1], 1), vdup_laneq_u64(m2.val[1], 1));
+    r4 = row_set(m2.val[2], m2.val[3], m3);
+  }
+
+  // loop for num rounds
+  for (size_t i = 0; i < num_rounds; i++) {
+    // theta
+    {
+      // c = r0 ^ r1 ^ r2 ^ r3 ^ r4, d = rll(c) ^ (rlr(c) << 1)
+      const row_t c = row_eor5(r0, r1, r2, r3, r4),
+                  d = row_eor(row_rll(c), row_rol1_u64(row_rlr(c)));
+
+      r0 = row_eor(r0, d); // r0 ^= d
+      r1 = row_eor(r1, d); // r1 ^= d
+      r2 = row_eor(r2, d); // r2 ^= d
+      r3 = row_eor(r3, d); // r3 ^= d
+      r4 = row_eor(r4, d); // r4 ^= d
+    }
+
+    // rho
+    r0 = row_rho(r0, 0);
+    r1 = row_rho(r1, 1);
+    r2 = row_rho(r2, 2);
+    r3 = row_rho(r3, 3);
+    r4 = row_rho(r4, 4);
+
+    // pi
+    {
+      const row_t t0 = row_set(
+        pi_lo_hi(row_get(r0, 0), row_get(r1, 0)),
+        pi_lo_hi(row_get(r2, 1), row_get(r3, 1)),
+        row_get_tail(r4)
+      );
+
+      const row_t t1 = row_set(
+        vcombine_u64(vdup_laneq_u64(row_get(r0, 1), 1), row_get_tail(r1)),
+        pi_lo_hi(row_get(r2, 0), row_get(r3, 0)),
+        vdup_laneq_u64(row_get(r4, 1), 0)
+      );
+
+      const row_t t2 = row_set(
+        vextq_u64(row_get(r0, 0), row_get(r1, 1), 1),
+        vcombine_u64(vdup_laneq_u64(row_get(r2, 1), 1), row_get_tail(r3)),
+        vdup_laneq_u64(row_get(r4, 0), 0)
+      );
+
+      const row_t t3 = row_set(
+        vcombine_u64(row_get_tail(r0), vdup_laneq_u64(row_get(r1, 0), 0)),
+        vextq_u64(row_get(r2, 0), row_get(r3, 1), 1),
+        vdup_laneq_u64(row_get(r4, 1), 1)
+      );
+
+      const row_t t4 = row_set(
+        pi_lo_hi(row_get(r0, 1), row_get(r1, 1)),
+        vcombine_u64(row_get_tail(r2), vdup_laneq_u64(row_get(r3, 0), 0)),
+        vdup_laneq_u64(row_get(r4, 0), 1)
+      );
+
+      // store rows
+      r0 = t0;
+      r1 = t1;
+      r2 = t2;
+      r3 = t3;
+      r4 = t4;
+    }
+
+    // chi
+    r0 = row_chi(r0);
+    r1 = row_chi(r1);
+    r2 = row_chi(r2);
+    r3 = row_chi(r3);
+    r4 = row_chi(r4);
+
+    // iota
+    const uint64x2_t rc = { RCS[24 - num_rounds + i], 0 };
+    r0.head.val[0] ^= rc;
+  }
+
+  // store rows
+  {
+    // store columns 0-4 of r0 and columns 0-2 of r1
+    const uint64x2x4_t m0 = {
+      .val = {
+        row_get(r0, 0),
+        row_get(r0, 1),
+        vcombine_u64(row_get_tail(r0), vdup_laneq_u64(row_get(r1, 0), 0)),
+        vextq_u64(row_get(r1, 0), row_get(r1, 1), 1)
+      },
+    };
+    vst1q_u64_x4(a + 0, m0);
+  }
+
+  {
+    // store columns 3-4 of r1, columns 0-4 of r2, and column 0 of r3
+    const uint64x2x4_t m1 = {
+      .val = {
+        vcombine_u64(vdup_laneq_u64(row_get(r1, 1), 1), row_get_tail(r1)),
+        row_get(r2, 0),
+        row_get(r2, 1),
+        vcombine_u64(row_get_tail(r2), vdup_laneq_u64(row_get(r3, 0), 0)),
+      },
+    };
+    vst1q_u64_x4(a + 8, m1);
+  }
+
+  {
+    // store columns 1-4 of r3 and columns 03 of r4
+    const uint64x2x4_t m2 = {
+      .val = {
+        vextq_u64(row_get(r3, 0), row_get(r3, 1), 1),
+
+        vcombine_u64(vdup_laneq_u64(row_get(r3, 1), 1), row_get_tail(r3)),
+        row_get(r4, 0),
+        row_get(r4, 1),
+      },
+    };
+    vst1q_u64_x4(a + 16, m2);
+  }
+
+  // store column 4 of r4
+  vst1_u64(a + 24, row_get_tail(r4));
+}
+#endif /* SHA3_BACKEND == BACKEND_DIET_NEON */
+
 #if SHA3_BACKEND == BACKEND_AVX512
 #define permute_n permute_n_avx512 // use avx512 backend
 #elif SHA3_BACKEND == BACKEND_NEON
 #define permute_n permute_n_neon // use neon backend
+#elif SHA3_BACKEND == BACKEND_DIET_NEON
+#define permute_n permute_n_diet_neon // use diet-neon backend
 #elif SHA3_BACKEND == BACKEND_SCALAR
 #define permute_n permute_n_scalar // use scalar backend
 #else
@@ -2257,6 +2550,8 @@ const char *sha3_backend(void) {
   return "avx512";
 #elif SHA3_BACKEND == BACKEND_NEON
   return "neon";
+#elif SHA3_BACKEND == BACKEND_DIET_NEON
+  return "diet-neon";
 #elif SHA3_BACKEND == BACKEND_SCALAR
   return "scalar";
 #endif /* SHA3_BACKEND */
@@ -2550,6 +2845,22 @@ static void test_permute_neon(void) {
 #endif /* SHA3_BACKEND == BACKEND_NEON */
 }
 
+static void test_permute_diet_neon(void) {
+#if SHA3_BACKEND == BACKEND_DIET_NEON
+  for (size_t i = 0; i < sizeof(PERMUTE_TESTS) / sizeof(PERMUTE_TESTS[0]); i++) {
+    const size_t exp_len = PERMUTE_TESTS[i].exp_len;
+
+    uint64_t got[25] = { 0 };
+    memcpy(got, PERMUTE_TESTS[i].a, sizeof(got));
+    permute_n_diet_neon(got, 24); // call permute_n_diet_neon() directly
+
+    if (memcmp(got, PERMUTE_TESTS[i].exp, exp_len)) {
+      fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE_TESTS[i].exp, exp_len);
+    }
+  }
+#endif /* SHA3_BACKEND == BACKEND_DIET_NEON */
+}
+
 static const struct {
   uint64_t a[25]; // input state
   const uint64_t exp[25]; // expected value
@@ -2606,6 +2917,22 @@ static void test_permute12_neon(void) {
 #endif /* SHA3_BACKEND == BACKEND_NEON */
 }
 
+static void test_permute12_diet_neon(void) {
+#if SHA3_BACKEND == BACKEND_DIET_NEON
+  for (size_t i = 0; i < sizeof(PERMUTE12_TESTS) / sizeof(PERMUTE12_TESTS[0]); i++) {
+    const size_t exp_len = PERMUTE12_TESTS[i].exp_len;
+
+    uint64_t got[25] = { 0 };
+    memcpy(got, PERMUTE12_TESTS[i].a, sizeof(got));
+    permute_n_diet_neon(got, 12); // call permute_n_diet_neon() directly
+
+    if (memcmp(got, PERMUTE12_TESTS[i].exp, exp_len)) {
+      fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE12_TESTS[i].exp, exp_len);
+    }
+  }
+#endif /* SHA3_BACKEND == BACKEND_DIET_NEON */
+}
+
 static void test_sha3_224(void) {
   static const struct {
     const char *name; // test name
@@ -6761,9 +7088,11 @@ int main(void) {
   test_permute_scalar();
   test_permute_avx512();
   test_permute_neon();
+  test_permute_diet_neon();
   test_permute12_scalar();
   test_permute12_avx512();
   test_permute12_neon();
+  test_permute12_diet_neon();
   test_sha3_224();
   test_sha3_256();
   test_sha3_384();
-- 
cgit v1.2.3