aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--tests/permute/permute.c66
1 files changed, 34 insertions, 32 deletions
diff --git a/tests/permute/permute.c b/tests/permute/permute.c
index 7c8afa6..3a236a0 100644
--- a/tests/permute/permute.c
+++ b/tests/permute/permute.c
@@ -444,12 +444,28 @@ void permute_avx512(uint64_t a[static 25], const size_t num_rounds) {
// - probably more unnecessary register usage
void permute_avx512_fast(uint64_t s[static 25], const size_t num_rounds) {
// unaligned load mask and permutation indices
- uint8_t mask = 0x1f;
- const __mmask8 m = _load_mask8(&mask);
+ uint8_t mask = 0x1f,
+ m0b = 0x01;
+ const __mmask8 m = _load_mask8(&mask),
+ m0 = _load_mask8(&m0b);
+
+ // round constants (used in iota)
+ static const uint64_t RCS[] = {
+ 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, 0x8000000080008000ULL,
+ 0x000000000000808bULL, 0x0000000080000001ULL, 0x8000000080008081ULL, 0x8000000000008009ULL,
+ 0x000000000000008aULL, 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL,
+ 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, 0x8000000000008003ULL,
+ 0x8000000000008002ULL, 0x8000000000000080ULL, 0x000000000000800aULL, 0x800000008000000aULL,
+ 0x8000000080008081ULL, 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL,
+ };
+
+ // load round constant
+ // note: this will bomb if num_rounds < 8 or num_rounds > 24.
+ __m512i rc = _mm512_loadu_epi64((void*) (RCS + 24 - num_rounds));
- // load rc permutation (TODO)
- // static const uint64_t rc_ps[8] = { 1, 2, 3, 4, 5, 6, 7, 0 },
- // const __m512i rc_p = _mm512_loadu_epi64(m, (void*) rc_ps);
+ // load rc permutation
+ static const uint64_t rc_ps[8] = { 1, 2, 3, 4, 5, 6, 7, 0 };
+ const __m512i rc_p = _mm512_loadu_epi64((void*) rc_ps);
// load rows
__m512i r0 = _mm512_maskz_loadu_epi64(m, (void*) (s)),
@@ -639,32 +655,15 @@ void permute_avx512_fast(uint64_t s[static 25], const size_t num_rounds) {
// iota
{
- // round constants (ambiguous in spec)
- static const uint64_t RCS[] = {
- 0x0000000000000001ULL, 0x0000000000008082ULL, 0x800000000000808aULL, 0x8000000080008000ULL,
- 0x000000000000808bULL, 0x0000000080000001ULL, 0x8000000080008081ULL, 0x8000000000008009ULL,
- 0x000000000000008aULL, 0x0000000000000088ULL, 0x0000000080008009ULL, 0x000000008000000aULL,
- 0x000000008000808bULL, 0x800000000000008bULL, 0x8000000000008089ULL, 0x8000000000008003ULL,
- 0x8000000000008002ULL, 0x8000000000000080ULL, 0x000000000000800aULL, 0x800000008000000aULL,
- 0x8000000080008081ULL, 0x8000000000008080ULL, 0x0000000080000001ULL, 0x8000000080008008ULL,
- };
-
- // TODO
- // if ((24 - num_rounds + i + 1) % 8) {
- // rc = _mm512_permutexvar_epi64(p0, r4),
- // } else {
- // }
-
- // get rc address
- const uint64_t *rc = RCS + (24 - num_rounds + i);
-
- // load mask
- uint8_t m0b = 0x01;
- const __mmask8 m0 = _load_mask8(&m0b);
-
- // mask/store row
- const __m512i c0 = _mm512_maskz_loadu_epi64(m0, (void*) rc);
- r0 = _mm512_xor_epi64(r0, c0);
+ // xor round constant, shuffle rc register
+ r0 = _mm512_mask_xor_epi64(r0, m0, r0, rc);
+ rc = _mm512_permutexvar_epi64(rc_p, rc);
+
+ if (((24 - num_rounds + i + 1) % 8) == 0 && i != 23) {
+ // load next set of round constants
+ // note: this will bomb if num_rounds < 8 or num_rounds > 24.
+ rc = _mm512_loadu_epi64((void*) (RCS + 24 - num_rounds + (i + 1)));
+ }
}
}
@@ -816,7 +815,7 @@ static void test_permute_fast(void) {
}
// number of times to run permutation in timing tests below
-#define NUM_TIME_PERMUTES 10000000
+#define NUM_TIME_PERMUTES 20000000
// time scalar keccak permutation
static void time_permute_scalar(void) {
@@ -882,8 +881,11 @@ int main() {
test_iota();
test_permute_slow();
test_permute_fast();
+
+ printf("timing permute, please wait...\n");
time_permute_scalar();
time_permute_avx512_slow();
time_permute_avx512_fast();
+
return 0;
}