// Copyright (C) 2018-2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // #include #include #include #include #include #include using namespace ::testing; using namespace std; using namespace ngraph; class RenameReLU : public ngraph::pass::MatcherPass { public: NGRAPH_RTTI_DECLARATION; RenameReLU() : MatcherPass() { auto relu = pattern::wrap_type(); ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) { auto relu = m.get_match_root(); relu->set_friendly_name("renamed"); return false; }; auto m = std::make_shared(relu, "RenameReLU"); this->register_matcher(m, callback); } }; NGRAPH_RTTI_DEFINITION(RenameReLU, "RenameReLU", 0); class RenameSigmoid : public ngraph::pass::MatcherPass { public: NGRAPH_RTTI_DECLARATION; RenameSigmoid() : MatcherPass() { auto sigmoid = pattern::wrap_type(); ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) { auto sigmoid = m.get_match_root(); sigmoid->set_friendly_name("renamed"); return false; }; auto m = std::make_shared(sigmoid, "RenameSigmoid"); this->register_matcher(m, callback); } }; NGRAPH_RTTI_DEFINITION(RenameSigmoid, "RenameSigmoid", 0); class TestFunctionPass : public ngraph::pass::FunctionPass { public: NGRAPH_RTTI_DECLARATION; bool run_on_function(std::shared_ptr f) override { pass::Manager manager(get_pass_config()); manager.register_pass(); manager.register_pass(); manager.run_passes(f); return true; } }; NGRAPH_RTTI_DEFINITION(TestFunctionPass, "TestFunctionPass", 0); class TestGraphRewritePass : public ngraph::pass::GraphRewrite { public: NGRAPH_RTTI_DECLARATION; TestGraphRewritePass() { add_matcher(); add_matcher(); } }; NGRAPH_RTTI_DEFINITION(TestGraphRewritePass, "TestGraphRewritePass", 0); std::tuple, std::shared_ptr, std::shared_ptr> get_test_function() { auto data = std::make_shared(ngraph::element::f32, ngraph::Shape{3, 1, 2}); auto relu = std::make_shared(data); relu->set_friendly_name("relu"); auto sigmoid = std::make_shared(relu); sigmoid->set_friendly_name("sigmoid"); auto f = std::make_shared(ngraph::NodeVector{sigmoid}, ngraph::ParameterVector{data}); return std::tuple, std::shared_ptr, std::shared_ptr>( f, relu, sigmoid); } TEST(PassConfig, EnableDisablePasses1) { std::shared_ptr f; std::shared_ptr relu, sigmoid; std::tie(f, relu, sigmoid) = get_test_function(); pass::Manager manager; manager.register_pass(); manager.run_passes(f); ASSERT_EQ(relu->get_friendly_name(), "relu"); ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); } TEST(PassConfig, EnableDisablePasses2) { std::shared_ptr f; std::shared_ptr relu, sigmoid; std::tie(f, relu, sigmoid) = get_test_function(); pass::Manager manager; manager.register_pass(); auto pass_config = manager.get_pass_config(); pass_config->disable(); manager.run_passes(f); ASSERT_EQ(relu->get_friendly_name(), "relu"); ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid"); pass_config->enable(); pass_config->enable(); manager.run_passes(f); ASSERT_EQ(relu->get_friendly_name(), "renamed"); ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); } TEST(PassConfig, EnableDisablePasses3) { std::shared_ptr f; std::shared_ptr relu, sigmoid; std::tie(f, relu, sigmoid) = get_test_function(); pass::Manager manager; manager.register_pass(); auto pass_config = manager.get_pass_config(); pass_config->enable(); manager.run_passes(f); ASSERT_EQ(relu->get_friendly_name(), "renamed"); ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); } TEST(PassConfig, EnableDisablePasses4) { std::shared_ptr f; std::shared_ptr relu, sigmoid; std::tie(f, relu, sigmoid) = get_test_function(); pass::Manager manager; manager.register_pass(); auto pass_config = manager.get_pass_config(); pass_config->enable(); pass_config->disable(); manager.run_passes(f); ASSERT_EQ(relu->get_friendly_name(), "renamed"); ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid"); pass_config->enable(); manager.run_passes(f); ASSERT_EQ(relu->get_friendly_name(), "renamed"); ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); } TEST(PassConfig, EnableDisablePasses5) { std::shared_ptr f; std::shared_ptr relu, sigmoid; std::tie(f, relu, sigmoid) = get_test_function(); pass::Manager manager; manager.register_pass(); manager.run_passes(f); ASSERT_EQ(relu->get_friendly_name(), "relu"); ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); } TEST(PassConfig, EnableDisablePasses6) { std::shared_ptr f; std::shared_ptr relu, sigmoid; std::tie(f, relu, sigmoid) = get_test_function(); pass::Manager manager; manager.register_pass(); auto pass_config = manager.get_pass_config(); pass_config->disable(); manager.run_passes(f); ASSERT_EQ(relu->get_friendly_name(), "relu"); ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid"); pass_config->enable(); pass_config->enable(); manager.run_passes(f); ASSERT_EQ(relu->get_friendly_name(), "renamed"); ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); } TEST(PassConfig, EnableDisablePasses7) { std::shared_ptr f; std::shared_ptr relu, sigmoid; std::tie(f, relu, sigmoid) = get_test_function(); pass::Manager manager; manager.register_pass(); auto pass_config = manager.get_pass_config(); pass_config->enable(); manager.run_passes(f); ASSERT_EQ(relu->get_friendly_name(), "renamed"); ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); } TEST(PassConfig, EnableDisablePasses8) { std::shared_ptr f; std::shared_ptr relu, sigmoid; std::tie(f, relu, sigmoid) = get_test_function(); pass::Manager manager; manager.register_pass(); auto pass_config = manager.get_pass_config(); pass_config->enable(); pass_config->disable(); manager.run_passes(f); ASSERT_EQ(relu->get_friendly_name(), "renamed"); ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid"); pass_config->enable(); manager.run_passes(f); ASSERT_EQ(relu->get_friendly_name(), "renamed"); ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); } TEST(PassConfig, EnableDisablePasses9) { std::shared_ptr f; std::shared_ptr relu, sigmoid; std::tie(f, relu, sigmoid) = get_test_function(); pass::Manager manager; auto anchor = manager.register_pass(); anchor->add_matcher(); anchor->add_matcher(); auto pass_config = manager.get_pass_config(); pass_config->enable(); pass_config->disable(); manager.run_passes(f); ASSERT_EQ(relu->get_friendly_name(), "renamed"); ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid"); pass_config->enable(); manager.run_passes(f); ASSERT_EQ(relu->get_friendly_name(), "renamed"); ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); } class TestNestedMatcher : public ngraph::pass::MatcherPass { public: NGRAPH_RTTI_DECLARATION; TestNestedMatcher() : MatcherPass() { auto any_op = pattern::any_input(); ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) { auto root = m.get_match_root(); auto pass_config = this->get_pass_config(); if (std::dynamic_pointer_cast(root) && !pass_config->is_disabled()) { auto pass = std::make_shared(); pass->set_pass_config(pass_config); pass->apply(root); } else if (std::dynamic_pointer_cast(root) && !pass_config->is_disabled()) { auto pass = std::make_shared(); pass->set_pass_config(pass_config); pass->apply(root); } return false; }; auto m = std::make_shared(any_op, "TestNestedMatcher"); this->register_matcher(m, callback); } }; NGRAPH_RTTI_DEFINITION(TestNestedMatcher, "TestNestedMatcher", 0); class TestNestedGraphRewrite : public pass::GraphRewrite { public: NGRAPH_RTTI_DECLARATION; TestNestedGraphRewrite() { add_matcher(); } }; NGRAPH_RTTI_DEFINITION(TestNestedGraphRewrite, "TestNestedGraphRewrite", 0); TEST(PassConfig, EnableDisablePasses10) { std::shared_ptr f; std::shared_ptr relu, sigmoid; std::tie(f, relu, sigmoid) = get_test_function(); pass::Manager manager; manager.register_pass(); auto pass_config = manager.get_pass_config(); pass_config->disable(); pass_config->disable(); manager.run_passes(f); ASSERT_EQ(relu->get_friendly_name(), "relu"); ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid"); pass_config->enable(); pass_config->enable(); manager.run_passes(f); ASSERT_EQ(relu->get_friendly_name(), "renamed"); ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); } TEST(PassConfig, EnableDisablePasses11) { std::shared_ptr f; std::shared_ptr relu, sigmoid; std::tie(f, relu, sigmoid) = get_test_function(); pass::Manager manager; auto anchor = manager.register_pass(); anchor->add_matcher(); auto pass_config = manager.get_pass_config(); pass_config->disable(); pass_config->disable(); manager.run_passes(f); ASSERT_EQ(relu->get_friendly_name(), "relu"); ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid"); pass_config->enable(); pass_config->enable(); manager.run_passes(f); ASSERT_EQ(relu->get_friendly_name(), "renamed"); ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); }