From 761a6d10d0e606b129b074638a9a52e27276d93c Mon Sep 17 00:00:00 2001 From: Szymon Irzabek Date: Wed, 27 Jul 2022 10:14:09 +0200 Subject: [PATCH] [GNA] Add missing support for batch normalization with weights broadcasting. Add unit tests. (#12301) --- .../transformations/broadcast_const.cpp | 45 +++++++++++----- .../transformations/broadcast_const.hpp | 4 +- .../transformations/gna_broadcast_const.cpp | 52 ++++++++++++++----- 3 files changed, 75 insertions(+), 26 deletions(-) diff --git a/src/plugins/intel_gna/transformations/broadcast_const.cpp b/src/plugins/intel_gna/transformations/broadcast_const.cpp index 0c1deeee977..70abfc8f155 100644 --- a/src/plugins/intel_gna/transformations/broadcast_const.cpp +++ b/src/plugins/intel_gna/transformations/broadcast_const.cpp @@ -12,6 +12,7 @@ #include #include #include "legacy/ngraph_ops/eltwise.hpp" +#include "legacy/ngraph_ops/scaleshift.hpp" #include #include @@ -57,19 +58,27 @@ ov::op::BroadcastModeSpec GetBroadcastType(Node eltwise_node) { return ov::op::BroadcastType::NONE; } -bool DoTransformation(Node const_node, Node eltwise_node) { - if (HasDynamicShape(const_node) || HasDynamicShape(eltwise_node)) +bool DoTransformation(Node const_node_1, Node const_node_2, Node eltwise_node) { + if (HasDynamicShape(const_node_1) || (const_node_2 != nullptr && HasDynamicShape(const_node_2)) || HasDynamicShape(eltwise_node)) return false; const ngraph::Shape & eltwise_out_shape = eltwise_node->get_output_tensor(0).get_shape(); auto broadcast_const = ngraph::opset8::Constant::create(ngraph::element::Type_t::i64, ngraph::Shape{eltwise_out_shape.size()}, eltwise_out_shape); - auto new_const_node = ngraph::op::util::make_try_fold(const_node, + auto new_const_node_1 = ngraph::op::util::make_try_fold(const_node_1, broadcast_const, GetBroadcastType(eltwise_node)); - ngraph::replace_node(const_node, new_const_node); + ngraph::replace_node(const_node_1, new_const_node_1); + + if (const_node_2) { + auto new_const_node_2 = ngraph::op::util::make_try_fold(const_node_2, + broadcast_const, + GetBroadcastType(eltwise_node)); + + ngraph::replace_node(const_node_2, new_const_node_2); + } return true; } @@ -92,34 +101,46 @@ bool IsEltwiseAcceptable(const ngraph::Output& output) { BroadcastAddMultiplyConst::BroadcastAddMultiplyConst() { MATCHER_SCOPE(BroadcastAddMultiplyConst); - auto constant = ngraph::pattern::wrap_type(); - auto fake_quantize = ngraph::pattern::wrap_type({constant, + auto constant_1 = ngraph::pattern::wrap_type(); + auto constant_2 = ngraph::pattern::wrap_type(); + auto fake_quantize_1 = ngraph::pattern::wrap_type({constant_1, ngraph::pattern::wrap_type(), ngraph::pattern::wrap_type(), ngraph::pattern::wrap_type(), ngraph::pattern::wrap_type()}); - auto eltwise_input = std::make_shared(ngraph::OutputVector{constant, fake_quantize}); + auto fake_quantize_2 = ngraph::pattern::wrap_type({constant_2, + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type()}); + auto input1 = std::make_shared(ngraph::OutputVector{constant_1, fake_quantize_1}); + auto input2 = std::make_shared(ngraph::OutputVector{constant_2, fake_quantize_2}); auto eltwise_left_const = ngraph::pattern::wrap_type({eltwise_input, ngraph::pattern::any_input()}, IsEltwiseAcceptable); + ngraph::op::Eltwise>({input1, ngraph::pattern::any_input()}, IsEltwiseAcceptable); auto eltwise_right_const = ngraph::pattern::wrap_type({ngraph::pattern::any_input(), eltwise_input}, IsEltwiseAcceptable); - auto eltwise = std::make_shared(ngraph::OutputVector{eltwise_left_const, eltwise_right_const}); + ngraph::op::Eltwise>({ngraph::pattern::any_input(), input1}, IsEltwiseAcceptable); + auto scaleshift = ngraph::pattern::wrap_type({ngraph::pattern::any_input(), input1, input2}, IsEltwiseAcceptable); + auto eltwise = std::make_shared(ngraph::OutputVector{eltwise_left_const, eltwise_right_const, scaleshift}); ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); - auto const_node = pattern_map.at(constant).get_node_shared_ptr(); + auto const_node_1 = pattern_map.at(constant_1).get_node_shared_ptr(); + auto const_it_2 = pattern_map.find(constant_2); + auto const_node_2 = (const_it_2 == std::end(pattern_map) ? nullptr : const_it_2->second.get_node_shared_ptr()); auto eltwise_node_it = pattern_map.find(eltwise_left_const); if (eltwise_node_it == pattern_map.end()) eltwise_node_it = pattern_map.find(eltwise_right_const); + if (eltwise_node_it == pattern_map.end()) + eltwise_node_it = pattern_map.find(scaleshift); if (eltwise_node_it == pattern_map.end()) return false; - return DoTransformation(const_node, eltwise_node_it->second.get_node_shared_ptr()); + return DoTransformation(const_node_1, const_node_2, eltwise_node_it->second.get_node_shared_ptr()); }; auto m = std::make_shared(eltwise, matcher_name); diff --git a/src/plugins/intel_gna/transformations/broadcast_const.hpp b/src/plugins/intel_gna/transformations/broadcast_const.hpp index 17ecd98ce15..77143531a09 100644 --- a/src/plugins/intel_gna/transformations/broadcast_const.hpp +++ b/src/plugins/intel_gna/transformations/broadcast_const.hpp @@ -10,7 +10,7 @@ namespace GNAPluginNS { /** * @brief Brodcast data in Const layer - * Transformation recognizes the next patterns + * Transformation recognizes the following patterns * * Constant Any * | | @@ -22,7 +22,7 @@ namespace GNAPluginNS { * | | * Eltwise * - * Where Eltwise node is one of the: Multiply, Substract and Add + * Where Eltwise node is one of the: Multiply, Substract, Add or ScaleShiftIE * There are different types of broadcasting: NONE/EXPLICIT, NUMPY and PDPD * * If eltwise node inputs have different shapes and one the inputs is Constant node diff --git a/src/tests/unit/gna/ngraph/transformations/gna_broadcast_const.cpp b/src/tests/unit/gna/ngraph/transformations/gna_broadcast_const.cpp index ff23673a70f..baea0d8c008 100644 --- a/src/tests/unit/gna/ngraph/transformations/gna_broadcast_const.cpp +++ b/src/tests/unit/gna/ngraph/transformations/gna_broadcast_const.cpp @@ -12,6 +12,7 @@ #include #include #include "legacy/ngraph_ops/eltwise.hpp" +#include "legacy/ngraph_ops/scaleshift.hpp" namespace testing { @@ -122,22 +123,40 @@ std::shared_ptr CreateFunction(const ngraph::Shape& data_shape bool add_input_fake_quantize, bool add_const_fake_quantize, bool swap_outputs, + bool add_scaleshift, EltwiseFactoryPtr eltwise_factory) { - auto input_params = std::make_shared(ngraph::element::Type_t::f32, data_shape); + const auto input_params_1 = std::make_shared(ngraph::element::Type_t::f32, data_shape); + ngraph::ParameterVector params{input_params_1}; - auto constant = ngraph::opset8::Constant::create(ngraph::element::Type_t::f32, + const auto constant_1 = ngraph::opset8::Constant::create(ngraph::element::Type_t::f32, ngraph::Shape{const_shape_dims}, const_shape_values); - Node const_last_node = constant; + + Node const_last_node = constant_1; + + if (add_scaleshift) { + const auto input_params_2 = std::make_shared(ngraph::element::Type_t::f32, data_shape); + params.push_back(input_params_2); + + const auto constant_2 = ngraph::opset8::Constant::create(ngraph::element::Type_t::f32, + ngraph::Shape{const_shape_dims}, + const_shape_values); + + const_last_node = std::make_shared(input_params_2, + constant_1, + constant_2, + ngraph::element::Type_t::f32); + } if (add_const_fake_quantize) { - auto fake_quantize = createFakeQuantizeNode(const_last_node); + const auto fake_quantize = createFakeQuantizeNode(const_last_node); const_last_node = fake_quantize; } - Node input_last_node = input_params; + Node input_last_node = input_params_1; + if (add_input_fake_quantize) { - auto fake_quantize = createFakeQuantizeNode(input_params); + const auto fake_quantize = createFakeQuantizeNode(input_last_node); input_last_node = fake_quantize; } @@ -147,11 +166,10 @@ std::shared_ptr CreateFunction(const ngraph::Shape& data_shape if (swap_outputs) left_node.swap(right_node); - auto add = eltwise_factory->CreateNode(left_node, right_node); + const auto add = eltwise_factory->CreateNode(left_node, right_node); - auto result = std::make_shared(add); - return std::make_shared(ngraph::ResultVector{result}, - ngraph::ParameterVector{input_params}); + const auto result = std::make_shared(add); + return std::make_shared(ngraph::ResultVector{result}, params); } } // namespace @@ -163,6 +181,7 @@ class BroadcastConstTestFixture: public CommonTestUtils::TestsCommon, bool /* add_input_fake_quantize */, bool /* add_const_fake_quantize */, bool /* swap_outputs */, + bool /* add_scaleshift */, ov::op::AutoBroadcastType>> { public: void SetUp() override; @@ -176,8 +195,9 @@ void BroadcastConstTestFixture::SetUp() { bool add_input_fake_quantize; bool add_const_fake_quantize; bool swap_outputs; + bool add_scaleshift; ov::op::AutoBroadcastType broadcast_type; - std::tie(eltwise_factory, add_input_fake_quantize, add_const_fake_quantize, swap_outputs, broadcast_type) = this->GetParam(); + std::tie(eltwise_factory, add_input_fake_quantize, add_const_fake_quantize, swap_outputs, add_scaleshift, broadcast_type) = this->GetParam(); eltwise_factory->SetBroadcastType(broadcast_type); @@ -189,6 +209,7 @@ void BroadcastConstTestFixture::SetUp() { add_input_fake_quantize, add_const_fake_quantize, swap_outputs, + add_scaleshift, eltwise_factory); reference_function = CreateFunction(shape_info.data_shape, shape_info.data_shape, @@ -196,6 +217,7 @@ void BroadcastConstTestFixture::SetUp() { add_input_fake_quantize, add_const_fake_quantize, swap_outputs, + add_scaleshift, eltwise_factory); } @@ -261,6 +283,7 @@ INSTANTIATE_TEST_SUITE_P(BroadcastConstTestNumpySuite, BroadcastConstTestFixture ::testing::Bool(), ::testing::Bool(), ::testing::Bool(), + ::testing::Bool(), ::testing::Values(ov::op::AutoBroadcastType::NUMPY))); INSTANTIATE_TEST_SUITE_P(BroadcastConstTestPDPDSuite, BroadcastConstTestFixture, @@ -268,6 +291,7 @@ INSTANTIATE_TEST_SUITE_P(BroadcastConstTestPDPDSuite, BroadcastConstTestFixture, ::testing::Bool(), ::testing::Bool(), ::testing::Values(false), + ::testing::Bool(), ::testing::Values(ov::op::AutoBroadcastType::PDPD))); // ------------------------------------------------------------------------------------------------ @@ -277,6 +301,7 @@ class BroadcastConstTestPassedFixture: public CommonTestUtils::TestsCommon, bool /* add_input_fake_quantize */, bool /* add_const_fake_quantize */, bool /* swap_outputs */, + bool /* add_scaleshift */, ov::op::AutoBroadcastType>> { public: void SetUp() override; @@ -290,8 +315,9 @@ void BroadcastConstTestPassedFixture::SetUp() { bool add_input_fake_quantize; bool add_const_fake_quantize; bool swap_outputs; + bool add_scaleshift; ov::op::AutoBroadcastType broadcast_type; - std::tie(eltwise_factory, add_input_fake_quantize, add_const_fake_quantize, swap_outputs, broadcast_type) = this->GetParam(); + std::tie(eltwise_factory, add_input_fake_quantize, add_const_fake_quantize, swap_outputs, add_scaleshift, broadcast_type) = this->GetParam(); eltwise_factory->SetBroadcastType(broadcast_type); @@ -303,6 +329,7 @@ void BroadcastConstTestPassedFixture::SetUp() { add_input_fake_quantize, add_const_fake_quantize, swap_outputs, + add_scaleshift, eltwise_factory); } @@ -315,6 +342,7 @@ INSTANTIATE_TEST_SUITE_P(BroadcastConstTestPassedSuite, BroadcastConstTestPassed ::testing::Bool(), ::testing::Bool(), ::testing::Bool(), + ::testing::Bool(), ::testing::ValuesIn(broadcast_passed_types))); } // namespace testing