From 8c79eb770f232439691d8d253c691275ce534512 Mon Sep 17 00:00:00 2001
From: Paul Duncan <pabs@pablotron.org>
Date: Wed, 8 May 2024 16:52:25 -0400
Subject: sha3.c: diet-neon: misc fixes.  still too slow

---
 sha3.c | 82 +++++++++++++++++++++++++++++-------------------------------------
 1 file changed, 36 insertions(+), 46 deletions(-)

diff --git a/sha3.c b/sha3.c
index 0107dcb..21b2c78 100644
--- a/sha3.c
+++ b/sha3.c
@@ -797,24 +797,19 @@ static inline void permute_n_neon(uint64_t a[static 25], const size_t num_rounds
 // 64-bit lane per row at the expense of making many of the instructions
 // simpler.
 typedef struct {
-  uint64x2x2_t head; // first 4 columns
-  uint64x1_t tail; // last column
+  uint64x2_t p0, p1; // first 4 columns
+  uint64x1_t p2; // last column
 } row_t;
 
 // set contents of row
 static inline row_t row_set(const uint64x2_t a, const uint64x2_t b, const uint64x1_t c) {
-  return (row_t) { .head = { { a, b } }, .tail = c };
+  return (row_t) { a, b, c };
 }
 
 // get Nth pair of u64s from row
-static inline uint64x2_t row_get(const row_t a, const size_t n) {
-  return a.head.val[n];
-}
-
-// get tail column from row
-static inline uint64x1_t row_get_tail(const row_t a) {
-  return a.tail;
-}
+// (N = [0, 1], returns uint64x2_t, N == 2, returns uint64x1_t)
+#define row_get(A, N) ((A).p ## N)
+#define row_last(A) ((A).p2)
 
 // rotate row lanes left
 //
@@ -828,7 +823,7 @@ static inline uint64x1_t row_get_tail(const row_t a) {
 //
 static inline row_t row_rll(const row_t a) {
   return row_set(
-    vcombine_u64(row_get_tail(a), vdup_laneq_u64(row_get(a, 0), 0)), // { a4, a0 }
+    vcombine_u64(row_last(a), vdup_laneq_u64(row_get(a, 0), 0)), // { a4, a0 }
     vextq_u64(row_get(a, 0), row_get(a, 1), 1), // { a1, a2 }
     vdup_laneq_u64(row_get(a, 1), 1) // { a3, n/a }
   );
@@ -847,7 +842,7 @@ static inline row_t row_rll(const row_t a) {
 static inline row_t row_rlr(const row_t a) {
   return row_set(
     vextq_u64(row_get(a, 0), row_get(a, 1), 1), // { a1, a2 }
-    vcombine_u64(vdup_laneq_u64(row_get(a, 1), 1), row_get_tail(a)), // { a3, a4 }
+    vcombine_u64(vdup_laneq_u64(row_get(a, 1), 1), row_last(a)), // { a3, a4 }
     vdup_laneq_u64(row_get(a, 0), 0) // { a0, n/a }
   );
 }
@@ -857,7 +852,7 @@ static inline row_t row_eor(const row_t a, const row_t b) {
   return row_set(
     row_get(a, 0) ^ row_get(b, 0),
     row_get(a, 1) ^ row_get(b, 1),
-    row_get_tail(a) ^ row_get_tail(b)
+    row_last(a) ^ row_last(b)
   );
 }
 
@@ -867,7 +862,7 @@ static inline row_t row_eor5(const row_t a, const row_t b, const row_t c, const
   return row_set(
     row_get(a, 0) ^ row_get(b, 0) ^ row_get(c, 0) ^ row_get(d, 0) ^ row_get(e, 0),
     row_get(a, 1) ^ row_get(b, 1) ^ row_get(c, 1) ^ row_get(d, 1) ^ row_get(e, 1),
-    row_get_tail(a) ^ row_get_tail(b) ^ row_get_tail(c) ^ row_get_tail(d) ^ row_get_tail(e)
+    row_last(a) ^ row_last(b) ^ row_last(c) ^ row_last(d) ^ row_last(e)
   );
 }
 
@@ -876,30 +871,25 @@ static inline row_t row_rol1_u64(const row_t a) {
   return row_set(
     VROLQ(row_get(a, 0), 1),
     VROLQ(row_get(a, 1), 1),
-    VROL(row_get_tail(a), 1)
+    VROL(row_last(a), 1)
   );
 }
 
-// rho lane rotate values
-static const struct {
-  int64x2_t head[2];
-  int64x1_t tail;
-} RHO_IDS[] = {
-  { .head = { {  0,  1 }, { 62, 28 } }, .tail = { 27 } },
-  { .head = { { 36, 44 }, {  6, 55 } }, .tail = { 20 } },
-  { .head = { {  3, 10 }, { 43, 25 } }, .tail = { 39 } },
-  { .head = { { 41, 45 }, { 15, 21 } }, .tail = {  8 } },
-  { .head = { { 18,  2 }, { 61, 56 } }, .tail = { 14 } },
-};
+static const int64x2_t r_ab = { 0x0f4fac6f92a43900LL, 0x0e15677702b4ab01LL };
+static const int64x1_t r_c =  { 0x000000000e22751bLL };
 
 // apply rho rotation to row
 static inline row_t row_rho(const row_t a, const size_t id) {
-  const int64x2_t *vh = RHO_IDS[id].head;
-  const int64x1_t vt = RHO_IDS[id].tail;
+  const int64x2_t v0_hi = (r_ab >> (6 * id)) & 0x3f,
+                  v0_lo = ((r_ab >> (6 * id)) & 0x3f) - 64,
+                  v1_hi = (r_ab >> (30 + (6 * id))) & 0x3f,
+                  v1_lo = ((r_ab >> (30 + (6 * id))) & 0x3f) - 64;
+  const int64x1_t v2_hi = (r_c >> (6 * id)) & 0x3f,
+                  v2_lo = ((r_c >> (6 * id)) & 0x3f) - 64;
   return row_set(
-    vorrq_u64(vshlq_u64(row_get(a, 0), vh[0]), vshlq_u64(row_get(a, 0), vh[0] - 64)),
-    vorrq_u64(vshlq_u64(row_get(a, 1), vh[1]), vshlq_u64(row_get(a, 1), vh[1] - 64)),
-    vorr_u64(vshl_u64(row_get_tail(a), vt), vshl_u64(row_get_tail(a), vt - 64))
+    vorrq_u64(vshlq_u64(row_get(a, 0), v0_hi), vshlq_u64(row_get(a, 0), v0_lo)),
+    vorrq_u64(vshlq_u64(row_get(a, 1), v1_hi), vshlq_u64(row_get(a, 1), v1_lo)),
+    vorr_u64(vshl_u64(row_last(a), v2_hi), vshl_u64(row_last(a), v2_lo))
   );
 }
 
@@ -910,7 +900,7 @@ static inline row_t row_andn(const row_t a, const row_t b) {
   return row_set(
     vbicq_u64(row_get(b, 0), row_get(a, 0)),
     vbicq_u64(row_get(b, 1), row_get(a, 1)),
-    vbic_u64(row_get_tail(b), row_get_tail(a))
+    vbic_u64(row_last(b), row_last(a))
   );
 }
 
@@ -919,7 +909,7 @@ static inline row_t row_andn(const row_t a, const row_t b) {
 static inline row_t row_chi(const row_t a) {
   return row_eor(a, row_andn(row_rlr(a), row_set(
     row_get(a, 1),                            // { a2, a3 }
-    vcombine_u64(row_get_tail(a), vdup_laneq_u64(row_get(a, 0), 0)), // { a4, a0 }
+    vcombine_u64(row_last(a), vdup_laneq_u64(row_get(a, 0), 0)), // { a4, a0 }
     vdup_laneq_u64(row_get(a, 0), 1)          // { a1 }
   )));
 }
@@ -952,7 +942,7 @@ static inline void permute_n_diet_neon(uint64_t a[static 25], const size_t num_r
   }
 
   // loop for num rounds
-  for (size_t i = 0; i < num_rounds; i++) {
+  for (size_t i = 24 - num_rounds; i < 24; i++) {
     // theta
     {
       // c = r0 ^ r1 ^ r2 ^ r3 ^ r4, d = rll(c) ^ (rlr(c) << 1)
@@ -978,30 +968,30 @@ static inline void permute_n_diet_neon(uint64_t a[static 25], const size_t num_r
       const row_t t0 = row_set(
         pi_lo_hi(row_get(r0, 0), row_get(r1, 0)),
         pi_lo_hi(row_get(r2, 1), row_get(r3, 1)),
-        row_get_tail(r4)
+        row_last(r4)
       );
 
       const row_t t1 = row_set(
-        vcombine_u64(vdup_laneq_u64(row_get(r0, 1), 1), row_get_tail(r1)),
+        vcombine_u64(vdup_laneq_u64(row_get(r0, 1), 1), row_last(r1)),
         pi_lo_hi(row_get(r2, 0), row_get(r3, 0)),
         vdup_laneq_u64(row_get(r4, 1), 0)
       );
 
       const row_t t2 = row_set(
         vextq_u64(row_get(r0, 0), row_get(r1, 1), 1),
-        vcombine_u64(vdup_laneq_u64(row_get(r2, 1), 1), row_get_tail(r3)),
+        vcombine_u64(vdup_laneq_u64(row_get(r2, 1), 1), row_last(r3)),
         vdup_laneq_u64(row_get(r4, 0), 0)
       );
 
       const row_t t3 = row_set(
-        vcombine_u64(row_get_tail(r0), vdup_laneq_u64(row_get(r1, 0), 0)),
+        vcombine_u64(row_last(r0), vdup_laneq_u64(row_get(r1, 0), 0)),
         vextq_u64(row_get(r2, 0), row_get(r3, 1), 1),
         vdup_laneq_u64(row_get(r4, 1), 1)
       );
 
       const row_t t4 = row_set(
         pi_lo_hi(row_get(r0, 1), row_get(r1, 1)),
-        vcombine_u64(row_get_tail(r2), vdup_laneq_u64(row_get(r3, 0), 0)),
+        vcombine_u64(row_last(r2), vdup_laneq_u64(row_get(r3, 0), 0)),
         vdup_laneq_u64(row_get(r4, 0), 1)
       );
 
@@ -1021,8 +1011,8 @@ static inline void permute_n_diet_neon(uint64_t a[static 25], const size_t num_r
     r4 = row_chi(r4);
 
     // iota
-    const uint64x2_t rc = { RCS[24 - num_rounds + i], 0 };
-    r0.head.val[0] ^= rc;
+    const uint64x2_t rc = { RCS[i], 0 };
+    r0.p0 ^= rc;
   }
 
   // store rows
@@ -1032,7 +1022,7 @@ static inline void permute_n_diet_neon(uint64_t a[static 25], const size_t num_r
       .val = {
         row_get(r0, 0),
         row_get(r0, 1),
-        vcombine_u64(row_get_tail(r0), vdup_laneq_u64(row_get(r1, 0), 0)),
+        vcombine_u64(row_last(r0), vdup_laneq_u64(row_get(r1, 0), 0)),
         vextq_u64(row_get(r1, 0), row_get(r1, 1), 1)
       },
     };
@@ -1043,10 +1033,10 @@ static inline void permute_n_diet_neon(uint64_t a[static 25], const size_t num_r
     // store columns 3-4 of r1, columns 0-4 of r2, and column 0 of r3
     const uint64x2x4_t m1 = {
       .val = {
-        vcombine_u64(vdup_laneq_u64(row_get(r1, 1), 1), row_get_tail(r1)),
+        vcombine_u64(vdup_laneq_u64(row_get(r1, 1), 1), row_last(r1)),
         row_get(r2, 0),
         row_get(r2, 1),
-        vcombine_u64(row_get_tail(r2), vdup_laneq_u64(row_get(r3, 0), 0)),
+        vcombine_u64(row_last(r2), vdup_laneq_u64(row_get(r3, 0), 0)),
       },
     };
     vst1q_u64_x4(a + 8, m1);
@@ -1058,7 +1048,7 @@ static inline void permute_n_diet_neon(uint64_t a[static 25], const size_t num_r
       .val = {
         vextq_u64(row_get(r3, 0), row_get(r3, 1), 1),
 
-        vcombine_u64(vdup_laneq_u64(row_get(r3, 1), 1), row_get_tail(r3)),
+        vcombine_u64(vdup_laneq_u64(row_get(r3, 1), 1), row_last(r3)),
         row_get(r4, 0),
         row_get(r4, 1),
       },
-- 
cgit v1.2.3