[Opset13][FP8] FakeConvert-13 input validation improvements (#21226)
* Improve shape type infer for FakeConvert-13 * Fix formatting * Formatting + missing test * Implement requested changes * Improve type merge * Fix formatting * Add requested type_prop tests * Implement requested changes * Fix formatting * Implement requested changes * Use template type
This commit is contained in:
parent
6b898fc8d9
commit
ed061d5179
@ -33,7 +33,7 @@ public:
|
||||
const std::string& get_destination_type() const;
|
||||
|
||||
private:
|
||||
void validate_type() const;
|
||||
void validate_destination_type() const;
|
||||
|
||||
std::string m_destination_type = "f8e4m3";
|
||||
};
|
||||
|
@ -0,0 +1,39 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include "openvino/op/fake_convert.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace v13 {
|
||||
template <class T, class TRShape = result_shape_t<T>>
|
||||
std::vector<TRShape> shape_infer(const FakeConvert* op, const std::vector<T>& input_shapes) {
|
||||
const auto inputs_count = input_shapes.size();
|
||||
const auto has_shifts_input = inputs_count == 3;
|
||||
NODE_VALIDATION_CHECK(op, inputs_count == 2 || has_shifts_input);
|
||||
const auto scales_shape = input_shapes[1];
|
||||
if (has_shifts_input) {
|
||||
const auto shifts_shape = input_shapes[2];
|
||||
NODE_SHAPE_INFER_CHECK(op,
|
||||
input_shapes,
|
||||
scales_shape.compatible(shifts_shape),
|
||||
"FakeConvert scale shape is not compatible with shift shape.");
|
||||
}
|
||||
TRShape output_pshape = input_shapes[0];
|
||||
NODE_SHAPE_INFER_CHECK(op,
|
||||
input_shapes,
|
||||
T::broadcast_merge_into(output_pshape, scales_shape, op::AutoBroadcastType::NUMPY),
|
||||
"Argument shapes are inconsistent.");
|
||||
NODE_SHAPE_INFER_CHECK(
|
||||
op,
|
||||
input_shapes,
|
||||
input_shapes[0].compatible(output_pshape),
|
||||
"FakeConvert support only unidirectional broadcasting, inputs cannot be broadcastd into data.");
|
||||
return {output_pshape};
|
||||
}
|
||||
} // namespace v13
|
||||
} // namespace op
|
||||
} // namespace ov
|
@ -5,6 +5,7 @@
|
||||
#include "openvino/op/fake_convert.hpp"
|
||||
|
||||
#include "element_visitor.hpp"
|
||||
#include "fake_convert_shape_inference.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "openvino/reference/convert.hpp"
|
||||
#include "openvino/reference/fake_convert.hpp"
|
||||
@ -69,8 +70,26 @@ const std::string& FakeConvert::get_destination_type() const {
|
||||
|
||||
void FakeConvert::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v13_FakeConvert_validate_and_infer_types);
|
||||
validate_type();
|
||||
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
|
||||
validate_destination_type();
|
||||
auto out_type = get_input_element_type(0);
|
||||
for (size_t i = 1; i < get_input_size(); i++) {
|
||||
OPENVINO_ASSERT(element::Type::merge(out_type, out_type, get_input_element_type(i)),
|
||||
"Mixed input types are not supported.");
|
||||
}
|
||||
switch (out_type) {
|
||||
case element::bf16:
|
||||
case element::f16:
|
||||
case element::f32:
|
||||
case element::dynamic:
|
||||
break;
|
||||
default:
|
||||
OPENVINO_THROW("The element type of the input tensor must be a bf16, f16, f32 but got: ", out_type);
|
||||
}
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
const auto input_shapes = get_node_input_partial_shapes(*this);
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
const auto output_shapes = shape_infer(this, input_shapes);
|
||||
set_output_type(0, out_type, output_shapes[0]);
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Node> FakeConvert::clone_with_new_inputs(const ov::OutputVector& new_args) const {
|
||||
@ -90,7 +109,7 @@ bool FakeConvert::visit_attributes(ov::AttributeVisitor& visitor) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void FakeConvert::validate_type() const {
|
||||
void FakeConvert::validate_destination_type() const {
|
||||
const auto& valid_types = fake_convert_details::get_valid_types();
|
||||
const auto is_supported_type =
|
||||
std::find(valid_types.begin(), valid_types.end(), m_destination_type) != valid_types.end();
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common_test_utils/test_assertions.hpp"
|
||||
#include "common_test_utils/type_prop.hpp"
|
||||
|
||||
using namespace ov;
|
||||
@ -40,6 +41,26 @@ TEST(type_prop, fake_convert_basic_f16) {
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 8, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, fake_convert_basic_bf16) {
|
||||
const auto data = std::make_shared<Parameter>(element::bf16, PartialShape{2, 3, 8, 6});
|
||||
const auto scale = std::make_shared<Parameter>(element::bf16, PartialShape{});
|
||||
const auto shift = std::make_shared<Parameter>(element::bf16, PartialShape{});
|
||||
|
||||
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::bf16);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 8, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, fake_convert_basic_dynamic) {
|
||||
const auto data = std::make_shared<Parameter>(element::dynamic, PartialShape{2, 3, 8, 6});
|
||||
const auto scale = std::make_shared<Parameter>(element::dynamic, PartialShape{});
|
||||
const auto shift = std::make_shared<Parameter>(element::dynamic, PartialShape{});
|
||||
|
||||
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::dynamic);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 8, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, fake_convert_dynamic_shape) {
|
||||
const auto data = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{});
|
||||
@ -49,3 +70,105 @@ TEST(type_prop, fake_convert_dynamic_shape) {
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape::dynamic()));
|
||||
}
|
||||
|
||||
TEST(type_prop, fake_convert_label) {
|
||||
PartialShape data_shape{2, 1, Dimension::dynamic(), 6};
|
||||
set_shape_labels(data_shape, 10);
|
||||
const auto data = std::make_shared<Parameter>(element::f32, data_shape);
|
||||
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{});
|
||||
const auto shift = std::make_shared<Parameter>(element::f32, PartialShape{});
|
||||
|
||||
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 1, Dimension::dynamic(), 6}));
|
||||
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), testing::ElementsAre(10, 11, 12, 13));
|
||||
}
|
||||
|
||||
TEST(type_prop, fake_convert_example_0) {
|
||||
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{2, 1, 3, 6});
|
||||
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{1, 1, 3, 1});
|
||||
const auto shift = std::make_shared<Parameter>(element::f32, PartialShape{1, 1, 3, 1});
|
||||
|
||||
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 1, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, fake_convert_example_1) {
|
||||
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{2, 1, 3, 6});
|
||||
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{3, 1});
|
||||
const auto shift = std::make_shared<Parameter>(element::f32, PartialShape{3, 1});
|
||||
|
||||
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 1, 3, 6}));
|
||||
}
|
||||
TEST(type_prop, fake_convert_example_2) {
|
||||
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{2, 1, {3, 5}, 6});
|
||||
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{3, 1});
|
||||
const auto shift = std::make_shared<Parameter>(element::f32, PartialShape{3, 1});
|
||||
|
||||
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 1, 3, 6}));
|
||||
}
|
||||
TEST(type_prop, fake_convert_example_3) {
|
||||
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{4, {2, 5}, 6});
|
||||
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{4, 3, 1});
|
||||
const auto shift = std::make_shared<Parameter>(element::f32, PartialShape{4, 3, 1});
|
||||
|
||||
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{4, 3, 6}));
|
||||
}
|
||||
|
||||
TEST(type_prop, fake_convert_basic_unsupported_type) {
|
||||
const auto data = std::make_shared<Parameter>(element::f64, PartialShape{2, 3, 8, 6});
|
||||
const auto scale = std::make_shared<Parameter>(element::f64, PartialShape{});
|
||||
const auto shift = std::make_shared<Parameter>(element::f64, PartialShape{});
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift),
|
||||
Exception,
|
||||
testing::HasSubstr("The element type of the input tensor must be a bf16, f16, f32 but got: f64"));
|
||||
}
|
||||
|
||||
TEST(type_prop, fake_convert_basic_unsupported_shape_scale_shift) {
|
||||
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{2, 3, 8, 6});
|
||||
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{1});
|
||||
const auto shift = std::make_shared<Parameter>(element::f32, PartialShape{});
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift),
|
||||
AssertFailure,
|
||||
testing::HasSubstr("FakeConvert scale shape is not compatible with shift shape."));
|
||||
}
|
||||
|
||||
TEST(type_prop, fake_convert_basic_unsupported_shape_not_broadcastable) {
|
||||
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{2, 3, 8, 6});
|
||||
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{1, 1, 3, 1});
|
||||
const auto shift = std::make_shared<Parameter>(element::f32, PartialShape{1, 1, 3, 1});
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift),
|
||||
AssertFailure,
|
||||
testing::HasSubstr("Argument shapes are inconsistent."));
|
||||
}
|
||||
|
||||
TEST(type_prop, fake_convert_basic_unsupported_shape_not_unidirectional_broadcastable) {
|
||||
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{2, 3, 1, 6});
|
||||
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{1, 1, 3, 1});
|
||||
const auto shift = std::make_shared<Parameter>(element::f32, PartialShape{1, 1, 3, 1});
|
||||
|
||||
OV_EXPECT_THROW(
|
||||
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift),
|
||||
AssertFailure,
|
||||
testing::HasSubstr(
|
||||
"FakeConvert support only unidirectional broadcasting, inputs cannot be broadcastd into data."));
|
||||
}
|
||||
TEST(type_prop, fake_convert_basic_unsupported_mixed_types) {
|
||||
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{2, 3, 8, 6});
|
||||
const auto scale = std::make_shared<Parameter>(element::dynamic, PartialShape{});
|
||||
const auto shift = std::make_shared<Parameter>(element::f16, PartialShape{});
|
||||
|
||||
OV_EXPECT_THROW(const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift),
|
||||
AssertFailure,
|
||||
testing::HasSubstr("Mixed input types are not supported."));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user