Add dynamic shape checks for legacy transformations (#2783)

* Added dynamic shape checks for ConvertInterpolate pass

* Added dynamic checks for ConvertLRNToLegacy pass

* Added dynamic checks for ConvertMatMul* pass

* Added dynamic checks for ConvertPadToLegacy pass

* Updated TileIE; added dynamic checks to ConvertTileToLegacy pass

* Added dynamic checks to FCBiasFusion pass

* Added dynamic checks to Reshape1DOps pass

* Added dynamic checks to ReshapeFCFusion pass

* Added dynamic checks to ReshapeFC pass

* Updaed Reshape1DConvolution pattern
This commit is contained in:
Gleb Kazantaev 2020-10-28 10:36:16 +03:00 committed by GitHub
parent c7661078d9
commit 91afa14901
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 320 additions and 98 deletions

View File

@ -17,6 +17,7 @@
#include <ngraph/ngraph.hpp> #include <ngraph/ngraph.hpp>
#include <ngraph/opsets/opset1.hpp> #include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp> #include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pass/graph_rewrite.hpp> #include <ngraph/pass/graph_rewrite.hpp>
#include <transformations/utils/utils.hpp> #include <transformations/utils/utils.hpp>
@ -44,25 +45,15 @@ public:
private: private:
void construct_reshape_fc() { void construct_reshape_fc() {
auto input = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 4}); auto m_reshape = pattern::wrap_type<opset1::Reshape>(pattern::has_static_shape());
auto m_fc = pattern::wrap_type<op::FullyConnected>({m_reshape,
pattern::any_input(),
pattern::any_input()});
auto reshape_shape = std::make_shared<pattern::op::Label>(element::i64, Shape{4}); ngraph::graph_rewrite_callback callback = [=](pattern::Matcher &m) {
auto reshape = std::make_shared<ngraph::opset1::Reshape>(input, reshape_shape, true); auto & pattern_to_output = m.get_pattern_value_map();
auto fc = pattern_to_output[m_fc].get_node_shared_ptr();
auto weights = std::make_shared<pattern::op::Label>(element::f32, Shape{2, 4}); auto reshape = pattern_to_output[m_reshape].get_node_shared_ptr();
auto biases = std::make_shared<pattern::op::Label>(element::f32, Shape{2});
auto fc = std::make_shared<ngraph::op::FullyConnected>(reshape, weights, biases, Shape{1, 2});
ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) {
auto fc = std::dynamic_pointer_cast<ngraph::op::FullyConnected>(m.get_match_root());
if (!fc) {
return false;
}
auto reshape = std::dynamic_pointer_cast<ngraph::opset1::Reshape>(fc->input_value(0).get_node_shared_ptr());
if (!reshape) {
return false;
}
// Check that Reshape reshapes 4D tensor to 2D or input shape = output shape // Check that Reshape reshapes 4D tensor to 2D or input shape = output shape
auto shape_in = reshape->input_value(0).get_shape(); auto shape_in = reshape->input_value(0).get_shape();
@ -89,7 +80,7 @@ private:
return true; return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(fc, "ReshapeFullyConnectedFusion"); auto m = std::make_shared<ngraph::pattern::Matcher>(m_fc, "ReshapeFullyConnectedFusion");
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE); this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
} }
}; };

View File

@ -21,18 +21,23 @@ op::TileIE::TileIE(const Output<ngraph::Node>& data1, const int64_t axis, const
} }
std::shared_ptr<Node> op::TileIE::clone_with_new_inputs(const OutputVector& new_args) const { std::shared_ptr<Node> op::TileIE::clone_with_new_inputs(const OutputVector& new_args) const {
if (new_args.size() != 1) { check_new_args_count(this, new_args);
throw ngraph_error("Incorrect number of new arguments");
}
return make_shared<TileIE>(new_args.at(0), axis, tiles); return make_shared<TileIE>(new_args.at(0), axis, tiles);
} }
void op::TileIE::validate_and_infer_types() { void op::TileIE::validate_and_infer_types() {
auto input_shape = get_input_partial_shape(0).to_shape(); const auto & input_pshape = get_input_partial_shape(0);
auto output_pshape = PartialShape::dynamic();
if (input_pshape.rank().is_static()) {
const auto & rank = input_pshape.rank().get_length();
NODE_VALIDATION_CHECK(this,
axis >= 0 && axis < rank,
"Axis: ", axis, " must be >= 0 and less than ", rank, "(input rank)");
output_pshape = input_pshape;
if (output_pshape[axis].is_static()) {
output_pshape[axis] *= tiles;
}
}
ngraph::Shape output_shape(input_shape); set_output_type(0, get_input_element_type(0), output_pshape);
output_shape[axis] *= tiles;
set_output_type(0, get_input_element_type(0), PartialShape(output_shape));
} }

