summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sha3.c329
1 files changed, 329 insertions, 0 deletions
diff --git a/sha3.c b/sha3.c
index 4073287..b2158f5 100644
--- a/sha3.c
+++ b/sha3.c
@@ -30,6 +30,7 @@
#define BACKEND_SCALAR 1 // scalar backend
#define BACKEND_AVX512 2 // AVX-512 backend
#define BACKEND_NEON 3 // Neon backend
+#define BACKEND_DIET_NEON 4 // Neon backend which uses fewer registers
// if SHA3_BACKEND is defined and set to 0 (the default), then unset it
// and auto-detect the appropriate backend
@@ -758,10 +759,302 @@ static inline void permute_n_neon(uint64_t a[static 25], const size_t num_rounds
}
#endif /* SHA3_BACKEND == BACKEND_NEON */
+#if SHA3_BACKEND == BACKEND_DIET_NEON
+#include <arm_neon.h>
+
+// rotate element in uint64x1_t left by N bits
+#define VROL(A, N) (vorr_u64(vshl_n_u64((A), (N)), vshr_n_u64((A), 64-(N))))
+
+// rotate elements in uint64x2_t left by N bits
+// note: vrax1q_u64() not supported on pizza
+#define VROLQ(A, N) (vorrq_u64(vshlq_n_u64((A), (N)), vshrq_n_u64((A), 64-(N))))
+
+// keccak row, represented as 3 128-bit vector registers
+//
+// columns are stored in the low 5 64-bit lanes. this wastes one
+// 64-bit lane per row at the expense of making many of the instructions
+// simpler.
+typedef struct {
+ uint64x2x2_t head; // first 4 columns
+ uint64x1_t tail; // last column
+} row_t;
+
+// set contents of row
+static inline row_t row_set(const uint64x2_t a, const uint64x2_t b, const uint64x1_t c) {
+ return (row_t) { .head = { { a, b } }, .tail = c };
+}
+
+// get Nth pair of u64s from row
+static inline uint64x2_t row_get(const row_t a, const size_t n) {
+ return a.head.val[n];
+}
+
+// get tail column from row
+static inline uint64x1_t row_get_tail(const row_t a) {
+ return a.tail;
+}
+
+// rotate row lanes left
+//
+// --------------------------- ---------------------------
+// | 64-bit Lanes (Before) | | 64-bit Lanes (After) |
+// |-------------------------| |-------------------------|
+// | 0 | 1 | 2 | 3 | 4 | 5 | --> | 0 | 1 | 2 | 3 | 4 | 5 |
+// |---|---|---|---|---|-----| |---|---|---|---|---|-----|
+// | A | B | C | D | E | n/a | | E | A | B | C | D | n/a |
+// --------------------------- ---------------------------
+//
+static inline row_t row_rll(const row_t a) {
+ return row_set(
+ vcombine_u64(row_get_tail(a), vdup_laneq_u64(row_get(a, 0), 0)), // { a4, a0 }
+ vextq_u64(row_get(a, 0), row_get(a, 1), 1), // { a1, a2 }
+ vdup_laneq_u64(row_get(a, 1), 1) // { a3, n/a }
+ );
+}
+
+// rotate row lanes right
+//
+// --------------------------- ---------------------------
+// | 64-bit Lanes (Before) | | 64-bit Lanes (After) |
+// |-------------------------| |-------------------------|
+// | 0 | 1 | 2 | 3 | 4 | 5 | --> | 0 | 1 | 2 | 3 | 4 | 5 |
+// |---|---|---|---|---|-----| |---|---|---|---|---|-----|
+// | A | B | C | D | E | n/a | | B | C | D | E | A | n/a |
+// --------------------------- ---------------------------
+//
+static inline row_t row_rlr(const row_t a) {
+ return row_set(
+ vextq_u64(row_get(a, 0), row_get(a, 1), 1), // { a1, a2 }
+ vcombine_u64(vdup_laneq_u64(row_get(a, 1), 1), row_get_tail(a)), // { a3, a4 }
+ vdup_laneq_u64(row_get(a, 0), 0) // { a0, n/a }
+ );
+}
+
+// c = a ^ b
+static inline row_t row_eor(const row_t a, const row_t b) {
+ return row_set(
+ row_get(a, 0) ^ row_get(b, 0),
+ row_get(a, 1) ^ row_get(b, 1),
+ row_get_tail(a) ^ row_get_tail(b)
+ );
+}
+
+// f = a ^ b ^ c ^ d ^ e
+// FIXME want: veor3_u64(a, b, c);
+static inline row_t row_eor5(const row_t a, const row_t b, const row_t c, const row_t d, const row_t e) {
+ return row_set(
+ row_get(a, 0) ^ row_get(b, 0) ^ row_get(c, 0) ^ row_get(d, 0) ^ row_get(e, 0),
+ row_get(a, 1) ^ row_get(b, 1) ^ row_get(c, 1) ^ row_get(d, 1) ^ row_get(e, 1),
+ row_get_tail(a) ^ row_get_tail(b) ^ row_get_tail(c) ^ row_get_tail(d) ^ row_get_tail(e)
+ );
+}
+
+// rotate bits in each lane left one bit
+static inline row_t row_rol1_u64(const row_t a) {
+ return row_set(
+ VROLQ(row_get(a, 0), 1),
+ VROLQ(row_get(a, 1), 1),
+ VROL(row_get_tail(a), 1)
+ );
+}
+
+// rho lane rotate values
+static const struct {
+ int64x2_t head[2];
+ int64x1_t tail;
+} RHO_IDS[] = {
+ { .head = { { 0, 1 }, { 62, 28 } }, .tail = { 27 } },
+ { .head = { { 36, 44 }, { 6, 55 } }, .tail = { 20 } },
+ { .head = { { 3, 10 }, { 43, 25 } }, .tail = { 39 } },
+ { .head = { { 41, 45 }, { 15, 21 } }, .tail = { 8 } },
+ { .head = { { 18, 2 }, { 61, 56 } }, .tail = { 14 } },
+};
+
+// apply rho rotation to row
+static inline row_t row_rho(const row_t a, const size_t id) {
+ const int64x2_t *vh = RHO_IDS[id].head;
+ const int64x1_t vt = RHO_IDS[id].tail;
+ return row_set(
+ vorrq_u64(vshlq_u64(row_get(a, 0), vh[0]), vshlq_u64(row_get(a, 0), vh[0] - 64)),
+ vorrq_u64(vshlq_u64(row_get(a, 1), vh[1]), vshlq_u64(row_get(a, 1), vh[1] - 64)),
+ vorr_u64(vshl_u64(row_get_tail(a), vt), vshl_u64(row_get_tail(a), vt - 64))
+ );
+}
+
+// c = (~a & b)
+// note: was using ~(a | ~b) = (~a & b) (demorgan's laws), but changed
+// to BIC b, a instead (b & ~a)
+static inline row_t row_andn(const row_t a, const row_t b) {
+ return row_set(
+ vbicq_u64(row_get(b, 0), row_get(a, 0)),
+ vbicq_u64(row_get(b, 1), row_get(a, 1)),
+ vbic_u64(row_get_tail(b), row_get_tail(a))
+ );
+}
+
+// apply chi permutation to entire row
+// note: ~(a | ~b) = (~a & b) (demorgan's laws)
+static inline row_t row_chi(const row_t a) {
+ return row_eor(a, row_andn(row_rlr(a), row_set(
+ row_get(a, 1), // { a2, a3 }
+ vcombine_u64(row_get_tail(a), vdup_laneq_u64(row_get(a, 0), 0)), // { a4, a0 }
+ vdup_laneq_u64(row_get(a, 0), 1) // { a1 }
+ )));
+}
+
+// return new vector with low lane of first argument and high lane of
+// second argument
+static inline uint64x2_t pi_lo_hi(const uint64x2_t a, const uint64x2_t b) {
+ // was using vqtbl2q_u8() with tables, but this is faster
+ const uint64x2_t c = vextq_u64(b, a, 1);
+ return vextq_u64(c, c, 1);
+}
+
+// neon keccak permutation with inlined steps
+static inline void permute_n_neon(uint64_t a[static 25], const size_t num_rounds) {
+ // load rows
+ row_t r0, r1, r2, r3, r4;
+ {
+ // 3 loads of 8 and 1 load of 1 cell (3*8 + 1 = 25)
+ const uint64x2x4_t m0 = vld1q_u64_x4(a + 0), // r0 cols 0-4, r1 cols 0-2
+ m1 = vld1q_u64_x4(a + 8), // r1 cols 3-4, r2 cols 0-4, r3 col 1
+ m2 = vld1q_u64_x4(a + 16); // r3 cols 1-4, r4 cols 0-3
+ const uint64x1_t m3 = vld1_u64(a + 24); // r4 col 4
+
+ // permute loaded data into rows
+ r0 = row_set(m0.val[0], m0.val[1], vdup_laneq_u64(m0.val[2], 0));
+ r1 = row_set(vextq_u64(m0.val[2], m0.val[3], 1), vextq_u64(m0.val[3], m1.val[0], 1), vdup_laneq_u64(m1.val[0], 1));
+ r2 = row_set(m1.val[1], m1.val[2], vdup_laneq_u64(m1.val[3], 0));
+ r3 = row_set(vextq_u64(m1.val[3], m2.val[0], 1), vextq_u64(m2.val[0], m2.val[1], 1), vdup_laneq_u64(m2.val[1], 1));
+ r4 = row_set(m2.val[2], m2.val[3], m3);
+ }
+
+ // loop for num rounds
+ for (size_t i = 0; i < num_rounds; i++) {
+ // theta
+ {
+ // c = r0 ^ r1 ^ r2 ^ r3 ^ r4, d = rll(c) ^ (rlr(c) << 1)
+ const row_t c = row_eor5(r0, r1, r2, r3, r4),
+ d = row_eor(row_rll(c), row_rol1_u64(row_rlr(c)));
+
+ r0 = row_eor(r0, d); // r0 ^= d
+ r1 = row_eor(r1, d); // r1 ^= d
+ r2 = row_eor(r2, d); // r2 ^= d
+ r3 = row_eor(r3, d); // r3 ^= d
+ r4 = row_eor(r4, d); // r4 ^= d
+ }
+
+ // rho
+ r0 = row_rho(r0, 0);
+ r1 = row_rho(r1, 1);
+ r2 = row_rho(r2, 2);
+ r3 = row_rho(r3, 3);
+ r4 = row_rho(r4, 4);
+
+ // pi
+ {
+ const row_t t0 = row_set(
+ pi_lo_hi(row_get(r0, 0), row_get(r1, 0)),
+ pi_lo_hi(row_get(r2, 1), row_get(r3, 1)),
+ row_get_tail(r4)
+ );
+
+ const row_t t1 = row_set(
+ vcombine_u64(vdup_laneq_u64(row_get(r0, 1), 1), row_get_tail(r1)),
+ pi_lo_hi(row_get(r2, 0), row_get(r3, 0)),
+ vdup_laneq_u64(row_get(r4, 1), 0)
+ );
+
+ const row_t t2 = row_set(
+ vextq_u64(row_get(r0, 0), row_get(r1, 1), 1),
+ vcombine_u64(vdup_laneq_u64(row_get(r2, 1), 1), row_get_tail(r3)),
+ vdup_laneq_u64(row_get(r4, 0), 0)
+ );
+
+ const row_t t3 = row_set(
+ vcombine_u64(row_get_tail(r0), vdup_laneq_u64(row_get(r1, 0), 0)),
+ vextq_u64(row_get(r2, 0), row_get(r3, 1), 1),
+ vdup_laneq_u64(row_get(r4, 1), 1)
+ );
+
+ const row_t t4 = row_set(
+ pi_lo_hi(row_get(r0, 1), row_get(r1, 1)),
+ vcombine_u64(row_get_tail(r2), vdup_laneq_u64(row_get(r3, 0), 0)),
+ vdup_laneq_u64(row_get(r4, 0), 1)
+ );
+
+ // store rows
+ r0 = t0;
+ r1 = t1;
+ r2 = t2;
+ r3 = t3;
+ r4 = t4;
+ }
+
+ // chi
+ r0 = row_chi(r0);
+ r1 = row_chi(r1);
+ r2 = row_chi(r2);
+ r3 = row_chi(r3);
+ r4 = row_chi(r4);
+
+ // iota
+ const uint64x2_t rc = { RCS[24 - num_rounds + i], 0 };
+ r0.head.val[0] ^= rc;
+ }
+
+ // store rows
+ {
+ // store columns 0-4 of r0 and columns 0-2 of r1
+ const uint64x2x4_t m0 = {
+ .val = {
+ row_get(r0, 0),
+ row_get(r0, 1),
+ vcombine_u64(row_get_tail(r0), vdup_laneq_u64(row_get(r1, 0), 0)),
+ vextq_u64(row_get(r1, 0), row_get(r1, 1), 1)
+ },
+ };
+ vst1q_u64_x4(a + 0, m0);
+ }
+
+ {
+ // store columns 3-4 of r1, columns 0-4 of r2, and column 0 of r3
+ const uint64x2x4_t m1 = {
+ .val = {
+ vcombine_u64(vdup_laneq_u64(row_get(r1, 1), 1), row_get_tail(r1)),
+ row_get(r2, 0),
+ row_get(r2, 1),
+ vcombine_u64(row_get_tail(r2), vdup_laneq_u64(row_get(r3, 0), 0)),
+ },
+ };
+ vst1q_u64_x4(a + 8, m1);
+ }
+
+ {
+ // store columns 1-4 of r3 and columns 03 of r4
+ const uint64x2x4_t m2 = {
+ .val = {
+ vextq_u64(row_get(r3, 0), row_get(r3, 1), 1),
+
+ vcombine_u64(vdup_laneq_u64(row_get(r3, 1), 1), row_get_tail(r3)),
+ row_get(r4, 0),
+ row_get(r4, 1),
+ },
+ };
+ vst1q_u64_x4(a + 16, m2);
+ }
+
+ // store column 4 of r4
+ vst1_u64(a + 24, row_get_tail(r4));
+}
+#endif /* SHA3_BACKEND == BACKEND_DIET_NEON */
+
#if SHA3_BACKEND == BACKEND_AVX512
#define permute_n permute_n_avx512 // use avx512 backend
#elif SHA3_BACKEND == BACKEND_NEON
#define permute_n permute_n_neon // use neon backend
+#elif SHA3_BACKEND == BACKEND_DIET_NEON
+#define permute_n permute_n_diet_neon // use diet-neon backend
#elif SHA3_BACKEND == BACKEND_SCALAR
#define permute_n permute_n_scalar // use scalar backend
#else
@@ -2257,6 +2550,8 @@ const char *sha3_backend(void) {
return "avx512";
#elif SHA3_BACKEND == BACKEND_NEON
return "neon";
+#elif SHA3_BACKEND == BACKEND_DIET_NEON
+ return "diet-neon";
#elif SHA3_BACKEND == BACKEND_SCALAR
return "scalar";
#endif /* SHA3_BACKEND */
@@ -2550,6 +2845,22 @@ static void test_permute_neon(void) {
#endif /* SHA3_BACKEND == BACKEND_NEON */
}
+static void test_permute_diet_neon(void) {
+#if SHA3_BACKEND == BACKEND_DIET_NEON
+ for (size_t i = 0; i < sizeof(PERMUTE_TESTS) / sizeof(PERMUTE_TESTS[0]); i++) {
+ const size_t exp_len = PERMUTE_TESTS[i].exp_len;
+
+ uint64_t got[25] = { 0 };
+ memcpy(got, PERMUTE_TESTS[i].a, sizeof(got));
+ permute_n_diet_neon(got, 24); // call permute_n_diet_neon() directly
+
+ if (memcmp(got, PERMUTE_TESTS[i].exp, exp_len)) {
+ fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE_TESTS[i].exp, exp_len);
+ }
+ }
+#endif /* SHA3_BACKEND == BACKEND_DIET_NEON */
+}
+
static const struct {
uint64_t a[25]; // input state
const uint64_t exp[25]; // expected value
@@ -2606,6 +2917,22 @@ static void test_permute12_neon(void) {
#endif /* SHA3_BACKEND == BACKEND_NEON */
}
+static void test_permute12_diet_neon(void) {
+#if SHA3_BACKEND == BACKEND_DIET_NEON
+ for (size_t i = 0; i < sizeof(PERMUTE12_TESTS) / sizeof(PERMUTE12_TESTS[0]); i++) {
+ const size_t exp_len = PERMUTE12_TESTS[i].exp_len;
+
+ uint64_t got[25] = { 0 };
+ memcpy(got, PERMUTE12_TESTS[i].a, sizeof(got));
+ permute_n_diet_neon(got, 12); // call permute_n_diet_neon() directly
+
+ if (memcmp(got, PERMUTE12_TESTS[i].exp, exp_len)) {
+ fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE12_TESTS[i].exp, exp_len);
+ }
+ }
+#endif /* SHA3_BACKEND == BACKEND_DIET_NEON */
+}
+
static void test_sha3_224(void) {
static const struct {
const char *name; // test name
@@ -6761,9 +7088,11 @@ int main(void) {
test_permute_scalar();
test_permute_avx512();
test_permute_neon();
+ test_permute_diet_neon();
test_permute12_scalar();
test_permute12_avx512();
test_permute12_neon();
+ test_permute12_diet_neon();
test_sha3_224();
test_sha3_256();
test_sha3_384();