summaryrefslogtreecommitdiff
path: root/sha3.c
diff options
context:
space:
mode:
Diffstat (limited to 'sha3.c')
-rw-r--r--sha3.c61
1 files changed, 21 insertions, 40 deletions
diff --git a/sha3.c b/sha3.c
index 6f936b9..8563d27 100644
--- a/sha3.c
+++ b/sha3.c
@@ -231,23 +231,6 @@ static inline void permute_n_scalar(uint64_t a[static 25], const size_t num_roun
iota(a, (SHA3_NUM_ROUNDS - num_rounds + i));
}
}
-
-/**
- * @brief 24 round scalar Keccak permutation.
- * @param[in,out] a Keccak state (array of 25 64-bit integers).
- */
-static inline void permute_scalar(uint64_t a[static 25]) {
- permute_n_scalar(a, 24);
-}
-
-/**
- * @brief 12 round scalar Keccak permutation.
- * @note Only used by TurboSHAKE and KangarooTwelve.
- * @param[in,out] a Keccak state (array of 25 64-bit integers).
- */
-static inline void permute12_scalar(uint64_t a[static 25]) {
- permute_n_scalar(a, 12);
-}
#endif /* (SHA3_BACKEND == BACKEND_SCALAR) || defined(SHA3_TEST) */
#if SHA3_BACKEND == BACKEND_AVX512
@@ -477,36 +460,34 @@ static inline void permute_n_avx512(uint64_t s[static 25], const size_t num_roun
_mm512_mask_storeu_epi64(s + 5 * 3, 0x1f, r3);
_mm512_mask_storeu_epi64(s + 5 * 4, 0x1f, r4);
}
+#endif /* SHA3_BACKEND == BACKEND_AVX512 */
+
+#if SHA3_BACKEND == BACKEND_AVX512
+// use avx512 backend
+#define permute_n permute_n_avx512
+#elif SHA3_BACKEND == BACKEND_SCALAR
+// use scalar backend
+#define permute_n permute_n_scalar
+#else
+#error "unknown sha3 backend"
+#endif /* SHA3_BACKEND */
/**
- * @brief 24 round AVX-512 Keccak permutation.
+ * @brief 24 round Keccak permutation.
* @param[in,out] a Keccak state (array of 25 64-bit integers).
*/
-static inline void permute_avx512(uint64_t s[static 25]) {
- permute_n_avx512(s, 24);
+static inline void permute(uint64_t s[static 25]) {
+ permute_n(s, 24);
}
/**
- * @brief 12 round AVX-512 Keccak permutation.
+ * @brief 12 round Keccak permutation.
* @note Only used by TurboSHAKE and KangarooTwelve.
* @param[in,out] a Keccak state (array of 25 64-bit integers).
*/
-static inline void permute12_avx512(uint64_t s[static 25]) {
- permute_n_avx512(s, 12);
+static inline void permute12(uint64_t s[static 25]) {
+ permute_n(s, 12);
}
-#endif /* SHA3_BACKEND == BACKEND_AVX512 */
-
-#if SHA3_BACKEND == BACKEND_AVX512
-// use avx512 backend
-#define permute permute_avx512
-#define permute12 permute12_avx512
-#elif SHA3_BACKEND == BACKEND_SCALAR
-// use scalar backend
-#define permute permute_scalar
-#define permute12 permute12_scalar
-#else
-#error "unknown sha3 backend"
-#endif /* SHA3_BACKEND */
// absorb message into state, return updated byte count
// used by `hash_absorb()`, `hash_once()`, and `xof_absorb_raw()`
@@ -2231,7 +2212,7 @@ static void test_permute_scalar(void) {
uint64_t got[25] = { 0 };
memcpy(got, PERMUTE_TESTS[i].a, sizeof(got));
- permute_scalar(got);
+ permute_n_scalar(got, 24); // call permute_n_scalar() directly
if (memcmp(got, PERMUTE_TESTS[i].exp, exp_len)) {
fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE_TESTS[i].exp, exp_len);
@@ -2246,7 +2227,7 @@ static void test_permute_avx512(void) {
uint64_t got[25] = { 0 };
memcpy(got, PERMUTE_TESTS[i].a, sizeof(got));
- permute_avx512(got);
+ permute_n_avx512(got, 24); // call permute_n_avx512() directly
if (memcmp(got, PERMUTE_TESTS[i].exp, exp_len)) {
fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE_TESTS[i].exp, exp_len);
@@ -2271,7 +2252,7 @@ static void test_permute12_scalar(void) {
uint64_t got[25] = { 0 };
memcpy(got, PERMUTE12_TESTS[i].a, sizeof(got));
- permute12_scalar(got);
+ permute_n_scalar(got, 12); // call permute_n_scalar() directly
if (memcmp(got, PERMUTE12_TESTS[i].exp, exp_len)) {
fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE12_TESTS[i].exp, exp_len);
@@ -2286,7 +2267,7 @@ static void test_permute12_avx512(void) {
uint64_t got[25] = { 0 };
memcpy(got, PERMUTE12_TESTS[i].a, sizeof(got));
- permute12_avx512(got);
+ permute_n_avx512(got, 12); // call permute_n_avx512() directly
if (memcmp(got, PERMUTE12_TESTS[i].exp, exp_len)) {
fail_test(__func__, "", (uint8_t*) got, exp_len, (uint8_t*) PERMUTE12_TESTS[i].exp, exp_len);