View File

@ -12,15 +12,15 @@
#include <ngraph/opsets/opset1.hpp> #include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp> #include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <legacy/ngraph_ops/interp.hpp> #include <legacy/ngraph_ops/interp.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertInterpolateToInterpOrResampleMatcher, "ConvertInterpolateToInterpOrResampleMatcher", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertInterpolateToInterpOrResampleMatcher, "ConvertInterpolateToInterpOrResampleMatcher", 0);
ngraph::pass::ConvertInterpolateToInterpOrResampleMatcher::ConvertInterpolateToInterpOrResampleMatcher() { ngraph::pass::ConvertInterpolateToInterpOrResampleMatcher::ConvertInterpolateToInterpOrResampleMatcher() {
auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1}); auto interpolate = pattern::wrap_type<opset1::Interpolate>({pattern::any_input(pattern::has_static_shape()),
auto shp = std::make_shared<pattern::op::Label>(element::i64, Shape{2}); pattern::wrap_type<opset1::Constant>()});
auto interpolate = std::make_shared<ngraph::opset1::Interpolate>(data, shp, ngraph::op::v0::InterpolateAttrs());
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) { ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto interpolate = std::dynamic_pointer_cast<ngraph::opset1::Interpolate> (m.get_match_root()); auto interpolate = std::dynamic_pointer_cast<ngraph::opset1::Interpolate> (m.get_match_root());

View File

@ -10,15 +10,16 @@
#include <ngraph/opsets/opset1.hpp> #include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp> #include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <legacy/ngraph_ops/lrn_ie.hpp> #include <legacy/ngraph_ops/lrn_ie.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertLRNToLegacyMatcher, "ConvertLRNToLegacyMatcher", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertLRNToLegacyMatcher, "ConvertLRNToLegacyMatcher", 0);
ngraph::pass::ConvertLRNToLegacyMatcher::ConvertLRNToLegacyMatcher() { ngraph::pass::ConvertLRNToLegacyMatcher::ConvertLRNToLegacyMatcher() {
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1}); auto lrn = pattern::wrap_type<opset1::LRN>({pattern::any_input(),
auto input_1 = std::make_shared<pattern::op::Label>(element::i64, Shape{1}); pattern::wrap_type<opset1::Constant>()},
auto lrn = std::make_shared<ngraph::opset1::LRN>(input_0, input_1, 1, 1, 1, 1); pattern::has_static_rank());
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) { ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto lrn = std::dynamic_pointer_cast<ngraph::opset1::LRN> (m.get_match_root()); auto lrn = std::dynamic_pointer_cast<ngraph::opset1::LRN> (m.get_match_root());
@ -36,7 +37,7 @@ ngraph::pass::ConvertLRNToLegacyMatcher::ConvertLRNToLegacyMatcher() {
if (axis_value.size() == 1 && axis_value[0] == 1) { if (axis_value.size() == 1 && axis_value[0] == 1) {
region = "across"; region = "across";
} else { } else {
std::vector<bool> norm(lrn->get_shape().size(), false); std::vector<bool> norm(lrn->get_output_partial_shape(0).rank().get_length(), false);
for (auto & axis : axis_value) { for (auto & axis : axis_value) {
if (axis < 0 || static_cast<size_t>(axis) >= norm.size()) { if (axis < 0 || static_cast<size_t>(axis) >= norm.size()) {
return false; return false;

View File

@ -13,6 +13,7 @@
#include <ngraph/opsets/opset1.hpp> #include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp> #include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <legacy/ngraph_ops/fully_connected.hpp> #include <legacy/ngraph_ops/fully_connected.hpp>
#include <transformations/utils/utils.hpp> #include <transformations/utils/utils.hpp>
@ -20,9 +21,9 @@
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertMatMulToFC, "ConvertMatMulToFC", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertMatMulToFC, "ConvertMatMulToFC", 0);
ngraph::pass::ConvertMatMulToFC::ConvertMatMulToFC() { ngraph::pass::ConvertMatMulToFC::ConvertMatMulToFC() {
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1}); auto matmul = pattern::wrap_type<opset1::MatMul>({pattern::any_input(pattern::has_static_shape()),
auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1}); pattern::any_input(pattern::has_static_shape())},
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input_0, input_1); pattern::has_static_shape());
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) { ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(m.get_match_root()); auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(m.get_match_root());
@ -163,9 +164,9 @@ ngraph::pass::ConvertMatMulToFC::ConvertMatMulToFC() {
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertMatMulToGemm, "ConvertMatMulToGemm", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertMatMulToGemm, "ConvertMatMulToGemm", 0);
ngraph::pass::ConvertMatMulToGemm::ConvertMatMulToGemm() { ngraph::pass::ConvertMatMulToGemm::ConvertMatMulToGemm() {
auto input_0 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1}); auto matmul = pattern::wrap_type<opset1::MatMul>({pattern::any_input(pattern::has_static_shape()),
auto input_1 = std::make_shared<pattern::op::Label>(element::f32, Shape {1, 1}); pattern::any_input(pattern::has_static_shape())},
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input_0, input_1); pattern::has_static_shape());
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) { ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(m.get_match_root()); auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(m.get_match_root());

View File

@ -14,7 +14,7 @@
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertPadToLegacyMatcher, "ConvertPadToLegacyMatcher", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertPadToLegacyMatcher, "ConvertPadToLegacyMatcher", 0);
ngraph::pass::ConvertPadToLegacyMatcher::ConvertPadToLegacyMatcher() { ngraph::pass::ConvertPadToLegacyMatcher::ConvertPadToLegacyMatcher() {
auto m_pad = ngraph::pattern::wrap_type<ngraph::opset1::Pad>(); auto m_pad = ngraph::pattern::wrap_type<ngraph::opset1::Pad>(pattern::has_static_shape());
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) { ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto pad = std::dynamic_pointer_cast<ngraph::opset1::Pad> (m.get_match_root()); auto pad = std::dynamic_pointer_cast<ngraph::opset1::Pad> (m.get_match_root());

View File

@ -8,6 +8,7 @@
#include <vector> #include <vector>
#include <ngraph/opsets/opset1.hpp> #include <ngraph/opsets/opset1.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <legacy/ngraph_ops/tile_ie.hpp> #include <legacy/ngraph_ops/tile_ie.hpp>
#include <ngraph/rt_info.hpp> #include <ngraph/rt_info.hpp>
@ -15,9 +16,8 @@
NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertTileToLegacyMatcher, "ConvertTileToLegacyMatcher", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertTileToLegacyMatcher, "ConvertTileToLegacyMatcher", 0);
ngraph::pass::ConvertTileToLegacyMatcher::ConvertTileToLegacyMatcher() { ngraph::pass::ConvertTileToLegacyMatcher::ConvertTileToLegacyMatcher() {
auto data = std::make_shared<pattern::op::Label>(element::f32, Shape{1, 1, 1, 1}); auto tile = pattern::wrap_type<ngraph::opset1::Tile>({pattern::any_input(pattern::has_static_rank()),
auto shp = std::make_shared<pattern::op::Label>(element::i64, Shape{4}); pattern::wrap_type<opset1::Constant>()});
auto tile = std::make_shared<ngraph::opset1::Tile>(data, shp);
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) { ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto tile = std::dynamic_pointer_cast<ngraph::opset1::Tile> (m.get_match_root()); auto tile = std::dynamic_pointer_cast<ngraph::opset1::Tile> (m.get_match_root());
@ -25,15 +25,14 @@ ngraph::pass::ConvertTileToLegacyMatcher::ConvertTileToLegacyMatcher() {
return false; return false;
} }
auto data_node = tile->input_value(0).get_node_shared_ptr();
auto tiles_node = std::dynamic_pointer_cast<ngraph::opset1::Constant> (tile->input_value(1).get_node_shared_ptr()); auto tiles_node = std::dynamic_pointer_cast<ngraph::opset1::Constant> (tile->input_value(1).get_node_shared_ptr());
if (!data_node || !tiles_node) return false; if (!tiles_node) return false;
auto tiles = tiles_node->cast_vector<int64_t>(); auto tiles = tiles_node->cast_vector<int64_t>();
auto input_shape = data_node->get_shape(); auto input_shape_rank = tile->get_input_partial_shape(0).rank().get_length();
int64_t cur_dim_id = tiles.size() - 1; int64_t cur_dim_id = tiles.size() - 1;
if (tiles.size() != input_shape.size()) return false; if (tiles.size() != input_shape_rank) return false;
// IE Tile operations supports only one axis to be tiled // IE Tile operations supports only one axis to be tiled
// bool already_set = false; // bool already_set = false;
@ -48,9 +47,7 @@ ngraph::pass::ConvertTileToLegacyMatcher::ConvertTileToLegacyMatcher() {
// } // }
// //
// if (!already_set) return false; // if (!already_set) return false;
auto last_node = std::dynamic_pointer_cast<ngraph::Node>(data_node); auto last_node = tile->input_value(0);
if (!last_node)
return false;
auto friendly_name = tile->get_friendly_name(); auto friendly_name = tile->get_friendly_name();
int num_of_tile_dims = 0; int num_of_tile_dims = 0;
@ -78,17 +75,17 @@ ngraph::pass::ConvertTileToLegacyMatcher::ConvertTileToLegacyMatcher() {
auto ie_tile = std::make_shared<ngraph::op::TileIE>(last_node, cur_dim_id, tile_dim); auto ie_tile = std::make_shared<ngraph::op::TileIE>(last_node, cur_dim_id, tile_dim);
ie_tile->set_friendly_name(friendly_name); ie_tile->set_friendly_name(friendly_name);
friendly_name += "_" + std::to_string(cur_dim_id); friendly_name += "_" + std::to_string(cur_dim_id);
new_ops.push_back(ie_tile);
last_node = std::dynamic_pointer_cast<ngraph::Node>(ie_tile); last_node = ie_tile;
new_ops.push_back(last_node);
} }
--cur_dim_id; --cur_dim_id;
++tiles_it; ++tiles_it;
} }
last_node->set_friendly_name(tile->get_friendly_name()); last_node.get_node_shared_ptr()->set_friendly_name(tile->get_friendly_name());
ngraph::copy_runtime_info(tile, new_ops); ngraph::copy_runtime_info(tile, new_ops);
ngraph::replace_node(tile, last_node); ngraph::replace_node(tile, {last_node});
return true; return true;
}; };

