Updated Mul->Add conversion to support dynamic shapes (#512)

* Updated Mul Add conversion to support dynamic shapes

* Keep changes

* Fix for cases when eltwise performs broadcasting via Constant

* Added comments;Fixed eltwise shape infer; Updated tests
This commit is contained in:
Gleb Kazantaev 2020-05-26 10:24:52 +03:00 committed by GitHub
parent e835a4cf58
commit d3764a7563
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 410 additions and 49 deletions

View File

@ -35,5 +35,13 @@ enum class CONVERSION_RESULT {
NONE
};
/*
* check_constant function checks how given constant performs elementwise operation with given input
* CONVERSION_RESULT has several types:
* SCALE_SHIFT - constant applies only per-channel
* POWER - constant applies as single value
* NONE - default return value
*/
INFERENCE_ENGINE_API_CPP(CONVERSION_RESULT)
check_constant(const std::shared_ptr<ngraph::op::Constant> & constant, const ngraph::Shape & shape);
check_constant(const std::shared_ptr<ngraph::op::Constant> & constant, const ngraph::PartialShape & shape);

View File

@ -70,10 +70,13 @@ ngraph::graph_rewrite_callback get_callback() {
"Unsupported template parameter. Only Add or Multiply allowed!");
auto lin_op = std::dynamic_pointer_cast<T> (m.get_match_root());
if (!lin_op) {
if (!lin_op || lin_op->output(0).get_partial_shape().rank().is_dynamic()) {
return false;
}
const auto output_shape = lin_op->output(0).get_partial_shape();
const auto output_shape_rank = output_shape.rank().get_length();
if (!lin_op->get_element_type().is_real()) {
return convert_to_eltwise<T>(lin_op,
lin_op->input(0).get_source_output(),
@ -93,39 +96,58 @@ ngraph::graph_rewrite_callback get_callback() {
}
}
// Check that eltwise is not useless otherwise we remove it
if ((std::is_same<T, ngraph::opset1::Add>() && ngraph::op::util::constantIsEqualTo(const_node, 0)) ||
(std::is_same<T, ngraph::opset1::Multiply>() && ngraph::op::util::constantIsEqualTo(const_node, 1))) {
bool has_result_output = false;
for (const auto & output : lin_op->output(0).get_target_inputs()) {
if (dynamic_cast<ngraph::op::Result*>(output.get_node())) {
has_result_output = true;
}
/* This lambda checks data and constant shapes for broadcasting
For example:
1. data_shape{1, 64, 64} and const_shape{64, 1, 1} - constant broadcasts data_shape zero dimension
2. data_shape{DYN, 64, 64} and const_shape{1, 1, 64} - constant do not broadcasts data_shape
3. data_shape{64, 64} and const_shape{1, 1, 1} - constant broadcasts data_shape with additional dimension
*/
auto constant_broadcast_output = [](const ngraph::PartialShape & data_pshape, const ngraph::Shape & const_shape) -> bool {
if (data_pshape.rank().is_dynamic() || const_shape.size() > data_pshape.rank().get_length()) {
return true;
}
auto parent = data_node.get_node_shared_ptr();
size_t consumers_count = 0;
for (const auto &output : parent->outputs()) {
consumers_count += output.get_target_inputs().size();
std::vector<ngraph::Dimension> data_shape(data_pshape);
auto const_shape_it = const_shape.rbegin();
auto data_shape_it = data_shape.rbegin();
while (const_shape_it != const_shape.rend()) {
auto data_dim = *data_shape_it;
auto const_dim = *const_shape_it;
/* DATA DIM - CONST DIM - CONSTANT BROADCAST OUTPUT
DYN - 64 - TRUE
DYN - 1 - FALSE
64 - 1 - FALSE
1 - 64 - TRUE
64 - 64 - FALSE
*/
if ((data_dim.is_dynamic() && const_dim != 1) ||
(data_dim.is_static() && data_dim.get_length() == 1 && const_dim != 1)) {
return true;
}
++const_shape_it;
++data_shape_it;
}
if (!has_result_output || consumers_count == 1) {
if (!std::dynamic_pointer_cast<ngraph::op::Parameter>(parent)) {
parent->set_friendly_name(lin_op->get_friendly_name());
}
// TODO: due to ngraph::replace_node function limitations we have to reconnect output port consumers to the new input
// using replace_source_output method
for (auto &input : lin_op->output(0).get_target_inputs()) {
input.replace_source_output(data_node);
}
return false;
};
// Check that eltwise is not useless and do not broadcast output otherwise we remove it
if (((std::is_same<T, ngraph::opset1::Add>() && ngraph::op::util::constantIsEqualTo(const_node, 0)) ||
(std::is_same<T, ngraph::opset1::Multiply>() && ngraph::op::util::constantIsEqualTo(const_node, 1))) &&
!constant_broadcast_output(data_node.get_partial_shape(), const_node->get_shape())) {
bool ret_status = ngraph::replace_output_update_name(lin_op->output(0), data_node);
if (ret_status) {
return true;
}
}
auto res = check_constant(const_node, data_node.get_partial_shape());
auto res = check_constant(const_node, data_node.get_shape());
if (res == CONVERSION_RESULT::NONE || (res == CONVERSION_RESULT::SCALE_SHIFT && lin_op->get_shape().size() < 4)) {
if (res == CONVERSION_RESULT::NONE || (res == CONVERSION_RESULT::SCALE_SHIFT && output_shape_rank < 4)) {
return convert_to_eltwise<T>(lin_op,
lin_op->input(0).get_source_output(),
lin_op->input(1).get_source_output());
@ -140,12 +162,12 @@ ngraph::graph_rewrite_callback get_callback() {
std::shared_ptr<ngraph::op::ScaleShiftIE> scaleshift;
if (std::is_same<T, ngraph::opset1::Add>()) {
auto weights = ngraph::opset1::Constant::create(weights_et, weights_shape, {1});
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, ngraph::op::util::normalize_constant(weights, lin_op->get_shape()),
ngraph::op::util::normalize_constant(const_node, lin_op->get_shape()));
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, ngraph::op::util::normalize_constant(weights, output_shape),
ngraph::op::util::normalize_constant(const_node, output_shape));
} else {
auto bias = ngraph::opset1::Constant::create(weights_et, weights_shape, {0});
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, ngraph::op::util::normalize_constant(const_node, lin_op->get_shape()),
ngraph::op::util::normalize_constant(bias, lin_op->get_shape()));
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, ngraph::op::util::normalize_constant(const_node, output_shape),
ngraph::op::util::normalize_constant(bias, output_shape));
}
scaleshift->set_friendly_name(lin_op->get_friendly_name());

View File

@ -47,7 +47,7 @@ bool has_op_with_type(const std::shared_ptr<const ngraph::Function> &function) {
INFERENCE_ENGINE_API_CPP(bool) get_single_value(const std::shared_ptr<op::Constant> & const_node, float & value);
INFERENCE_ENGINE_API_CPP(std::shared_ptr<ngraph::Node>) normalize_constant(const std::shared_ptr<op::Constant> & constant,
const Shape & shape);
const PartialShape & shape);
INFERENCE_ENGINE_API_CPP(std::shared_ptr<ngraph::Node>) broadcastTo(const Output<Node>& input, const Shape& shape);

View File

@ -37,16 +37,24 @@ void op::Eltwise::validate_and_infer_types() {
NODE_VALIDATION_CHECK(this, element::Type::merge(et_result, data1_et, data2_et),
"Element types for first and second do not match :", data1_et, " and ", data2_et);
auto shape1 = get_input_partial_shape(0).to_shape();
auto shape2 = get_input_partial_shape(1).to_shape();
if (get_input_partial_shape(0).rank().is_dynamic() ||
get_input_partial_shape(1).rank().is_dynamic()) {
set_output_type(0, et_result, PartialShape::dynamic());
return;
}
ngraph::Shape output_shape(std::max(shape1.size(), shape2.size()));
std::vector<Dimension> shape1(get_input_partial_shape(0));
std::vector<Dimension> shape2(get_input_partial_shape(1));
std::vector<Dimension> output_shape(PartialShape::dynamic(std::max(shape1.size(), shape2.size())));
auto output_shape_it = output_shape.rbegin();
auto shape1_it = shape1.rbegin(), shape2_it = shape2.rbegin();
while (shape1_it != shape1.rend() || shape2_it != shape2.rend()) {
if (shape1_it != shape1.rend() && shape2_it != shape2.rend()) {
*output_shape_it = std::max(*shape1_it, *shape2_it);
if (shape1_it->is_static() && shape2_it->is_static()) {
*output_shape_it = (shape1_it->get_length() > shape2_it->get_length() ? *shape1_it : *shape2_it);
}
} else if (shape1_it != shape1.rend()) {
*output_shape_it = *shape1_it;
} else if (shape2_it != shape2.rend()) {
@ -61,5 +69,5 @@ void op::Eltwise::validate_and_infer_types() {
}
}
set_output_type(0, data1_et, PartialShape(output_shape));
set_output_type(0, et_result, output_shape);
}

View File

@ -17,11 +17,11 @@
#include "ngraph_ops/scaleshift.hpp"
CONVERSION_RESULT check_constant(const std::shared_ptr<ngraph::opset1::Constant>& constant,
const ngraph::Shape& shape) {
if (!constant) return CONVERSION_RESULT::NONE;
const ngraph::PartialShape& shape) {
if (!constant || shape.rank().is_dynamic()) return CONVERSION_RESULT::NONE;
auto const_shape = constant->get_shape();
auto input_shape = shape;
std::vector<ngraph::Dimension> input_shape(shape);
// In case of scalar we will convert it to Power
if (const_shape.empty() || (const_shape.size() == 1 && const_shape[0] == 1)) {
@ -47,7 +47,7 @@ CONVERSION_RESULT check_constant(const std::shared_ptr<ngraph::opset1::Constant>
if (idx == feature_index && *in_it == 1) {
is_power = true;
} else if (idx == feature_index && *in_it != *out_it) {
} else if (idx == feature_index && (out_it->is_dynamic() || *in_it != out_it->get_length())) {
return CONVERSION_RESULT::NONE;
}
}
@ -95,6 +95,11 @@ void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshi
const_weights_node = ngraph::as_type_ptr<ngraph::opset1::Constant>(mul_input_0);
}
if (add_node->get_output_partial_shape(0).rank().is_dynamic() ||
mul_node->get_output_partial_shape(0).rank().is_dynamic()) {
return false;
}
// Check that eltwise is not useless otherwise we remove it
if (ngraph::op::util::constantIsEqualTo(const_weights_node, 1) &&
ngraph::op::util::constantIsEqualTo(const_bias_node, 0)) {
@ -124,11 +129,14 @@ void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshi
}
}
auto res1 = check_constant(const_weights_node, data_node.get_shape());
auto res2 = check_constant(const_bias_node, mul_node->get_output_shape(0));
auto res1 = check_constant(const_weights_node, data_node.get_partial_shape());
auto res2 = check_constant(const_bias_node, mul_node->get_output_partial_shape(0));
const auto output_shape = add_node->get_output_partial_shape(0);
const auto output_shape_rank = output_shape.rank().get_length();
if (res1 == CONVERSION_RESULT::NONE || res2 == CONVERSION_RESULT::NONE ||
((res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT) && add_node->get_shape().size() < 4)) {
((res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT) && output_shape_rank < 4)) {
return false;
}
@ -136,8 +144,8 @@ void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshi
if (res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT) {
NodeVector new_ops;
auto weights_in = ngraph::op::util::normalize_constant(const_weights_node, add_node->get_shape());
auto biases_in = ngraph::op::util::normalize_constant(const_bias_node, add_node->get_shape());
auto weights_in = ngraph::op::util::normalize_constant(const_weights_node, output_shape);
auto biases_in = ngraph::op::util::normalize_constant(const_bias_node, output_shape);
new_ops.push_back(weights_in);
new_ops.push_back(biases_in);

View File

@ -49,12 +49,12 @@ bool get_single_value(const std::shared_ptr<op::Constant>& const_node, float& va
}
std::shared_ptr<Node> normalize_constant(const std::shared_ptr<op::Constant>& constant,
const Shape& shape) {
const PartialShape& shape) {
auto const_shape = constant->get_shape();
if (const_shape.size() == shape.size()) {
if (const_shape.size() == shape.rank().get_length()) {
return constant;
}
int cnt = shape.size() - const_shape.size();
int64_t cnt = shape.rank().get_length() - const_shape.size();
for (int i = 0; i < cnt; ++i) {
const_shape.insert(const_shape.begin(), 1);
}

View File

@ -1757,7 +1757,7 @@ TEST_F(NGraphReaderTests, RemoveAdd2) {
</output>
</layer>
<layer id="3" name="add" precision="FP32" type="ReLU">
<data originalLayersNames="relu"/>
<data originalLayersNames="add,relu"/>
<input>
<port id="0">
<dim>1</dim>

View File

@ -0,0 +1,315 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "common_test_utils/test_common.hpp"
#include <string>
#include <sstream>
#include <fstream>
#include <memory>
#include <queue>
#include <map>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <transformations/utils/utils.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/convert_opset1_to_legacy/conv_bias_fusion.hpp>
#include <ngraph/pass/visualize_tree.hpp>
#include <transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.hpp>
#include <transformations/convert_opset1_to_legacy/convert_mul_or_add_finally.hpp>
#include <ngraph_ops/power.hpp>
#include <ngraph_ops/scaleshift.hpp>
#include "ngraph_test_utils.hpp"
using namespace testing;
using InputShape = ngraph::PartialShape;
struct ConstantParams {
ngraph::Shape shape;
float value;
bool skip;
ConstantParams() : skip(true) {}
ConstantParams(const ngraph::Shape & shape, float value)
: shape(shape), value(value), skip(false) {}
};
using MulConstant = ConstantParams;
using AddConstant = ConstantParams;
using RefFunction = std::function<std::shared_ptr<ngraph::Function>(const InputShape&, const MulConstant&, const AddConstant&)>;
class MulAddConversionTests: public CommonTestUtils::TestsCommon,
public testing::WithParamInterface<std::tuple<std::tuple<InputShape, MulConstant, AddConstant>, RefFunction> > {
public:
std::shared_ptr<ngraph::Function> f, f_ref;
void SetUp() override {
const auto& attrs = std::get<0>(GetParam());
const auto& input_shape = std::get<0>(attrs);
const auto& mul_const = std::get<1>(attrs);
const auto& add_const = std::get<2>(attrs);
const auto& get_ref_function = std::get<1>(GetParam());
f = get_initial_function(input_shape, mul_const, add_const);
f_ref = get_ref_function(input_shape, mul_const, add_const);
}
static
std::shared_ptr<ngraph::Function> get_initial_function(const InputShape& input_shape,
const MulConstant& mul_const,
const AddConstant& add_const) {
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
ngraph::Output<ngraph::Node> last = input;
if (!mul_const.skip) {
last = std::make_shared<ngraph::opset1::Multiply>(last, create_constant(mul_const.shape, mul_const.value));
}
if (!add_const.skip) {
last = std::make_shared<ngraph::opset1::Add>(last, create_constant(add_const.shape, add_const.value));
}
last = std::make_shared<ngraph::opset1::Relu>(last);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{last.get_node_shared_ptr()}, ngraph::ParameterVector{input});
}
static
std::shared_ptr<ngraph::Function> get_scale_shift_reference(const InputShape& input_shape,
const MulConstant& mul_const,
const AddConstant& add_const) {
if (mul_const.skip && add_const.skip) {
throw ngraph::ngraph_error("Invalid arguments");
}
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
auto scsh = std::make_shared<ngraph::op::ScaleShiftIE>(input, (!mul_const.skip ? create_constant(mul_const.shape, mul_const.value)
: create_constant(add_const.shape, 1)),
(!add_const.skip ? create_constant(add_const.shape, add_const.value)
: create_constant(mul_const.shape, 0)));
auto relu = std::make_shared<ngraph::opset1::Relu>(scsh);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{relu}, ngraph::ParameterVector{input});
}
static
std::shared_ptr<ngraph::Function> get_power_reference(const InputShape& input_shape,
const MulConstant& mul_const,
const AddConstant& add_const) {
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
float scale(1), shift(0);
if (!mul_const.skip) scale = mul_const.value;
if (!add_const.skip) shift = add_const.value;
auto pow = std::make_shared<ngraph::op::PowerIE>(input, 1., scale, shift);
auto relu = std::make_shared<ngraph::opset1::Relu>(pow);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{relu}, ngraph::ParameterVector{input});
}
static
std::shared_ptr<ngraph::Function> get_eltwise_add_reference(const InputShape& input_shape,
const MulConstant& mul_const,
const AddConstant& add_const) {
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
auto add = std::make_shared<ngraph::op::Eltwise>(input, create_constant(add_const.shape, add_const.value), ELTWISE_TYPE::Sum);
auto relu = std::make_shared<ngraph::opset1::Relu>(add);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{relu}, ngraph::ParameterVector{input});
}
static
std::shared_ptr<ngraph::Function> get_eltwise_mul_reference(const InputShape& input_shape,
const MulConstant& mul_const,
const AddConstant& add_const) {
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
auto mul = std::make_shared<ngraph::op::Eltwise>(input, create_constant(mul_const.shape, mul_const.value), ELTWISE_TYPE::Prod);
auto relu = std::make_shared<ngraph::opset1::Relu>(mul);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{relu}, ngraph::ParameterVector{input});
}
static
std::shared_ptr<ngraph::opset1::Constant> create_constant(const ngraph::Shape & shape, float init_value) {
return ngraph::opset1::Constant::create(ngraph::element::f32, shape, {init_value});
}
};
class MulOrAddConversionTests: public MulAddConversionTests {};
TEST_P(MulAddConversionTests, CompareFunctions) {
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertMulAddToScaleShiftOrPower().run_on_function(f);
ASSERT_NO_THROW(check_rt_info(f));
ngraph::pass::ConstantFolding().run_on_function(f);
f->validate_nodes_and_infer_types();
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
TEST_P(MulOrAddConversionTests, CompareFunctions) {
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertMulOrAddFinally().run_on_function(f);
ASSERT_NO_THROW(check_rt_info(f));
ngraph::pass::ConstantFolding().run_on_function(f);
f->validate_nodes_and_infer_types();
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
#define CONST(A, B) ConstantParams(A, B)
#define NONE ConstantParams()
#define SCALESHIFT MulAddConversionTests::get_scale_shift_reference
#define POWER MulAddConversionTests::get_power_reference
#define SAME MulAddConversionTests::get_initial_function
#define ELTWISE_SUM MulAddConversionTests::get_eltwise_add_reference
#define ELTWISE_PROD MulAddConversionTests::get_eltwise_mul_reference
INSTANTIATE_TEST_CASE_P(MulAddToScaleShift, MulAddConversionTests, testing::Combine(
testing::Values(std::make_tuple(InputShape{DYN, 3, 64, 64},
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)),
std::make_tuple(InputShape{DYN, 3, DYN, 64},
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)),
std::make_tuple(InputShape{DYN, 3, DYN, DYN},
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5))),
testing::Values(SCALESHIFT)));
INSTANTIATE_TEST_CASE_P(MulToScaleShift, MulOrAddConversionTests, testing::Combine(
testing::Values(std::make_tuple(InputShape{DYN, 3, 64, 64},
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
NONE),
std::make_tuple(InputShape{DYN, 3, DYN, 64},
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
NONE),
std::make_tuple(InputShape{DYN, 3, DYN, DYN},
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
NONE)),
testing::Values(SCALESHIFT)));
INSTANTIATE_TEST_CASE_P(AddToScaleShift, MulOrAddConversionTests, testing::Combine(
testing::Values(std::make_tuple(InputShape{DYN, 3, 64, 64},
NONE,
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)),
std::make_tuple(InputShape{DYN, 3, DYN, 64},
NONE,
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)),
std::make_tuple(InputShape{DYN, 3, DYN, DYN},
NONE,
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5))),
testing::Values(SCALESHIFT)));
INSTANTIATE_TEST_CASE_P(MulAddToPower, MulAddConversionTests, testing::Combine(
testing::Values(std::make_tuple(InputShape{DYN, 3, 64, 64},
CONST(ngraph::Shape({1}), 0.5),
CONST(ngraph::Shape({1}), 0.5)),
std::make_tuple(InputShape{DYN, 3, DYN, 64},
CONST(ngraph::Shape({1}), 0.5),
CONST(ngraph::Shape({1}), 0.5)),
std::make_tuple(InputShape{DYN, 3, DYN, DYN},
CONST(ngraph::Shape({1}), 0.5),
CONST(ngraph::Shape({1}), 0.5))),
testing::Values(POWER)));
INSTANTIATE_TEST_CASE_P(MulToPower, MulOrAddConversionTests, testing::Combine(
testing::Values(std::make_tuple(InputShape{DYN, 3, 64, 64},
CONST(ngraph::Shape({1}), 0.5),
NONE),
std::make_tuple(InputShape{DYN, 3, DYN, 64},
CONST(ngraph::Shape({1}), 0.5),
NONE),
std::make_tuple(InputShape{DYN, 3, DYN, DYN},
CONST(ngraph::Shape({1}), 0.5),
NONE)),
testing::Values(POWER)));
INSTANTIATE_TEST_CASE_P(AddToPower, MulOrAddConversionTests, testing::Combine(
testing::Values(std::make_tuple(InputShape{DYN, 3, 64, 64},
NONE,
CONST(ngraph::Shape({1}), 0.5)),
std::make_tuple(InputShape{DYN, 3, DYN, 64},
NONE,
CONST(ngraph::Shape({1}), 0.5)),
std::make_tuple(InputShape{DYN, 3, DYN, DYN},
NONE,
CONST(ngraph::Shape({1}), 0.5))),
testing::Values(POWER)));
INSTANTIATE_TEST_CASE_P(MulAddNegative, MulAddConversionTests, testing::Combine(
testing::Values(std::make_tuple(InputShape{DYN, 3, 64},
CONST(ngraph::Shape({1, 3, 1}), 0.5),
CONST(ngraph::Shape({1, 3, 1}), 0.5)/*ScaleShift must always be 4D*/),
std::make_tuple(InputShape{DYN, 3, DYN},
CONST(ngraph::Shape({1, 1, 3, 1}), 0.5),
CONST(ngraph::Shape({3, 1}), 0.5)/*detect broadcast case*/),
std::make_tuple(InputShape{DYN, 3, DYN},
CONST(ngraph::Shape({3, 1}), 0.5),
CONST(ngraph::Shape({1, 1, 3, 1}), 0.5)/*detect broadcast case*/),
std::make_tuple(InputShape{DYN, DYN, DYN, DYN},
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)),
std::make_tuple(InputShape{DYN, DYN, DYN, DYN},
CONST(ngraph::Shape({1, 3, 2, 1}), 0.5),
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)),
std::make_tuple(InputShape{1, 3, 2},
CONST(ngraph::Shape({1, 3, 1}), 0.5),
CONST(ngraph::Shape({1, 3, 2}), 0.5)),
std::make_tuple(InputShape{1, DYN, 64, 64},
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5))),
testing::Values(SAME)));
INSTANTIATE_TEST_CASE_P(MulToEltwise, MulOrAddConversionTests, testing::Combine(
testing::Values(std::make_tuple(InputShape{DYN, 3, 64},
CONST(ngraph::Shape({1, 1, 64}), 0.5),
NONE),
std::make_tuple(InputShape{DYN, 3, DYN},
CONST(ngraph::Shape({1, 1, 3, 1}), 0.5),
NONE),
std::make_tuple(InputShape{DYN, DYN, DYN, DYN},
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
NONE),
std::make_tuple(InputShape{DYN, 3, DYN, DYN},
CONST(ngraph::Shape({1, 3, 2, 1}), 0.5),
NONE),
std::make_tuple(InputShape{1, 3, 2},
CONST(ngraph::Shape({1, 3, 2}), 0.5),
NONE),
std::make_tuple(InputShape{1, DYN, 64, 64},
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
NONE),
std::make_tuple(InputShape{64, 1, 64},
CONST(ngraph::Shape({64, 64, 64}), 1),
NONE),
std::make_tuple(InputShape{64, 64, 1},
CONST(ngraph::Shape({1, 1, 64}), 1),
NONE),
std::make_tuple(InputShape{DYN, 1, 64},
CONST(ngraph::Shape({64, 1, 64}), 1),
NONE)),
testing::Values(ELTWISE_PROD)));
INSTANTIATE_TEST_CASE_P(AddToEltwise, MulOrAddConversionTests, testing::Combine(
testing::Values(std::make_tuple(InputShape{DYN, 3, 64},
NONE,
CONST(ngraph::Shape({1, 1, 64}), 0.5)),
std::make_tuple(InputShape{DYN, 3, DYN},
NONE,
CONST(ngraph::Shape({1, 1, 3, 1}), 0.5)),
std::make_tuple(InputShape{DYN, DYN, DYN, DYN},
NONE,
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)),
std::make_tuple(InputShape{DYN, 3, DYN, DYN},
NONE,
CONST(ngraph::Shape({1, 3, 2, 1}), 0.5)),
std::make_tuple(InputShape{1, 3, 2},
NONE,
CONST(ngraph::Shape({1, 3, 2}), 0.5)),
std::make_tuple(InputShape{1, DYN, 64, 64},
NONE,
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5))),
testing::Values(ELTWISE_SUM)));
#undef CONST
#undef SCALESHIFT
#undef POWER
#undef SAME
#undef ELTWISE_PROD
#undef ELTWISE_SUM