diff --git a/src/common/transformations/src/transformations/common_optimizations/ric_fusion.cpp b/src/common/transformations/src/transformations/common_optimizations/ric_fusion.cpp index a00daa194ff..ecfd6722578 100644 --- a/src/common/transformations/src/transformations/common_optimizations/ric_fusion.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/ric_fusion.cpp @@ -825,18 +825,25 @@ public: bool ov::pass::ReverseInputChannelsFusion::run_on_model(const std::shared_ptr& model) { RUN_ON_MODEL_SCOPE(ReverseInputChannelsFusion); - Manager m; - m.set_per_pass_validation(false); NodeVector nodes_to_fuse; // First we need to initialize and propagate RIC attributes through entire graph - auto ric_prop = m.register_pass(); { using namespace init; - ADD_MATCHER(ric_prop, SplitConcat, nodes_to_fuse) - ADD_MATCHER(ric_prop, Gather, nodes_to_fuse) + Manager m; + m.set_per_pass_validation(false); + auto ric_init = m.register_pass(); + ADD_MATCHER(ric_init, SplitConcat, nodes_to_fuse) + ADD_MATCHER(ric_init, Gather, nodes_to_fuse) + if (!m.run_passes(model)) { + return false; + } } + Manager m; + m.set_per_pass_validation(false); + + auto ric_prop = m.register_pass(); { using namespace prop; ADD_MATCHER(ric_prop, Convolution) diff --git a/src/core/include/openvino/pass/manager.hpp b/src/core/include/openvino/pass/manager.hpp index 1cdf01c22a5..40f14d83890 100644 --- a/src/core/include/openvino/pass/manager.hpp +++ b/src/core/include/openvino/pass/manager.hpp @@ -60,7 +60,13 @@ public: return pass; } - void run_passes(std::shared_ptr); + /// \brief Runs registered transformations on a given model + /// + /// \param model Input model + /// + /// \return Returns true if the model was changed by transformations, + /// false otherwise. + bool run_passes(std::shared_ptr model); void set_pass_visualization(bool new_state) { m_visualize = new_state; diff --git a/src/core/src/pass/manager.cpp b/src/core/src/pass/manager.cpp index e4dfd7b8cc7..bac12e6c544 100644 --- a/src/core/src/pass/manager.cpp +++ b/src/core/src/pass/manager.cpp @@ -59,7 +59,7 @@ void ov::pass::Manager::set_per_pass_validation(bool new_state) { m_per_pass_validation = new_state; } -void ov::pass::Manager::run_passes(shared_ptr func) { +bool ov::pass::Manager::run_passes(shared_ptr func) { NGRAPH_SUPPRESS_DEPRECATED_START OV_ITT_SCOPED_TASK(ov::itt::domains::core, "pass::Manager::run_passes"); @@ -70,7 +70,9 @@ void ov::pass::Manager::run_passes(shared_ptr func) { ngraph::stopwatch pass_timer; ngraph::stopwatch overall_timer; overall_timer.start(); + bool pass_applied = false; bool function_changed = false; + bool needs_validate = false; for (auto& pass : m_pass_list) { if (m_pass_config->is_disabled(pass->get_type_info())) { NGRAPH_DEBUG << "Pass " << pass->get_name() << " is disabled"; @@ -91,7 +93,7 @@ void ov::pass::Manager::run_passes(shared_ptr func) { } // GraphRewrite is a temporary container for MatcherPass to make execution // on on entire ngraph::Function - function_changed = GraphRewrite(matcher_pass).run_on_model(func); + pass_applied = GraphRewrite(matcher_pass).run_on_model(func); } else if (auto function_pass = dynamic_pointer_cast(pass)) { // This checks is to skip the graph transformation when the graph pass relies on // static shape but the function state is dynamic. @@ -102,12 +104,12 @@ void ov::pass::Manager::run_passes(shared_ptr func) { } if (dynamic_pointer_cast(pass)) { - if (function_changed) { + if (needs_validate) { function_pass->run_on_model(func); - function_changed = false; + needs_validate = false; } } else { - function_changed = function_pass->run_on_model(func); + pass_applied = function_pass->run_on_model(func); } } else if (auto node_pass = dynamic_pointer_cast(pass)) { if (node_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && func->is_dynamic()) { @@ -116,7 +118,7 @@ void ov::pass::Manager::run_passes(shared_ptr func) { continue; } for (const shared_ptr& n : func->get_ops()) { - function_changed |= node_pass->run_on_node(n); + pass_applied |= node_pass->run_on_node(n); } } @@ -138,9 +140,13 @@ void ov::pass::Manager::run_passes(shared_ptr func) { if (profile_enabled) { cout << setw(7) << pass_timer.get_milliseconds() << "ms " << pass->get_name() << "\n"; } + function_changed = function_changed || pass_applied; + needs_validate = pass_applied; } if (profile_enabled) { cout << "passes done in " << overall_timer.get_milliseconds() << "ms\n"; } NGRAPH_SUPPRESS_DEPRECATED_END + + return function_changed; }