View File

@ -15,38 +15,33 @@
NGRAPH_RTTI_DEFINITION(ngraph::pass::FullyConnectedBiasFusion, "FullyConnectedBiasFusion", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::FullyConnectedBiasFusion, "FullyConnectedBiasFusion", 0);
ngraph::pass::FullyConnectedBiasFusion::FullyConnectedBiasFusion() { ngraph::pass::FullyConnectedBiasFusion::FullyConnectedBiasFusion() {
auto fc = ngraph::pattern::wrap_type<op::FullyConnected>(); auto m_fc = ngraph::pattern::wrap_type<op::FullyConnected>([](Output<Node> output) {
auto add = ngraph::pattern::wrap_type<opset1::Add>({fc, std::make_shared<pattern::op::Label>()}); return pattern::consumers_count(1)(output) &&
pattern::has_static_shape()(output);
});
auto m_bias = pattern::any_input();
auto m_add = ngraph::pattern::wrap_type<opset1::Add>({m_fc, m_bias});
ngraph::graph_rewrite_callback callback = [](pattern::Matcher &m) { ngraph::matcher_pass_callback callback = [=](pattern::Matcher &m) {
auto add = m.get_match_root(); auto & pattern_to_output = m.get_pattern_value_map();
auto add_input_0 = add->input(0).get_source_output().get_node_shared_ptr();
auto add_input_1 = add->input(1).get_source_output().get_node_shared_ptr();
auto m_fc = std::dynamic_pointer_cast<op::FullyConnected>(add_input_0); auto add = pattern_to_output[m_add].get_node_shared_ptr();
auto m_bias = add_input_1; auto bias = pattern_to_output[m_bias].get_node_shared_ptr();
auto fc = std::dynamic_pointer_cast<op::FullyConnected>(pattern_to_output[m_fc].get_node_shared_ptr());
if (m_fc == nullptr) { if (!fc) {
m_fc = std::dynamic_pointer_cast<op::FullyConnected>(add_input_1);
if (m_fc == nullptr)
return false;
m_bias = add_input_0;
}
if (auto bcast_m = std::dynamic_pointer_cast<opset1::Broadcast>(m_bias)) {
m_bias = bcast_m->input(0).get_source_output().get_node_shared_ptr();
}
if (!std::dynamic_pointer_cast<opset1::Constant>(m_bias)) {
return false;
}
Shape bias_shape(m_bias->get_shape());
if (m_fc->output(0).get_target_inputs().size() != 1) {
return false; return false;
} }
Shape output_shape(m_fc->get_shape()); if (auto bcast = std::dynamic_pointer_cast<opset1::Broadcast>(bias)) {
bias = bcast->input_value(0).get_node_shared_ptr();
}
if (!std::dynamic_pointer_cast<opset1::Constant>(bias)) {
return false;
}
Shape bias_shape(bias->get_shape());
Shape output_shape(fc->get_shape());
size_t bias_size = std::accumulate(bias_shape.begin(), bias_shape.end(), size_t{1}, std::multiplies<int64_t>()); size_t bias_size = std::accumulate(bias_shape.begin(), bias_shape.end(), size_t{1}, std::multiplies<int64_t>());
if (bias_shape.empty() || bias_shape.back() != output_shape.back() || bias_shape.back() != bias_size) { if (bias_shape.empty() || bias_shape.back() != output_shape.back() || bias_shape.back() != bias_size) {
return false; return false;
@ -54,7 +49,7 @@ ngraph::pass::FullyConnectedBiasFusion::FullyConnectedBiasFusion() {
NodeVector new_ops; NodeVector new_ops;
auto new_bias = std::make_shared<opset1::Add>(m_fc->input(2).get_source_output(), m_bias); auto new_bias = std::make_shared<opset1::Add>(fc->input(2).get_source_output(), bias);
new_ops.push_back(new_bias); new_ops.push_back(new_bias);
std::shared_ptr<Node> final_bias = new_bias; std::shared_ptr<Node> final_bias = new_bias;
if (new_bias->get_shape().size() >= 2) { if (new_bias->get_shape().size() >= 2) {
@ -62,19 +57,19 @@ ngraph::pass::FullyConnectedBiasFusion::FullyConnectedBiasFusion() {
new_ops.push_back(final_bias); new_ops.push_back(final_bias);
} }
auto new_fc = std::make_shared<op::FullyConnected>(m_fc->input(0).get_source_output(), auto new_fc = std::make_shared<op::FullyConnected>(fc->input(0).get_source_output(),
m_fc->input(1).get_source_output(), fc->input(1).get_source_output(),
final_bias, final_bias,
m_fc->get_shape(), fc->get_shape(),
m_fc->get_output_type()); fc->get_output_type());
new_ops.push_back(new_fc); new_ops.push_back(new_fc);
new_fc->set_friendly_name(add->get_friendly_name()); new_fc->set_friendly_name(add->get_friendly_name());
ngraph::copy_runtime_info({m_fc, add}, new_ops); ngraph::copy_runtime_info({fc, add}, new_ops);
ngraph::replace_node(add, new_fc); ngraph::replace_node(add, new_fc);
return true; return true;
}; };
auto m = std::make_shared<ngraph::pattern::Matcher>(add, "FullyConnectedBiasFusion"); auto m = std::make_shared<ngraph::pattern::Matcher>(m_add, "FullyConnectedBiasFusion");
this->register_matcher(m, callback); this->register_matcher(m, callback);
} }

View File

@ -109,7 +109,7 @@ std::shared_ptr<Node> convert(const Output<Node> & data, std::shared_ptr<opset1:
matcher_pass_callback get_callback() { matcher_pass_callback get_callback() {
return [](pattern::Matcher& m) { return [](pattern::Matcher& m) {
auto node = m.get_match_root(); auto node = m.get_match_root();
if (!node || node->input(0).get_partial_shape().rank().get_length() != 3) { if (node->input(0).get_partial_shape().rank().get_length() != 3) {
return false; return false;
} }
@ -154,7 +154,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::Reshape1DOps, "Reshape1DOps", 0);
NGRAPH_RTTI_DEFINITION(ngraph::pass::Reshape1DConvolution, "Reshape1DConvolution", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::Reshape1DConvolution, "Reshape1DConvolution", 0);
ngraph::pass::Reshape1DConvolution::Reshape1DConvolution() { ngraph::pass::Reshape1DConvolution::Reshape1DConvolution() {
auto conv = ngraph::pattern::wrap_type<op::ConvolutionIE>(); auto conv = ngraph::pattern::wrap_type<op::ConvolutionIE>(pattern::has_static_shape());
auto m = std::make_shared<ngraph::pattern::Matcher>(conv, "Reshape1DConvolution"); auto m = std::make_shared<ngraph::pattern::Matcher>(conv, "Reshape1DConvolution");
this->register_matcher(m, get_callback()); this->register_matcher(m, get_callback());
} }
@ -162,7 +162,7 @@ ngraph::pass::Reshape1DConvolution::Reshape1DConvolution() {
NGRAPH_RTTI_DEFINITION(ngraph::pass::Reshape1DAvgPool, "Reshape1DAvgPool", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::Reshape1DAvgPool, "Reshape1DAvgPool", 0);
ngraph::pass::Reshape1DAvgPool::Reshape1DAvgPool() { ngraph::pass::Reshape1DAvgPool::Reshape1DAvgPool() {
auto pool = ngraph::pattern::wrap_type<opset1::AvgPool>(); auto pool = ngraph::pattern::wrap_type<opset1::AvgPool>(pattern::has_static_shape());
auto m = std::make_shared<ngraph::pattern::Matcher>(pool, "Reshape1DAvgPool"); auto m = std::make_shared<ngraph::pattern::Matcher>(pool, "Reshape1DAvgPool");
this->register_matcher(m, get_callback()); this->register_matcher(m, get_callback());
} }
@ -170,7 +170,7 @@ ngraph::pass::Reshape1DAvgPool::Reshape1DAvgPool() {
NGRAPH_RTTI_DEFINITION(ngraph::pass::Reshape1DMaxPool, "Reshape1DMaxPool", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::Reshape1DMaxPool, "Reshape1DMaxPool", 0);
ngraph::pass::Reshape1DMaxPool::Reshape1DMaxPool() { ngraph::pass::Reshape1DMaxPool::Reshape1DMaxPool() {
auto pool = ngraph::pattern::wrap_type<opset1::MaxPool>(); auto pool = ngraph::pattern::wrap_type<opset1::MaxPool>(pattern::has_static_shape());
auto m = std::make_shared<ngraph::pattern::Matcher>(pool, "Reshape1DMaxPool"); auto m = std::make_shared<ngraph::pattern::Matcher>(pool, "Reshape1DMaxPool");
this->register_matcher(m, get_callback()); this->register_matcher(m, get_callback());
} }

View File

@ -9,6 +9,7 @@
#include <ngraph/opsets/opset1.hpp> #include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp> #include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include "legacy/ngraph_ops/fully_connected.hpp" #include "legacy/ngraph_ops/fully_connected.hpp"
#include "transformations/utils/utils.hpp" #include "transformations/utils/utils.hpp"
@ -16,10 +17,10 @@
NGRAPH_RTTI_DEFINITION(ngraph::pass::ReshapeFullyConnected, "ReshapeFullyConnected", 0); NGRAPH_RTTI_DEFINITION(ngraph::pass::ReshapeFullyConnected, "ReshapeFullyConnected", 0);
ngraph::pass::ReshapeFullyConnected::ReshapeFullyConnected() { ngraph::pass::ReshapeFullyConnected::ReshapeFullyConnected() {
auto input0 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1}); auto fc = pattern::wrap_type<op::FullyConnected>({pattern::any_input(pattern::has_static_shape()),
auto input1 = std::make_shared<pattern::op::Label>(element::i64, Shape{1, 1}); pattern::any_input(),
auto input2 = std::make_shared<pattern::op::Label>(element::i64, Shape{1}); pattern::any_input()},
auto fc = std::make_shared<ngraph::op::FullyConnected>(input0, input1, input2, Shape{1, 1}); pattern::has_static_shape());
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto fc = std::dynamic_pointer_cast<ngraph::op::FullyConnected> (m.get_match_root()); auto fc = std::dynamic_pointer_cast<ngraph::op::FullyConnected> (m.get_match_root());

View File

@ -0,0 +1,31 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <legacy/transformations/convert_opset1_to_legacy/convert_interpolate_to_interp_or_resample.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, ConvertInterpolateDynamic) {
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
auto shape = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {30, 60});
auto interp = std::make_shared<ngraph::opset1::Interpolate>(data, shape, ngraph::op::v0::InterpolateAttrs());
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{interp}, ngraph::ParameterVector{data});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::ConvertInterpolateToInterpOrResampleMatcher>();
ASSERT_NO_THROW(m.run_passes(f));
}

View File

@ -0,0 +1,31 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <legacy/transformations/convert_opset1_to_legacy/convert_lrn_to_lrn_ie.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, ConvertLRNToLegacyDynamic) {
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
auto axis = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
auto lrn = std::make_shared<ngraph::opset1::LRN>(data, axis, 1, 2, 3, 4);
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{lrn}, ngraph::ParameterVector{data});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::ConvertLRNToLegacyMatcher>();
ASSERT_NO_THROW(m.run_passes(f));
}

View File

@ -252,3 +252,18 @@ TEST(TransformationTests, ConvertMatMulTest7) {
auto res = compare_functions(f, f_ref); auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second; ASSERT_TRUE(res.first) << res.second;
} }
TEST(TransformationTests, ConvertMatMulDynamic) {
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{2, 2}, {1});
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, input2, false, true);
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{input1});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<ngraph::pass::ConvertMatMulToFC>();
m.register_pass<ngraph::pass::ConvertMatMulToGemm>();
m.register_pass<ngraph::pass::ReshapeFullyConnected>();
ASSERT_NO_THROW(m.run_passes(f));
}

View File

@ -0,0 +1,32 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <legacy/transformations/convert_opset1_to_legacy/convert_pad_to_pad_ie.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
TEST(TransformationTests, ConvertPadToLegacyDynamic) {
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
auto pad_begin = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {1, 0});
auto pad_end = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {0, 1});
auto pad = std::make_shared<ngraph::opset1::Pad>(data, pad_begin, pad_end, ngraph::op::PadMode::SYMMETRIC);
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{pad}, ngraph::ParameterVector{data});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::ConvertPadToLegacyMatcher>();
ASSERT_NO_THROW(m.run_passes(f));
}

