Add dynamic shape checks to nGraph transformations (#2735)
* Added dynamic shape checks for BatchNormDecompositoin pass * Added dynamic shapes checks for FQTranspose fusion pass * Added patter::has_static_rank predicate * Added dynamic shapes checks for BroadcastToTiles pass * Fixed BN inputs order * Add dynamic shape checks for DepthToSpace/SpaceToDepth passes * Added dynamic check for ReduceToPooling pass * Updated BN transformation * Fix PR comments * size_t to int64_t * Updated reduce to pooling pattern
This commit is contained in:
parent
8c97127aa7
commit
c4e0b74fb1
@ -47,7 +47,10 @@ public:
|
|||||||
class ngraph::pass::ConvertReduceMeanToPooling: public ConvertReduceBase {
|
class ngraph::pass::ConvertReduceMeanToPooling: public ConvertReduceBase {
|
||||||
public:
|
public:
|
||||||
ConvertReduceMeanToPooling() {
|
ConvertReduceMeanToPooling() {
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<opset1::ReduceMean>(), "ConvertReduceMean");
|
auto m = std::make_shared<ngraph::pattern::Matcher>(
|
||||||
|
ngraph::pattern::wrap_type<opset1::ReduceMean>({pattern::any_input(pattern::has_static_shape()),
|
||||||
|
pattern::wrap_type<opset1::Constant>()},
|
||||||
|
pattern::has_static_shape()), "ConvertReduceMean");
|
||||||
register_matcher(m, convert_reduce_to_pooling<opset1::ReduceMean>());
|
register_matcher(m, convert_reduce_to_pooling<opset1::ReduceMean>());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -55,7 +58,10 @@ public:
|
|||||||
class ngraph::pass::ConvertReduceMaxToPooling: public ConvertReduceBase {
|
class ngraph::pass::ConvertReduceMaxToPooling: public ConvertReduceBase {
|
||||||
public:
|
public:
|
||||||
ConvertReduceMaxToPooling() {
|
ConvertReduceMaxToPooling() {
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<opset1::ReduceMax>(), "ConvertReduceMax");
|
auto m = std::make_shared<ngraph::pattern::Matcher>(
|
||||||
|
ngraph::pattern::wrap_type<opset1::ReduceMax>({pattern::any_input(pattern::has_static_shape()),
|
||||||
|
pattern::wrap_type<opset1::Constant>()},
|
||||||
|
pattern::has_static_shape()), "ConvertReduceMax");
|
||||||
register_matcher(m, convert_reduce_to_pooling<opset1::ReduceMax>());
|
register_matcher(m, convert_reduce_to_pooling<opset1::ReduceMax>());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -63,7 +69,10 @@ public:
|
|||||||
class ngraph::pass::ConvertReduceSumToPooling: public ConvertReduceBase {
|
class ngraph::pass::ConvertReduceSumToPooling: public ConvertReduceBase {
|
||||||
public:
|
public:
|
||||||
ConvertReduceSumToPooling() {
|
ConvertReduceSumToPooling() {
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<opset1::ReduceSum>(), "ConvertReduceSum");
|
auto m = std::make_shared<ngraph::pattern::Matcher>(
|
||||||
|
ngraph::pattern::wrap_type<opset1::ReduceSum>({pattern::any_input(pattern::has_static_shape()),
|
||||||
|
pattern::wrap_type<opset1::Constant>()},
|
||||||
|
pattern::has_static_shape()), "ConvertReduceSum");
|
||||||
register_matcher(m, convert_reduce_to_pooling<opset1::ReduceSum>());
|
register_matcher(m, convert_reduce_to_pooling<opset1::ReduceSum>());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -79,12 +88,12 @@ ngraph::matcher_pass_callback ConvertReduceBase::convert_reduce_to_pooling() {
|
|||||||
|
|
||||||
auto input = reduce->input_value(0);
|
auto input = reduce->input_value(0);
|
||||||
|
|
||||||
auto axes_node = reduce->input_value(1).get_node_shared_ptr();
|
auto axes_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(reduce->input_value(1).get_node_shared_ptr());
|
||||||
if (!ngraph::op::is_constant(axes_node)) {
|
if (!axes_node) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto axes_vector = std::dynamic_pointer_cast<ngraph::opset1::Constant>(axes_node)->template cast_vector<int64_t>();
|
auto axes_vector = axes_node->template cast_vector<int64_t>();
|
||||||
const auto input_rank = input.get_partial_shape().rank().get_length();
|
const auto input_rank = input.get_partial_shape().rank().get_length();
|
||||||
// Transform negative axes into non-negative ones
|
// Transform negative axes into non-negative ones
|
||||||
for (size_t i = 0; i < axes_vector.size(); ++i) {
|
for (size_t i = 0; i < axes_vector.size(); ++i) {
|
||||||
@ -99,10 +108,6 @@ ngraph::matcher_pass_callback ConvertReduceBase::convert_reduce_to_pooling() {
|
|||||||
return replace_output_update_name(reduce->output(0), input);
|
return replace_output_update_name(reduce->output(0), input);
|
||||||
}
|
}
|
||||||
|
|
||||||
// As this transformation requires static input shape we should guaranty it
|
|
||||||
if (input.get_partial_shape().is_dynamic()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto input_shape = input.get_shape();
|
auto input_shape = input.get_shape();
|
||||||
|
|
||||||
// If Reduce op reduces only 1 dims we replace it with Reshape
|
// If Reduce op reduces only 1 dims we replace it with Reshape
|
||||||
|
@ -9,56 +9,42 @@
|
|||||||
|
|
||||||
#include <ngraph/opsets/opset1.hpp>
|
#include <ngraph/opsets/opset1.hpp>
|
||||||
#include <ngraph/rt_info.hpp>
|
#include <ngraph/rt_info.hpp>
|
||||||
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::PullTransposeThroughFQUp, "PullTransposeThroughFQUp", 0);
|
NGRAPH_RTTI_DEFINITION(ngraph::pass::PullTransposeThroughFQUp, "PullTransposeThroughFQUp", 0);
|
||||||
|
|
||||||
ngraph::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
|
ngraph::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
|
||||||
auto data1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
|
auto m_fq = pattern::wrap_type<opset1::FakeQuantize>({pattern::any_input(pattern::has_static_rank()),
|
||||||
auto data2 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
|
pattern::any_input(pattern::has_static_rank()),
|
||||||
auto data3 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
|
pattern::any_input(pattern::has_static_rank()),
|
||||||
auto data4 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
|
pattern::any_input(pattern::has_static_rank()),
|
||||||
auto data5 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1});
|
pattern::any_input(pattern::has_static_rank())},
|
||||||
auto fq = std::make_shared<ngraph::opset1::FakeQuantize>(data1, data2, data3, data4, data5, 1);
|
pattern::consumers_count(1));
|
||||||
auto transpose_order = std::make_shared<pattern::op::Label>(element::i64, Shape{4});
|
auto m_transpose = pattern::wrap_type<opset1::Transpose>({m_fq, pattern::wrap_type<opset1::Constant>()});
|
||||||
auto transpose = std::make_shared<ngraph::opset1::Transpose>(fq, transpose_order);
|
|
||||||
|
|
||||||
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
|
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||||
auto transpose = ngraph::as_type_ptr<ngraph::opset1::Transpose>(m.get_match_root());
|
auto & pattern_map = m.get_pattern_value_map();
|
||||||
if (!transpose) {
|
auto transpose = pattern_map[m_transpose].get_node_shared_ptr();
|
||||||
return false;
|
auto fq = pattern_map[m_fq].get_node_shared_ptr();
|
||||||
}
|
|
||||||
|
|
||||||
auto const_node = transpose->input(1).get_source_output().get_node_shared_ptr();
|
auto input_rank = fq->input(0).get_partial_shape().rank().get_length();
|
||||||
auto const_order = ngraph::as_type_ptr<ngraph::opset1::Constant>(const_node);
|
|
||||||
if (!const_order) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto fq_node = transpose->input(0).get_source_output().get_node_shared_ptr();
|
|
||||||
auto fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(fq_node);
|
|
||||||
if (!fq || fq->output(0).get_target_inputs().size() != 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto input_shape = fq->input(0).get_source_output().get_shape();
|
|
||||||
|
|
||||||
ngraph::NodeVector new_ops;
|
ngraph::NodeVector new_ops;
|
||||||
ngraph::OutputVector fq_inputs;
|
ngraph::OutputVector fq_inputs;
|
||||||
for (size_t i = 0; i < fq->inputs().size(); ++i) {
|
for (size_t i = 0; i < fq->inputs().size(); ++i) {
|
||||||
std::shared_ptr<ngraph::Node> fq_input;
|
auto fq_input = fq->input_value(i);
|
||||||
fq_input = fq->input(i).get_source_output().get_node_shared_ptr();
|
auto fq_input_rank = fq_input.get_partial_shape().rank().get_length();
|
||||||
auto fq_input_shape = fq_input->get_shape();
|
|
||||||
std::vector<int64_t> unsqueeze_axes;
|
std::vector<int64_t> unsqueeze_axes;
|
||||||
for (size_t j = 0; j < input_shape.size() - fq_input_shape.size(); ++j) {
|
for (size_t j = 0; j < input_rank - fq_input_rank; ++j) {
|
||||||
unsqueeze_axes.push_back(j);
|
unsqueeze_axes.push_back(j);
|
||||||
}
|
}
|
||||||
if (!unsqueeze_axes.empty()) {
|
if (!unsqueeze_axes.empty()) {
|
||||||
fq_input = std::make_shared<ngraph::opset1::Unsqueeze>(fq_input,
|
fq_input = std::make_shared<ngraph::opset1::Unsqueeze>(fq_input,
|
||||||
opset1::Constant::create(element::i64, Shape{unsqueeze_axes.size()}, unsqueeze_axes));
|
opset1::Constant::create(element::i64, Shape{unsqueeze_axes.size()}, unsqueeze_axes));
|
||||||
new_ops.push_back(fq_input);
|
new_ops.push_back(fq_input.get_node_shared_ptr());
|
||||||
}
|
}
|
||||||
fq_input = transpose->copy_with_new_inputs({fq_input, const_order});
|
fq_input = transpose->copy_with_new_inputs({fq_input, transpose->input_value(1)});
|
||||||
ngraph::copy_runtime_info(transpose, fq_input);
|
ngraph::copy_runtime_info(transpose, fq_input.get_node_shared_ptr());
|
||||||
fq_inputs.push_back(fq_input);
|
fq_inputs.push_back(fq_input);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -71,6 +57,6 @@ ngraph::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
|
|||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose, "PullTransposeThroughFQUp");
|
auto m = std::make_shared<ngraph::pattern::Matcher>(m_transpose, "PullTransposeThroughFQUp");
|
||||||
this->register_matcher(m, callback);
|
this->register_matcher(m, callback);
|
||||||
}
|
}
|
||||||
|
@ -10,41 +10,34 @@
|
|||||||
#include <ngraph/opsets/opset1.hpp>
|
#include <ngraph/opsets/opset1.hpp>
|
||||||
#include <ngraph/opsets/opset5.hpp>
|
#include <ngraph/opsets/opset5.hpp>
|
||||||
#include <ngraph/rt_info.hpp>
|
#include <ngraph/rt_info.hpp>
|
||||||
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
|
|
||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormDecomposition, "BatchNormDecomposition", 0);
|
NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormDecomposition, "BatchNormDecomposition", 0);
|
||||||
|
|
||||||
ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
|
ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
|
||||||
Shape shape{2, 2, 1, 1};
|
auto bn = pattern::wrap_type<opset1::BatchNormInference>({
|
||||||
auto input = make_shared<pattern::op::Label>(element::f32, shape);
|
pattern::any_input(pattern::has_static_rank()),
|
||||||
auto mean_shape = Shape{2};
|
pattern::any_input(pattern::has_static_shape()),
|
||||||
auto mean = make_shared<pattern::op::Label>(element::f32, mean_shape);
|
pattern::any_input(pattern::has_static_shape()),
|
||||||
auto var_shape = Shape{2};
|
pattern::any_input(pattern::has_static_shape()),
|
||||||
auto var = make_shared<pattern::op::Label>(element::f32, var_shape);
|
pattern::any_input(pattern::has_static_shape())
|
||||||
auto gamma_shape = Shape{2};
|
});
|
||||||
auto gamma = make_shared<pattern::op::Label>(element::f32, gamma_shape);
|
|
||||||
auto beta_shape = Shape{2};
|
|
||||||
auto beta = make_shared<pattern::op::Label>(element::f32, beta_shape);
|
|
||||||
auto bn = make_shared<opset1::BatchNormInference>(input, gamma, beta, mean, var, 0.001);
|
|
||||||
|
|
||||||
ngraph::graph_rewrite_callback callback = [this, input, gamma, beta, mean, var](ngraph::pattern::Matcher &m) {
|
|
||||||
auto pattern_map = m.get_pattern_map();
|
|
||||||
|
|
||||||
auto m_input = pattern_map[input];
|
|
||||||
auto m_gamma = pattern_map[gamma];
|
|
||||||
auto m_beta = pattern_map[beta];
|
|
||||||
auto m_mean = pattern_map[mean];
|
|
||||||
auto m_var = pattern_map[var];
|
|
||||||
|
|
||||||
// TODO: check that all input shapes are static
|
|
||||||
|
|
||||||
|
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher &m) {
|
||||||
auto m_bn = dynamic_pointer_cast<opset1::BatchNormInference>(m.get_match_root());
|
auto m_bn = dynamic_pointer_cast<opset1::BatchNormInference>(m.get_match_root());
|
||||||
if (!m_bn) {
|
if (!m_bn) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& input_type = m_input->get_element_type();
|
auto m_gamma = m_bn->input_value(0);
|
||||||
|
auto m_beta = m_bn->input_value(1);
|
||||||
|
auto m_input = m_bn->input_value(2);
|
||||||
|
auto m_mean = m_bn->input_value(3);
|
||||||
|
auto m_var = m_bn->input_value(4);
|
||||||
|
|
||||||
|
const auto& input_type = m_input.get_element_type();
|
||||||
// scale_add = variance + eps
|
// scale_add = variance + eps
|
||||||
auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()}));
|
auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()}));
|
||||||
// scale = sqrt(variance + eps)
|
// scale = sqrt(variance + eps)
|
||||||
@ -52,8 +45,10 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
|
|||||||
// Divide `gamma` by `sqrt(variance + eps)`
|
// Divide `gamma` by `sqrt(variance + eps)`
|
||||||
auto gamma_div_scale = std::make_shared<opset5::Divide>(m_gamma, scale);
|
auto gamma_div_scale = std::make_shared<opset5::Divide>(m_gamma, scale);
|
||||||
|
|
||||||
size_t dims_to_add = m_input->get_shape().size() - 2;
|
int64_t dims_to_add = m_input.get_partial_shape().rank().get_length() - 2;
|
||||||
Shape input_aligned_shape = m_gamma->get_shape();
|
|
||||||
|
// TODO: instead of getting full shape we can concatenate sequence of ones with ShapeOf
|
||||||
|
Shape input_aligned_shape = m_gamma.get_shape();
|
||||||
for (size_t i = 0; i < dims_to_add; ++i)
|
for (size_t i = 0; i < dims_to_add; ++i)
|
||||||
input_aligned_shape.push_back(1);
|
input_aligned_shape.push_back(1);
|
||||||
auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape);
|
auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape);
|
||||||
@ -84,36 +79,29 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
|
|||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormV5Decomposition, "BatchNormDecomposition", 5);
|
NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormV5Decomposition, "BatchNormDecomposition", 5);
|
||||||
|
|
||||||
|
// TODO: this pass will be unified with BatchNormDecomposition pass
|
||||||
ngraph::pass::BatchNormV5Decomposition::BatchNormV5Decomposition() {
|
ngraph::pass::BatchNormV5Decomposition::BatchNormV5Decomposition() {
|
||||||
Shape shape{2, 2, 1, 1};
|
auto bn = pattern::wrap_type<opset5::BatchNormInference>({
|
||||||
auto input = make_shared<pattern::op::Label>(element::f32, shape);
|
pattern::any_input(pattern::has_static_rank()),
|
||||||
auto mean_shape = Shape{2};
|
pattern::any_input(pattern::has_static_shape()),
|
||||||
auto mean = make_shared<pattern::op::Label>(element::f32, mean_shape);
|
pattern::any_input(pattern::has_static_shape()),
|
||||||
auto var_shape = Shape{2};
|
pattern::any_input(pattern::has_static_shape()),
|
||||||
auto var = make_shared<pattern::op::Label>(element::f32, var_shape);
|
pattern::any_input(pattern::has_static_shape())
|
||||||
auto gamma_shape = Shape{2};
|
});
|
||||||
auto gamma = make_shared<pattern::op::Label>(element::f32, gamma_shape);
|
|
||||||
auto beta_shape = Shape{2};
|
|
||||||
auto beta = make_shared<pattern::op::Label>(element::f32, beta_shape);
|
|
||||||
auto bn = make_shared<opset5::BatchNormInference>(input, gamma, beta, mean, var, 0.001);
|
|
||||||
|
|
||||||
ngraph::graph_rewrite_callback callback = [this, input, gamma, beta, mean, var](ngraph::pattern::Matcher &m) {
|
|
||||||
auto pattern_map = m.get_pattern_map();
|
|
||||||
|
|
||||||
auto m_input = pattern_map[input];
|
|
||||||
auto m_gamma = pattern_map[gamma];
|
|
||||||
auto m_beta = pattern_map[beta];
|
|
||||||
auto m_mean = pattern_map[mean];
|
|
||||||
auto m_var = pattern_map[var];
|
|
||||||
|
|
||||||
// TODO: check that all input shapes are static
|
|
||||||
|
|
||||||
|
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher &m) {
|
||||||
auto m_bn = dynamic_pointer_cast<opset5::BatchNormInference>(m.get_match_root());
|
auto m_bn = dynamic_pointer_cast<opset5::BatchNormInference>(m.get_match_root());
|
||||||
if (!m_bn) {
|
if (!m_bn) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto& input_type = m_input->get_element_type();
|
auto m_input = m_bn->input_value(0);
|
||||||
|
auto m_gamma = m_bn->input_value(1);
|
||||||
|
auto m_beta = m_bn->input_value(2);
|
||||||
|
auto m_mean = m_bn->input_value(3);
|
||||||
|
auto m_var = m_bn->input_value(4);
|
||||||
|
|
||||||
|
const auto& input_type = m_input.get_element_type();
|
||||||
// scale_add = variance + eps
|
// scale_add = variance + eps
|
||||||
auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()}));
|
auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()}));
|
||||||
// scale = sqrt(variance + eps)
|
// scale = sqrt(variance + eps)
|
||||||
@ -121,8 +109,10 @@ ngraph::pass::BatchNormV5Decomposition::BatchNormV5Decomposition() {
|
|||||||
// Divide `gamma` by `sqrt(variance + eps)`
|
// Divide `gamma` by `sqrt(variance + eps)`
|
||||||
auto gamma_div_scale = std::make_shared<opset5::Divide>(m_gamma, scale);
|
auto gamma_div_scale = std::make_shared<opset5::Divide>(m_gamma, scale);
|
||||||
|
|
||||||
size_t dims_to_add = m_input->get_shape().size() - 2;
|
int64_t dims_to_add = m_input.get_partial_shape().rank().get_length() - 2;
|
||||||
Shape input_aligned_shape = m_gamma->get_shape();
|
|
||||||
|
// TODO: instead of getting full shape we can concatenate sequence of ones with ShapeOf
|
||||||
|
Shape input_aligned_shape = m_gamma.get_shape();
|
||||||
for (size_t i = 0; i < dims_to_add; ++i)
|
for (size_t i = 0; i < dims_to_add; ++i)
|
||||||
input_aligned_shape.push_back(1);
|
input_aligned_shape.push_back(1);
|
||||||
auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape);
|
auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape);
|
||||||
|
@ -16,24 +16,28 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertBroadcastToTiles, "ConvertBroadcastT
|
|||||||
ngraph::pass::ConvertBroadcastToTiles::ConvertBroadcastToTiles() {
|
ngraph::pass::ConvertBroadcastToTiles::ConvertBroadcastToTiles() {
|
||||||
auto broadcast = ngraph::pattern::wrap_type<ngraph::opset1::Broadcast>();
|
auto broadcast = ngraph::pattern::wrap_type<ngraph::opset1::Broadcast>();
|
||||||
|
|
||||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||||
auto broadcast = std::dynamic_pointer_cast<ngraph::opset1::Broadcast>(m.get_match_root());
|
auto broadcast = std::dynamic_pointer_cast<ngraph::opset1::Broadcast>(m.get_match_root());
|
||||||
|
|
||||||
if (!broadcast) {
|
if (!broadcast) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto data_node = broadcast->input_value(0).get_node_shared_ptr();
|
auto data_node = broadcast->input_value(0);
|
||||||
|
if (data_node.get_partial_shape().is_dynamic()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
auto shape_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(broadcast->input_value(1).get_node_shared_ptr());
|
auto shape_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(broadcast->input_value(1).get_node_shared_ptr());
|
||||||
auto axes_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(broadcast->input_value(2).get_node_shared_ptr());
|
auto axes_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(broadcast->input_value(2).get_node_shared_ptr());
|
||||||
if (!data_node || !shape_node || !axes_node) return false;
|
if (!shape_node || !axes_node) return false;
|
||||||
|
|
||||||
auto output_shape = shape_node->cast_vector<int64_t>();
|
auto output_shape = shape_node->cast_vector<int64_t>();
|
||||||
auto input_shape = data_node->get_shape();
|
auto input_shape = data_node.get_shape();
|
||||||
int64_t cur_dim_id = output_shape.size() - 1;
|
int64_t cur_dim_id = output_shape.size() - 1;
|
||||||
size_t dims_count = output_shape.size();
|
size_t dims_count = output_shape.size();
|
||||||
|
|
||||||
auto last_node = std::dynamic_pointer_cast<ngraph::Node>(data_node);
|
auto last_node = data_node;
|
||||||
|
|
||||||
NodeVector new_ops;
|
NodeVector new_ops;
|
||||||
|
|
||||||
@ -61,7 +65,7 @@ ngraph::pass::ConvertBroadcastToTiles::ConvertBroadcastToTiles() {
|
|||||||
auto shape_const = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape {shape.size()}, shape);
|
auto shape_const = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape {shape.size()}, shape);
|
||||||
auto reshape = std::make_shared<ngraph::opset1::Reshape>(data_node, shape_const, true);
|
auto reshape = std::make_shared<ngraph::opset1::Reshape>(data_node, shape_const, true);
|
||||||
new_ops.push_back(reshape);
|
new_ops.push_back(reshape);
|
||||||
last_node = std::dynamic_pointer_cast<ngraph::Node>(reshape);
|
last_node = reshape;
|
||||||
input_shape = shape;
|
input_shape = shape;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -87,9 +91,8 @@ ngraph::pass::ConvertBroadcastToTiles::ConvertBroadcastToTiles() {
|
|||||||
new_ops.push_back(tile);
|
new_ops.push_back(tile);
|
||||||
tile->set_friendly_name(broadcast->get_friendly_name());
|
tile->set_friendly_name(broadcast->get_friendly_name());
|
||||||
|
|
||||||
last_node = std::dynamic_pointer_cast<ngraph::Node>(tile);
|
|
||||||
ngraph::copy_runtime_info(broadcast, new_ops);
|
ngraph::copy_runtime_info(broadcast, new_ops);
|
||||||
ngraph::replace_node(broadcast, last_node);
|
ngraph::replace_node(broadcast, tile);
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertDepthToSpace, "ConvertDepthToSpace", 0);
|
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertDepthToSpace, "ConvertDepthToSpace", 0);
|
||||||
|
|
||||||
ngraph::pass::ConvertDepthToSpace::ConvertDepthToSpace() {
|
ngraph::pass::ConvertDepthToSpace::ConvertDepthToSpace() {
|
||||||
auto dts_node = ngraph::pattern::wrap_type<ngraph::opset1::DepthToSpace>();
|
auto dts_node = ngraph::pattern::wrap_type<ngraph::opset1::DepthToSpace>({pattern::any_input(pattern::has_static_shape())});
|
||||||
|
|
||||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||||
auto dts_node = std::dynamic_pointer_cast<ngraph::opset1::DepthToSpace> (m.get_match_root());
|
auto dts_node = std::dynamic_pointer_cast<ngraph::opset1::DepthToSpace> (m.get_match_root());
|
||||||
@ -22,7 +22,7 @@ ngraph::pass::ConvertDepthToSpace::ConvertDepthToSpace() {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto input = dts_node->input(0).get_source_output().get_node_shared_ptr();
|
auto input = dts_node->input_value(0);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* In this transformation we decompose DepthToSpace operation to the next sequence of ops:
|
* In this transformation we decompose DepthToSpace operation to the next sequence of ops:
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertSpaceToDepth, "ConvertSpaceToDepth", 0);
|
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertSpaceToDepth, "ConvertSpaceToDepth", 0);
|
||||||
|
|
||||||
ngraph::pass::ConvertSpaceToDepth::ConvertSpaceToDepth() {
|
ngraph::pass::ConvertSpaceToDepth::ConvertSpaceToDepth() {
|
||||||
auto dts = ngraph::pattern::wrap_type<ngraph::opset1::SpaceToDepth>();
|
auto dts = ngraph::pattern::wrap_type<ngraph::opset1::SpaceToDepth>({pattern::any_input(pattern::has_static_shape())});
|
||||||
|
|
||||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||||
auto std_node = std::dynamic_pointer_cast<ngraph::opset1::SpaceToDepth> (m.get_match_root());
|
auto std_node = std::dynamic_pointer_cast<ngraph::opset1::SpaceToDepth> (m.get_match_root());
|
||||||
@ -22,7 +22,7 @@ ngraph::pass::ConvertSpaceToDepth::ConvertSpaceToDepth() {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto input = std_node->input(0).get_source_output().get_node_shared_ptr();
|
auto input = std_node->input_value(0);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* In this transformation we decompose SpaceToDepth operation to the next sequence of ops:
|
* In this transformation we decompose SpaceToDepth operation to the next sequence of ops:
|
||||||
|
@ -0,0 +1,40 @@
|
|||||||
|
// Copyright (C) 2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
|
#include <ngraph/function.hpp>
|
||||||
|
#include <ngraph/opsets/opset1.hpp>
|
||||||
|
#include <transformations/op_conversions/batch_norm_decomposition.hpp>
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
#include <transformations/utils/utils.hpp>
|
||||||
|
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
|
||||||
|
using namespace testing;
|
||||||
|
|
||||||
|
TEST(TransformationTests, BatchNormDecompositionDynamic) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
|
||||||
|
auto gamma = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3}, {3});
|
||||||
|
auto beta = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3}, {3});
|
||||||
|
auto mean = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3}, {3});
|
||||||
|
auto var = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3}, {3});
|
||||||
|
auto broadcast = std::make_shared<ngraph::opset1::BatchNormInference>(input, gamma, beta, mean, var, 0.001);
|
||||||
|
broadcast->set_friendly_name("broadcast");
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ngraph::pass::BatchNormDecomposition>();
|
||||||
|
ASSERT_NO_THROW(manager.run_passes(f));
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,40 @@
|
|||||||
|
// Copyright (C) 2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
|
#include <ngraph/function.hpp>
|
||||||
|
#include <ngraph/opsets/opset1.hpp>
|
||||||
|
#include <ngraph/opsets/opset3.hpp>
|
||||||
|
#include <transformations/op_conversions/convert_broadcast_to_tiles.hpp>
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
#include <transformations/utils/utils.hpp>
|
||||||
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
|
||||||
|
using namespace testing;
|
||||||
|
|
||||||
|
TEST(TransformationTests, ConvertBroadcastToTilesDynamic) {
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||||
|
{
|
||||||
|
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
|
||||||
|
auto target_shape = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{3}, std::vector<int64_t>{3, 5, 2});
|
||||||
|
auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(input1, target_shape);
|
||||||
|
broadcast->set_friendly_name("broadcast");
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ngraph::pass::ConvertBroadcastToTiles>();
|
||||||
|
ASSERT_NO_THROW(manager.run_passes(f));
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -54,8 +54,7 @@ public:
|
|||||||
f_ref = get_reference_function(input_shape, reduce_type, reference_params);
|
f_ref = get_reference_function(input_shape, reduce_type, reference_params);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
static std::shared_ptr<ngraph::Function> get_initial_function(const ngraph::PartialShape & input_shape,
|
||||||
std::shared_ptr<ngraph::Function> get_initial_function(const ngraph::PartialShape & input_shape,
|
|
||||||
const std::vector<int64_t> & axes,
|
const std::vector<int64_t> & axes,
|
||||||
const ReduceType & reduce_type,
|
const ReduceType & reduce_type,
|
||||||
const bool keep_dims) {
|
const bool keep_dims) {
|
||||||
@ -72,7 +71,7 @@ private:
|
|||||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{reduce}, ngraph::ParameterVector{input});
|
return std::make_shared<ngraph::Function>(ngraph::NodeVector{reduce}, ngraph::ParameterVector{input});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Function> get_reference_function(const ngraph::PartialShape & input_shape,
|
static std::shared_ptr<ngraph::Function> get_reference_function(const ngraph::PartialShape & input_shape,
|
||||||
const ReduceType & reduce,
|
const ReduceType & reduce,
|
||||||
const ReduceToPoolParams & params) {
|
const ReduceToPoolParams & params) {
|
||||||
auto param = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
auto param = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
|
||||||
@ -137,6 +136,10 @@ INSTANTIATE_TEST_CASE_P(ReduceToReshapePoolReshape, ConvertReduceToPoolingTests,
|
|||||||
std::make_tuple(MAX, InputShape{2, 9}, ReduceAxes{-1}, KeepDims{true}, ReduceToPoolParams({1, 1, 9, 1}, {9, 1}, {1, 1})),
|
std::make_tuple(MAX, InputShape{2, 9}, ReduceAxes{-1}, KeepDims{true}, ReduceToPoolParams({1, 1, 9, 1}, {9, 1}, {1, 1})),
|
||||||
std::make_tuple(MAX, InputShape{2, 3, 4, 1}, ReduceAxes{1, 3, 2}, KeepDims{false}, ReduceToPoolParams({1, 1, 12, 1}, {12, 1}, {1}))));
|
std::make_tuple(MAX, InputShape{2, 3, 4, 1}, ReduceAxes{1, 3, 2}, KeepDims{false}, ReduceToPoolParams({1, 1, 12, 1}, {12, 1}, {1}))));
|
||||||
|
|
||||||
|
TEST(ConvertReduceToPooling, Negative) {
|
||||||
|
auto f = ConvertReduceToPoolingTests::get_initial_function(
|
||||||
|
ngraph::PartialShape::dynamic(), {3}, MAX, true);
|
||||||
|
ASSERT_NO_THROW(ngraph::pass::ConvertReduceToPooling().run_on_function(f));
|
||||||
|
}
|
||||||
|
|
||||||
#undef MAX
|
#undef MAX
|
||||||
|
|
||||||
|
|
||||||
|
@ -181,3 +181,29 @@ TEST(TransformationTests, TestSpaceToDepthTransformDepthFirst) {
|
|||||||
std::vector<int64_t> shape_end_value_ref{1, 12 * 4, 1080 / 2, 1616 / 2};
|
std::vector<int64_t> shape_end_value_ref{1, 12 * 4, 1080 / 2, 1616 / 2};
|
||||||
ASSERT_EQ(shape_end_value, shape_end_value_ref);
|
ASSERT_EQ(shape_end_value, shape_end_value_ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, TestSpaceToDepthDynamic) {
|
||||||
|
auto input = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr);
|
||||||
|
|
||||||
|
{
|
||||||
|
auto space_to_depth = std::make_shared<ngraph::op::SpaceToDepth>(input, ngraph::op::SpaceToDepth::SpaceToDepthMode::DEPTH_FIRST, 2);
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{space_to_depth}, ngraph::ParameterVector{input});
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
m.register_pass<ngraph::pass::ConvertSpaceToDepth>();
|
||||||
|
ASSERT_NO_THROW(m.run_passes(f));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, TestDepthToSpaceDynamic) {
|
||||||
|
auto input = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr);
|
||||||
|
|
||||||
|
{
|
||||||
|
auto depth_to_space = std::make_shared<ngraph::op::DepthToSpace>(input, ngraph::op::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, 2);
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{depth_to_space}, ngraph::ParameterVector{input});
|
||||||
|
ngraph::pass::Manager m;
|
||||||
|
m.register_pass<ngraph::pass::ConvertDepthToSpace>();
|
||||||
|
ASSERT_NO_THROW(m.run_passes(f));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -55,3 +55,29 @@ TEST(TransformationTests, FQTransposeTest1) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, FQTransposeDynamic) {
|
||||||
|
auto data1 = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
|
||||||
|
auto data2 = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{3}, {1, 2, 3});
|
||||||
|
auto data3 = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1, 3}, {1, 2, 3});
|
||||||
|
auto data4 = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1, 3}, {1, 2, 3});
|
||||||
|
auto data5 = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1, 3}, {1, 2, 3});
|
||||||
|
auto transpose_order = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1});
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Function> f(nullptr);
|
||||||
|
{
|
||||||
|
auto fq = std::make_shared<ngraph::op::FakeQuantize>(data1, data2, data3, data4, data5, 1);
|
||||||
|
auto transpose = std::make_shared<ngraph::op::Transpose>(fq, transpose_order);
|
||||||
|
|
||||||
|
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{transpose}, ngraph::ParameterVector{data1});
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ngraph::pass::PullTransposeThroughFQUp>();
|
||||||
|
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
|
||||||
|
check_rt_info(f);
|
||||||
|
});
|
||||||
|
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||||
|
ASSERT_NO_THROW(manager.run_passes(f));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -61,6 +61,9 @@ namespace ngraph
|
|||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
std::function<bool(Output<Node>)> has_static_shape();
|
std::function<bool(Output<Node>)> has_static_shape();
|
||||||
|
|
||||||
|
NGRAPH_API
|
||||||
|
std::function<bool(Output<Node>)> has_static_rank();
|
||||||
|
|
||||||
NGRAPH_API
|
NGRAPH_API
|
||||||
std::function<bool(Output<Node>)> type_matches(const element::Type& type);
|
std::function<bool(Output<Node>)> type_matches(const element::Type& type);
|
||||||
|
|
||||||
|
@ -95,6 +95,13 @@ namespace ngraph
|
|||||||
[=](Output<Node> output) -> bool { return output.get_partial_shape().is_static(); };
|
[=](Output<Node> output) -> bool { return output.get_partial_shape().is_static(); };
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::function<bool(Output<Node>)> has_static_rank()
|
||||||
|
{
|
||||||
|
return [=](Output<Node> output) -> bool {
|
||||||
|
return output.get_partial_shape().rank().is_static();
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
std::function<bool(Output<Node>)> type_matches(const element::Type& type)
|
std::function<bool(Output<Node>)> type_matches(const element::Type& type)
|
||||||
{
|
{
|
||||||
return [=](Output<Node> output) -> bool { return output.get_element_type() == type; };
|
return [=](Output<Node> output) -> bool { return output.get_element_type() == type; };
|
||||||
|
Loading…
Reference in New Issue
Block a user