aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--tests/permute/permute.c62
1 files changed, 52 insertions, 10 deletions
diff --git a/tests/permute/permute.c b/tests/permute/permute.c
index 9ba5095..7c8afa6 100644
--- a/tests/permute/permute.c
+++ b/tests/permute/permute.c
@@ -12,6 +12,7 @@
// number of rounds for permute()
#define SHA3_NUM_ROUNDS 24
+// scalar impl of theta step
static void theta_scalar(uint64_t a[static 25]) {
const uint64_t c[5] = {
a[0] ^ a[5] ^ a[10] ^ a[15] ^ a[20],
@@ -36,6 +37,7 @@ static void theta_scalar(uint64_t a[static 25]) {
a[20] ^= d[0]; a[21] ^= d[1]; a[22] ^= d[2]; a[23] ^= d[3]; a[24] ^= d[4];
}
+// scalar impl of rho step
static void rho_scalar(uint64_t a[static 25]) {
a[1] = ROL(a[1], 1); // 1 % 64 = 1
a[2] = ROL(a[2], 62); // 190 % 64 = 62
@@ -63,6 +65,7 @@ static void rho_scalar(uint64_t a[static 25]) {
a[24] = ROL(a[24], 14); // 78 % 64 = 14
}
+// scalar impl of pi step
static void pi_scalar(uint64_t a[static 25]) {
uint64_t t[25] = { 0 };
memcpy(t, a, sizeof(t));
@@ -92,6 +95,7 @@ static void pi_scalar(uint64_t a[static 25]) {
a[24] = t[21];
}
+// scalar impl of chi step
static void chi_scalar(uint64_t a[static 25]) {
uint64_t t[25] = { 0 };
memcpy(t, a, sizeof(t));
@@ -122,6 +126,7 @@ static void chi_scalar(uint64_t a[static 25]) {
a[24] = t[24] ^ (~t[20] & t[21]);
}
+// scalar impl of iota step
static void iota_scalar(uint64_t a[static 25], const int i) {
// round constants (ambiguous in spec)
static const uint64_t RCS[] = {
@@ -136,7 +141,7 @@ static void iota_scalar(uint64_t a[static 25], const int i) {
a[0] ^= RCS[i];
}
-// keccak permutation (scalar implementation).
+// scalar impl of keccak permutation.
void permute_scalar(uint64_t a[static 25], const size_t num_rounds) {
for (int i = 0; i < (int) num_rounds; i++) {
theta_scalar(a);
@@ -150,6 +155,7 @@ void permute_scalar(uint64_t a[static 25], const size_t num_rounds) {
#ifdef __AVX512F__
#include <immintrin.h>
+// avx512 impl of theta step
static void theta_avx512(uint64_t s[static 25]) {
// unaligned load mask and permutation indices
uint8_t mask = 0x1f;
@@ -190,6 +196,7 @@ static void theta_avx512(uint64_t s[static 25]) {
_mm512_mask_storeu_epi64((void*) (s + 20), m, r4);
}
+// avx512 impl of rho step
static void rho_avx512(uint64_t s[static 25]) {
// unaligned load mask and rotate values
uint8_t mask = 0x1f;
@@ -229,6 +236,15 @@ static void rho_avx512(uint64_t s[static 25]) {
_mm512_mask_storeu_epi64((void*) (s + 20), m, r4);
}
+// avx512 impl of pi step
+//
+// note: i originally tried a simpler implementation which just copied
+// the state array out to a temporary buffer and then gathered it back
+// in shuffled order.
+//
+// the "simpler" implementation did not work correctly, so i rewrote the
+// shuffling as in-register permutations (which should actually be
+// faster).
static void pi_avx512(uint64_t s[static 25]) {
// mask bytes
uint8_t mask = 0x1f,
@@ -241,14 +257,8 @@ static void pi_avx512(uint64_t s[static 25]) {
m01 = _load_mask8(&m01b),
m23 = _load_mask8(&m23b),
m4 = _load_mask8(&m4b);
- // permutation indices (offsets into state array)
- // static uint64_t vs0[8] = { 0, 6, 12, 18, 24, 0, 0, 0 },
- // vs1[8] = { 3, 9, 10, 16, 22, 0, 0, 0 },
- // vs2[8] = { 1, 7, 13, 19, 20, 0, 0, 0 },
- // vs3[8] = { 4, 5, 11, 17, 23, 0, 0, 0 },
- // vs4[8] = { 2, 8, 14, 15, 21, 0, 0, 0 };
- // permutation indices
+ // permutation indices (offsets into state array)
//
// (note: these are masked so only the relevant indices for
// _mm512_maskz_permutex2var_epi64() in each array are filled in)
@@ -334,6 +344,7 @@ static void pi_avx512(uint64_t s[static 25]) {
_mm512_mask_storeu_epi64((void*) (s + 20), m, t4);
}
+// avx512 impl of chi step
static void chi_avx512(uint64_t s[static 25]) {
// mask bytes
uint8_t mask = 0x1f;
@@ -384,6 +395,7 @@ static void chi_avx512(uint64_t s[static 25]) {
_mm512_mask_storeu_epi64((void*) (s + 20), m, t4);
}
+// avx512 impl of iota step
static void iota_avx512(uint64_t s[static 25], const int i) {
// round constants (ambiguous in spec)
static const uint64_t RCS[] = {
@@ -407,7 +419,7 @@ static void iota_avx512(uint64_t s[static 25], const int i) {
_mm512_mask_storeu_epi64((void*) s, m0, t0);
}
-// keccak permutation (slow avx512 implementation).
+// slow avx512 impl of keccak permutation.
void permute_avx512(uint64_t a[static 25], const size_t num_rounds) {
for (int i = 0; i < (int) num_rounds; i++) {
theta_avx512(a);
@@ -418,7 +430,18 @@ void permute_avx512(uint64_t a[static 25], const size_t num_rounds) {
}
}
-// keccak permutation (fast avx512 implementation).
+// fast avx512 impl of keccak permutation
+//
+// this version is similar to permute_avx_512_slow(), except the
+// function calls are inlined as blocks, duplicate definitions have
+// been removed, and the state array loads and stores only happen at the
+// beginning and end of the function.
+//
+// there are still several optimizations that can be done. for example:
+// - iota does not need to reload the round constant every round
+// (only every 8th round)
+// - some spills could be addressed
+// - probably more unnecessary register usage
void permute_avx512_fast(uint64_t s[static 25], const size_t num_rounds) {
// unaligned load mask and permutation indices
uint8_t mask = 0x1f;
@@ -654,6 +677,14 @@ void permute_avx512_fast(uint64_t s[static 25], const size_t num_rounds) {
}
#endif /* __AVX512F__ */
+// verify that both state arrays are identical.
+//
+// if the state arrays are equal, print "$TEST_NAME passed" to standard
+// output. if the state arrays are NOT equal, then print the contents
+// of both state arrays with the differing cells highlighted to standard
+// error, then exit with an error.
+//
+// used by test_*() functions below.
static void check(const char *name, uint64_t a_scalar[static 25], uint64_t a_avx512[static 25]) {
// compare
if (!memcmp(a_scalar, a_avx512, 25 * sizeof(uint64_t))) {
@@ -689,6 +720,7 @@ static void check(const char *name, uint64_t a_scalar[static 25], uint64_t a_avx
exit(-1);
}
+// test avx512 theta
static void test_theta(void) {
uint64_t a_scalar[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 };
uint64_t a_avx512[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 };
@@ -700,6 +732,7 @@ static void test_theta(void) {
check("test_theta()", a_scalar, a_avx512);
}
+// test avx512 rho
static void test_rho(void) {
uint64_t a_scalar[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 };
uint64_t a_avx512[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 };
@@ -711,6 +744,7 @@ static void test_rho(void) {
check("test_rho()", a_scalar, a_avx512);
}
+// test avx512 pi
static void test_pi(void) {
uint64_t a_scalar[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 };
uint64_t a_avx512[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 };
@@ -722,6 +756,7 @@ static void test_pi(void) {
check("test_pi()", a_scalar, a_avx512);
}
+// test avx512 chi
static void test_chi(void) {
uint64_t a_scalar[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 };
uint64_t a_avx512[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 };
@@ -733,6 +768,7 @@ static void test_chi(void) {
check("test_chi()", a_scalar, a_avx512);
}
+// test avx512 iota
static void test_iota(void) {
uint64_t a_scalar[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 };
uint64_t a_avx512[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 };
@@ -751,6 +787,7 @@ static void test_iota(void) {
}
}
+// test avx512 permute_slow
static void test_permute_slow(void) {
uint64_t a_scalar[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 };
uint64_t a_avx512[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 };
@@ -763,6 +800,7 @@ static void test_permute_slow(void) {
check("test_permute_slow()", a_scalar, a_avx512);
}
+// test avx512 permute_fast
static void test_permute_fast(void) {
#ifdef __AVX512F__
uint64_t a_scalar[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 };
@@ -777,8 +815,10 @@ static void test_permute_fast(void) {
#endif /* __AVX512F__ */
}
+// number of times to run permutation in timing tests below
#define NUM_TIME_PERMUTES 10000000
+// time scalar keccak permutation
static void time_permute_scalar(void) {
uint64_t a[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 };
@@ -794,6 +834,7 @@ static void time_permute_scalar(void) {
printf("time_permute_scalar(): %zu\n", t1 - t0);
}
+// time slow avx512 keccak permutation
static void time_permute_avx512_slow(void) {
#ifdef __AVX512F__
uint64_t a[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 };
@@ -813,6 +854,7 @@ static void time_permute_avx512_slow(void) {
#endif /* __AVX512F__ */
}
+// time fast avx512 keccak permutation
static void time_permute_avx512_fast(void) {
#ifdef __AVX512F__
uint64_t a[25] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24 };