[GNA] Add missing support for batch normalization with weights broadcasting. Add unit tests. (#12301)

This commit is contained in:
Szymon Irzabek 2022-07-27 10:14:09 +02:00 committed by GitHub
parent d7cf585485
commit 761a6d10d0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 75 additions and 26 deletions

View File

@ -12,6 +12,7 @@
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
#include "legacy/ngraph_ops/eltwise.hpp"
#include "legacy/ngraph_ops/scaleshift.hpp"
#include <transformations/utils/utils.hpp>
#include <vector>
@ -57,19 +58,27 @@ ov::op::BroadcastModeSpec GetBroadcastType(Node eltwise_node) {
return ov::op::BroadcastType::NONE;
}
bool DoTransformation(Node const_node, Node eltwise_node) {
if (HasDynamicShape(const_node) || HasDynamicShape(eltwise_node))
bool DoTransformation(Node const_node_1, Node const_node_2, Node eltwise_node) {
if (HasDynamicShape(const_node_1) || (const_node_2 != nullptr && HasDynamicShape(const_node_2)) || HasDynamicShape(eltwise_node))
return false;
const ngraph::Shape & eltwise_out_shape = eltwise_node->get_output_tensor(0).get_shape();
auto broadcast_const = ngraph::opset8::Constant::create(ngraph::element::Type_t::i64,
ngraph::Shape{eltwise_out_shape.size()}, eltwise_out_shape);
auto new_const_node = ngraph::op::util::make_try_fold<ngraph::opset8::Broadcast>(const_node,
auto new_const_node_1 = ngraph::op::util::make_try_fold<ngraph::opset8::Broadcast>(const_node_1,
broadcast_const,
GetBroadcastType(eltwise_node));
ngraph::replace_node(const_node, new_const_node);
ngraph::replace_node(const_node_1, new_const_node_1);
if (const_node_2) {
auto new_const_node_2 = ngraph::op::util::make_try_fold<ngraph::opset8::Broadcast>(const_node_2,
broadcast_const,
GetBroadcastType(eltwise_node));
ngraph::replace_node(const_node_2, new_const_node_2);
}
return true;
}
@ -92,34 +101,46 @@ bool IsEltwiseAcceptable(const ngraph::Output<ngraph::Node>& output) {
BroadcastAddMultiplyConst::BroadcastAddMultiplyConst() {
MATCHER_SCOPE(BroadcastAddMultiplyConst);
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
auto constant_1 = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto constant_2 = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
auto fake_quantize_1 = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant_1,
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
auto eltwise_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
auto fake_quantize_2 = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant_2,
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
auto input1 = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant_1, fake_quantize_1});
auto input2 = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant_2, fake_quantize_2});
auto eltwise_left_const = ngraph::pattern::wrap_type<ngraph::opset8::Add,
ngraph::opset8::Subtract,
ngraph::opset8::Multiply,
ngraph::op::Eltwise>({eltwise_input, ngraph::pattern::any_input()}, IsEltwiseAcceptable);
ngraph::op::Eltwise>({input1, ngraph::pattern::any_input()}, IsEltwiseAcceptable);
auto eltwise_right_const = ngraph::pattern::wrap_type<ngraph::opset8::Add,
ngraph::opset8::Subtract,
ngraph::opset8::Multiply,
ngraph::op::Eltwise>({ngraph::pattern::any_input(), eltwise_input}, IsEltwiseAcceptable);
auto eltwise = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{eltwise_left_const, eltwise_right_const});
ngraph::op::Eltwise>({ngraph::pattern::any_input(), input1}, IsEltwiseAcceptable);
auto scaleshift = ngraph::pattern::wrap_type<ngraph::op::ScaleShiftIE>({ngraph::pattern::any_input(), input1, input2}, IsEltwiseAcceptable);
auto eltwise = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{eltwise_left_const, eltwise_right_const, scaleshift});
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto const_node = pattern_map.at(constant).get_node_shared_ptr();
auto const_node_1 = pattern_map.at(constant_1).get_node_shared_ptr();
auto const_it_2 = pattern_map.find(constant_2);
auto const_node_2 = (const_it_2 == std::end(pattern_map) ? nullptr : const_it_2->second.get_node_shared_ptr());
auto eltwise_node_it = pattern_map.find(eltwise_left_const);
if (eltwise_node_it == pattern_map.end())
eltwise_node_it = pattern_map.find(eltwise_right_const);
if (eltwise_node_it == pattern_map.end())
eltwise_node_it = pattern_map.find(scaleshift);
if (eltwise_node_it == pattern_map.end())
return false;
return DoTransformation(const_node, eltwise_node_it->second.get_node_shared_ptr());
return DoTransformation(const_node_1, const_node_2, eltwise_node_it->second.get_node_shared_ptr());
};
auto m = std::make_shared<ngraph::pattern::Matcher>(eltwise, matcher_name);

