[Opset13][FP8] FakeConvert-13 input validation improvements (#21226)

* Improve shape type infer for FakeConvert-13

* Fix formatting

* Formatting + missing test

* Implement requested changes

* Improve type merge

* Fix formatting

* Add requested type_prop tests

* Implement requested changes

* Fix formatting

* Implement requested changes

* Use template type
This commit is contained in:
Mateusz Mikolajczyk 2023-11-27 08:38:08 +01:00 committed by GitHub
parent 6b898fc8d9
commit ed061d5179
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 185 additions and 4 deletions

View File

@ -33,7 +33,7 @@ public:
const std::string& get_destination_type() const;
private:
void validate_type() const;
void validate_destination_type() const;
std::string m_destination_type = "f8e4m3";
};

View File

@ -0,0 +1,39 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/op/fake_convert.hpp"
#include "utils.hpp"
namespace ov {
namespace op {
namespace v13 {
template <class T, class TRShape = result_shape_t<T>>
std::vector<TRShape> shape_infer(const FakeConvert* op, const std::vector<T>& input_shapes) {
const auto inputs_count = input_shapes.size();
const auto has_shifts_input = inputs_count == 3;
NODE_VALIDATION_CHECK(op, inputs_count == 2 || has_shifts_input);
const auto scales_shape = input_shapes[1];
if (has_shifts_input) {
const auto shifts_shape = input_shapes[2];
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
scales_shape.compatible(shifts_shape),
"FakeConvert scale shape is not compatible with shift shape.");
}
TRShape output_pshape = input_shapes[0];
NODE_SHAPE_INFER_CHECK(op,
input_shapes,
T::broadcast_merge_into(output_pshape, scales_shape, op::AutoBroadcastType::NUMPY),
"Argument shapes are inconsistent.");
NODE_SHAPE_INFER_CHECK(
op,
input_shapes,
input_shapes[0].compatible(output_pshape),
"FakeConvert support only unidirectional broadcasting, inputs cannot be broadcastd into data.");
return {output_pshape};
}
} // namespace v13
} // namespace op
} // namespace ov

View File

