Add dynamic shape checks to nGraph transformations (#2735)

* Added dynamic shape checks for BatchNormDecompositoin pass

* Added dynamic shapes checks for FQTranspose fusion pass

* Added patter::has_static_rank predicate

* Added dynamic shapes checks for BroadcastToTiles pass

* Fixed BN inputs order

* Add dynamic shape checks for DepthToSpace/SpaceToDepth passes

* Added dynamic check for ReduceToPooling pass

* Updated BN transformation

* Fix PR comments

* size_t to int64_t

* Updated reduce to pooling pattern
This commit is contained in:
Gleb Kazantaev 2020-10-23 15:39:47 +03:00 committed by GitHub
parent 8c97127aa7
commit c4e0b74fb1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 240 additions and 111 deletions

View File

@ -47,7 +47,10 @@ public:
class ngraph::pass::ConvertReduceMeanToPooling: public ConvertReduceBase { class ngraph::pass::ConvertReduceMeanToPooling: public ConvertReduceBase {
public: public:
ConvertReduceMeanToPooling() { ConvertReduceMeanToPooling() {
auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<opset1::ReduceMean>(), "ConvertReduceMean"); auto m = std::make_shared<ngraph::pattern::Matcher>(
ngraph::pattern::wrap_type<opset1::ReduceMean>({pattern::any_input(pattern::has_static_shape()),
pattern::wrap_type<opset1::Constant>()},
pattern::has_static_shape()), "ConvertReduceMean");
register_matcher(m, convert_reduce_to_pooling<opset1::ReduceMean>()); register_matcher(m, convert_reduce_to_pooling<opset1::ReduceMean>());
} }
}; };
@ -55,7 +58,10 @@ public:
class ngraph::pass::ConvertReduceMaxToPooling: public ConvertReduceBase { class ngraph::pass::ConvertReduceMaxToPooling: public ConvertReduceBase {
public: public:
ConvertReduceMaxToPooling() { ConvertReduceMaxToPooling() {
auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<opset1::ReduceMax>(), "ConvertReduceMax"); auto m = std::make_shared<ngraph::pattern::Matcher>(
ngraph::pattern::wrap_type<opset1::ReduceMax>({pattern::any_input(pattern::has_static_shape()),
pattern::wrap_type<opset1::Constant>()},
pattern::has_static_shape()), "ConvertReduceMax");
register_matcher(m, convert_reduce_to_pooling<opset1::ReduceMax>()); register_matcher(m, convert_reduce_to_pooling<opset1::ReduceMax>());
} }
}; };
@ -63,7 +69,10 @@ public:
class ngraph::pass::ConvertReduceSumToPooling: public ConvertReduceBase { class ngraph::pass::ConvertReduceSumToPooling: public ConvertReduceBase {
public: public:
ConvertReduceSumToPooling() { ConvertReduceSumToPooling() {
auto m = std::make_shared<ngraph::pattern::Matcher>(ngraph::pattern::wrap_type<opset1::ReduceSum>(), "ConvertReduceSum"); auto m = std::make_shared<ngraph::pattern::Matcher>(
ngraph::pattern::wrap_type<opset1::ReduceSum>({pattern::any_input(pattern::has_static_shape()),
pattern::wrap_type<opset1::Constant>()},
pattern::has_static_shape()), "ConvertReduceSum");
register_matcher(m, convert_reduce_to_pooling<opset1::ReduceSum>()); register_matcher(m, convert_reduce_to_pooling<opset1::ReduceSum>());
} }
}; };
@ -79,12 +88,12 @@ ngraph::matcher_pass_callback ConvertReduceBase::convert_reduce_to_pooling() {
auto input = reduce->input_value(0); auto input = reduce->input_value(0);
auto axes_node = reduce->input_value(1).get_node_shared_ptr(); auto axes_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(reduce->input_value(1).get_node_shared_ptr());
if (!ngraph::op::is_constant(axes_node)) { if (!axes_node) {
return false; return false;
} }
auto axes_vector = std::dynamic_pointer_cast<ngraph::opset1::Constant>(axes_node)->template cast_vector<int64_t>(); auto axes_vector = axes_node->template cast_vector<int64_t>();
const auto input_rank = input.get_partial_shape().rank().get_length(); const auto input_rank = input.get_partial_shape().rank().get_length();
// Transform negative axes into non-negative ones // Transform negative axes into non-negative ones
for (size_t i = 0; i < axes_vector.size(); ++i) { for (size_t i = 0; i < axes_vector.size(); ++i) {
@ -99,10 +108,6 @@ ngraph::matcher_pass_callback ConvertReduceBase::convert_reduce_to_pooling() {
return replace_output_update_name(reduce->output(0), input); return replace_output_update_name(reduce->output(0), input);
} }
// As this transformation requires static input shape we should guaranty it
if (input.get_partial_shape().is_dynamic()) {
return false;
}
auto input_shape = input.get_shape(); auto input_shape = input.get_shape();
// If Reduce op reduces only 1 dims we replace it with Reshape // If Reduce op reduces only 1 dims we replace it with Reshape

View File

@ -9,56 +9,42 @@
#include <ngraph/opsets/opset1.hpp> #include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp> #include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::PullTransposeThroughFQUp, "PullTransposeThroughFQUp", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::PullTransposeThroughFQUp, "PullTransposeThroughFQUp", 0);
ngraph::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() { ngraph::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
auto data1 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1}); auto m_fq = pattern::wrap_type<opset1::FakeQuantize>({pattern::any_input(pattern::has_static_rank()),
auto data2 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1}); pattern::any_input(pattern::has_static_rank()),
auto data3 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1}); pattern::any_input(pattern::has_static_rank()),
auto data4 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1}); pattern::any_input(pattern::has_static_rank()),
auto data5 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1}); pattern::any_input(pattern::has_static_rank())},
auto fq = std::make_shared<ngraph::opset1::FakeQuantize>(data1, data2, data3, data4, data5, 1); pattern::consumers_count(1));
auto transpose_order = std::make_shared<pattern::op::Label>(element::i64, Shape{4}); auto m_transpose = pattern::wrap_type<opset1::Transpose>({m_fq, pattern::wrap_type<opset1::Constant>()});
auto transpose = std::make_shared<ngraph::opset1::Transpose>(fq, transpose_order);
ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) { ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto transpose = ngraph::as_type_ptr<ngraph::opset1::Transpose>(m.get_match_root()); auto & pattern_map = m.get_pattern_value_map();
if (!transpose) { auto transpose = pattern_map[m_transpose].get_node_shared_ptr();
return false; auto fq = pattern_map[m_fq].get_node_shared_ptr();
}
auto const_node = transpose->input(1).get_source_output().get_node_shared_ptr(); auto input_rank = fq->input(0).get_partial_shape().rank().get_length();
auto const_order = ngraph::as_type_ptr<ngraph::opset1::Constant>(const_node);
if (!const_order) {
return false;
}
auto fq_node = transpose->input(0).get_source_output().get_node_shared_ptr();
auto fq = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(fq_node);
if (!fq || fq->output(0).get_target_inputs().size() != 1) {
return false;
}
auto input_shape = fq->input(0).get_source_output().get_shape();
ngraph::NodeVector new_ops; ngraph::NodeVector new_ops;
ngraph::OutputVector fq_inputs; ngraph::OutputVector fq_inputs;
for (size_t i = 0; i < fq->inputs().size(); ++i) { for (size_t i = 0; i < fq->inputs().size(); ++i) {
std::shared_ptr<ngraph::Node> fq_input; auto fq_input = fq->input_value(i);
fq_input = fq->input(i).get_source_output().get_node_shared_ptr(); auto fq_input_rank = fq_input.get_partial_shape().rank().get_length();
auto fq_input_shape = fq_input->get_shape();
std::vector<int64_t> unsqueeze_axes; std::vector<int64_t> unsqueeze_axes;
for (size_t j = 0; j < input_shape.size() - fq_input_shape.size(); ++j) { for (size_t j = 0; j < input_rank - fq_input_rank; ++j) {
unsqueeze_axes.push_back(j); unsqueeze_axes.push_back(j);
} }
if (!unsqueeze_axes.empty()) { if (!unsqueeze_axes.empty()) {
fq_input = std::make_shared<ngraph::opset1::Unsqueeze>(fq_input, fq_input = std::make_shared<ngraph::opset1::Unsqueeze>(fq_input,
opset1::Constant::create(element::i64, Shape{unsqueeze_axes.size()}, unsqueeze_axes)); opset1::Constant::create(element::i64, Shape{unsqueeze_axes.size()}, unsqueeze_axes));
new_ops.push_back(fq_input); new_ops.push_back(fq_input.get_node_shared_ptr());
} }
fq_input = transpose->copy_with_new_inputs({fq_input, const_order}); fq_input = transpose->copy_with_new_inputs({fq_input, transpose->input_value(1)});
ngraph::copy_runtime_info(transpose, fq_input); ngraph::copy_runtime_info(transpose, fq_input.get_node_shared_ptr());
fq_inputs.push_back(fq_input); fq_inputs.push_back(fq_input);
} }
@ -71,6 +57,6 @@ ngraph::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
return true; return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose, "PullTransposeThroughFQUp"); auto m = std::make_shared<ngraph::pattern::Matcher>(m_transpose, "PullTransposeThroughFQUp");
this->register_matcher(m, callback); this->register_matcher(m, callback);
} }