View File

@ -0,0 +1,48 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <legacy/transformations/convert_opset1_to_legacy/convert_tile_to_ie_tile.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <ngraph/pass/manager.hpp>
#include <legacy/ngraph_ops/tile_ie.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ngraph;
TEST(TransformationTests, ConvertTileToLegacyDynamic1) {
auto data = std::make_shared<opset1::Parameter>(element::f32, PartialShape{1, Dimension::dynamic()});
auto axes = opset1::Constant::create(element::i64, Shape{1}, {0});
auto tile = std::make_shared<opset1::Tile>(data, axes);
auto f = std::make_shared<Function>(NodeVector{tile}, ParameterVector{data});
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertTileToLegacyMatcher>();
ASSERT_NO_THROW(manager.run_passes(f));
ASSERT_NO_THROW(check_rt_info(f));
}
TEST(TransformationTests, ConvertTileToLegacyDynamic2) {
auto data = std::make_shared<opset1::Parameter>(element::f32, PartialShape::dynamic());
auto axes = opset1::Constant::create(element::i64, Shape{1}, {0});
auto tile = std::make_shared<opset1::Tile>(data, axes);
auto f = std::make_shared<Function>(NodeVector{tile}, ParameterVector{data});
pass::Manager manager;
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::ConvertTileToLegacyMatcher>();
ASSERT_NO_THROW(manager.run_passes(f));
ASSERT_NO_THROW(check_rt_info(f));
}

