ReverseInputChannelsFusion - no reverse input channels -> return (#15784)

* ReverseInputChannelsFusion - return early if there is no reverse input channels

Ticket: 98067

* run_passes

* fix unnecessary validate calls
This commit is contained in:
Mateusz Tabaka 2023-02-28 18:56:21 +01:00 committed by GitHub
parent 0988c2b813
commit 62ff31df8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 31 additions and 12 deletions

View File

@ -825,18 +825,25 @@ public:
bool ov::pass::ReverseInputChannelsFusion::run_on_model(const std::shared_ptr<ov::Model>& model) {
RUN_ON_MODEL_SCOPE(ReverseInputChannelsFusion);
Manager m;
m.set_per_pass_validation(false);
NodeVector nodes_to_fuse;
// First we need to initialize and propagate RIC attributes through entire graph
auto ric_prop = m.register_pass<GraphRewrite>();
{
using namespace init;
ADD_MATCHER(ric_prop, SplitConcat, nodes_to_fuse)
ADD_MATCHER(ric_prop, Gather, nodes_to_fuse)
Manager m;
m.set_per_pass_validation(false);
auto ric_init = m.register_pass<GraphRewrite>();
ADD_MATCHER(ric_init, SplitConcat, nodes_to_fuse)
ADD_MATCHER(ric_init, Gather, nodes_to_fuse)
if (!m.run_passes(model)) {
return false;
}
}
Manager m;
m.set_per_pass_validation(false);
auto ric_prop = m.register_pass<GraphRewrite>();
{
using namespace prop;
ADD_MATCHER(ric_prop, Convolution)

View File

@ -60,7 +60,13 @@ public:
return pass;
}
void run_passes(std::shared_ptr<Model>);
/// \brief Runs registered transformations on a given model
///
/// \param model Input model
///
/// \return Returns true if the model was changed by transformations,
/// false otherwise.
bool run_passes(std::shared_ptr<Model> model);
void set_pass_visualization(bool new_state) {
m_visualize = new_state;

View File

@ -59,7 +59,7 @@ void ov::pass::Manager::set_per_pass_validation(bool new_state) {
m_per_pass_validation = new_state;
}
void ov::pass::Manager::run_passes(shared_ptr<ov::Model> func) {
bool ov::pass::Manager::run_passes(shared_ptr<ov::Model> func) {
NGRAPH_SUPPRESS_DEPRECATED_START
OV_ITT_SCOPED_TASK(ov::itt::domains::core, "pass::Manager::run_passes");
@ -70,7 +70,9 @@ void ov::pass::Manager::run_passes(shared_ptr<ov::Model> func) {
ngraph::stopwatch pass_timer;
ngraph::stopwatch overall_timer;
overall_timer.start();
bool pass_applied = false;
bool function_changed = false;
bool needs_validate = false;
for (auto& pass : m_pass_list) {
if (m_pass_config->is_disabled(pass->get_type_info())) {
NGRAPH_DEBUG << "Pass " << pass->get_name() << " is disabled";
@ -91,7 +93,7 @@ void ov::pass::Manager::run_passes(shared_ptr<ov::Model> func) {
}
// GraphRewrite is a temporary container for MatcherPass to make execution
// on on entire ngraph::Function
function_changed = GraphRewrite(matcher_pass).run_on_model(func);
pass_applied = GraphRewrite(matcher_pass).run_on_model(func);
} else if (auto function_pass = dynamic_pointer_cast<ModelPass>(pass)) {
// This checks is to skip the graph transformation when the graph pass relies on
// static shape but the function state is dynamic.
@ -102,12 +104,12 @@ void ov::pass::Manager::run_passes(shared_ptr<ov::Model> func) {
}
if (dynamic_pointer_cast<Validate>(pass)) {
if (function_changed) {
if (needs_validate) {
function_pass->run_on_model(func);
function_changed = false;
needs_validate = false;
}
} else {
function_changed = function_pass->run_on_model(func);
pass_applied = function_pass->run_on_model(func);
}
} else if (auto node_pass = dynamic_pointer_cast<ngraph::pass::NodePass>(pass)) {
if (node_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && func->is_dynamic()) {
@ -116,7 +118,7 @@ void ov::pass::Manager::run_passes(shared_ptr<ov::Model> func) {
continue;
}
for (const shared_ptr<Node>& n : func->get_ops()) {
function_changed |= node_pass->run_on_node(n);
pass_applied |= node_pass->run_on_node(n);
}
}
@ -138,9 +140,13 @@ void ov::pass::Manager::run_passes(shared_ptr<ov::Model> func) {
if (profile_enabled) {
cout << setw(7) << pass_timer.get_milliseconds() << "ms " << pass->get_name() << "\n";
}
function_changed = function_changed || pass_applied;
needs_validate = pass_applied;
}
if (profile_enabled) {
cout << "passes done in " << overall_timer.get_milliseconds() << "ms\n";
}
NGRAPH_SUPPRESS_DEPRECATED_END
return function_changed;
}