View File

@ -10,41 +10,34 @@
#include <ngraph/opsets/opset1.hpp> #include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset5.hpp> #include <ngraph/opsets/opset5.hpp>
#include <ngraph/rt_info.hpp> #include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
using namespace ngraph; using namespace ngraph;
NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormDecomposition, "BatchNormDecomposition", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormDecomposition, "BatchNormDecomposition", 0);
ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() { ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
Shape shape{2, 2, 1, 1}; auto bn = pattern::wrap_type<opset1::BatchNormInference>({
auto input = make_shared<pattern::op::Label>(element::f32, shape); pattern::any_input(pattern::has_static_rank()),
auto mean_shape = Shape{2}; pattern::any_input(pattern::has_static_shape()),
auto mean = make_shared<pattern::op::Label>(element::f32, mean_shape); pattern::any_input(pattern::has_static_shape()),
auto var_shape = Shape{2}; pattern::any_input(pattern::has_static_shape()),
auto var = make_shared<pattern::op::Label>(element::f32, var_shape); pattern::any_input(pattern::has_static_shape())
auto gamma_shape = Shape{2}; });
auto gamma = make_shared<pattern::op::Label>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<pattern::op::Label>(element::f32, beta_shape);
auto bn = make_shared<opset1::BatchNormInference>(input, gamma, beta, mean, var, 0.001);
ngraph::graph_rewrite_callback callback = [this, input, gamma, beta, mean, var](ngraph::pattern::Matcher &m) {
auto pattern_map = m.get_pattern_map();
auto m_input = pattern_map[input];
auto m_gamma = pattern_map[gamma];
auto m_beta = pattern_map[beta];
auto m_mean = pattern_map[mean];
auto m_var = pattern_map[var];
// TODO: check that all input shapes are static
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher &m) {
auto m_bn = dynamic_pointer_cast<opset1::BatchNormInference>(m.get_match_root()); auto m_bn = dynamic_pointer_cast<opset1::BatchNormInference>(m.get_match_root());
if (!m_bn) { if (!m_bn) {
return false; return false;
} }
const auto& input_type = m_input->get_element_type(); auto m_gamma = m_bn->input_value(0);
auto m_beta = m_bn->input_value(1);
auto m_input = m_bn->input_value(2);
auto m_mean = m_bn->input_value(3);
auto m_var = m_bn->input_value(4);
const auto& input_type = m_input.get_element_type();
// scale_add = variance + eps // scale_add = variance + eps
auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()})); auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()}));
// scale = sqrt(variance + eps) // scale = sqrt(variance + eps)
@ -52,8 +45,10 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
// Divide `gamma` by `sqrt(variance + eps)` // Divide `gamma` by `sqrt(variance + eps)`
auto gamma_div_scale = std::make_shared<opset5::Divide>(m_gamma, scale); auto gamma_div_scale = std::make_shared<opset5::Divide>(m_gamma, scale);
size_t dims_to_add = m_input->get_shape().size() - 2; int64_t dims_to_add = m_input.get_partial_shape().rank().get_length() - 2;
Shape input_aligned_shape = m_gamma->get_shape();
// TODO: instead of getting full shape we can concatenate sequence of ones with ShapeOf
Shape input_aligned_shape = m_gamma.get_shape();
for (size_t i = 0; i < dims_to_add; ++i) for (size_t i = 0; i < dims_to_add; ++i)
input_aligned_shape.push_back(1); input_aligned_shape.push_back(1);
auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape); auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape);
@ -84,36 +79,29 @@ ngraph::pass::BatchNormDecomposition::BatchNormDecomposition() {
NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormV5Decomposition, "BatchNormDecomposition", 5); NGRAPH_RTTI_DEFINITION(ngraph::pass::BatchNormV5Decomposition, "BatchNormDecomposition", 5);
// TODO: this pass will be unified with BatchNormDecomposition pass
ngraph::pass::BatchNormV5Decomposition::BatchNormV5Decomposition() { ngraph::pass::BatchNormV5Decomposition::BatchNormV5Decomposition() {
Shape shape{2, 2, 1, 1}; auto bn = pattern::wrap_type<opset5::BatchNormInference>({
auto input = make_shared<pattern::op::Label>(element::f32, shape); pattern::any_input(pattern::has_static_rank()),
auto mean_shape = Shape{2}; pattern::any_input(pattern::has_static_shape()),
auto mean = make_shared<pattern::op::Label>(element::f32, mean_shape); pattern::any_input(pattern::has_static_shape()),
auto var_shape = Shape{2}; pattern::any_input(pattern::has_static_shape()),
auto var = make_shared<pattern::op::Label>(element::f32, var_shape); pattern::any_input(pattern::has_static_shape())
auto gamma_shape = Shape{2}; });
auto gamma = make_shared<pattern::op::Label>(element::f32, gamma_shape);
auto beta_shape = Shape{2};
auto beta = make_shared<pattern::op::Label>(element::f32, beta_shape);
auto bn = make_shared<opset5::BatchNormInference>(input, gamma, beta, mean, var, 0.001);
ngraph::graph_rewrite_callback callback = [this, input, gamma, beta, mean, var](ngraph::pattern::Matcher &m) {
auto pattern_map = m.get_pattern_map();
auto m_input = pattern_map[input];
auto m_gamma = pattern_map[gamma];
auto m_beta = pattern_map[beta];
auto m_mean = pattern_map[mean];
auto m_var = pattern_map[var];
// TODO: check that all input shapes are static
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher &m) {
auto m_bn = dynamic_pointer_cast<opset5::BatchNormInference>(m.get_match_root()); auto m_bn = dynamic_pointer_cast<opset5::BatchNormInference>(m.get_match_root());
if (!m_bn) { if (!m_bn) {
return false; return false;
} }
const auto& input_type = m_input->get_element_type(); auto m_input = m_bn->input_value(0);
auto m_gamma = m_bn->input_value(1);
auto m_beta = m_bn->input_value(2);
auto m_mean = m_bn->input_value(3);
auto m_var = m_bn->input_value(4);
const auto& input_type = m_input.get_element_type();
// scale_add = variance + eps // scale_add = variance + eps
auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()})); auto scale_add = make_shared<opset5::Add>(m_var, opset5::Constant::create(input_type, Shape{}, {m_bn->get_eps_value()}));
// scale = sqrt(variance + eps) // scale = sqrt(variance + eps)
@ -121,8 +109,10 @@ ngraph::pass::BatchNormV5Decomposition::BatchNormV5Decomposition() {
// Divide `gamma` by `sqrt(variance + eps)` // Divide `gamma` by `sqrt(variance + eps)`
auto gamma_div_scale = std::make_shared<opset5::Divide>(m_gamma, scale); auto gamma_div_scale = std::make_shared<opset5::Divide>(m_gamma, scale);
size_t dims_to_add = m_input->get_shape().size() - 2; int64_t dims_to_add = m_input.get_partial_shape().rank().get_length() - 2;
Shape input_aligned_shape = m_gamma->get_shape();
// TODO: instead of getting full shape we can concatenate sequence of ones with ShapeOf
Shape input_aligned_shape = m_gamma.get_shape();
for (size_t i = 0; i < dims_to_add; ++i) for (size_t i = 0; i < dims_to_add; ++i)
input_aligned_shape.push_back(1); input_aligned_shape.push_back(1);
auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape); auto new_shape = opset5::Constant::create(element::i64, Shape{input_aligned_shape.size()}, input_aligned_shape);

