[CPU] Bug in jit_convert fixed (#9485)

This commit is contained in:
Vladislav Volkov
2021-12-30 18:22:16 +03:00
committed by GitHub
parent ec5198094a
commit e52c96389d

View File

@@ -6,6 +6,7 @@
#include "cpu_memcpy.h"
#include <utils/bfloat16.hpp>
#include <utils/general_utils.h>
#include <utils/jit_kernel.hpp>
#include <mkldnn_selective_build.h>
#include <ie_parallel.hpp>
#include <openvino/core/type/float16.hpp>
@@ -17,8 +18,8 @@
using namespace MKLDNNPlugin;
using namespace InferenceEngine;
using namespace dnnl::impl::cpu::x64;
using namespace dnnl::impl::utils;
using namespace mkldnn::impl::utils;
using namespace mkldnn::impl::cpu::x64;
using namespace Xbyak;
namespace {
@@ -52,109 +53,51 @@ void convert_vec<float, ov::float16>(jit_generator & gen,
gen.movdqu(gen.xword[dst], f16vec);
}
class jit_convert_array : public jit_generator {
class jit_convert_array : public jit_kernel {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_convert_array)
void generate() override {
const size_t vlen = 8u;
const size_t vlen_log2 = 3;
auto reg_src = rax;
auto reg_dst = rbx;
auto reg_sz = rdx;
Label tail, exit;
constexpr size_t vlen = 8u;
constexpr size_t vlen_log2 = 3;
preamble();
mov(reg_src, ptr[param1 + offsetof(args_t, src)]);
mov(reg_dst, ptr[param1 + offsetof(args_t, out)]);
mov(reg_sz, ptr[param1 + offsetof(args_t, count)]);
// Get arguments addresses
auto src = arg(&args_t::src);
auto dst = arg(&args_t::out);
auto size = arg(&args_t::count);
xor_(rsi, rsi);
mov(r8, reg_sz);
shr(r8, vlen_log2);
size >>= vlen_log2;
foreach(rsi, 1, r8, [&, this](const Xbyak::Reg64& idx) {
_convert_vec(*this, reg_src, reg_dst);
add(reg_src, _src_size * vlen);
add(reg_dst, _dst_size * vlen);
foreach(0, size, [&, this](const Xbyak::Reg64& idx) {
_convert_vec(*this, src, dst);
src += _src_size * vlen;
dst += _dst_size * vlen;
});
L(tail);
shl(rsi, vlen_log2);
sub(reg_sz, rsi);
test(reg_sz, reg_sz);
jz(exit);
// allocate array for 8 floats on stack
sub(rsp, vlen * sizeof(float));
mov(r8, rsp);
vpxor(ymm4, ymm4, ymm4);
vmovups(yword[r8], ymm4);
mov(size, argPtr(&args_t::count));
size &= vlen - 1;
// Tail conversion
copy(r8, reg_src, reg_sz, _src_size);
_convert_vec(*this, r8, r8);
copy(reg_dst, r8, reg_sz, _dst_size);
_if(size != 0)
._then([&] {
auto tmp = stack(vlen * sizeof(float));
tmp.clear();
// Free the array on stack
add(rsp, vlen * sizeof(float));
auto tail_size = var<size_t>();
L(exit);
tail_size = size;
tail_size <<= static_cast<size_t>(std::logb(_src_size)) - 1;
copy<uint16_t>(tmp.pointer(), src, tail_size);
postamble();
}
_convert_vec(*this, tmp.pointer(), tmp.pointer());
void foreach(const Xbyak::Reg64& idx,
size_t step,
const Xbyak::Reg64& end,
std::function<void(const Xbyak::Reg64&)> && fn) {
Label loop, exit;
L(loop);
cmp(idx, end);
jge(exit);
fn(idx);
add(idx, step);
jmp(loop);
L(exit);
}
void copy(const Xbyak::Reg64& dst,
const Xbyak::Reg64& src,
const Xbyak::Reg64& size,
size_t item_size) {
push(rsi);
push(r15);
xor_(rsi, rsi);
auto address_frame = [this](size_t size) -> const AddressFrame& {
switch (size) {
case 1: return byte;
case 2: return word;
case 4: return dword;
case 8: return qword;
default:
break;
}
return ptr;
};
const auto & addr_frame = address_frame(item_size);
foreach(rsi, 1, size, [&, this](const Xbyak::Reg64& idx) {
mov(r15, addr_frame[src + idx * item_size]);
mov(addr_frame[dst + idx * item_size], r15);
tail_size = size;
tail_size <<= static_cast<size_t>(std::logb(_dst_size)) - 1;
copy<uint16_t>(dst, tmp.pointer(), tail_size);
});
pop(r15);
pop(rsi);
postamble();
}
public:
@@ -179,7 +122,8 @@ public:
template<typename src_t, typename dst_t>
static fn_t get() {
if (mayiuse(avx2) && cpu().has(util::Cpu::tF16C)) {
if (mayiuse(cpu_isa_t::avx2)
&& dnnl::impl::cpu::x64::cpu().has(Xbyak::util::Cpu::tF16C)) {
static jit_convert_array converter(convert_vec<src_t, dst_t>, sizeof(src_t), sizeof(dst_t));
auto & generator = static_cast<jit_generator&>(converter);
generator.create_kernel();
@@ -216,7 +160,7 @@ struct PrecisionInfo {
template <>
struct PrecisionInfo<Precision::BF16> {
using value_type = bfloat16_t;
using value_type = MKLDNNPlugin::bfloat16_t;
};
template <>
@@ -232,7 +176,7 @@ struct PrecisionInfo<Precision::BOOL> {
template<typename T,
typename U = typename std::conditional<
std::is_same<ov::float16, T>::value
|| std::is_same<bfloat16_t, T>::value,
|| std::is_same<MKLDNNPlugin::bfloat16_t, T>::value,
float, T>::type>
struct Range {
const std::tuple<U, U> & fit(const Precision & prec);
@@ -250,8 +194,8 @@ const std::tuple<U, U> & Range<T, U>::fit(const Precision & prec) {
double lbound, ubound;
switch (prec) {
case Precision::BF16:
lbound = static_cast<double>(std::numeric_limits<bfloat16_t>::lowest());
ubound = static_cast<double>(std::numeric_limits<bfloat16_t>::max());
lbound = static_cast<double>(std::numeric_limits<MKLDNNPlugin::bfloat16_t>::lowest());
ubound = static_cast<double>(std::numeric_limits<MKLDNNPlugin::bfloat16_t>::max());
break;
case Precision::FP16:
lbound = static_cast<double>(std::numeric_limits<ov::float16>::lowest());
@@ -366,20 +310,20 @@ struct ConvertPrecision<std::tuple<src_t, dst_t>> {
};
template<>
struct ConvertPrecision<std::tuple<float, bfloat16_t>> {
struct ConvertPrecision<std::tuple<float, MKLDNNPlugin::bfloat16_t>> {
void operator()(ConvertContext & ctx) {
auto src = static_cast<const float *>(ctx.srcPtr);
auto dst = static_cast<bfloat16_t *>(ctx.dstPtr);
auto dst = static_cast<MKLDNNPlugin::bfloat16_t *>(ctx.dstPtr);
if (ctx.interimPrc.is_float()) {
parallel_for(ctx.size, [&](size_t i) {
dst[i] = static_cast<bfloat16_t>(src[i]);
dst[i] = static_cast<MKLDNNPlugin::bfloat16_t>(src[i]);
});
} else {
float lbound, ubound;
std::tie(lbound, ubound) = ctx.range<float>();
parallel_for(ctx.size, [&](size_t i) {
dst[i] = static_cast<bfloat16_t>(std::trunc(std::max(std::min(src[i], ubound), lbound)));
dst[i] = static_cast<MKLDNNPlugin::bfloat16_t>(std::trunc(std::max(std::min(src[i], ubound), lbound)));
});
}
@@ -388,9 +332,9 @@ struct ConvertPrecision<std::tuple<float, bfloat16_t>> {
};
template<>
struct ConvertPrecision<std::tuple<bfloat16_t, float>> {
struct ConvertPrecision<std::tuple<MKLDNNPlugin::bfloat16_t, float>> {
void operator()(ConvertContext & ctx) {
auto src = static_cast<const bfloat16_t *>(ctx.srcPtr);
auto src = static_cast<const MKLDNNPlugin::bfloat16_t *>(ctx.srcPtr);
auto dst = static_cast<float *>(ctx.dstPtr);
if (ctx.interimPrc.is_float()) {
@@ -399,7 +343,7 @@ struct ConvertPrecision<std::tuple<bfloat16_t, float>> {
});
} else {
float lbound, ubound;
std::tie(lbound, ubound) = ctx.range<bfloat16_t>();
std::tie(lbound, ubound) = ctx.range<MKLDNNPlugin::bfloat16_t>();
parallel_for(ctx.size, [&](size_t i) {
dst[i] = std::trunc(std::max(std::min(static_cast<float>(src[i]), ubound), lbound));
});