From ae421618db3b68ccda95f54d1c9e8d05b2dab90a Mon Sep 17 00:00:00 2001 From: Paul Duncan Date: Mon, 6 May 2024 21:56:10 -0400 Subject: sha3.c: neon backend now twice the speed of scalar backend (~50% fewer cyles, see commit message) made the following changes: - row_t contents are now 3 uint64x2_t instead of uin64x2x3_t (so they are stored as registers instead of memory) - fetch round constants 2 at a time - round loop unrolled once - drop convoluted ext/trn store (hard to read, doesn't help) bench results ------------- scalar backend: > make clean all SHA3_BACKEND=1 ... > ./bench 10000 info: cpucycles: version=20240318 implementation=arm64-vct persecond=2400000000 info: backend=scalar num_trials=10000 src_lens=64,256,1024,4096,16384 dst_lens=32 function,dst_len,64,256,1024,4096,16384 sha3_224,28,20.2,10.3,10.3,9.3,9.2 sha3_256,32,20.2,10.3,10.3,9.9,9.7 sha3_384,48,20.9,15.3,12.8,12.7,12.7 sha3_512,64,20.2,20.2,18.9,17.9,18.1 shake128,32,20.2,10.3,9.0,8.1,7.9 shake256,32,20.2,10.1,10.3,9.9,9.7 neon backend: > make clean all SHA3_BACKEND=3 ... > ./bench 10000 info: cpucycles: version=20240318 implementation=arm64-vct persecond=2400000000 info: backend=neon num_trials=10000 src_lens=64,256,1024,4096,16384 dst_lens=32 function,dst_len,64,256,1024,4096,16384 sha3_224,28,9.7,5.0,5.0,4.6,4.5 sha3_256,32,9.7,5.0,5.0,4.9,4.8 sha3_384,48,9.7,7.3,6.2,6.2,6.1 sha3_512,64,9.7,9.7,9.1,8.7,8.7 shake128,32,9.7,5.0,4.5,4.0,4.0 shake256,32,9.7,5.0,5.1,4.9,4.8 --- sha3.c | 290 +++++++++++++++++++++++++++++++---------------------------------- 1 file changed, 137 insertions(+), 153 deletions(-) (limited to 'sha3.c') diff --git a/sha3.c b/sha3.c index 3101530..28916a3 100644 --- a/sha3.c +++ b/sha3.c @@ -486,18 +486,21 @@ static inline void permute_n_avx512(uint64_t s[static 25], const size_t num_roun // 64-bit lane per row at the expense of making many of the instructions // simpler. typedef union { - uint64x2x3_t u64x2x3; + uint64x2_t p0, p1, p2; } row_t; // set contents of row static inline row_t row_set(const uint64x2_t a, const uint64x2_t b, const uint64x2_t c) { - return (row_t) { .u64x2x3 = { .val = { a, b, c } } }; + row_t r; + r.p0 = a; + r.p1 = b; + r.p2 = c; + + return r; } // get Nth pair of u64s from row -static inline uint64x2_t row_get(const row_t a, const size_t n) { - return a.u64x2x3.val[n]; -} +#define ROW_GET(A, N) ((A).p ## N) // load row from array static inline row_t row_load(const uint64_t p[static 5]) { @@ -506,14 +509,15 @@ static inline row_t row_load(const uint64_t p[static 5]) { // load row from array static inline row_t row_load_unsafe(const uint64_t p[static 6]) { - return (row_t) { .u64x2x3 = vld1q_u64_x3(p) }; + const uint64x2x3_t d = vld1q_u64_x3(p); + return row_set(d.val[0], d.val[1], d.val[2]); } // store row to array static inline void row_store(uint64_t p[static 5], const row_t a) { - const uint64x2x2_t vals = { .val = { row_get(a, 0), row_get(a, 1) } }; + const uint64x2x2_t vals = { .val = { ROW_GET(a, 0), ROW_GET(a, 1) } }; vst1q_u64_x2(p + 0, vals); - vst1_u64(p + 4, vdup_laneq_u64(row_get(a, 2), 0)); + vst1_u64(p + 4, vdup_laneq_u64(ROW_GET(a, 2), 0)); } // rotate row lanes left @@ -528,9 +532,9 @@ static inline void row_store(uint64_t p[static 5], const row_t a) { // static inline row_t row_rll(const row_t a) { return row_set( - vzip1q_u64(row_get(a, 2), row_get(a, 0)), // { a4, a0 } - vextq_u64(row_get(a, 0), row_get(a, 1), 1), // { a1, a2 } - vdupq_laneq_u64(row_get(a, 1), 1) // { a3, n/a } + vzip1q_u64(ROW_GET(a, 2), ROW_GET(a, 0)), // { a4, a0 } + vextq_u64(ROW_GET(a, 0), ROW_GET(a, 1), 1), // { a1, a2 } + vdupq_laneq_u64(ROW_GET(a, 1), 1) // { a3, n/a } ); } @@ -546,56 +550,44 @@ static inline row_t row_rll(const row_t a) { // static inline row_t row_rlr(const row_t a) { return row_set( - vextq_u64(row_get(a, 0), row_get(a, 1), 1), // { a1, a2 } - vextq_u64(row_get(a, 1), row_get(a, 2), 1), // { a3, a4 } - row_get(a, 0) // { a0, n/a } + vextq_u64(ROW_GET(a, 0), ROW_GET(a, 1), 1), // { a1, a2 } + vextq_u64(ROW_GET(a, 1), ROW_GET(a, 2), 1), // { a3, a4 } + ROW_GET(a, 0) // { a0, n/a } ); } // c = a ^ b -static inline row_t row_eor(const row_t a, const row_t b) { - return row_set( - row_get(a, 0) ^ row_get(b, 0), - row_get(a, 1) ^ row_get(b, 1), - row_get(a, 2) ^ row_get(b, 2) - ); -} +#define ROW_EOR(A, B) row_set( \ + ROW_GET(A, 0) ^ ROW_GET(B, 0), \ + ROW_GET(A, 1) ^ ROW_GET(B, 1), \ + ROW_GET(A, 2) ^ ROW_GET(B, 2) \ +) // f = a ^ b ^ c ^ d ^ e // FIXME want: veor3_u64(a, b, c); static inline row_t row_eor5(const row_t a, const row_t b, const row_t c, const row_t d, const row_t e) { return row_set( - row_get(a, 0) ^ row_get(b, 0) ^ row_get(c, 0) ^ row_get(d, 0) ^ row_get(e, 0), - row_get(a, 1) ^ row_get(b, 1) ^ row_get(c, 1) ^ row_get(d, 1) ^ row_get(e, 1), - row_get(a, 2) ^ row_get(b, 2) ^ row_get(c, 2) ^ row_get(d, 2) ^ row_get(e, 2) + ROW_GET(a, 0) ^ ROW_GET(b, 0) ^ ROW_GET(c, 0) ^ ROW_GET(d, 0) ^ ROW_GET(e, 0), + ROW_GET(a, 1) ^ ROW_GET(b, 1) ^ ROW_GET(c, 1) ^ ROW_GET(d, 1) ^ ROW_GET(e, 1), + ROW_GET(a, 2) ^ ROW_GET(b, 2) ^ ROW_GET(c, 2) ^ ROW_GET(d, 2) ^ ROW_GET(e, 2) ); } // rotate bits in each lane left one bit static inline row_t row_rol1_u64(const row_t a) { return row_set( - VROLQ(row_get(a, 0), 1), - VROLQ(row_get(a, 1), 1), - VROLQ(row_get(a, 2), 1) + VROLQ(ROW_GET(a, 0), 1), + VROLQ(ROW_GET(a, 1), 1), + VROLQ(ROW_GET(a, 2), 1) ); } -// rho lane rotate values -static const int64x2x3_t RHO_IDS[] = { - { .val = { { 0, 1 }, { 62, 28 }, { 27, 0 } } }, - { .val = { { 36, 44 }, { 6, 55 }, { 20, 0 } } }, - { .val = { { 3, 10 }, { 43, 25 }, { 39, 0 } } }, - { .val = { { 41, 45 }, { 15, 21 }, { 8, 0 } } }, - { .val = { { 18, 2 }, { 61, 56 }, { 14, 0 } } }, -}; - // apply rho rotation to row -static inline row_t row_rho(const row_t a, const size_t id) { - const int64x2x3_t v = RHO_IDS[id]; +static inline row_t row_rho(const row_t a, const int64x2_t v0, const int64x2_t v1, const int64x2_t v2) { return row_set( - vorrq_u64(vshlq_u64(row_get(a, 0), v.val[0]), vshlq_u64(row_get(a, 0), v.val[0] - 64)), - vorrq_u64(vshlq_u64(row_get(a, 1), v.val[1]), vshlq_u64(row_get(a, 1), v.val[1] - 64)), - vorrq_u64(vshlq_u64(row_get(a, 2), v.val[2]), vshlq_u64(row_get(a, 2), v.val[2] - 64)) + vorrq_u64(vshlq_u64(ROW_GET(a, 0), v0), vshlq_u64(ROW_GET(a, 0), v0 - 64)), + vorrq_u64(vshlq_u64(ROW_GET(a, 1), v1), vshlq_u64(ROW_GET(a, 1), v1 - 64)), + vorrq_u64(vshlq_u64(ROW_GET(a, 2), v2), vshlq_u64(ROW_GET(a, 2), v2 - 64)) ); } @@ -604,19 +596,19 @@ static inline row_t row_rho(const row_t a, const size_t id) { // to BIC b, a instead (b & ~a) static inline row_t row_andn(const row_t a, const row_t b) { return row_set( - vbicq_u64(row_get(b, 0), row_get(a, 0)), - vbicq_u64(row_get(b, 1), row_get(a, 1)), - vbicq_u64(row_get(b, 2), row_get(a, 2)) + vbicq_u64(ROW_GET(b, 0), ROW_GET(a, 0)), + vbicq_u64(ROW_GET(b, 1), ROW_GET(a, 1)), + vbicq_u64(ROW_GET(b, 2), ROW_GET(a, 2)) ); } // apply chi permutation to entire row // note: ~(a | ~b) = (~a & b) (demorgan's laws) static inline row_t row_chi(const row_t a) { - return row_eor(a, row_andn(row_rlr(a), row_set( - row_get(a, 1), // { a2, a3 } - vtrn1q_u64(row_get(a, 2), row_get(a, 0)), // { a4, a0 } - vdupq_laneq_u64(row_get(a, 0), 1) // { a1, n/a } + return ROW_EOR(a, row_andn(row_rlr(a), row_set( + ROW_GET(a, 1), // { a2, a3 } + vtrn1q_u64(ROW_GET(a, 2), ROW_GET(a, 0)), // { a4, a0 } + vdupq_laneq_u64(ROW_GET(a, 0), 1) // { a1, n/a } ))); } @@ -628,8 +620,94 @@ static inline uint64x2_t pi_lo_hi(const uint64x2_t a, const uint64x2_t b) { return vextq_u64(c, c, 1); } +// perform one neon permutation round +#define NEON_PERMUTE_ROUND(RC) do { \ + /* theta */ \ + { \ + /* c = r0 ^ r1 ^ r2 ^ r3 ^ r4, d = rll(c) ^ (rlr(c) << 1) */ \ + const row_t c = row_eor5(r0, r1, r2, r3, r4), \ + d = ROW_EOR(row_rll(c), row_rol1_u64(row_rlr(c))); \ + \ + r0 = ROW_EOR(r0, d); /* r0 ^= d */ \ + r1 = ROW_EOR(r1, d); /* r1 ^= d */ \ + r2 = ROW_EOR(r2, d); /* r2 ^= d */ \ + r3 = ROW_EOR(r3, d); /* r3 ^= d */ \ + r4 = ROW_EOR(r4, d); /* r4 ^= d */ \ + } \ + \ + /* rho */ \ + { \ + r0 = row_rho(r0, r0_a, r0_b, r02_c); \ + r1 = row_rho(r1, r1_a, r1_b, r13_c); \ + r2 = row_rho(r2, r2_a, r2_b, vextq_s64(r02_c, r02_c, 1)); \ + r3 = row_rho(r3, r3_a, r3_b, vextq_s64(r13_c, r13_c, 1)); \ + r4 = row_rho(r4, r4_a, r4_b, r4_c); \ + } \ + \ + /* pi */ \ + { \ + const row_t t0 = row_set( \ + pi_lo_hi(ROW_GET(r0, 0), ROW_GET(r1, 0)), \ + pi_lo_hi(ROW_GET(r2, 1), ROW_GET(r3, 1)), \ + ROW_GET(r4, 2) \ + ); \ + \ + const row_t t1 = row_set( \ + vextq_u64(ROW_GET(r0, 1), ROW_GET(r1, 2), 1), \ + pi_lo_hi(ROW_GET(r2, 0), ROW_GET(r3, 0)), \ + ROW_GET(r4, 1) \ + ); \ + \ + const row_t t2 = row_set( \ + vextq_u64(ROW_GET(r0, 0), ROW_GET(r1, 1), 1), \ + vextq_u64(ROW_GET(r2, 1), ROW_GET(r3, 2), 1), \ + ROW_GET(r4, 0) \ + ); \ + \ + const row_t t3 = row_set( \ + vtrn1q_u64(ROW_GET(r0, 2), ROW_GET(r1, 0)), \ + vextq_u64(ROW_GET(r2, 0), ROW_GET(r3, 1), 1), \ + vdupq_laneq_u64(ROW_GET(r4, 1), 1) \ + ); \ + \ + const row_t t4 = row_set( \ + pi_lo_hi(ROW_GET(r0, 1), ROW_GET(r1, 1)), \ + vtrn1q_u64(ROW_GET(r2, 2), ROW_GET(r3, 0)), \ + vdupq_laneq_u64(ROW_GET(r4, 0), 1) \ + ); \ + \ + /* store rows */ \ + r0 = t0; \ + r1 = t1; \ + r2 = t2; \ + r3 = t3; \ + r4 = t4; \ + } \ + \ + /* chi */ \ + r0 = row_chi(r0); \ + r1 = row_chi(r1); \ + r2 = row_chi(r2); \ + r3 = row_chi(r3); \ + r4 = row_chi(r4); \ + \ + /* iota */ \ + r0.p0 ^= RC; \ +} while (0) + // neon keccak permutation with inlined steps static inline void permute_n_neon(uint64_t a[static 25], const size_t num_rounds) { + // rho rotate ids + static const int64x2_t + r0_a = { 0, 1 }, r0_b = { 62, 28 }, r02_c = { 27, 39 }, + r1_a = { 36, 44 }, r1_b = { 6, 55 }, r13_c = { 20, 8 }, + r2_a = { 3, 10 }, r2_b = { 43, 25 }, + r3_a = { 41, 45 }, r3_b = { 15, 21 }, + r4_a = { 18, 2 }, r4_b = { 61, 56 }, r4_c = { 14, 0 }; + + // iota round constant mask + static const uint64x2_t rc_mask = { 0xffffffffffffffffULL, 0 }; + // load rows row_t r0, r1, r2, r3, r4; { @@ -647,115 +725,21 @@ static inline void permute_n_neon(uint64_t a[static 25], const size_t num_rounds r4 = row_set(m2.val[2], m2.val[3], m3); } - // loop for num rounds - for (size_t i = 0; i < num_rounds; i++) { - // theta - { - // c = r0 ^ r1 ^ r2 ^ r3 ^ r4, d = rll(c) ^ (rlr(c) << 1) - const row_t c = row_eor5(r0, r1, r2, r3, r4), - d = row_eor(row_rll(c), row_rol1_u64(row_rlr(c))); - - r0 = row_eor(r0, d); // r0 ^= d - r1 = row_eor(r1, d); // r1 ^= d - r2 = row_eor(r2, d); // r2 ^= d - r3 = row_eor(r3, d); // r3 ^= d - r4 = row_eor(r4, d); // r4 ^= d - } - - // rho - r0 = row_rho(r0, 0); - r1 = row_rho(r1, 1); - r2 = row_rho(r2, 2); - r3 = row_rho(r3, 3); - r4 = row_rho(r4, 4); + // round loop (two rounds per iteration) + for (size_t i = 0; i < num_rounds; i+=2) { + // load next two round constants + const uint64x2_t rcs = vld1q_u64(RCS + (24 - num_rounds + i)); - // pi - { - const row_t t0 = row_set( - pi_lo_hi(row_get(r0, 0), row_get(r1, 0)), - pi_lo_hi(row_get(r2, 1), row_get(r3, 1)), - row_get(r4, 2) - ); - - const row_t t1 = row_set( - vextq_u64(row_get(r0, 1), row_get(r1, 2), 1), - pi_lo_hi(row_get(r2, 0), row_get(r3, 0)), - row_get(r4, 1) - ); - - const row_t t2 = row_set( - vextq_u64(row_get(r0, 0), row_get(r1, 1), 1), - vextq_u64(row_get(r2, 1), row_get(r3, 2), 1), - row_get(r4, 0) - ); - - const row_t t3 = row_set( - vtrn1q_u64(row_get(r0, 2), row_get(r1, 0)), - vextq_u64(row_get(r2, 0), row_get(r3, 1), 1), - vdupq_laneq_u64(row_get(r4, 1), 1) - ); - - const row_t t4 = row_set( - pi_lo_hi(row_get(r0, 1), row_get(r1, 1)), - vtrn1q_u64(row_get(r2, 2), row_get(r3, 0)), - vdupq_laneq_u64(row_get(r4, 0), 1) - ); - - // store rows - r0 = t0; - r1 = t1; - r2 = t2; - r3 = t3; - r4 = t4; - } - - // chi - r0 = row_chi(r0); - r1 = row_chi(r1); - r2 = row_chi(r2); - r3 = row_chi(r3); - r4 = row_chi(r4); - - // iota - const uint64x2_t rc = { RCS[24 - num_rounds + i], 0 }; - r0.u64x2x3.val[0] ^= rc; + NEON_PERMUTE_ROUND(rcs & rc_mask); + NEON_PERMUTE_ROUND(vextq_u64(rcs, rcs, 1) & rc_mask); } // store rows - { - // store columns 0-4 of r0 and columns 0-2 of r1 - vst1q_u64_x4(a + 0, (uint64x2x4_t) { - .val = { - row_get(r0, 0), - row_get(r0, 1), - vtrn1q_u64(row_get(r0, 2), row_get(r1, 0)), - vextq_u64(row_get(r1, 0), row_get(r1, 1), 1) - }, - }); - - // store columns 3-4 of r1, columns 0-4 of r2, and column 0 of r3 - vst1q_u64_x4(a + 8, (uint64x2x4_t) { - .val = { - vextq_u64(row_get(r1, 1), row_get(r1, 2), 1), - row_get(r2, 0), - row_get(r2, 1), - vtrn1q_u64(row_get(r2, 2), row_get(r3, 0)), - }, - }); - - // store columns 1-4 of r3 and columns 03 of r4 - vst1q_u64_x4(a + 16, (uint64x2x4_t) { - .val = { - vextq_u64(row_get(r3, 0), row_get(r3, 1), 1), - vextq_u64(row_get(r3, 1), row_get(r3, 2), 1), - row_get(r4, 0), - row_get(r4, 1), - }, - }); - - // store column 4 of r4 - vst1_u64(a + 24, vdup_laneq_u64(row_get(r4, 2), 0)); - } + row_store(a + 0, r0); + row_store(a + 5, r1); + row_store(a + 10, r2); + row_store(a + 15, r3); + row_store(a + 20, r4); } #endif /* SHA3_BACKEND == BACKEND_NEON */ -- cgit v1.2.3