[CPU] Bug in jit_convert fixed (#9485)
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "cpu_memcpy.h"
|
||||
#include <utils/bfloat16.hpp>
|
||||
#include <utils/general_utils.h>
|
||||
#include <utils/jit_kernel.hpp>
|
||||
#include <mkldnn_selective_build.h>
|
||||
#include <ie_parallel.hpp>
|
||||
#include <openvino/core/type/float16.hpp>
|
||||
@@ -17,8 +18,8 @@
|
||||
|
||||
using namespace MKLDNNPlugin;
|
||||
using namespace InferenceEngine;
|
||||
using namespace dnnl::impl::cpu::x64;
|
||||
using namespace dnnl::impl::utils;
|
||||
using namespace mkldnn::impl::utils;
|
||||
using namespace mkldnn::impl::cpu::x64;
|
||||
using namespace Xbyak;
|
||||
|
||||
namespace {
|
||||
@@ -52,109 +53,51 @@ void convert_vec<float, ov::float16>(jit_generator & gen,
|
||||
gen.movdqu(gen.xword[dst], f16vec);
|
||||
}
|
||||
|
||||
class jit_convert_array : public jit_generator {
|
||||
class jit_convert_array : public jit_kernel {
|
||||
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_convert_array)
|
||||
|
||||
void generate() override {
|
||||
const size_t vlen = 8u;
|
||||
const size_t vlen_log2 = 3;
|
||||
|
||||
auto reg_src = rax;
|
||||
auto reg_dst = rbx;
|
||||
auto reg_sz = rdx;
|
||||
|
||||
Label tail, exit;
|
||||
constexpr size_t vlen = 8u;
|
||||
constexpr size_t vlen_log2 = 3;
|
||||
|
||||
preamble();
|
||||
|
||||
mov(reg_src, ptr[param1 + offsetof(args_t, src)]);
|
||||
mov(reg_dst, ptr[param1 + offsetof(args_t, out)]);
|
||||
mov(reg_sz, ptr[param1 + offsetof(args_t, count)]);
|
||||
// Get arguments addresses
|
||||
auto src = arg(&args_t::src);
|
||||
auto dst = arg(&args_t::out);
|
||||
auto size = arg(&args_t::count);
|
||||
|
||||
xor_(rsi, rsi);
|
||||
mov(r8, reg_sz);
|
||||
shr(r8, vlen_log2);
|
||||
size >>= vlen_log2;
|
||||
|
||||
foreach(rsi, 1, r8, [&, this](const Xbyak::Reg64& idx) {
|
||||
_convert_vec(*this, reg_src, reg_dst);
|
||||
add(reg_src, _src_size * vlen);
|
||||
add(reg_dst, _dst_size * vlen);
|
||||
foreach(0, size, [&, this](const Xbyak::Reg64& idx) {
|
||||
_convert_vec(*this, src, dst);
|
||||
src += _src_size * vlen;
|
||||
dst += _dst_size * vlen;
|
||||
});
|
||||
|
||||
L(tail);
|
||||
|
||||
shl(rsi, vlen_log2);
|
||||
sub(reg_sz, rsi);
|
||||
test(reg_sz, reg_sz);
|
||||
jz(exit);
|
||||
|
||||
// allocate array for 8 floats on stack
|
||||
sub(rsp, vlen * sizeof(float));
|
||||
mov(r8, rsp);
|
||||
|
||||
vpxor(ymm4, ymm4, ymm4);
|
||||
vmovups(yword[r8], ymm4);
|
||||
mov(size, argPtr(&args_t::count));
|
||||
size &= vlen - 1;
|
||||
|
||||
// Tail conversion
|
||||
copy(r8, reg_src, reg_sz, _src_size);
|
||||
_convert_vec(*this, r8, r8);
|
||||
copy(reg_dst, r8, reg_sz, _dst_size);
|
||||
_if(size != 0)
|
||||
._then([&] {
|
||||
auto tmp = stack(vlen * sizeof(float));
|
||||
tmp.clear();
|
||||
|
||||
// Free the array on stack
|
||||
add(rsp, vlen * sizeof(float));
|
||||
auto tail_size = var<size_t>();
|
||||
|
||||
L(exit);
|
||||
tail_size = size;
|
||||
tail_size <<= static_cast<size_t>(std::logb(_src_size)) - 1;
|
||||
copy<uint16_t>(tmp.pointer(), src, tail_size);
|
||||
|
||||
postamble();
|
||||
}
|
||||
_convert_vec(*this, tmp.pointer(), tmp.pointer());
|
||||
|
||||
void foreach(const Xbyak::Reg64& idx,
|
||||
size_t step,
|
||||
const Xbyak::Reg64& end,
|
||||
std::function<void(const Xbyak::Reg64&)> && fn) {
|
||||
Label loop, exit;
|
||||
|
||||
L(loop);
|
||||
cmp(idx, end);
|
||||
jge(exit);
|
||||
|
||||
fn(idx);
|
||||
|
||||
add(idx, step);
|
||||
jmp(loop);
|
||||
L(exit);
|
||||
}
|
||||
|
||||
void copy(const Xbyak::Reg64& dst,
|
||||
const Xbyak::Reg64& src,
|
||||
const Xbyak::Reg64& size,
|
||||
size_t item_size) {
|
||||
push(rsi);
|
||||
push(r15);
|
||||
|
||||
xor_(rsi, rsi);
|
||||
|
||||
auto address_frame = [this](size_t size) -> const AddressFrame& {
|
||||
switch (size) {
|
||||
case 1: return byte;
|
||||
case 2: return word;
|
||||
case 4: return dword;
|
||||
case 8: return qword;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return ptr;
|
||||
};
|
||||
|
||||
const auto & addr_frame = address_frame(item_size);
|
||||
|
||||
foreach(rsi, 1, size, [&, this](const Xbyak::Reg64& idx) {
|
||||
mov(r15, addr_frame[src + idx * item_size]);
|
||||
mov(addr_frame[dst + idx * item_size], r15);
|
||||
tail_size = size;
|
||||
tail_size <<= static_cast<size_t>(std::logb(_dst_size)) - 1;
|
||||
copy<uint16_t>(dst, tmp.pointer(), tail_size);
|
||||
});
|
||||
|
||||
pop(r15);
|
||||
pop(rsi);
|
||||
postamble();
|
||||
}
|
||||
|
||||
public:
|
||||
@@ -179,7 +122,8 @@ public:
|
||||
|
||||
template<typename src_t, typename dst_t>
|
||||
static fn_t get() {
|
||||
if (mayiuse(avx2) && cpu().has(util::Cpu::tF16C)) {
|
||||
if (mayiuse(cpu_isa_t::avx2)
|
||||
&& dnnl::impl::cpu::x64::cpu().has(Xbyak::util::Cpu::tF16C)) {
|
||||
static jit_convert_array converter(convert_vec<src_t, dst_t>, sizeof(src_t), sizeof(dst_t));
|
||||
auto & generator = static_cast<jit_generator&>(converter);
|
||||
generator.create_kernel();
|
||||
@@ -216,7 +160,7 @@ struct PrecisionInfo {
|
||||
|
||||
template <>
|
||||
struct PrecisionInfo<Precision::BF16> {
|
||||
using value_type = bfloat16_t;
|
||||
using value_type = MKLDNNPlugin::bfloat16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -232,7 +176,7 @@ struct PrecisionInfo<Precision::BOOL> {
|
||||
template<typename T,
|
||||
typename U = typename std::conditional<
|
||||
std::is_same<ov::float16, T>::value
|
||||
|| std::is_same<bfloat16_t, T>::value,
|
||||
|| std::is_same<MKLDNNPlugin::bfloat16_t, T>::value,
|
||||
float, T>::type>
|
||||
struct Range {
|
||||
const std::tuple<U, U> & fit(const Precision & prec);
|
||||
@@ -250,8 +194,8 @@ const std::tuple<U, U> & Range<T, U>::fit(const Precision & prec) {
|
||||
double lbound, ubound;
|
||||
switch (prec) {
|
||||
case Precision::BF16:
|
||||
lbound = static_cast<double>(std::numeric_limits<bfloat16_t>::lowest());
|
||||
ubound = static_cast<double>(std::numeric_limits<bfloat16_t>::max());
|
||||
lbound = static_cast<double>(std::numeric_limits<MKLDNNPlugin::bfloat16_t>::lowest());
|
||||
ubound = static_cast<double>(std::numeric_limits<MKLDNNPlugin::bfloat16_t>::max());
|
||||
break;
|
||||
case Precision::FP16:
|
||||
lbound = static_cast<double>(std::numeric_limits<ov::float16>::lowest());
|
||||
@@ -366,20 +310,20 @@ struct ConvertPrecision<std::tuple<src_t, dst_t>> {
|
||||
};
|
||||
|
||||
template<>
|
||||
struct ConvertPrecision<std::tuple<float, bfloat16_t>> {
|
||||
struct ConvertPrecision<std::tuple<float, MKLDNNPlugin::bfloat16_t>> {
|
||||
void operator()(ConvertContext & ctx) {
|
||||
auto src = static_cast<const float *>(ctx.srcPtr);
|
||||
auto dst = static_cast<bfloat16_t *>(ctx.dstPtr);
|
||||
auto dst = static_cast<MKLDNNPlugin::bfloat16_t *>(ctx.dstPtr);
|
||||
|
||||
if (ctx.interimPrc.is_float()) {
|
||||
parallel_for(ctx.size, [&](size_t i) {
|
||||
dst[i] = static_cast<bfloat16_t>(src[i]);
|
||||
dst[i] = static_cast<MKLDNNPlugin::bfloat16_t>(src[i]);
|
||||
});
|
||||
} else {
|
||||
float lbound, ubound;
|
||||
std::tie(lbound, ubound) = ctx.range<float>();
|
||||
parallel_for(ctx.size, [&](size_t i) {
|
||||
dst[i] = static_cast<bfloat16_t>(std::trunc(std::max(std::min(src[i], ubound), lbound)));
|
||||
dst[i] = static_cast<MKLDNNPlugin::bfloat16_t>(std::trunc(std::max(std::min(src[i], ubound), lbound)));
|
||||
});
|
||||
}
|
||||
|
||||
@@ -388,9 +332,9 @@ struct ConvertPrecision<std::tuple<float, bfloat16_t>> {
|
||||
};
|
||||
|
||||
template<>
|
||||
struct ConvertPrecision<std::tuple<bfloat16_t, float>> {
|
||||
struct ConvertPrecision<std::tuple<MKLDNNPlugin::bfloat16_t, float>> {
|
||||
void operator()(ConvertContext & ctx) {
|
||||
auto src = static_cast<const bfloat16_t *>(ctx.srcPtr);
|
||||
auto src = static_cast<const MKLDNNPlugin::bfloat16_t *>(ctx.srcPtr);
|
||||
auto dst = static_cast<float *>(ctx.dstPtr);
|
||||
|
||||
if (ctx.interimPrc.is_float()) {
|
||||
@@ -399,7 +343,7 @@ struct ConvertPrecision<std::tuple<bfloat16_t, float>> {
|
||||
});
|
||||
} else {
|
||||
float lbound, ubound;
|
||||
std::tie(lbound, ubound) = ctx.range<bfloat16_t>();
|
||||
std::tie(lbound, ubound) = ctx.range<MKLDNNPlugin::bfloat16_t>();
|
||||
parallel_for(ctx.size, [&](size_t i) {
|
||||
dst[i] = std::trunc(std::max(std::min(static_cast<float>(src[i]), ubound), lbound));
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user