View File

@ -16,24 +16,28 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertBroadcastToTiles, "ConvertBroadcastT
ngraph::pass::ConvertBroadcastToTiles::ConvertBroadcastToTiles() { ngraph::pass::ConvertBroadcastToTiles::ConvertBroadcastToTiles() {
auto broadcast = ngraph::pattern::wrap_type<ngraph::opset1::Broadcast>(); auto broadcast = ngraph::pattern::wrap_type<ngraph::opset1::Broadcast>();
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto broadcast = std::dynamic_pointer_cast<ngraph::opset1::Broadcast>(m.get_match_root()); auto broadcast = std::dynamic_pointer_cast<ngraph::opset1::Broadcast>(m.get_match_root());
if (!broadcast) { if (!broadcast) {
return false; return false;
} }
auto data_node = broadcast->input_value(0).get_node_shared_ptr(); auto data_node = broadcast->input_value(0);
if (data_node.get_partial_shape().is_dynamic()) {
return false;
}
auto shape_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(broadcast->input_value(1).get_node_shared_ptr()); auto shape_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(broadcast->input_value(1).get_node_shared_ptr());
auto axes_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(broadcast->input_value(2).get_node_shared_ptr()); auto axes_node = std::dynamic_pointer_cast<ngraph::opset1::Constant>(broadcast->input_value(2).get_node_shared_ptr());
if (!data_node || !shape_node || !axes_node) return false; if (!shape_node || !axes_node) return false;
auto output_shape = shape_node->cast_vector<int64_t>(); auto output_shape = shape_node->cast_vector<int64_t>();
auto input_shape = data_node->get_shape(); auto input_shape = data_node.get_shape();
int64_t cur_dim_id = output_shape.size() - 1; int64_t cur_dim_id = output_shape.size() - 1;
size_t dims_count = output_shape.size(); size_t dims_count = output_shape.size();
auto last_node = std::dynamic_pointer_cast<ngraph::Node>(data_node); auto last_node = data_node;
NodeVector new_ops; NodeVector new_ops;
@ -61,7 +65,7 @@ ngraph::pass::ConvertBroadcastToTiles::ConvertBroadcastToTiles() {
auto shape_const = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape {shape.size()}, shape); auto shape_const = std::make_shared<ngraph::opset1::Constant>(element::i64, Shape {shape.size()}, shape);
auto reshape = std::make_shared<ngraph::opset1::Reshape>(data_node, shape_const, true); auto reshape = std::make_shared<ngraph::opset1::Reshape>(data_node, shape_const, true);
new_ops.push_back(reshape); new_ops.push_back(reshape);
last_node = std::dynamic_pointer_cast<ngraph::Node>(reshape); last_node = reshape;
input_shape = shape; input_shape = shape;
} }
@ -87,9 +91,8 @@ ngraph::pass::ConvertBroadcastToTiles::ConvertBroadcastToTiles() {
new_ops.push_back(tile); new_ops.push_back(tile);
tile->set_friendly_name(broadcast->get_friendly_name()); tile->set_friendly_name(broadcast->get_friendly_name());
last_node = std::dynamic_pointer_cast<ngraph::Node>(tile);
ngraph::copy_runtime_info(broadcast, new_ops); ngraph::copy_runtime_info(broadcast, new_ops);
ngraph::replace_node(broadcast, last_node); ngraph::replace_node(broadcast, tile);
return true; return true;
}; };

