Improve ConvertBroadcast3 pass to avoid extra Multiply operations for BIDIRECTIONAL mode (#3113)

* Fixed ConvertBroadcast3 pass for BIDIRECTIONAL mode to avoid excess Multiply operations

* Added funcitonal tests for new decompositions

* Return false if mode is unknown; avoid usign node in replace_node

* Added functional tests for cases when TargetShape input is not a Constant
This commit is contained in:
Gleb Kazantaev 2020-11-13 14:39:07 +03:00 committed by GitHub
parent c3683341f3
commit e79298fb40
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 395 additions and 49 deletions

View File

@ -19,13 +19,8 @@ class TRANSFORMATIONS_API ConvertBroadcast3;
} // namespace pass
} // namespace ngraph
class ngraph::pass::ConvertBroadcast3: public ngraph::pass::GraphRewrite {
class ngraph::pass::ConvertBroadcast3: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
ConvertBroadcast3() : GraphRewrite() {
convert_broadcast3();
}
private:
void convert_broadcast3();
ConvertBroadcast3();
};

View File

@ -11,50 +11,101 @@
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertBroadcast3, "ConvertBroadcast3", 0);
void ngraph::pass::ConvertBroadcast3::convert_broadcast3() {
auto broadcast = std::make_shared<pattern::op::Label>(element::f32, Shape {}, pattern::has_class<opset3::Broadcast>());
bool make_compatible_shape(const ngraph::PartialShape & input_shape, std::vector<size_t> & target_shape) {
if (input_shape.rank().is_dynamic()) {
return false;
}
const int64_t & input_shape_rank = input_shape.rank().get_length();
if (input_shape_rank > target_shape.size()) {
// target_shape rank must greater or equal to input_shape rank, so in case when it's less we
// insert missing input_shape dimensions to the beginning of the target_shape.
const int64_t & dims_to_add_count = input_shape_rank - target_shape.size();
std::vector<size_t> dims_to_add(dims_to_add_count);
for (size_t dim = 0; dim < dims_to_add_count; ++dim) {
if (input_shape[dim].is_dynamic()) {
return false;
}
dims_to_add[dim] = input_shape[dim].get_length();
}
target_shape.insert(target_shape.begin(), dims_to_add.begin(), dims_to_add.end());
}
for (int64_t i_dim = input_shape_rank - 1, t_dim = target_shape.size() - 1; i_dim >= 0 && t_dim >= 0; --i_dim, --t_dim) {
if (input_shape[i_dim].is_static()) {
const auto & input_dim = input_shape[i_dim].get_length();
if (input_dim != target_shape[t_dim] && input_dim != 1 && target_shape[t_dim] != 1) {
// this dimensions are not broadcastable
return false;
}
target_shape[t_dim] = std::max(target_shape[t_dim], static_cast<size_t>(input_dim));
} else {
if (target_shape[t_dim] == 1) {
// For example: |
// \/
// input_shape [DYN, 3, 4]
// target_shape [ 1, 3, 4] - broadcasted first dimension is unknown
return false;
}
}
}
return true;
}
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
auto broadcast = std::dynamic_pointer_cast<ngraph::opset3::Broadcast>(m.get_match_root());
ngraph::pass::ConvertBroadcast3::ConvertBroadcast3() {
auto broadcast = pattern::wrap_type<opset3::Broadcast>();
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto broadcast = std::dynamic_pointer_cast<opset3::Broadcast>(m.get_match_root());
if (!broadcast) {
return false;
}
auto input = broadcast->input_value(0);
auto target_shape = broadcast->input_value(1);
auto last_node = input.get_node_shared_ptr();
auto broadcast_type = broadcast->get_broadcast_spec();
auto target_shape_input = broadcast->input_value(1);
const auto & broadcast_type = broadcast->get_broadcast_spec();
const auto & input_element_type = input.get_element_type();
if (broadcast_type == op::BroadcastType::NUMPY) {
last_node = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape, op::AutoBroadcastType::NUMPY);
ngraph::copy_runtime_info(broadcast, last_node);
input = std::make_shared<opset1::Broadcast>(input, target_shape_input, op::AutoBroadcastType::NUMPY);
} else if (broadcast_type == op::BroadcastType::PDPD) {
last_node = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape, op::AutoBroadcastType::PDPD);
ngraph::copy_runtime_info(broadcast, last_node);
input = std::make_shared<opset1::Broadcast>(input, target_shape_input, op::AutoBroadcastType::PDPD);
} else if (broadcast_type == op::BroadcastType::NONE) {
last_node = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape, broadcast->input_value(2), op::AutoBroadcastType::NONE);
ngraph::copy_runtime_info(broadcast, last_node);
input = std::make_shared<opset1::Broadcast>(input, target_shape_input, broadcast->input_value(2), op::AutoBroadcastType::NONE);
} else if (broadcast_type == op::BroadcastType::BIDIRECTIONAL) {
auto constant_one = std::make_shared<ngraph::opset1::Constant>(input.get_element_type(), Shape({1}), std::vector<int>{1});
auto broadcast_ones = std::make_shared<ngraph::opset1::Broadcast>(constant_one, target_shape, op::AutoBroadcastType::NUMPY);
if (input.get_element_type() == element::boolean) {
last_node = std::make_shared<ngraph::opset1::LogicalOr>(input, broadcast_ones);
if (auto const_target_shape = std::dynamic_pointer_cast<opset1::Constant>(target_shape_input.get_node_shared_ptr())) {
const auto & input_shape = input.get_partial_shape();
const auto & target_shape = const_target_shape->cast_vector<size_t>();
std::vector<size_t> aligned_target_shape{target_shape};
if (make_compatible_shape(input_shape, aligned_target_shape)) {
input = std::make_shared<opset1::Broadcast>(input,
opset1::Constant::create(element::i64, Shape({aligned_target_shape.size()}), aligned_target_shape));
} else {
input = std::make_shared<opset1::Multiply>(input,
opset1::Constant::create(input_element_type, target_shape, {1}));
}
} else {
last_node = std::make_shared<ngraph::opset1::Multiply>(input, broadcast_ones);
auto constant_one = opset1::Constant::create(input_element_type, {1}, {1});
auto broadcast_ones = std::make_shared<opset1::Broadcast>(constant_one, target_shape_input);
if (input_element_type == element::boolean) {
input = std::make_shared<ngraph::opset1::LogicalOr>(input, broadcast_ones);
} else {
input = std::make_shared<ngraph::opset1::Multiply>(input, broadcast_ones);
}
copy_runtime_info(broadcast, broadcast_ones);
}
ngraph::copy_runtime_info(broadcast, {last_node, broadcast_ones, constant_one});
} else {
return false;
}
last_node->set_friendly_name(broadcast->get_friendly_name());
ngraph::replace_node(m.get_match_root(), last_node);
input.get_node_shared_ptr()->set_friendly_name(broadcast->get_friendly_name());
copy_runtime_info(broadcast, input.get_node_shared_ptr());
replace_node(broadcast, {input});
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(broadcast, "ConvertBroadcast3");
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
auto m = std::make_shared<pattern::Matcher>(broadcast, "ConvertBroadcast3");
register_matcher(m, callback);
}

