[Opset13][FP8] FakeConvert evaluate and reference impl (#21034)
* FakeConvert op init * Update dest types names * Update op hpp * Update opset ops number * Init type_prop tests * Add attributes tests * Add op check test * Update namespace in fc cpp * Update getters * Refactor static member * Make destination_type lower case * Update type in test * Move get_valid_types out of class * Update ops number in opset * Evaluate init * Init poc evaluate * Update types name * Style update * FakeConvert eval tests * Add support for FP16 input * Remove unused functions * Use const auto and constexpr * Create fake_convert reference file and move ref functions * FakeConvert eval and reference update * Remove unused is_normal variable * Add bf16 tests * Update construtor in tests * adjust convert helper name * Update comments * Make func params as constexpr * Update types * Non scalar scale shape tests * Use autobrodacast_select * Use single autobroadcast_select * Check scale_shape size * Add var to keep scale_shift lambda * Use lamda only for autobroadcast * More checks for per channel scale * Remove union half_t * Change template to fp16 type for emulate * Use f8e4m3_min_val as constexpr * Update unsigned short to uint16_t * Minor style refactor * Minor comments update * Update fake_convert_details namespace to func * Update apply_conversion return type * Add doxygen text * Update supported type assert * Remove duplicated tests * Update namespace move to cpp * Fill init with zeroes instead of reserve * Use add div mul sub reference for applying scale * Use autobroadcast_select for applying scale * Bring back opt broacast cases * Reuse scale_shift as func * Adjust s and o var names * Update multiplication in loop tu accumulate
This commit is contained in:
parent
65c3c7c20a
commit
63e7e2dd3e
@ -27,6 +27,7 @@ public:
|
|||||||
void validate_and_infer_types() override;
|
void validate_and_infer_types() override;
|
||||||
std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
|
std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
|
||||||
bool visit_attributes(ov::AttributeVisitor& visitor) override;
|
bool visit_attributes(ov::AttributeVisitor& visitor) override;
|
||||||
|
bool evaluate(TensorVector& outputs, const TensorVector& inputs) const override;
|
||||||
bool has_evaluate() const override;
|
bool has_evaluate() const override;
|
||||||
|
|
||||||
const std::string& get_destination_type() const;
|
const std::string& get_destination_type() const;
|
||||||
|
221
src/core/reference/include/openvino/reference/fake_convert.hpp
Normal file
221
src/core/reference/include/openvino/reference/fake_convert.hpp
Normal file
@ -0,0 +1,221 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cmath>
|
||||||
|
#include <numeric>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "openvino/reference/add.hpp"
|
||||||
|
#include "openvino/reference/autobroadcast_binop.hpp"
|
||||||
|
#include "openvino/reference/convert.hpp"
|
||||||
|
#include "openvino/reference/divide.hpp"
|
||||||
|
#include "openvino/reference/multiply.hpp"
|
||||||
|
#include "openvino/reference/subtract.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace reference {
|
||||||
|
namespace func {
|
||||||
|
/**
|
||||||
|
* @brief Emulation of conversion fp16 value to f8e5m2 format
|
||||||
|
*
|
||||||
|
* @param arg_f Pointer to the input data.
|
||||||
|
* @param out_f Pointer to the otuput data.
|
||||||
|
* @param count Number of elements in the data input.
|
||||||
|
*/
|
||||||
|
void emulate_f8e5m2_on_fp16(const float16* const arg_f, float16* out_f, size_t count);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Emulation of conversion fp16 value to f8e4m3 format
|
||||||
|
*
|
||||||
|
* @param arg_f Pointer to the input data.
|
||||||
|
* @param out_f Pointer to the otuput data.
|
||||||
|
* @param count Number of elements in the data input.
|
||||||
|
*
|
||||||
|
* Exponent denormal values 0 -7
|
||||||
|
* Exponent normal values 1..15 -6..8 (7 - exponent)
|
||||||
|
* Exponent NaN values 15 8
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
void emulate_f8e4m3_on_fp16(const float16* arg_f, float16* out_f, size_t count);
|
||||||
|
} // namespace func
|
||||||
|
|
||||||
|
namespace fake_convert_details {
|
||||||
|
template <bool INVERT, typename T>
|
||||||
|
inline T scale_shift(T val, T scale, T shift) {
|
||||||
|
return INVERT ? static_cast<T>((val + shift) / scale) : static_cast<T>(val * scale - shift);
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* @brief Apply scale and shift for the data input.
|
||||||
|
* The scale_shape and shift_shape should be equal and numpy-broadcastable to the data_shape.
|
||||||
|
*
|
||||||
|
* @param data Pointer to the input data.
|
||||||
|
* @param scale Pointer to the input scale.
|
||||||
|
* @param shift Pointer to the input shift.
|
||||||
|
* @param data_shape Shape of the input data.
|
||||||
|
* @param scale_shape Shape of the input scale.
|
||||||
|
* @param shift_shape Shape of the input shift.
|
||||||
|
* @param invert Flag denoting applying scale before (if false) or after conversion (if true).
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
void apply_scale_shift(T* out,
|
||||||
|
const T* data,
|
||||||
|
const T* scale,
|
||||||
|
const T* shift,
|
||||||
|
const Shape& data_shape,
|
||||||
|
const Shape& scale_shape,
|
||||||
|
const Shape& shift_shape,
|
||||||
|
const bool invert = false) {
|
||||||
|
const auto scale_shift_func = invert ? scale_shift<true, T> : scale_shift<false, T>;
|
||||||
|
const auto scale_size = shape_size(scale_shape);
|
||||||
|
const auto data_size = shape_size(data_shape);
|
||||||
|
|
||||||
|
if (scale_size == 1) { // per tensor scale, probably for activation
|
||||||
|
T scale_val = scale[0];
|
||||||
|
T shift_val = shift[0];
|
||||||
|
for (size_t j = 0; j < data_size; ++j) {
|
||||||
|
out[j] = scale_shift_func(data[j], scale_val, shift_val);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (scale_shape.size() > 1 && scale_shape[0] == 1 && data_shape.size() == scale_shape.size() &&
|
||||||
|
data_shape[1] == scale_shape[1] && scale_size == data_shape[1]) { // per channel scale for DW activations
|
||||||
|
const auto step = shape_size(data_shape.begin() + 2, data_shape.end());
|
||||||
|
for (size_t bs = 0; bs < data_shape[0]; ++bs) {
|
||||||
|
for (size_t i = 0; i < scale_size; i++) {
|
||||||
|
T scale_val = scale[i];
|
||||||
|
T shift_val = shift[i];
|
||||||
|
for (size_t j = 0; j < step; ++j) {
|
||||||
|
out[j] = scale_shift_func(data[j], scale_val, shift_val);
|
||||||
|
}
|
||||||
|
data += step;
|
||||||
|
out += step;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The specific cases above are optimized verions of the broadcast for specific case
|
||||||
|
// Autobroadcast helper is generic approach for broadcast
|
||||||
|
autobroadcast_select(data,
|
||||||
|
scale,
|
||||||
|
shift,
|
||||||
|
out,
|
||||||
|
data_shape,
|
||||||
|
scale_shape,
|
||||||
|
shift_shape,
|
||||||
|
op::AutoBroadcastType::NUMPY,
|
||||||
|
scale_shift_func);
|
||||||
|
}
|
||||||
|
/**
|
||||||
|
* @brief Call conversion of fp16 value to the desired destination type
|
||||||
|
*
|
||||||
|
* @param arg Pointer to the input data.
|
||||||
|
* @param out Pointer to the otuput data.
|
||||||
|
* @param count Number of elements in the data input.
|
||||||
|
* @param destination_type Name of the destination type.
|
||||||
|
*/
|
||||||
|
void apply_conversion(const float16* data, float16* out, size_t element_count, const std::string& destination_type);
|
||||||
|
} // namespace fake_convert_details
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Reference implementation of the FakeConvert operation specialized for float16 type.
|
||||||
|
* It emulates format specified by "destination_type" on the input type.
|
||||||
|
* FakeConvert performs element-wise quantization of floating-point input values into a set of values corresponding to a
|
||||||
|
* target low-precision floating-point type.
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* @param data Pointer to the input data.
|
||||||
|
* @param scale Pointer to the input scale.
|
||||||
|
* @param shift Pointer to the input shift.
|
||||||
|
* @param out Pointer to the output data.
|
||||||
|
* @param data_shape Shape of the input data.
|
||||||
|
* @param scale_shape Shape of the input scale.
|
||||||
|
* @param shift_shape Shape of the input shift.
|
||||||
|
* @param destination_type Name of the destination type.
|
||||||
|
*/
|
||||||
|
template <typename T, typename std::enable_if<std::is_same<T, float16>::value, bool>::type = true>
|
||||||
|
void fake_convert(const T* data,
|
||||||
|
const T* scale,
|
||||||
|
const T* shift,
|
||||||
|
T* out,
|
||||||
|
const Shape& data_shape,
|
||||||
|
const Shape& scale_shape,
|
||||||
|
const Shape& shift_shape,
|
||||||
|
const std::string& destination_type) {
|
||||||
|
const size_t element_count = shape_size(data_shape);
|
||||||
|
fake_convert_details::apply_scale_shift<float16>(out,
|
||||||
|
data,
|
||||||
|
scale,
|
||||||
|
shift,
|
||||||
|
data_shape,
|
||||||
|
scale_shape,
|
||||||
|
shift_shape,
|
||||||
|
false);
|
||||||
|
fake_convert_details::apply_conversion(out, out, element_count, destination_type);
|
||||||
|
fake_convert_details::apply_scale_shift<float16>(out,
|
||||||
|
out,
|
||||||
|
scale,
|
||||||
|
shift,
|
||||||
|
data_shape,
|
||||||
|
scale_shape,
|
||||||
|
shift_shape,
|
||||||
|
true);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Reference implementation of the FakeConvert operation for floating-point types, with middle conversion
|
||||||
|
* to float16. It emulates format specified by "destination_type" on the input type. FakeConvert performs
|
||||||
|
* element-wise quantization of floating-point input values into a set of values corresponding to a target low-precision
|
||||||
|
* floating-point type.
|
||||||
|
*
|
||||||
|
*
|
||||||
|
* @param data Pointer to the input data.
|
||||||
|
* @param scale Pointer to the input scale.
|
||||||
|
* @param shift Pointer to the input shift.
|
||||||
|
* @param out Pointer to the output data.
|
||||||
|
* @param data_shape Shape of the input data.
|
||||||
|
* @param scale_shape Shape of the input scale.
|
||||||
|
* @param shift_shape Shape of the input shift.
|
||||||
|
* @param destination_type Name of the destination type.
|
||||||
|
*/
|
||||||
|
template <typename T, typename std::enable_if<!std::is_same<T, float16>::value, bool>::type = true>
|
||||||
|
void fake_convert(const T* data,
|
||||||
|
const T* scale,
|
||||||
|
const T* shift,
|
||||||
|
T* out,
|
||||||
|
const Shape& data_shape,
|
||||||
|
const Shape& scale_shape,
|
||||||
|
const Shape& shift_shape,
|
||||||
|
const std::string& destination_type) {
|
||||||
|
const size_t element_count = shape_size(data_shape);
|
||||||
|
fake_convert_details::apply_scale_shift<T>(out, data, scale, shift, data_shape, scale_shape, shift_shape, false);
|
||||||
|
|
||||||
|
std::vector<ov::float16> tmp_fp16(element_count, 0.f);
|
||||||
|
reference::convert(out, tmp_fp16.data(), element_count);
|
||||||
|
fake_convert_details::apply_conversion(tmp_fp16.data(), tmp_fp16.data(), element_count, destination_type);
|
||||||
|
reference::convert(tmp_fp16.data(), out, element_count);
|
||||||
|
|
||||||
|
fake_convert_details::apply_scale_shift<T>(out, out, scale, shift, data_shape, scale_shape, shift_shape, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Reference implementation of the FakeConvert operation, with default `shift` input.
|
||||||
|
*/
|
||||||
|
template <typename T>
|
||||||
|
void fake_convert(const T* data,
|
||||||
|
const T* scale,
|
||||||
|
T* out,
|
||||||
|
const Shape& data_shape,
|
||||||
|
const Shape& scale_shape,
|
||||||
|
const std::string& destination_type) {
|
||||||
|
const auto shift = std::vector<T>(shape_size(scale_shape), 0.f);
|
||||||
|
fake_convert<T>(data, scale, shift.data(), out, data_shape, scale_shape, scale_shape, destination_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace reference
|
||||||
|
} // namespace ov
|
163
src/core/reference/src/op/fake_convert.cpp
Normal file
163
src/core/reference/src/op/fake_convert.cpp
Normal file
@ -0,0 +1,163 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "openvino/reference/fake_convert.hpp"
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace reference {
|
||||||
|
namespace func {
|
||||||
|
void emulate_f8e5m2_on_fp16(const float16* const arg_f, float16* out_f, size_t count) {
|
||||||
|
const auto arg_u = reinterpret_cast<const uint16_t*>(arg_f);
|
||||||
|
auto out_u = reinterpret_cast<uint16_t*>(out_f);
|
||||||
|
uint16_t val_bit_repr;
|
||||||
|
|
||||||
|
constexpr auto exp_bits = 5;
|
||||||
|
constexpr auto mbits = 8;
|
||||||
|
constexpr auto non_mant_bits = exp_bits + 1; /// exponent + sign
|
||||||
|
constexpr auto lshift = 10 - (mbits - non_mant_bits);
|
||||||
|
constexpr uint16_t mask_mant = static_cast<uint16_t>(0xFFFF << lshift); /// 1111111111111111 -> 1 11111 1100000000
|
||||||
|
constexpr uint16_t grs_bitmask = 0x00FF; /// 0 00000 0011111111, grs denotes guard, round, sticky bits
|
||||||
|
constexpr uint16_t rne_tie = 0x0180; /// 0 00000 0110000000, rne denotes round to nearest even
|
||||||
|
constexpr uint16_t fp16_inf = 0x7C00;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
/// converts float number to half precision in round-to-nearest-even mode and returns half with converted value.
|
||||||
|
val_bit_repr = arg_u[i];
|
||||||
|
/// 0x7c00 = 0111110000000000 - exponent mask
|
||||||
|
/// s 11111 xxx xxxx xxxx - is nan (if some x is 1) or inf (if all x is 0)
|
||||||
|
/// 0x7800 is 0111100000000000 and 0x400 is 0000010000000000
|
||||||
|
/// number is not normal if all exponent is 1 or 0
|
||||||
|
/// 0x7f00 is 0 11111 1100000000
|
||||||
|
/// 0x7b00 is 0 11110 1100000000
|
||||||
|
const bool can_round = ((val_bit_repr & 0x7F00) < 0x7B00) ? true : false;
|
||||||
|
/// s 11111 xxx xxxx xxxx - is nan (if some x is 1) or inf (if all x is 0)
|
||||||
|
const bool is_naninf = ((val_bit_repr & fp16_inf) == fp16_inf) ? true : false;
|
||||||
|
/* nearest rounding masks */
|
||||||
|
/// grs_bitmask - grs_bitmask is 0 00000 0011111111 or 0 00000 00grs11111
|
||||||
|
uint16_t rnmask = (val_bit_repr & grs_bitmask);
|
||||||
|
/// rne_tie - 0x180 is 0 00000 0110000000 or 384.0
|
||||||
|
uint16_t rnmask_tie = (val_bit_repr & rne_tie);
|
||||||
|
|
||||||
|
if (!is_naninf && can_round) {
|
||||||
|
/* round to nearest even, if rne_mask is enabled */
|
||||||
|
/* 0 00000 0010000000, find grs patterns */
|
||||||
|
// 0xx - do nothing
|
||||||
|
// 100 - this is a tie : round up if the mantissa's bit just before G is 1, else do nothing
|
||||||
|
// 101, 110, 111 - round up > 0x0080
|
||||||
|
val_bit_repr += (((rnmask > 0x0080) || (rnmask_tie == rne_tie)) << lshift);
|
||||||
|
}
|
||||||
|
val_bit_repr &= mask_mant; /* truncation */
|
||||||
|
out_u[i] = val_bit_repr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Emulation of conversion fp16 value to f8e4m3 format
|
||||||
|
*
|
||||||
|
* @param arg_f Pointer to the input data.
|
||||||
|
* @param out_f Pointer to the otuput data.
|
||||||
|
* @param count Number of elements in the data input.
|
||||||
|
*
|
||||||
|
* Exponent denormal values 0 -7
|
||||||
|
* Exponent normal values 1..15 -6..8 (7 - exponent)
|
||||||
|
* Exponent NaN values 15 8
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
void emulate_f8e4m3_on_fp16(const float16* arg_f, float16* out_f, size_t count) {
|
||||||
|
const auto arg_u = reinterpret_cast<const uint16_t*>(arg_f);
|
||||||
|
auto out_u = reinterpret_cast<uint16_t*>(out_f);
|
||||||
|
uint16_t val_bit_repr;
|
||||||
|
|
||||||
|
constexpr auto use_clamp = true;
|
||||||
|
constexpr auto exp_bits = 5;
|
||||||
|
constexpr auto mbits = 9;
|
||||||
|
constexpr auto non_mant_bits = exp_bits + 1; /// exponent + sign
|
||||||
|
constexpr auto lshift = 10 - (mbits - non_mant_bits);
|
||||||
|
constexpr auto fp16_exp_bias = 15;
|
||||||
|
constexpr auto f8e4m3_min_val = 0.001953125f; /// 2**-9
|
||||||
|
|
||||||
|
constexpr uint16_t mask_mant = static_cast<uint16_t>(0xFFFF << lshift); /// 1111111111111111 -> 1 11111 1111000000
|
||||||
|
constexpr uint16_t grs_bitmask = 0x007F; /// 0 00000 0001111111, grs denotes guard, round, sticky bits
|
||||||
|
constexpr uint16_t rne_tie = 0x00C0; /// 0 00000 0011000000, rne denotes round to nearest even
|
||||||
|
constexpr uint16_t rne_mask = 1;
|
||||||
|
constexpr uint16_t fp16_inf = 0x7C00;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < count; ++i) {
|
||||||
|
val_bit_repr = arg_u[i];
|
||||||
|
/* flush values below 1-4-3 (offset=4) subnormal range to zero */
|
||||||
|
if (std::abs(static_cast<float>(arg_f[i])) < f8e4m3_min_val) {
|
||||||
|
val_bit_repr = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
short exp_h = static_cast<short>((val_bit_repr & fp16_inf) >> 10) -
|
||||||
|
fp16_exp_bias; /// 0111110000000000 -> 0000000000011111 - 15, biased exponent for fp16
|
||||||
|
const short sign_h = (val_bit_repr & 0x8000); /// & 1 00000 0000000000
|
||||||
|
short mantissa_h = (val_bit_repr & 0x03FF); /// & 0 00000 1111111111
|
||||||
|
///(val_bit_repr && 0111111111111111) < 0 10010 1110000000 (19326)
|
||||||
|
const bool can_round = ((val_bit_repr & 0x7FFF) < 0b101111110000000) ? true : false;
|
||||||
|
bool is_naninf = ((val_bit_repr & fp16_inf) == fp16_inf) ? true : false;
|
||||||
|
|
||||||
|
int dshift = 0;
|
||||||
|
if (exp_h > 8) { // too large, set it to NaN or inf
|
||||||
|
is_naninf = true;
|
||||||
|
if (use_clamp) {
|
||||||
|
exp_h = 8;
|
||||||
|
mantissa_h = 0b0000001100000000;
|
||||||
|
} else {
|
||||||
|
mantissa_h = 0;
|
||||||
|
exp_h = 16;
|
||||||
|
}
|
||||||
|
} else if (exp_h < -9) { /// -13, -12, -11 for rounding
|
||||||
|
/* flush values below 1-4-3 (offset=4) subnormal range to zero */
|
||||||
|
exp_h = -15;
|
||||||
|
mantissa_h = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (exp_h == 8 && mantissa_h >= 0b0000001100000000) {
|
||||||
|
mantissa_h = 0b0000001100000000;
|
||||||
|
}
|
||||||
|
/* nearest rounding masks, & 0 00000 000 111 1111 - mantissa bits below f8e4m3 (grs) */
|
||||||
|
const uint16_t rnmask = (mantissa_h & grs_bitmask);
|
||||||
|
/* & 0 00000 0011000000 - edge between f8e4m3 and fp16 mantissa */
|
||||||
|
const uint16_t rnmask_tie = (mantissa_h & rne_tie);
|
||||||
|
if (!is_naninf && can_round && rne_mask) {
|
||||||
|
/* round to nearest even, if rne_mask is enabled */
|
||||||
|
/// rnmask > 0 00000 0001000000(64) or 0 00000 0011000000 - edge bits is 1
|
||||||
|
/// += 0 00000 0010000000
|
||||||
|
mantissa_h += (((rnmask > 0x0040) || (rnmask_tie == rne_tie)) << lshift);
|
||||||
|
}
|
||||||
|
if (exp_h < -6) { /* handle denormals -11, -10, -9, dshift 1, 2, 3 */
|
||||||
|
dshift = (-6 - exp_h);
|
||||||
|
mantissa_h = mantissa_h >> dshift;
|
||||||
|
}
|
||||||
|
mantissa_h &= mask_mant; /* truncation */
|
||||||
|
mantissa_h <<= dshift;
|
||||||
|
mantissa_h += ((exp_h + 15) << 10);
|
||||||
|
val_bit_repr = mantissa_h | sign_h;
|
||||||
|
out_u[i] = val_bit_repr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace func
|
||||||
|
|
||||||
|
namespace fake_convert_details {
|
||||||
|
/**
|
||||||
|
* @brief Call conversion of fp16 value to the desired destination type
|
||||||
|
*
|
||||||
|
* @param arg Pointer to the input data.
|
||||||
|
* @param out Pointer to the otuput data.
|
||||||
|
* @param count Number of elements in the data input.
|
||||||
|
* @param destination_type Name of the destination type.
|
||||||
|
*/
|
||||||
|
void apply_conversion(const float16* data, float16* out, size_t element_count, const std::string& destination_type) {
|
||||||
|
if (destination_type == "f8e5m2") {
|
||||||
|
reference::func::emulate_f8e5m2_on_fp16(data, out, element_count);
|
||||||
|
} else if (destination_type == "f8e4m3") {
|
||||||
|
reference::func::emulate_f8e4m3_on_fp16(data, out, element_count);
|
||||||
|
} else {
|
||||||
|
OPENVINO_THROW("Unsupported destination type.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace fake_convert_details
|
||||||
|
} // namespace reference
|
||||||
|
} // namespace ov
|
@ -4,17 +4,47 @@
|
|||||||
|
|
||||||
#include "openvino/op/fake_convert.hpp"
|
#include "openvino/op/fake_convert.hpp"
|
||||||
|
|
||||||
|
#include "element_visitor.hpp"
|
||||||
#include "itt.hpp"
|
#include "itt.hpp"
|
||||||
|
#include "openvino/reference/convert.hpp"
|
||||||
|
#include "openvino/reference/fake_convert.hpp"
|
||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
namespace op {
|
namespace op {
|
||||||
namespace v13 {
|
namespace v13 {
|
||||||
namespace fake_convert {
|
namespace fake_convert_details {
|
||||||
static const std::vector<std::string>& get_valid_types() {
|
static const std::vector<std::string>& get_valid_types() {
|
||||||
static const std::vector<std::string> valid_types{"f8e4m3", "f8e5m2"};
|
static const std::vector<std::string> valid_types{"f8e4m3", "f8e5m2"};
|
||||||
return valid_types;
|
return valid_types;
|
||||||
}
|
}
|
||||||
} // namespace fake_convert
|
|
||||||
|
struct Evaluate : element::NoAction<bool> {
|
||||||
|
using element::NoAction<bool>::visit;
|
||||||
|
template <element::Type_t ET, class T = fundamental_type_for<ET>>
|
||||||
|
static result_type visit(ov::TensorVector& outputs,
|
||||||
|
const ov::TensorVector& inputs,
|
||||||
|
const std::string& destination_type) {
|
||||||
|
if (inputs.size() == 2) { // Default shift
|
||||||
|
reference::fake_convert<T>(inputs[0].data<const T>(),
|
||||||
|
inputs[1].data<const T>(),
|
||||||
|
outputs[0].data<T>(),
|
||||||
|
inputs[0].get_shape(),
|
||||||
|
inputs[1].get_shape(),
|
||||||
|
destination_type);
|
||||||
|
} else {
|
||||||
|
reference::fake_convert<T>(inputs[0].data<const T>(),
|
||||||
|
inputs[1].data<const T>(),
|
||||||
|
inputs[2].data<const T>(),
|
||||||
|
outputs[0].data<T>(),
|
||||||
|
inputs[0].get_shape(),
|
||||||
|
inputs[1].get_shape(),
|
||||||
|
inputs[2].get_shape(),
|
||||||
|
destination_type);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // namespace fake_convert_details
|
||||||
|
|
||||||
FakeConvert::FakeConvert(const ov::Output<ov::Node>& arg,
|
FakeConvert::FakeConvert(const ov::Output<ov::Node>& arg,
|
||||||
const ov::Output<ov::Node>& scale,
|
const ov::Output<ov::Node>& scale,
|
||||||
@ -61,15 +91,40 @@ bool FakeConvert::visit_attributes(ov::AttributeVisitor& visitor) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void FakeConvert::validate_type() const {
|
void FakeConvert::validate_type() const {
|
||||||
const auto& valid_types = fake_convert::get_valid_types();
|
const auto& valid_types = fake_convert_details::get_valid_types();
|
||||||
OPENVINO_ASSERT(std::find(valid_types.begin(), valid_types.end(), m_destination_type) != valid_types.end(),
|
const auto is_supported_type =
|
||||||
"Bad format for f8 conversion type: " + m_destination_type);
|
std::find(valid_types.begin(), valid_types.end(), m_destination_type) != valid_types.end();
|
||||||
|
OPENVINO_ASSERT(is_supported_type, "Bad format for f8 conversion type: ", m_destination_type);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool FakeConvert::has_evaluate() const {
|
bool FakeConvert::has_evaluate() const {
|
||||||
return false;
|
OV_OP_SCOPE(v13_FakeConvert_has_evaluate);
|
||||||
|
switch (get_input_element_type(0)) {
|
||||||
|
case element::f16:
|
||||||
|
case element::bf16:
|
||||||
|
case element::f32:
|
||||||
|
return true;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool FakeConvert::evaluate(ov::TensorVector& outputs, const ov::TensorVector& inputs) const {
|
||||||
|
OV_OP_SCOPE(v13_FakeConvert_evaluate);
|
||||||
|
|
||||||
|
OPENVINO_ASSERT(outputs.size() == 1);
|
||||||
|
OPENVINO_ASSERT(inputs.size() == 2 || inputs.size() == 3);
|
||||||
|
|
||||||
|
outputs[0].set_shape(inputs[0].get_shape());
|
||||||
|
|
||||||
|
using namespace ov::element;
|
||||||
|
return IfTypeOf<bf16, f16, f32>::apply<fake_convert_details::Evaluate>(inputs[0].get_element_type(),
|
||||||
|
outputs,
|
||||||
|
inputs,
|
||||||
|
get_destination_type());
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
} // namespace v13
|
} // namespace v13
|
||||||
} // namespace op
|
} // namespace op
|
||||||
} // namespace ov
|
} // namespace ov
|
||||||
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user