Files
openvino/docs/template_plugin/src/template_function_transformation.cpp

38 lines
1.4 KiB
C++
Raw Normal View History

// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "template_function_transformation.hpp"
using namespace ngraph;
// ! [function_pass:template_transformation_cpp]
// template_function_transformation.cpp
nGraph Transformations refactoring (#931) This PR introduces next changes: 1. Transformations *_tbl.hpp files were replaced with direct registration in cpp files. 2. Plugins use pass::Manager to call conversion passes. 3. Transformations callback was moved to PassBase class as there is no more need to keep it in separate class 4. All pattern based transformations must be inherited from MatcherPass class. GraphRewrite class will be used only for matchers registration and execution on function. MatcherPass class adds new features to pattern-based transformations approach: * Allows to run matcher pass on a single node. * Operations that were created inside transformation callback can be added to execution list to be available for pattern matching within single GraphRewrite. 5. GraphRewrite MatchClosure was replaced with MatcherPass. So all matchers will be registered as a MatcherPass. 6. Added pass::Manager::clear_state() method to avoid dependency with nodes that no longer belongs to function after replacement. 7. Some representative transformations were updated to use MatcherPass as an example. 8. Mul->Add sequence fusion transformation was replaced with LinOpSequenceFusion. 9. Pattern and callback registration code was moved to class c-tors (will be finished for remaining passes in other PR) . 10. Updated pass::Manager to get pass names only when NGRAPH_PROFILE_PASS_ENABLE enabled. 11. Moving towards removing PassProperty. 12. Added ngraph::pattern::wrap_type<T>(inputs, pred) to simplify pattern creation. 13. GraphRewrite was updated to execute MatcherPass more efficient.
2020-07-27 19:47:37 +03:00
bool pass::MyFunctionTransformation::run_on_function(std::shared_ptr<ngraph::Function> f) {
// Example transformation code
std::vector<std::shared_ptr<Node> > nodes;
// Traverse nGraph Function in topological order
for (auto & node : f->get_ordered_ops()) {
// Check that number of input and output ports are equal to 1
if (node->inputs().size() == 1 && node->outputs().size() == 1) {
// Check that input and output shape a fully defined (not dynamic) and number of consumers equal to 1
Input<Node> input = node->input(0);
Output<Node> output = node->output(0);
if (input.get_partial_shape().is_static() && output.get_partial_shape().is_static() && output.get_target_inputs().size() == 1) {
nodes.push_back(node);
}
}
}
// Print types and names for collected nodes
for (auto & node : nodes) {
std::cout << "Type: " << node->get_type_info().name << std::endl
<< "Name: " << node->get_friendly_name() << std::endl;
}
// Return false because we didn't change nGraph Function
return false;
}
// ! [function_pass:template_transformation_cpp]