ConvMulFusion - handle ConvolutionBackpropData with 3 inputs (#17145)

* ConvMulFusion - handle ConvolutionBackpropData with 3 inputs

Ticket: 98769

* add using

* use compare functions
This commit is contained in:
Mateusz Tabaka 2023-04-26 09:37:31 +02:00 committed by GitHub
parent 3c485feea8
commit da4316845f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 217 additions and 124 deletions

View File

@ -4,14 +4,11 @@
#include "transformations/common_optimizations/conv_mul_fusion.hpp"
#include <memory>
#include <ngraph/ngraph.hpp>
#include <ngraph/pattern/matcher.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <openvino/opsets/opset4.hpp>
#include <openvino/pass/pattern/op/or.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
#include <transformations/utils/utils.hpp>
#include <vector>
#include "itt.hpp"
@ -19,9 +16,9 @@ ov::pass::ConvolutionMultiplyFusion::ConvolutionMultiplyFusion() {
MATCHER_SCOPE(ConvolutionMultiplyFusion);
auto input = pattern::any_input();
auto weights = pass::pattern::any_input(pattern::has_static_dim(0) /* has OIYX layout */);
auto conv = ngraph::pattern::wrap_type<opset4::Convolution>({input, weights}, pattern::consumers_count(1));
auto mul_const = ngraph::pattern::wrap_type<opset4::Constant>(pattern::has_static_shape());
auto mul = ngraph::pattern::wrap_type<opset4::Multiply>({conv, mul_const});
auto conv = pattern::wrap_type<opset4::Convolution>({input, weights}, pattern::consumers_count(1));
auto mul_const = pattern::wrap_type<opset4::Constant>(pattern::has_static_shape());
auto mul = pattern::wrap_type<opset4::Multiply>({conv, mul_const});
matcher_pass_callback callback = [conv, input, weights, mul, mul_const](pattern::Matcher& m) -> bool {
const auto& pattern_to_output = m.get_pattern_value_map();
@ -57,12 +54,10 @@ ov::pass::ConvolutionMultiplyFusion::ConvolutionMultiplyFusion() {
if (!is_scalar_multiplier) {
auto final_const_shape = Shape(weights_rank, 1);
final_const_shape[0] = channel_dim;
final_const =
std::make_shared<opset4::Reshape>(m_const,
opset4::Constant::create(ngraph::element::i64,
ngraph::Shape{final_const_shape.size()},
final_const_shape),
true);
final_const = std::make_shared<opset4::Reshape>(
m_const,
opset4::Constant::create(element::i64, Shape{final_const_shape.size()}, final_const_shape),
true);
}
// Multiply convolution weights with aligned Constant values
@ -76,7 +71,7 @@ ov::pass::ConvolutionMultiplyFusion::ConvolutionMultiplyFusion() {
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, matcher_name);
auto m = std::make_shared<pattern::Matcher>(mul, matcher_name);
register_matcher(m, callback);
}
@ -84,9 +79,9 @@ ov::pass::GroupConvolutionMultiplyFusion::GroupConvolutionMultiplyFusion() {
MATCHER_SCOPE(GroupConvolutionMultiplyFusion);
auto input = pattern::any_input();
auto weights = pass::pattern::any_input(pattern::has_static_dims({0, 1}) /* has GOIYX layout */);
auto conv = ngraph::pattern::wrap_type<opset4::GroupConvolution>({input, weights}, pattern::consumers_count(1));
auto mul_const = ngraph::pattern::wrap_type<opset4::Constant>(); // pattern::has_static_shape());
auto mul = ngraph::pattern::wrap_type<opset4::Multiply>({conv, mul_const});
auto conv = pattern::wrap_type<opset4::GroupConvolution>({input, weights}, pattern::consumers_count(1));
auto mul_const = pattern::wrap_type<opset4::Constant>(); // pattern::has_static_shape());
auto mul = pattern::wrap_type<opset4::Multiply>({conv, mul_const});
matcher_pass_callback callback = [conv, input, weights, mul, mul_const](pattern::Matcher& m) -> bool {
const auto& pattern_to_output = m.get_pattern_value_map();
@ -142,12 +137,10 @@ ov::pass::GroupConvolutionMultiplyFusion::GroupConvolutionMultiplyFusion() {
final_const_shape[0] = G;
final_const_shape[1] = O;
}
final_const =
std::make_shared<opset4::Reshape>(m_const,
opset4::Constant::create(ngraph::element::i64,
ngraph::Shape{final_const_shape.size()},
final_const_shape),
true);
final_const = std::make_shared<opset4::Reshape>(
m_const,
opset4::Constant::create(element::i64, Shape{final_const_shape.size()}, final_const_shape),
true);
}
// Multiply convolution weights with aligned Constant values
@ -164,7 +157,7 @@ ov::pass::GroupConvolutionMultiplyFusion::GroupConvolutionMultiplyFusion() {
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, matcher_name);
auto m = std::make_shared<pattern::Matcher>(mul, matcher_name);
register_matcher(m, callback);
}
@ -172,18 +165,20 @@ ov::pass::ConvolutionBackpropDataMultiplyFusion::ConvolutionBackpropDataMultiply
MATCHER_SCOPE(ConvolutionBackpropDataMultiplyFusion);
auto input = pattern::any_input();
auto weights = pass::pattern::any_input(pattern::has_static_dim(1) /* has IOYX layout */);
auto conv =
ngraph::pattern::wrap_type<opset4::ConvolutionBackpropData>({input, weights}, pattern::consumers_count(1));
auto mul_const = ngraph::pattern::wrap_type<opset4::Constant>(pattern::has_static_shape());
auto mul = ngraph::pattern::wrap_type<opset4::Multiply>({conv, mul_const});
auto conv_2_inputs =
pattern::wrap_type<opset4::ConvolutionBackpropData>({input, weights}, pattern::consumers_count(1));
auto conv_3_inputs = pattern::wrap_type<opset4::ConvolutionBackpropData>({input, weights, pattern::any_input()},
pattern::consumers_count(1));
auto conv = std::make_shared<pattern::op::Or>(OutputVector{conv_2_inputs, conv_3_inputs});
auto mul_const = pattern::wrap_type<opset4::Constant>(pattern::has_static_shape());
auto mul = pattern::wrap_type<opset4::Multiply>({conv, mul_const});
matcher_pass_callback callback = [conv, input, weights, mul, mul_const](pattern::Matcher& m) -> bool {
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
const auto& pattern_to_output = m.get_pattern_value_map();
const auto& m_weights = pattern_to_output.at(weights);
const auto& m_const = pattern_to_output.at(mul_const);
const auto& m_input = pattern_to_output.at(input);
const auto& m_conv = pattern_to_output.at(conv).get_node_shared_ptr();
const auto& m_mul = pattern_to_output.at(mul).get_node_shared_ptr();
const auto& channel_dim = m_weights.get_partial_shape()[1].get_length();
@ -211,26 +206,33 @@ ov::pass::ConvolutionBackpropDataMultiplyFusion::ConvolutionBackpropDataMultiply
if (!is_scalar_multiplier) {
auto final_const_shape = Shape(weights_rank - 1, 1);
final_const_shape[0] = channel_dim;
final_const =
std::make_shared<opset4::Reshape>(m_const,
opset4::Constant::create(ngraph::element::i64,
ngraph::Shape{final_const_shape.size()},
final_const_shape),
true);
final_const = std::make_shared<opset4::Reshape>(
m_const,
opset4::Constant::create(element::i64, Shape{final_const_shape.size()}, final_const_shape),
true);
}
// Multiply convolution weights with aligned Constant values
auto weights_multiply = std::make_shared<opset4::Multiply>(m_weights, final_const);
// Replace Convolution->Multiply with Convolution with new inputs
auto new_conv = m_conv->clone_with_new_inputs({m_input, weights_multiply});
std::shared_ptr<Node> new_conv;
std::shared_ptr<Node> m_conv;
auto it = pattern_to_output.find(conv_2_inputs);
if (it != pattern_to_output.end()) {
m_conv = it->second.get_node_shared_ptr();
new_conv = m_conv->clone_with_new_inputs({m_input, weights_multiply});
} else {
m_conv = pattern_to_output.at(conv_3_inputs).get_node_shared_ptr();
new_conv = m_conv->clone_with_new_inputs({m_input, weights_multiply, m_conv->input_value(2)});
}
new_conv->set_friendly_name(m_mul->get_friendly_name());
copy_runtime_info({m_conv, m_mul}, {new_conv, final_const.get_node_shared_ptr(), weights_multiply});
replace_node(m_mul, new_conv);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, matcher_name);
auto m = std::make_shared<pattern::Matcher>(mul, matcher_name);
register_matcher(m, callback);
}
@ -238,18 +240,21 @@ ov::pass::GroupConvolutionBackpropDataMultiplyFusion::GroupConvolutionBackpropDa
MATCHER_SCOPE(GroupConvolutionBackpropDataMultiplyFusion);
auto input = pattern::any_input();
auto weights = pass::pattern::any_input(pattern::has_static_dims({0, 2}) /* has GIOYX layout */);
auto conv =
ngraph::pattern::wrap_type<opset4::GroupConvolutionBackpropData>({input, weights}, pattern::consumers_count(1));
auto mul_const = ngraph::pattern::wrap_type<opset4::Constant>(pattern::has_static_shape());
auto mul = ngraph::pattern::wrap_type<opset4::Multiply>({conv, mul_const});
auto conv_2_inputs =
pattern::wrap_type<opset4::GroupConvolutionBackpropData>({input, weights}, pattern::consumers_count(1));
auto conv_3_inputs =
pattern::wrap_type<opset4::GroupConvolutionBackpropData>({input, weights, pattern::any_input()},
pattern::consumers_count(1));
auto conv = std::make_shared<pattern::op::Or>(OutputVector{conv_2_inputs, conv_3_inputs});
auto mul_const = pattern::wrap_type<opset4::Constant>(pattern::has_static_shape());
auto mul = pattern::wrap_type<opset4::Multiply>({conv, mul_const});
matcher_pass_callback callback = [conv, input, weights, mul, mul_const](pattern::Matcher& m) -> bool {
matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
const auto& pattern_to_output = m.get_pattern_value_map();
const auto& m_weights = pattern_to_output.at(weights);
const auto& m_const = pattern_to_output.at(mul_const);
const auto& m_input = pattern_to_output.at(input);
const auto& m_conv = pattern_to_output.at(conv).get_node_shared_ptr();
const auto& m_mul = pattern_to_output.at(mul).get_node_shared_ptr();
const auto& G = m_weights.get_partial_shape()[0].get_length();
@ -279,25 +284,32 @@ ov::pass::GroupConvolutionBackpropDataMultiplyFusion::GroupConvolutionBackpropDa
auto final_const_shape = Shape(weights_rank, 1);
final_const_shape[0] = G;
final_const_shape[2] = O;
final_const =
std::make_shared<opset4::Reshape>(m_const,
opset4::Constant::create(ngraph::element::i64,
ngraph::Shape{final_const_shape.size()},
final_const_shape),
true);
final_const = std::make_shared<opset4::Reshape>(
m_const,
opset4::Constant::create(element::i64, Shape{final_const_shape.size()}, final_const_shape),
true);
}
// Multiply convolution weights with aligned Constant values
auto weights_multiply = std::make_shared<opset4::Multiply>(m_weights, final_const);
// Replace Convolution->Multiply with Convolution with new inputs
auto new_conv = m_conv->clone_with_new_inputs({m_input, weights_multiply});
std::shared_ptr<Node> new_conv;
std::shared_ptr<Node> m_conv;
auto it = pattern_to_output.find(conv_2_inputs);
if (it != pattern_to_output.end()) {
m_conv = it->second.get_node_shared_ptr();
new_conv = m_conv->clone_with_new_inputs({m_input, weights_multiply});
} else {
m_conv = pattern_to_output.at(conv_3_inputs).get_node_shared_ptr();
new_conv = m_conv->clone_with_new_inputs({m_input, weights_multiply, m_conv->input_value(2)});
}
new_conv->set_friendly_name(m_mul->get_friendly_name());
copy_runtime_info({m_conv, m_mul}, {new_conv, final_const.get_node_shared_ptr(), weights_multiply});
replace_node(m_mul, new_conv);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, matcher_name);
auto m = std::make_shared<pattern::Matcher>(mul, matcher_name);
register_matcher(m, callback);
}

