From 8ba339d16fb021bc825c7e6b0b267ead80674898 Mon Sep 17 00:00:00 2001 From: Mateusz Tabaka Date: Thu, 13 Jan 2022 18:05:47 +0100 Subject: [PATCH] Eliminate no-op elementwise operations (#8840) * Eliminate no-op elementwise operations This change adds EliminateEltwise pass to NopElimination. EliminateEltwise removes: - Subtract with zero - Multiply with one - Divide by one * add Add support to EliminateEltwise * fix unit test * use get_all_data_elements_bitwise_identical instead of get_single_value * fix are_all_data_elements_bitwise_identical for constant created from HostTensor * fix lpt tests * check for mul in is_dequantization_subgraph function * optimize fetching constant value * apply review comments --- .../common_optimizations/nop_elimination.hpp | 10 + .../include/transformations/utils/utils.hpp | 4 + .../common_optimizations/mul_conv_fusion.cpp | 34 +-- .../common_optimizations/nop_elimination.cpp | 30 +++ .../op_conversions/convert_divide.cpp | 23 +- .../src/transformations/utils/utils.cpp | 126 ++++++++++ src/core/src/op/constant.cpp | 2 +- src/core/tests/visitors/op/constant.cpp | 33 +++ .../transformations/convert_divide.cpp | 27 +++ .../transformations/nop_elimination.cpp | 216 +++++++++++++++++- .../common_test_utils/ngraph_test_utils.cpp | 8 +- 11 files changed, 471 insertions(+), 42 deletions(-) diff --git a/src/common/transformations/include/transformations/common_optimizations/nop_elimination.hpp b/src/common/transformations/include/transformations/common_optimizations/nop_elimination.hpp index ca5028d5126..aeffc8a392b 100644 --- a/src/common/transformations/include/transformations/common_optimizations/nop_elimination.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/nop_elimination.hpp @@ -21,6 +21,7 @@ class TRANSFORMATIONS_API EliminateConvertNonZero; class TRANSFORMATIONS_API EliminateConcat; class TRANSFORMATIONS_API EliminateSplit; class TRANSFORMATIONS_API EliminateTranspose; +class TRANSFORMATIONS_API EliminateEltwise; class TRANSFORMATIONS_API NopElimination; } // namespace pass @@ -86,6 +87,15 @@ public: EliminateTranspose(); }; +/** + * @ingroup ie_transformation_common_api + * @brief EliminateEltwise eliminates eltwise ops that do nothing + */ +class ngraph::pass::EliminateEltwise: public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + EliminateEltwise(); +}; class ngraph::pass::NopElimination: public GraphRewrite { public: diff --git a/src/common/transformations/include/transformations/utils/utils.hpp b/src/common/transformations/include/transformations/utils/utils.hpp index 576f7d127dd..77acf4841f3 100644 --- a/src/common/transformations/include/transformations/utils/utils.hpp +++ b/src/common/transformations/include/transformations/utils/utils.hpp @@ -210,6 +210,10 @@ TRANSFORMATIONS_API std::shared_ptr node_to_get_shape_value_of_ind TRANSFORMATIONS_API std::shared_ptr node_to_get_shape_value_of_indices_from_shape_source( const ngraph::Output& shape_source, const std::vector& indices); + +TRANSFORMATIONS_API bool is_dequantization_subgraph(const Output& node); + +TRANSFORMATIONS_API bool can_eliminate_eltwise_node(const std::shared_ptr& eltwise, const Output& constant, const Output& non_constant_input); } // namespace util } // namespace op } // namespace ngraph diff --git a/src/common/transformations/src/transformations/common_optimizations/mul_conv_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/mul_conv_fusion.cpp index 8e034e465e7..06fc2d73a3a 100644 --- a/src/common/transformations/src/transformations/common_optimizations/mul_conv_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/mul_conv_fusion.cpp @@ -18,32 +18,6 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::MultiplyConvolutionFusion, "MultiplyConvolutionFusion", 0); -static bool is_dequantization_subgraph(const ngraph::Output& multiply) { - auto inputs = multiply.get_node()->input_values(); - const auto subtract = std::find_if(inputs.begin(), inputs.end(), - [] (const ngraph::Output& n) -> bool { - return ov::is_type(n.get_node()); - }); - if (subtract != inputs.end()) - inputs = subtract->get_node()->input_values(); - const auto first_convert = std::find_if(inputs.begin(), inputs.end(), - [] (const ngraph::Output& n) -> bool { - if (ov::is_type(n.get_node())) { - const auto input = n.get_node()->input_value(0); - return ov::is_type(input.get_node()); - } - return false; - }); - if (first_convert == inputs.end()) - return false; - const auto second_convert = first_convert->get_node()->input_value(0); - const auto& first_convert_src_type = second_convert.get_element_type(); - const auto& first_convert_dest_type = first_convert->get_element_type(); - const auto second_convert_src_type = second_convert.get_node()->input_value(0).get_element_type(); - return (first_convert_src_type == ngraph::element::i8 || first_convert_src_type == ngraph::element::u8) && - first_convert_dest_type == second_convert_src_type; -} - ngraph::pass::MultiplyConvolutionFusion::MultiplyConvolutionFusion() { MATCHER_SCOPE(MultiplyConvolutionFusion); auto input_pattern = pattern::any_input(); @@ -57,7 +31,7 @@ ngraph::pass::MultiplyConvolutionFusion::MultiplyConvolutionFusion() { // Can't fuse Multiply to Convolution if that Multiply is part of dequantization subgraph // since that breaks low precision transformations - if (is_dequantization_subgraph(pattern_to_output.at(mul_pattern))) + if (op::util::is_dequantization_subgraph(pattern_to_output.at(mul_pattern))) return false; const auto& weights = pattern_to_output.at(weights_pattern); @@ -110,7 +84,7 @@ ngraph::pass::MultiplyGroupConvolutionFusion::MultiplyGroupConvolutionFusion() { // Can't fuse Multiply to Convolution if that Multiply is part of dequantization subgraph // since that breaks low precision transformations - if (is_dequantization_subgraph(pattern_to_output.at(mul_pattern))) + if (op::util::is_dequantization_subgraph(pattern_to_output.at(mul_pattern))) return false; const auto& weights = pattern_to_output.at(weights_pattern); @@ -174,7 +148,7 @@ ngraph::pass::MultiplyConvolutionBackpropDataFusion::MultiplyConvolutionBackprop // Can't fuse Multiply to Convolution if that Multiply is part of dequantization subgraph // since that breaks low precision transformations - if (is_dequantization_subgraph(pattern_to_output.at(mul_pattern))) + if (op::util::is_dequantization_subgraph(pattern_to_output.at(mul_pattern))) return false; const auto& weights = pattern_to_output.at(weights_pattern); @@ -240,7 +214,7 @@ ngraph::pass::MultiplyGroupConvolutionBackpropDataFusion::MultiplyGroupConvoluti // Can't fuse Multiply to Convolution if that Multiply is part of dequantization subgraph // since that breaks low precision transformations - if (is_dequantization_subgraph(pattern_to_output.at(mul_pattern))) + if (op::util::is_dequantization_subgraph(pattern_to_output.at(mul_pattern))) return false; const auto& weights = pattern_to_output.at(weights_pattern); diff --git a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp index c8825ff650a..33f94c85142 100644 --- a/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/nop_elimination.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include using namespace std; @@ -503,6 +504,34 @@ pass::EliminateTranspose::EliminateTranspose() { this->register_matcher(m, callback); } +NGRAPH_RTTI_DEFINITION(pass::EliminateEltwise, "EliminateEltwise", 0); + +pass::EliminateEltwise::EliminateEltwise() { + MATCHER_SCOPE(EliminateEltwise); + auto input = pattern::any_input(); + auto constant_pattern = pattern::wrap_type(); + auto eltwise_pattern = pattern::wrap_type({input, constant_pattern}); + + matcher_pass_callback callback = [=](pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + auto eltwise = pattern_map.at(eltwise_pattern).get_node_shared_ptr(); + auto non_const_input = pattern_map.at(input); + auto constant = pattern_map.at(constant_pattern); + + if (!op::util::can_eliminate_eltwise_node(eltwise, constant, non_const_input)) { + return false; + } + + return replace_output_update_name(eltwise->output(0), non_const_input); + }; + + auto m = std::make_shared(eltwise_pattern, matcher_name); + this->register_matcher(m, callback); +} + NGRAPH_RTTI_DEFINITION(ngraph::pass::NopElimination, "NopElimination", 0); ngraph::pass::NopElimination::NopElimination(bool use_shape_for_elimination) { @@ -513,6 +542,7 @@ ngraph::pass::NopElimination::NopElimination(bool use_shape_for_elimination) { add_matcher(); add_matcher(); add_matcher(); + add_matcher(); // shape-dependent transformations if (use_shape_for_elimination) { diff --git a/src/common/transformations/src/transformations/op_conversions/convert_divide.cpp b/src/common/transformations/src/transformations/op_conversions/convert_divide.cpp index 79862b3186e..777dafc8a54 100644 --- a/src/common/transformations/src/transformations/op_conversions/convert_divide.cpp +++ b/src/common/transformations/src/transformations/op_conversions/convert_divide.cpp @@ -4,6 +4,8 @@ #include "itt.hpp" #include "transformations/op_conversions/convert_divide.hpp" +#include "transformations/utils/utils.hpp" + #include #include @@ -28,25 +30,30 @@ bool convert_divide(std::shared_ptr node) { return false; } - ngraph::Output pow = std::make_shared(div->input_value(1), + std::shared_ptr pow = std::make_shared(div->input_value(1), ngraph::op::Constant::create(div->get_input_element_type(1), ngraph::Shape{}, {-1})); if (std::dynamic_pointer_cast(div->get_input_node_shared_ptr(1))) { if (auto const_pow = ngraph::get_constant_from_source(pow)) { pow = const_pow; } else { - NGRAPH_DEBUG << "ConvertDivide has failed due to unsupported evaluate type in " << pow.get_node(); + NGRAPH_DEBUG << "ConvertDivide has failed due to unsupported evaluate type in " << pow.get(); return false; } } else { - ngraph::copy_runtime_info(div, pow.get_node_shared_ptr()); + ngraph::copy_runtime_info(div, pow); } auto mul = std::make_shared(div->input(0).get_source_output(), pow); - - mul->set_friendly_name(div->get_friendly_name()); - ngraph::copy_runtime_info(div, mul); - ngraph::replace_node(div, mul); + // if Divide is an inverse, then we don't need the Multiply + if (ngraph::op::util::can_eliminate_eltwise_node(mul, mul->input_value(0), mul->input_value(1))) { + pow->set_friendly_name(div->get_friendly_name()); + ngraph::replace_node(div, pow); + } else { + mul->set_friendly_name(div->get_friendly_name()); + ngraph::copy_runtime_info(div, mul); + ngraph::replace_node(div, mul); + } return true; } } // namespace @@ -74,4 +81,4 @@ ngraph::pass::ConvertDivideWithConstant::ConvertDivideWithConstant() { auto m = std::make_shared(div, matcher_name); this->register_matcher(m, callback); -} \ No newline at end of file +} diff --git a/src/common/transformations/src/transformations/utils/utils.cpp b/src/common/transformations/src/transformations/utils/utils.cpp index 5f98f8f2992..da7cc581c87 100644 --- a/src/common/transformations/src/transformations/utils/utils.cpp +++ b/src/common/transformations/src/transformations/utils/utils.cpp @@ -203,6 +203,132 @@ void visit_shape_path(const std::shared_ptr& node, } } +bool is_dequantization_subgraph(const Output& node) { + if (!is_type(node.get_node())) { + return false; + } + + auto mul_inputs = node.get_node()->input_values(); + Node* sub = nullptr; + Node* convert = nullptr; + + if (is_type(mul_inputs[0].get_node())) { + sub = mul_inputs[0].get_node(); + } else if (is_type(mul_inputs[0].get_node())) { + convert = mul_inputs[0].get_node(); + } else { + return false; + } + + if (sub) { + auto sub_inputs = sub->input_values(); + if (is_type(sub_inputs[0].get_node())) { + convert = sub_inputs[0].get_node(); + } + } + + if (!convert) { + return false; + } + + auto input_type = convert->get_input_element_type(0); + auto output_type = convert->get_output_element_type(0); + return input_type.is_integral() && output_type.is_real(); +} + +bool can_eliminate_eltwise_node(const std::shared_ptr& eltwise, const Output& constant, const Output& non_constant_input) { + if (!is_type(eltwise) && + !is_type(eltwise) && + !is_type(eltwise) && + !is_type(eltwise)) { + return false; + } + + if (is_dequantization_subgraph(eltwise)) { + return false; + } + + // check if constant has a single value with either 0 (for Add, Subtract) or 1 (for Multiply, Divide) + auto constant_ptr = std::dynamic_pointer_cast(constant.get_node_shared_ptr()); + if (!constant_ptr) { + return false; + } + if (!constant_ptr->get_all_data_elements_bitwise_identical()) { + return false; + } + float actual_const = 0; + const void* data_ptr = constant_ptr->get_data_ptr(); + switch (constant_ptr->get_element_type()) { + case element::f32: + actual_const = reinterpret_cast(data_ptr)[0]; + break; + case element::i32: + actual_const = static_cast(reinterpret_cast(data_ptr)[0]); + break; + case element::u32: + actual_const = static_cast(reinterpret_cast(data_ptr)[0]); + break; + case element::i64: + actual_const = static_cast(reinterpret_cast(data_ptr)[0]); + break; + case element::u64: + actual_const = static_cast(reinterpret_cast(data_ptr)[0]); + break; + case element::i8: + actual_const = static_cast(reinterpret_cast(data_ptr)[0]); + break; + case element::u8: + actual_const = static_cast(reinterpret_cast(data_ptr)[0]); + break; + case element::i16: + actual_const = static_cast(reinterpret_cast(data_ptr)[0]); + break; + case element::u16: + actual_const = static_cast(reinterpret_cast(data_ptr)[0]); + break; + case element::f64: + actual_const = static_cast(reinterpret_cast(data_ptr)[0]); + break; + default: + return false; + } + float expected_const = 0; + if (is_type(eltwise) || + is_type(eltwise)) { + expected_const = 1; + } + if (actual_const != expected_const) { + return false; + } + + // fuse uncoditionally if constant is a scalar + const auto& constant_shape = constant.get_shape(); + if (ov::is_scalar(constant_shape)) { + return true; + } + + const auto& input_shape = non_constant_input.get_partial_shape(); + if (input_shape.rank().is_dynamic()) { + return false; + } + + // cannot fuse if constant extends input's rank + auto input_rank = static_cast(input_shape.rank().get_length()); + auto constant_rank = constant_shape.size(); + if (input_rank < constant_rank) { + return false; + } + + // cannot fuse if constant makes input to be broadcasted, e.g. + // Multiply(input{2, 1, 5}, constant{1, 5, 1}) -> {2, 5, 5} + for (size_t i = 0; i < constant_rank; i++) { + auto constant_dim = constant_shape[constant_rank - i - 1]; + if (constant_dim != 1 && input_shape[input_rank - i - 1] != constant_dim) { + return false; + } + } + return true; +} } // namespace util } // namespace op } // namespace ngraph diff --git a/src/core/src/op/constant.cpp b/src/core/src/op/constant.cpp index 8fd55284055..eb31e3059a5 100644 --- a/src/core/src/op/constant.cpp +++ b/src/core/src/op/constant.cpp @@ -48,8 +48,8 @@ ov::op::v0::Constant::Constant(const shared_ptr& tensor constructor_validate_and_infer_types(); allocate_buffer(); tensor->read(get_data_ptr_nc(), tensor->get_size_in_bytes()); - m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical(); } + m_all_elements_bitwise_identical = are_all_data_elements_bitwise_identical(); constructor_validate_and_infer_types(); } diff --git a/src/core/tests/visitors/op/constant.cpp b/src/core/tests/visitors/op/constant.cpp index 3f6dec98aac..64cf98dfa09 100644 --- a/src/core/tests/visitors/op/constant.cpp +++ b/src/core/tests/visitors/op/constant.cpp @@ -9,6 +9,7 @@ #include "ngraph/opsets/opset3.hpp" #include "ngraph/opsets/opset4.hpp" #include "ngraph/opsets/opset5.hpp" +#include "ngraph/runtime/host_tensor.hpp" #include "util/visitor.hpp" using namespace std; @@ -56,3 +57,35 @@ TEST(attributes, constant_op_identical_elements) { EXPECT_EQ(data, g_data); ASSERT_TRUE(g_k->get_all_data_elements_bitwise_identical()); } + +TEST(attributes, constant_op_from_host_tensor_different_elements) { + vector data{5, 4, 3, 2, 1, 0}; + auto tensor = std::make_shared(element::i64, Shape{2, 3}, &data[0]); + auto k = make_shared(tensor); + ASSERT_FALSE(k->get_all_data_elements_bitwise_identical()); + NodeBuilder builder(k); + auto g_k = ov::as_type_ptr(builder.create()); + g_k->validate_and_infer_types(); + ASSERT_TRUE(g_k); + EXPECT_EQ(k->get_element_type(), g_k->get_element_type()); + EXPECT_EQ(k->get_shape(), g_k->get_shape()); + vector g_data = g_k->get_vector(); + EXPECT_EQ(data, g_data); + ASSERT_FALSE(g_k->get_all_data_elements_bitwise_identical()); +} + +TEST(attributes, constant_op_from_host_tensor_identical_elements) { + vector data{5, 5, 5, 5, 5, 5}; + auto tensor = std::make_shared(element::i64, Shape{2, 3}, &data[0]); + auto k = make_shared(tensor); + ASSERT_TRUE(k->get_all_data_elements_bitwise_identical()); + NodeBuilder builder(k); + auto g_k = ov::as_type_ptr(builder.create()); + g_k->validate_and_infer_types(); + ASSERT_TRUE(g_k); + EXPECT_EQ(k->get_element_type(), g_k->get_element_type()); + EXPECT_EQ(k->get_shape(), g_k->get_shape()); + vector g_data = g_k->get_vector(); + EXPECT_EQ(data, g_data); + ASSERT_TRUE(g_k->get_all_data_elements_bitwise_identical()); +} diff --git a/src/tests/functional/inference_engine/transformations/convert_divide.cpp b/src/tests/functional/inference_engine/transformations/convert_divide.cpp index 0b9d96a8ac2..55e37172b53 100644 --- a/src/tests/functional/inference_engine/transformations/convert_divide.cpp +++ b/src/tests/functional/inference_engine/transformations/convert_divide.cpp @@ -38,8 +38,31 @@ TEST_F(TransformationTestsF, ConvertDivide) { function_ref = std::make_shared(ngraph::NodeVector{mul}, ngraph::ParameterVector{data}); } + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); } +TEST_F(TransformationTestsF, ConvertDivideInverse) { + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{3, 1, 2}); + auto divide_constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1}); + auto divide = std::make_shared(divide_constant, data); + + function = std::make_shared(ngraph::NodeVector{divide}, ngraph::ParameterVector{data}); + + manager.register_pass(); + } + + { + auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{3, 1, 2}); + auto constant = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{}, {-1.0}); + auto pow = std::make_shared(data, constant); + + function_ref = std::make_shared(ngraph::NodeVector{pow}, ngraph::ParameterVector{data}); + } + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); +} + + TEST_F(TransformationTestsF, ConvertDivideNegative) { { auto data = std::make_shared(ngraph::element::i32, ngraph::Shape{3, 1, 2}); @@ -58,6 +81,7 @@ TEST_F(TransformationTestsF, ConvertDivideNegative) { function_ref = std::make_shared(ngraph::NodeVector{divide}, ngraph::ParameterVector{data}); } + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); } TEST_F(TransformationTestsF, ConvertDivideScalar) { @@ -84,6 +108,7 @@ TEST_F(TransformationTestsF, ConvertDivideScalar) { NGRAPH_CHECK(mul->get_output_partial_shape(0).rank().get_length() == 0); } + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); } TEST_F(TransformationTestsF, ConvertDivideWithConstantPositive) { @@ -103,6 +128,7 @@ TEST_F(TransformationTestsF, ConvertDivideWithConstantPositive) { function_ref = std::make_shared(ngraph::NodeVector{mul}, ngraph::ParameterVector{data}); } + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); } TEST_F(TransformationTestsF, ConvertDivideWithConstantNegative) { @@ -122,6 +148,7 @@ TEST_F(TransformationTestsF, ConvertDivideWithConstantNegative) { function_ref = std::make_shared(ngraph::NodeVector{divide}, ngraph::ParameterVector{data1, data2}); } + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); } TEST_F(TransformationTestsF, ConvertDivideFP16ShapeOfSubgraphNegative) { diff --git a/src/tests/functional/inference_engine/transformations/nop_elimination.cpp b/src/tests/functional/inference_engine/transformations/nop_elimination.cpp index ef654188194..bfe7db2dbe7 100644 --- a/src/tests/functional/inference_engine/transformations/nop_elimination.cpp +++ b/src/tests/functional/inference_engine/transformations/nop_elimination.cpp @@ -18,8 +18,11 @@ #include #include #include +#include +#include #include "common_test_utils/ngraph_test_utils.hpp" +#include "common_test_utils/common_utils.hpp" #include @@ -817,4 +820,215 @@ TEST(nop_elimination, gather_3d_indices_constant_axis_1) { check_usecase(PartialShape{3, 2, 1}, i32, multiout, std::vector{1}, 2, 0); check_usecase(PartialShape{1, 16}, i32, multiout, std::vector{0, 0}, 0, 1); } -} \ No newline at end of file +} + +using namespace helpers; + +struct ShapeParams { + PartialShape shape1; + Shape shape2; + bool swap_inputs; + bool can_fuse; +}; + +enum class ConstantKind { + ZERO, + ONE, + RANDOM, +}; + +static std::ostream& operator<<(std::ostream& os, ConstantKind kind) { + switch (kind) { + case ConstantKind::ZERO: + os << "zero"; + break; + case ConstantKind::ONE: + os << "one"; + break; + case ConstantKind::RANDOM: + os << "random"; + break; + } + return os; +} + +struct TypeParams { + EltwiseTypes op_type; + ConstantKind constant_kind; + bool can_fuse; +}; + +using EliminateEltwiseParams = std::tuple; + +class EliminateEltwiseTests: public testing::WithParamInterface, virtual public TransformationTestsF { + public: + static std::string get_test_case_name(testing::TestParamInfo info) { + const auto& shape_params = std::get<0>(info.param); + const auto& type_params = std::get<1>(info.param); + const auto& element_type = std::get<2>(info.param); + std::ostringstream result; + result << type_params.op_type + << "_input1=" << shape_params.shape1 + << "_input2=" << shape_params.shape2 + << "_swap_inputs=" << std::boolalpha << shape_params.swap_inputs + << "_constant=" << type_params.constant_kind + << "_type=" << element_type; + return result.str(); + } +}; + +TEST_P(EliminateEltwiseTests, eliminate_eltwise) { + auto params = GetParam(); + const auto& shape_params = std::get<0>(params); + const auto& type_params = std::get<1>(params); + const auto& type = std::get<2>(params); + const auto& shape1 = shape_params.shape1; + const auto& shape2 = shape_params.shape2; + bool swap_inputs = shape_params.swap_inputs; + bool can_fuse = shape_params.can_fuse && type_params.can_fuse; + + auto parameter = make_shared(type, shape1); + std::shared_ptr constant; + switch (type_params.constant_kind) { + case ConstantKind::ZERO: + constant = op::Constant::create(type, shape2, {0}); + break; + case ConstantKind::ONE: + constant = op::Constant::create(type, shape2, {1}); + break; + case ConstantKind::RANDOM: + constant = builder::makeConstant(type, shape2, {}, true, 20 /* upTo */, 2 /* startFrom */); + break; + } + + shared_ptr A = parameter; + shared_ptr B = constant; + if (swap_inputs) { + std::swap(A, B); + if (type_params.op_type == EltwiseTypes::SUBTRACT || + type_params.op_type == EltwiseTypes::DIVIDE) { + can_fuse = false; + } + } + + shared_ptr node; + switch (type_params.op_type) { + case EltwiseTypes::ADD: + node = make_shared(A, B); + break; + case EltwiseTypes::SUBTRACT: + node = make_shared(A, B); + break; + case EltwiseTypes::MULTIPLY: + node = make_shared(A, B); + break; + case EltwiseTypes::DIVIDE: + node = make_shared(A, B); + break; + default: + ASSERT_FALSE(true) << "Invalid EltwiseType"; + } + auto abs = make_shared(node); + function = make_shared(abs, ParameterVector{parameter}); + + manager.register_pass(); + + if (can_fuse) { + auto abs = make_shared(parameter); + function_ref = make_shared(abs, ParameterVector{parameter}); + } + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + if (type == element::f32) { + enable_accuracy_check(); + } +} + +std::vector shape_params = { + // input 1, input 2, swap inputs, can fuse + { Shape{}, Shape{}, false, true }, + { Shape{}, Shape{}, true, true }, + { Shape{5}, Shape{}, false, true }, + { Shape{5}, Shape{1}, false, true }, + { Shape{5}, Shape{5}, false, true }, + { Shape{5}, Shape{5}, true, true }, + { Shape{2, 3, 5}, Shape{}, false, true }, + { Shape{2, 3, 5}, Shape{1}, false, true }, + { Shape{2, 3, 5}, Shape{1, 1}, false, true }, + { Shape{2, 3, 5}, Shape{1, 1, 1}, false, true }, + { Shape{2, 3, 5}, Shape{5}, false, true }, + { Shape{2, 3, 5}, Shape{1, 5}, false, true }, + { Shape{2, 3, 5}, Shape{1, 1, 5}, false, true }, + { Shape{2, 3, 5}, Shape{3, 5}, false, true }, + { Shape{2, 3, 5}, Shape{1, 3, 5}, false, true }, + { Shape{2, 3, 5}, Shape{2, 3, 5}, false, true }, + { Shape{2, 3, 5}, Shape{2, 3, 5}, true, true }, + { PartialShape::dynamic(), Shape{}, false, true }, + { PartialShape::dynamic(3), Shape{}, false, true }, + { PartialShape::dynamic(3), Shape{1}, false, true }, + { PartialShape::dynamic(3), Shape{1, 1}, false, true }, + { PartialShape::dynamic(3), Shape{1, 1, 1}, false, true }, + { PartialShape{Dimension::dynamic(), 3, Dimension::dynamic()}, Shape{1, 1}, false, true }, + { PartialShape{Dimension::dynamic(), 3, Dimension::dynamic()}, Shape{3, 1}, false, true }, + { PartialShape{Dimension::dynamic(), 3, Dimension::dynamic()}, Shape{1, 3, 1}, false, true }, + // negative cases + { Shape{}, Shape{1}, false, false }, + { Shape{}, Shape{2, 3}, false, false }, + { Shape{5}, Shape{1, 1}, false, false }, + { Shape{4, 1, 3}, Shape{2, 3}, false, false }, + { Shape{1, 2, 3}, Shape{4, 2, 3}, false, false }, + { Shape{1, 1, 3}, Shape{2, 1, 3}, false, false }, + { Shape{1, 2, 3}, Shape{1, 1, 1, 1}, false, false }, + { PartialShape::dynamic(), Shape{2, 3, 4}, false, false }, + { PartialShape::dynamic(3), Shape{1, 2, 1}, false, false }, + { PartialShape::dynamic(3), Shape{1, 1, 1, 1}, false, false }, +}; + +std::vector type_params = { + // op type, constant value, can fuse + { EltwiseTypes::ADD, ConstantKind::ZERO, true }, + { EltwiseTypes::ADD, ConstantKind::RANDOM, false }, + { EltwiseTypes::SUBTRACT, ConstantKind::ZERO, true }, + { EltwiseTypes::SUBTRACT, ConstantKind::RANDOM, false }, + { EltwiseTypes::MULTIPLY, ConstantKind::ONE, true }, + { EltwiseTypes::MULTIPLY, ConstantKind::RANDOM, false }, + { EltwiseTypes::DIVIDE, ConstantKind::ONE, true }, + { EltwiseTypes::DIVIDE, ConstantKind::RANDOM, false }, +}; + +std::vector types{ + element::f32, element::f64, + element::i32, element::u32, + element::i64, element::u64, + element::i8, element::u8, + element::i16, element::u16, +}; + +INSTANTIATE_TEST_SUITE_P(EliminateEltwise, EliminateEltwiseTests, + ::testing::Combine( + ::testing::ValuesIn(shape_params), + ::testing::ValuesIn(type_params), + ::testing::ValuesIn(types)), + EliminateEltwiseTests::get_test_case_name); + + +TEST_F(TransformationTestsF, eliminate_eltwise_dequantization_subgraph) { + { + auto constant = opset8::Constant::create(element::i8, Shape{}, {2}); + auto convert = make_shared(constant, element::f32); + auto sub = make_shared(convert, opset8::Constant::create(element::f32, Shape{}, {0})); + auto mul = make_shared(sub, opset8::Constant::create(element::f32, Shape{}, {1})); + function = make_shared(mul, ParameterVector{}); + } + { + auto constant = opset8::Constant::create(element::i8, Shape{}, {2}); + auto convert = make_shared(constant, element::f32); + auto mul = make_shared(convert, opset8::Constant::create(element::f32, Shape{}, {1})); + function_ref = make_shared(mul, ParameterVector{}); + } + + manager.register_pass(); + + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + enable_accuracy_check(); +} diff --git a/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp b/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp index 88c8e42cdeb..bbbd3a08389 100644 --- a/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp +++ b/src/tests/ie_test_utils/common_test_utils/ngraph_test_utils.cpp @@ -201,7 +201,11 @@ void TransformationTestsF::accuracy_check(std::shared_ptr ref_functio for (auto param : ref_function->get_parameters()) { types.push_back(param->get_element_type()); - InferenceEngine::TensorDesc td(InferenceEngine::Precision::FP32, param->get_shape(), InferenceEngine::Layout::ANY); + auto layout = InferenceEngine::Layout::ANY; + if (ov::is_scalar(param->get_shape())) { + layout = InferenceEngine::Layout::SCALAR; + } + InferenceEngine::TensorDesc td(InferenceEngine::Precision::FP32, param->get_shape(), layout); const auto &input = FuncTestUtils::createAndFillBlob(td); const auto &input_size = input->byteSize(); @@ -227,8 +231,8 @@ void TransformationTestsF::accuracy_check(std::shared_ptr ref_functio IE_ASSERT(ref_outputs[i].second.size() == outputs[i].second.size()); auto * ref = reinterpret_cast(ref_outputs[i].second.data()); auto * out = reinterpret_cast(outputs[i].second.data()); - IE_ASSERT(ref_outputs[i].second.size() / 8); size_t size = ref_outputs[i].second.size() / sizeof(float); + IE_ASSERT(size > 0); Compare(ref, out, size, 1e-5); } }