Commit d9019494 authored by s_kleplj's avatar s_kleplj
Browse files

diagonal solution for avx512

parent 9d4e96b8
......@@ -5,68 +5,292 @@
#include <cstdint>
#include <cstddef>
#include <utility>
#include <type_traits>
#include <immintrin.h>
namespace levensol {
struct policy_sse {
using data_type = __m128i;
};
struct policy_avx {
using data_type = __m256i;
};
struct policy_avx512 {
using data_type = __m512i;
};
template< typename policy>
class levenstein {
public:
using data_type = typename policy::data_type;
static constexpr std::size_t multiplier = sizeof(data_type) / sizeof(std::uint32_t);
static constexpr std::size_t multimask = multiplier - 1;
levenstein(std::size_t a_size, std::size_t b_size) :
a_size_{a_size},
b_size_{b_size},
fst_vector(std::max(a_size, b_size) + 1, 0),
snd_vector(std::max(a_size, b_size) + 1, 0)
a_vector_size{a_size/multiplier + 2},
b_vector_size{b_size/multiplier + 2},
odd_vector((a_vector_size + b_vector_size) / 2 + 1),
even_vector((a_vector_size + b_vector_size) / 2 + 1),
a_vector(a_vector_size),
b_vector(b_vector_size)
{
std::uint32_t *const odd_v = (std::uint32_t*)&odd_vector[0];
std::uint32_t *const even_v = (std::uint32_t*)&even_vector[0];
for (std::size_t i = ((a_vector_size + b_vector_size) / 2 + 1) * multiplier; i-- > 0;) {
odd_v[i] = even_v[i] = a_size + b_size;
}
}
// a_size >= b_size
std::uint32_t compute_impl(const std::uint32_t* a, const std::uint32_t* b, const std::size_t a_size, const std::size_t b_size)
// a_vector is in reverse
std::uint32_t compute_impl(const std::size_t a_size, const std::size_t b_size)
{
for (std::size_t i = 0; i <= b_size; ++i) {
fst_vector[i] = i;
bool odd = (a_size & 1) == 0;
std::uint32_t *const odd_v = (std::uint32_t*)&odd_vector[0];
std::uint32_t *const even_v = (std::uint32_t*)&even_vector[0];
std::uint32_t *const a_v = (std::uint32_t*)&a_vector[0];
std::uint32_t *const b_v = (std::uint32_t*)&b_vector[0];
for (std::size_t i = 0; i < a_size; ++i) {
if (i & 1) {
odd_v[i/2] = a_size - i;
} else {
even_v[i/2] = a_size - i;
}
}
for (std::size_t i = 0; i < b_size; ++i) {
if ((i + a_size) & 1) {
odd_v[(i + a_size) / 2] = i + 2;
} else {
even_v[(i + a_size) / 2] = i + 2;
}
}
if (odd) {
odd_v[(a_size - 1)/2] = 1;
} else {
even_v[(a_size - 1)/2] = 1;
}
for (std::size_t x = 0; x < a_size; ++x) {
snd_vector[0] = x + 1;
std::uint32_t a_x = a[x];
for (std::size_t y = 0; y < b_size; ++y) {
std::uint32_t tmp = std::min(
fst_vector[y] + (a_x == b[y] ? 0 : 1),
fst_vector[y + 1] + 1);
snd_vector[y + 1] = std::min(
tmp,
snd_vector[y] + 1);
for (std::size_t i = 0; i < a_size + b_size; ++i) {
std::size_t x, y, d, cycles;
// i == 0 means only first chars overlap
// then there is something in-between
// i == a_size + b_size - 1 means only last chars overlap
// i is the offset of first indices of a_vector and b_vector (a is in reverse)
// relative to each other
x = (i < a_size) ? a_size - 1 - i : 0;
y = (i < a_size) ? 0 : i + 1 - a_size;
d = ((i < a_size) ? a_size - i - 1: i - (a_size-1)) / 2;
if (!odd && d == 0) {
even_v[d] = std::min(
even_v[d] + (a_v[x] == b_v[y] ? 0 : 1),
odd_v[d] + 1);
++x; ++y; ++d;
}
#ifdef USE_AVX512
if constexpr (std::is_same<policy,policy_avx512>::value) {
const unsigned char x_rest = (x + (((d - 1) | multimask) + 1) - d) & multimask, y_rest = (y + (((d - 1) | multimask) + 1) - d) & multimask;
cycles = std::min({a_size - x, b_size - y, (((d - 1) | multimask) + 1) - d});
for(; cycles != 0; ++x, ++y, ++d, --cycles) {
if (odd) {
odd_v[d] = std::min(
odd_v[d] + (a_v[x] == b_v[y] ? 0 : 1),
std::min(
even_v[d],
even_v[d + 1]) + 1);
} else {
even_v[d] = std::min(
even_v[d] + (a_v[x] == b_v[y] ? 0 : 1),
std::min(
odd_v[d - 1],
odd_v[d]) + 1);
}
}
snd_vector.swap(fst_vector);
cycles = std::min(a_size - x, b_size - y) / multiplier;
std::size_t X = x / multiplier, Y = y / multiplier, D = d / multiplier;
for (; cycles != 0; x += multiplier, y += multiplier, d += multiplier, ++X, ++Y, ++D, --cycles) {
__m512i a, b;
switch(x_rest) {
case 0:
a = a_vector[X];
break;
case 1:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 1);
break;
case 2:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 2);
break;
case 3:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 3);
break;
case 4:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 4);
break;
case 5:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 5);
break;
case 6:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 6);
break;
case 7:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 7);
break;
case 8:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 8);
break;
case 9:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 9);
break;
case 10:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 10);
break;
case 11:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 11);
break;
case 12:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 12);
break;
case 13:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 13);
break;
case 14:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 14);
break;
case 15:
a = _mm512_alignr_epi32(a_vector[X + 1], a_vector[X], 15);
break;
}
switch(y_rest) {
case 0:
b = b_vector[Y];
break;
case 1:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 1);
break;
case 2:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 2);
break;
case 3:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 3);
break;
case 4:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 4);
break;
case 5:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 5);
break;
case 6:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 6);
break;
case 7:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 7);
break;
case 8:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 8);
break;
case 9:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 9);
break;
case 10:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 10);
break;
case 11:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 11);
break;
case 12:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 12);
break;
case 13:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 13);
break;
case 14:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 14);
break;
case 15:
b = _mm512_alignr_epi32(b_vector[Y + 1], b_vector[Y], 15);
break;
}
const auto mask = _mm512_cmpeq_epu32_mask(a, b);
if (odd) {
auto tmp = odd_vector[D];
tmp = _mm512_mask_sub_epi32(tmp, mask, tmp, _mm512_set1_epi32(1));
tmp = _mm512_min_epu32(even_vector[D], tmp);
auto tmp2 = _mm512_alignr_epi32(even_vector[D + 1], even_vector[D], 1);
tmp = _mm512_min_epu32(tmp, tmp2);
odd_vector[D] = _mm512_add_epi32(tmp, _mm512_set1_epi32(1));
} else {
auto tmp = even_vector[D];
tmp = _mm512_mask_sub_epi32(tmp, mask, tmp, _mm512_set1_epi32(1));
tmp = _mm512_min_epu32(odd_vector[D], tmp);
auto tmp2 = _mm512_alignr_epi32(odd_vector[D], odd_vector[D - 1], multimask);
tmp = _mm512_min_epu32(tmp, tmp2);
even_vector[D] = _mm512_add_epi32(tmp, _mm512_set1_epi32(1));
}
}
}
#endif
cycles = std::min(a_size - x, b_size - y);
for(; cycles != 0; ++x, ++y, ++d, --cycles) {
if (odd) {
odd_v[d] = std::min({
odd_v[d] + (a_v[x] == b_v[y] ? 0 : 1),
even_v[d] + 1,
even_v[d + 1] + 1});
} else {
even_v[d] = std::min({
even_v[d] + (a_v[x] == b_v[y] ? 0 : 1),
odd_v[d - 1] + 1,
odd_v[d] + 1});
}
}
odd = !odd;
}
return fst_vector[b_size];
return ((b_size & 1) == 0 ? odd_v[(b_size - 1)/2] : even_v[(b_size - 1)/2]) - 1;
}
std::uint32_t compute(const std::uint32_t* a, const std::uint32_t* b)
{
if (a_size_ < b_size_) {
return compute_impl(a, b, a_size_, b_size_);
} else {
return compute_impl(b, a, b_size_, a_size_);
if (a == b) {
return std::max(a_size_, b_size_) - std::min(a_size_, b_size_);
}
}
const std::size_t a_size_, b_size_;
std::vector<std::uint32_t> fst_vector, snd_vector;
};
std::uint32_t *const a_v = (std::uint32_t*)&a_vector[0];
std::uint32_t *const b_v = (std::uint32_t*)&b_vector[0];
struct policy_sse {
};
for (std::size_t x = 0; x < a_size_; ++x) {
a_v[x] = a[a_size_ - x - 1];
}
struct policy_avx {
};
for (std::size_t y = 0; y < b_size_; ++y) {
b_v[y] = b[y];
}
struct policy_avx512 {
return compute_impl(a_size_, b_size_);
}
const std::size_t a_size_, b_size_;
const std::size_t a_vector_size, b_vector_size;
std::vector<data_type> odd_vector, even_vector, a_vector, b_vector;
};
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment