[IE] Add RTTI macro to ReshapeFullyConnectedFusion ngrap pass (#2837)
This commit is contained in:
parent
15c10e74fe
commit
4021e144b5
@ -25,13 +25,14 @@
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class ReshapeFullyConnectedFusion;
|
||||
class INFERENCE_ENGINE_API_CLASS(ReshapeFullyConnectedFusion);
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
class ngraph::pass::ReshapeFullyConnectedFusion : public ngraph::pass::GraphRewrite {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
ReshapeFullyConnectedFusion() : GraphRewrite() {
|
||||
construct_reshape_fc();
|
||||
}
|
||||
@ -44,43 +45,5 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
void construct_reshape_fc() {
|
||||
auto m_reshape = pattern::wrap_type<opset1::Reshape>(pattern::has_static_shape());
|
||||
auto m_fc = pattern::wrap_type<op::FullyConnected>({m_reshape,
|
||||
pattern::any_input(),
|
||||
pattern::any_input()});
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [=](pattern::Matcher &m) {
|
||||
auto & pattern_to_output = m.get_pattern_value_map();
|
||||
auto fc = pattern_to_output[m_fc].get_node_shared_ptr();
|
||||
auto reshape = pattern_to_output[m_reshape].get_node_shared_ptr();
|
||||
|
||||
// Check that Reshape reshapes 4D tensor to 2D or input shape = output shape
|
||||
auto shape_in = reshape->input_value(0).get_shape();
|
||||
auto shape_out = reshape->get_shape();
|
||||
if (!((shape_in.size() == 4 && reshape->get_shape().size() == 2) || (shape_in == shape_out && !shape_in.empty()))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that Weights[O, C*H*W] consistent with Input[N, C, H, W]
|
||||
auto shape_w = fc->input_value(1).get_shape();
|
||||
if (shape_in[0] != shape_out[0] || std::accumulate(shape_in.begin() + 1, shape_in.end(), size_t{1}, std::multiplies<size_t>()) != shape_w[1]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto new_fc = std::make_shared<op::FullyConnected>(reshape->input_value(0),
|
||||
fc->input_value(1),
|
||||
fc->input_value(2),
|
||||
fc->get_shape(),
|
||||
fc->output(0).get_element_type());
|
||||
|
||||
new_fc->set_friendly_name(fc->get_friendly_name());
|
||||
ngraph::copy_runtime_info({reshape, fc}, new_fc);
|
||||
ngraph::replace_node(fc, new_fc);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(m_fc, "ReshapeFullyConnectedFusion");
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
}
|
||||
void construct_reshape_fc();
|
||||
};
|
||||
|
@ -0,0 +1,56 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "legacy/transformations/convert_opset1_to_legacy/reshape_fc_fusion.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::ReshapeFullyConnectedFusion, "ReshapeFullyConnectedFusion", 0);
|
||||
|
||||
void ngraph::pass::ReshapeFullyConnectedFusion::construct_reshape_fc() {
|
||||
auto m_reshape = pattern::wrap_type<opset1::Reshape>(pattern::has_static_shape());
|
||||
auto m_fc = pattern::wrap_type<op::FullyConnected>({m_reshape,
|
||||
pattern::any_input(),
|
||||
pattern::any_input()});
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [=](pattern::Matcher &m) {
|
||||
auto & pattern_to_output = m.get_pattern_value_map();
|
||||
auto fc = pattern_to_output[m_fc].get_node_shared_ptr();
|
||||
auto reshape = pattern_to_output[m_reshape].get_node_shared_ptr();
|
||||
|
||||
// Check that Reshape reshapes 4D tensor to 2D or input shape = output shape
|
||||
auto shape_in = reshape->input_value(0).get_shape();
|
||||
auto shape_out = reshape->get_shape();
|
||||
if (!((shape_in.size() == 4 && reshape->get_shape().size() == 2) || (shape_in == shape_out && !shape_in.empty()))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that Weights[O, C*H*W] consistent with Input[N, C, H, W]
|
||||
auto shape_w = fc->input_value(1).get_shape();
|
||||
if (shape_in[0] != shape_out[0] || std::accumulate(shape_in.begin() + 1, shape_in.end(), size_t{1}, std::multiplies<size_t>()) != shape_w[1]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto new_fc = std::make_shared<op::FullyConnected>(reshape->input_value(0),
|
||||
fc->input_value(1),
|
||||
fc->input_value(2),
|
||||
fc->get_shape(),
|
||||
fc->output(0).get_element_type());
|
||||
|
||||
new_fc->set_friendly_name(fc->get_friendly_name());
|
||||
ngraph::copy_runtime_info({reshape, fc}, new_fc);
|
||||
ngraph::replace_node(fc, new_fc);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(m_fc, "ReshapeFullyConnectedFusion");
|
||||
this->add_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
}
|
Loading…
Reference in New Issue
Block a user