Files
openvino/docs/snippets/ov_model_snippets.cpp
Ilya Churaev 9da124544a Transformation guide (#10628)
* Fixed some comments about transformations

* Changed transformation guide

* Fixed typo

* Moved transformation doc to extensibility

* Moved images to Extensibility_UG

* Added separate document for each pass

* Added see also section

* Fixed comments
2022-03-01 09:03:59 +03:00

230 lines
8.1 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <openvino/openvino.hpp>
#include <openvino/opsets/opset8.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
#include <openvino/pass/manager.hpp>
#include <openvino/pass/graph_rewrite.hpp>
#include <openvino/core/rt_info.hpp>
void pattern_matcher_examples(std::shared_ptr<ov::Node> node) {
{
// ! [pattern:simple_example]
// Pattern example
auto input = std::make_shared<ov::opset8::Parameter>(ov::element::i64, ov::Shape{1});
auto shapeof = std::make_shared<ov::opset8::ShapeOf>(input);
// Create Matcher with Parameter->ShapeOf pattern
auto m = std::make_shared<ov::pass::pattern::Matcher>(shapeof, "MyPatternBasedTransformation");
// ! [pattern:simple_example]
// ! [pattern:callback_example]
ov::graph_rewrite_callback callback = [](ov::pass::pattern::Matcher& m) {
// Get root node
std::shared_ptr<ov::Node> root_node = m.get_match_root();
// Get all nodes matched by pattern
ov::NodeVector nodes = m.get_matched_nodes();
// Transformation code
return false;
};
// ! [pattern:callback_example]
}
{
// ! [pattern:label_example]
// Detect Multiply with arbitrary first input and second as Constant
// ov::pattern::op::Label - represent arbitrary input
auto input = ov::pass::pattern::any_input();
auto value = ov::opset8::Constant::create(ov::element::f32, ov::Shape{1}, {0.5});
auto mul = std::make_shared<ov::opset8::Multiply>(input, value);
auto m = std::make_shared<ov::pass::pattern::Matcher>(mul, "MultiplyMatcher");
// ! [pattern:label_example]
}
{
// ! [pattern:concat_example]
// Detect Concat operation with arbitrary number of inputs
auto concat = ov::pass::pattern::wrap_type<ov::opset8::Concat>();
auto m = std::make_shared<ov::pass::pattern::Matcher>(concat, "ConcatMatcher");
// ! [pattern:concat_example]
}
{
// ! [pattern:predicate_example]
// Detect Multiply->Add sequence where mul has exactly one consumer
auto mul = ov::pass::pattern::wrap_type<ov::opset8::Multiply>(ov::pass::pattern::consumers_count(1)/*сheck consumers count*/);
auto add = ov::pass::pattern::wrap_type<ov::opset8::Add>({mul, ov::pass::pattern::any_input()});
auto m = std::make_shared<ov::pass::pattern::Matcher>(add, "MultiplyAddMatcher");
// Matcher can be used to match pattern manually on given node
if (m->match(node->output(0))) {
// Successfully matched
}
// ! [pattern:predicate_example]
}
}
bool openvino_api_examples(std::shared_ptr<ov::Node> node) {
{
// ! [ov:ports_example]
// Let's suppose that node is opset8::Convolution operation
// as we know opset8::Convolution has two input ports (data, weights) and one output port
ov::Input<ov::Node> data = node->input(0);
ov::Input<ov::Node> weights = node->input(1);
ov::Output<ov::Node> output = node->output(0);
// Getting shape and type
auto pshape = data.get_partial_shape();
auto el_type = data.get_element_type();
// Getting parent for input port
ov::Output<ov::Node> parent_output;
parent_output = data.get_source_output();
// Another short way to get partent for output port
parent_output = node->input_value(0);
// Getting all consumers for output port
auto consumers = output.get_target_inputs();
// ! [ov:ports_example]
}
{
// ! [ngraph:shape_check]
auto partial_shape = node->input(0).get_partial_shape(); // get zero input partial shape
// Check that input shape rank is static
if (!partial_shape.rank().is_static()) {
return false;
}
auto rank_size = partial_shape.rank().get_length();
// Check that second dimension is not dynamic
if (rank_size < 2 || partial_shape[1].is_dynamic()) {
return false;
}
auto dim = partial_shape[1].get_length();
// ! [ngraph:shape_check]
}
return true;
}
// ! [ov:replace_node]
bool ov_replace_node(std::shared_ptr<ov::Node> node) {
// Step 1. Verify that node has opset8::Negative type
auto neg = std::dynamic_pointer_cast<ov::opset8::Negative>(node);
if (!neg) {
return false;
}
// Step 2. Create opset8::Multiply operation where the first input is negative operation input and second as Constant with -1 value
auto mul = std::make_shared<ov::opset8::Multiply>(neg->input_value(0),
ov::opset8::Constant::create(neg->get_element_type(), ov::Shape{1}, {-1}));
mul->set_friendly_name(neg->get_friendly_name());
ov::copy_runtime_info(neg, mul);
// Step 3. Replace Negative operation with Multiply operation
ov::replace_node(neg, mul);
return true;
// Step 4. Negative operation will be removed automatically because all consumers was moved to Multiply operation
}
// ! [ov:replace_node]
bool ov_manual_replace_node(std::shared_ptr<ov::Node> node) {
auto neg = std::dynamic_pointer_cast<ov::opset8::Negative>(node);
if (!neg) {
return false;
}
auto mul = std::make_shared<ov::opset8::Multiply>(neg->input_value(0),
ov::opset8::Constant::create(neg->get_element_type(), ov::Shape{1}, {-1}));
mul->set_friendly_name(neg->get_friendly_name());
ov::copy_runtime_info(neg, mul);
// ! [ov:manual_replace]
// All neg->output(0) consumers will be moved to mul->output(0) port
neg->output(0).replace(mul->output(0));
// ! [ov:manual_replace]
return true;
}
// ! [ov:insert_node]
// Step 1. Lets suppose that we have a node with single output port and we want to insert additional operation new_node after it
void insert_example(std::shared_ptr<ov::Node> node) {
// Get all consumers for node
auto consumers = node->output(0).get_target_inputs();
// Step 2. Create new node. Let it be opset8::Relu.
auto new_node = std::make_shared<ov::opset8::Relu>(node);
// Step 3. Reconnect all consumers to new_node
for (auto input : consumers) {
input.replace_source_output(new_node);
}
}
// ! [ov:insert_node]
// ! [ov:insert_node_with_copy]
void insert_example_with_copy(std::shared_ptr<ov::Node> node) {
// Make a node copy
auto node_copy = node->clone_with_new_inputs(node->input_values());
// Create new node
auto new_node = std::make_shared<ov::opset8::Relu>(node_copy);
ov::replace_node(node, new_node);
}
// ! [ov:insert_node_with_copy]
void eliminate_example(std::shared_ptr<ov::Node> node) {
// ! [ov:eliminate_node]
// Suppose we have a node that we want to remove
bool success = ov::replace_output_update_name(node->output(0), node->input_value(0));
// ! [ov:eliminate_node]
}
void replace_friendly_name() {
auto div = std::make_shared<ov::opset8::Divide>();
// ! [ov:replace_friendly_name]
// Replace Div operation with Power and Multiply sub-graph and set original friendly name to Multiply operation
auto pow = std::make_shared<ov::opset8::Power>(div->input(1).get_source_output(),
ov::op::v0::Constant::create(div->get_input_element_type(1), ov::Shape{1}, {-1}));
auto mul = std::make_shared<ov::opset8::Multiply>(div->input(0).get_source_output(), pow);
mul->set_friendly_name(div->get_friendly_name());
ngraph::replace_node(div, mul);
// ! [ov:replace_friendly_name]
}
void constant_subgraph() {
// ! [ov:constant_subgraph]
// After ConstantFolding pass Power will be replaced with Constant
auto input = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::Shape{1});
auto pow = std::make_shared<ov::opset8::Power>(ov::opset8::Constant::create(ov::element::f32, ov::Shape{1}, {2}),
ov::opset8::Constant::create(ov::element::f32, ov::Shape{1}, {3}));
auto mul = std::make_shared<ov::opset8::Multiply>(input /* not constant input */, pow);
// ! [ov:constant_subgraph]
}
void copy_runtime_info_snippet() {
std::shared_ptr<ov::Node> transpose, reshape, div, pow, mul, conv, bias, conv_fused, a, b, c, e, f;
// ! [ov:copy_runtime_info]
// Replace Transpose with Reshape operation (1:1)
ov::copy_runtime_info(transpose, reshape);
// Replace Div operation with Power and Multiply sub-graph (1:N)
ov::copy_runtime_info(div, {pow, mul});
// Fuse Convolution with Add operation (N:1)
ov::copy_runtime_info({conv, bias}, {conv_fused});
// Any other transformation that replaces one sub-graph with another sub-graph (N:M)
ov::copy_runtime_info({a, b, c}, {e, f});
// ! [ov:copy_runtime_info]
}