Updated Mul->Add conversion to support dynamic shapes (#512)
* Updated Mul Add conversion to support dynamic shapes * Keep changes * Fix for cases when eltwise performs broadcasting via Constant * Added comments;Fixed eltwise shape infer; Updated tests
This commit is contained in:
parent
e835a4cf58
commit
d3764a7563
@ -35,5 +35,13 @@ enum class CONVERSION_RESULT {
|
||||
NONE
|
||||
};
|
||||
|
||||
/*
|
||||
* check_constant function checks how given constant performs elementwise operation with given input
|
||||
* CONVERSION_RESULT has several types:
|
||||
* SCALE_SHIFT - constant applies only per-channel
|
||||
* POWER - constant applies as single value
|
||||
* NONE - default return value
|
||||
*/
|
||||
|
||||
INFERENCE_ENGINE_API_CPP(CONVERSION_RESULT)
|
||||
check_constant(const std::shared_ptr<ngraph::op::Constant> & constant, const ngraph::Shape & shape);
|
||||
check_constant(const std::shared_ptr<ngraph::op::Constant> & constant, const ngraph::PartialShape & shape);
|
||||
|
@ -70,10 +70,13 @@ ngraph::graph_rewrite_callback get_callback() {
|
||||
"Unsupported template parameter. Only Add or Multiply allowed!");
|
||||
|
||||
auto lin_op = std::dynamic_pointer_cast<T> (m.get_match_root());
|
||||
if (!lin_op) {
|
||||
if (!lin_op || lin_op->output(0).get_partial_shape().rank().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto output_shape = lin_op->output(0).get_partial_shape();
|
||||
const auto output_shape_rank = output_shape.rank().get_length();
|
||||
|
||||
if (!lin_op->get_element_type().is_real()) {
|
||||
return convert_to_eltwise<T>(lin_op,
|
||||
lin_op->input(0).get_source_output(),
|
||||
@ -93,39 +96,58 @@ ngraph::graph_rewrite_callback get_callback() {
|
||||
}
|
||||
}
|
||||
|
||||
// Check that eltwise is not useless otherwise we remove it
|
||||
if ((std::is_same<T, ngraph::opset1::Add>() && ngraph::op::util::constantIsEqualTo(const_node, 0)) ||
|
||||
(std::is_same<T, ngraph::opset1::Multiply>() && ngraph::op::util::constantIsEqualTo(const_node, 1))) {
|
||||
bool has_result_output = false;
|
||||
for (const auto & output : lin_op->output(0).get_target_inputs()) {
|
||||
if (dynamic_cast<ngraph::op::Result*>(output.get_node())) {
|
||||
has_result_output = true;
|
||||
}
|
||||
/* This lambda checks data and constant shapes for broadcasting
|
||||
For example:
|
||||
1. data_shape{1, 64, 64} and const_shape{64, 1, 1} - constant broadcasts data_shape zero dimension
|
||||
2. data_shape{DYN, 64, 64} and const_shape{1, 1, 64} - constant do not broadcasts data_shape
|
||||
3. data_shape{64, 64} and const_shape{1, 1, 1} - constant broadcasts data_shape with additional dimension
|
||||
*/
|
||||
auto constant_broadcast_output = [](const ngraph::PartialShape & data_pshape, const ngraph::Shape & const_shape) -> bool {
|
||||
if (data_pshape.rank().is_dynamic() || const_shape.size() > data_pshape.rank().get_length()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto parent = data_node.get_node_shared_ptr();
|
||||
size_t consumers_count = 0;
|
||||
for (const auto &output : parent->outputs()) {
|
||||
consumers_count += output.get_target_inputs().size();
|
||||
std::vector<ngraph::Dimension> data_shape(data_pshape);
|
||||
|
||||
auto const_shape_it = const_shape.rbegin();
|
||||
auto data_shape_it = data_shape.rbegin();
|
||||
|
||||
while (const_shape_it != const_shape.rend()) {
|
||||
auto data_dim = *data_shape_it;
|
||||
auto const_dim = *const_shape_it;
|
||||
|
||||
/* DATA DIM - CONST DIM - CONSTANT BROADCAST OUTPUT
|
||||
DYN - 64 - TRUE
|
||||
DYN - 1 - FALSE
|
||||
64 - 1 - FALSE
|
||||
1 - 64 - TRUE
|
||||
64 - 64 - FALSE
|
||||
*/
|
||||
if ((data_dim.is_dynamic() && const_dim != 1) ||
|
||||
(data_dim.is_static() && data_dim.get_length() == 1 && const_dim != 1)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
++const_shape_it;
|
||||
++data_shape_it;
|
||||
}
|
||||
|
||||
if (!has_result_output || consumers_count == 1) {
|
||||
if (!std::dynamic_pointer_cast<ngraph::op::Parameter>(parent)) {
|
||||
parent->set_friendly_name(lin_op->get_friendly_name());
|
||||
}
|
||||
// TODO: due to ngraph::replace_node function limitations we have to reconnect output port consumers to the new input
|
||||
// using replace_source_output method
|
||||
for (auto &input : lin_op->output(0).get_target_inputs()) {
|
||||
input.replace_source_output(data_node);
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// Check that eltwise is not useless and do not broadcast output otherwise we remove it
|
||||
if (((std::is_same<T, ngraph::opset1::Add>() && ngraph::op::util::constantIsEqualTo(const_node, 0)) ||
|
||||
(std::is_same<T, ngraph::opset1::Multiply>() && ngraph::op::util::constantIsEqualTo(const_node, 1))) &&
|
||||
!constant_broadcast_output(data_node.get_partial_shape(), const_node->get_shape())) {
|
||||
bool ret_status = ngraph::replace_output_update_name(lin_op->output(0), data_node);
|
||||
if (ret_status) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
auto res = check_constant(const_node, data_node.get_partial_shape());
|
||||
|
||||
auto res = check_constant(const_node, data_node.get_shape());
|
||||
|
||||
if (res == CONVERSION_RESULT::NONE || (res == CONVERSION_RESULT::SCALE_SHIFT && lin_op->get_shape().size() < 4)) {
|
||||
if (res == CONVERSION_RESULT::NONE || (res == CONVERSION_RESULT::SCALE_SHIFT && output_shape_rank < 4)) {
|
||||
return convert_to_eltwise<T>(lin_op,
|
||||
lin_op->input(0).get_source_output(),
|
||||
lin_op->input(1).get_source_output());
|
||||
@ -140,12 +162,12 @@ ngraph::graph_rewrite_callback get_callback() {
|
||||
std::shared_ptr<ngraph::op::ScaleShiftIE> scaleshift;
|
||||
if (std::is_same<T, ngraph::opset1::Add>()) {
|
||||
auto weights = ngraph::opset1::Constant::create(weights_et, weights_shape, {1});
|
||||
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, ngraph::op::util::normalize_constant(weights, lin_op->get_shape()),
|
||||
ngraph::op::util::normalize_constant(const_node, lin_op->get_shape()));
|
||||
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, ngraph::op::util::normalize_constant(weights, output_shape),
|
||||
ngraph::op::util::normalize_constant(const_node, output_shape));
|
||||
} else {
|
||||
auto bias = ngraph::opset1::Constant::create(weights_et, weights_shape, {0});
|
||||
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, ngraph::op::util::normalize_constant(const_node, lin_op->get_shape()),
|
||||
ngraph::op::util::normalize_constant(bias, lin_op->get_shape()));
|
||||
scaleshift = std::make_shared<ngraph::op::ScaleShiftIE>(data_node, ngraph::op::util::normalize_constant(const_node, output_shape),
|
||||
ngraph::op::util::normalize_constant(bias, output_shape));
|
||||
}
|
||||
|
||||
scaleshift->set_friendly_name(lin_op->get_friendly_name());
|
||||
|
@ -47,7 +47,7 @@ bool has_op_with_type(const std::shared_ptr<const ngraph::Function> &function) {
|
||||
INFERENCE_ENGINE_API_CPP(bool) get_single_value(const std::shared_ptr<op::Constant> & const_node, float & value);
|
||||
|
||||
INFERENCE_ENGINE_API_CPP(std::shared_ptr<ngraph::Node>) normalize_constant(const std::shared_ptr<op::Constant> & constant,
|
||||
const Shape & shape);
|
||||
const PartialShape & shape);
|
||||
|
||||
INFERENCE_ENGINE_API_CPP(std::shared_ptr<ngraph::Node>) broadcastTo(const Output<Node>& input, const Shape& shape);
|
||||
|
||||
|
@ -37,16 +37,24 @@ void op::Eltwise::validate_and_infer_types() {
|
||||
NODE_VALIDATION_CHECK(this, element::Type::merge(et_result, data1_et, data2_et),
|
||||
"Element types for first and second do not match :", data1_et, " and ", data2_et);
|
||||
|
||||
auto shape1 = get_input_partial_shape(0).to_shape();
|
||||
auto shape2 = get_input_partial_shape(1).to_shape();
|
||||
if (get_input_partial_shape(0).rank().is_dynamic() ||
|
||||
get_input_partial_shape(1).rank().is_dynamic()) {
|
||||
set_output_type(0, et_result, PartialShape::dynamic());
|
||||
return;
|
||||
}
|
||||
|
||||
ngraph::Shape output_shape(std::max(shape1.size(), shape2.size()));
|
||||
std::vector<Dimension> shape1(get_input_partial_shape(0));
|
||||
std::vector<Dimension> shape2(get_input_partial_shape(1));
|
||||
|
||||
std::vector<Dimension> output_shape(PartialShape::dynamic(std::max(shape1.size(), shape2.size())));
|
||||
auto output_shape_it = output_shape.rbegin();
|
||||
|
||||
auto shape1_it = shape1.rbegin(), shape2_it = shape2.rbegin();
|
||||
while (shape1_it != shape1.rend() || shape2_it != shape2.rend()) {
|
||||
if (shape1_it != shape1.rend() && shape2_it != shape2.rend()) {
|
||||
*output_shape_it = std::max(*shape1_it, *shape2_it);
|
||||
if (shape1_it->is_static() && shape2_it->is_static()) {
|
||||
*output_shape_it = (shape1_it->get_length() > shape2_it->get_length() ? *shape1_it : *shape2_it);
|
||||
}
|
||||
} else if (shape1_it != shape1.rend()) {
|
||||
*output_shape_it = *shape1_it;
|
||||
} else if (shape2_it != shape2.rend()) {
|
||||
@ -61,5 +69,5 @@ void op::Eltwise::validate_and_infer_types() {
|
||||
}
|
||||
}
|
||||
|
||||
set_output_type(0, data1_et, PartialShape(output_shape));
|
||||
set_output_type(0, et_result, output_shape);
|
||||
}
|
||||
|
@ -17,11 +17,11 @@
|
||||
#include "ngraph_ops/scaleshift.hpp"
|
||||
|
||||
CONVERSION_RESULT check_constant(const std::shared_ptr<ngraph::opset1::Constant>& constant,
|
||||
const ngraph::Shape& shape) {
|
||||
if (!constant) return CONVERSION_RESULT::NONE;
|
||||
const ngraph::PartialShape& shape) {
|
||||
if (!constant || shape.rank().is_dynamic()) return CONVERSION_RESULT::NONE;
|
||||
|
||||
auto const_shape = constant->get_shape();
|
||||
auto input_shape = shape;
|
||||
std::vector<ngraph::Dimension> input_shape(shape);
|
||||
|
||||
// In case of scalar we will convert it to Power
|
||||
if (const_shape.empty() || (const_shape.size() == 1 && const_shape[0] == 1)) {
|
||||
@ -47,7 +47,7 @@ CONVERSION_RESULT check_constant(const std::shared_ptr<ngraph::opset1::Constant>
|
||||
|
||||
if (idx == feature_index && *in_it == 1) {
|
||||
is_power = true;
|
||||
} else if (idx == feature_index && *in_it != *out_it) {
|
||||
} else if (idx == feature_index && (out_it->is_dynamic() || *in_it != out_it->get_length())) {
|
||||
return CONVERSION_RESULT::NONE;
|
||||
}
|
||||
}
|
||||
@ -95,6 +95,11 @@ void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshi
|
||||
const_weights_node = ngraph::as_type_ptr<ngraph::opset1::Constant>(mul_input_0);
|
||||
}
|
||||
|
||||
if (add_node->get_output_partial_shape(0).rank().is_dynamic() ||
|
||||
mul_node->get_output_partial_shape(0).rank().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that eltwise is not useless otherwise we remove it
|
||||
if (ngraph::op::util::constantIsEqualTo(const_weights_node, 1) &&
|
||||
ngraph::op::util::constantIsEqualTo(const_bias_node, 0)) {
|
||||
@ -124,11 +129,14 @@ void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshi
|
||||
}
|
||||
}
|
||||
|
||||
auto res1 = check_constant(const_weights_node, data_node.get_shape());
|
||||
auto res2 = check_constant(const_bias_node, mul_node->get_output_shape(0));
|
||||
auto res1 = check_constant(const_weights_node, data_node.get_partial_shape());
|
||||
auto res2 = check_constant(const_bias_node, mul_node->get_output_partial_shape(0));
|
||||
|
||||
const auto output_shape = add_node->get_output_partial_shape(0);
|
||||
const auto output_shape_rank = output_shape.rank().get_length();
|
||||
|
||||
if (res1 == CONVERSION_RESULT::NONE || res2 == CONVERSION_RESULT::NONE ||
|
||||
((res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT) && add_node->get_shape().size() < 4)) {
|
||||
((res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT) && output_shape_rank < 4)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -136,8 +144,8 @@ void ngraph::pass::ConvertMulAddToScaleShiftOrPower::convert_mul_add_to_scaleshi
|
||||
if (res1 == CONVERSION_RESULT::SCALE_SHIFT || res2 == CONVERSION_RESULT::SCALE_SHIFT) {
|
||||
NodeVector new_ops;
|
||||
|
||||
auto weights_in = ngraph::op::util::normalize_constant(const_weights_node, add_node->get_shape());
|
||||
auto biases_in = ngraph::op::util::normalize_constant(const_bias_node, add_node->get_shape());
|
||||
auto weights_in = ngraph::op::util::normalize_constant(const_weights_node, output_shape);
|
||||
auto biases_in = ngraph::op::util::normalize_constant(const_bias_node, output_shape);
|
||||
new_ops.push_back(weights_in);
|
||||
new_ops.push_back(biases_in);
|
||||
|
||||
|
@ -49,12 +49,12 @@ bool get_single_value(const std::shared_ptr<op::Constant>& const_node, float& va
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> normalize_constant(const std::shared_ptr<op::Constant>& constant,
|
||||
const Shape& shape) {
|
||||
const PartialShape& shape) {
|
||||
auto const_shape = constant->get_shape();
|
||||
if (const_shape.size() == shape.size()) {
|
||||
if (const_shape.size() == shape.rank().get_length()) {
|
||||
return constant;
|
||||
}
|
||||
int cnt = shape.size() - const_shape.size();
|
||||
int64_t cnt = shape.rank().get_length() - const_shape.size();
|
||||
for (int i = 0; i < cnt; ++i) {
|
||||
const_shape.insert(const_shape.begin(), 1);
|
||||
}
|
||||
|
@ -1757,7 +1757,7 @@ TEST_F(NGraphReaderTests, RemoveAdd2) {
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="3" name="add" precision="FP32" type="ReLU">
|
||||
<data originalLayersNames="relu"/>
|
||||
<data originalLayersNames="add,relu"/>
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>1</dim>
|
||||
|
@ -0,0 +1,315 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common_test_utils/test_common.hpp"
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <map>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/conv_bias_fusion.hpp>
|
||||
#include <ngraph/pass/visualize_tree.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_mul_add_to_scaleshift_or_power.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_mul_or_add_finally.hpp>
|
||||
#include <ngraph_ops/power.hpp>
|
||||
#include <ngraph_ops/scaleshift.hpp>
|
||||
|
||||
#include "ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
using InputShape = ngraph::PartialShape;
|
||||
struct ConstantParams {
|
||||
ngraph::Shape shape;
|
||||
float value;
|
||||
bool skip;
|
||||
ConstantParams() : skip(true) {}
|
||||
ConstantParams(const ngraph::Shape & shape, float value)
|
||||
: shape(shape), value(value), skip(false) {}
|
||||
};
|
||||
using MulConstant = ConstantParams;
|
||||
using AddConstant = ConstantParams;
|
||||
using RefFunction = std::function<std::shared_ptr<ngraph::Function>(const InputShape&, const MulConstant&, const AddConstant&)>;
|
||||
|
||||
class MulAddConversionTests: public CommonTestUtils::TestsCommon,
|
||||
public testing::WithParamInterface<std::tuple<std::tuple<InputShape, MulConstant, AddConstant>, RefFunction> > {
|
||||
public:
|
||||
std::shared_ptr<ngraph::Function> f, f_ref;
|
||||
|
||||
void SetUp() override {
|
||||
const auto& attrs = std::get<0>(GetParam());
|
||||
const auto& input_shape = std::get<0>(attrs);
|
||||
const auto& mul_const = std::get<1>(attrs);
|
||||
const auto& add_const = std::get<2>(attrs);
|
||||
const auto& get_ref_function = std::get<1>(GetParam());
|
||||
|
||||
f = get_initial_function(input_shape, mul_const, add_const);
|
||||
f_ref = get_ref_function(input_shape, mul_const, add_const);
|
||||
}
|
||||
|
||||
static
|
||||
std::shared_ptr<ngraph::Function> get_initial_function(const InputShape& input_shape,
|
||||
const MulConstant& mul_const,
|
||||
const AddConstant& add_const) {
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
||||
ngraph::Output<ngraph::Node> last = input;
|
||||
if (!mul_const.skip) {
|
||||
last = std::make_shared<ngraph::opset1::Multiply>(last, create_constant(mul_const.shape, mul_const.value));
|
||||
}
|
||||
if (!add_const.skip) {
|
||||
last = std::make_shared<ngraph::opset1::Add>(last, create_constant(add_const.shape, add_const.value));
|
||||
}
|
||||
last = std::make_shared<ngraph::opset1::Relu>(last);
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{last.get_node_shared_ptr()}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
static
|
||||
std::shared_ptr<ngraph::Function> get_scale_shift_reference(const InputShape& input_shape,
|
||||
const MulConstant& mul_const,
|
||||
const AddConstant& add_const) {
|
||||
if (mul_const.skip && add_const.skip) {
|
||||
throw ngraph::ngraph_error("Invalid arguments");
|
||||
}
|
||||
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
||||
auto scsh = std::make_shared<ngraph::op::ScaleShiftIE>(input, (!mul_const.skip ? create_constant(mul_const.shape, mul_const.value)
|
||||
: create_constant(add_const.shape, 1)),
|
||||
(!add_const.skip ? create_constant(add_const.shape, add_const.value)
|
||||
: create_constant(mul_const.shape, 0)));
|
||||
auto relu = std::make_shared<ngraph::opset1::Relu>(scsh);
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{relu}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
static
|
||||
std::shared_ptr<ngraph::Function> get_power_reference(const InputShape& input_shape,
|
||||
const MulConstant& mul_const,
|
||||
const AddConstant& add_const) {
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
||||
float scale(1), shift(0);
|
||||
if (!mul_const.skip) scale = mul_const.value;
|
||||
if (!add_const.skip) shift = add_const.value;
|
||||
auto pow = std::make_shared<ngraph::op::PowerIE>(input, 1., scale, shift);
|
||||
auto relu = std::make_shared<ngraph::opset1::Relu>(pow);
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{relu}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
static
|
||||
std::shared_ptr<ngraph::Function> get_eltwise_add_reference(const InputShape& input_shape,
|
||||
const MulConstant& mul_const,
|
||||
const AddConstant& add_const) {
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
||||
auto add = std::make_shared<ngraph::op::Eltwise>(input, create_constant(add_const.shape, add_const.value), ELTWISE_TYPE::Sum);
|
||||
auto relu = std::make_shared<ngraph::opset1::Relu>(add);
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{relu}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
static
|
||||
std::shared_ptr<ngraph::Function> get_eltwise_mul_reference(const InputShape& input_shape,
|
||||
const MulConstant& mul_const,
|
||||
const AddConstant& add_const) {
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
||||
auto mul = std::make_shared<ngraph::op::Eltwise>(input, create_constant(mul_const.shape, mul_const.value), ELTWISE_TYPE::Prod);
|
||||
auto relu = std::make_shared<ngraph::opset1::Relu>(mul);
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{relu}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
static
|
||||
std::shared_ptr<ngraph::opset1::Constant> create_constant(const ngraph::Shape & shape, float init_value) {
|
||||
return ngraph::opset1::Constant::create(ngraph::element::f32, shape, {init_value});
|
||||
}
|
||||
};
|
||||
|
||||
class MulOrAddConversionTests: public MulAddConversionTests {};
|
||||
|
||||
TEST_P(MulAddConversionTests, CompareFunctions) {
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertMulAddToScaleShiftOrPower().run_on_function(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ngraph::pass::ConstantFolding().run_on_function(f);
|
||||
f->validate_nodes_and_infer_types();
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST_P(MulOrAddConversionTests, CompareFunctions) {
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertMulOrAddFinally().run_on_function(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ngraph::pass::ConstantFolding().run_on_function(f);
|
||||
f->validate_nodes_and_infer_types();
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
#define CONST(A, B) ConstantParams(A, B)
|
||||
#define NONE ConstantParams()
|
||||
#define SCALESHIFT MulAddConversionTests::get_scale_shift_reference
|
||||
#define POWER MulAddConversionTests::get_power_reference
|
||||
#define SAME MulAddConversionTests::get_initial_function
|
||||
#define ELTWISE_SUM MulAddConversionTests::get_eltwise_add_reference
|
||||
#define ELTWISE_PROD MulAddConversionTests::get_eltwise_mul_reference
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(MulAddToScaleShift, MulAddConversionTests, testing::Combine(
|
||||
testing::Values(std::make_tuple(InputShape{DYN, 3, 64, 64},
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN, 64},
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN, DYN},
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5))),
|
||||
testing::Values(SCALESHIFT)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(MulToScaleShift, MulOrAddConversionTests, testing::Combine(
|
||||
testing::Values(std::make_tuple(InputShape{DYN, 3, 64, 64},
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
|
||||
NONE),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN, 64},
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
|
||||
NONE),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN, DYN},
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
|
||||
NONE)),
|
||||
testing::Values(SCALESHIFT)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(AddToScaleShift, MulOrAddConversionTests, testing::Combine(
|
||||
testing::Values(std::make_tuple(InputShape{DYN, 3, 64, 64},
|
||||
NONE,
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN, 64},
|
||||
NONE,
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN, DYN},
|
||||
NONE,
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5))),
|
||||
testing::Values(SCALESHIFT)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(MulAddToPower, MulAddConversionTests, testing::Combine(
|
||||
testing::Values(std::make_tuple(InputShape{DYN, 3, 64, 64},
|
||||
CONST(ngraph::Shape({1}), 0.5),
|
||||
CONST(ngraph::Shape({1}), 0.5)),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN, 64},
|
||||
CONST(ngraph::Shape({1}), 0.5),
|
||||
CONST(ngraph::Shape({1}), 0.5)),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN, DYN},
|
||||
CONST(ngraph::Shape({1}), 0.5),
|
||||
CONST(ngraph::Shape({1}), 0.5))),
|
||||
testing::Values(POWER)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(MulToPower, MulOrAddConversionTests, testing::Combine(
|
||||
testing::Values(std::make_tuple(InputShape{DYN, 3, 64, 64},
|
||||
CONST(ngraph::Shape({1}), 0.5),
|
||||
NONE),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN, 64},
|
||||
CONST(ngraph::Shape({1}), 0.5),
|
||||
NONE),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN, DYN},
|
||||
CONST(ngraph::Shape({1}), 0.5),
|
||||
NONE)),
|
||||
testing::Values(POWER)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(AddToPower, MulOrAddConversionTests, testing::Combine(
|
||||
testing::Values(std::make_tuple(InputShape{DYN, 3, 64, 64},
|
||||
NONE,
|
||||
CONST(ngraph::Shape({1}), 0.5)),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN, 64},
|
||||
NONE,
|
||||
CONST(ngraph::Shape({1}), 0.5)),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN, DYN},
|
||||
NONE,
|
||||
CONST(ngraph::Shape({1}), 0.5))),
|
||||
testing::Values(POWER)));
|
||||
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(MulAddNegative, MulAddConversionTests, testing::Combine(
|
||||
testing::Values(std::make_tuple(InputShape{DYN, 3, 64},
|
||||
CONST(ngraph::Shape({1, 3, 1}), 0.5),
|
||||
CONST(ngraph::Shape({1, 3, 1}), 0.5)/*ScaleShift must always be 4D*/),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN},
|
||||
CONST(ngraph::Shape({1, 1, 3, 1}), 0.5),
|
||||
CONST(ngraph::Shape({3, 1}), 0.5)/*detect broadcast case*/),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN},
|
||||
CONST(ngraph::Shape({3, 1}), 0.5),
|
||||
CONST(ngraph::Shape({1, 1, 3, 1}), 0.5)/*detect broadcast case*/),
|
||||
std::make_tuple(InputShape{DYN, DYN, DYN, DYN},
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)),
|
||||
std::make_tuple(InputShape{DYN, DYN, DYN, DYN},
|
||||
CONST(ngraph::Shape({1, 3, 2, 1}), 0.5),
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)),
|
||||
std::make_tuple(InputShape{1, 3, 2},
|
||||
CONST(ngraph::Shape({1, 3, 1}), 0.5),
|
||||
CONST(ngraph::Shape({1, 3, 2}), 0.5)),
|
||||
std::make_tuple(InputShape{1, DYN, 64, 64},
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5))),
|
||||
testing::Values(SAME)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(MulToEltwise, MulOrAddConversionTests, testing::Combine(
|
||||
testing::Values(std::make_tuple(InputShape{DYN, 3, 64},
|
||||
CONST(ngraph::Shape({1, 1, 64}), 0.5),
|
||||
NONE),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN},
|
||||
CONST(ngraph::Shape({1, 1, 3, 1}), 0.5),
|
||||
NONE),
|
||||
std::make_tuple(InputShape{DYN, DYN, DYN, DYN},
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
|
||||
NONE),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN, DYN},
|
||||
CONST(ngraph::Shape({1, 3, 2, 1}), 0.5),
|
||||
NONE),
|
||||
std::make_tuple(InputShape{1, 3, 2},
|
||||
CONST(ngraph::Shape({1, 3, 2}), 0.5),
|
||||
NONE),
|
||||
std::make_tuple(InputShape{1, DYN, 64, 64},
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5),
|
||||
NONE),
|
||||
std::make_tuple(InputShape{64, 1, 64},
|
||||
CONST(ngraph::Shape({64, 64, 64}), 1),
|
||||
NONE),
|
||||
std::make_tuple(InputShape{64, 64, 1},
|
||||
CONST(ngraph::Shape({1, 1, 64}), 1),
|
||||
NONE),
|
||||
std::make_tuple(InputShape{DYN, 1, 64},
|
||||
CONST(ngraph::Shape({64, 1, 64}), 1),
|
||||
NONE)),
|
||||
testing::Values(ELTWISE_PROD)));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(AddToEltwise, MulOrAddConversionTests, testing::Combine(
|
||||
testing::Values(std::make_tuple(InputShape{DYN, 3, 64},
|
||||
NONE,
|
||||
CONST(ngraph::Shape({1, 1, 64}), 0.5)),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN},
|
||||
NONE,
|
||||
CONST(ngraph::Shape({1, 1, 3, 1}), 0.5)),
|
||||
std::make_tuple(InputShape{DYN, DYN, DYN, DYN},
|
||||
NONE,
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5)),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN, DYN},
|
||||
NONE,
|
||||
CONST(ngraph::Shape({1, 3, 2, 1}), 0.5)),
|
||||
std::make_tuple(InputShape{1, 3, 2},
|
||||
NONE,
|
||||
CONST(ngraph::Shape({1, 3, 2}), 0.5)),
|
||||
std::make_tuple(InputShape{1, DYN, 64, 64},
|
||||
NONE,
|
||||
CONST(ngraph::Shape({1, 3, 1, 1}), 0.5))),
|
||||
testing::Values(ELTWISE_SUM)));
|
||||
|
||||
#undef CONST
|
||||
#undef SCALESHIFT
|
||||
#undef POWER
|
||||
#undef SAME
|
||||
#undef ELTWISE_PROD
|
||||
#undef ELTWISE_SUM
|
Loading…
Reference in New Issue
Block a user