[IE] Add RTTI macro to ReshapeFullyConnectedFusion ngrap pass (#2837)

This commit is contained in:
Zoran Zomborat 2020-10-29 06:33:55 +02:00 committed by GitHub
parent 15c10e74fe
commit 4021e144b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 59 additions and 40 deletions

View File

@ -25,13 +25,14 @@
namespace ngraph {
namespace pass {
class ReshapeFullyConnectedFusion;
class INFERENCE_ENGINE_API_CLASS(ReshapeFullyConnectedFusion);
} // namespace pass
} // namespace ngraph
class ngraph::pass::ReshapeFullyConnectedFusion : public ngraph::pass::GraphRewrite {
public:
NGRAPH_RTTI_DECLARATION;
ReshapeFullyConnectedFusion() : GraphRewrite() {
construct_reshape_fc();
}
@ -44,43 +45,5 @@ public:
}
private:
void construct_reshape_fc() {
auto m_reshape = pattern::wrap_type<opset1::Reshape>(pattern::has_static_shape());
auto m_fc = pattern::wrap_type<op::FullyConnected>({m_reshape,
pattern::any_input(),
pattern::any_input()});
ngraph::graph_rewrite_callback callback = [=](pattern::Matcher &m) {
auto & pattern_to_output = m.get_pattern_value_map();
auto fc = pattern_to_output[m_fc].get_node_shared_ptr();
auto reshape = pattern_to_output[m_reshape].get_node_shared_ptr();
// Check that Reshape reshapes 4D tensor to 2D or input shape = output shape
auto shape_in = reshape->input_value(0).get_shape();
auto shape_out = reshape->get_shape();
if (!((shape_in.size() == 4 && reshape->get_shape().size() == 2) || (shape_in == shape_out && !shape_in.empty()))) {
return false;
}
// Check that Weights[O, C*H*W] consistent with Input[N, C, H, W]
auto shape_w = fc->input_value(1).get_shape();
if (shape_in[0] != shape_out[0] || std::accumulate(shape_in.begin() + 1, shape_in.end(), size_t{1}, std::multiplies<size_t>()) != shape_w[1]) {
return false;
}
auto new_fc = std::make_shared<op::FullyConnected>(reshape->input_value(0),
fc->input_value(1),
fc->input_value(2),
fc->get_shape(),
fc->output(0).get_element_type());
new_fc->set_friendly_name(fc->get_friendly_name());
ngraph::copy_runtime_info({reshape, fc}, new_fc);
ngraph::replace_node(fc, new_fc);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(m_fc, "ReshapeFullyConnectedFusion");
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
void construct_reshape_fc();
};

View File

@ -0,0 +1,56 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "legacy/transformations/convert_opset1_to_legacy/reshape_fc_fusion.hpp"
#include <memory>
#include <vector>
#include <string>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::ReshapeFullyConnectedFusion, "ReshapeFullyConnectedFusion", 0);
void ngraph::pass::ReshapeFullyConnectedFusion::construct_reshape_fc() {
auto m_reshape = pattern::wrap_type<opset1::Reshape>(pattern::has_static_shape());
auto m_fc = pattern::wrap_type<op::FullyConnected>({m_reshape,
pattern::any_input(),
pattern::any_input()});
ngraph::graph_rewrite_callback callback = [=](pattern::Matcher &m) {
auto & pattern_to_output = m.get_pattern_value_map();
auto fc = pattern_to_output[m_fc].get_node_shared_ptr();
auto reshape = pattern_to_output[m_reshape].get_node_shared_ptr();
// Check that Reshape reshapes 4D tensor to 2D or input shape = output shape
auto shape_in = reshape->input_value(0).get_shape();
auto shape_out = reshape->get_shape();
if (!((shape_in.size() == 4 && reshape->get_shape().size() == 2) || (shape_in == shape_out && !shape_in.empty()))) {
return false;
}
// Check that Weights[O, C*H*W] consistent with Input[N, C, H, W]
auto shape_w = fc->input_value(1).get_shape();
if (shape_in[0] != shape_out[0] || std::accumulate(shape_in.begin() + 1, shape_in.end(), size_t{1}, std::multiplies<size_t>()) != shape_w[1]) {
return false;
}
auto new_fc = std::make_shared<op::FullyConnected>(reshape->input_value(0),
fc->input_value(1),
fc->input_value(2),
fc->get_shape(),
fc->output(0).get_element_type());
new_fc->set_friendly_name(fc->get_friendly_name());
ngraph::copy_runtime_info({reshape, fc}, new_fc);
ngraph::replace_node(fc, new_fc);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(m_fc, "ReshapeFullyConnectedFusion");
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
}