Improve ConvertBroadcast3 pass to avoid extra Multiply operations for BIDIRECTIONAL mode (#3113)
* Fixed ConvertBroadcast3 pass for BIDIRECTIONAL mode to avoid excess Multiply operations * Added funcitonal tests for new decompositions * Return false if mode is unknown; avoid usign node in replace_node * Added functional tests for cases when TargetShape input is not a Constant
This commit is contained in:
parent
c3683341f3
commit
e79298fb40
@ -19,13 +19,8 @@ class TRANSFORMATIONS_API ConvertBroadcast3;
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::ConvertBroadcast3: public ngraph::pass::GraphRewrite {
|
||||
class ngraph::pass::ConvertBroadcast3: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ConvertBroadcast3() : GraphRewrite() {
|
||||
convert_broadcast3();
|
||||
}
|
||||
|
||||
private:
|
||||
void convert_broadcast3();
|
||||
ConvertBroadcast3();
|
||||
};
|
||||
|
@ -11,50 +11,101 @@
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertBroadcast3, "ConvertBroadcast3", 0);
|
||||
|
||||
void ngraph::pass::ConvertBroadcast3::convert_broadcast3() {
|
||||
auto broadcast = std::make_shared<pattern::op::Label>(element::f32, Shape {}, pattern::has_class<opset3::Broadcast>());
|
||||
bool make_compatible_shape(const ngraph::PartialShape & input_shape, std::vector<size_t> & target_shape) {
|
||||
if (input_shape.rank().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
const int64_t & input_shape_rank = input_shape.rank().get_length();
|
||||
if (input_shape_rank > target_shape.size()) {
|
||||
// target_shape rank must greater or equal to input_shape rank, so in case when it's less we
|
||||
// insert missing input_shape dimensions to the beginning of the target_shape.
|
||||
const int64_t & dims_to_add_count = input_shape_rank - target_shape.size();
|
||||
std::vector<size_t> dims_to_add(dims_to_add_count);
|
||||
for (size_t dim = 0; dim < dims_to_add_count; ++dim) {
|
||||
if (input_shape[dim].is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
dims_to_add[dim] = input_shape[dim].get_length();
|
||||
}
|
||||
target_shape.insert(target_shape.begin(), dims_to_add.begin(), dims_to_add.end());
|
||||
}
|
||||
for (int64_t i_dim = input_shape_rank - 1, t_dim = target_shape.size() - 1; i_dim >= 0 && t_dim >= 0; --i_dim, --t_dim) {
|
||||
if (input_shape[i_dim].is_static()) {
|
||||
const auto & input_dim = input_shape[i_dim].get_length();
|
||||
if (input_dim != target_shape[t_dim] && input_dim != 1 && target_shape[t_dim] != 1) {
|
||||
// this dimensions are not broadcastable
|
||||
return false;
|
||||
}
|
||||
target_shape[t_dim] = std::max(target_shape[t_dim], static_cast<size_t>(input_dim));
|
||||
} else {
|
||||
if (target_shape[t_dim] == 1) {
|
||||
// For example: |
|
||||
// \/
|
||||
// input_shape [DYN, 3, 4]
|
||||
// target_shape [ 1, 3, 4] - broadcasted first dimension is unknown
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
||||
auto broadcast = std::dynamic_pointer_cast<ngraph::opset3::Broadcast>(m.get_match_root());
|
||||
ngraph::pass::ConvertBroadcast3::ConvertBroadcast3() {
|
||||
auto broadcast = pattern::wrap_type<opset3::Broadcast>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
auto broadcast = std::dynamic_pointer_cast<opset3::Broadcast>(m.get_match_root());
|
||||
if (!broadcast) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input = broadcast->input_value(0);
|
||||
auto target_shape = broadcast->input_value(1);
|
||||
|
||||
auto last_node = input.get_node_shared_ptr();
|
||||
auto broadcast_type = broadcast->get_broadcast_spec();
|
||||
auto target_shape_input = broadcast->input_value(1);
|
||||
const auto & broadcast_type = broadcast->get_broadcast_spec();
|
||||
const auto & input_element_type = input.get_element_type();
|
||||
|
||||
if (broadcast_type == op::BroadcastType::NUMPY) {
|
||||
last_node = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape, op::AutoBroadcastType::NUMPY);
|
||||
ngraph::copy_runtime_info(broadcast, last_node);
|
||||
input = std::make_shared<opset1::Broadcast>(input, target_shape_input, op::AutoBroadcastType::NUMPY);
|
||||
} else if (broadcast_type == op::BroadcastType::PDPD) {
|
||||
last_node = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape, op::AutoBroadcastType::PDPD);
|
||||
ngraph::copy_runtime_info(broadcast, last_node);
|
||||
input = std::make_shared<opset1::Broadcast>(input, target_shape_input, op::AutoBroadcastType::PDPD);
|
||||
} else if (broadcast_type == op::BroadcastType::NONE) {
|
||||
last_node = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape, broadcast->input_value(2), op::AutoBroadcastType::NONE);
|
||||
ngraph::copy_runtime_info(broadcast, last_node);
|
||||
input = std::make_shared<opset1::Broadcast>(input, target_shape_input, broadcast->input_value(2), op::AutoBroadcastType::NONE);
|
||||
} else if (broadcast_type == op::BroadcastType::BIDIRECTIONAL) {
|
||||
auto constant_one = std::make_shared<ngraph::opset1::Constant>(input.get_element_type(), Shape({1}), std::vector<int>{1});
|
||||
auto broadcast_ones = std::make_shared<ngraph::opset1::Broadcast>(constant_one, target_shape, op::AutoBroadcastType::NUMPY);
|
||||
if (input.get_element_type() == element::boolean) {
|
||||
last_node = std::make_shared<ngraph::opset1::LogicalOr>(input, broadcast_ones);
|
||||
if (auto const_target_shape = std::dynamic_pointer_cast<opset1::Constant>(target_shape_input.get_node_shared_ptr())) {
|
||||
const auto & input_shape = input.get_partial_shape();
|
||||
const auto & target_shape = const_target_shape->cast_vector<size_t>();
|
||||
std::vector<size_t> aligned_target_shape{target_shape};
|
||||
if (make_compatible_shape(input_shape, aligned_target_shape)) {
|
||||
input = std::make_shared<opset1::Broadcast>(input,
|
||||
opset1::Constant::create(element::i64, Shape({aligned_target_shape.size()}), aligned_target_shape));
|
||||
} else {
|
||||
input = std::make_shared<opset1::Multiply>(input,
|
||||
opset1::Constant::create(input_element_type, target_shape, {1}));
|
||||
}
|
||||
} else {
|
||||
last_node = std::make_shared<ngraph::opset1::Multiply>(input, broadcast_ones);
|
||||
auto constant_one = opset1::Constant::create(input_element_type, {1}, {1});
|
||||
auto broadcast_ones = std::make_shared<opset1::Broadcast>(constant_one, target_shape_input);
|
||||
if (input_element_type == element::boolean) {
|
||||
input = std::make_shared<ngraph::opset1::LogicalOr>(input, broadcast_ones);
|
||||
} else {
|
||||
input = std::make_shared<ngraph::opset1::Multiply>(input, broadcast_ones);
|
||||
}
|
||||
copy_runtime_info(broadcast, broadcast_ones);
|
||||
}
|
||||
ngraph::copy_runtime_info(broadcast, {last_node, broadcast_ones, constant_one});
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
last_node->set_friendly_name(broadcast->get_friendly_name());
|
||||
|
||||
ngraph::replace_node(m.get_match_root(), last_node);
|
||||
input.get_node_shared_ptr()->set_friendly_name(broadcast->get_friendly_name());
|
||||
copy_runtime_info(broadcast, input.get_node_shared_ptr());
|
||||
replace_node(broadcast, {input});
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(broadcast, "ConvertBroadcast3");
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
auto m = std::make_shared<pattern::Matcher>(broadcast, "ConvertBroadcast3");
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
@ -4,20 +4,314 @@
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common_test_utils/test_common.hpp"
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <map>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <transformations/op_conversions/convert_broadcast3.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <ngraph/pass/visualize_tree.hpp>
|
||||
#include <transformations/op_conversions/convert_broadcast3.hpp>
|
||||
#include <ngraph_ops/convolution_ie.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
|
||||
using InputShape = PartialShape;
|
||||
using TargetShape = Shape;
|
||||
|
||||
void convert_broadcast3_test(std::shared_ptr<Function> f, std::shared_ptr<Function> f_ref) {
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertBroadcast3>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
class ConvertBroadcast3NUMPYTest: public CommonTestUtils::TestsCommon,
|
||||
public testing::WithParamInterface<std::tuple<InputShape, TargetShape>> {
|
||||
public:
|
||||
std::shared_ptr<Function> f, f_ref;
|
||||
|
||||
void SetUp() override {
|
||||
const auto& input_shape = std::get<0>(GetParam());
|
||||
const auto& target_shape = std::get<1>(GetParam());
|
||||
|
||||
f = get_initial_function(input_shape, target_shape);
|
||||
f_ref = get_reference_broadcast(input_shape, target_shape);
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
|
||||
const TargetShape & target_shape) {
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
||||
auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{target_shape.size()}, target_shape);
|
||||
auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::NUMPY);
|
||||
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
|
||||
const TargetShape & target_shape) {
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
||||
auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{target_shape.size()}, target_shape);
|
||||
auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape_node, op::AutoBroadcastType::NUMPY);
|
||||
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertBroadcast3BIDIRECTMulTest: public CommonTestUtils::TestsCommon,
|
||||
public testing::WithParamInterface<std::tuple<InputShape, TargetShape>> {
|
||||
public:
|
||||
std::shared_ptr<Function> f, f_ref;
|
||||
|
||||
void SetUp() override {
|
||||
const auto& input_shape = std::get<0>(GetParam());
|
||||
const auto& target_shape = std::get<1>(GetParam());
|
||||
|
||||
f = get_initial_function(input_shape, target_shape);
|
||||
f_ref = get_reference_broadcast(input_shape, target_shape);
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
|
||||
const TargetShape & target_shape) {
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
||||
auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{target_shape.size()}, target_shape);
|
||||
auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::BIDIRECTIONAL);
|
||||
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
|
||||
const TargetShape & target_shape) {
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
||||
auto const_node = ngraph::opset1::Constant::create(ngraph::element::f32, Shape{target_shape}, {1});
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply>(input, const_node);
|
||||
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input});
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertBroadcast3BIDIRECTBroadcastTest: public CommonTestUtils::TestsCommon,
|
||||
public testing::WithParamInterface<std::tuple<InputShape, TargetShape, TargetShape>> {
|
||||
public:
|
||||
std::shared_ptr<Function> f, f_ref;
|
||||
|
||||
void SetUp() override {
|
||||
const auto& input_shape = std::get<0>(GetParam());
|
||||
const auto& target_shape = std::get<1>(GetParam());
|
||||
const auto& aligned_target_shape = std::get<2>(GetParam());
|
||||
|
||||
f = get_initial_function(input_shape, target_shape);
|
||||
f_ref = get_reference_broadcast(input_shape, aligned_target_shape);
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
|
||||
const TargetShape & target_shape) {
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
||||
auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{target_shape.size()}, target_shape);
|
||||
auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::BIDIRECTIONAL);
|
||||
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
|
||||
const TargetShape & aligned_target_shape) {
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
||||
auto target_shape_node = ngraph::opset1::Constant::create(ngraph::element::i64, Shape{aligned_target_shape.size()}, aligned_target_shape);
|
||||
auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape_node, op::AutoBroadcastType::NUMPY);
|
||||
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertBroadcast3BIDIRECTBroadcastMultiplyTest: public CommonTestUtils::TestsCommon,
|
||||
public testing::WithParamInterface<std::tuple<InputShape, TargetShape>> {
|
||||
public:
|
||||
std::shared_ptr<Function> f, f_ref;
|
||||
|
||||
void SetUp() override {
|
||||
const auto& input_shape = std::get<0>(GetParam());
|
||||
const auto& target_shape = std::get<1>(GetParam());
|
||||
|
||||
f = get_initial_function(input_shape, target_shape);
|
||||
f_ref = get_reference_broadcast(input_shape, target_shape);
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
|
||||
const TargetShape & target_shape) {
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
||||
auto target_shape_node = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i64, target_shape);
|
||||
auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::BIDIRECTIONAL);
|
||||
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input, target_shape_node});
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
|
||||
const TargetShape & target_shape) {
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
||||
auto target_shape_node = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i64, target_shape);
|
||||
auto constant_one = opset1::Constant::create(ngraph::element::f32, {1}, {1});
|
||||
auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(constant_one, target_shape_node, op::AutoBroadcastType::NUMPY);
|
||||
auto mul = std::make_shared<ngraph::opset1::Multiply>(input, broadcast);
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input, target_shape_node});
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertBroadcast3BIDIRECTBroadcastLogicalOrTest: public CommonTestUtils::TestsCommon,
|
||||
public testing::WithParamInterface<std::tuple<InputShape, TargetShape>> {
|
||||
public:
|
||||
std::shared_ptr<Function> f, f_ref;
|
||||
|
||||
void SetUp() override {
|
||||
const auto& input_shape = std::get<0>(GetParam());
|
||||
const auto& target_shape = std::get<1>(GetParam());
|
||||
|
||||
f = get_initial_function(input_shape, target_shape);
|
||||
f_ref = get_reference_broadcast(input_shape, target_shape);
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> get_initial_function(const InputShape & input_shape,
|
||||
const TargetShape & target_shape) {
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::boolean, input_shape);
|
||||
auto target_shape_node = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i64, target_shape);
|
||||
auto broadcast = std::make_shared<ngraph::opset3::Broadcast>(input, target_shape_node, op::BroadcastType::BIDIRECTIONAL);
|
||||
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input, target_shape_node});
|
||||
}
|
||||
|
||||
std::shared_ptr<Function> get_reference_broadcast(const InputShape & input_shape,
|
||||
const TargetShape & target_shape) {
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::boolean, input_shape);
|
||||
auto target_shape_node = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::i64, target_shape);
|
||||
auto constant_one = opset1::Constant::create(ngraph::element::boolean, {1}, {1});
|
||||
auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(constant_one, target_shape_node, op::AutoBroadcastType::NUMPY);
|
||||
auto mul = std::make_shared<ngraph::opset1::LogicalOr>(input, broadcast);
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{mul}, ngraph::ParameterVector{input, target_shape_node});
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ConvertBroadcast3NUMPYTest, CompareFunctions) {
|
||||
convert_broadcast3_test(f, f_ref);
|
||||
}
|
||||
|
||||
TEST_P(ConvertBroadcast3BIDIRECTMulTest, CompareFunctions) {
|
||||
convert_broadcast3_test(f, f_ref);
|
||||
}
|
||||
|
||||
TEST_P(ConvertBroadcast3BIDIRECTBroadcastTest, CompareFunctions) {
|
||||
convert_broadcast3_test(f, f_ref);
|
||||
}
|
||||
|
||||
TEST_P(ConvertBroadcast3BIDIRECTBroadcastMultiplyTest, CompareFunctions) {
|
||||
convert_broadcast3_test(f, f_ref);
|
||||
}
|
||||
|
||||
TEST_P(ConvertBroadcast3BIDIRECTBroadcastLogicalOrTest, CompareFunctions) {
|
||||
convert_broadcast3_test(f, f_ref);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(ConvertBroadcast3NUMPY, ConvertBroadcast3NUMPYTest,
|
||||
testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{1, 2, 3, 4, 5}),
|
||||
std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, TargetShape{8, 3, 64, 64, 64}),
|
||||
std::make_tuple(InputShape{2, DYN, 64, 64, 64}, TargetShape{2, 3, 64, 64, 64}),
|
||||
std::make_tuple(InputShape{3, 1, DYN, 64, 64}, TargetShape{3, 3, 3, 64, 64}),
|
||||
std::make_tuple(InputShape{3, 3, 64, DYN, 64}, TargetShape{3, 3, 64, 64, 64}),
|
||||
std::make_tuple(InputShape{3, 3, 64, 64, DYN}, TargetShape{3, 3, 64, 64, 3}),
|
||||
std::make_tuple(InputShape{1, 3, 64, 64}, TargetShape{6, 3, 64, 64}),
|
||||
std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{7, 3, 1, 1}),
|
||||
std::make_tuple(InputShape{DYN, 3, 64, 64}, TargetShape{8, 3, 64, 64}),
|
||||
std::make_tuple(InputShape{2, DYN, 64, 64}, TargetShape{2, 3, 64, 64}),
|
||||
std::make_tuple(InputShape{3, 3, DYN, 64}, TargetShape{3, 3, 3, 64}),
|
||||
std::make_tuple(InputShape{3, 3, 64, DYN}, TargetShape{3, 3, 64, 4}),
|
||||
std::make_tuple(InputShape{DYN, DYN, DYN}, TargetShape{5, 3, 1}),
|
||||
std::make_tuple(InputShape{DYN, 3, 10}, TargetShape{3, 3, 10}),
|
||||
std::make_tuple(InputShape{2, DYN, 9}, TargetShape{2, 3, 9}),
|
||||
std::make_tuple(InputShape{3, 3, DYN}, TargetShape{3, 3, 3})));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTMulTest,
|
||||
testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{1, 2, 3, 4, 5}),
|
||||
std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, TargetShape{1, 3, 64, 64, 64}),
|
||||
std::make_tuple(InputShape{2, DYN, 64, 64, 64}, TargetShape{2, 1, 64, 64, 64}),
|
||||
std::make_tuple(InputShape{3, 1, DYN, 64, 64}, TargetShape{3, 3, 1, 64, 64}),
|
||||
std::make_tuple(InputShape{DYN, 1, DYN, 64, DYN}, TargetShape{3, 3, 3, 64, 1}),
|
||||
std::make_tuple(InputShape{3, 3, 64, DYN, 64}, TargetShape{3, 3, 64, 1, 64}),
|
||||
std::make_tuple(InputShape{3, 3, 64, 64, DYN}, TargetShape{3, 3, 64, 64, 1}),
|
||||
std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{7, 3, 1, 1}),
|
||||
std::make_tuple(InputShape{DYN, 3, 64, 64}, TargetShape{1, 3, 64, 64}),
|
||||
std::make_tuple(InputShape{2, DYN, 64, 64}, TargetShape{2, 1, 64, 64}),
|
||||
std::make_tuple(InputShape{3, 3, DYN, 64}, TargetShape{3, 3, 1, 64}),
|
||||
std::make_tuple(InputShape{DYN, 3, DYN, 64}, TargetShape{3, 3, 64}),
|
||||
std::make_tuple(InputShape{3, 3, 64, DYN}, TargetShape{3, 3, 64, 1}),
|
||||
std::make_tuple(InputShape{DYN, DYN, DYN}, TargetShape{5, 3, 1}),
|
||||
std::make_tuple(InputShape{DYN, 3, 10}, TargetShape{1, 3, 10}),
|
||||
std::make_tuple(InputShape{DYN, 3, 10}, TargetShape{10}),
|
||||
std::make_tuple(InputShape{2, DYN, 9}, TargetShape{2, 1, 9}),
|
||||
std::make_tuple(InputShape{3, 3, DYN}, TargetShape{3, 3, 1})));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTBroadcastTest,
|
||||
testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{2, 2, 3, 4, 5}, TargetShape{2, 2, 3, 4, 5}),
|
||||
std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, TargetShape{3, 3, 64, 64, 64}, TargetShape{3, 3, 64, 64, 64}),
|
||||
std::make_tuple(InputShape{2, DYN, 64, 64, 64}, TargetShape{2, 3, 64, 64, 1}, TargetShape{2, 3, 64, 64, 64}),
|
||||
std::make_tuple(InputShape{3, 1, DYN, 64, 64}, TargetShape{1, 3, 3, 64, 64}, TargetShape{3, 3, 3, 64, 64}),
|
||||
std::make_tuple(InputShape{3, 1, DYN, 64, DYN}, TargetShape{1, 3, 3, 64, 3}, TargetShape{3, 3, 3, 64, 3}),
|
||||
std::make_tuple(InputShape{3, 3, 64, DYN, 64}, TargetShape{1, 1, 1, 2, 64}, TargetShape{3, 3, 64, 2, 64}),
|
||||
std::make_tuple(InputShape{3, 3, 64, 64, DYN}, TargetShape{3, 3, 64, 64, 3}, TargetShape{3, 3, 64, 64, 3}),
|
||||
std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{7, 3, 2, 3}, TargetShape{7, 3, 2, 3}),
|
||||
std::make_tuple(InputShape{DYN, 3, 64, 64}, TargetShape{3, 3, 64, 64}, TargetShape{3, 3, 64, 64}),
|
||||
std::make_tuple(InputShape{2, DYN, 64, 64}, TargetShape{2, 3, 64, 64}, TargetShape{2, 3, 64, 64}),
|
||||
std::make_tuple(InputShape{3, 3, DYN, 64}, TargetShape{1, 3, 1}, TargetShape{3, 3, 3, 64}),
|
||||
std::make_tuple(InputShape{3, 3, DYN, 64}, TargetShape{3, 3, 64}, TargetShape{3, 3, 3, 64}),
|
||||
std::make_tuple(InputShape{3, 3, 64, DYN}, TargetShape{64}, TargetShape{3, 3, 64, 64}),
|
||||
std::make_tuple(InputShape{DYN, DYN, DYN}, TargetShape{5, 3, 3}, TargetShape{5, 3, 3}),
|
||||
std::make_tuple(InputShape{1, 3, DYN}, TargetShape{3, 3, 10}, TargetShape{3, 3, 10}),
|
||||
std::make_tuple(InputShape{2, DYN, 9}, TargetShape{2, 2, 1}, TargetShape{2, 2, 9}),
|
||||
std::make_tuple(InputShape{3, 3, DYN}, TargetShape{3}, TargetShape{3, 3, 3})));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTBroadcastMultiplyTest,
|
||||
testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{5}),
|
||||
std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, TargetShape{4}),
|
||||
std::make_tuple(InputShape{2, DYN, 64, 64, 64}, TargetShape{3}),
|
||||
std::make_tuple(InputShape{3, 1, DYN, 64, 64}, TargetShape{2}),
|
||||
std::make_tuple(InputShape{3, 3, 64, DYN, 64}, TargetShape{1}),
|
||||
std::make_tuple(InputShape{1, 3, 64, 64}, TargetShape{5}),
|
||||
std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{4}),
|
||||
std::make_tuple(InputShape{DYN, 3, 64, 64}, TargetShape{3}),
|
||||
std::make_tuple(InputShape{2, DYN, 64, 64}, TargetShape{2}),
|
||||
std::make_tuple(InputShape{3, 3, DYN, 64}, TargetShape{1}),
|
||||
std::make_tuple(InputShape{DYN, DYN, DYN}, TargetShape{5}),
|
||||
std::make_tuple(InputShape{DYN, 3, 10}, TargetShape{4}),
|
||||
std::make_tuple(InputShape{2, DYN, 9}, TargetShape{3}),
|
||||
std::make_tuple(InputShape{3, 3, DYN}, TargetShape{2})));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(ConvertBroadcast3BIDIRECT, ConvertBroadcast3BIDIRECTBroadcastLogicalOrTest,
|
||||
testing::Values(std::make_tuple(InputShape{DYN, DYN, DYN, DYN, DYN}, TargetShape{5}),
|
||||
std::make_tuple(InputShape{DYN, 3, 64, 64, 64}, TargetShape{4}),
|
||||
std::make_tuple(InputShape{2, DYN, 64, 64, 64}, TargetShape{3}),
|
||||
std::make_tuple(InputShape{3, 1, DYN, 64, 64}, TargetShape{2}),
|
||||
std::make_tuple(InputShape{3, 3, 64, DYN, 64}, TargetShape{1}),
|
||||
std::make_tuple(InputShape{1, 3, 64, 64}, TargetShape{5}),
|
||||
std::make_tuple(InputShape{DYN, DYN, DYN, DYN}, TargetShape{4}),
|
||||
std::make_tuple(InputShape{DYN, 3, 64, 64}, TargetShape{3}),
|
||||
std::make_tuple(InputShape{2, DYN, 64, 64}, TargetShape{2}),
|
||||
std::make_tuple(InputShape{3, 3, DYN, 64}, TargetShape{1}),
|
||||
std::make_tuple(InputShape{DYN, DYN, DYN}, TargetShape{5}),
|
||||
std::make_tuple(InputShape{DYN, 3, 10}, TargetShape{4}),
|
||||
std::make_tuple(InputShape{2, DYN, 9}, TargetShape{3}),
|
||||
std::make_tuple(InputShape{3, 3, DYN}, TargetShape{2})));
|
||||
|
||||
|
||||
// Broadcast-3 is converted directly to Broadcast-1 for modes NUMPY, NONE and PDPD
|
||||
TEST(TransformationTests, ConvertBroadcast3WithNumpyModeToBroadcast1) {
|
||||
@ -30,8 +324,10 @@ TEST(TransformationTests, ConvertBroadcast3WithNumpyModeToBroadcast1) {
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
|
||||
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertBroadcast3().run_on_function(f);
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertBroadcast3>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
@ -63,8 +359,10 @@ TEST(TransformationTests, ConvertBroadcast3WithPDPDModeToBroadcast1) {
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
|
||||
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertBroadcast3().run_on_function(f);
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertBroadcast3>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
@ -97,8 +395,10 @@ TEST(TransformationTests, ConvertBroadcast3WithExplicitModeToBroadcast1) {
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
|
||||
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertBroadcast3().run_on_function(f);
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertBroadcast3>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
@ -131,20 +431,20 @@ TEST(TransformationTests, ConvertBroadcast3WithBidirectionalModeToBroadcast1) {
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
|
||||
|
||||
ngraph::pass::InitNodeInfo().run_on_function(f);
|
||||
ngraph::pass::ConvertBroadcast3().run_on_function(f);
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::InitNodeInfo>();
|
||||
manager.register_pass<pass::ConvertBroadcast3>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 1, 2});
|
||||
auto target_shape = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{3}, std::vector<int64_t>{3, 5, 1});
|
||||
auto constant_one = std::make_shared<ngraph::opset1::Constant>(input->get_output_element_type(0), ngraph::Shape({1}), std::vector<int>{1});
|
||||
auto broadcast_ones = std::make_shared<ngraph::opset1::Broadcast>(constant_one, target_shape, ngraph::op::AutoBroadcastType::NUMPY);
|
||||
auto multiply = std::make_shared<ngraph::opset1::Multiply>(input, broadcast_ones);
|
||||
multiply->set_friendly_name("broadcast");
|
||||
auto target_shape = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{3}, std::vector<int64_t>{3, 5, 2});
|
||||
auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(input, target_shape, ngraph::op::AutoBroadcastType::NUMPY);
|
||||
broadcast->set_friendly_name("broadcast");
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{multiply}, ngraph::ParameterVector{input});
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
|
Loading…
Reference in New Issue
Block a user