View File

@ -10,7 +10,7 @@ namespace GNAPluginNS {
/**
* @brief Brodcast data in Const layer
* Transformation recognizes the next patterns
* Transformation recognizes the following patterns
*
* Constant Any
* | |
@ -22,7 +22,7 @@ namespace GNAPluginNS {
* | |
* Eltwise
*
* Where Eltwise node is one of the: Multiply, Substract and Add
* Where Eltwise node is one of the: Multiply, Substract, Add or ScaleShiftIE
* There are different types of broadcasting: NONE/EXPLICIT, NUMPY and PDPD
*
* If eltwise node inputs have different shapes and one the inputs is Constant node

View File

@ -12,6 +12,7 @@
#include <ngraph/pass/manager.hpp>
#include <transformations/init_node_info.hpp>
#include "legacy/ngraph_ops/eltwise.hpp"
#include "legacy/ngraph_ops/scaleshift.hpp"
namespace testing {
@ -122,22 +123,40 @@ std::shared_ptr<ngraph::Function> CreateFunction(const ngraph::Shape& data_shape
bool add_input_fake_quantize,
bool add_const_fake_quantize,
bool swap_outputs,
bool add_scaleshift,
EltwiseFactoryPtr eltwise_factory) {
auto input_params = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::Type_t::f32, data_shape);
const auto input_params_1 = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::Type_t::f32, data_shape);
ngraph::ParameterVector params{input_params_1};
auto constant = ngraph::opset8::Constant::create(ngraph::element::Type_t::f32,
const auto constant_1 = ngraph::opset8::Constant::create(ngraph::element::Type_t::f32,
ngraph::Shape{const_shape_dims},
const_shape_values);
Node const_last_node = constant;
Node const_last_node = constant_1;
if (add_scaleshift) {
const auto input_params_2 = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::Type_t::f32, data_shape);
params.push_back(input_params_2);
const auto constant_2 = ngraph::opset8::Constant::create(ngraph::element::Type_t::f32,
ngraph::Shape{const_shape_dims},
const_shape_values);
const_last_node = std::make_shared<ngraph::op::ScaleShiftIE>(input_params_2,
constant_1,
constant_2,
ngraph::element::Type_t::f32);
}
if (add_const_fake_quantize) {
auto fake_quantize = createFakeQuantizeNode(const_last_node);
const auto fake_quantize = createFakeQuantizeNode(const_last_node);
const_last_node = fake_quantize;
}
Node input_last_node = input_params;
Node input_last_node = input_params_1;
if (add_input_fake_quantize) {
auto fake_quantize = createFakeQuantizeNode(input_params);
const auto fake_quantize = createFakeQuantizeNode(input_last_node);
input_last_node = fake_quantize;
}
@ -147,11 +166,10 @@ std::shared_ptr<ngraph::Function> CreateFunction(const ngraph::Shape& data_shape
if (swap_outputs)
left_node.swap(right_node);
auto add = eltwise_factory->CreateNode(left_node, right_node);
const auto add = eltwise_factory->CreateNode(left_node, right_node);
auto result = std::make_shared<ngraph::opset8::Result>(add);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params});
const auto result = std::make_shared<ngraph::opset8::Result>(add);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, params);
}
} // namespace
@ -163,6 +181,7 @@ class BroadcastConstTestFixture: public CommonTestUtils::TestsCommon,
bool /* add_input_fake_quantize */,
bool /* add_const_fake_quantize */,
bool /* swap_outputs */,
bool /* add_scaleshift */,
ov::op::AutoBroadcastType>> {
public:
void SetUp() override;
@ -176,8 +195,9 @@ void BroadcastConstTestFixture::SetUp() {
bool add_input_fake_quantize;
bool add_const_fake_quantize;
bool swap_outputs;
bool add_scaleshift;
ov::op::AutoBroadcastType broadcast_type;
std::tie(eltwise_factory, add_input_fake_quantize, add_const_fake_quantize, swap_outputs, broadcast_type) = this->GetParam();
std::tie(eltwise_factory, add_input_fake_quantize, add_const_fake_quantize, swap_outputs, add_scaleshift, broadcast_type) = this->GetParam();
eltwise_factory->SetBroadcastType(broadcast_type);
@ -189,6 +209,7 @@ void BroadcastConstTestFixture::SetUp() {
add_input_fake_quantize,
add_const_fake_quantize,
swap_outputs,
add_scaleshift,
eltwise_factory);
reference_function = CreateFunction(shape_info.data_shape,
shape_info.data_shape,
@ -196,6 +217,7 @@ void BroadcastConstTestFixture::SetUp() {
add_input_fake_quantize,
add_const_fake_quantize,
swap_outputs,
add_scaleshift,
eltwise_factory);
}
@ -261,6 +283,7 @@ INSTANTIATE_TEST_SUITE_P(BroadcastConstTestNumpySuite, BroadcastConstTestFixture
::testing::Bool(),
::testing::Bool(),
::testing::Bool(),
::testing::Bool(),
::testing::Values(ov::op::AutoBroadcastType::NUMPY)));
INSTANTIATE_TEST_SUITE_P(BroadcastConstTestPDPDSuite, BroadcastConstTestFixture,
@ -268,6 +291,7 @@ INSTANTIATE_TEST_SUITE_P(BroadcastConstTestPDPDSuite, BroadcastConstTestFixture,
::testing::Bool(),
::testing::Bool(),
::testing::Values(false),
::testing::Bool(),
::testing::Values(ov::op::AutoBroadcastType::PDPD)));
// ------------------------------------------------------------------------------------------------
@ -277,6 +301,7 @@ class BroadcastConstTestPassedFixture: public CommonTestUtils::TestsCommon,
bool /* add_input_fake_quantize */,
bool /* add_const_fake_quantize */,
bool /* swap_outputs */,
bool /* add_scaleshift */,
ov::op::AutoBroadcastType>> {
public:
void SetUp() override;
@ -290,8 +315,9 @@ void BroadcastConstTestPassedFixture::SetUp() {
bool add_input_fake_quantize;
bool add_const_fake_quantize;
bool swap_outputs;
bool add_scaleshift;
ov::op::AutoBroadcastType broadcast_type;
std::tie(eltwise_factory, add_input_fake_quantize, add_const_fake_quantize, swap_outputs, broadcast_type) = this->GetParam();
std::tie(eltwise_factory, add_input_fake_quantize, add_const_fake_quantize, swap_outputs, add_scaleshift, broadcast_type) = this->GetParam();
eltwise_factory->SetBroadcastType(broadcast_type);
@ -303,6 +329,7 @@ void BroadcastConstTestPassedFixture::SetUp() {
add_input_fake_quantize,
add_const_fake_quantize,
swap_outputs,
add_scaleshift,
eltwise_factory);
}
@ -315,6 +342,7 @@ INSTANTIATE_TEST_SUITE_P(BroadcastConstTestPassedSuite, BroadcastConstTestPassed
::testing::Bool(),
::testing::Bool(),
::testing::Bool(),
::testing::Bool(),
::testing::ValuesIn(broadcast_passed_types)));
} // namespace testing