From 63e7e2dd3ecc253112373ed6b42afd2f8556bff5 Mon Sep 17 00:00:00 2001 From: Katarzyna Mitrus Date: Fri, 24 Nov 2023 09:54:19 +0100 Subject: [PATCH] [Opset13][FP8] FakeConvert evaluate and reference impl (#21034) * FakeConvert op init * Update dest types names * Update op hpp * Update opset ops number * Init type_prop tests * Add attributes tests * Add op check test * Update namespace in fc cpp * Update getters * Refactor static member * Make destination_type lower case * Update type in test * Move get_valid_types out of class * Update ops number in opset * Evaluate init * Init poc evaluate * Update types name * Style update * FakeConvert eval tests * Add support for FP16 input * Remove unused functions * Use const auto and constexpr * Create fake_convert reference file and move ref functions * FakeConvert eval and reference update * Remove unused is_normal variable * Add bf16 tests * Update construtor in tests * adjust convert helper name * Update comments * Make func params as constexpr * Update types * Non scalar scale shape tests * Use autobrodacast_select * Use single autobroadcast_select * Check scale_shape size * Add var to keep scale_shift lambda * Use lamda only for autobroadcast * More checks for per channel scale * Remove union half_t * Change template to fp16 type for emulate * Use f8e4m3_min_val as constexpr * Update unsigned short to uint16_t * Minor style refactor * Minor comments update * Update fake_convert_details namespace to func * Update apply_conversion return type * Add doxygen text * Update supported type assert * Remove duplicated tests * Update namespace move to cpp * Fill init with zeroes instead of reserve * Use add div mul sub reference for applying scale * Use autobroadcast_select for applying scale * Bring back opt broacast cases * Reuse scale_shift as func * Adjust s and o var names * Update multiplication in loop tu accumulate --- src/core/include/openvino/op/fake_convert.hpp | 1 + .../openvino/reference/fake_convert.hpp | 221 +++ src/core/reference/src/op/fake_convert.cpp | 163 +++ src/core/src/op/fake_convert.cpp | 67 +- src/core/tests/eval.cpp | 1227 +++++++++++++++++ 5 files changed, 1673 insertions(+), 6 deletions(-) create mode 100644 src/core/reference/include/openvino/reference/fake_convert.hpp create mode 100644 src/core/reference/src/op/fake_convert.cpp diff --git a/src/core/include/openvino/op/fake_convert.hpp b/src/core/include/openvino/op/fake_convert.hpp index 0c63de608c7..07c0908b08b 100644 --- a/src/core/include/openvino/op/fake_convert.hpp +++ b/src/core/include/openvino/op/fake_convert.hpp @@ -27,6 +27,7 @@ public: void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; bool visit_attributes(ov::AttributeVisitor& visitor) override; + bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override; bool has_evaluate() const override; const std::string& get_destination_type() const; diff --git a/src/core/reference/include/openvino/reference/fake_convert.hpp b/src/core/reference/include/openvino/reference/fake_convert.hpp new file mode 100644 index 00000000000..a6e1283a568 --- /dev/null +++ b/src/core/reference/include/openvino/reference/fake_convert.hpp @@ -0,0 +1,221 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include + +#include "openvino/reference/add.hpp" +#include "openvino/reference/autobroadcast_binop.hpp" +#include "openvino/reference/convert.hpp" +#include "openvino/reference/divide.hpp" +#include "openvino/reference/multiply.hpp" +#include "openvino/reference/subtract.hpp" + +namespace ov { +namespace reference { +namespace func { +/** + * @brief Emulation of conversion fp16 value to f8e5m2 format + * + * @param arg_f Pointer to the input data. + * @param out_f Pointer to the otuput data. + * @param count Number of elements in the data input. + */ +void emulate_f8e5m2_on_fp16(const float16* const arg_f, float16* out_f, size_t count); + +/** + * @brief Emulation of conversion fp16 value to f8e4m3 format + * + * @param arg_f Pointer to the input data. + * @param out_f Pointer to the otuput data. + * @param count Number of elements in the data input. + * + * Exponent denormal values 0 -7 + * Exponent normal values 1..15 -6..8 (7 - exponent) + * Exponent NaN values 15 8 + * + */ +void emulate_f8e4m3_on_fp16(const float16* arg_f, float16* out_f, size_t count); +} // namespace func + +namespace fake_convert_details { +template +inline T scale_shift(T val, T scale, T shift) { + return INVERT ? static_cast((val + shift) / scale) : static_cast(val * scale - shift); +} +/** + * @brief Apply scale and shift for the data input. + * The scale_shape and shift_shape should be equal and numpy-broadcastable to the data_shape. + * + * @param data Pointer to the input data. + * @param scale Pointer to the input scale. + * @param shift Pointer to the input shift. + * @param data_shape Shape of the input data. + * @param scale_shape Shape of the input scale. + * @param shift_shape Shape of the input shift. + * @param invert Flag denoting applying scale before (if false) or after conversion (if true). + */ +template +void apply_scale_shift(T* out, + const T* data, + const T* scale, + const T* shift, + const Shape& data_shape, + const Shape& scale_shape, + const Shape& shift_shape, + const bool invert = false) { + const auto scale_shift_func = invert ? scale_shift : scale_shift; + const auto scale_size = shape_size(scale_shape); + const auto data_size = shape_size(data_shape); + + if (scale_size == 1) { // per tensor scale, probably for activation + T scale_val = scale[0]; + T shift_val = shift[0]; + for (size_t j = 0; j < data_size; ++j) { + out[j] = scale_shift_func(data[j], scale_val, shift_val); + } + return; + } + + if (scale_shape.size() > 1 && scale_shape[0] == 1 && data_shape.size() == scale_shape.size() && + data_shape[1] == scale_shape[1] && scale_size == data_shape[1]) { // per channel scale for DW activations + const auto step = shape_size(data_shape.begin() + 2, data_shape.end()); + for (size_t bs = 0; bs < data_shape[0]; ++bs) { + for (size_t i = 0; i < scale_size; i++) { + T scale_val = scale[i]; + T shift_val = shift[i]; + for (size_t j = 0; j < step; ++j) { + out[j] = scale_shift_func(data[j], scale_val, shift_val); + } + data += step; + out += step; + } + } + return; + } + + // The specific cases above are optimized verions of the broadcast for specific case + // Autobroadcast helper is generic approach for broadcast + autobroadcast_select(data, + scale, + shift, + out, + data_shape, + scale_shape, + shift_shape, + op::AutoBroadcastType::NUMPY, + scale_shift_func); +} +/** + * @brief Call conversion of fp16 value to the desired destination type + * + * @param arg Pointer to the input data. + * @param out Pointer to the otuput data. + * @param count Number of elements in the data input. + * @param destination_type Name of the destination type. + */ +void apply_conversion(const float16* data, float16* out, size_t element_count, const std::string& destination_type); +} // namespace fake_convert_details + +/** + * @brief Reference implementation of the FakeConvert operation specialized for float16 type. + * It emulates format specified by "destination_type" on the input type. + * FakeConvert performs element-wise quantization of floating-point input values into a set of values corresponding to a + * target low-precision floating-point type. + * + * + * @param data Pointer to the input data. + * @param scale Pointer to the input scale. + * @param shift Pointer to the input shift. + * @param out Pointer to the output data. + * @param data_shape Shape of the input data. + * @param scale_shape Shape of the input scale. + * @param shift_shape Shape of the input shift. + * @param destination_type Name of the destination type. + */ +template ::value, bool>::type = true> +void fake_convert(const T* data, + const T* scale, + const T* shift, + T* out, + const Shape& data_shape, + const Shape& scale_shape, + const Shape& shift_shape, + const std::string& destination_type) { + const size_t element_count = shape_size(data_shape); + fake_convert_details::apply_scale_shift(out, + data, + scale, + shift, + data_shape, + scale_shape, + shift_shape, + false); + fake_convert_details::apply_conversion(out, out, element_count, destination_type); + fake_convert_details::apply_scale_shift(out, + out, + scale, + shift, + data_shape, + scale_shape, + shift_shape, + true); +} + +/** + * @brief Reference implementation of the FakeConvert operation for floating-point types, with middle conversion + * to float16. It emulates format specified by "destination_type" on the input type. FakeConvert performs + * element-wise quantization of floating-point input values into a set of values corresponding to a target low-precision + * floating-point type. + * + * + * @param data Pointer to the input data. + * @param scale Pointer to the input scale. + * @param shift Pointer to the input shift. + * @param out Pointer to the output data. + * @param data_shape Shape of the input data. + * @param scale_shape Shape of the input scale. + * @param shift_shape Shape of the input shift. + * @param destination_type Name of the destination type. + */ +template ::value, bool>::type = true> +void fake_convert(const T* data, + const T* scale, + const T* shift, + T* out, + const Shape& data_shape, + const Shape& scale_shape, + const Shape& shift_shape, + const std::string& destination_type) { + const size_t element_count = shape_size(data_shape); + fake_convert_details::apply_scale_shift(out, data, scale, shift, data_shape, scale_shape, shift_shape, false); + + std::vector tmp_fp16(element_count, 0.f); + reference::convert(out, tmp_fp16.data(), element_count); + fake_convert_details::apply_conversion(tmp_fp16.data(), tmp_fp16.data(), element_count, destination_type); + reference::convert(tmp_fp16.data(), out, element_count); + + fake_convert_details::apply_scale_shift(out, out, scale, shift, data_shape, scale_shape, shift_shape, true); +} + +/** + * @brief Reference implementation of the FakeConvert operation, with default `shift` input. + */ +template +void fake_convert(const T* data, + const T* scale, + T* out, + const Shape& data_shape, + const Shape& scale_shape, + const std::string& destination_type) { + const auto shift = std::vector(shape_size(scale_shape), 0.f); + fake_convert(data, scale, shift.data(), out, data_shape, scale_shape, scale_shape, destination_type); +} + +} // namespace reference +} // namespace ov diff --git a/src/core/reference/src/op/fake_convert.cpp b/src/core/reference/src/op/fake_convert.cpp new file mode 100644 index 00000000000..fa7499e6a12 --- /dev/null +++ b/src/core/reference/src/op/fake_convert.cpp @@ -0,0 +1,163 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/reference/fake_convert.hpp" + +namespace ov { +namespace reference { +namespace func { +void emulate_f8e5m2_on_fp16(const float16* const arg_f, float16* out_f, size_t count) { + const auto arg_u = reinterpret_cast(arg_f); + auto out_u = reinterpret_cast(out_f); + uint16_t val_bit_repr; + + constexpr auto exp_bits = 5; + constexpr auto mbits = 8; + constexpr auto non_mant_bits = exp_bits + 1; /// exponent + sign + constexpr auto lshift = 10 - (mbits - non_mant_bits); + constexpr uint16_t mask_mant = static_cast(0xFFFF << lshift); /// 1111111111111111 -> 1 11111 1100000000 + constexpr uint16_t grs_bitmask = 0x00FF; /// 0 00000 0011111111, grs denotes guard, round, sticky bits + constexpr uint16_t rne_tie = 0x0180; /// 0 00000 0110000000, rne denotes round to nearest even + constexpr uint16_t fp16_inf = 0x7C00; + + for (size_t i = 0; i < count; ++i) { + /// converts float number to half precision in round-to-nearest-even mode and returns half with converted value. + val_bit_repr = arg_u[i]; + /// 0x7c00 = 0111110000000000 - exponent mask + /// s 11111 xxx xxxx xxxx - is nan (if some x is 1) or inf (if all x is 0) + /// 0x7800 is 0111100000000000 and 0x400 is 0000010000000000 + /// number is not normal if all exponent is 1 or 0 + /// 0x7f00 is 0 11111 1100000000 + /// 0x7b00 is 0 11110 1100000000 + const bool can_round = ((val_bit_repr & 0x7F00) < 0x7B00) ? true : false; + /// s 11111 xxx xxxx xxxx - is nan (if some x is 1) or inf (if all x is 0) + const bool is_naninf = ((val_bit_repr & fp16_inf) == fp16_inf) ? true : false; + /* nearest rounding masks */ + /// grs_bitmask - grs_bitmask is 0 00000 0011111111 or 0 00000 00grs11111 + uint16_t rnmask = (val_bit_repr & grs_bitmask); + /// rne_tie - 0x180 is 0 00000 0110000000 or 384.0 + uint16_t rnmask_tie = (val_bit_repr & rne_tie); + + if (!is_naninf && can_round) { + /* round to nearest even, if rne_mask is enabled */ + /* 0 00000 0010000000, find grs patterns */ + // 0xx - do nothing + // 100 - this is a tie : round up if the mantissa's bit just before G is 1, else do nothing + // 101, 110, 111 - round up > 0x0080 + val_bit_repr += (((rnmask > 0x0080) || (rnmask_tie == rne_tie)) << lshift); + } + val_bit_repr &= mask_mant; /* truncation */ + out_u[i] = val_bit_repr; + } +} + +/** + * @brief Emulation of conversion fp16 value to f8e4m3 format + * + * @param arg_f Pointer to the input data. + * @param out_f Pointer to the otuput data. + * @param count Number of elements in the data input. + * + * Exponent denormal values 0 -7 + * Exponent normal values 1..15 -6..8 (7 - exponent) + * Exponent NaN values 15 8 + * + */ +void emulate_f8e4m3_on_fp16(const float16* arg_f, float16* out_f, size_t count) { + const auto arg_u = reinterpret_cast(arg_f); + auto out_u = reinterpret_cast(out_f); + uint16_t val_bit_repr; + + constexpr auto use_clamp = true; + constexpr auto exp_bits = 5; + constexpr auto mbits = 9; + constexpr auto non_mant_bits = exp_bits + 1; /// exponent + sign + constexpr auto lshift = 10 - (mbits - non_mant_bits); + constexpr auto fp16_exp_bias = 15; + constexpr auto f8e4m3_min_val = 0.001953125f; /// 2**-9 + + constexpr uint16_t mask_mant = static_cast(0xFFFF << lshift); /// 1111111111111111 -> 1 11111 1111000000 + constexpr uint16_t grs_bitmask = 0x007F; /// 0 00000 0001111111, grs denotes guard, round, sticky bits + constexpr uint16_t rne_tie = 0x00C0; /// 0 00000 0011000000, rne denotes round to nearest even + constexpr uint16_t rne_mask = 1; + constexpr uint16_t fp16_inf = 0x7C00; + + for (size_t i = 0; i < count; ++i) { + val_bit_repr = arg_u[i]; + /* flush values below 1-4-3 (offset=4) subnormal range to zero */ + if (std::abs(static_cast(arg_f[i])) < f8e4m3_min_val) { + val_bit_repr = 0; + } + + short exp_h = static_cast((val_bit_repr & fp16_inf) >> 10) - + fp16_exp_bias; /// 0111110000000000 -> 0000000000011111 - 15, biased exponent for fp16 + const short sign_h = (val_bit_repr & 0x8000); /// & 1 00000 0000000000 + short mantissa_h = (val_bit_repr & 0x03FF); /// & 0 00000 1111111111 + ///(val_bit_repr && 0111111111111111) < 0 10010 1110000000 (19326) + const bool can_round = ((val_bit_repr & 0x7FFF) < 0b101111110000000) ? true : false; + bool is_naninf = ((val_bit_repr & fp16_inf) == fp16_inf) ? true : false; + + int dshift = 0; + if (exp_h > 8) { // too large, set it to NaN or inf + is_naninf = true; + if (use_clamp) { + exp_h = 8; + mantissa_h = 0b0000001100000000; + } else { + mantissa_h = 0; + exp_h = 16; + } + } else if (exp_h < -9) { /// -13, -12, -11 for rounding + /* flush values below 1-4-3 (offset=4) subnormal range to zero */ + exp_h = -15; + mantissa_h = 0; + } + + if (exp_h == 8 && mantissa_h >= 0b0000001100000000) { + mantissa_h = 0b0000001100000000; + } + /* nearest rounding masks, & 0 00000 000 111 1111 - mantissa bits below f8e4m3 (grs) */ + const uint16_t rnmask = (mantissa_h & grs_bitmask); + /* & 0 00000 0011000000 - edge between f8e4m3 and fp16 mantissa */ + const uint16_t rnmask_tie = (mantissa_h & rne_tie); + if (!is_naninf && can_round && rne_mask) { + /* round to nearest even, if rne_mask is enabled */ + /// rnmask > 0 00000 0001000000(64) or 0 00000 0011000000 - edge bits is 1 + /// += 0 00000 0010000000 + mantissa_h += (((rnmask > 0x0040) || (rnmask_tie == rne_tie)) << lshift); + } + if (exp_h < -6) { /* handle denormals -11, -10, -9, dshift 1, 2, 3 */ + dshift = (-6 - exp_h); + mantissa_h = mantissa_h >> dshift; + } + mantissa_h &= mask_mant; /* truncation */ + mantissa_h <<= dshift; + mantissa_h += ((exp_h + 15) << 10); + val_bit_repr = mantissa_h | sign_h; + out_u[i] = val_bit_repr; + } +} +} // namespace func + +namespace fake_convert_details { +/** + * @brief Call conversion of fp16 value to the desired destination type + * + * @param arg Pointer to the input data. + * @param out Pointer to the otuput data. + * @param count Number of elements in the data input. + * @param destination_type Name of the destination type. + */ +void apply_conversion(const float16* data, float16* out, size_t element_count, const std::string& destination_type) { + if (destination_type == "f8e5m2") { + reference::func::emulate_f8e5m2_on_fp16(data, out, element_count); + } else if (destination_type == "f8e4m3") { + reference::func::emulate_f8e4m3_on_fp16(data, out, element_count); + } else { + OPENVINO_THROW("Unsupported destination type."); + } +} +} // namespace fake_convert_details +} // namespace reference +} // namespace ov diff --git a/src/core/src/op/fake_convert.cpp b/src/core/src/op/fake_convert.cpp index f32db7409dd..6f7dfd3614d 100644 --- a/src/core/src/op/fake_convert.cpp +++ b/src/core/src/op/fake_convert.cpp @@ -4,17 +4,47 @@ #include "openvino/op/fake_convert.hpp" +#include "element_visitor.hpp" #include "itt.hpp" +#include "openvino/reference/convert.hpp" +#include "openvino/reference/fake_convert.hpp" namespace ov { namespace op { namespace v13 { -namespace fake_convert { +namespace fake_convert_details { static const std::vector& get_valid_types() { static const std::vector valid_types{"f8e4m3", "f8e5m2"}; return valid_types; } -} // namespace fake_convert + +struct Evaluate : element::NoAction { + using element::NoAction::visit; + template > + static result_type visit(ov::TensorVector& outputs, + const ov::TensorVector& inputs, + const std::string& destination_type) { + if (inputs.size() == 2) { // Default shift + reference::fake_convert(inputs[0].data(), + inputs[1].data(), + outputs[0].data(), + inputs[0].get_shape(), + inputs[1].get_shape(), + destination_type); + } else { + reference::fake_convert(inputs[0].data(), + inputs[1].data(), + inputs[2].data(), + outputs[0].data(), + inputs[0].get_shape(), + inputs[1].get_shape(), + inputs[2].get_shape(), + destination_type); + } + return true; + } +}; +} // namespace fake_convert_details FakeConvert::FakeConvert(const ov::Output& arg, const ov::Output& scale, @@ -61,15 +91,40 @@ bool FakeConvert::visit_attributes(ov::AttributeVisitor& visitor) { } void FakeConvert::validate_type() const { - const auto& valid_types = fake_convert::get_valid_types(); - OPENVINO_ASSERT(std::find(valid_types.begin(), valid_types.end(), m_destination_type) != valid_types.end(), - "Bad format for f8 conversion type: " + m_destination_type); + const auto& valid_types = fake_convert_details::get_valid_types(); + const auto is_supported_type = + std::find(valid_types.begin(), valid_types.end(), m_destination_type) != valid_types.end(); + OPENVINO_ASSERT(is_supported_type, "Bad format for f8 conversion type: ", m_destination_type); } bool FakeConvert::has_evaluate() const { - return false; + OV_OP_SCOPE(v13_FakeConvert_has_evaluate); + switch (get_input_element_type(0)) { + case element::f16: + case element::bf16: + case element::f32: + return true; + default: + return false; + } } +bool FakeConvert::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const { + OV_OP_SCOPE(v13_FakeConvert_evaluate); + + OPENVINO_ASSERT(outputs.size() == 1); + OPENVINO_ASSERT(inputs.size() == 2 || inputs.size() == 3); + + outputs[0].set_shape(inputs[0].get_shape()); + + using namespace ov::element; + return IfTypeOf::apply(inputs[0].get_element_type(), + outputs, + inputs, + get_destination_type()); + + return true; +} } // namespace v13 } // namespace op } // namespace ov diff --git a/src/core/tests/eval.cpp b/src/core/tests/eval.cpp index 1edce9028bb..fb85b589873 100644 --- a/src/core/tests/eval.cpp +++ b/src/core/tests/eval.cpp @@ -34,6 +34,7 @@ #include "openvino/op/cum_sum.hpp" #include "openvino/op/erf.hpp" #include "openvino/op/exp.hpp" +#include "openvino/op/fake_convert.hpp" #include "openvino/op/fake_quantize.hpp" #include "openvino/op/floor.hpp" #include "openvino/op/gather.hpp" @@ -62,6 +63,7 @@ #include "openvino/op/tan.hpp" #include "openvino/op/tanh.hpp" #include "openvino/op/topk.hpp" +#include "openvino/op/transpose.hpp" #include "openvino/op/unsqueeze.hpp" #include "openvino/runtime/tensor.hpp" #include "sequnce_generator.hpp" @@ -2655,6 +2657,1231 @@ TEST(eval, evaluate_fake_quantize_dynamic_input) { Pointwise(FloatEq(), std::vector{2.f, 2.6666667f, 2.6666667f, 3.3333333f, 3.3333333f, 4.f})); } +///////////////////////////////////// FakeConvert (FP8) + +namespace testing { +namespace fp8 { +constexpr float MAX_F8E4M3 = 448.f; +constexpr float MIN_F8E4M3 = 0.001953125f; + +constexpr float MAX_F8E5M2 = 57344.f; +constexpr float MIN_F8E5M2 = 0.0000152587890625f; +} // namespace fp8 +} // namespace testing + +///////////////////////////////////// FakeConvert f8e4m3 + +TEST(eval, evaluate_fake_convert_f32_to_f8e4m3_scale_small) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{fp8::MIN_F8E4M3, fp8::MIN_F8E4M3 / 2.f, fp8::MIN_F8E4M3 / 4.f}; + + const auto data_shape = Shape{input_data.size()}; + + auto data = make_shared(et, data_shape); + auto max_input_val = std::abs(*std::max_element(input_data.begin(), input_data.end())); + + auto scale = op::v0::Constant::create(et, Shape{1}, {fp8::MAX_F8E4M3 / max_input_val}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e4m3"); + + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), + Pointwise(FloatEq(), std::vector{0.001953125f, 0.0009765625f, 0.00048828125f})); +} + +TEST(eval, evaluate_fake_convert_f32_to_f8e4m3_scale_1_small) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{fp8::MIN_F8E4M3, fp8::MIN_F8E4M3 / 2.f, fp8::MIN_F8E4M3 / 4.f}; + + const auto data_shape = Shape{input_data.size()}; + auto data = make_shared(et, data_shape); + auto scale = op::v0::Constant::create(et, Shape{1}, {1.f}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e4m3"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), std::vector{0.001953125f, 0.f, 0.f})); +} + +TEST(eval, evaluate_fake_convert_f32_to_f8e4m3_scale_big) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{fp8::MAX_F8E4M3 / 2.f, fp8::MAX_F8E4M3, fp8::MAX_F8E4M3 * 2.f, fp8::MAX_F8E4M3 * 4.f}; + + const auto data_shape = Shape{input_data.size()}; + auto data = make_shared(et, data_shape); + auto max_input_val = std::abs(*std::max_element(input_data.begin(), input_data.end())); + auto scale = op::v0::Constant::create(et, Shape{1}, {fp8::MAX_F8E4M3 / max_input_val}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e4m3"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), std::vector{224.f, 448.f, 896.f, 1792.f})); +} + +TEST(eval, evaluate_fake_convert_f32_to_f8e4m3_4x3_scale_big) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{fp8::MAX_F8E4M3 / 4.f, + fp8::MAX_F8E4M3 / 3.f, + fp8::MAX_F8E4M3 / 2.f, + fp8::MAX_F8E4M3, + fp8::MAX_F8E4M3, + fp8::MAX_F8E4M3, + fp8::MAX_F8E4M3 * 1.2f, + fp8::MAX_F8E4M3 * 2.3f, + fp8::MAX_F8E4M3 * 3.4f, + fp8::MAX_F8E4M3 * 2.f, + fp8::MAX_F8E4M3 * 3.f, + fp8::MAX_F8E4M3 * 4.f}; + + std::vector output_data{112, 144, 224, 448, 448, 448, 560, 1008, 1568, 896, 1280, 1792}; + + const auto data_shape = Shape{4, 3}; + auto data = make_shared(et, data_shape); + auto scale = op::v0::Constant::create(et, + Shape{4, 1}, + {fp8::MAX_F8E4M3 / (fp8::MAX_F8E4M3 / 2.f), + 1.0f, + fp8::MAX_F8E4M3 / (fp8::MAX_F8E4M3 * 3.5f), + fp8::MAX_F8E4M3 / (fp8::MAX_F8E4M3 * 4.f)}); + auto shift = op::v0::Constant::create(et, Shape{4, 1}, {0.f, 0.f, 0.f, 0.f}); + + auto op = make_shared(data, scale, shift, "f8e4m3"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), output_data)); +} + +TEST(eval, evaluate_fake_convert_f32_to_f8e4m3_3x4_scale_big) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{fp8::MAX_F8E4M3 / 4.f, + fp8::MAX_F8E4M3 / 3.f, + fp8::MAX_F8E4M3 / 2.f, + fp8::MAX_F8E4M3, + fp8::MAX_F8E4M3, + fp8::MAX_F8E4M3, + fp8::MAX_F8E4M3 * 1.2f, + fp8::MAX_F8E4M3 * 2.3f, + fp8::MAX_F8E4M3 * 3.4f, + fp8::MAX_F8E4M3 * 2.f, + fp8::MAX_F8E4M3 * 3.f, + fp8::MAX_F8E4M3 * 4.f}; + + std::vector output_data{112, 448, 560, 896, 144, 448, 1008, 1280, 224, 448, 1568, 1792}; + + const auto data_shape = Shape{4, 3}; // To be transposed to 3x4 + auto data = make_shared(et, data_shape); + std::vector order{1, 0}; + auto transpose_order = make_shared(element::i32, Shape{order.size()}, order); + auto transposed_data = make_shared(data, transpose_order); + const auto transposed_data_shape = Shape{3, 4}; + + auto scale = op::v0::Constant::create(et, + Shape{1, 4}, + {fp8::MAX_F8E4M3 / (fp8::MAX_F8E4M3 / 2.f), + 1.0f, + fp8::MAX_F8E4M3 / (fp8::MAX_F8E4M3 * 3.5f), + fp8::MAX_F8E4M3 / (fp8::MAX_F8E4M3 * 4.f)}); + auto shift = op::v0::Constant::create(et, Shape{1, 4}, {0.f, 0.f, 0.f, 0.f}); + + auto op = make_shared(transposed_data, scale, shift, "f8e4m3"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), transposed_data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), output_data)); +} + +TEST(eval, evaluate_fake_convert_f32_to_f8e4m3_scale_shift_big) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{fp8::MAX_F8E4M3 / 2.f, fp8::MAX_F8E4M3, fp8::MAX_F8E4M3 * 2.f, fp8::MAX_F8E4M3 * 4.f}; + std::vector output_data{224.f, 448.f, 896.f, 1728.f}; + + const auto data_shape = Shape{input_data.size()}; + auto data = make_shared(et, data_shape); + auto max_input_val = std::abs(*std::max_element(input_data.begin(), input_data.end())); + const auto scale_val = fp8::MAX_F8E4M3 / max_input_val; + auto scale = op::v0::Constant::create(et, Shape{1}, {scale_val}); + auto shift = op::v0::Constant::create(et, Shape{1}, {fp8::MAX_F8E4M3 * scale_val}); + + auto op = make_shared(data, scale, shift, "f8e4m3"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), output_data)); +} + +TEST(eval, evaluate_fake_convert_f32_to_f8e4m3_big_scale_1) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{fp8::MAX_F8E4M3 / 2.f, fp8::MAX_F8E4M3, fp8::MAX_F8E4M3 * 2.f, fp8::MAX_F8E4M3 * 4.f}; + const auto data_shape = Shape{input_data.size()}; + + auto data = make_shared(et, data_shape); + auto scale = op::v0::Constant::create(et, Shape{1}, {1.f}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e4m3"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), std::vector{224.f, 448.f, 448.f, 448.f})); +} + +TEST(eval, evaluate_fake_convert_f32_to_f8e4m3_scale_1) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f}; + const auto data_shape = Shape{input_data.size()}; + + auto data = make_shared(et, data_shape); + auto scale = op::v0::Constant::create(et, Shape{1}, {1.f}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e4m3"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT( + read_vector(result), + Pointwise( + FloatEq(), + std::vector< + float>{0.f, 0.1015625f, 0.203125f, 0.3125f, 0.40625f, 0.5f, 0.625f, 0.6875f, 0.8125f, 0.875f, 1.f})); +} + +TEST(eval, evaluate_fake_convert_f32_to_f8e4m3_no_scale_no_shift) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f}; + const auto data_shape = Shape{input_data.size()}; + + auto data = make_shared(et, data_shape); + auto scale = op::v0::Constant::create(et, Shape{1}, {1.0f}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e4m3"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT( + read_vector(result), + Pointwise( + FloatEq(), + std::vector< + float>{0.f, 0.1015625f, 0.203125f, 0.3125f, 0.40625f, 0.5f, 0.625f, 0.6875f, 0.8125f, 0.875f, 1.f})); +} + +TEST(eval, evaluate_fake_convert_f32_seq_to_f8e4m3_scale_1) { + using namespace testing; + constexpr auto et = element::f32; + + const auto data_shape = Shape{8}; + + std::vector input_data; + std::generate_n(std::back_inserter(input_data), shape_size(data_shape), ov::SeqGen(0.143f)); + + auto data = make_shared(et, data_shape); + auto scale = op::v0::Constant::create(et, Shape{1}, {1.f}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e4m3"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), + Pointwise(FloatEq(), std::vector{0.140625, 1.125, 2.25, 3.25, 4, 5, 6, 7})); +} + +TEST(eval, evaluate_fake_convert_f32_seq_to_f8e4m3_scale) { + using namespace testing; + constexpr auto et = element::f32; + + const auto data_shape = Shape{8}; + std::vector input_data; + std::generate_n(std::back_inserter(input_data), shape_size(data_shape), ov::SeqGen(0.143f)); + + auto data = make_shared(et, data_shape); + auto max_input_val = std::abs(*std::max_element(input_data.begin(), input_data.end())); + auto scale = op::v0::Constant::create(et, Shape{1}, {fp8::MAX_F8E4M3 / max_input_val}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e4m3"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + + EXPECT_THAT(read_vector(result), + Pointwise(FloatEq(), + std::vector{0.14349776506424f, + 1.14798212051392f, + 2.0408570766449f, + 3.06128573417664f, + 4.08171415328979f, + 5.10214281082153f, + 6.12257146835327f, + 7.14300012588501f})); +} + +TEST(eval, evaluate_fake_convert_f32_seq_to_f8e4m3_scale_shift) { + using namespace testing; + constexpr auto et = element::f32; + + const auto data_shape = Shape{8}; + std::vector input_data; + std::generate_n(std::back_inserter(input_data), shape_size(data_shape), ov::SeqGen(0.143f)); + + auto data = make_shared(et, data_shape); + auto max_input_val = std::abs(*std::max_element(input_data.begin(), input_data.end())); + auto scale = op::v0::Constant::create(et, Shape{1}, {fp8::MAX_F8E4M3 / max_input_val}); + auto shift = op::v0::Constant::create(et, Shape{1}, {5.f}); + + auto op = make_shared(data, scale, shift, "f8e4m3"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), + Pointwise(FloatEq(), + std::vector{0.14349776506424f, + 1.10014951229095f, + 2.1205780506134f, + 3.14100670814514f, + 4.1614351272583f, + 5.18186378479004f, + 6.20229244232178f, + 7.22272109985352f})); +} + +TEST(eval, evaluate_fake_convert_f32_matching_f8_to_f8e4m3_scale_1) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{ + 0.f, 0.001953125f, 0.00390625f, 0.005859375f, 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f, + 0.015625f, 0.017578125f, 0.01953125f, 0.021484375f, 0.0234375f, 0.025390625f, 0.02734375f, 0.029296875f, + 0.03125f, 0.03515625f, 0.0390625f, 0.04296875f, 0.046875f, 0.05078125f, 0.0546875f, 0.05859375f, + 0.0625f, 0.0703125f, 0.078125f, 0.0859375f, 0.09375f, 0.1015625f, 0.109375f, 0.1171875f, + 0.125f, 0.140625f, 0.15625f, 0.171875f, 0.1875f, 0.203125f, 0.21875f, 0.234375f, + 0.25f, 0.28125f, 0.3125f, 0.34375f, 0.375f, 0.40625f, 0.4375f, 0.46875f, + 0.5f, 0.5625f, 0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f, 0.9375f, + 1.f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f, + 2.f, 2.25f, 2.5f, 2.75f, 3.f, 3.25f, 3.5f, 3.75f, + 4.f, 4.5f, 5.f, 5.5f, 6.f, 6.5f, 7.f, 7.5f, + 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, + 16.f, 18.f, 20.f, 22.f, 24.f, 26.f, 28.f, 30.f, + 32.f, 36.f, 40.f, 44.f, 48.f, 52.f, 56.f, 60.f, + 64.f, 72.f, 80.f, 88.f, 96.f, 104.f, 112.f, 120.f, + 128.f, 144.f, 160.f, 176.f, 192.f, 208.f, 224.f, 240.f, + 256.f, 288.f, 320.f, 352.f, 384.f, 416.f, 448.f}; + + const auto data_shape = Shape{input_data.size()}; + + auto data = make_shared(et, data_shape); + + auto scale = op::v0::Constant::create(et, Shape{1}, {1.f}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e4m3"); + + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), input_data)); +} + +TEST(eval, evaluate_fake_convert_f16_matching_f8_to_f8e4m3_scale_1) { + using namespace testing; + constexpr auto et = element::f16; + + std::vector input_data{ + 0.f, 0.001953125f, 0.00390625f, 0.005859375f, 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f, + 0.015625f, 0.017578125f, 0.01953125f, 0.021484375f, 0.0234375f, 0.025390625f, 0.02734375f, 0.029296875f, + 0.03125f, 0.03515625f, 0.0390625f, 0.04296875f, 0.046875f, 0.05078125f, 0.0546875f, 0.05859375f, + 0.0625f, 0.0703125f, 0.078125f, 0.0859375f, 0.09375f, 0.1015625f, 0.109375f, 0.1171875f, + 0.125f, 0.140625f, 0.15625f, 0.171875f, 0.1875f, 0.203125f, 0.21875f, 0.234375f, + 0.25f, 0.28125f, 0.3125f, 0.34375f, 0.375f, 0.40625f, 0.4375f, 0.46875f, + 0.5f, 0.5625f, 0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f, 0.9375f, + 1.f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f, + 2.f, 2.25f, 2.5f, 2.75f, 3.f, 3.25f, 3.5f, 3.75f, + 4.f, 4.5f, 5.f, 5.5f, 6.f, 6.5f, 7.f, 7.5f, + 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, + 16.f, 18.f, 20.f, 22.f, 24.f, 26.f, 28.f, 30.f, + 32.f, 36.f, 40.f, 44.f, 48.f, 52.f, 56.f, 60.f, + 64.f, 72.f, 80.f, 88.f, 96.f, 104.f, 112.f, 120.f, + 128.f, 144.f, 160.f, 176.f, 192.f, 208.f, 224.f, 240.f, + 256.f, 288.f, 320.f, 352.f, 384.f, 416.f, 448.f}; + + const auto data_shape = Shape{input_data.size()}; + + auto data = make_shared(et, data_shape); + auto scale = op::v0::Constant::create(et, Shape{1}, std::vector{1.f}); + auto shift = op::v0::Constant::create(et, Shape{1}, std::vector{0.f}); + + auto op = make_shared(data, scale, shift, "f8e4m3"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), input_data)); +} + +TEST(eval, evaluate_fake_convert_bf16_matching_f8_to_f8e4m3_scale_1) { + using namespace testing; + constexpr auto et = element::bf16; + + std::vector input_data{ + 0.f, 0.001953125f, 0.00390625f, 0.005859375f, 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f, + 0.015625f, 0.017578125f, 0.01953125f, 0.021484375f, 0.0234375f, 0.025390625f, 0.02734375f, 0.029296875f, + 0.03125f, 0.03515625f, 0.0390625f, 0.04296875f, 0.046875f, 0.05078125f, 0.0546875f, 0.05859375f, + 0.0625f, 0.0703125f, 0.078125f, 0.0859375f, 0.09375f, 0.1015625f, 0.109375f, 0.1171875f, + 0.125f, 0.140625f, 0.15625f, 0.171875f, 0.1875f, 0.203125f, 0.21875f, 0.234375f, + 0.25f, 0.28125f, 0.3125f, 0.34375f, 0.375f, 0.40625f, 0.4375f, 0.46875f, + 0.5f, 0.5625f, 0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f, 0.9375f, + 1.f, 1.125f, 1.25f, 1.375f, 1.5f, 1.625f, 1.75f, 1.875f, + 2.f, 2.25f, 2.5f, 2.75f, 3.f, 3.25f, 3.5f, 3.75f, + 4.f, 4.5f, 5.f, 5.5f, 6.f, 6.5f, 7.f, 7.5f, + 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, + 16.f, 18.f, 20.f, 22.f, 24.f, 26.f, 28.f, 30.f, + 32.f, 36.f, 40.f, 44.f, 48.f, 52.f, 56.f, 60.f, + 64.f, 72.f, 80.f, 88.f, 96.f, 104.f, 112.f, 120.f, + 128.f, 144.f, 160.f, 176.f, 192.f, 208.f, 224.f, 240.f, + 256.f, 288.f, 320.f, 352.f, 384.f, 416.f, 448.f}; + + const auto data_shape = Shape{input_data.size()}; + + auto data = make_shared(et, data_shape); + auto scale = op::v0::Constant::create(et, Shape{1}, std::vector{1.f}); + auto shift = op::v0::Constant::create(et, Shape{1}, std::vector{0.f}); + + auto op = make_shared(data, scale, shift, "f8e4m3"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), input_data)); +} + +///////////////////////////////////// FakeConvert f8e5m2 +TEST(eval, evaluate_fake_convert_f32_to_f8e5m2_scale_1) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f}; + + const auto data_shape = Shape{input_data.size()}; + + auto data = make_shared(et, data_shape); + + auto scale = op::v0::Constant::create(et, Shape{1}, {1.f}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e5m2"); + + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT( + read_vector(result), + Pointwise( + FloatEq(), + std::vector{0.f, 0.09375f, 0.1875f, 0.3125f, 0.375f, 0.5f, 0.625f, 0.75f, 0.75f, 0.875f, 1.f})); +} + +TEST(eval, evaluate_fake_convert_f16_to_f8e5m2_scale_1) { + using namespace testing; + constexpr auto et = element::f16; + + std::vector input_data{0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f}; + + const auto data_shape = Shape{input_data.size()}; + + auto data = make_shared(et, data_shape); + auto scale = op::v0::Constant::create(et, Shape{1}, {1.f}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e5m2"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT( + read_vector(result), + Pointwise( + FloatEq(), + std::vector{0.f, 0.09375f, 0.1875f, 0.3125f, 0.375f, 0.5f, 0.625f, 0.75f, 0.75f, 0.875f, 1.f})); +} + +TEST(eval, evaluate_fake_convert_bf16_to_f8e5m2_scale_1) { + using namespace testing; + constexpr auto et = element::bf16; + + std::vector input_data{0.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f}; + + const auto data_shape = Shape{input_data.size()}; + + auto data = make_shared(et, data_shape); + auto scale = op::v0::Constant::create(et, Shape{1}, {1.f}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e5m2"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT( + read_vector(result), + Pointwise( + FloatEq(), + std::vector{0.f, 0.09375f, 0.1875f, 0.3125f, 0.375f, 0.5f, 0.625f, 0.75f, 0.75f, 0.875f, 1.f})); +} + +TEST(eval, evaluate_fake_convert_f32_to_f8e5m2_scale_small) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{fp8::MIN_F8E5M2, fp8::MIN_F8E5M2 / 2.f, fp8::MIN_F8E5M2 / 4.f}; + + const auto data_shape = Shape{input_data.size()}; + auto data = make_shared(et, data_shape); + auto max_input_val = std::abs(*std::max_element(input_data.begin(), input_data.end())); + auto scale = op::v0::Constant::create(et, Shape{1}, {fp8::MAX_F8E5M2 / max_input_val}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e5m2"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), + Pointwise(FloatEq(), std::vector{1.52587890625e-05, 7.62939453125e-06, 3.814697265625e-06})); +} + +TEST(eval, evaluate_fake_convert_f32_to_f8e5m2_scale_1_small) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{fp8::MIN_F8E5M2, fp8::MIN_F8E5M2 / 2.f, fp8::MIN_F8E5M2 / 4.f}; + + const auto data_shape = Shape{input_data.size()}; + auto data = make_shared(et, data_shape); + + auto scale = op::v0::Constant::create(et, Shape{1}, {1.0f}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e5m2"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), std::vector{1.52587890625e-05, 0.f, 0.f})); +} + +TEST(eval, evaluate_fake_convert_f32_to_f8e5m2_small_scale_1) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{fp8::MIN_F8E5M2, fp8::MIN_F8E5M2 / 2.f, fp8::MIN_F8E5M2 / 4.f}; + + const auto data_shape = Shape{input_data.size()}; + auto data = make_shared(et, data_shape); + + auto scale = op::v0::Constant::create(et, Shape{1}, {1.f}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e5m2"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), std::vector{1.52587890625e-05, 0.f, 0.f})); +} + +TEST(eval, evaluate_fake_convert_f32_to_f8e5m2_scale_big) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{fp8::MAX_F8E5M2 / 2.f, fp8::MAX_F8E5M2, fp8::MAX_F8E5M2 * 2.f, fp8::MAX_F8E5M2 * 4.f}; + + const auto data_shape = Shape{input_data.size()}; + auto data = make_shared(et, data_shape); + auto max_input_val = std::abs(*std::max_element(input_data.begin(), input_data.end())); + auto scale = op::v0::Constant::create(et, Shape{1}, {fp8::MAX_F8E5M2 / max_input_val}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e5m2"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), input_data)); +} + +TEST(eval, evaluate_fake_convert_f32_to_f8e5m2_scale_shift_big) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{fp8::MAX_F8E5M2 / 2.f, fp8::MAX_F8E5M2, fp8::MAX_F8E5M2 * 2.f, fp8::MAX_F8E5M2 * 4.f}; + std::vector output_data{28672.f, 57344.f, 114688.f, 221184.f}; + + const auto data_shape = Shape{input_data.size()}; + auto data = make_shared(et, data_shape); + auto max_input_val = std::abs(*std::max_element(input_data.begin(), input_data.end())); + const auto scale_val = fp8::MAX_F8E5M2 / max_input_val; + auto scale = op::v0::Constant::create(et, Shape{1}, {scale_val}); + auto shift = op::v0::Constant::create(et, Shape{1}, {fp8::MAX_F8E5M2 * scale_val}); + + auto op = make_shared(data, scale, shift, "f8e5m2"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), output_data)); +} + +TEST(eval, evaluate_fake_convert_f16_to_f8e5m2_scale_shift_big) { + using namespace testing; + constexpr auto et = element::f16; + + std::vector input_data{fp8::MAX_F8E5M2 / 2.f, + fp8::MAX_F8E5M2, + fp8::MAX_F8E5M2 + 500.f, + fp8::MAX_F8E5M2 + 1000.f}; + std::vector output_data{28192.f, 57344.f, 57888.f, 58400.f}; + + const auto data_shape = Shape{input_data.size()}; + auto data = make_shared(et, data_shape); + auto max_input_val = std::abs(*std::max_element(input_data.begin(), input_data.end())); + const auto scale_val = fp8::MAX_F8E5M2 / max_input_val; + auto scale = op::v0::Constant::create(et, Shape{1}, {scale_val}); + auto shift = op::v0::Constant::create(et, Shape{1}, {fp8::MAX_F8E5M2 * scale_val}); + + auto op = make_shared(data, scale, shift, "f8e5m2"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), output_data)); +} + +TEST(eval, evaluate_fake_convert_bf16_to_f8e5m2_scale_shift_big) { + using namespace testing; + constexpr auto et = element::bf16; + + std::vector input_data{fp8::MAX_F8E5M2 / 2.f, + fp8::MAX_F8E5M2, + fp8::MAX_F8E5M2 + 500.f, + fp8::MAX_F8E5M2 + 1000.f}; + std::vector output_data{28032.f, 57088.f, 57856.f, 58368.f}; + + const auto data_shape = Shape{input_data.size()}; + auto data = make_shared(et, data_shape); + auto max_input_val = std::abs(*std::max_element(input_data.begin(), input_data.end())); + const auto scale_val = fp8::MAX_F8E5M2 / max_input_val; + auto scale = op::v0::Constant::create(et, Shape{1}, {scale_val}); + auto shift = op::v0::Constant::create(et, Shape{1}, {fp8::MAX_F8E5M2 * scale_val}); + + auto op = make_shared(data, scale, shift, "f8e5m2"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), output_data)); +} + +TEST(eval, evaluate_fake_convert_f32_to_f8e5m2_big_scale_1) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{fp8::MAX_F8E5M2 / 2.f, + fp8::MAX_F8E5M2, + fp8::MAX_F8E5M2 + 1, + fp8::MAX_F8E5M2 * 2.f, + fp8::MAX_F8E5M2 * 4.f}; + const auto data_shape = Shape{input_data.size()}; + + auto data = make_shared(et, data_shape); + auto scale = op::v0::Constant::create(et, Shape{1}, {1.f}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e5m2"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + + constexpr auto inf = std::numeric_limits::infinity(); + EXPECT_THAT( + read_vector(result), + Pointwise(FloatEq(), std::vector{fp8::MAX_F8E5M2 / 2.f, fp8::MAX_F8E5M2, fp8::MAX_F8E5M2, inf, inf})); +} + +TEST(eval, evaluate_fake_convert_f32_matching_f8_to_f8e5m2_scale_1) { + using namespace testing; + constexpr auto et = element::f32; + + // clang-format off + std::vector input_data{ + 0.f, 0.0000152587890625f, 0.00003051758125f, 0.0000457763671875f, + 0.00006103515625f, 0.0000762939453125f, 0.000091552734375, 0.0001068115234375, + 0.0001220703125f, 0.000152587890625f, 0.00018310546875f, 0.000213623046875f, + 0.000244140625f, 0.00030517578125, 0.0003662109375f, 0.00042724609375f, + 0.00048828125f, 0.0006103515625f, 0.000732421875f, 0.0008544921875, + 0.0009765625f, 0.001220703125, 0.00146484375f, 0.001708984375f, + 0.001953125f, 0.00244140625f, 0.0029296875f, 0.00341796875f, + 0.00390625f, 0.0048828125f, 0.005859375f, 0.0068359375f, + 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f, + + 0.015625f, /*0.017578125f,*/ 0.01953125f, /*0.021484375f,*/ 0.0234375f, /*0.025390625f,*/ 0.02734375f, /*0.029296875f,*/ + 0.03125f, /*0.03515625f,*/ 0.0390625f, /*0.04296875f,*/ 0.046875f, /*0.05078125f,*/ 0.0546875f, /*0.05859375f,*/ + 0.0625f, /*0.0703125f,*/ 0.078125f, /*0.0859375f,*/ 0.09375f, /*0.1015625f,*/ 0.109375f, /*0.1171875f,*/ + 0.125f, /*0.140625f,*/ 0.15625f, /*0.171875f,*/ 0.1875f, /*0.203125f,*/ 0.21875f, /*0.234375f,*/ + 0.25f, /*0.28125f,*/ 0.3125f, /*0.34375f,*/ 0.375f, /*0.40625f,*/ 0.4375f, /*0.46875f,*/ + 0.5f, /*0.5625f,*/ 0.625f, /*0.6875f,*/ 0.75f, /*0.8125f,*/ 0.875f, /*0.9375f,*/ + 1.f, /*1.125f,*/ 1.25f, /*1.375f,*/ 1.5f, /*1.625f,*/ 1.75f, /*1.875f,*/ + 2.f, /*2.25f,*/ 2.5f, /*2.75f,*/ 3.f, /*3.25f,*/ 3.5f, /*3.75f,*/ + 4.f, /*4.5f,*/ 5.f, /*5.5f,*/ 6.f, /*6.5f,*/ 7.f, /*7.5f,*/ + 8.f, /*9.f,*/ 10.f, /*11.f,*/ 12.f, /*13.f,*/ 14.f, /*15.f,*/ + 16.f, /*18.f,*/ 20.f, /*22.f,*/ 24.f, /*26.f,*/ 28.f, /*30.f,*/ + 32.f, /*36.f*,*/ 40.f, /*44.f,*/ 48.f, /*52.f,*/ 56.f, /*60.f,*/ + 64.f, /*72.f,*/ 80.f, /*88.f,*/ 96.f, /*104.f,*/ 112.f, /*120.f,*/ + 128.f, /*144.f,*/ 160.f, /*176.f,*/ 192.f, /*208.f,*/ 224.f, /*240.f,*/ + 256.f, /*288.f,*/ 320.f, /*352.f,*/ 384.f, /*416.f,*/ 448.f, + + 512.f, 640.f, 768.f, 896.f, + 1024.f, 1280.f, 1536.f, 1792.f, + 2048.f, 2560.f, 3072.f, 3584.f, + 4096.f, 5120.f, 6144.f, 7168.f, + 8192.f, 10240.f, 12288.f, 14336.f, + 16384.f, 20480.f, 24576.f, 28672.f, + 32768.f, 40960.f, 49152.f, 57344.0 + }; + // clang-format on + + const auto data_shape = Shape{input_data.size()}; + + auto data = make_shared(et, data_shape); + + auto scale = op::v0::Constant::create(et, Shape{1}, {1.f}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e5m2"); + + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), input_data)); +} + +TEST(eval, evaluate_fake_convert_f16_matching_f8_to_f8e5m2_scale_1) { + using namespace testing; + constexpr auto et = element::f16; + + // clang-format off + std::vector input_data{ + 0.f, 0.0000152587890625f, 0.00003051758125f, 0.0000457763671875f, + 0.00006103515625f, 0.0000762939453125f, 0.000091552734375, 0.0001068115234375, + 0.0001220703125f, 0.000152587890625f, 0.00018310546875f, 0.000213623046875f, + 0.000244140625f, 0.00030517578125, 0.0003662109375f, 0.00042724609375f, + 0.00048828125f, 0.0006103515625f, 0.000732421875f, 0.0008544921875, + 0.0009765625f, 0.001220703125, 0.00146484375f, 0.001708984375f, + 0.001953125f, 0.00244140625f, 0.0029296875f, 0.00341796875f, + 0.00390625f, 0.0048828125f, 0.005859375f, 0.0068359375f, + 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f, + + 0.015625f, /*0.017578125f,*/ 0.01953125f, /*0.021484375f,*/ 0.0234375f, /*0.025390625f,*/ 0.02734375f, /*0.029296875f,*/ + 0.03125f, /*0.03515625f,*/ 0.0390625f, /*0.04296875f,*/ 0.046875f, /*0.05078125f,*/ 0.0546875f, /*0.05859375f,*/ + 0.0625f, /*0.0703125f,*/ 0.078125f, /*0.0859375f,*/ 0.09375f, /*0.1015625f,*/ 0.109375f, /*0.1171875f,*/ + 0.125f, /*0.140625f,*/ 0.15625f, /*0.171875f,*/ 0.1875f, /*0.203125f,*/ 0.21875f, /*0.234375f,*/ + 0.25f, /*0.28125f,*/ 0.3125f, /*0.34375f,*/ 0.375f, /*0.40625f,*/ 0.4375f, /*0.46875f,*/ + 0.5f, /*0.5625f,*/ 0.625f, /*0.6875f,*/ 0.75f, /*0.8125f,*/ 0.875f, /*0.9375f,*/ + 1.f, /*1.125f,*/ 1.25f, /*1.375f,*/ 1.5f, /*1.625f,*/ 1.75f, /*1.875f,*/ + 2.f, /*2.25f,*/ 2.5f, /*2.75f,*/ 3.f, /*3.25f,*/ 3.5f, /*3.75f,*/ + 4.f, /*4.5f,*/ 5.f, /*5.5f,*/ 6.f, /*6.5f,*/ 7.f, /*7.5f,*/ + 8.f, /*9.f,*/ 10.f, /*11.f,*/ 12.f, /*13.f,*/ 14.f, /*15.f,*/ + 16.f, /*18.f,*/ 20.f, /*22.f,*/ 24.f, /*26.f,*/ 28.f, /*30.f,*/ + 32.f, /*36.f*,*/ 40.f, /*44.f,*/ 48.f, /*52.f,*/ 56.f, /*60.f,*/ + 64.f, /*72.f,*/ 80.f, /*88.f,*/ 96.f, /*104.f,*/ 112.f, /*120.f,*/ + 128.f, /*144.f,*/ 160.f, /*176.f,*/ 192.f, /*208.f,*/ 224.f, /*240.f,*/ + 256.f, /*288.f,*/ 320.f, /*352.f,*/ 384.f, /*416.f,*/ 448.f, + + 512.f, 640.f, 768.f, 896.f, + 1024.f, 1280.f, 1536.f, 1792.f, + 2048.f, 2560.f, 3072.f, 3584.f, + 4096.f, 5120.f, 6144.f, 7168.f, + 8192.f, 10240.f, 12288.f, 14336.f, + 16384.f, 20480.f, 24576.f, 28672.f, + 32768.f, 40960.f, 49152.f, 57344.0 + }; + // clang-format on + + const auto data_shape = Shape{input_data.size()}; + + auto data = make_shared(et, data_shape); + + auto scale = op::v0::Constant::create(et, Shape{1}, {1.f}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e5m2"); + + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), input_data)); +} + +TEST(eval, evaluate_fake_convert_bf16_matching_f8_to_f8e5m2_scale_1) { + using namespace testing; + constexpr auto et = element::bf16; + + // clang-format off + std::vector input_data{ + 0.f, 0.0000152587890625f, 0.00003051758125f, 0.0000457763671875f, + 0.00006103515625f, 0.0000762939453125f, 0.000091552734375, 0.0001068115234375, + 0.0001220703125f, 0.000152587890625f, 0.00018310546875f, 0.000213623046875f, + 0.000244140625f, 0.00030517578125, 0.0003662109375f, 0.00042724609375f, + 0.00048828125f, 0.0006103515625f, 0.000732421875f, 0.0008544921875, + 0.0009765625f, 0.001220703125, 0.00146484375f, 0.001708984375f, + 0.001953125f, 0.00244140625f, 0.0029296875f, 0.00341796875f, + 0.00390625f, 0.0048828125f, 0.005859375f, 0.0068359375f, + 0.0078125f, 0.009765625f, 0.01171875f, 0.013671875f, + + 0.015625f, /*0.017578125f,*/ 0.01953125f, /*0.021484375f,*/ 0.0234375f, /*0.025390625f,*/ 0.02734375f, /*0.029296875f,*/ + 0.03125f, /*0.03515625f,*/ 0.0390625f, /*0.04296875f,*/ 0.046875f, /*0.05078125f,*/ 0.0546875f, /*0.05859375f,*/ + 0.0625f, /*0.0703125f,*/ 0.078125f, /*0.0859375f,*/ 0.09375f, /*0.1015625f,*/ 0.109375f, /*0.1171875f,*/ + 0.125f, /*0.140625f,*/ 0.15625f, /*0.171875f,*/ 0.1875f, /*0.203125f,*/ 0.21875f, /*0.234375f,*/ + 0.25f, /*0.28125f,*/ 0.3125f, /*0.34375f,*/ 0.375f, /*0.40625f,*/ 0.4375f, /*0.46875f,*/ + 0.5f, /*0.5625f,*/ 0.625f, /*0.6875f,*/ 0.75f, /*0.8125f,*/ 0.875f, /*0.9375f,*/ + 1.f, /*1.125f,*/ 1.25f, /*1.375f,*/ 1.5f, /*1.625f,*/ 1.75f, /*1.875f,*/ + 2.f, /*2.25f,*/ 2.5f, /*2.75f,*/ 3.f, /*3.25f,*/ 3.5f, /*3.75f,*/ + 4.f, /*4.5f,*/ 5.f, /*5.5f,*/ 6.f, /*6.5f,*/ 7.f, /*7.5f,*/ + 8.f, /*9.f,*/ 10.f, /*11.f,*/ 12.f, /*13.f,*/ 14.f, /*15.f,*/ + 16.f, /*18.f,*/ 20.f, /*22.f,*/ 24.f, /*26.f,*/ 28.f, /*30.f,*/ + 32.f, /*36.f*,*/ 40.f, /*44.f,*/ 48.f, /*52.f,*/ 56.f, /*60.f,*/ + 64.f, /*72.f,*/ 80.f, /*88.f,*/ 96.f, /*104.f,*/ 112.f, /*120.f,*/ + 128.f, /*144.f,*/ 160.f, /*176.f,*/ 192.f, /*208.f,*/ 224.f, /*240.f,*/ + 256.f, /*288.f,*/ 320.f, /*352.f,*/ 384.f, /*416.f,*/ 448.f, + + 512.f, 640.f, 768.f, 896.f, + 1024.f, 1280.f, 1536.f, 1792.f, + 2048.f, 2560.f, 3072.f, 3584.f, + 4096.f, 5120.f, 6144.f, 7168.f, + 8192.f, 10240.f, 12288.f, 14336.f, + 16384.f, 20480.f, 24576.f, 28672.f, + 32768.f, 40960.f, 49152.f, 57344.0 + }; + // clang-format on + + const auto data_shape = Shape{input_data.size()}; + + auto data = make_shared(et, data_shape); + + auto scale = op::v0::Constant::create(et, Shape{1}, {1.f}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e5m2"); + + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), input_data)); +} + +TEST(eval, evaluate_fake_convert_f32_matching_f8e4m3_to_f8e5m2_scale_1) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{ + 0.017578125f, 0.021484375f, 0.025390625f, 0.029296875f, 0.03515625f, 0.0703125f, 0.140625f, + 0.28125f, 0.5625f, 1.125f, 1.625f, 1.875f, 2.25f, 3.75f, + 4.5f, 9.f, 18.f, 36.f, 72.f, 144.f, 288.f, + }; + /* Rounded to f8e5m2 vals */ + std::vector output_data{0.015625f, 0.0234375f, 0.0234375f, 0.03125f, 0.03125f, 0.0625f, 0.125f, + 0.25f, 0.5f, 1.f, 1.5, 2.f, 2.f, 4.f, + 4.f, 8.f, 16.f, 32.f, 64.f, 128.f, 256.f}; + + const auto data_shape = Shape{input_data.size()}; + auto data = make_shared(et, data_shape); + auto scale = op::v0::Constant::create(et, Shape{1}, {1.f}); + auto shift = op::v0::Constant::create(et, Shape{1}, {0.f}); + + auto op = make_shared(data, scale, shift, "f8e5m2"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), output_data)); +} + +TEST(eval, evaluate_fake_convert_f32_seq_to_f8e5m2_scale_shift) { + using namespace testing; + constexpr auto et = element::f32; + + const auto data_shape = Shape{8}; + std::vector input_data; + std::generate_n(std::back_inserter(input_data), shape_size(data_shape), ov::SeqGen(0.143f)); + + auto data = make_shared(et, data_shape); + auto max_input_val = std::abs(*std::max_element(input_data.begin(), input_data.end())); + auto scale = op::v0::Constant::create(et, Shape{1}, {fp8::MAX_F8E5M2 / max_input_val}); + auto shift = op::v0::Constant::create(et, Shape{1}, {5.f}); + + auto op = make_shared(data, scale, shift, "f8e5m2"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), + Pointwise(FloatEq(), + std::vector{0.128176391124725f, + 1.02105140686035f, + 2.04147982597351f, + 3.06190848350525f, + 4.08233690261841f, + 5.10276556015015f, + 6.12319421768188f, + 7.14362287521362f})); +} + +TEST(eval, evaluate_fake_convert_f32_to_f8e5m2_4x3_scale_big) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{fp8::MAX_F8E5M2 / 4.f, + fp8::MAX_F8E5M2 / 3.f, + fp8::MAX_F8E5M2 / 2.f, + fp8::MAX_F8E5M2, + fp8::MAX_F8E5M2, + fp8::MAX_F8E5M2, + fp8::MAX_F8E5M2 * 1.2f, + fp8::MAX_F8E5M2 * 2.3f, + fp8::MAX_F8E5M2 * 3.4f, + fp8::MAX_F8E5M2 * 2.f, + fp8::MAX_F8E5M2 * 3.f, + fp8::MAX_F8E5M2 * 4.f}; + + std::vector output_data{14336.f, + 20480.f, + 28672.f, + 57344.f, + 57344.f, + 57344.f, + 71680.f, + 143360.f, + 200704.f, + 114688.f, + 163840.f, + 229376.f}; + + const auto data_shape = Shape{4, 3}; + auto data = make_shared(et, data_shape); + auto scale = op::v0::Constant::create(et, + Shape{4, 1}, + {fp8::MAX_F8E5M2 / (fp8::MAX_F8E5M2 / 2.f), + 1.0f, + fp8::MAX_F8E5M2 / (fp8::MAX_F8E5M2 * 3.5f), + fp8::MAX_F8E5M2 / (fp8::MAX_F8E5M2 * 4.f)}); + auto shift = op::v0::Constant::create(et, Shape{4, 1}, {0.f, 0.f, 0.f, 0.f}); + + auto op = make_shared(data, scale, shift, "f8e5m2"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), output_data)); +} + +TEST(eval, evaluate_fake_convert_f32_to_f8e5m2_3x4_scale_big) { + using namespace testing; + constexpr auto et = element::f32; + + std::vector input_data{fp8::MAX_F8E5M2 / 4.f, + fp8::MAX_F8E5M2 / 3.f, + fp8::MAX_F8E5M2 / 2.f, + fp8::MAX_F8E5M2, + fp8::MAX_F8E5M2, + fp8::MAX_F8E5M2, + fp8::MAX_F8E5M2 * 1.2f, + fp8::MAX_F8E5M2 * 2.3f, + fp8::MAX_F8E5M2 * 3.4f, + fp8::MAX_F8E5M2 * 2.f, + fp8::MAX_F8E5M2 * 3.f, + fp8::MAX_F8E5M2 * 4.f}; + + std::vector output_data{14336.f, + 57344.f, + 71680.f, + 114688.f, + 20480.f, + 57344.f, + 143360.f, + 163840.f, + 28672.f, + 57344.f, + 200704.f, + 229376.f}; + + const auto data_shape = Shape{4, 3}; // To be transposed to 3x4 + auto data = make_shared(et, data_shape); + std::vector order{1, 0}; + auto transpose_order = make_shared(element::i32, Shape{order.size()}, order); + auto transposed_data = make_shared(data, transpose_order); + const auto transposed_data_shape = Shape{3, 4}; + + auto scale = op::v0::Constant::create(et, + Shape{1, 4}, + {fp8::MAX_F8E5M2 / (fp8::MAX_F8E5M2 / 2.f), + 1.0f, + fp8::MAX_F8E5M2 / (fp8::MAX_F8E5M2 * 3.5f), + fp8::MAX_F8E5M2 / (fp8::MAX_F8E5M2 * 4.f)}); + auto shift = op::v0::Constant::create(et, Shape{1, 4}, {0.f, 0.f, 0.f, 0.f}); + + auto op = make_shared(transposed_data, scale, shift, "f8e5m2"); + auto model = make_shared(OutputVector{op}, ParameterVector{data}); + + auto result = ov::Tensor(); + auto out_vector = ov::TensorVector{result}; + auto in_vector = ov::TensorVector{make_tensor(data_shape, input_data)}; + ASSERT_TRUE(model->evaluate(out_vector, in_vector)); + result = out_vector.at(0); + + EXPECT_EQ(result.get_element_type(), et); + EXPECT_EQ(result.get_shape(), transposed_data_shape); + EXPECT_THAT(read_vector(result), Pointwise(FloatEq(), output_data)); +} + TEST(eval, evaluate_cum_sum_v0) { auto data = make_shared(element::f32, Shape{2, 3}); auto axis = ov::op::v0::Constant::create(element::i32, Shape{1}, {1});