View File

@ -94,3 +94,18 @@ TEST(TransformationTests, FullyConnectedBiasFusionTest2D) {
auto res = compare_functions(f, f_ref); auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second; ASSERT_TRUE(res.first) << res.second;
} }
TEST(TransformationTests, FullyConnectedBiasFusionDynamic) {
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786, 128}, {1});
auto empty_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{786}, {0});
auto fc = std::make_shared<ngraph::op::FullyConnected>(input1, weights, empty_bias, ngraph::Shape{1, 786});
auto const_bias = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 786}, {1});
auto add = std::make_shared<ngraph::opset1::Add>(fc, const_bias);
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{add}, ngraph::ParameterVector{input1});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::FullyConnectedBiasFusion>();
ASSERT_NO_THROW(manager.run_passes(f));
}

View File

@ -18,6 +18,7 @@
#include <legacy/transformations/convert_opset1_to_legacy/reshape_1d_ops.hpp> #include <legacy/transformations/convert_opset1_to_legacy/reshape_1d_ops.hpp>
#include <transformations/init_node_info.hpp> #include <transformations/init_node_info.hpp>
#include <ngraph/opsets/opset1.hpp> #include <ngraph/opsets/opset1.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp" #include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing; using namespace testing;
@ -153,3 +154,46 @@ TEST(TransformationTests, AvgPoolReshapeTest1) {
auto res = compare_functions(f, f_ref); auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second; ASSERT_TRUE(res.first) << res.second;
} }
TEST(TransformationTests, ReshapeDynamicTest1) {
{
auto input = std::make_shared<opset1::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
ngraph::Strides strides{1};
ngraph::Shape pads_begin{0}, pads_end{0}, kernel{3};
auto pool = std::make_shared<ngraph::opset1::AvgPool>(input, strides, pads_begin, pads_end, kernel, false, ngraph::op::RoundingType::FLOOR);
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{pool}, ngraph::ParameterVector{input});
pass::Manager manager;
manager.register_pass<ngraph::pass::Reshape1DOps>();
ASSERT_NO_THROW(manager.run_passes(f));
}
{
auto input = std::make_shared<opset1::Parameter>(ngraph::element::f32, ngraph::Shape{1, 3, 64});
ngraph::Strides strides{1};
ngraph::Shape pads_begin{0}, pads_end{0}, kernel{3};
auto pool = std::make_shared<ngraph::opset1::MaxPool>(input, strides, pads_begin, pads_end, kernel, ngraph::op::RoundingType::FLOOR);
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{pool}, ngraph::ParameterVector{input});
pass::Manager manager;
manager.register_pass<ngraph::pass::Reshape1DOps>();
ASSERT_NO_THROW(manager.run_passes(f));
}
{
auto input = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1, 3, 64}, {1});
auto w = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{6, 3, 3/*OIW*/}, {1});
auto b = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{6}, {1});
ngraph::Strides strides{1}, dilations{1};
ngraph::CoordinateDiff pads_begin{0}, pads_end{0};
ngraph::Shape output_shape{1, 6, 62};
auto conv = std::make_shared<ngraph::op::ConvolutionIE>(input, w, b, strides, dilations, pads_begin, pads_end, 1);
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{conv}, ngraph::ParameterVector{});
pass::Manager manager;
manager.register_pass<ngraph::pass::Reshape1DOps>();
ASSERT_NO_THROW(manager.run_passes(f));
}
}

View File

@ -78,3 +78,18 @@ TEST(TransformationTests, ReshapeFCFusiuonTest3) {
} }
ASSERT_EQ(f->get_ops().size(), 7); ASSERT_EQ(f->get_ops().size(), 7);
} }
TEST(TransformationTests, ReshapeFCFusiuonDynamic) {
auto input = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1, 3, 64, 64}, {1});
auto reshape_shape = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{4}, {1, 3, 64, 64});
auto fc_weights = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{6, 3 * 64 * 64}, {1});
auto fc_biases = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{6}, {1});
auto reshape = std::make_shared<ngraph::op::v1::Reshape>(input, reshape_shape, true);
auto fc = std::make_shared<ngraph::op::FullyConnected>(reshape, fc_weights, fc_biases, ngraph::Shape{1, 6});
auto f = std::make_shared<ngraph::Function>(ngraph::NodeVector{fc}, ngraph::ParameterVector{});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::ReshapeFullyConnectedFusion>();
ASSERT_NO_THROW(manager.run_passes(f));
}