This PR introduces next changes: 1. Transformations *_tbl.hpp files were replaced with direct registration in cpp files. 2. Plugins use pass::Manager to call conversion passes. 3. Transformations callback was moved to PassBase class as there is no more need to keep it in separate class 4. All pattern based transformations must be inherited from MatcherPass class. GraphRewrite class will be used only for matchers registration and execution on function. MatcherPass class adds new features to pattern-based transformations approach: * Allows to run matcher pass on a single node. * Operations that were created inside transformation callback can be added to execution list to be available for pattern matching within single GraphRewrite. 5. GraphRewrite MatchClosure was replaced with MatcherPass. So all matchers will be registered as a MatcherPass. 6. Added pass::Manager::clear_state() method to avoid dependency with nodes that no longer belongs to function after replacement. 7. Some representative transformations were updated to use MatcherPass as an example. 8. Mul->Add sequence fusion transformation was replaced with LinOpSequenceFusion. 9. Pattern and callback registration code was moved to class c-tors (will be finished for remaining passes in other PR) . 10. Updated pass::Manager to get pass names only when NGRAPH_PROFILE_PASS_ENABLE enabled. 11. Moving towards removing PassProperty. 12. Added ngraph::pattern::wrap_type<T>(inputs, pred) to simplify pattern creation. 13. GraphRewrite was updated to execute MatcherPass more efficient.
120 lines
4.4 KiB
C++
120 lines
4.4 KiB
C++
//*****************************************************************************
|
|
// Copyright 2017-2020 Intel Corporation
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
//*****************************************************************************
|
|
|
|
#include <algorithm>
|
|
#include <cstdio>
|
|
#include <iostream>
|
|
#include <list>
|
|
#include <memory>
|
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
|
#include <ngraph/rt_info.hpp>
|
|
|
|
#include "gtest/gtest.h"
|
|
#include "ngraph/graph_util.hpp"
|
|
#include "ngraph/log.hpp"
|
|
#include "ngraph/ngraph.hpp"
|
|
#include "ngraph/opsets/opset3.hpp"
|
|
#include "ngraph/pass/graph_rewrite.hpp"
|
|
#include "ngraph/pass/manager.hpp"
|
|
|
|
using namespace ngraph;
|
|
using namespace std;
|
|
|
|
class TestMatcherPass : public pass::MatcherPass
|
|
{
|
|
public:
|
|
TestMatcherPass()
|
|
{
|
|
auto m_relu1 =
|
|
ngraph::pattern::wrap_type<ngraph::opset3::Relu>(pattern::consumers_count(1));
|
|
auto m_relu2 = ngraph::pattern::wrap_type<ngraph::opset3::Relu>({m_relu1});
|
|
|
|
ngraph::graph_rewrite_callback callback = [=](pattern::Matcher& m) {
|
|
// Map that helps to connect labels with matched outputs
|
|
auto& node_to_output = m.get_pattern_value_map();
|
|
|
|
// Create new Relu operation and add register it for additional execution
|
|
auto new_relu = register_new_node<ngraph::opset3::Relu>(
|
|
node_to_output.at(m_relu1).get_node_shared_ptr()->input_value(0));
|
|
|
|
// Copy runtime info attributes to newly created operation
|
|
ngraph::copy_runtime_info(m.get_matched_nodes(), new_relu);
|
|
|
|
// Save last Relu name to new Relu operation
|
|
new_relu->set_friendly_name(m.get_match_root()->get_friendly_name());
|
|
|
|
// Replace Relu->Relu with Relu
|
|
ngraph::replace_node(m.get_match_root(), new_relu);
|
|
|
|
// Return true as the root node was changed
|
|
return true;
|
|
};
|
|
|
|
// Register pattern with Divide operation as a pattern root node
|
|
auto m = std::make_shared<ngraph::pattern::Matcher>(m_relu2, "ReluReluFusion");
|
|
// Register Matcher
|
|
this->register_matcher(m, callback);
|
|
}
|
|
};
|
|
|
|
TEST(pattern, matcher_pass)
|
|
{
|
|
{
|
|
TestMatcherPass test_matcher;
|
|
auto a = make_shared<opset3::Parameter>(element::f32, Shape{1});
|
|
auto b = make_shared<opset3::Relu>(a);
|
|
auto c = make_shared<opset3::Relu>(b);
|
|
auto f = std::make_shared<Function>(ngraph::NodeVector{c}, ParameterVector{a});
|
|
|
|
ASSERT_TRUE(test_matcher.get_matcher()->match(c->output(0)));
|
|
ASSERT_TRUE(test_matcher.get_matcher()->get_matched_nodes().size() == 2);
|
|
test_matcher.get_matcher()->clear_state();
|
|
ASSERT_TRUE(test_matcher.get_matcher()->get_matched_nodes().empty());
|
|
|
|
test_matcher.apply(c);
|
|
ASSERT_TRUE(test_matcher.get_new_nodes().size() == 1);
|
|
test_matcher.apply(test_matcher.get_new_nodes()[0]);
|
|
ASSERT_TRUE(test_matcher.get_new_nodes().empty());
|
|
}
|
|
|
|
{
|
|
TestMatcherPass test_matcher;
|
|
auto a = make_shared<opset3::Parameter>(element::f32, Shape{1});
|
|
auto b = make_shared<opset3::Relu>(a);
|
|
auto c = make_shared<opset3::Relu>(b);
|
|
auto f = std::make_shared<Function>(ngraph::NodeVector{b, c}, ParameterVector{a});
|
|
|
|
ASSERT_FALSE(test_matcher.get_matcher()->match(c->output(0)));
|
|
}
|
|
|
|
{
|
|
std::shared_ptr<Function> f;
|
|
{
|
|
auto a = make_shared<opset3::Parameter>(element::f32, Shape{1});
|
|
auto b = make_shared<opset3::Relu>(a);
|
|
auto c = make_shared<opset3::Relu>(b);
|
|
auto d = make_shared<opset3::Relu>(c);
|
|
f = std::make_shared<Function>(ngraph::NodeVector{d}, ParameterVector{a});
|
|
}
|
|
|
|
pass::GraphRewrite pass;
|
|
pass.add_matcher<TestMatcherPass>();
|
|
pass.run_on_function(f);
|
|
|
|
// Parameter->Relu->Result
|
|
ASSERT_TRUE(f->get_ops().size() == 3);
|
|
}
|
|
} |