From 8c79eb770f232439691d8d253c691275ce534512 Mon Sep 17 00:00:00 2001 From: Paul Duncan Date: Wed, 8 May 2024 16:52:25 -0400 Subject: sha3.c: diet-neon: misc fixes. still too slow --- sha3.c | 82 +++++++++++++++++++++++++++++------------------------------------- 1 file changed, 36 insertions(+), 46 deletions(-) diff --git a/sha3.c b/sha3.c index 0107dcb..21b2c78 100644 --- a/sha3.c +++ b/sha3.c @@ -797,24 +797,19 @@ static inline void permute_n_neon(uint64_t a[static 25], const size_t num_rounds // 64-bit lane per row at the expense of making many of the instructions // simpler. typedef struct { - uint64x2x2_t head; // first 4 columns - uint64x1_t tail; // last column + uint64x2_t p0, p1; // first 4 columns + uint64x1_t p2; // last column } row_t; // set contents of row static inline row_t row_set(const uint64x2_t a, const uint64x2_t b, const uint64x1_t c) { - return (row_t) { .head = { { a, b } }, .tail = c }; + return (row_t) { a, b, c }; } // get Nth pair of u64s from row -static inline uint64x2_t row_get(const row_t a, const size_t n) { - return a.head.val[n]; -} - -// get tail column from row -static inline uint64x1_t row_get_tail(const row_t a) { - return a.tail; -} +// (N = [0, 1], returns uint64x2_t, N == 2, returns uint64x1_t) +#define row_get(A, N) ((A).p ## N) +#define row_last(A) ((A).p2) // rotate row lanes left // @@ -828,7 +823,7 @@ static inline uint64x1_t row_get_tail(const row_t a) { // static inline row_t row_rll(const row_t a) { return row_set( - vcombine_u64(row_get_tail(a), vdup_laneq_u64(row_get(a, 0), 0)), // { a4, a0 } + vcombine_u64(row_last(a), vdup_laneq_u64(row_get(a, 0), 0)), // { a4, a0 } vextq_u64(row_get(a, 0), row_get(a, 1), 1), // { a1, a2 } vdup_laneq_u64(row_get(a, 1), 1) // { a3, n/a } ); @@ -847,7 +842,7 @@ static inline row_t row_rll(const row_t a) { static inline row_t row_rlr(const row_t a) { return row_set( vextq_u64(row_get(a, 0), row_get(a, 1), 1), // { a1, a2 } - vcombine_u64(vdup_laneq_u64(row_get(a, 1), 1), row_get_tail(a)), // { a3, a4 } + vcombine_u64(vdup_laneq_u64(row_get(a, 1), 1), row_last(a)), // { a3, a4 } vdup_laneq_u64(row_get(a, 0), 0) // { a0, n/a } ); } @@ -857,7 +852,7 @@ static inline row_t row_eor(const row_t a, const row_t b) { return row_set( row_get(a, 0) ^ row_get(b, 0), row_get(a, 1) ^ row_get(b, 1), - row_get_tail(a) ^ row_get_tail(b) + row_last(a) ^ row_last(b) ); } @@ -867,7 +862,7 @@ static inline row_t row_eor5(const row_t a, const row_t b, const row_t c, const return row_set( row_get(a, 0) ^ row_get(b, 0) ^ row_get(c, 0) ^ row_get(d, 0) ^ row_get(e, 0), row_get(a, 1) ^ row_get(b, 1) ^ row_get(c, 1) ^ row_get(d, 1) ^ row_get(e, 1), - row_get_tail(a) ^ row_get_tail(b) ^ row_get_tail(c) ^ row_get_tail(d) ^ row_get_tail(e) + row_last(a) ^ row_last(b) ^ row_last(c) ^ row_last(d) ^ row_last(e) ); } @@ -876,30 +871,25 @@ static inline row_t row_rol1_u64(const row_t a) { return row_set( VROLQ(row_get(a, 0), 1), VROLQ(row_get(a, 1), 1), - VROL(row_get_tail(a), 1) + VROL(row_last(a), 1) ); } -// rho lane rotate values -static const struct { - int64x2_t head[2]; - int64x1_t tail; -} RHO_IDS[] = { - { .head = { { 0, 1 }, { 62, 28 } }, .tail = { 27 } }, - { .head = { { 36, 44 }, { 6, 55 } }, .tail = { 20 } }, - { .head = { { 3, 10 }, { 43, 25 } }, .tail = { 39 } }, - { .head = { { 41, 45 }, { 15, 21 } }, .tail = { 8 } }, - { .head = { { 18, 2 }, { 61, 56 } }, .tail = { 14 } }, -}; +static const int64x2_t r_ab = { 0x0f4fac6f92a43900LL, 0x0e15677702b4ab01LL }; +static const int64x1_t r_c = { 0x000000000e22751bLL }; // apply rho rotation to row static inline row_t row_rho(const row_t a, const size_t id) { - const int64x2_t *vh = RHO_IDS[id].head; - const int64x1_t vt = RHO_IDS[id].tail; + const int64x2_t v0_hi = (r_ab >> (6 * id)) & 0x3f, + v0_lo = ((r_ab >> (6 * id)) & 0x3f) - 64, + v1_hi = (r_ab >> (30 + (6 * id))) & 0x3f, + v1_lo = ((r_ab >> (30 + (6 * id))) & 0x3f) - 64; + const int64x1_t v2_hi = (r_c >> (6 * id)) & 0x3f, + v2_lo = ((r_c >> (6 * id)) & 0x3f) - 64; return row_set( - vorrq_u64(vshlq_u64(row_get(a, 0), vh[0]), vshlq_u64(row_get(a, 0), vh[0] - 64)), - vorrq_u64(vshlq_u64(row_get(a, 1), vh[1]), vshlq_u64(row_get(a, 1), vh[1] - 64)), - vorr_u64(vshl_u64(row_get_tail(a), vt), vshl_u64(row_get_tail(a), vt - 64)) + vorrq_u64(vshlq_u64(row_get(a, 0), v0_hi), vshlq_u64(row_get(a, 0), v0_lo)), + vorrq_u64(vshlq_u64(row_get(a, 1), v1_hi), vshlq_u64(row_get(a, 1), v1_lo)), + vorr_u64(vshl_u64(row_last(a), v2_hi), vshl_u64(row_last(a), v2_lo)) ); } @@ -910,7 +900,7 @@ static inline row_t row_andn(const row_t a, const row_t b) { return row_set( vbicq_u64(row_get(b, 0), row_get(a, 0)), vbicq_u64(row_get(b, 1), row_get(a, 1)), - vbic_u64(row_get_tail(b), row_get_tail(a)) + vbic_u64(row_last(b), row_last(a)) ); } @@ -919,7 +909,7 @@ static inline row_t row_andn(const row_t a, const row_t b) { static inline row_t row_chi(const row_t a) { return row_eor(a, row_andn(row_rlr(a), row_set( row_get(a, 1), // { a2, a3 } - vcombine_u64(row_get_tail(a), vdup_laneq_u64(row_get(a, 0), 0)), // { a4, a0 } + vcombine_u64(row_last(a), vdup_laneq_u64(row_get(a, 0), 0)), // { a4, a0 } vdup_laneq_u64(row_get(a, 0), 1) // { a1 } ))); } @@ -952,7 +942,7 @@ static inline void permute_n_diet_neon(uint64_t a[static 25], const size_t num_r } // loop for num rounds - for (size_t i = 0; i < num_rounds; i++) { + for (size_t i = 24 - num_rounds; i < 24; i++) { // theta { // c = r0 ^ r1 ^ r2 ^ r3 ^ r4, d = rll(c) ^ (rlr(c) << 1) @@ -978,30 +968,30 @@ static inline void permute_n_diet_neon(uint64_t a[static 25], const size_t num_r const row_t t0 = row_set( pi_lo_hi(row_get(r0, 0), row_get(r1, 0)), pi_lo_hi(row_get(r2, 1), row_get(r3, 1)), - row_get_tail(r4) + row_last(r4) ); const row_t t1 = row_set( - vcombine_u64(vdup_laneq_u64(row_get(r0, 1), 1), row_get_tail(r1)), + vcombine_u64(vdup_laneq_u64(row_get(r0, 1), 1), row_last(r1)), pi_lo_hi(row_get(r2, 0), row_get(r3, 0)), vdup_laneq_u64(row_get(r4, 1), 0) ); const row_t t2 = row_set( vextq_u64(row_get(r0, 0), row_get(r1, 1), 1), - vcombine_u64(vdup_laneq_u64(row_get(r2, 1), 1), row_get_tail(r3)), + vcombine_u64(vdup_laneq_u64(row_get(r2, 1), 1), row_last(r3)), vdup_laneq_u64(row_get(r4, 0), 0) ); const row_t t3 = row_set( - vcombine_u64(row_get_tail(r0), vdup_laneq_u64(row_get(r1, 0), 0)), + vcombine_u64(row_last(r0), vdup_laneq_u64(row_get(r1, 0), 0)), vextq_u64(row_get(r2, 0), row_get(r3, 1), 1), vdup_laneq_u64(row_get(r4, 1), 1) ); const row_t t4 = row_set( pi_lo_hi(row_get(r0, 1), row_get(r1, 1)), - vcombine_u64(row_get_tail(r2), vdup_laneq_u64(row_get(r3, 0), 0)), + vcombine_u64(row_last(r2), vdup_laneq_u64(row_get(r3, 0), 0)), vdup_laneq_u64(row_get(r4, 0), 1) ); @@ -1021,8 +1011,8 @@ static inline void permute_n_diet_neon(uint64_t a[static 25], const size_t num_r r4 = row_chi(r4); // iota - const uint64x2_t rc = { RCS[24 - num_rounds + i], 0 }; - r0.head.val[0] ^= rc; + const uint64x2_t rc = { RCS[i], 0 }; + r0.p0 ^= rc; } // store rows @@ -1032,7 +1022,7 @@ static inline void permute_n_diet_neon(uint64_t a[static 25], const size_t num_r .val = { row_get(r0, 0), row_get(r0, 1), - vcombine_u64(row_get_tail(r0), vdup_laneq_u64(row_get(r1, 0), 0)), + vcombine_u64(row_last(r0), vdup_laneq_u64(row_get(r1, 0), 0)), vextq_u64(row_get(r1, 0), row_get(r1, 1), 1) }, }; @@ -1043,10 +1033,10 @@ static inline void permute_n_diet_neon(uint64_t a[static 25], const size_t num_r // store columns 3-4 of r1, columns 0-4 of r2, and column 0 of r3 const uint64x2x4_t m1 = { .val = { - vcombine_u64(vdup_laneq_u64(row_get(r1, 1), 1), row_get_tail(r1)), + vcombine_u64(vdup_laneq_u64(row_get(r1, 1), 1), row_last(r1)), row_get(r2, 0), row_get(r2, 1), - vcombine_u64(row_get_tail(r2), vdup_laneq_u64(row_get(r3, 0), 0)), + vcombine_u64(row_last(r2), vdup_laneq_u64(row_get(r3, 0), 0)), }, }; vst1q_u64_x4(a + 8, m1); @@ -1058,7 +1048,7 @@ static inline void permute_n_diet_neon(uint64_t a[static 25], const size_t num_r .val = { vextq_u64(row_get(r3, 0), row_get(r3, 1), 1), - vcombine_u64(vdup_laneq_u64(row_get(r3, 1), 1), row_get_tail(r3)), + vcombine_u64(vdup_laneq_u64(row_get(r3, 1), 1), row_last(r3)), row_get(r4, 0), row_get(r4, 1), }, -- cgit v1.2.3