diff --git a/inference-engine/src/legacy_api/src/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.cpp b/inference-engine/src/legacy_api/src/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.cpp index 342f4e8ae34..53ce548e8bf 100644 --- a/inference-engine/src/legacy_api/src/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.cpp +++ b/inference-engine/src/legacy_api/src/transformations/convert_opset1_to_legacy/convert_opset1_to_legacy.cpp @@ -61,7 +61,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertOpSet1ToLegacy, "ConvertOpSet1ToLega bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr f) { OV_ITT_SCOPED_TASK(InferenceEngine::itt::domains::IELegacy, "ngraph::pass::ConvertOpSet1ToLegacy"); - ngraph::pass::Manager manager; + ngraph::pass::Manager manager(get_pass_config()); manager.register_pass(); @@ -148,7 +148,6 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr(); - manager.set_pass_config(get_pass_config()); manager.run_passes(f); return true; } diff --git a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp index bba3fcb5d93..4f3a264c262 100644 --- a/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp +++ b/inference-engine/src/transformations/src/transformations/common_optimizations/common_optimizations.cpp @@ -53,7 +53,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::CommonOptimizations, "CommonOptimizations", bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr f) { OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::CommonOptimizations"); - ngraph::pass::Manager manager; + ngraph::pass::Manager manager(get_pass_config()); // This pass must be called first in pipeline manager.register_pass(); @@ -118,8 +118,6 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptradd_matcher(); fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions"); - // Propagate local PassConfig to internal pass::Manager - manager.set_pass_config(get_pass_config()); manager.run_passes(f); return true; } diff --git a/inference-engine/src/transformations/src/transformations/opset_conversions/convert_opset2_to_opset1.cpp b/inference-engine/src/transformations/src/transformations/opset_conversions/convert_opset2_to_opset1.cpp index 558cbe82a89..df71f2bd37c 100644 --- a/inference-engine/src/transformations/src/transformations/opset_conversions/convert_opset2_to_opset1.cpp +++ b/inference-engine/src/transformations/src/transformations/opset_conversions/convert_opset2_to_opset1.cpp @@ -18,12 +18,11 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertOpSet2ToOpSet1, "ConvertOpSet2ToOpSe bool ngraph::pass::ConvertOpSet2ToOpSet1::run_on_function(std::shared_ptr f) { OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::ConvertOpSet2ToOpSet1"); - ngraph::pass::Manager manager; + ngraph::pass::Manager manager(get_pass_config()); manager.register_pass(); manager.register_pass(); - manager.set_pass_config(get_pass_config()); manager.run_passes(f); return true; } diff --git a/inference-engine/src/transformations/src/transformations/opset_conversions/convert_opset3_to_opset2.cpp b/inference-engine/src/transformations/src/transformations/opset_conversions/convert_opset3_to_opset2.cpp index cedbb5bbf33..85f54d1e433 100644 --- a/inference-engine/src/transformations/src/transformations/opset_conversions/convert_opset3_to_opset2.cpp +++ b/inference-engine/src/transformations/src/transformations/opset_conversions/convert_opset3_to_opset2.cpp @@ -22,7 +22,7 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertOpSet3ToOpSet2, "ConvertOpSet3ToOpSe bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr f) { OV_ITT_SCOPED_TASK(itt::domains::IETransform, "ngraph::pass::ConvertOpSet3ToOpSet2"); - ngraph::pass::Manager manager; + ngraph::pass::Manager manager(get_pass_config()); manager.register_pass(); manager.register_pass(); @@ -31,7 +31,6 @@ bool ngraph::pass::ConvertOpSet3ToOpSet2::run_on_function(std::shared_ptr(); manager.register_pass(); - manager.set_pass_config(get_pass_config()); manager.run_passes(f); return true; } diff --git a/ngraph/core/include/ngraph/pass/graph_rewrite.hpp b/ngraph/core/include/ngraph/pass/graph_rewrite.hpp index 4b19bb5894c..9cb572d1c8e 100644 --- a/ngraph/core/include/ngraph/pass/graph_rewrite.hpp +++ b/ngraph/core/include/ngraph/pass/graph_rewrite.hpp @@ -150,7 +150,7 @@ public: auto pass = std::make_shared(std::forward(args)...); auto pass_config = get_pass_config(); pass->set_pass_config(pass_config); - if (!Enabled) + if (!Enabled && !pass_config->is_enabled()) { pass_config->disable(); } @@ -168,6 +168,8 @@ public: bool run_on_function(std::shared_ptr f) override; + void set_pass_config(const std::shared_ptr& pass_config) override; + protected: bool m_enable_shape_inference = false; diff --git a/ngraph/core/include/ngraph/pass/manager.hpp b/ngraph/core/include/ngraph/pass/manager.hpp index 6a0060f406b..508c6ea93af 100644 --- a/ngraph/core/include/ngraph/pass/manager.hpp +++ b/ngraph/core/include/ngraph/pass/manager.hpp @@ -38,6 +38,9 @@ public: Manager(); ~Manager(); + //// \brief Construct Manager with shared PassConfig instance + explicit Manager(std::shared_ptr pass_config); + /// \brief Register given transformation class type to execution list /// Example below show the basic usage of pass::Manager /// @@ -59,7 +62,7 @@ public: { push_pass(); } - if (!Enable) + if (!Enable && !m_pass_config->is_enabled()) { m_pass_config->disable(); } @@ -99,12 +102,6 @@ public: /// This object allows to disable/enable transformations execution, set callback to particular /// transformation. For mo details see PassConfig class. std::shared_ptr get_pass_config() { return m_pass_config; } - /// \brief Set external PassConfig object. - void set_pass_config(const std::shared_ptr& pass_config) - { - *m_pass_config = *pass_config; - } - protected: template std::shared_ptr push_pass(Args&&... args) diff --git a/ngraph/core/include/ngraph/pass/pass.hpp b/ngraph/core/include/ngraph/pass/pass.hpp index 100d45438e1..d279d822aee 100644 --- a/ngraph/core/include/ngraph/pass/pass.hpp +++ b/ngraph/core/include/ngraph/pass/pass.hpp @@ -62,7 +62,7 @@ namespace ngraph /// \brief Set PassConfig for particular transformation instance /// \param pass_config is a PassConfig shared_ptr - void set_pass_config(const std::shared_ptr& pass_config) + virtual void set_pass_config(const std::shared_ptr& pass_config) { m_pass_config = pass_config; } diff --git a/ngraph/core/include/ngraph/pass/pass_config.hpp b/ngraph/core/include/ngraph/pass/pass_config.hpp index 7d85d1a26f3..d292d18a19e 100644 --- a/ngraph/core/include/ngraph/pass/pass_config.hpp +++ b/ngraph/core/include/ngraph/pass/pass_config.hpp @@ -72,7 +72,7 @@ namespace ngraph public: /// \brief Disable transformation by its type_info /// \param type_info Transformation type_info - void disable(const DiscreteTypeInfo& type_info) { m_disabled.insert(type_info); } + void disable(const DiscreteTypeInfo& type_info); /// \brief Disable transformation by its class type (based on type_info) template void disable() @@ -82,7 +82,7 @@ namespace ngraph /// \brief Enable transformation by its type_info /// \param type_info Transformation type_info - void enable(const DiscreteTypeInfo& type_info) { m_disabled.erase(type_info); } + void enable(const DiscreteTypeInfo& type_info); /// \brief Enable transformation by its class type (based on type_info) template void enable() @@ -161,12 +161,31 @@ namespace ngraph return is_disabled(T::type_info); } + /// \brief Check either transformation type is force enabled or not + /// \param type_info Transformation type_info + /// \return true if transformation type was force enabled and false otherwise + bool is_enabled(const DiscreteTypeInfo& type_info) const + { + return m_enabled.count(type_info); + } + + /// \brief Check either transformation class type is force enabled or not + /// \return true if transformation type was force enabled and false otherwise + template + bool is_enabled() const + { + return is_enabled(T::type_info); + } + + void add_disabled_passes(const PassConfig& rhs); + private: param_callback m_callback = [](const std::shared_ptr&) { return false; }; param_callback_map m_callback_map; std::unordered_set m_disabled; + std::unordered_set m_enabled; }; } } \ No newline at end of file diff --git a/ngraph/core/src/pass/graph_rewrite.cpp b/ngraph/core/src/pass/graph_rewrite.cpp index 66993ebd05a..08020f441fd 100644 --- a/ngraph/core/src/pass/graph_rewrite.cpp +++ b/ngraph/core/src/pass/graph_rewrite.cpp @@ -272,6 +272,35 @@ void pass::GraphRewrite::add_matcher(const shared_ptr& m, NGRAPH_SUPPRESS_DEPRECATED_END } +void pass::GraphRewrite::set_pass_config(const std::shared_ptr& rhs) +{ + auto pass_config = get_pass_config(); + // We have to preserve disabled passes because in case when we register matchers inside + // GraphRewrite c-tor we work with local PassConfig instance. + // For example: + // + // class ExampleGraphRewrite: public pass::GraphRewrite { + // ExampleGraphRewrite() { + // add_mather(); + // add_mather(); + // } + // }; + // + // When we call add_matcher inside c-tor we automatically work with locally created PassConfig + // instance that is not shared. So when instance of this pass is being created in pass::Manager + // we set shared PassConfig but we will override already existing rules inside local config. To + // resolve this we have to copy disabled passes from local PassConfig to shared but we take into + // account that if passes were manually enabled we do not add them. + rhs->add_disabled_passes(*pass_config); + PassBase::set_pass_config(rhs); + + // update nested transformations with new shared pass_config + for (auto& pass : m_matchers) + { + pass->set_pass_config(rhs); + } +} + void pass::RecurrentGraphRewrite::add_matcher( const std::shared_ptr& m, const ngraph::recurrent_graph_rewrite_callback& callback, diff --git a/ngraph/core/src/pass/manager.cpp b/ngraph/core/src/pass/manager.cpp index e4c7044897f..7c7a7f2b97c 100644 --- a/ngraph/core/src/pass/manager.cpp +++ b/ngraph/core/src/pass/manager.cpp @@ -44,6 +44,11 @@ pass::Manager::~Manager() { } +pass::Manager::Manager(std::shared_ptr pass_config) + : m_pass_config(std::move(pass_config)) +{ +} + void pass::Manager::run_passes(shared_ptr func) { OV_ITT_SCOPED_TASK(itt::domains::nGraph, "pass::Manager::run_passes"); diff --git a/ngraph/core/src/pass/pass_config.cpp b/ngraph/core/src/pass/pass_config.cpp index c123d4b18dd..e55d061f087 100644 --- a/ngraph/core/src/pass/pass_config.cpp +++ b/ngraph/core/src/pass/pass_config.cpp @@ -30,3 +30,25 @@ pass::param_callback pass::PassConfig::get_callback(const DiscreteTypeInfo& type return m_callback; } } + +void pass::PassConfig::enable(const ngraph::DiscreteTypeInfo& type_info) +{ + m_disabled.erase(type_info); + m_enabled.insert(type_info); +} + +void pass::PassConfig::disable(const ngraph::DiscreteTypeInfo& type_info) +{ + m_enabled.erase(type_info); + m_disabled.insert(type_info); +} + +void pass::PassConfig::add_disabled_passes(const PassConfig& rhs) +{ + for (const auto& pass : rhs.m_disabled) + { + if (is_enabled(pass)) + continue; + disable(pass); + } +} diff --git a/ngraph/test/CMakeLists.txt b/ngraph/test/CMakeLists.txt index 04f72d08d84..8e486b6e865 100644 --- a/ngraph/test/CMakeLists.txt +++ b/ngraph/test/CMakeLists.txt @@ -88,6 +88,7 @@ set(SRC op_is.cpp opset1.cpp partial_shape.cpp + pass_config.cpp pass_liveness.cpp pass_manager.cpp pass_shape_relevance.cpp diff --git a/ngraph/test/graph_rewrite.cpp b/ngraph/test/graph_rewrite.cpp index e5c435d7e2a..5cb3f5da222 100644 --- a/ngraph/test/graph_rewrite.cpp +++ b/ngraph/test/graph_rewrite.cpp @@ -298,13 +298,13 @@ TEST(PassConfigTest, Test1) { auto f = get_function(); - pass::Manager manager; + auto pass_config = std::make_shared(); + pass::Manager manager(pass_config); + manager.register_pass(); - auto pass_config = std::make_shared(); pass_config->set_callback(get_callback()); - manager.set_pass_config(pass_config); manager.run_passes(f); ASSERT_EQ(count_ops_of_type(f), 1); @@ -343,11 +343,9 @@ TEST(PassConfigTest, Test1) { auto pass_config = std::make_shared(); - pass::Manager manager1; - pass::Manager manager2; - manager1.set_pass_config(pass_config); - manager2.set_pass_config(pass_config); - ASSERT_EQ(pass_config.use_count(), 1); + pass::Manager manager1(pass_config); + pass::Manager manager2(pass_config); + ASSERT_EQ(pass_config.use_count(), 3); } { diff --git a/ngraph/test/pass_config.cpp b/ngraph/test/pass_config.cpp new file mode 100644 index 00000000000..f350c4a5658 --- /dev/null +++ b/ngraph/test/pass_config.cpp @@ -0,0 +1,381 @@ +// Copyright (C) 2018-2020 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include +#include +#include +#include +#include + +using namespace ::testing; +using namespace std; +using namespace ngraph; + +class RenameReLU : public ngraph::pass::MatcherPass +{ +public: + NGRAPH_RTTI_DECLARATION; + RenameReLU() + : MatcherPass() + { + auto relu = pattern::wrap_type(); + ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) { + auto relu = m.get_match_root(); + relu->set_friendly_name("renamed"); + return false; + }; + + auto m = std::make_shared(relu, "RenameReLU"); + this->register_matcher(m, callback); + } +}; + +NGRAPH_RTTI_DEFINITION(RenameReLU, "RenameReLU", 0); + +class RenameSigmoid : public ngraph::pass::MatcherPass +{ +public: + NGRAPH_RTTI_DECLARATION; + RenameSigmoid() + : MatcherPass() + { + auto sigmoid = pattern::wrap_type(); + ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) { + auto sigmoid = m.get_match_root(); + sigmoid->set_friendly_name("renamed"); + return false; + }; + + auto m = std::make_shared(sigmoid, "RenameSigmoid"); + this->register_matcher(m, callback); + } +}; + +NGRAPH_RTTI_DEFINITION(RenameSigmoid, "RenameSigmoid", 0); + +class TestFunctionPass : public ngraph::pass::FunctionPass +{ +public: + NGRAPH_RTTI_DECLARATION; + + bool run_on_function(std::shared_ptr f) override + { + pass::Manager manager(get_pass_config()); + + manager.register_pass(); + manager.register_pass(); + + manager.run_passes(f); + return true; + } +}; + +NGRAPH_RTTI_DEFINITION(TestFunctionPass, "TestFunctionPass", 0); + +class TestGraphRewritePass : public ngraph::pass::GraphRewrite +{ +public: + NGRAPH_RTTI_DECLARATION; + TestGraphRewritePass() + { + add_matcher(); + add_matcher(); + } +}; + +NGRAPH_RTTI_DEFINITION(TestGraphRewritePass, "TestGraphRewritePass", 0); + +std::tuple, std::shared_ptr, std::shared_ptr> + get_test_function() +{ + auto data = + std::make_shared(ngraph::element::f32, ngraph::Shape{3, 1, 2}); + auto relu = std::make_shared(data); + relu->set_friendly_name("relu"); + auto sigmoid = std::make_shared(relu); + sigmoid->set_friendly_name("sigmoid"); + auto f = std::make_shared(ngraph::NodeVector{sigmoid}, + ngraph::ParameterVector{data}); + return std::tuple, std::shared_ptr, std::shared_ptr>( + f, relu, sigmoid); +} + +TEST(PassConfig, EnableDisablePasses1) +{ + std::shared_ptr f; + std::shared_ptr relu, sigmoid; + std::tie(f, relu, sigmoid) = get_test_function(); + + pass::Manager manager; + manager.register_pass(); + manager.run_passes(f); + + ASSERT_EQ(relu->get_friendly_name(), "relu"); + ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); +} + +TEST(PassConfig, EnableDisablePasses2) +{ + std::shared_ptr f; + std::shared_ptr relu, sigmoid; + std::tie(f, relu, sigmoid) = get_test_function(); + + pass::Manager manager; + manager.register_pass(); + + auto pass_config = manager.get_pass_config(); + pass_config->disable(); + manager.run_passes(f); + + ASSERT_EQ(relu->get_friendly_name(), "relu"); + ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid"); + + pass_config->enable(); + pass_config->enable(); + manager.run_passes(f); + + ASSERT_EQ(relu->get_friendly_name(), "renamed"); + ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); +} + +TEST(PassConfig, EnableDisablePasses3) +{ + std::shared_ptr f; + std::shared_ptr relu, sigmoid; + std::tie(f, relu, sigmoid) = get_test_function(); + + pass::Manager manager; + manager.register_pass(); + + auto pass_config = manager.get_pass_config(); + pass_config->enable(); + manager.run_passes(f); + + ASSERT_EQ(relu->get_friendly_name(), "renamed"); + ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); +} + +TEST(PassConfig, EnableDisablePasses4) +{ + std::shared_ptr f; + std::shared_ptr relu, sigmoid; + std::tie(f, relu, sigmoid) = get_test_function(); + + pass::Manager manager; + manager.register_pass(); + + auto pass_config = manager.get_pass_config(); + pass_config->enable(); + pass_config->disable(); + manager.run_passes(f); + + ASSERT_EQ(relu->get_friendly_name(), "renamed"); + ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid"); + + pass_config->enable(); + manager.run_passes(f); + ASSERT_EQ(relu->get_friendly_name(), "renamed"); + ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); +} + +TEST(PassConfig, EnableDisablePasses5) +{ + std::shared_ptr f; + std::shared_ptr relu, sigmoid; + std::tie(f, relu, sigmoid) = get_test_function(); + + pass::Manager manager; + manager.register_pass(); + manager.run_passes(f); + + ASSERT_EQ(relu->get_friendly_name(), "relu"); + ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); +} + +TEST(PassConfig, EnableDisablePasses6) +{ + std::shared_ptr f; + std::shared_ptr relu, sigmoid; + std::tie(f, relu, sigmoid) = get_test_function(); + + pass::Manager manager; + manager.register_pass(); + + auto pass_config = manager.get_pass_config(); + pass_config->disable(); + manager.run_passes(f); + + ASSERT_EQ(relu->get_friendly_name(), "relu"); + ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid"); + + pass_config->enable(); + pass_config->enable(); + manager.run_passes(f); + + ASSERT_EQ(relu->get_friendly_name(), "renamed"); + ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); +} + +TEST(PassConfig, EnableDisablePasses7) +{ + std::shared_ptr f; + std::shared_ptr relu, sigmoid; + std::tie(f, relu, sigmoid) = get_test_function(); + + pass::Manager manager; + manager.register_pass(); + + auto pass_config = manager.get_pass_config(); + pass_config->enable(); + manager.run_passes(f); + + ASSERT_EQ(relu->get_friendly_name(), "renamed"); + ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); +} + +TEST(PassConfig, EnableDisablePasses8) +{ + std::shared_ptr f; + std::shared_ptr relu, sigmoid; + std::tie(f, relu, sigmoid) = get_test_function(); + + pass::Manager manager; + manager.register_pass(); + + auto pass_config = manager.get_pass_config(); + pass_config->enable(); + pass_config->disable(); + manager.run_passes(f); + + ASSERT_EQ(relu->get_friendly_name(), "renamed"); + ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid"); + + pass_config->enable(); + manager.run_passes(f); + + ASSERT_EQ(relu->get_friendly_name(), "renamed"); + ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); +} + +TEST(PassConfig, EnableDisablePasses9) +{ + std::shared_ptr f; + std::shared_ptr relu, sigmoid; + std::tie(f, relu, sigmoid) = get_test_function(); + + pass::Manager manager; + auto anchor = manager.register_pass(); + anchor->add_matcher(); + anchor->add_matcher(); + + auto pass_config = manager.get_pass_config(); + pass_config->enable(); + pass_config->disable(); + manager.run_passes(f); + + ASSERT_EQ(relu->get_friendly_name(), "renamed"); + ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid"); + + pass_config->enable(); + manager.run_passes(f); + + ASSERT_EQ(relu->get_friendly_name(), "renamed"); + ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); +} + +class TestNestedMatcher : public ngraph::pass::MatcherPass +{ +public: + NGRAPH_RTTI_DECLARATION; + TestNestedMatcher() + : MatcherPass() + { + auto any_op = pattern::any_input(); + ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) { + auto root = m.get_match_root(); + auto pass_config = this->get_pass_config(); + if (std::dynamic_pointer_cast(root) && + !pass_config->is_disabled()) + { + auto pass = std::make_shared(); + pass->set_pass_config(pass_config); + pass->apply(root); + } + else if (std::dynamic_pointer_cast(root) && + !pass_config->is_disabled()) + { + auto pass = std::make_shared(); + pass->set_pass_config(pass_config); + pass->apply(root); + } + return false; + }; + + auto m = std::make_shared(any_op, "TestNestedMatcher"); + this->register_matcher(m, callback); + } +}; + +NGRAPH_RTTI_DEFINITION(TestNestedMatcher, "TestNestedMatcher", 0); + +class TestNestedGraphRewrite : public pass::GraphRewrite +{ +public: + NGRAPH_RTTI_DECLARATION; + TestNestedGraphRewrite() { add_matcher(); } +}; + +NGRAPH_RTTI_DEFINITION(TestNestedGraphRewrite, "TestNestedGraphRewrite", 0); + +TEST(PassConfig, EnableDisablePasses10) +{ + std::shared_ptr f; + std::shared_ptr relu, sigmoid; + std::tie(f, relu, sigmoid) = get_test_function(); + + pass::Manager manager; + manager.register_pass(); + + auto pass_config = manager.get_pass_config(); + pass_config->disable(); + pass_config->disable(); + manager.run_passes(f); + + ASSERT_EQ(relu->get_friendly_name(), "relu"); + ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid"); + + pass_config->enable(); + pass_config->enable(); + manager.run_passes(f); + + ASSERT_EQ(relu->get_friendly_name(), "renamed"); + ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); +} + +TEST(PassConfig, EnableDisablePasses11) +{ + std::shared_ptr f; + std::shared_ptr relu, sigmoid; + std::tie(f, relu, sigmoid) = get_test_function(); + + pass::Manager manager; + auto anchor = manager.register_pass(); + anchor->add_matcher(); + + auto pass_config = manager.get_pass_config(); + pass_config->disable(); + pass_config->disable(); + manager.run_passes(f); + + ASSERT_EQ(relu->get_friendly_name(), "relu"); + ASSERT_EQ(sigmoid->get_friendly_name(), "sigmoid"); + + pass_config->enable(); + pass_config->enable(); + manager.run_passes(f); + + ASSERT_EQ(relu->get_friendly_name(), "renamed"); + ASSERT_EQ(sigmoid->get_friendly_name(), "renamed"); +} \ No newline at end of file