View File

@ -14,7 +14,7 @@
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertDepthToSpace, "ConvertDepthToSpace", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertDepthToSpace, "ConvertDepthToSpace", 0);
ngraph::pass::ConvertDepthToSpace::ConvertDepthToSpace() { ngraph::pass::ConvertDepthToSpace::ConvertDepthToSpace() {
auto dts_node = ngraph::pattern::wrap_type<ngraph::opset1::DepthToSpace>(); auto dts_node = ngraph::pattern::wrap_type<ngraph::opset1::DepthToSpace>({pattern::any_input(pattern::has_static_shape())});
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto dts_node = std::dynamic_pointer_cast<ngraph::opset1::DepthToSpace> (m.get_match_root()); auto dts_node = std::dynamic_pointer_cast<ngraph::opset1::DepthToSpace> (m.get_match_root());
@ -22,7 +22,7 @@ ngraph::pass::ConvertDepthToSpace::ConvertDepthToSpace() {
return false; return false;
} }
auto input = dts_node->input(0).get_source_output().get_node_shared_ptr(); auto input = dts_node->input_value(0);
/* /*
* In this transformation we decompose DepthToSpace operation to the next sequence of ops: * In this transformation we decompose DepthToSpace operation to the next sequence of ops:

View File

@ -14,7 +14,7 @@
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertSpaceToDepth, "ConvertSpaceToDepth", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertSpaceToDepth, "ConvertSpaceToDepth", 0);
ngraph::pass::ConvertSpaceToDepth::ConvertSpaceToDepth() { ngraph::pass::ConvertSpaceToDepth::ConvertSpaceToDepth() {
auto dts = ngraph::pattern::wrap_type<ngraph::opset1::SpaceToDepth>(); auto dts = ngraph::pattern::wrap_type<ngraph::opset1::SpaceToDepth>({pattern::any_input(pattern::has_static_shape())});
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto std_node = std::dynamic_pointer_cast<ngraph::opset1::SpaceToDepth> (m.get_match_root()); auto std_node = std::dynamic_pointer_cast<ngraph::opset1::SpaceToDepth> (m.get_match_root());
@ -22,7 +22,7 @@ ngraph::pass::ConvertSpaceToDepth::ConvertSpaceToDepth() {
return false; return false;
} }
auto input = std_node->input(0).get_source_output().get_node_shared_ptr(); auto input = std_node->input_value(0);
/* /*
* In this transformation we decompose SpaceToDepth operation to the next sequence of ops: * In this transformation we decompose SpaceToDepth operation to the next sequence of ops:

View File

@ -0,0 +1,40 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <transformations/op_conversions/batch_norm_decomposition.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, BatchNormDecompositionDynamic) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
auto gamma = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3}, {3});
auto beta = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3}, {3});
auto mean = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3}, {3});
auto var = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{3}, {3});
auto broadcast = std::make_shared<ngraph::opset1::BatchNormInference>(input, gamma, beta, mean, var, 0.001);
broadcast->set_friendly_name("broadcast");
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::BatchNormDecomposition>();
ASSERT_NO_THROW(manager.run_passes(f));
ASSERT_NO_THROW(check_rt_info(f));
}
}

View File

@ -0,0 +1,40 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <transformations/op_conversions/convert_broadcast_to_tiles.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, ConvertBroadcastToTilesDynamic) {
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 1, 2});
auto target_shape = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{3}, std::vector<int64_t>{3, 5, 2});
auto broadcast = std::make_shared<ngraph::opset1::Broadcast>(input1, target_shape);
broadcast->set_friendly_name("broadcast");
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{broadcast}, ngraph::ParameterVector{input1});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertBroadcastToTiles>();
ASSERT_NO_THROW(manager.run_passes(f));
ASSERT_NO_THROW(check_rt_info(f));
}
}

View File

@ -54,8 +54,7 @@ public:
f_ref = get_reference_function(input_shape, reduce_type, reference_params); f_ref = get_reference_function(input_shape, reduce_type, reference_params);
} }
private: static std::shared_ptr<ngraph::Function> get_initial_function(const ngraph::PartialShape & input_shape,
std::shared_ptr<ngraph::Function> get_initial_function(const ngraph::PartialShape & input_shape,
const std::vector<int64_t> & axes, const std::vector<int64_t> & axes,
const ReduceType & reduce_type, const ReduceType & reduce_type,
const bool keep_dims) { const bool keep_dims) {
@ -72,7 +71,7 @@ private:
return std::make_shared<ngraph::Function>(ngraph::NodeVector{reduce}, ngraph::ParameterVector{input}); return std::make_shared<ngraph::Function>(ngraph::NodeVector{reduce}, ngraph::ParameterVector{input});
} }
std::shared_ptr<ngraph::Function> get_reference_function(const ngraph::PartialShape & input_shape, static std::shared_ptr<ngraph::Function> get_reference_function(const ngraph::PartialShape & input_shape,
const ReduceType & reduce, const ReduceType & reduce,
const ReduceToPoolParams & params) { const ReduceToPoolParams & params) {
auto param = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape); auto param = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, input_shape);
@ -137,6 +136,10 @@ INSTANTIATE_TEST_CASE_P(ReduceToReshapePoolReshape, ConvertReduceToPoolingTests,
std::make_tuple(MAX, InputShape{2, 9}, ReduceAxes{-1}, KeepDims{true}, ReduceToPoolParams({1, 1, 9, 1}, {9, 1}, {1, 1})), std::make_tuple(MAX, InputShape{2, 9}, ReduceAxes{-1}, KeepDims{true}, ReduceToPoolParams({1, 1, 9, 1}, {9, 1}, {1, 1})),
std::make_tuple(MAX, InputShape{2, 3, 4, 1}, ReduceAxes{1, 3, 2}, KeepDims{false}, ReduceToPoolParams({1, 1, 12, 1}, {12, 1}, {1})))); std::make_tuple(MAX, InputShape{2, 3, 4, 1}, ReduceAxes{1, 3, 2}, KeepDims{false}, ReduceToPoolParams({1, 1, 12, 1}, {12, 1}, {1}))));
TEST(ConvertReduceToPooling, Negative) {
auto f = ConvertReduceToPoolingTests::get_initial_function(
ngraph::PartialShape::dynamic(), {3}, MAX, true);
ASSERT_NO_THROW(ngraph::pass::ConvertReduceToPooling().run_on_function(f));
}
#undef MAX #undef MAX

View File

@ -181,3 +181,29 @@ TEST(TransformationTests, TestSpaceToDepthTransformDepthFirst) {
std::vector<int64_t> shape_end_value_ref{1, 12 * 4, 1080 / 2, 1616 / 2}; std::vector<int64_t> shape_end_value_ref{1, 12 * 4, 1080 / 2, 1616 / 2};
ASSERT_EQ(shape_end_value, shape_end_value_ref); ASSERT_EQ(shape_end_value, shape_end_value_ref);
} }
TEST(TransformationTests, TestSpaceToDepthDynamic) {
auto input = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
std::shared_ptr<ngraph::Function> f(nullptr);
{
auto space_to_depth = std::make_shared<ngraph::op::SpaceToDepth>(input, ngraph::op::SpaceToDepth::SpaceToDepthMode::DEPTH_FIRST, 2);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{space_to_depth}, ngraph::ParameterVector{input});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::ConvertSpaceToDepth>();
ASSERT_NO_THROW(m.run_passes(f));
}
}
TEST(TransformationTests, TestDepthToSpaceDynamic) {
auto input = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
std::shared_ptr<ngraph::Function> f(nullptr);
{
auto depth_to_space = std::make_shared<ngraph::op::DepthToSpace>(input, ngraph::op::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, 2);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{depth_to_space}, ngraph::ParameterVector{input});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::ConvertDepthToSpace>();
ASSERT_NO_THROW(m.run_passes(f));
}
}

View File

@ -55,3 +55,29 @@ TEST(TransformationTests, FQTransposeTest1) {
} }
} }
} }
TEST(TransformationTests, FQTransposeDynamic) {
auto data1 = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
auto data2 = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{3}, {1, 2, 3});
auto data3 = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1, 3}, {1, 2, 3});
auto data4 = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1, 3}, {1, 2, 3});
auto data5 = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1, 3}, {1, 2, 3});
auto transpose_order = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1});
std::shared_ptr<ngraph::Function> f(nullptr);
{
auto fq = std::make_shared<ngraph::op::FakeQuantize>(data1, data2, data3, data4, data5, 1);
auto transpose = std::make_shared<ngraph::op::Transpose>(fq, transpose_order);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{transpose}, ngraph::ParameterVector{data1});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::PullTransposeThroughFQUp>();
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
check_rt_info(f);
});
manager.register_pass<ngraph::pass::ConstantFolding>();
ASSERT_NO_THROW(manager.run_passes(f));
}
}

View File

@ -61,6 +61,9 @@ namespace ngraph
NGRAPH_API NGRAPH_API
std::function<bool(Output<Node>)> has_static_shape(); std::function<bool(Output<Node>)> has_static_shape();
NGRAPH_API
std::function<bool(Output<Node>)> has_static_rank();
NGRAPH_API NGRAPH_API
std::function<bool(Output<Node>)> type_matches(const element::Type& type); std::function<bool(Output<Node>)> type_matches(const element::Type& type);

View File

@ -95,6 +95,13 @@ namespace ngraph
[=](Output<Node> output) -> bool { return output.get_partial_shape().is_static(); }; [=](Output<Node> output) -> bool { return output.get_partial_shape().is_static(); };
} }
std::function<bool(Output<Node>)> has_static_rank()
{
return [=](Output<Node> output) -> bool {
return output.get_partial_shape().rank().is_static();
};
}
std::function<bool(Output<Node>)> type_matches(const element::Type& type) std::function<bool(Output<Node>)> type_matches(const element::Type& type)
{ {
return [=](Output<Node> output) -> bool { return output.get_element_type() == type; }; return [=](Output<Node> output) -> bool { return output.get_element_type() == type; };