ReverseInputChannelsFusion - no reverse input channels -> return (#15784)
* ReverseInputChannelsFusion - return early if there is no reverse input channels Ticket: 98067 * run_passes * fix unnecessary validate calls
This commit is contained in:
parent
0988c2b813
commit
62ff31df8a
@ -825,18 +825,25 @@ public:
|
||||
|
||||
bool ov::pass::ReverseInputChannelsFusion::run_on_model(const std::shared_ptr<ov::Model>& model) {
|
||||
RUN_ON_MODEL_SCOPE(ReverseInputChannelsFusion);
|
||||
Manager m;
|
||||
m.set_per_pass_validation(false);
|
||||
|
||||
NodeVector nodes_to_fuse;
|
||||
// First we need to initialize and propagate RIC attributes through entire graph
|
||||
auto ric_prop = m.register_pass<GraphRewrite>();
|
||||
{
|
||||
using namespace init;
|
||||
ADD_MATCHER(ric_prop, SplitConcat, nodes_to_fuse)
|
||||
ADD_MATCHER(ric_prop, Gather, nodes_to_fuse)
|
||||
Manager m;
|
||||
m.set_per_pass_validation(false);
|
||||
auto ric_init = m.register_pass<GraphRewrite>();
|
||||
ADD_MATCHER(ric_init, SplitConcat, nodes_to_fuse)
|
||||
ADD_MATCHER(ric_init, Gather, nodes_to_fuse)
|
||||
if (!m.run_passes(model)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
Manager m;
|
||||
m.set_per_pass_validation(false);
|
||||
|
||||
auto ric_prop = m.register_pass<GraphRewrite>();
|
||||
{
|
||||
using namespace prop;
|
||||
ADD_MATCHER(ric_prop, Convolution)
|
||||
|
@ -60,7 +60,13 @@ public:
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_passes(std::shared_ptr<Model>);
|
||||
/// \brief Runs registered transformations on a given model
|
||||
///
|
||||
/// \param model Input model
|
||||
///
|
||||
/// \return Returns true if the model was changed by transformations,
|
||||
/// false otherwise.
|
||||
bool run_passes(std::shared_ptr<Model> model);
|
||||
|
||||
void set_pass_visualization(bool new_state) {
|
||||
m_visualize = new_state;
|
||||
|
@ -59,7 +59,7 @@ void ov::pass::Manager::set_per_pass_validation(bool new_state) {
|
||||
m_per_pass_validation = new_state;
|
||||
}
|
||||
|
||||
void ov::pass::Manager::run_passes(shared_ptr<ov::Model> func) {
|
||||
bool ov::pass::Manager::run_passes(shared_ptr<ov::Model> func) {
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
OV_ITT_SCOPED_TASK(ov::itt::domains::core, "pass::Manager::run_passes");
|
||||
|
||||
@ -70,7 +70,9 @@ void ov::pass::Manager::run_passes(shared_ptr<ov::Model> func) {
|
||||
ngraph::stopwatch pass_timer;
|
||||
ngraph::stopwatch overall_timer;
|
||||
overall_timer.start();
|
||||
bool pass_applied = false;
|
||||
bool function_changed = false;
|
||||
bool needs_validate = false;
|
||||
for (auto& pass : m_pass_list) {
|
||||
if (m_pass_config->is_disabled(pass->get_type_info())) {
|
||||
NGRAPH_DEBUG << "Pass " << pass->get_name() << " is disabled";
|
||||
@ -91,7 +93,7 @@ void ov::pass::Manager::run_passes(shared_ptr<ov::Model> func) {
|
||||
}
|
||||
// GraphRewrite is a temporary container for MatcherPass to make execution
|
||||
// on on entire ngraph::Function
|
||||
function_changed = GraphRewrite(matcher_pass).run_on_model(func);
|
||||
pass_applied = GraphRewrite(matcher_pass).run_on_model(func);
|
||||
} else if (auto function_pass = dynamic_pointer_cast<ModelPass>(pass)) {
|
||||
// This checks is to skip the graph transformation when the graph pass relies on
|
||||
// static shape but the function state is dynamic.
|
||||
@ -102,12 +104,12 @@ void ov::pass::Manager::run_passes(shared_ptr<ov::Model> func) {
|
||||
}
|
||||
|
||||
if (dynamic_pointer_cast<Validate>(pass)) {
|
||||
if (function_changed) {
|
||||
if (needs_validate) {
|
||||
function_pass->run_on_model(func);
|
||||
function_changed = false;
|
||||
needs_validate = false;
|
||||
}
|
||||
} else {
|
||||
function_changed = function_pass->run_on_model(func);
|
||||
pass_applied = function_pass->run_on_model(func);
|
||||
}
|
||||
} else if (auto node_pass = dynamic_pointer_cast<ngraph::pass::NodePass>(pass)) {
|
||||
if (node_pass->get_property(PassProperty::REQUIRE_STATIC_SHAPE) && func->is_dynamic()) {
|
||||
@ -116,7 +118,7 @@ void ov::pass::Manager::run_passes(shared_ptr<ov::Model> func) {
|
||||
continue;
|
||||
}
|
||||
for (const shared_ptr<Node>& n : func->get_ops()) {
|
||||
function_changed |= node_pass->run_on_node(n);
|
||||
pass_applied |= node_pass->run_on_node(n);
|
||||
}
|
||||
}
|
||||
|
||||
@ -138,9 +140,13 @@ void ov::pass::Manager::run_passes(shared_ptr<ov::Model> func) {
|
||||
if (profile_enabled) {
|
||||
cout << setw(7) << pass_timer.get_milliseconds() << "ms " << pass->get_name() << "\n";
|
||||
}
|
||||
function_changed = function_changed || pass_applied;
|
||||
needs_validate = pass_applied;
|
||||
}
|
||||
if (profile_enabled) {
|
||||
cout << "passes done in " << overall_timer.get_milliseconds() << "ms\n";
|
||||
}
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
|
||||
return function_changed;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user