View File

@ -11,15 +11,14 @@ using namespace SubgraphTestsDefinitions;
namespace {
const std::vector<ngraph::element::Type> types{ngraph::element::f32, ngraph::element::f16};
#define MUL(X) std::tuple<ngraph::NodeTypeInfo, int64_t>(ngraph::opset4::Multiply::get_type_info_static(), X)
#define ADD(X) std::tuple<ngraph::NodeTypeInfo, int64_t>(ngraph::opset4::Add::get_type_info_static(), X)
#define IN std::vector<std::tuple<ngraph::NodeTypeInfo, int64_t>>
const std::vector<ngraph::NodeTypeInfo> eltwise_types{ngraph::opset4::Multiply::get_type_info_static(),
/* ngraph::opset4::Add::get_type_info_static() */};
INSTANTIATE_TEST_SUITE_P(smoke_Convolution_1D, ConvEltwiseFusion,
::testing::Combine(
::testing::Values(ngraph::opset4::Convolution::get_type_info_static()),
::testing::ValuesIn(IN({ MUL(4), /* ADD(5) */})),
::testing::Values(std::tuple<ngraph::NodeTypeInfo, size_t>{ngraph::opset4::Convolution::get_type_info_static(), 2}),
::testing::ValuesIn(eltwise_types),
::testing::Values(false),
::testing::Values(ngraph::Shape{1, 8, 64}),
::testing::Values(ngraph::Shape{64, 8, 1}),
::testing::Values(ngraph::Shape{64, 1}),
@ -29,8 +28,9 @@ namespace {
INSTANTIATE_TEST_SUITE_P(smoke_GroupConvolution_1D, ConvEltwiseFusion,
::testing::Combine(
::testing::Values(ngraph::opset4::GroupConvolution::get_type_info_static()),
::testing::ValuesIn(IN({ MUL(4), /* ADD(5) */})),
::testing::Values(std::tuple<ngraph::NodeTypeInfo, size_t>{ngraph::opset4::GroupConvolution::get_type_info_static(), 2}),
::testing::ValuesIn(eltwise_types),
::testing::Values(false),
::testing::Values(ngraph::Shape{1, 12, 5}),
::testing::Values(ngraph::Shape{4, 5, 3, 2}),
::testing::Values(ngraph::Shape{20, 1}),
@ -40,8 +40,11 @@ namespace {
INSTANTIATE_TEST_SUITE_P(smoke_ConvolutionBackpropData_1D, ConvEltwiseFusion,
::testing::Combine(
::testing::Values(ngraph::opset4::ConvolutionBackpropData::get_type_info_static()),
::testing::ValuesIn(IN({ MUL(4), /* ADD(5) */})),
::testing::Combine(
::testing::Values(ngraph::opset4::ConvolutionBackpropData::get_type_info_static()),
::testing::ValuesIn(std::vector<size_t>{2, 3})),
::testing::ValuesIn(eltwise_types),
::testing::Values(false),
::testing::Values(ngraph::Shape{1, 12, 64}),
::testing::Values(ngraph::Shape{12, 20, 1}),
::testing::Values(ngraph::Shape{20, 1}),
@ -51,8 +54,11 @@ namespace {
INSTANTIATE_TEST_SUITE_P(smoke_GroupConvolutionBackpropData_1D, ConvEltwiseFusion,
::testing::Combine(
::testing::Values(ngraph::opset4::GroupConvolutionBackpropData::get_type_info_static()),
::testing::ValuesIn(IN({ MUL(4), /* ADD(5) */})),
::testing::Combine(
::testing::Values(ngraph::opset4::GroupConvolutionBackpropData::get_type_info_static()),
::testing::ValuesIn(std::vector<size_t>{2, 3})),
::testing::ValuesIn(eltwise_types),
::testing::Values(false),
::testing::Values(ngraph::Shape{1, 12, 64}),
::testing::Values(ngraph::Shape{4, 3, 5, 1}),
::testing::Values(ngraph::Shape{1, 20, 1}),
@ -70,8 +76,9 @@ namespace {
INSTANTIATE_TEST_SUITE_P(smoke_Convolution_2D, ConvEltwiseFusion,
::testing::Combine(
::testing::Values(ngraph::opset4::Convolution::get_type_info_static()),
::testing::ValuesIn(IN({ MUL(4), /* ADD(5) */})),
::testing::Values(std::tuple<ngraph::NodeTypeInfo, size_t>{ngraph::opset4::Convolution::get_type_info_static(), 2}),
::testing::ValuesIn(eltwise_types),
::testing::Values(false),
::testing::Values(ngraph::Shape{1, 3, 64, 64}),
::testing::Values(ngraph::Shape{20, 3, 1, 1}),
::testing::ValuesIn(const_shapes_2d),
@ -81,8 +88,9 @@ namespace {
INSTANTIATE_TEST_SUITE_P(smoke_GroupConvolution_2D, ConvEltwiseFusion,
::testing::Combine(
::testing::Values(ngraph::opset4::GroupConvolution::get_type_info_static()),
::testing::ValuesIn(IN({ MUL(4), /* ADD(5) */})),
::testing::Values(std::tuple<ngraph::NodeTypeInfo, size_t>{ngraph::opset4::GroupConvolution::get_type_info_static(), 2}),
::testing::ValuesIn(eltwise_types),
::testing::Values(false),
::testing::Values(ngraph::Shape{1, 12, 64, 64}),
::testing::Values(ngraph::Shape{4, 5, 3, 1, 2}),
::testing::ValuesIn(const_shapes_2d),
@ -92,8 +100,11 @@ namespace {
INSTANTIATE_TEST_SUITE_P(smoke_ConvolutionBackpropData_2D, ConvEltwiseFusion,
::testing::Combine(
::testing::Values(ngraph::opset4::ConvolutionBackpropData::get_type_info_static()),
::testing::ValuesIn(IN({ MUL(4), /* ADD(5) */})),
::testing::Combine(
::testing::Values(ngraph::opset4::ConvolutionBackpropData::get_type_info_static()),
::testing::ValuesIn(std::vector<size_t>{2, 3})),
::testing::ValuesIn(eltwise_types),
::testing::Values(false),
::testing::Values(ngraph::Shape{1, 3, 64, 64}),
::testing::Values(ngraph::Shape{3, 20, 3, 3}),
::testing::ValuesIn(const_shapes_2d),
@ -103,8 +114,11 @@ namespace {
INSTANTIATE_TEST_SUITE_P(smoke_GroupConvolutionBackpropData_2D, ConvEltwiseFusion,
::testing::Combine(
::testing::Values(ngraph::opset4::GroupConvolutionBackpropData::get_type_info_static()),
::testing::ValuesIn(IN({ MUL(4), /* ADD(5) */})),
::testing::Combine(
::testing::Values(ngraph::opset4::GroupConvolutionBackpropData::get_type_info_static()),
::testing::ValuesIn(std::vector<size_t>{2, 3})),
::testing::ValuesIn(eltwise_types),
::testing::Values(false),
::testing::Values(ngraph::Shape{1, 12, 64, 64}),
::testing::Values(ngraph::Shape{4, 3, 5, 1, 1}),
::testing::ValuesIn(const_shapes_2d),
@ -119,8 +133,9 @@ namespace {
INSTANTIATE_TEST_SUITE_P(smoke_Convolution_2D_Negative, ConvEltwiseFusion,
::testing::Combine(
::testing::Values(ngraph::opset4::Convolution::get_type_info_static()),
::testing::ValuesIn(IN({MUL(6), /* ADD(6) */})),
::testing::Values(std::tuple<ngraph::NodeTypeInfo, size_t>{ngraph::opset4::Convolution::get_type_info_static(), 2}),
::testing::ValuesIn(eltwise_types),
::testing::Values(true),
::testing::Values(ngraph::Shape{1, 3, 3, 3}),
::testing::Values(ngraph::Shape{3, 3, 1, 1}),
::testing::ValuesIn(neg_const_shapes_2d),
@ -130,8 +145,10 @@ namespace {
INSTANTIATE_TEST_SUITE_P(smoke_GroupConvolution_2D_Negative, ConvEltwiseFusion,
::testing::Combine(
::testing::Values(ngraph::opset4::GroupConvolution::get_type_info_static()),
::testing::ValuesIn(IN({MUL(6), /* ADD(6) */})),
::testing::Values(
std::tuple<ngraph::NodeTypeInfo, size_t>{ngraph::opset4::GroupConvolution::get_type_info_static(), 2}),
::testing::ValuesIn(eltwise_types),
::testing::Values(true),
::testing::Values(ngraph::Shape{1, 12, 3, 3}),
::testing::Values(ngraph::Shape{4, 5, 3, 1, 1}),
::testing::ValuesIn(neg_const_shapes_2d),
@ -141,8 +158,10 @@ namespace {
INSTANTIATE_TEST_SUITE_P(smoke_ConvolutionBackpropData_2D_Negative, ConvEltwiseFusion,
::testing::Combine(
::testing::Values(ngraph::opset4::ConvolutionBackpropData::get_type_info_static()),
::testing::ValuesIn(IN({MUL(6), /* ADD(6) */})),
::testing::Values(
std::tuple<ngraph::NodeTypeInfo, size_t>{ngraph::opset4::ConvolutionBackpropData::get_type_info_static(), 2}),
::testing::ValuesIn(eltwise_types),
::testing::Values(true),
::testing::Values(ngraph::Shape{1, 12, 3, 3}),
::testing::Values(ngraph::Shape{12, 3, 1, 1}),
::testing::ValuesIn(neg_const_shapes_2d),
@ -152,8 +171,10 @@ namespace {
INSTANTIATE_TEST_SUITE_P(smoke_GroupConvolutionBackpropData_2D_Negative, ConvEltwiseFusion,
::testing::Combine(
::testing::Values(ngraph::opset4::GroupConvolutionBackpropData::get_type_info_static()),
::testing::ValuesIn(IN({MUL(6), /* ADD(6) */})),
::testing::Values(
std::tuple<ngraph::NodeTypeInfo, size_t>{ngraph::opset4::GroupConvolutionBackpropData::get_type_info_static(), 2}),
::testing::ValuesIn(eltwise_types),
::testing::Values(true),
::testing::Values(ngraph::Shape{1, 12, 3, 3}),
::testing::Values(ngraph::Shape{4, 3, 5, 1, 1}),
::testing::ValuesIn(neg_const_shapes_2d),

View File

@ -15,11 +15,12 @@
namespace SubgraphTestsDefinitions {
typedef std::tuple<
ngraph::NodeTypeInfo, // Convolution type
std::tuple<
ngraph::NodeTypeInfo, // Eltwise type
int64_t // Expected number of ops
ngraph::NodeTypeInfo, // Convolution type
size_t // Number of inputs
>,
ngraph::NodeTypeInfo, // Eltwise type
bool, // Is the test negative or not
ngraph::Shape, // Input shape
ngraph::Shape, // Weights shape
ngraph::Shape, // Const shape

View File

@ -3,26 +3,33 @@
//
#include "transformations/common_optimizations/conv_mul_fusion.hpp"
#include "ngraph/pass/constant_folding.hpp"
#include "openvino/core/node.hpp"
#include "openvino/opsets/opset11.hpp"
#include "openvino/pass/constant_folding.hpp"
#include "shared_test_classes/subgraph/conv_eltwise_fusion.hpp"
using namespace ov;
// #include <legacy/transformations/convert_opset1_to_legacy/conv_bias_fusion.hpp>
// #include <legacy/transformations/convert_opset1_to_legacy/convert_convolutions.hpp>
namespace SubgraphTestsDefinitions {
std::string ConvEltwiseFusion::getTestCaseName(const testing::TestParamInfo<ConvEltwiseFusionParams> &obj) {
ngraph::NodeTypeInfo conv_type, eltwise_type;
std::tuple<ngraph::NodeTypeInfo, int64_t> t;
ngraph::Shape input_shape, weights_shape, const_shape;
ngraph::element::Type precision;
std::tuple<NodeTypeInfo, size_t> conv_params;
NodeTypeInfo conv_type, eltwise_type;
bool negative;
Shape input_shape, weights_shape, const_shape;
element::Type precision;
std::string targetName;
int64_t expected_number_of_ops;
std::tie(conv_type, t, input_shape, weights_shape, const_shape, precision, targetName) = obj.param;
std::tie(eltwise_type, expected_number_of_ops) = t;
std::tie(conv_params, eltwise_type, negative, input_shape, weights_shape, const_shape, precision, targetName) = obj.param;
size_t num_inputs;
std::tie(conv_type, num_inputs) = conv_params;
std::ostringstream results;
results << conv_type.name << "_";
results << "NumInputs=" << num_inputs << "_";
results << "Negative=" << std::boolalpha << negative << "_";
results << eltwise_type.name << "_";
results << "Input" << CommonTestUtils::vec2str(input_shape);
results << "Weights" << CommonTestUtils::vec2str(weights_shape);
@ -33,58 +40,110 @@ std::string ConvEltwiseFusion::getTestCaseName(const testing::TestParamInfo<Conv
}
void ConvEltwiseFusion::SetUp() {
ngraph::NodeTypeInfo conv_type, eltwise_type;
std::tuple<ngraph::NodeTypeInfo, int64_t> t;
ngraph::Shape input_shape, weights_shape, const_shape;
ngraph::element::Type precision;
int64_t expected_number_of_ops;
std::tie(conv_type, t, input_shape, weights_shape, const_shape, precision, targetDevice) = this->GetParam();
std::tie(eltwise_type, expected_number_of_ops) = t;
ngraph::pass::Manager manager;
std::tuple<NodeTypeInfo, size_t> conv_params;
NodeTypeInfo conv_type, eltwise_type;
bool negative;
Shape input_shape, weights_shape, const_shape;
element::Type precision;
size_t num_inputs;
std::tie(conv_params, eltwise_type, negative, input_shape, weights_shape, const_shape, precision, targetDevice) = this->GetParam();
std::tie(conv_type, num_inputs) = conv_params;
pass::Manager manager;
{
auto param = std::make_shared<ngraph::opset4::Parameter>(precision, input_shape);
auto param = std::make_shared<opset11::Parameter>(precision, input_shape);
auto spatial_dims = input_shape.size() - 2;
ngraph::Shape strides(spatial_dims, 1);
Shape strides(spatial_dims, 1);
std::vector<ptrdiff_t> pad_begin(spatial_dims, 0), pad_end(spatial_dims, 0);
auto weights = ngraph::builder::makeConstant<float>(precision, weights_shape, {}, true);
auto eltwise_const = ngraph::builder::makeConstant<float>(precision, const_shape, {}, true);
std::shared_ptr<ngraph::Node> conv;
if (conv_type == ngraph::opset4::Convolution::get_type_info_static()) {
conv = std::make_shared<ngraph::opset4::Convolution>(param, weights, strides, pad_begin, pad_end, strides);
} else if (conv_type == ngraph::opset4::GroupConvolution::get_type_info_static()) {
conv = std::make_shared<ngraph::opset4::GroupConvolution>(param, weights, strides, pad_begin, pad_end, strides);
} else if (conv_type == ngraph::opset4::ConvolutionBackpropData::get_type_info_static()) {
conv = std::make_shared<ngraph::opset4::ConvolutionBackpropData>(param, weights, strides, pad_begin, pad_end, strides);
} else if (conv_type == ngraph::opset4::GroupConvolutionBackpropData::get_type_info_static()) {
conv = std::make_shared<ngraph::opset4::GroupConvolutionBackpropData>(param, weights, strides, pad_begin, pad_end, strides);
auto weights = ngraph::builder::makeConstant<float>(precision, weights_shape, std::vector<float>(shape_size(weights_shape), 2));
auto eltwise_const = ngraph::builder::makeConstant<float>(precision, const_shape, std::vector<float>(shape_size(const_shape), 3));
std::shared_ptr<Node> conv;
if (conv_type == opset11::Convolution::get_type_info_static()) {
conv = std::make_shared<opset11::Convolution>(param, weights, strides, pad_begin, pad_end, strides);
} else if (conv_type == opset11::GroupConvolution::get_type_info_static()) {
conv = std::make_shared<opset11::GroupConvolution>(param, weights, strides, pad_begin, pad_end, strides);
} else if (conv_type == opset11::ConvolutionBackpropData::get_type_info_static()) {
if (num_inputs == 3) {
auto output_shape = std::make_shared<opset11::Constant>(element::u64, Shape{spatial_dims},
std::vector<size_t>(input_shape.begin() + 2, input_shape.end()));
conv = std::make_shared<opset11::ConvolutionBackpropData>(param, weights, output_shape, strides, pad_begin, pad_end, strides);
} else {
conv = std::make_shared<opset11::ConvolutionBackpropData>(param, weights, strides, pad_begin, pad_end, strides);
}
} else if (conv_type == opset11::GroupConvolutionBackpropData::get_type_info_static()) {
if (num_inputs == 3) {
auto output_shape = std::make_shared<opset11::Constant>(element::u64, Shape{spatial_dims},
std::vector<size_t>(input_shape.begin() + 2, input_shape.end()));
conv = std::make_shared<opset11::GroupConvolutionBackpropData>(param, weights, output_shape, strides, pad_begin, pad_end, strides);
} else {
conv = std::make_shared<opset11::GroupConvolutionBackpropData>(param, weights, strides, pad_begin, pad_end, strides);
}
} else {
OPENVINO_THROW("Unsupported type");
}
std::shared_ptr<ngraph::Node> eltwise;
if (eltwise_type == ngraph::opset4::Multiply::get_type_info_static()) {
eltwise = std::make_shared<ngraph::opset4::Multiply>(conv, eltwise_const);
std::shared_ptr<Node> eltwise;
if (eltwise_type == opset11::Multiply::get_type_info_static()) {
eltwise = std::make_shared<opset11::Multiply>(conv, eltwise_const);
manager.register_pass<ov::pass::ConvolutionMultiplyFusion>();
manager.register_pass<ov::pass::GroupConvolutionMultiplyFusion>();
manager.register_pass<ov::pass::ConvolutionBackpropDataMultiplyFusion>();
manager.register_pass<ov::pass::GroupConvolutionBackpropDataMultiplyFusion>();
} else if (eltwise_type == ngraph::opset4::Add::get_type_info_static()) {
eltwise = std::make_shared<ngraph::opset4::Add>(conv, eltwise_const);
// manager.register_pass<ngraph::pass::ConvertConvolutions>();
// manager.register_pass<ngraph::pass::ConvFusion>();
} else if (eltwise_type == opset11::Add::get_type_info_static()) {
eltwise = std::make_shared<opset11::Add>(conv, eltwise_const);
// manager.register_pass<pass::ConvertConvolutions>();
// manager.register_pass<pass::ConvFusion>();
} else {
OPENVINO_THROW("Unsupported type");
}
function = std::make_shared<ngraph::Function>(ngraph::OutputVector{eltwise}, ngraph::ParameterVector{param}, "conv_eltwise");
function = std::make_shared<Model>(eltwise, ParameterVector{param}, "conv_eltwise");
}
manager.register_pass<ngraph::pass::ConstantFolding>();
manager.register_pass<pass::ConstantFolding>();
auto cloned_function = ngraph::clone_function(*function);
std::shared_ptr<Model> function_ref;
if (!negative) {
auto param = std::make_shared<opset11::Parameter>(precision, input_shape);
auto spatial_dims = input_shape.size() - 2;
Shape strides(spatial_dims, 1);
std::vector<ptrdiff_t> pad_begin(spatial_dims, 0), pad_end(spatial_dims, 0);
auto weights = ngraph::builder::makeConstant<float>(precision, weights_shape, std::vector<float>(shape_size(weights_shape), 6));
std::shared_ptr<Node> conv;
if (conv_type == opset11::Convolution::get_type_info_static()) {
conv = std::make_shared<opset11::Convolution>(param, weights, strides, pad_begin, pad_end, strides);
} else if (conv_type == opset11::GroupConvolution::get_type_info_static()) {
conv = std::make_shared<opset11::GroupConvolution>(param, weights, strides, pad_begin, pad_end, strides);
} else if (conv_type == opset11::ConvolutionBackpropData::get_type_info_static()) {
if (num_inputs == 3) {
auto output_shape = std::make_shared<opset11::Constant>(element::u64, Shape{spatial_dims},
std::vector<size_t>(input_shape.begin() + 2, input_shape.end()));
conv = std::make_shared<opset11::ConvolutionBackpropData>(param, weights, output_shape, strides, pad_begin, pad_end, strides);
} else {
conv = std::make_shared<opset11::ConvolutionBackpropData>(param, weights, strides, pad_begin, pad_end, strides);
}
} else if (conv_type == opset11::GroupConvolutionBackpropData::get_type_info_static()) {
if (num_inputs == 3) {
auto output_shape = std::make_shared<opset11::Constant>(element::u64, Shape{spatial_dims},
std::vector<size_t>(input_shape.begin() + 2, input_shape.end()));
conv = std::make_shared<opset11::GroupConvolutionBackpropData>(param, weights, output_shape, strides, pad_begin, pad_end, strides);
} else {
conv = std::make_shared<opset11::GroupConvolutionBackpropData>(param, weights, strides, pad_begin, pad_end, strides);
}
}
function_ref = std::make_shared<Model>(conv, ParameterVector{param}, "conv_eltwise_ref");
} else {
function_ref = function->clone();
}
auto cloned_function = function->clone();
manager.run_passes(cloned_function);
ASSERT_EQ(cloned_function->get_ops().size(), expected_number_of_ops);
auto res = compare_functions(cloned_function, function_ref);
ASSERT_TRUE(res.first) << res.second;
}
} // namespace SubgraphTestsDefinitions