//***************************************************************************** // Copyright 2017-2020 Intel Corporation // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //***************************************************************************** #include #include #include #include #include #include #include #include "gtest/gtest.h" #include "ngraph/graph_util.hpp" #include "ngraph/log.hpp" #include "ngraph/ngraph.hpp" #include "ngraph/opsets/opset3.hpp" #include "ngraph/pass/graph_rewrite.hpp" #include "ngraph/pass/manager.hpp" using namespace ngraph; using namespace std; class TestMatcherPass : public pass::MatcherPass { public: TestMatcherPass() { auto m_relu1 = ngraph::pattern::wrap_type(pattern::consumers_count(1)); auto m_relu2 = ngraph::pattern::wrap_type({m_relu1}); ngraph::graph_rewrite_callback callback = [=](pattern::Matcher& m) { // Map that helps to connect labels with matched outputs auto& node_to_output = m.get_pattern_value_map(); // Create new Relu operation and add register it for additional execution auto new_relu = register_new_node( node_to_output.at(m_relu1).get_node_shared_ptr()->input_value(0)); // Copy runtime info attributes to newly created operation ngraph::copy_runtime_info(m.get_matched_nodes(), new_relu); // Save last Relu name to new Relu operation new_relu->set_friendly_name(m.get_match_root()->get_friendly_name()); // Replace Relu->Relu with Relu ngraph::replace_node(m.get_match_root(), new_relu); // Return true as the root node was changed return true; }; // Register pattern with Divide operation as a pattern root node auto m = std::make_shared(m_relu2, "ReluReluFusion"); // Register Matcher this->register_matcher(m, callback); } }; TEST(pattern, matcher_pass) { { TestMatcherPass test_matcher; auto a = make_shared(element::f32, Shape{1}); auto b = make_shared(a); auto c = make_shared(b); auto f = std::make_shared(ngraph::NodeVector{c}, ParameterVector{a}); ASSERT_TRUE(test_matcher.get_matcher()->match(c->output(0))); ASSERT_TRUE(test_matcher.get_matcher()->get_matched_nodes().size() == 2); test_matcher.get_matcher()->clear_state(); ASSERT_TRUE(test_matcher.get_matcher()->get_matched_nodes().empty()); test_matcher.apply(c); ASSERT_TRUE(test_matcher.get_new_nodes().size() == 1); test_matcher.apply(test_matcher.get_new_nodes()[0]); ASSERT_TRUE(test_matcher.get_new_nodes().empty()); } { TestMatcherPass test_matcher; auto a = make_shared(element::f32, Shape{1}); auto b = make_shared(a); auto c = make_shared(b); auto f = std::make_shared(ngraph::NodeVector{b, c}, ParameterVector{a}); ASSERT_FALSE(test_matcher.get_matcher()->match(c->output(0))); } { std::shared_ptr f; { auto a = make_shared(element::f32, Shape{1}); auto b = make_shared(a); auto c = make_shared(b); auto d = make_shared(c); f = std::make_shared(ngraph::NodeVector{d}, ParameterVector{a}); } pass::GraphRewrite pass; pass.add_matcher(); pass.run_on_function(f); // Parameter->Relu->Result ASSERT_TRUE(f->get_ops().size() == 3); } }