@ -5,6 +5,7 @@
#include "openvino/op/fake_convert.hpp"
#include "element_visitor.hpp"
#include "fake_convert_shape_inference.hpp"
#include "itt.hpp"
#include "openvino/reference/convert.hpp"
#include "openvino/reference/fake_convert.hpp"
@ -69,8 +70,26 @@ const std::string& FakeConvert::get_destination_type() const {
void FakeConvert::validate_and_infer_types() {
OV_OP_SCOPE(v13_FakeConvert_validate_and_infer_types);
validate_type();
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
validate_destination_type();
auto out_type = get_input_element_type(0);
for (size_t i = 1; i < get_input_size(); i++) {
OPENVINO_ASSERT(element::Type::merge(out_type, out_type, get_input_element_type(i)),
"Mixed input types are not supported.");
}
switch (out_type) {
case element::bf16:
case element::f16:
case element::f32:
case element::dynamic:
break;
default:
OPENVINO_THROW("The element type of the input tensor must be a bf16, f16, f32 but got: ", out_type);
}
OPENVINO_SUPPRESS_DEPRECATED_START
const auto input_shapes = get_node_input_partial_shapes(*this);
OPENVINO_SUPPRESS_DEPRECATED_END
const auto output_shapes = shape_infer(this, input_shapes);
set_output_type(0, out_type, output_shapes[0]);
}
std::shared_ptr<ov::Node> FakeConvert::clone_with_new_inputs(const ov::OutputVector& new_args) const {
@ -90,7 +109,7 @@ bool FakeConvert::visit_attributes(ov::AttributeVisitor& visitor) {
return true;
}
void FakeConvert::validate_type() const {
void FakeConvert::validate_destination_type() const {
const auto& valid_types = fake_convert_details::get_valid_types();
const auto is_supported_type =
std::find(valid_types.begin(), valid_types.end(), m_destination_type) != valid_types.end();

View File

@ -6,6 +6,7 @@
#include <gtest/gtest.h>
#include "common_test_utils/test_assertions.hpp"
#include "common_test_utils/type_prop.hpp"
using namespace ov;
@ -40,6 +41,26 @@ TEST(type_prop, fake_convert_basic_f16) {
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 8, 6}));
}
TEST(type_prop, fake_convert_basic_bf16) {
const auto data = std::make_shared<Parameter>(element::bf16, PartialShape{2, 3, 8, 6});
const auto scale = std::make_shared<Parameter>(element::bf16, PartialShape{});
const auto shift = std::make_shared<Parameter>(element::bf16, PartialShape{});
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift);
EXPECT_EQ(op->get_output_element_type(0), element::bf16);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 8, 6}));
}
TEST(type_prop, fake_convert_basic_dynamic) {
const auto data = std::make_shared<Parameter>(element::dynamic, PartialShape{2, 3, 8, 6});
const auto scale = std::make_shared<Parameter>(element::dynamic, PartialShape{});
const auto shift = std::make_shared<Parameter>(element::dynamic, PartialShape{});
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift);
EXPECT_EQ(op->get_output_element_type(0), element::dynamic);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 8, 6}));
}
TEST(type_prop, fake_convert_dynamic_shape) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape::dynamic());
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{});
@ -49,3 +70,105 @@ TEST(type_prop, fake_convert_dynamic_shape) {
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape::dynamic()));
}
TEST(type_prop, fake_convert_label) {
PartialShape data_shape{2, 1, Dimension::dynamic(), 6};
set_shape_labels(data_shape, 10);
const auto data = std::make_shared<Parameter>(element::f32, data_shape);
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{});
const auto shift = std::make_shared<Parameter>(element::f32, PartialShape{});
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 1, Dimension::dynamic(), 6}));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), testing::ElementsAre(10, 11, 12, 13));
}
TEST(type_prop, fake_convert_example_0) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{2, 1, 3, 6});
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{1, 1, 3, 1});
const auto shift = std::make_shared<Parameter>(element::f32, PartialShape{1, 1, 3, 1});
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 1, 3, 6}));
}
TEST(type_prop, fake_convert_example_1) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{2, 1, 3, 6});
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{3, 1});
const auto shift = std::make_shared<Parameter>(element::f32, PartialShape{3, 1});
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 1, 3, 6}));
}
TEST(type_prop, fake_convert_example_2) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{2, 1, {3, 5}, 6});
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{3, 1});
const auto shift = std::make_shared<Parameter>(element::f32, PartialShape{3, 1});
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 1, 3, 6}));
}
TEST(type_prop, fake_convert_example_3) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{4, {2, 5}, 6});
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{4, 3, 1});
const auto shift = std::make_shared<Parameter>(element::f32, PartialShape{4, 3, 1});
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{4, 3, 6}));
}
TEST(type_prop, fake_convert_basic_unsupported_type) {
const auto data = std::make_shared<Parameter>(element::f64, PartialShape{2, 3, 8, 6});
const auto scale = std::make_shared<Parameter>(element::f64, PartialShape{});
const auto shift = std::make_shared<Parameter>(element::f64, PartialShape{});
OV_EXPECT_THROW(const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift),
Exception,
testing::HasSubstr("The element type of the input tensor must be a bf16, f16, f32 but got: f64"));
}
TEST(type_prop, fake_convert_basic_unsupported_shape_scale_shift) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{2, 3, 8, 6});
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{1});
const auto shift = std::make_shared<Parameter>(element::f32, PartialShape{});
OV_EXPECT_THROW(const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift),
AssertFailure,
testing::HasSubstr("FakeConvert scale shape is not compatible with shift shape."));
}
TEST(type_prop, fake_convert_basic_unsupported_shape_not_broadcastable) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{2, 3, 8, 6});
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{1, 1, 3, 1});
const auto shift = std::make_shared<Parameter>(element::f32, PartialShape{1, 1, 3, 1});
OV_EXPECT_THROW(const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift),
AssertFailure,
testing::HasSubstr("Argument shapes are inconsistent."));
}
TEST(type_prop, fake_convert_basic_unsupported_shape_not_unidirectional_broadcastable) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{2, 3, 1, 6});
const auto scale = std::make_shared<Parameter>(element::f32, PartialShape{1, 1, 3, 1});
const auto shift = std::make_shared<Parameter>(element::f32, PartialShape{1, 1, 3, 1});
OV_EXPECT_THROW(
const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift),
AssertFailure,
testing::HasSubstr(
"FakeConvert support only unidirectional broadcasting, inputs cannot be broadcastd into data."));
}
TEST(type_prop, fake_convert_basic_unsupported_mixed_types) {
const auto data = std::make_shared<Parameter>(element::f32, PartialShape{2, 3, 8, 6});
const auto scale = std::make_shared<Parameter>(element::dynamic, PartialShape{});
const auto shift = std::make_shared<Parameter>(element::f16, PartialShape{});
OV_EXPECT_THROW(const auto op = std::make_shared<op::v13::FakeConvert>(data, scale, shift),
AssertFailure,
testing::HasSubstr("Mixed input types are not supported."));
}