View File

@ -4,20 +4,314 @@
#include <gtest/gtest.h>
#include "common_test_utils/test_common.hpp"
#include <string>
#include <sstream>
#include <fstream>
#include <memory>
#include <queue>
#include <map>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <transformations/op_conversions/convert_broadcast3.hpp>
#include <transformations/init_node_info.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <transformations/utils/utils.hpp>
#include <transformations/init_node_info.hpp>
#include <ngraph/pass/visualize_tree.hpp>
#include <transformations/op_conversions/convert_broadcast3.hpp>
#include <ngraph_ops/convolution_ie.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ngraph;
using InputShape = PartialShape;
using TargetShape = Shape;
void convert_broadcast3_test(std::shared_ptr<Function> f, std::shared_ptr<Function> f_ref) {
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertBroadcast3>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
class ConvertBroadcast3NUMPYTest: public CommonTestUtils::TestsCommon,
public testing::WithParamInterface<std::tuple<InputShape, TargetShape>> {
public:
std::shared_ptr<Function> f, f_ref;
void SetUp() override {
const auto& input_shape = std::get<0>(GetParam());
const auto& target_shape = std::get<1>(GetParam());
f = get_initial_function(input_shape, target_shape);
f_ref = get_reference_broadcast(input_shape, target_shape);
}
std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
const TargetShape & target_shape) {
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{target_shape.size()}, target_shape);
auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::NUMPY);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
}
std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
const TargetShape & target_shape) {
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{target_shape.size()}, target_shape);
auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape_node, op::AutoBroadcastType::NUMPY);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
}
};
class ConvertBroadcast3BIDIRECTMulTest: public CommonTestUtils::TestsCommon,
public testing::WithParamInterface<std::tuple<InputShape, TargetShape>> {
public:
std::shared_ptr<Function> f, f_ref;
void SetUp() override {
const auto& input_shape = std::get<0>(GetParam());
const auto& target_shape = std::get<1>(GetParam());
f = get_initial_function(input_shape, target_shape);
f_ref = get_reference_broadcast(input_shape, target_shape);
}
std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
const TargetShape & target_shape) {
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{target_shape.size()}, target_shape);
auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::BIDIRECTIONAL);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
}
std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
const TargetShape & target_shape) {
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
auto const_node = ngraph::opset1::Constant::create(ngraph::element::f32, Shape{target_shape}, {1});
auto mul = std::make_shared<ngraph::opset1::Multiply>(input, const_node);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
}
};
class ConvertBroadcast3BIDIRECTBroadcastTest: public CommonTestUtils::TestsCommon,
public testing::WithParamInterface<std::tuple<InputShape, TargetShape, TargetShape>> {
public:
std::shared_ptr<Function> f, f_ref;
void SetUp() override {
const auto& input_shape = std::get<0>(GetParam());
const auto& target_shape = std::get<1>(GetParam());
const auto& aligned_target_shape = std::get<2>(GetParam());
f = get_initial_function(input_shape, target_shape);
f_ref = get_reference_broadcast(input_shape, aligned_target_shape);
}
std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
const TargetShape & target_shape) {
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{target_shape.size()}, target_shape);
auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::BIDIRECTIONAL);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
}
std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
const TargetShape & aligned_target_shape) {
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, Shape{aligned_target_shape.size()}, aligned_target_shape);
auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape_node, op::AutoBroadcastType::NUMPY);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
}
};
class ConvertBroadcast3BIDIRECTBroadcastMultiplyTest: public CommonTestUtils::TestsCommon,
public testing::WithParamInterface<std::tuple<InputShape, TargetShape>> {
public:
std::shared_ptr<Function> f, f_ref;
void SetUp() override {
const auto& input_shape = std::get<0>(GetParam());
const auto& target_shape = std::get<1>(GetParam());
f = get_initial_function(input_shape, target_shape);
f_ref = get_reference_broadcast(input_shape, target_shape);
}
std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
const TargetShape & target_shape) {
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
auto target_shape_node = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i64, target_shape);
auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::BIDIRECTIONAL);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input, target_shape_node});
}
std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
const TargetShape & target_shape) {
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
auto target_shape_node = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i64, target_shape);
auto constant_one = opset1::Constant::create(ngraph::element::f32, {1}, {1});
auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(constant_one, target_shape_node, op::AutoBroadcastType::NUMPY);
auto mul = std::make_shared<ngraph::opset1::Multiply>(input, broadcast);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input, target_shape_node});
}
};
class ConvertBroadcast3BIDIRECTBroadcastLogicalOrTest: public CommonTestUtils::TestsCommon,
public testing::WithParamInterface<std::tuple<InputShape, TargetShape>> {
public:
std::shared_ptr<Function> f, f_ref;
void SetUp() override {
const auto& input_shape = std::get<0>(GetParam());
const auto& target_shape = std::get<1>(GetParam());
f = get_initial_function(input_shape, target_shape);
f_ref = get_reference_broadcast(input_shape, target_shape);
}
std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
const TargetShape & target_shape) {
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::boolean, input_shape);
auto target_shape_node = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i64, target_shape);
auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::BIDIRECTIONAL);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input, target_shape_node});
}
std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
const TargetShape & target_shape) {
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::boolean, input_shape);
auto target_shape_node = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i64, target_shape);
auto constant_one = opset1::Constant::create(ngraph::element::boolean, {1}, {1});
auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(constant_one, target_shape_node, op::AutoBroadcastType::NUMPY);
auto mul = std::make_shared<ngraph::opset1::LogicalOr>(input, broadcast);
return std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input, target_shape_node});
}
};
TEST_P(ConvertBroadcast3NUMPYTest, CompareFunctions) {
convert_broadcast3_test(f, f_ref);
}
TEST_P(ConvertBroadcast3BIDIRECTMulTest, CompareFunctions) {
convert_broadcast3_test(f, f_ref);
}
TEST_P(ConvertBroadcast3BIDIRECTBroadcastTest, CompareFunctions) {
convert_broadcast3_test(f, f_ref);
}
TEST_P(ConvertBroadcast3BIDIRECTBroadcastMultiplyTest, CompareFunctions) {
convert_broadcast3_test(f, f_ref);
}
TEST_P(ConvertBroadcast3BIDIRECTBroadcastLogicalOrTest, CompareFunctions) {
convert_broadcast3_test(f, f_ref);
}
INSTANTIATE_TEST_CASE_P(ConvertBroadcast3NUMPY, ConvertBroadcast3NUMPYTest,
testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{1, 2, 3, 4, 5}),
std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, TargetShape{8, 3, 64, 64, 64}),
std::make_tuple(InputShape{2, DYN, 64, 64, 64}, TargetShape{2, 3, 64, 64, 64}),
std::make_tuple(InputShape{3, 1, DYN, 64, 64}, TargetShape{3, 3, 3, 64, 64}),
std::make_tuple(InputShape{3, 3, 64, DYN, 64}, TargetShape{3, 3, 64, 64, 64}),
std::make_tuple(InputShape{3, 3, 64, 64, DYN}, TargetShape{3, 3, 64, 64, 3}),
std::make_tuple(InputShape{1, 3, 64, 64}, TargetShape{6, 3, 64, 64}),
std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{7, 3, 1, 1}),
std::make_tuple(InputShape{DYN, 3, 64, 64}, TargetShape{8, 3, 64, 64}),
std::make_tuple(InputShape{2, DYN, 64, 64}, TargetShape{2, 3, 64, 64}),
std::make_tuple(InputShape{3, 3, DYN, 64}, TargetShape{3, 3, 3, 64}),
std::make_tuple(InputShape{3, 3, 64, DYN}, TargetShape{3, 3, 64, 4}),
std::make_tuple(InputShape{DYN, DYN, DYN}, TargetShape{5, 3, 1}),
std::make_tuple(InputShape{DYN, 3, 10}, TargetShape{3, 3, 10}),
std::make_tuple(InputShape{2, DYN, 9}, TargetShape{2, 3, 9}),
std::make_tuple(InputShape{3, 3, DYN}, TargetShape{3, 3, 3})));
INSTANTIATE_TEST_CASE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTMulTest,
testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{1, 2, 3, 4, 5}),
std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, TargetShape{1, 3, 64, 64, 64}),
std::make_tuple(InputShape{2, DYN, 64, 64, 64}, TargetShape{2, 1, 64, 64, 64}),
std::make_tuple(InputShape{3, 1, DYN, 64, 64}, TargetShape{3, 3, 1, 64, 64}),
std::make_tuple(InputShape{DYN, 1, DYN, 64, DYN}, TargetShape{3, 3, 3, 64, 1}),
std::make_tuple(InputShape{3, 3, 64, DYN, 64}, TargetShape{3, 3, 64, 1, 64}),
std::make_tuple(InputShape{3, 3, 64, 64, DYN}, TargetShape{3, 3, 64, 64, 1}),
std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{7, 3, 1, 1}),
std::make_tuple(InputShape{DYN, 3, 64, 64}, TargetShape{1, 3, 64, 64}),
std::make_tuple(InputShape{2, DYN, 64, 64}, TargetShape{2, 1, 64, 64}),
std::make_tuple(InputShape{3, 3, DYN, 64}, TargetShape{3, 3, 1, 64}),
std::make_tuple(InputShape{DYN, 3, DYN, 64}, TargetShape{3, 3, 64}),
std::make_tuple(InputShape{3, 3, 64, DYN}, TargetShape{3, 3, 64, 1}),
std::make_tuple(InputShape{DYN, DYN, DYN}, TargetShape{5, 3, 1}),
std::make_tuple(InputShape{DYN, 3, 10}, TargetShape{1, 3, 10}),
std::make_tuple(InputShape{DYN, 3, 10}, TargetShape{10}),
std::make_tuple(InputShape{2, DYN, 9}, TargetShape{2, 1, 9}),
std::make_tuple(InputShape{3, 3, DYN}, TargetShape{3, 3, 1})));
INSTANTIATE_TEST_CASE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTBroadcastTest,
testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{2, 2, 3, 4, 5}, TargetShape{2, 2, 3, 4, 5}),
std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, TargetShape{3, 3, 64, 64, 64}, TargetShape{3, 3, 64, 64, 64}),
std::make_tuple(InputShape{2, DYN, 64, 64, 64}, TargetShape{2, 3, 64, 64, 1}, TargetShape{2, 3, 64, 64, 64}),
std::make_tuple(InputShape{3, 1, DYN, 64, 64}, TargetShape{1, 3, 3, 64, 64}, TargetShape{3, 3, 3, 64, 64}),
std::make_tuple(InputShape{3, 1, DYN, 64, DYN}, TargetShape{1, 3, 3, 64, 3}, TargetShape{3, 3, 3, 64, 3}),
std::make_tuple(InputShape{3, 3, 64, DYN, 64}, TargetShape{1, 1, 1, 2, 64}, TargetShape{3, 3, 64, 2, 64}),
std::make_tuple(InputShape{3, 3, 64, 64, DYN}, TargetShape{3, 3, 64, 64, 3}, TargetShape{3, 3, 64, 64, 3}),
std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{7, 3, 2, 3}, TargetShape{7, 3, 2, 3}),
std::make_tuple(InputShape{DYN, 3, 64, 64}, TargetShape{3, 3, 64, 64}, TargetShape{3, 3, 64, 64}),
std::make_tuple(InputShape{2, DYN, 64, 64}, TargetShape{2, 3, 64, 64}, TargetShape{2, 3, 64, 64}),
std::make_tuple(InputShape{3, 3, DYN, 64}, TargetShape{1, 3, 1}, TargetShape{3, 3, 3, 64}),
std::make_tuple(InputShape{3, 3, DYN, 64}, TargetShape{3, 3, 64}, TargetShape{3, 3, 3, 64}),
std::make_tuple(InputShape{3, 3, 64, DYN}, TargetShape{64}, TargetShape{3, 3, 64, 64}),
std::make_tuple(InputShape{DYN, DYN, DYN}, TargetShape{5, 3, 3}, TargetShape{5, 3, 3}),
std::make_tuple(InputShape{1, 3, DYN}, TargetShape{3, 3, 10}, TargetShape{3, 3, 10}),
std::make_tuple(InputShape{2, DYN, 9}, TargetShape{2, 2, 1}, TargetShape{2, 2, 9}),
std::make_tuple(InputShape{3, 3, DYN}, TargetShape{3}, TargetShape{3, 3, 3})));
INSTANTIATE_TEST_CASE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTBroadcastMultiplyTest,
testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{5}),
std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, TargetShape{4}),
std::make_tuple(InputShape{2, DYN, 64, 64, 64}, TargetShape{3}),
std::make_tuple(InputShape{3, 1, DYN, 64, 64}, TargetShape{2}),
std::make_tuple(InputShape{3, 3, 64, DYN, 64}, TargetShape{1}),
std::make_tuple(InputShape{1, 3, 64, 64}, TargetShape{5}),
std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{4}),
std::make_tuple(InputShape{DYN, 3, 64, 64}, TargetShape{3}),
std::make_tuple(InputShape{2, DYN, 64, 64}, TargetShape{2}),
std::make_tuple(InputShape{3, 3, DYN, 64}, TargetShape{1}),
std::make_tuple(InputShape{DYN, DYN, DYN}, TargetShape{5}),
std::make_tuple(InputShape{DYN, 3, 10}, TargetShape{4}),
std::make_tuple(InputShape{2, DYN, 9}, TargetShape{3}),
std::make_tuple(InputShape{3, 3, DYN}, TargetShape{2})));
INSTANTIATE_TEST_CASE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTBroadcastLogicalOrTest,
testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{5}),
std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, TargetShape{4}),
std::make_tuple(InputShape{2, DYN, 64, 64, 64}, TargetShape{3}),
std::make_tuple(InputShape{3, 1, DYN, 64, 64}, TargetShape{2}),
std::make_tuple(InputShape{3, 3, 64, DYN, 64}, TargetShape{1}),
std::make_tuple(InputShape{1, 3, 64, 64}, TargetShape{5}),
std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{4}),
std::make_tuple(InputShape{DYN, 3, 64, 64}, TargetShape{3}),
std::make_tuple(InputShape{2, DYN, 64, 64}, TargetShape{2}),
std::make_tuple(InputShape{3, 3, DYN, 64}, TargetShape{1}),
std::make_tuple(InputShape{DYN, DYN, DYN}, TargetShape{5}),
std::make_tuple(InputShape{DYN, 3, 10}, TargetShape{4}),
std::make_tuple(InputShape{2, DYN, 9}, TargetShape{3}),
std::make_tuple(InputShape{3, 3, DYN}, TargetShape{2})));
// Broadcast-3 is converted directly to Broadcast-1 for modes NUMPY, NONE and PDPD
TEST(TransformationTests, ConvertBroadcast3WithNumpyModeToBroadcast1) {
@ -30,8 +324,10 @@ TEST(TransformationTests, ConvertBroadcast3WithNumpyModeToBroadcast1) {
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertBroadcast3().run_on_function(f);
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertBroadcast3>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
@ -63,8 +359,10 @@ TEST(TransformationTests, ConvertBroadcast3WithPDPDModeToBroadcast1) {
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertBroadcast3().run_on_function(f);
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertBroadcast3>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
@ -97,8 +395,10 @@ TEST(TransformationTests, ConvertBroadcast3WithExplicitModeToBroadcast1) {
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertBroadcast3().run_on_function(f);
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertBroadcast3>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
@ -131,20 +431,20 @@ TEST(TransformationTests, ConvertBroadcast3WithBidirectionalModeToBroadcast1) {
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
ngraph::pass::InitNodeInfo().run_on_function(f);
ngraph::pass::ConvertBroadcast3().run_on_function(f);
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertBroadcast3>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 2});
auto target_shape = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{3}, std::vector<int64_t>{3, 5, 1});
auto constant_one = std::make_shared<ngraph::opset1::Constant>(input->get_output_element_type(0), ngraph::Shape({1}), std::vector<int>{1});
auto broadcast_ones = std::make_shared<ngraph::opset1::Broadcast>(constant_one, target_shape, ngraph::op::AutoBroadcastType::NUMPY);
auto multiply = std::make_shared<ngraph::opset1::Multiply>(input, broadcast_ones);
multiply->set_friendly_name("broadcast");
auto target_shape = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{3}, std::vector<int64_t>{3, 5, 2});
auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape, ngraph::op::AutoBroadcastType::NUMPY);
broadcast->set_friendly_name("broadcast");
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{multiply}, ngraph::ParameterVector{input});
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
}
auto res = compare_functions(f, f_ref);