[Opset13][FP8] FakeConvert evaluate and reference impl (#21034)

* FakeConvert op init

* Update dest types names

* Update op hpp

* Update opset ops number

* Init type_prop tests

* Add attributes tests

* Add op check test

* Update namespace in fc cpp

* Update getters

* Refactor static member

* Make destination_type lower case

* Update type in test

* Move get_valid_types out of class

* Update ops number in opset

* Evaluate init

* Init poc evaluate

* Update types name

* Style update

* FakeConvert eval tests

* Add support for FP16 input

* Remove unused functions

* Use const auto and constexpr

* Create fake_convert reference file and move ref functions

* FakeConvert eval and reference update

* Remove unused is_normal variable

* Add bf16 tests

* Update construtor in tests

* adjust convert helper name

* Update comments

* Make func params as constexpr

* Update types

* Non scalar scale shape tests

* Use autobrodacast_select

* Use single autobroadcast_select

* Check scale_shape size

* Add var to keep scale_shift lambda

* Use lamda only for autobroadcast

* More checks for per channel scale

* Remove union half_t

* Change template to fp16 type for emulate

* Use f8e4m3_min_val as constexpr

* Update unsigned short to uint16_t

* Minor style refactor

* Minor comments update

* Update fake_convert_details namespace to func

* Update apply_conversion return type

* Add doxygen text

* Update supported type assert

* Remove duplicated tests

* Update namespace move to cpp

* Fill init with zeroes instead of reserve

* Use add div mul sub reference for applying scale

* Use autobroadcast_select for applying scale

* Bring back opt broacast cases

* Reuse scale_shift as func

* Adjust s and o var names

* Update multiplication in loop tu accumulate
This commit is contained in:
Katarzyna Mitrus 2023-11-24 09:54:19 +01:00 committed by GitHub
parent 65c3c7c20a
commit 63e7e2dd3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 1673 additions and 6 deletions

View File

@ -27,6 +27,7 @@ public:
void validate_and_infer_types() override; void validate_and_infer_types() override;
std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override; std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
bool visit_attributes(ov::AttributeVisitor& visitor) override; bool visit_attributes(ov::AttributeVisitor& visitor) override;
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
bool has_evaluate() const override; bool has_evaluate() const override;
const std::string& get_destination_type() const; const std::string& get_destination_type() const;

View File

@ -0,0 +1,221 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <algorithm>
#include <cmath>
#include <numeric>
#include <vector>
#include "openvino/reference/add.hpp"
#include "openvino/reference/autobroadcast_binop.hpp"
#include "openvino/reference/convert.hpp"
#include "openvino/reference/divide.hpp"
#include "openvino/reference/multiply.hpp"
#include "openvino/reference/subtract.hpp"
namespace ov {
namespace reference {
namespace func {
/**
* @brief Emulation of conversion fp16 value to f8e5m2 format
*
* @param arg_f Pointer to the input data.
* @param out_f Pointer to the otuput data.
* @param count Number of elements in the data input.
*/
void emulate_f8e5m2_on_fp16(const float16* const arg_f, float16* out_f, size_t count);
/**
* @brief Emulation of conversion fp16 value to f8e4m3 format
*
* @param arg_f Pointer to the input data.
* @param out_f Pointer to the otuput data.
* @param count Number of elements in the data input.
*
* Exponent denormal values 0 -7
* Exponent normal values 1..15 -6..8 (7 - exponent)
* Exponent NaN values 15 8
*
*/
void emulate_f8e4m3_on_fp16(const float16* arg_f, float16* out_f, size_t count);
} // namespace func
namespace fake_convert_details {
template <bool INVERT, typename T>
inline T scale_shift(T val, T scale, T shift) {
return INVERT ? static_cast<T>((val + shift) / scale) : static_cast<T>(val * scale - shift);
}
/**
* @brief Apply scale and shift for the data input.
* The scale_shape and shift_shape should be equal and numpy-broadcastable to the data_shape.
*
* @param data Pointer to the input data.
* @param scale Pointer to the input scale.
* @param shift Pointer to the input shift.
* @param data_shape Shape of the input data.
* @param scale_shape Shape of the input scale.
* @param shift_shape Shape of the input shift.
* @param invert Flag denoting applying scale before (if false) or after conversion (if true).
*/
template <typename T>
void apply_scale_shift(T* out,
const T* data,
const T* scale,
const T* shift,
const Shape& data_shape,
const Shape& scale_shape,
const Shape& shift_shape,
const bool invert = false) {
const auto scale_shift_func = invert ? scale_shift<true, T> : scale_shift<false, T>;
const auto scale_size = shape_size(scale_shape);
const auto data_size = shape_size(data_shape);
if (scale_size == 1) { // per tensor scale, probably for activation
T scale_val = scale[0];
T shift_val = shift[0];
for (size_t j = 0; j < data_size; ++j) {
out[j] = scale_shift_func(data[j], scale_val, shift_val);
}
return;
}
if (scale_shape.size() > 1 && scale_shape[0] == 1 && data_shape.size() == scale_shape.size() &&
data_shape[1] == scale_shape[1] && scale_size == data_shape[1]) { // per channel scale for DW activations
const auto step = shape_size(data_shape.begin() + 2, data_shape.end());
for (size_t bs = 0; bs < data_shape[0]; ++bs) {
for (size_t i = 0; i < scale_size; i++) {
T scale_val = scale[i];
T shift_val = shift[i];
for (size_t j = 0; j < step; ++j) {
out[j] = scale_shift_func(data[j], scale_val, shift_val);
}
data += step;
out += step;
}
}
return;
}
// The specific cases above are optimized verions of the broadcast for specific case
// Autobroadcast helper is generic approach for broadcast
autobroadcast_select(data,
scale,
shift,
out,
data_shape,
scale_shape,
shift_shape,
op::AutoBroadcastType::NUMPY,
scale_shift_func);
}
/**
* @brief Call conversion of fp16 value to the desired destination type
*
* @param arg Pointer to the input data.
* @param out Pointer to the otuput data.
* @param count Number of elements in the data input.
* @param destination_type Name of the destination type.
*/
void apply_conversion(const float16* data, float16* out, size_t element_count, const std::string& destination_type);
} // namespace fake_convert_details
/**
* @brief Reference implementation of the FakeConvert operation specialized for float16 type.
* It emulates format specified by "destination_type" on the input type.
* FakeConvert performs element-wise quantization of floating-point input values into a set of values corresponding to a
* target low-precision floating-point type.
*
*
* @param data Pointer to the input data.
* @param scale Pointer to the input scale.
* @param shift Pointer to the input shift.
* @param out Pointer to the output data.
* @param data_shape Shape of the input data.
* @param scale_shape Shape of the input scale.
* @param shift_shape Shape of the input shift.
* @param destination_type Name of the destination type.
*/
template <typename T, typename std::enable_if<std::is_same<T, float16>::value, bool>::type = true>
void fake_convert(const T* data,
const T* scale,
const T* shift,
T* out,
const Shape& data_shape,
const Shape& scale_shape,
const Shape& shift_shape,
const std::string& destination_type) {
const size_t element_count = shape_size(data_shape);
fake_convert_details::apply_scale_shift<float16>(out,
data,
scale,
shift,
data_shape,
scale_shape,
shift_shape,
false);
fake_convert_details::apply_conversion(out, out, element_count, destination_type);
fake_convert_details::apply_scale_shift<float16>(out,
out,
scale,
shift,
data_shape,
scale_shape,
shift_shape,
true);
}
/**
* @brief Reference implementation of the FakeConvert operation for floating-point types, with middle conversion
* to float16. It emulates format specified by "destination_type" on the input type. FakeConvert performs
* element-wise quantization of floating-point input values into a set of values corresponding to a target low-precision
* floating-point type.
*
*
* @param data Pointer to the input data.
* @param scale Pointer to the input scale.
* @param shift Pointer to the input shift.
* @param out Pointer to the output data.
* @param data_shape Shape of the input data.
* @param scale_shape Shape of the input scale.
* @param shift_shape Shape of the input shift.
* @param destination_type Name of the destination type.
*/
template <typename T, typename std::enable_if<!std::is_same<T, float16>::value, bool>::type = true>
void fake_convert(const T* data,
const T* scale,
const T* shift,
T* out,
const Shape& data_shape,
const Shape& scale_shape,
const Shape& shift_shape,
const std::string& destination_type) {
const size_t element_count = shape_size(data_shape);
fake_convert_details::apply_scale_shift<T>(out, data, scale, shift, data_shape, scale_shape, shift_shape, false);
std::vector<ov::float16> tmp_fp16(element_count, 0.f);
reference::convert(out, tmp_fp16.data(), element_count);
fake_convert_details::apply_conversion(tmp_fp16.data(), tmp_fp16.data(), element_count, destination_type);
reference::convert(tmp_fp16.data(), out, element_count);
fake_convert_details::apply_scale_shift<T>(out, out, scale, shift, data_shape, scale_shape, shift_shape, true);
}
/**
* @brief Reference implementation of the FakeConvert operation, with default `shift` input.
*/
template <typename T>
void fake_convert(const T* data,
const T* scale,
T* out,
const Shape& data_shape,
const Shape& scale_shape,
const std::string& destination_type) {
const auto shift = std::vector<T>(shape_size(scale_shape), 0.f);
fake_convert<T>(data, scale, shift.data(), out, data_shape, scale_shape, scale_shape, destination_type);
}
} // namespace reference
} // namespace ov

View File

@ -0,0 +1,163 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/reference/fake_convert.hpp"
namespace ov {
namespace reference {
namespace func {
void emulate_f8e5m2_on_fp16(const float16* const arg_f, float16* out_f, size_t count) {
const auto arg_u = reinterpret_cast<const uint16_t*>(arg_f);
auto out_u = reinterpret_cast<uint16_t*>(out_f);
uint16_t val_bit_repr;
constexpr auto exp_bits = 5;
constexpr auto mbits = 8;
constexpr auto non_mant_bits = exp_bits + 1; /// exponent + sign
constexpr auto lshift = 10 - (mbits - non_mant_bits);
constexpr uint16_t mask_mant = static_cast<uint16_t>(0xFFFF << lshift); /// 1111111111111111 -> 1 11111 1100000000
constexpr uint16_t grs_bitmask = 0x00FF; /// 0 00000 0011111111, grs denotes guard, round, sticky bits
constexpr uint16_t rne_tie = 0x0180; /// 0 00000 0110000000, rne denotes round to nearest even
constexpr uint16_t fp16_inf = 0x7C00;
for (size_t i = 0; i < count; ++i) {
/// converts float number to half precision in round-to-nearest-even mode and returns half with converted value.
val_bit_repr = arg_u[i];
/// 0x7c00 = 0111110000000000 - exponent mask
/// s 11111 xxx xxxx xxxx - is nan (if some x is 1) or inf (if all x is 0)
/// 0x7800 is 0111100000000000 and 0x400 is 0000010000000000
/// number is not normal if all exponent is 1 or 0
/// 0x7f00 is 0 11111 1100000000
/// 0x7b00 is 0 11110 1100000000
const bool can_round = ((val_bit_repr & 0x7F00) < 0x7B00) ? true : false;
/// s 11111 xxx xxxx xxxx - is nan (if some x is 1) or inf (if all x is 0)
const bool is_naninf = ((val_bit_repr & fp16_inf) == fp16_inf) ? true : false;
/* nearest rounding masks */
/// grs_bitmask - grs_bitmask is 0 00000 0011111111 or 0 00000 00grs11111
uint16_t rnmask = (val_bit_repr & grs_bitmask);
/// rne_tie - 0x180 is 0 00000 0110000000 or 384.0
uint16_t rnmask_tie = (val_bit_repr & rne_tie);
if (!is_naninf && can_round) {
/* round to nearest even, if rne_mask is enabled */
/* 0 00000 0010000000, find grs patterns */
// 0xx - do nothing
// 100 - this is a tie : round up if the mantissa's bit just before G is 1, else do nothing
// 101, 110, 111 - round up > 0x0080
val_bit_repr += (((rnmask > 0x0080) || (rnmask_tie == rne_tie)) << lshift);
}
val_bit_repr &= mask_mant; /* truncation */
out_u[i] = val_bit_repr;
}
}
/**
* @brief Emulation of conversion fp16 value to f8e4m3 format
*
* @param arg_f Pointer to the input data.
* @param out_f Pointer to the otuput data.
* @param count Number of elements in the data input.
*
* Exponent denormal values 0 -7
* Exponent normal values 1..15 -6..8 (7 - exponent)
* Exponent NaN values 15 8
*
*/
void emulate_f8e4m3_on_fp16(const float16* arg_f, float16* out_f, size_t count) {
const auto arg_u = reinterpret_cast<const uint16_t*>(arg_f);
auto out_u = reinterpret_cast<uint16_t*>(out_f);
uint16_t val_bit_repr;
constexpr auto use_clamp = true;
constexpr auto exp_bits = 5;
constexpr auto mbits = 9;
constexpr auto non_mant_bits = exp_bits + 1; /// exponent + sign
constexpr auto lshift = 10 - (mbits - non_mant_bits);
constexpr auto fp16_exp_bias = 15;
constexpr auto f8e4m3_min_val = 0.001953125f; /// 2**-9
constexpr uint16_t mask_mant = static_cast<uint16_t>(0xFFFF << lshift); /// 1111111111111111 -> 1 11111 1111000000
constexpr uint16_t grs_bitmask = 0x007F; /// 0 00000 0001111111, grs denotes guard, round, sticky bits
constexpr uint16_t rne_tie = 0x00C0; /// 0 00000 0011000000, rne denotes round to nearest even
constexpr uint16_t rne_mask = 1;
constexpr uint16_t fp16_inf = 0x7C00;
for (size_t i = 0; i < count; ++i) {
val_bit_repr = arg_u[i];
/* flush values below 1-4-3 (offset=4) subnormal range to zero */
if (std::abs(static_cast<float>(arg_f[i])) < f8e4m3_min_val) {
val_bit_repr = 0;
}
short exp_h = static_cast<short>((val_bit_repr & fp16_inf) >> 10) -
fp16_exp_bias; /// 0111110000000000 -> 0000000000011111 - 15, biased exponent for fp16
const short sign_h = (val_bit_repr & 0x8000); /// & 1 00000 0000000000
short mantissa_h = (val_bit_repr & 0x03FF); /// & 0 00000 1111111111
///(val_bit_repr && 0111111111111111) < 0 10010 1110000000 (19326)
const bool can_round = ((val_bit_repr & 0x7FFF) < 0b101111110000000) ? true : false;
bool is_naninf = ((val_bit_repr & fp16_inf) == fp16_inf) ? true : false;
int dshift = 0;
if (exp_h > 8) { // too large, set it to NaN or inf
is_naninf = true;
if (use_clamp) {
exp_h = 8;
mantissa_h = 0b0000001100000000;
} else {
mantissa_h = 0;
exp_h = 16;
}
} else if (exp_h < -9) { /// -13, -12, -11 for rounding
/* flush values below 1-4-3 (offset=4) subnormal range to zero */
exp_h = -15;
mantissa_h = 0;
}
if (exp_h == 8 && mantissa_h >= 0b0000001100000000) {
mantissa_h = 0b0000001100000000;
}
/* nearest rounding masks, & 0 00000 000 111 1111 - mantissa bits below f8e4m3 (grs) */
const uint16_t rnmask = (mantissa_h & grs_bitmask);
/* & 0 00000 0011000000 - edge between f8e4m3 and fp16 mantissa */
const uint16_t rnmask_tie = (mantissa_h & rne_tie);
if (!is_naninf && can_round && rne_mask) {
/* round to nearest even, if rne_mask is enabled */
/// rnmask > 0 00000 0001000000(64) or 0 00000 0011000000 - edge bits is 1
/// += 0 00000 0010000000
mantissa_h += (((rnmask > 0x0040) || (rnmask_tie == rne_tie)) << lshift);
}
if (exp_h < -6) { /* handle denormals -11, -10, -9, dshift 1, 2, 3 */
dshift = (-6 - exp_h);
mantissa_h = mantissa_h >> dshift;
}
mantissa_h &= mask_mant; /* truncation */
mantissa_h <<= dshift;
mantissa_h += ((exp_h + 15) << 10);
val_bit_repr = mantissa_h | sign_h;
out_u[i] = val_bit_repr;
}
}
} // namespace func
namespace fake_convert_details {
/**
* @brief Call conversion of fp16 value to the desired destination type
*
* @param arg Pointer to the input data.
* @param out Pointer to the otuput data.
* @param count Number of elements in the data input.
* @param destination_type Name of the destination type.
*/
void apply_conversion(const float16* data, float16* out, size_t element_count, const std::string& destination_type) {
if (destination_type == "f8e5m2") {
reference::func::emulate_f8e5m2_on_fp16(data, out, element_count);
} else if (destination_type == "f8e4m3") {
reference::func::emulate_f8e4m3_on_fp16(data, out, element_count);
} else {
OPENVINO_THROW("Unsupported destination type.");
}
}
} // namespace fake_convert_details
} // namespace reference
} // namespace ov

View File

@ -4,17 +4,47 @@
#include "openvino/op/fake_convert.hpp" #include "openvino/op/fake_convert.hpp"
#include "element_visitor.hpp"
#include "itt.hpp" #include "itt.hpp"
#include "openvino/reference/convert.hpp"
#include "openvino/reference/fake_convert.hpp"
namespace ov { namespace ov {
namespace op { namespace op {
namespace v13 { namespace v13 {
namespace fake_convert { namespace fake_convert_details {
static const std::vector<std::string>& get_valid_types() { static const std::vector<std::string>& get_valid_types() {
static const std::vector<std::string> valid_types{"f8e4m3", "f8e5m2"}; static const std::vector<std::string> valid_types{"f8e4m3", "f8e5m2"};
return valid_types; return valid_types;
} }
} // namespace fake_convert
struct Evaluate : element::NoAction<bool> {
using element::NoAction<bool>::visit;
template <element::Type_t ET, class T = fundamental_type_for<ET>>
static result_type visit(ov::TensorVector& outputs,
const ov::TensorVector& inputs,
const std::string& destination_type) {
if (inputs.size() == 2) { // Default shift
reference::fake_convert<T>(inputs[0].data<const T>(),
inputs[1].data<const T>(),
outputs[0].data<T>(),
inputs[0].get_shape(),
inputs[1].get_shape(),
destination_type);
} else {
reference::fake_convert<T>(inputs[0].data<const T>(),
inputs[1].data<const T>(),
inputs[2].data<const T>(),
outputs[0].data<T>(),
inputs[0].get_shape(),
inputs[1].get_shape(),
inputs[2].get_shape(),
destination_type);
}
return true;
}
};
} // namespace fake_convert_details
FakeConvert::FakeConvert(const ov::Output<ov::Node>& arg, FakeConvert::FakeConvert(const ov::Output<ov::Node>& arg,
const ov::Output<ov::Node>& scale, const ov::Output<ov::Node>& scale,
@ -61,15 +91,40 @@ bool FakeConvert::visit_attributes(ov::AttributeVisitor& visitor) {
} }
void FakeConvert::validate_type() const { void FakeConvert::validate_type() const {
const auto& valid_types = fake_convert::get_valid_types(); const auto& valid_types = fake_convert_details::get_valid_types();
OPENVINO_ASSERT(std::find(valid_types.begin(), valid_types.end(), m_destination_type) != valid_types.end(), const auto is_supported_type =
"Bad format for f8 conversion type: " + m_destination_type); std::find(valid_types.begin(), valid_types.end(), m_destination_type) != valid_types.end();
OPENVINO_ASSERT(is_supported_type, "Bad format for f8 conversion type: ", m_destination_type);
} }
bool FakeConvert::has_evaluate() const { bool FakeConvert::has_evaluate() const {
return false; OV_OP_SCOPE(v13_FakeConvert_has_evaluate);
switch (get_input_element_type(0)) {
case element::f16:
case element::bf16:
case element::f32:
return true;
default:
return false;
}
} }
bool FakeConvert::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const {
OV_OP_SCOPE(v13_FakeConvert_evaluate);
OPENVINO_ASSERT(outputs.size() == 1);
OPENVINO_ASSERT(inputs.size() == 2 || inputs.size() == 3);
outputs[0].set_shape(inputs[0].get_shape());
using namespace ov::element;
return IfTypeOf<bf16, f16, f32>::apply<fake_convert_details::Evaluate>(inputs[0].get_element_type(),
outputs,
inputs,
get_destination_type());
return true;
}
} // namespace v13 } // namespace v13
} // namespace op } // namespace op
} // namespace ov } // namespace ov

File diff suppressed because it is too large Load Diff