[GNA] Add missing support for batch normalization with weights broadcasting. Add unit tests. (#12301)
This commit is contained in:
parent
d7cf585485
commit
761a6d10d0
@ -12,6 +12,7 @@
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
#include "legacy/ngraph_ops/eltwise.hpp"
|
||||
#include "legacy/ngraph_ops/scaleshift.hpp"
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include <vector>
|
||||
@ -57,19 +58,27 @@ ov::op::BroadcastModeSpec GetBroadcastType(Node eltwise_node) {
|
||||
return ov::op::BroadcastType::NONE;
|
||||
}
|
||||
|
||||
bool DoTransformation(Node const_node, Node eltwise_node) {
|
||||
if (HasDynamicShape(const_node) || HasDynamicShape(eltwise_node))
|
||||
bool DoTransformation(Node const_node_1, Node const_node_2, Node eltwise_node) {
|
||||
if (HasDynamicShape(const_node_1) || (const_node_2 != nullptr && HasDynamicShape(const_node_2)) || HasDynamicShape(eltwise_node))
|
||||
return false;
|
||||
const ngraph::Shape & eltwise_out_shape = eltwise_node->get_output_tensor(0).get_shape();
|
||||
|
||||
auto broadcast_const = ngraph::opset8::Constant::create(ngraph::element::Type_t::i64,
|
||||
ngraph::Shape{eltwise_out_shape.size()}, eltwise_out_shape);
|
||||
|
||||
auto new_const_node = ngraph::op::util::make_try_fold<ngraph::opset8::Broadcast>(const_node,
|
||||
auto new_const_node_1 = ngraph::op::util::make_try_fold<ngraph::opset8::Broadcast>(const_node_1,
|
||||
broadcast_const,
|
||||
GetBroadcastType(eltwise_node));
|
||||
|
||||
ngraph::replace_node(const_node, new_const_node);
|
||||
ngraph::replace_node(const_node_1, new_const_node_1);
|
||||
|
||||
if (const_node_2) {
|
||||
auto new_const_node_2 = ngraph::op::util::make_try_fold<ngraph::opset8::Broadcast>(const_node_2,
|
||||
broadcast_const,
|
||||
GetBroadcastType(eltwise_node));
|
||||
|
||||
ngraph::replace_node(const_node_2, new_const_node_2);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -92,34 +101,46 @@ bool IsEltwiseAcceptable(const ngraph::Output<ngraph::Node>& output) {
|
||||
BroadcastAddMultiplyConst::BroadcastAddMultiplyConst() {
|
||||
MATCHER_SCOPE(BroadcastAddMultiplyConst);
|
||||
|
||||
auto constant = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant,
|
||||
auto constant_1 = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||
auto constant_2 = ngraph::pattern::wrap_type<ngraph::opset8::Constant>();
|
||||
auto fake_quantize_1 = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant_1,
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
|
||||
auto eltwise_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
|
||||
auto fake_quantize_2 = ngraph::pattern::wrap_type<ngraph::opset8::FakeQuantize>({constant_2,
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset8::Constant>()});
|
||||
auto input1 = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant_1, fake_quantize_1});
|
||||
auto input2 = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant_2, fake_quantize_2});
|
||||
auto eltwise_left_const = ngraph::pattern::wrap_type<ngraph::opset8::Add,
|
||||
ngraph::opset8::Subtract,
|
||||
ngraph::opset8::Multiply,
|
||||
ngraph::op::Eltwise>({eltwise_input, ngraph::pattern::any_input()}, IsEltwiseAcceptable);
|
||||
ngraph::op::Eltwise>({input1, ngraph::pattern::any_input()}, IsEltwiseAcceptable);
|
||||
auto eltwise_right_const = ngraph::pattern::wrap_type<ngraph::opset8::Add,
|
||||
ngraph::opset8::Subtract,
|
||||
ngraph::opset8::Multiply,
|
||||
ngraph::op::Eltwise>({ngraph::pattern::any_input(), eltwise_input}, IsEltwiseAcceptable);
|
||||
auto eltwise = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{eltwise_left_const, eltwise_right_const});
|
||||
ngraph::op::Eltwise>({ngraph::pattern::any_input(), input1}, IsEltwiseAcceptable);
|
||||
auto scaleshift = ngraph::pattern::wrap_type<ngraph::op::ScaleShiftIE>({ngraph::pattern::any_input(), input1, input2}, IsEltwiseAcceptable);
|
||||
auto eltwise = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{eltwise_left_const, eltwise_right_const, scaleshift});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto const_node = pattern_map.at(constant).get_node_shared_ptr();
|
||||
auto const_node_1 = pattern_map.at(constant_1).get_node_shared_ptr();
|
||||
auto const_it_2 = pattern_map.find(constant_2);
|
||||
auto const_node_2 = (const_it_2 == std::end(pattern_map) ? nullptr : const_it_2->second.get_node_shared_ptr());
|
||||
|
||||
auto eltwise_node_it = pattern_map.find(eltwise_left_const);
|
||||
if (eltwise_node_it == pattern_map.end())
|
||||
eltwise_node_it = pattern_map.find(eltwise_right_const);
|
||||
if (eltwise_node_it == pattern_map.end())
|
||||
eltwise_node_it = pattern_map.find(scaleshift);
|
||||
if (eltwise_node_it == pattern_map.end())
|
||||
return false;
|
||||
|
||||
return DoTransformation(const_node, eltwise_node_it->second.get_node_shared_ptr());
|
||||
return DoTransformation(const_node_1, const_node_2, eltwise_node_it->second.get_node_shared_ptr());
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(eltwise, matcher_name);
|
||||
|
@ -10,7 +10,7 @@ namespace GNAPluginNS {
|
||||
|
||||
/**
|
||||
* @brief Brodcast data in Const layer
|
||||
* Transformation recognizes the next patterns
|
||||
* Transformation recognizes the following patterns
|
||||
*
|
||||
* Constant Any
|
||||
* | |
|
||||
@ -22,7 +22,7 @@ namespace GNAPluginNS {
|
||||
* | |
|
||||
* Eltwise
|
||||
*
|
||||
* Where Eltwise node is one of the: Multiply, Substract and Add
|
||||
* Where Eltwise node is one of the: Multiply, Substract, Add or ScaleShiftIE
|
||||
* There are different types of broadcasting: NONE/EXPLICIT, NUMPY and PDPD
|
||||
*
|
||||
* If eltwise node inputs have different shapes and one the inputs is Constant node
|
||||
|
@ -12,6 +12,7 @@
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include "legacy/ngraph_ops/eltwise.hpp"
|
||||
#include "legacy/ngraph_ops/scaleshift.hpp"
|
||||
|
||||
namespace testing {
|
||||
|
||||
@ -122,22 +123,40 @@ std::shared_ptr<ngraph::Function> CreateFunction(const ngraph::Shape& data_shape
|
||||
bool add_input_fake_quantize,
|
||||
bool add_const_fake_quantize,
|
||||
bool swap_outputs,
|
||||
bool add_scaleshift,
|
||||
EltwiseFactoryPtr eltwise_factory) {
|
||||
auto input_params = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::Type_t::f32, data_shape);
|
||||
const auto input_params_1 = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::Type_t::f32, data_shape);
|
||||
ngraph::ParameterVector params{input_params_1};
|
||||
|
||||
auto constant = ngraph::opset8::Constant::create(ngraph::element::Type_t::f32,
|
||||
const auto constant_1 = ngraph::opset8::Constant::create(ngraph::element::Type_t::f32,
|
||||
ngraph::Shape{const_shape_dims},
|
||||
const_shape_values);
|
||||
Node const_last_node = constant;
|
||||
|
||||
Node const_last_node = constant_1;
|
||||
|
||||
if (add_scaleshift) {
|
||||
const auto input_params_2 = std::make_shared<ngraph::opset8::Parameter>(ngraph::element::Type_t::f32, data_shape);
|
||||
params.push_back(input_params_2);
|
||||
|
||||
const auto constant_2 = ngraph::opset8::Constant::create(ngraph::element::Type_t::f32,
|
||||
ngraph::Shape{const_shape_dims},
|
||||
const_shape_values);
|
||||
|
||||
const_last_node = std::make_shared<ngraph::op::ScaleShiftIE>(input_params_2,
|
||||
constant_1,
|
||||
constant_2,
|
||||
ngraph::element::Type_t::f32);
|
||||
}
|
||||
|
||||
if (add_const_fake_quantize) {
|
||||
auto fake_quantize = createFakeQuantizeNode(const_last_node);
|
||||
const auto fake_quantize = createFakeQuantizeNode(const_last_node);
|
||||
const_last_node = fake_quantize;
|
||||
}
|
||||
|
||||
Node input_last_node = input_params;
|
||||
Node input_last_node = input_params_1;
|
||||
|
||||
if (add_input_fake_quantize) {
|
||||
auto fake_quantize = createFakeQuantizeNode(input_params);
|
||||
const auto fake_quantize = createFakeQuantizeNode(input_last_node);
|
||||
input_last_node = fake_quantize;
|
||||
}
|
||||
|
||||
@ -147,11 +166,10 @@ std::shared_ptr<ngraph::Function> CreateFunction(const ngraph::Shape& data_shape
|
||||
if (swap_outputs)
|
||||
left_node.swap(right_node);
|
||||
|
||||
auto add = eltwise_factory->CreateNode(left_node, right_node);
|
||||
const auto add = eltwise_factory->CreateNode(left_node, right_node);
|
||||
|
||||
auto result = std::make_shared<ngraph::opset8::Result>(add);
|
||||
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
|
||||
ngraph::ParameterVector{input_params});
|
||||
const auto result = std::make_shared<ngraph::opset8::Result>(add);
|
||||
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result}, params);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -163,6 +181,7 @@ class BroadcastConstTestFixture: public CommonTestUtils::TestsCommon,
|
||||
bool /* add_input_fake_quantize */,
|
||||
bool /* add_const_fake_quantize */,
|
||||
bool /* swap_outputs */,
|
||||
bool /* add_scaleshift */,
|
||||
ov::op::AutoBroadcastType>> {
|
||||
public:
|
||||
void SetUp() override;
|
||||
@ -176,8 +195,9 @@ void BroadcastConstTestFixture::SetUp() {
|
||||
bool add_input_fake_quantize;
|
||||
bool add_const_fake_quantize;
|
||||
bool swap_outputs;
|
||||
bool add_scaleshift;
|
||||
ov::op::AutoBroadcastType broadcast_type;
|
||||
std::tie(eltwise_factory, add_input_fake_quantize, add_const_fake_quantize, swap_outputs, broadcast_type) = this->GetParam();
|
||||
std::tie(eltwise_factory, add_input_fake_quantize, add_const_fake_quantize, swap_outputs, add_scaleshift, broadcast_type) = this->GetParam();
|
||||
|
||||
eltwise_factory->SetBroadcastType(broadcast_type);
|
||||
|
||||
@ -189,6 +209,7 @@ void BroadcastConstTestFixture::SetUp() {
|
||||
add_input_fake_quantize,
|
||||
add_const_fake_quantize,
|
||||
swap_outputs,
|
||||
add_scaleshift,
|
||||
eltwise_factory);
|
||||
reference_function = CreateFunction(shape_info.data_shape,
|
||||
shape_info.data_shape,
|
||||
@ -196,6 +217,7 @@ void BroadcastConstTestFixture::SetUp() {
|
||||
add_input_fake_quantize,
|
||||
add_const_fake_quantize,
|
||||
swap_outputs,
|
||||
add_scaleshift,
|
||||
eltwise_factory);
|
||||
}
|
||||
|
||||
@ -261,6 +283,7 @@ INSTANTIATE_TEST_SUITE_P(BroadcastConstTestNumpySuite, BroadcastConstTestFixture
|
||||
::testing::Bool(),
|
||||
::testing::Bool(),
|
||||
::testing::Bool(),
|
||||
::testing::Bool(),
|
||||
::testing::Values(ov::op::AutoBroadcastType::NUMPY)));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(BroadcastConstTestPDPDSuite, BroadcastConstTestFixture,
|
||||
@ -268,6 +291,7 @@ INSTANTIATE_TEST_SUITE_P(BroadcastConstTestPDPDSuite, BroadcastConstTestFixture,
|
||||
::testing::Bool(),
|
||||
::testing::Bool(),
|
||||
::testing::Values(false),
|
||||
::testing::Bool(),
|
||||
::testing::Values(ov::op::AutoBroadcastType::PDPD)));
|
||||
|
||||
// ------------------------------------------------------------------------------------------------
|
||||
@ -277,6 +301,7 @@ class BroadcastConstTestPassedFixture: public CommonTestUtils::TestsCommon,
|
||||
bool /* add_input_fake_quantize */,
|
||||
bool /* add_const_fake_quantize */,
|
||||
bool /* swap_outputs */,
|
||||
bool /* add_scaleshift */,
|
||||
ov::op::AutoBroadcastType>> {
|
||||
public:
|
||||
void SetUp() override;
|
||||
@ -290,8 +315,9 @@ void BroadcastConstTestPassedFixture::SetUp() {
|
||||
bool add_input_fake_quantize;
|
||||
bool add_const_fake_quantize;
|
||||
bool swap_outputs;
|
||||
bool add_scaleshift;
|
||||
ov::op::AutoBroadcastType broadcast_type;
|
||||
std::tie(eltwise_factory, add_input_fake_quantize, add_const_fake_quantize, swap_outputs, broadcast_type) = this->GetParam();
|
||||
std::tie(eltwise_factory, add_input_fake_quantize, add_const_fake_quantize, swap_outputs, add_scaleshift, broadcast_type) = this->GetParam();
|
||||
|
||||
eltwise_factory->SetBroadcastType(broadcast_type);
|
||||
|
||||
@ -303,6 +329,7 @@ void BroadcastConstTestPassedFixture::SetUp() {
|
||||
add_input_fake_quantize,
|
||||
add_const_fake_quantize,
|
||||
swap_outputs,
|
||||
add_scaleshift,
|
||||
eltwise_factory);
|
||||
}
|
||||
|
||||
@ -315,6 +342,7 @@ INSTANTIATE_TEST_SUITE_P(BroadcastConstTestPassedSuite, BroadcastConstTestPassed
|
||||
::testing::Bool(),
|
||||
::testing::Bool(),
|
||||
::testing::Bool(),
|
||||
::testing::Bool(),
|
||||
::testing::ValuesIn(broadcast_passed_types)));
|
||||
|
||||
} // namespace testing
|
||||
|
Loading…
Reference in New Issue
Block a user