Updated Transformation development doc (#2391)

This commit is contained in:
Gleb Kazantaev 2020-09-23 17:26:12 +03:00 committed by GitHub
parent 4a7f9ff86f
commit 30eeb1a5a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 76 additions and 25 deletions

View File

@ -63,6 +63,8 @@ Below you can find examples how `ngraph::Function` can be created:
nGraph has tree main transformation types: `ngraph::pass::FunctionPass` - strait forward way to work with `ngraph::Function` directly;
`ngraph::pass::MatcherPass` - pattern based transformation approach; `ngraph::pass::GraphRewrite` - container for matcher passes.
![transformations_structure]
###1. ngraph::pass::FunctionPass <a name="function_pass"></a>
`ngraph::pass::FunctionPass` is used for transformations that take entire `ngraph::Function` as input and process it.
@ -131,7 +133,7 @@ The last step is to register Matcher and callback inside MatcherPass pass. And t
```cpp
// Register matcher and callback
this->register_matcher(m, callback);
register_matcher(m, callback);
```
### Matcher pass execution
MatcherPass has multiple ways to be executed:
@ -154,21 +156,32 @@ In addition GraphRewrite handles nodes that were registered by MatcherPasses dur
> **Note**: when using `pass::Manager` temporary GraphRewrite is used to execute single MatcherPass.
GraphRewrite has two algorithms for MatcherPasses execution. First algorithm is a straight-forward. It applies each MatcherPass in registraion order to current node.
![graph_rewrite_execution]
But it is nor really efficient when you have a lot of registered passes. So first of all GraphRewrite check that all MatcherPass patterns has type based root node (it means that type of this node is not hidden into predicate).
And then creates map from registered MatcherPases. That helps to avoid additional cost of applying each MatcherPass for each node.
![graph_rewrite_efficient_search]
## Pattern matching <a name="pattern_matching"></a>
Sometimes patterns can't be expressed via regular nGraph operations. For example if you want to detect Convolution->Add sub-graph without specifying particular input type for Convolution operation or you want to create pattern where some of operations can have different types.
Sometimes patterns can't be expressed via regular nGraph operations or it is too complicated.
For example if you want to detect Convolution->Add sub-graph without specifying particular input type for Convolution operation or you want to create pattern where some of operations can have different types.
And for these cases nGraph provides additional helpers to construct patterns for GraphRewrite transformations.
There are two main helpers:
1. `ngraph::pattern::op::Label` - helps to express inputs if their type is undefined.
2. `ngraph::pattern::op::Any` - helps to express intermediate nodes of pattern if their type is unknown.
1. `ngraph::pattern::any_input` - helps to express inputs if their types are undefined.
2. `ngraph::pattern::wrap_type<T>` - helps to express nodes of pattern without specifying node attributes.
Let's go through example to have better understanding how it works:
> **Note**: node attributes do not participate in pattern matching and needed only for operations creation. Only operation types participate in pattern matching.
Example below shows basic usage of `pattern::op::Label` class.
Here we construct Multiply pattern with arbitrary first input and Constant as a second input.
Example below shows basic usage of `pattern::any_input`.
Here we construct Multiply pattern with arbitrary first input and Constant as a second input.
Also as Multiply is commutative operation it does not matter in which order we set inputs (any_input/Constant or Constant/any_input) because both cases will be matched.
@snippet example_ngraph_utils.cpp pattern:label_example
@ -176,7 +189,7 @@ This example show how we can construct pattern when operation has arbitrary numb
@snippet example_ngraph_utils.cpp pattern:concat_example
This example shows how to use predicate to construct pattern where operation has two different types. Also it shows how to match pattern manually on given node.
This example shows how to use predicate to construct pattern. Also it shows how to match pattern manually on given node.
@snippet example_ngraph_utils.cpp pattern:predicate_example
@ -321,9 +334,11 @@ ngraph::copy_runtime_info({a, b, c}, {e, f});
When transformation has multiple fusions or decompositions `ngraph::copy_runtime_info` must be called multiple times for each case.
> **Note**: copy_runtime_info removes rt_info from destination nodes. If you want to keep it you need to specify them in source nodes like this: copy_runtime_info({a, b, c}, {a, b})
###5. Constant Folding
If your transformation inserts constant sub-graphs that needs to be folded do not forget to use `ngraph::pass::ConstantFolding()` after your transformation.
If your transformation inserts constant sub-graphs that needs to be folded do not forget to use `ngraph::pass::ConstantFolding()` after your transformation or call constant folding directly for operation.
Example below shows how constant sub-graph can be constructed.
```cpp
@ -334,6 +349,12 @@ auto pow = std::make_shared<ngraph::opset3::Power>(
auto mul = std::make_shared<ngraph::opset3::Multiply>(input /* not constant input */, pow);
```
Manual constant folding is more preferable than `ngraph::pass::ConstantFolding()` because it is much faster.
Below you can find an example of manual constant folding:
@snippet src/template_pattern_transformation.cpp manual_constant_folding
## Common mistakes in transformations <a name="common_mistakes"></a>
In transformation development process
@ -427,4 +448,8 @@ The basic transformation test looks like this:
[ngraph_replace_node]: ../images/ngraph_replace_node.png
[ngraph_insert_node]: ../images/ngraph_insert_node.png
[ngraph_insert_node]: ../images/ngraph_insert_node.png
[transformations_structure]: ../images/transformations_structure.png
[register_new_node]: ../images/register_new_node.png
[graph_rewrite_execution]: ../images/graph_rewrite_execution.png
[graph_rewrite_efficient_search]: ../images/graph_rewrite_efficient_search.png

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:05eb8600d2c905975674f3a0a5dc676107d22f65f2a1f78ee1cfabc1771721ea
size 41307

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:17cd470c6d04d7aabbdb4a08e31f9c97eab960cf7ef5bbd3a541df92db38f26b
size 40458

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:80297287c81a2f27b7e74895738afd90844354a8dd745757e8321e2fb6ed547e
size 31246

View File

@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0b206c602626f17ba5787810b9a28f9cde511448c3e63a5c7ba976cee7868bdb
size 14907

View File

@ -4,6 +4,8 @@
#include <memory>
#include <ngraph/pattern/op/wrap_type.hpp>
// ! [ngraph:include]
#include <ngraph/ngraph.hpp>
#include <ngraph/opsets/opset3.hpp>
@ -89,7 +91,7 @@ ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
// ! [pattern:label_example]
// Detect Multiply with arbitrary first input and second as Constant
// ngraph::pattern::op::Label - represent arbitrary input
auto input = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32, ngraph::Shape{1});
auto input = ngraph::pattern::any_input();
auto value = ngraph::opset3::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0.5});
auto mul = std::make_shared<ngraph::opset3::Multiply>(input, value);
auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "MultiplyMatcher");
@ -99,20 +101,17 @@ auto m = std::make_shared<ngraph::pattern::Matcher>(mul, "MultiplyMatcher");
{
// ! [pattern:concat_example]
// Detect Concat operation with arbitrary number of inputs
auto concat = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32, ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset3::Concat>());
auto concat = ngraph::pattern::wrap_type<ngraph::opset3::Concat>();
auto m = std::make_shared<ngraph::pattern::Matcher>(concat, "ConcatMatcher");
// ! [pattern:concat_example]
}
{
// ! [pattern:predicate_example]
// Detect Multiply or Add operation
auto lin_op = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32, ngraph::Shape{},
[](const std::shared_ptr<ngraph::Node> & node) -> bool {
return std::dynamic_pointer_cast<ngraph::opset3::Multiply>(node) ||
std::dynamic_pointer_cast<ngraph::opset3::Add>(node);
});
auto m = std::make_shared<ngraph::pattern::Matcher>(lin_op, "MultiplyOrAddMatcher");
// Detect Multiply->Add sequence where mul has exactly one consumer
auto mul = ngraph::pattern::wrap_type<ngraph::opset3::Multiply>(ngraph::pattern::consumers_count(1)/*сheck consumers count*/);
auto add = ngraph::pattern::wrap_type<ngraph::opset3::Add>({mul, ngraph::pattern::any_input()});
auto m = std::make_shared<ngraph::pattern::Matcher>(add, "MultiplyAddMatcher");
// Matcher can be used to match pattern manually on given node
if (m->match(node->output(0))) {
// Successfully matched

View File

@ -10,7 +10,7 @@ using namespace ngraph;
// template_function_transformation.cpp
bool pass::MyFunctionTransformation::run_on_function(std::shared_ptr<ngraph::Function> f) {
// Example transformation code
std::vector<std::shared_ptr<Node> > nodes;
NodeVector nodes;
// Traverse nGraph Function in topological order
for (auto & node : f->get_ordered_ops()) {

View File

@ -18,8 +18,6 @@ class MyFunctionTransformation;
// template_function_transformation.hpp
class ngraph::pass::MyFunctionTransformation: public ngraph::pass::FunctionPass {
public:
MyFunctionTransformation() : FunctionPass() {}
bool run_on_function(std::shared_ptr<ngraph::Function> f) override;
};
// ! [function_pass:template_transformation_hpp]

View File

@ -16,8 +16,8 @@ using namespace ngraph;
// template_pattern_transformation.cpp
ngraph::pass::DecomposeDivideMatcher::DecomposeDivideMatcher() {
// Pattern example
auto input0 = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto input1 = std::make_shared<pattern::op::Label>(element::f32, Shape{});
auto input0 = pattern::any_input();
auto input1 = pattern::any_input();
auto div = std::make_shared<ngraph::opset3::Divide>(input0, input1);
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
@ -49,7 +49,7 @@ ngraph::pass::DecomposeDivideMatcher::DecomposeDivideMatcher() {
// Register pattern with Divide operation as a pattern root node
auto m = std::make_shared<ngraph::pattern::Matcher>(div, "ConvertDivide");
// Register Matcher
this->register_matcher(m, callback);
register_matcher(m, callback);
}
// ! [graph_rewrite:template_transformation_cpp]
@ -82,7 +82,7 @@ ngraph::pass::ReluReluFusionMatcher::ReluReluFusionMatcher() {
// Register pattern with Relu operation as a pattern root node
auto m = std::make_shared<ngraph::pattern::Matcher>(m_relu2, "ReluReluFusion");
// Register Matcher
this->register_matcher(m, callback);
register_matcher(m, callback);
}
// ! [matcher_pass:relu_fusion]
@ -137,3 +137,16 @@ pass.add_matcher<ngraph::pass::ReluReluFusionMatcher>();
pass.run_on_function(f);
// ! [matcher_pass:graph_rewrite]
}
// ! [manual_constant_folding]
template <class T>
Output<Node> eltwise_fold(const Output<Node> & input0, const Output<Node> & input1) {
auto eltwise = std::make_shared<T>(input0, input1);
OutputVector output(eltwise->get_output_size());
// If constant folding wasn't successful return eltwise output
if (!eltwise->constant_fold(output, {input0, input1})) {
return eltwise->output(0);
}
return output[0];
}
// ! [manual_constant_folding]

View File

@ -17,6 +17,10 @@ class ReluReluFusionMatcher;
// ! [graph_rewrite:template_transformation_hpp]
// template_pattern_transformation.hpp
/**
* @ingroup ie_transformation_common_api
* @brief Add transformation description.
*/
class ngraph::pass::DecomposeDivideMatcher: public ngraph::pass::MatcherPass {
public:
DecomposeDivideMatcher();