From 86770194e53447a0da5f8aef9c57e45feb7cc557 Mon Sep 17 00:00:00 2001 From: Paul Duncan Date: Wed, 8 May 2024 06:15:41 -0400 Subject: sha3.c: neon: refactor, add documentation - switch row_eor() from macro to static inline function - compress rho rotate values into from 15 128-bit registers to two to reduce register pressure (still spilling, though) - remove PERMUTE macro - switch from unrolled loop with macro in body of permute_n_neon() to regular loop - add documentation for register/lane layout and for compressed rho rotations with these changes the neon backend is still uses ~50% more cycles than the scalar backend, so i will probably leave it disabled for the initial release. scalar (pi5): > ./bench 2000 info: cpucycles: version=20240318 implementation=arm64-vct persecond=2400000000 info: backend=scalar num_trials=2000 src_lens=64,256,1024,4096,16384 dst_lens=32 function,dst_len,64,256,1024,4096,16384 sha3_224,28,20.2,10.3,10.3,9.3,9.2 sha3_256,32,20.2,10.3,10.3,9.9,9.7 sha3_384,48,20.9,15.3,12.8,12.7,12.5 sha3_512,64,20.2,20.2,18.9,25.3,17.9 shake128,32,20.2,10.1,9.0,8.1,7.9 shake256,32,20.2,10.3,10.3,9.9,9.7 neon backend bench results (pi5): > ./bench 2000 info: cpucycles: version=20240318 implementation=arm64-vct persecond=2400000000 info: backend=neon num_trials=2000 src_lens=64,256,1024,4096,16384 dst_lens=32 function,dst_len,64,256,1024,4096,16384 sha3_224,28,32.7,16.3,16.4,14.9,14.6 sha3_256,32,32.0,16.2,16.4,15.9,15.5 sha3_384,48,32.7,24.2,20.4,20.2,20.0 sha3_512,64,32.0,32.2,30.1,28.6,28.5 shake128,32,32.7,16.2,14.2,12.8,12.5 shake256,32,32.7,16.2,16.3,15.7,15.4 --- sha3.c | 272 +++++++++++++++++++++++++++++++++++++---------------------------- 1 file changed, 157 insertions(+), 115 deletions(-) diff --git a/sha3.c b/sha3.c index 581bbff..c00f2b7 100644 --- a/sha3.c +++ b/sha3.c @@ -552,11 +552,13 @@ static inline row_t row_rlr(const row_t a) { } // c = a ^ b -#define ROW_EOR(A, B) row_set( \ - ROW_GET(A, 0) ^ ROW_GET(B, 0), \ - ROW_GET(A, 1) ^ ROW_GET(B, 1), \ - ROW_GET(A, 2) ^ ROW_GET(B, 2) \ -) +static inline row_t row_eor(const row_t A, const row_t B) { + return row_set( + ROW_GET(A, 0) ^ ROW_GET(B, 0), + ROW_GET(A, 1) ^ ROW_GET(B, 1), + ROW_GET(A, 2) ^ ROW_GET(B, 2) + ); +} // f = a ^ b ^ c ^ d ^ e // FIXME want: veor3_u64(a, b, c); @@ -577,12 +579,69 @@ static inline row_t row_rol1_u64(const row_t a) { ); } +// encoded rho rotate values +// +// original values: +// +// static const int64x2_t +// r0_a = { 0, 1 }, r0_b = { 62, 28 }, r02_c = { 27, 39 }, +// r1_a = { 36, 44 }, r1_b = { 6, 55 }, r13_c = { 20, 8 }, +// r2_a = { 3, 10 }, r2_b = { 43, 25 }, +// r3_a = { 41, 45 }, r3_b = { 15, 21 }, +// r4_a = { 18, 2 }, r4_b = { 61, 56 }, r4_c = { 14, 0 }; +// +// low element of r[0-4]_{a,b} packed into low lane of r_ab, like so: +// +// >> v = [0, 36, 3, 41, 18, 62, 6, 43, 15, 61].each_with_index.reduce(0) { |r, (c, i)| r+(64**i)* +// c } +// => 1103290028930644224 +// >> (v >> 6*9) & 0x3f +// => 61 +// >> 6*9 +// => 54 +// >> v +// => 1103290028930644224 +// >> '0x%016x' % v +// => "0x0f4fac6f92a43900" +// +// high element of r[0-4]_{a,b} packed into high lane of r_ab, like so: +// +// >> v = [1, 44, 10, 45, 2, 28, 55, 25, 21, 56].each_with_index.reduce(0) { |r, (c, i)| r+(64**i) +// *c } +// => 1014831051886078721 +// >> '0x%016x' % v +// => "0x0e15677702b4ab01" +// +// low elements of r[0-4]_c packed into low lane of r_c, like so: +// +// >> v = [27, 20, 39, 8, 14].each_with_index.reduce(0) { |r, (c, i)| r+(64**i)*c } +// => 237139227 +// >> '0x%016x' % v +// => "0x000000000e22751b" +// +// (there are no high elements of r[0-4]_c, all zero) +// +// to extract elements, right shift by 6*Y (where Y is the row +// number), then mask to lower 6 bits (0x3f). for example, to +// extract r4_b: +// +// >> (v >> 6*9) & 0x3f +// => 61 +static const int64x2_t r_ab = { 0x0f4fac6f92a43900LL, 0x0e15677702b4ab01LL }, + r_c = { 0x000000000e22751bLL, 0 }; + // apply rho rotation to row -static inline row_t row_rho(const row_t a, const int64x2_t v0, const int64x2_t v1, const int64x2_t v2) { +static inline row_t row_rho(const row_t a, const size_t id) { + const int64x2_t v0_hi = (r_ab >> (6 * id)) & 0x3f, + v0_lo = ((r_ab >> (6 * id)) & 0x3f) - 64, + v1_hi = (r_ab >> (30 + (6 * id))) & 0x3f, + v1_lo = ((r_ab >> (30 + (6 * id))) & 0x3f) - 64, + v2_hi = (r_c >> (6 * id)) & 0x3f, + v2_lo = ((r_c >> (6 * id)) & 0x3f) - 64; return row_set( - vorrq_u64(vshlq_u64(ROW_GET(a, 0), v0), vshlq_u64(ROW_GET(a, 0), v0 - 64)), - vorrq_u64(vshlq_u64(ROW_GET(a, 1), v1), vshlq_u64(ROW_GET(a, 1), v1 - 64)), - vorrq_u64(vshlq_u64(ROW_GET(a, 2), v2), vshlq_u64(ROW_GET(a, 2), v2 - 64)) + vorrq_u64(vshlq_u64(ROW_GET(a, 0), v0_hi), vshlq_u64(ROW_GET(a, 0), v0_lo)), + vorrq_u64(vshlq_u64(ROW_GET(a, 1), v1_hi), vshlq_u64(ROW_GET(a, 1), v1_lo)), + vorrq_u64(vshlq_u64(ROW_GET(a, 2), v2_hi), vshlq_u64(ROW_GET(a, 2), v2_lo)) ); } @@ -600,7 +659,7 @@ static inline row_t row_andn(const row_t a, const row_t b) { // apply chi permutation to entire row // note: ~(a | ~b) = (~a & b) (demorgan's laws) static inline row_t row_chi(const row_t a) { - return ROW_EOR(a, row_andn(row_rlr(a), row_set( + return row_eor(a, row_andn(row_rlr(a), row_set( ROW_GET(a, 1), // { a2, a3 } vtrn1q_u64(ROW_GET(a, 2), ROW_GET(a, 0)), // { a4, a0 } vdupq_laneq_u64(ROW_GET(a, 0), 1) // { a1, n/a } @@ -615,118 +674,101 @@ static inline uint64x2_t pi_lo_hi(const uint64x2_t a, const uint64x2_t b) { return vextq_u64(c, c, 1); } -// perform one neon permutation round -#define NEON_PERMUTE_ROUND(RC) do { \ - /* theta */ \ - { \ - /* c = r0 ^ r1 ^ r2 ^ r3 ^ r4, d = rll(c) ^ (rlr(c) << 1) */ \ - const row_t c = row_eor5(r0, r1, r2, r3, r4), \ - d = ROW_EOR(row_rll(c), row_rol1_u64(row_rlr(c))); \ - \ - r0 = ROW_EOR(r0, d); /* r0 ^= d */ \ - r1 = ROW_EOR(r1, d); /* r1 ^= d */ \ - r2 = ROW_EOR(r2, d); /* r2 ^= d */ \ - r3 = ROW_EOR(r3, d); /* r3 ^= d */ \ - r4 = ROW_EOR(r4, d); /* r4 ^= d */ \ - } \ - \ - /* rho */ \ - { \ - r0 = row_rho(r0, r0_a, r0_b, r02_c); \ - r1 = row_rho(r1, r1_a, r1_b, r13_c); \ - r2 = row_rho(r2, r2_a, r2_b, vextq_s64(r02_c, r02_c, 1)); \ - r3 = row_rho(r3, r3_a, r3_b, vextq_s64(r13_c, r13_c, 1)); \ - r4 = row_rho(r4, r4_a, r4_b, r4_c); \ - } \ - \ - /* pi */ \ - { \ - const row_t t0 = row_set( \ - pi_lo_hi(ROW_GET(r0, 0), ROW_GET(r1, 0)), \ - pi_lo_hi(ROW_GET(r2, 1), ROW_GET(r3, 1)), \ - ROW_GET(r4, 2) \ - ); \ - \ - const row_t t1 = row_set( \ - vextq_u64(ROW_GET(r0, 1), ROW_GET(r1, 2), 1), \ - pi_lo_hi(ROW_GET(r2, 0), ROW_GET(r3, 0)), \ - ROW_GET(r4, 1) \ - ); \ - \ - const row_t t2 = row_set( \ - vextq_u64(ROW_GET(r0, 0), ROW_GET(r1, 1), 1), \ - vextq_u64(ROW_GET(r2, 1), ROW_GET(r3, 2), 1), \ - ROW_GET(r4, 0) \ - ); \ - \ - const row_t t3 = row_set( \ - vtrn1q_u64(ROW_GET(r0, 2), ROW_GET(r1, 0)), \ - vextq_u64(ROW_GET(r2, 0), ROW_GET(r3, 1), 1), \ - vdupq_laneq_u64(ROW_GET(r4, 1), 1) \ - ); \ - \ - const row_t t4 = row_set( \ - pi_lo_hi(ROW_GET(r0, 1), ROW_GET(r1, 1)), \ - vtrn1q_u64(ROW_GET(r2, 2), ROW_GET(r3, 0)), \ - vdupq_laneq_u64(ROW_GET(r4, 0), 1) \ - ); \ - \ - /* store rows */ \ - r0 = t0; \ - r1 = t1; \ - r2 = t2; \ - r3 = t3; \ - r4 = t4; \ - } \ - \ - /* chi */ \ - r0 = row_chi(r0); \ - r1 = row_chi(r1); \ - r2 = row_chi(r2); \ - r3 = row_chi(r3); \ - r4 = row_chi(r4); \ - \ - /* iota */ \ - r0.p0 ^= RC; \ -} while (0) - // neon keccak permutation with inlined steps static inline void permute_n_neon(uint64_t a[static 25], const size_t num_rounds) { - // rho rotate ids - static const int64x2_t - r0_a = { 0, 1 }, r0_b = { 62, 28 }, r02_c = { 27, 39 }, - r1_a = { 36, 44 }, r1_b = { 6, 55 }, r13_c = { 20, 8 }, - r2_a = { 3, 10 }, r2_b = { 43, 25 }, - r3_a = { 41, 45 }, r3_b = { 15, 21 }, - r4_a = { 18, 2 }, r4_b = { 61, 56 }, r4_c = { 14, 0 }; - - // iota round constant mask - static const uint64x2_t rc_mask = { 0xffffffffffffffffULL, 0 }; + // Map of Keccak state to 64-bit lanes of 128-bit registers + // --------------------------------------------------------- + // | | Column / Register and 64-Bit Lane | + // |-------------------------------------------------------| + // | Row | 3 | 4 | 0 | 1 | 2 | + // |-----|---------|---------|---------|---------|---------| + // | 2 | r2.p1.1 | r2.p2.0 | r2.p0.0 | r2.p0.1 | r2.p1.0 | + // | 1 | r1.p1.1 | r1.p2.0 | r1.p0.0 | r1.p0.1 | r1.p1.0 | + // | 0 | r0.p1.1 | r0.p2.0 | r0.p0.0 | r0.p0.1 | r1.p1.0 | + // | 4 | r4.p1.1 | r4.p2.0 | r4.p0.0 | r4.p0.1 | r1.p1.0 | + // | 3 | r3.p1.1 | r3.p2.0 | r3.p0.0 | r3.p0.1 | r1.p1.0 | + // --------------------------------------------------------- // load rows - row_t r0, r1, r2, r3, r4; - { - // 3 loads of 8 and 1 load of 1 cell (3*8 + 1 = 25) - const uint64x2x4_t m0 = vld1q_u64_x4(a + 0), // r0 cols 0-4, r1 cols 0-2 - m1 = vld1q_u64_x4(a + 8), // r1 cols 3-4, r2 cols 0-4, r3 col 1 - m2 = vld1q_u64_x4(a + 16); // r3 cols 1-4, r4 cols 0-3 - const uint64x2_t m3 = vld1q_dup_u64(a + 24); // r4 col 4 - - // permute loaded data into rows - r0 = row_set(m0.val[0], m0.val[1], m0.val[2]); - r1 = row_set(vextq_u64(m0.val[2], m0.val[3], 1), vextq_u64(m0.val[3], m1.val[0], 1), vextq_u64(m1.val[0], m1.val[0], 1)); - r2 = row_set(m1.val[1], m1.val[2], m1.val[3]); - r3 = row_set(vextq_u64(m1.val[3], m2.val[0], 1), vextq_u64(m2.val[0], m2.val[1], 1), vextq_u64(m2.val[1], m2.val[1], 1)); - r4 = row_set(m2.val[2], m2.val[3], m3); - } + row_t r0 = row_load(a + 0), + r1 = row_load(a + 5), + r2 = row_load(a + 10), + r3 = row_load(a + 15), + r4 = row_load(a + 20); // round loop (two rounds per iteration) - for (size_t i = 0; i < num_rounds; i+=2) { - // load next two round constants - const uint64x2_t rcs = vld1q_u64(RCS + (24 - num_rounds + i)); + for (size_t i = 24 - num_rounds; i < SHA3_NUM_ROUNDS; i++) { + /* theta */ + { + /* c = r0 ^ r1 ^ r2 ^ r3 ^ r4, d = rll(c) ^ (rlr(c) << 1) */ + const row_t c = row_eor5(r0, r1, r2, r3, r4), + d = row_eor(row_rll(c), row_rol1_u64(row_rlr(c))); + + r0 = row_eor(r0, d); /* r0 ^= d */ + r1 = row_eor(r1, d); /* r1 ^= d */ + r2 = row_eor(r2, d); /* r2 ^= d */ + r3 = row_eor(r3, d); /* r3 ^= d */ + r4 = row_eor(r4, d); /* r4 ^= d */ + } + + /* rho */ + { + r0 = row_rho(r0, 0); + r1 = row_rho(r1, 1); + r2 = row_rho(r2, 2); + r3 = row_rho(r3, 3); + r4 = row_rho(r4, 4); + } + + /* pi */ + { + const row_t t0 = row_set( + pi_lo_hi(ROW_GET(r0, 0), ROW_GET(r1, 0)), + pi_lo_hi(ROW_GET(r2, 1), ROW_GET(r3, 1)), + ROW_GET(r4, 2) + ); + + const row_t t1 = row_set( + vextq_u64(ROW_GET(r0, 1), ROW_GET(r1, 2), 1), + pi_lo_hi(ROW_GET(r2, 0), ROW_GET(r3, 0)), + ROW_GET(r4, 1) + ); + + const row_t t2 = row_set( + vextq_u64(ROW_GET(r0, 0), ROW_GET(r1, 1), 1), + vextq_u64(ROW_GET(r2, 1), ROW_GET(r3, 2), 1), + ROW_GET(r4, 0) + ); + + const row_t t3 = row_set( + vtrn1q_u64(ROW_GET(r0, 2), ROW_GET(r1, 0)), + vextq_u64(ROW_GET(r2, 0), ROW_GET(r3, 1), 1), + vdupq_laneq_u64(ROW_GET(r4, 1), 1) + ); + + const row_t t4 = row_set( + pi_lo_hi(ROW_GET(r0, 1), ROW_GET(r1, 1)), + vtrn1q_u64(ROW_GET(r2, 2), ROW_GET(r3, 0)), + vdupq_laneq_u64(ROW_GET(r4, 0), 1) + ); + + /* store rows */ + r0 = t0; + r1 = t1; + r2 = t2; + r3 = t3; + r4 = t4; + } + + /* chi */ + r0 = row_chi(r0); + r1 = row_chi(r1); + r2 = row_chi(r2); + r3 = row_chi(r3); + r4 = row_chi(r4); - NEON_PERMUTE_ROUND(rcs & rc_mask); - NEON_PERMUTE_ROUND(vextq_u64(rcs, rcs, 1) & rc_mask); + /* iota */ + r0.p0 ^= (uint64x2_t) { RCS[i], 0 }; } // store rows -- cgit v1.2.3