Commit 43b918dd authored by s_kleplj's avatar s_kleplj
Browse files

added implementations for all policies, changed policy implementation slightly, simplified avx512

parent d9019494
......@@ -12,21 +12,46 @@
namespace levensol {
struct policy_sse {
using data_type = __m128i;
};
struct policy_avx {
using data_type = __m256i;
};
struct policy_avx512 {
};
template<typename policy>
struct policy_data {
using data_type = std::size_t;
static constexpr bool needs_alignment = false;
};
template<>
struct policy_data<policy_sse> {
using data_type = __m128i;
static constexpr bool needs_alignment = true;
};
#ifdef USE_AVX
template<>
struct policy_data<policy_avx> {
using data_type = __m256i;
static constexpr bool needs_alignment = true;
};
#endif
#ifdef USE_AVX512
template<>
struct policy_data<policy_avx512> {
using data_type = __m512i;
static constexpr bool needs_alignment = true;
};
#endif
template< typename policy>
class levenstein {
public:
using data_type = typename policy::data_type;
using data_type = typename policy_data<policy>::data_type;
static constexpr std::size_t multiplier = sizeof(data_type) / sizeof(std::uint32_t);
static constexpr std::size_t multimask = multiplier - 1;
levenstein(std::size_t a_size, std::size_t b_size) :
......@@ -88,7 +113,7 @@ namespace levensol {
x = (i < a_size) ? a_size - 1 - i : 0;
y = (i < a_size) ? 0 : i + 1 - a_size;
d = ((i < a_size) ? a_size - i - 1: i - (a_size-1)) / 2;
d = ((i < a_size) ? a_size - 1 - i: i + 1 - a_size) / 2;
if (!odd && d == 0) {
even_v[d] = std::min(
......@@ -96,10 +121,7 @@ namespace levensol {
odd_v[d] + 1);
++x; ++y; ++d;
}
#ifdef USE_AVX512
if constexpr (std::is_same<policy,policy_avx512>::value) {
const unsigned char x_rest = (x + (((d - 1) | multimask) + 1) - d) & multimask, y_rest = (y + (((d - 1) | multimask) + 1) - d) & multimask;
if constexpr (policy_data<policy>().needs_alignment) {
cycles = std::min({a_size - x, b_size - y, (((d - 1) | multimask) + 1) - d});
for(; cycles != 0; ++x, ++y, ++d, --cycles) {
if (odd) {
......@@ -116,133 +138,88 @@ if constexpr (std::is_same<policy,policy_avx512>::value) {
odd_v[d]) + 1);
}
}
}
if constexpr (std::is_same<policy, policy_sse>::value) {
cycles = std::min(a_size - x, b_size - y) / multiplier;
std::size_t X = x / multiplier, Y = y / multiplier, D = d / multiplier;
for (; cycles != 0; x += multiplier, y += multiplier, d += multiplier, ++X, ++Y, ++D, --cycles) {
__m512i a, b;
switch(x_rest) {
case 0:
a = a_vector[X];
break;
case 1:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 1);
break;
case 2:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 2);
break;
case 3:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 3);
break;
case 4:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 4);
break;
case 5:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 5);
break;
case 6:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 6);
break;
case 7:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 7);
break;
case 8:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 8);
break;
case 9:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 9);
break;
case 10:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 10);
break;
case 11:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 11);
break;
case 12:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 12);
break;
case 13:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 13);
break;
case 14:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 14);
break;
case 15:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 15);
break;
for (; cycles != 0; x += multiplier, y += multiplier, d += multiplier, --cycles) {
data_type
a = _mm_lddqu_si128((data_type*)(a_v + x)),
b = _mm_lddqu_si128((data_type*)(b_v + y));
auto tmp = _mm_cmpeq_epi32(a, b);
if (odd) {
tmp = _mm_add_epi32(tmp, *(data_type*)(odd_v + d));
tmp = _mm_min_epu32(*(data_type*)(even_v + d), tmp);
auto tmp2 = _mm_lddqu_si128((data_type*)(even_v + d + 1));
tmp = _mm_min_epu32(tmp, tmp2);
*(data_type*)(odd_v + d) = _mm_add_epi32(tmp, _mm_set1_epi32(1));
} else {
tmp = _mm_add_epi32(tmp, *(data_type*)(even_v + d));
tmp = _mm_min_epu32(*(data_type*)(odd_v + d), tmp);
auto tmp2 = _mm_lddqu_si128((data_type*)(odd_v + d - 1));
tmp = _mm_min_epu32(tmp, tmp2);
*(data_type*)(even_v + d) = _mm_add_epi32(tmp, _mm_set1_epi32(1));
}
}
}
#ifdef USE_AVX
if constexpr (std::is_same<policy, policy_avx>::value) {
cycles = std::min(a_size - x, b_size - y) / multiplier;
for (; cycles != 0; x += multiplier, y += multiplier, d += multiplier, --cycles) {
const data_type
a = _mm256_lddqu_si256((data_type*)(a_v + x)),
b = _mm256_lddqu_si256((data_type*)(b_v + y));
switch(y_rest) {
case 0:
b = b_vector[Y];
break;
case 1:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 1);
break;
case 2:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 2);
break;
case 3:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 3);
break;
case 4:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 4);
break;
case 5:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 5);
break;
case 6:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 6);
break;
case 7:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 7);
break;
case 8:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 8);
break;
case 9:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 9);
break;
case 10:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 10);
break;
case 11:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 11);
break;
case 12:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 12);
break;
case 13:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 13);
break;
case 14:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 14);
break;
case 15:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 15);
break;
auto tmp = _mm256_cmpeq_epi32(a, b);
if (odd) {
tmp = _mm256_add_epi32(tmp, *(data_type*)(odd_v + d));
tmp = _mm256_min_epu32(*(data_type*)(even_v + d), tmp);
const auto tmp2 = _mm256_lddqu_si256((data_type*)(even_v + d + 1));
tmp = _mm256_min_epu32(tmp, tmp2);
*(data_type*)(odd_v + d) = _mm256_add_epi32(tmp, _mm256_set1_epi32(1));
} else {
tmp = _mm256_add_epi32(tmp, *(data_type*)(even_v + d));
tmp = _mm256_min_epu32(*(data_type*)(odd_v + d), tmp);
const auto tmp2 = _mm256_lddqu_si256((data_type*)(odd_v + d - 1));
tmp = _mm256_min_epu32(tmp, tmp2);
*(data_type*)(even_v + d) = _mm256_add_epi32(tmp, _mm256_set1_epi32(1));
}
}
}
#endif
#ifdef USE_AVX512
if constexpr (std::is_same<policy, policy_avx512>::value) {
cycles = std::min(a_size - x, b_size - y) / multiplier;
const auto mask = _mm512_cmpeq_epu32_mask(a, b);
for (; cycles != 0; x += multiplier, y += multiplier, d += multiplier, --cycles) {
const auto mask = _mm512_cmpeq_epu32_mask(
_mm512_loadu_si512(a_v + x),
_mm512_loadu_si512(b_v + y));
if (odd) {
auto tmp = odd_vector[D];
auto tmp = *(data_type*)(odd_v + d);
tmp = _mm512_mask_sub_epi32(tmp, mask, tmp, _mm512_set1_epi32(1));
tmp = _mm512_min_epu32(even_vector[D], tmp);
auto tmp2 = _mm512_alignr_epi32(even_vector[D + 1], even_vector[D], 1);
tmp = _mm512_min_epu32(*(data_type*)(even_v + d), tmp);
const auto tmp2 = _mm512_loadu_si512(even_v + d + 1);
tmp = _mm512_min_epu32(tmp, tmp2);
odd_vector[D] = _mm512_add_epi32(tmp, _mm512_set1_epi32(1));
*(data_type*)(odd_v + d) = _mm512_add_epi32(tmp, _mm512_set1_epi32(1));
} else {
auto tmp = even_vector[D];
auto tmp = *(data_type*)(even_v + d);
tmp = _mm512_mask_sub_epi32(tmp, mask, tmp, _mm512_set1_epi32(1));
tmp = _mm512_min_epu32(odd_vector[D], tmp);
auto tmp2 = _mm512_alignr_epi32(odd_vector[D], odd_vector[D - 1], multimask);
tmp = _mm512_min_epu32(*(data_type*)(odd_v + d), tmp);
const auto tmp2 = _mm512_loadu_si512(odd_v + d - 1);
tmp = _mm512_min_epu32(tmp, tmp2);
even_vector[D] = _mm512_add_epi32(tmp, _mm512_set1_epi32(1));
*(data_type*)(even_v + d) = _mm512_add_epi32(tmp, _mm512_set1_epi32(1));
}
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment