Transformations Python API (#10971)

* Keep changes

* Update tests

* Keep changes

* Cleanup

* Add predicates support; new pattern ops; new tests

* support for public passes; added tests

* Fix compilation warning

* Fix code style

* Added docstrings; code cleanup

* Update python API tests

* Fix build on Windows

* Revert back pass registration logic

* Fix flake8 errors

* Update docstrings; fix utils.hpp

* Cleanup

* Cleanup

* Fix flake errors

* Fix mypy

* Skip mypy for passes
This commit is contained in:
Gleb Kazantaev 2022-03-29 11:41:23 +03:00 committed by GitHub
parent 02c60c76ab
commit 866f006a83
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 1811 additions and 83 deletions

View File

@ -1,6 +1,11 @@
# Copyright (C) 2018-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# type: ignore
# flake8: noqa
from openvino.pyopenvino.passes import Manager
from openvino.pyopenvino.passes import ModelPass, Matcher, MatcherPass, PassBase, WrapType, Or, AnyInput
from openvino.pyopenvino.passes import consumers_count, has_static_dim, has_static_dims, has_static_shape,\
has_static_rank, rank_equals, type_matches, type_matches_any
from openvino.pyopenvino.passes import Serialize, ConstantFolding, VisualizeTree, MakeStateful, LowLatency2, ConvertFP32ToFP16
from openvino.runtime.passes.manager import Manager
from openvino.runtime.passes.graph_rewrite import GraphRewrite, BackwardGraphRewrite

View File

@ -0,0 +1,30 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# type: ignore
from openvino.pyopenvino.passes import MatcherPass
from openvino.pyopenvino.passes import GraphRewrite as GraphRewriteBase
from openvino.pyopenvino.passes import BackwardGraphRewrite as BackwardGraphRewriteBase
class GraphRewrite(GraphRewriteBase):
"""GraphRewrite that additionally holds python transformations objects."""
def __init__(self) -> None:
super().__init__()
self.passes_list = [] # need to keep python instances alive
def add_matcher(self, transformation: MatcherPass) -> MatcherPass:
"""Append MatcherPass instance to the end of execution list."""
self.passes_list.append(transformation)
return super().add_matcher(transformation)
class BackwardGraphRewrite(BackwardGraphRewriteBase):
"""BackwardGraphRewriteBase that additionally holds python transformations objects."""
def __init__(self) -> None:
super().__init__()
self.passes_list = [] # need to keep python instances alive
def add_matcher(self, transformation: MatcherPass) -> MatcherPass:
"""Append MatcherPass instance to the end of execution list."""
self.passes_list.append(transformation)
return super().add_matcher(transformation)

View File

@ -0,0 +1,24 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
# type: ignore
from openvino.pyopenvino.passes import Manager as ManagerBase
from openvino.pyopenvino.passes import PassBase
class Manager(ManagerBase):
"""Manager that additionally holds transformations objects."""
def __init__(self) -> None:
super().__init__()
self.passes_list = [] # need to keep python instances alive
def register_pass(self, *args, **kwargs) -> PassBase:
"""Register transformation for further execution."""
for arg in args:
if isinstance(arg, PassBase):
self.passes_list.append(arg)
for arg in kwargs.values():
if isinstance(arg, PassBase):
self.passes_list.append(arg)
return super().register_pass(*args, **kwargs)

View File

@ -4,4 +4,4 @@
"""Generic utilities. Factor related functions out to separate files."""
from openvino.pyopenvino.util import numpy_to_c
from openvino.pyopenvino.util import get_constant_from_source
from openvino.pyopenvino.util import get_constant_from_source, replace_node, replace_output_update_name

View File

@ -62,7 +62,7 @@ endif()
# create target
file(GLOB_RECURSE SOURCES core/*.cpp graph/*.cpp frontend/*.cpp pyopenvino.cpp)
file(GLOB_RECURSE SOURCES core/*.cpp graph/*.cpp frontend/*.cpp utils/*cpp pyopenvino.cpp)
list(FILTER SOURCES EXCLUDE REGEX frontend/onnx|tensorflow|paddle/* )
pybind11_add_module(${PROJECT_NAME} MODULE ${SOURCES})

View File

@ -0,0 +1,75 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "pyopenvino/graph/passes/graph_rewrite.hpp"
#include <pybind11/pybind11.h>
#include <openvino/pass/graph_rewrite.hpp>
#include <openvino/pass/pass.hpp>
namespace py = pybind11;
void regclass_passes_GraphRewrite(py::module m) {
py::class_<ov::pass::GraphRewrite, std::shared_ptr<ov::pass::GraphRewrite>, ov::pass::ModelPass, ov::pass::PassBase>
graph_rewrite(m, "GraphRewrite");
graph_rewrite.doc() =
"openvino.runtime.passes.GraphRewrite executes sequence of MatcherPass transformations in topological order";
graph_rewrite.def(py::init<>());
graph_rewrite.def(py::init([](const std::shared_ptr<ov::pass::MatcherPass>& pass) {
return std::make_shared<ov::pass::GraphRewrite>(pass);
}),
py::arg("pass"),
R"(
Register single MatcherPass pass inside GraphRewrite.
:param pass: openvino.runtime.passes.MatcherPass instance
:type pass: openvino.runtime.passes.MatcherPass
)");
graph_rewrite.def("add_matcher",
static_cast<std::shared_ptr<ov::pass::MatcherPass> (ov::pass::GraphRewrite::*)(
const std::shared_ptr<ov::pass::MatcherPass>&)>(&ov::pass::GraphRewrite::add_matcher),
py::arg("pass"),
R"(
Register single MatcherPass pass inside GraphRewrite.
:param pass: openvino.runtime.passes.MatcherPass instance
:type pass: openvino.runtime.passes.MatcherPass
)");
py::class_<ov::pass::BackwardGraphRewrite,
std::shared_ptr<ov::pass::BackwardGraphRewrite>,
ov::pass::GraphRewrite,
ov::pass::ModelPass,
ov::pass::PassBase>
back_graph_rewrite(m, "BackwardGraphRewrite");
back_graph_rewrite.doc() = "openvino.runtime.passes.BackwardGraphRewrite executes sequence of MatcherPass "
"transformations in reversed topological order";
back_graph_rewrite.def(py::init<>());
back_graph_rewrite.def(py::init([](const std::shared_ptr<ov::pass::MatcherPass>& pass) {
return std::make_shared<ov::pass::BackwardGraphRewrite>(pass);
}),
py::arg("pass"),
R"(
Register single MatcherPass pass inside BackwardGraphRewrite.
:param pass: openvino.runtime.passes.MatcherPass instance
:type pass: openvino.runtime.passes.MatcherPass
)");
back_graph_rewrite.def(
"add_matcher",
static_cast<std::shared_ptr<ov::pass::MatcherPass> (ov::pass::BackwardGraphRewrite::*)(
const std::shared_ptr<ov::pass::MatcherPass>&)>(&ov::pass::BackwardGraphRewrite::add_matcher),
py::arg("pass"),
R"(
Register single MatcherPass pass inside BackwardGraphRewrite.
:param pass: openvino.runtime.passes.MatcherPass instance
:type pass: openvino.runtime.passes.MatcherPass
)");
}

View File

@ -0,0 +1,11 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_passes_GraphRewrite(py::module m);

View File

@ -31,63 +31,84 @@ inline Version convert_to_version(const std::string& version) {
"'! The supported versions are: 'UNSPECIFIED'(default), 'IR_V10', 'IR_V11'.");
}
namespace {
class ManagerWrapper : public ov::pass::Manager {
public:
ManagerWrapper() {}
~ManagerWrapper() {}
void register_pass(const std::string& pass_name) {
if (pass_name == "ConstantFolding")
push_pass<ov::pass::ConstantFolding>();
if (m_per_pass_validation)
push_pass<ov::pass::Validate>();
return;
}
void register_pass(const std::string& pass_name, const FilePaths& file_paths, const std::string& version) {
if (pass_name == "Serialize") {
push_pass<ov::pass::Serialize>(file_paths.first, file_paths.second, convert_to_version(version));
}
return;
}
void register_pass(const std::string& pass_name,
const std::string& xml_path,
const std::string& bin_path,
const std::string& version) {
if (pass_name == "Serialize")
push_pass<ov::pass::Serialize>(xml_path, bin_path, convert_to_version(version));
return;
}
};
} // namespace
void regclass_graph_passes_Manager(py::module m) {
py::class_<ManagerWrapper> manager(m, "Manager");
manager.doc() = "openvino.runtime.passes.Manager wraps ov::pass::Manager using ManagerWrapper";
void regclass_passes_Manager(py::module m) {
py::class_<ov::pass::Manager> manager(m, "Manager");
manager.doc() = "openvino.runtime.passes.Manager executes sequence of transformation on a given Model";
manager.def(py::init<>());
manager.def("set_per_pass_validation", &ManagerWrapper::set_per_pass_validation);
manager.def("run_passes", &ManagerWrapper::run_passes);
manager.def("register_pass",
(void (ManagerWrapper::*)(const std::string&)) & ManagerWrapper::register_pass,
py::arg("pass_name"),
manager.def("set_per_pass_validation",
&ov::pass::Manager::set_per_pass_validation,
py::arg("new_state"),
R"(
Set the type of register pass for pass manager.
Enables or disables Model validation after each pass execution.
:param pass_name : String to set the type of a pass.
:type pass_name: str
// )");
:param new_state: flag which enables or disables model validation.
:type new_state: bool
)");
manager.def("run_passes",
&ov::pass::Manager::run_passes,
py::arg("model"),
R"(
Executes sequence of transformations on given Model.
:param model: openvino.runtime.Model to be transformed.
:type model: openvino.runtime.Model
)");
manager.def("register_pass",
(void (ManagerWrapper::*)(const std::string&, const FilePaths&, const std::string&)) &
ManagerWrapper::register_pass,
py::arg("pass_name"),
py::arg("output_files"),
py::arg("version") = "UNSPECIFIED",
&ov::pass::Manager::register_pass_instance,
py::arg("transformation"),
R"(
Register pass instance for execution. Execution order matches the registration order.
:param transformation: transformation instance.
:type transformation: openvino.runtime.passes.PassBase
)");
manager.def(
"register_pass",
[](ov::pass::Manager& self, const std::string& pass_name) -> void {
PyErr_WarnEx(PyExc_DeprecationWarning,
"register_pass with this arguments is deprecated! "
"Please use register_pass(ConstantFolding()) instead.",
1);
if (pass_name == "ConstantFolding") {
self.register_pass<ov::pass::ConstantFolding>();
}
},
py::arg("pass_name"),
R"(
This method is deprecated. Please use m.register_pass(ConstantFolding()) instead.
Register pass by name from the list of predefined passes.
:param pass_name: String to set the type of a pass.
:type pass_name: str
)");
manager.def(
"register_pass",
[](ov::pass::Manager& self,
const std::string& pass_name,
const FilePaths& file_paths,
const std::string& version) -> void {
PyErr_WarnEx(PyExc_DeprecationWarning,
"register_pass with this arguments is deprecated! "
"Please use register_pass(Serialize(xml, bin, version)) instead.",
1);
if (pass_name == "Serialize") {
self.register_pass<ov::pass::Serialize>(file_paths.first,
file_paths.second,
convert_to_version(version));
}
},
py::arg("pass_name"),
py::arg("output_files"),
py::arg("version") = "UNSPECIFIED",
R"(
This method is deprecated. Please use m.register_pass(Serialize(...)) instead.
Set the type of register pass for pass manager.
:param pass_name: String to set the type of a pass.
@ -100,6 +121,7 @@ void regclass_graph_passes_Manager(py::module m) {
- "IR_V10" : v10 IR
- "IR_V11" : v11 IR
:type version: str
Examples
----------
1. Default Version
@ -108,16 +130,30 @@ void regclass_graph_passes_Manager(py::module m) {
2. IR version 11
pass_manager = Manager()
pass_manager.register_pass("Serialize", output_files=("example.xml", "example.bin"), version="IR_V11")
// )");
)");
manager.def(
"register_pass",
(void (ManagerWrapper::*)(const std::string&, const std::string&, const std::string&, const std::string&)) &
ManagerWrapper::register_pass,
[](ov::pass::Manager& self,
const std::string& pass_name,
const std::string& xml_path,
const std::string& bin_path,
const std::string& version) -> void {
PyErr_WarnEx(PyExc_DeprecationWarning,
"register_pass with this arguments is deprecated! "
"Please use register_pass(Serialize(xml, bin, version)) instead.",
1);
if (pass_name == "Serialize") {
self.register_pass<ov::pass::Serialize>(xml_path, bin_path, convert_to_version(version));
}
},
py::arg("pass_name"),
py::arg("xml_path"),
py::arg("bin_path"),
py::arg("version") = "UNSPECIFIED",
R"(
This method is deprecated. Please use m.register_pass(Serialize(...)) instead.
Set the type of register pass for pass manager.
:param pass_name: String to set the type of a pass.
@ -132,6 +168,7 @@ void regclass_graph_passes_Manager(py::module m) {
- "IR_V10" : v10 IR
- "IR_V11" : v11 IR
:type version: str
Examples
----------
1. Default Version
@ -140,5 +177,5 @@ void regclass_graph_passes_Manager(py::module m) {
2. IR version 11
pass_manager = Manager()
pass_manager.register_pass("Serialize", xml_path="example.xml", bin_path="example.bin", version="IR_V11")
// )");
)");
}

View File

@ -8,4 +8,4 @@
namespace py = pybind11;
void regclass_graph_passes_Manager(py::module m);
void regclass_passes_Manager(py::module m);

View File

@ -0,0 +1,201 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "pyopenvino/graph/passes/matcher_pass.hpp"
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <string>
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
namespace py = pybind11;
void regclass_passes_Matcher(py::module m) {
py::class_<ov::pass::pattern::Matcher, std::shared_ptr<ov::pass::pattern::Matcher>> matcher(m, "Matcher");
matcher.doc() = "openvino.runtime.passes.Matcher wraps ov::pass::pattern::Matcher";
matcher.def(py::init([](const std::shared_ptr<ov::Node>& node, const std::string& name) {
return std::make_shared<ov::pass::pattern::Matcher>(node, name);
}),
py::arg("node"),
py::arg("name"),
R"(
Creates Matcher object with given pattern root node and matcher name.
Matcher object is used for pattern matching on Model.
:param node: pattern root node.
:type node: openvino.runtime.Node
:param name: pattern name. Usually matches the MatcherPass class name.
:type name: str
)");
matcher.def(py::init([](ov::Output<ov::Node>& output, const std::string& name) {
return std::make_shared<ov::pass::pattern::Matcher>(output, name);
}),
py::arg("output"),
py::arg("name"),
R"(
Creates Matcher object with given pattern root node output and matcher name.
Matcher object is used for pattern matching on Model.
:param node: pattern root node output.
:type node: openvino.runtime.Output
:param name: pattern name. Usually matches the MatcherPass class name.
:type name: str
)");
matcher.def("get_name",
&ov::pass::pattern::Matcher::get_name,
R"(
Get Matcher name.
:return: openvino.runtime.passes.Matcher name.
:rtype: str
)");
matcher.def("get_match_root",
&ov::pass::pattern::Matcher::get_match_root,
R"(
Get matched root node inside Model. Should be used after match() method is called.
:return: matched node.
:rtype: openvino.runtime.Node
)");
matcher.def("get_match_value",
&ov::pass::pattern::Matcher::get_match_value,
R"(
Get matched node output inside Model. Should be used after match() method is called.
:return: matched node output.
:rtype: openvino.runtime.Output
)");
matcher.def("get_match_nodes",
&ov::pass::pattern::Matcher::get_matched_nodes,
R"(
Get NodeVector of matched nodes. Should be used after match() method is called.
:return: matched nodes vector.
:rtype: List[openvino.runtime.Node]
)");
matcher.def("get_match_values",
static_cast<const ov::OutputVector& (ov::pass::pattern::Matcher::*)() const>(
&ov::pass::pattern::Matcher::get_matched_values),
R"(
Get OutputVector of matched outputs. Should be used after match() method is called.
:return: matched outputs vector.
:rtype: List[openvino.runtime.Output]
)");
matcher.def("get_pattern_value_map",
&ov::pass::pattern::Matcher::get_pattern_value_map,
R"(
Get map which can be used to access matched nodes using nodes from pattern.
Should be used after match() method is called.
:return: mapping of pattern nodes to matched nodes.
:rtype: dict
)");
matcher.def("match",
static_cast<bool (ov::pass::pattern::Matcher::*)(const ov::Output<ov::Node>&)>(
&ov::pass::pattern::Matcher::match),
R"(
Matches registered pattern starting from given output.
:return: status of matching.
:rtype: bool
)");
matcher.def("match",
static_cast<bool (ov::pass::pattern::Matcher::*)(std::shared_ptr<ov::Node>)>(
&ov::pass::pattern::Matcher::match),
R"(
Matches registered pattern starting from given Node.
:return: status of matching.
:rtype: bool
)");
}
class PyMatcherPass : public ov::pass::MatcherPass {
public:
void py_register_matcher(const std::shared_ptr<ov::pass::pattern::Matcher>& matcher,
const ov::matcher_pass_callback& callback) {
register_matcher(matcher, callback);
}
};
void regclass_passes_MatcherPass(py::module m) {
py::class_<ov::pass::MatcherPass, std::shared_ptr<ov::pass::MatcherPass>, ov::pass::PassBase, PyMatcherPass>
matcher_pass(m, "MatcherPass");
matcher_pass.doc() = "openvino.runtime.passes.MatcherPass wraps ov::pass::MatcherPass";
matcher_pass.def(py::init<>());
matcher_pass.def(
py::init([](const std::shared_ptr<ov::pass::pattern::Matcher>& m, ov::matcher_pass_callback callback) {
return std::make_shared<ov::pass::MatcherPass>(m, callback);
}),
py::arg("matcher"),
py::arg("callback"),
R"(
Create MatcherPass from existing Matcher and callback objects.
:param matcher: openvino.runtime.passes.Matcher with registered pattern.
:type matcher: openvino.runtime.passes.Matcher
:param callback: Function that performs transformation on the matched nodes.
:type callback: function
:return: created openvino.runtime.passes.MatcherPass instance.
:rtype: openvino.runtime.passes.MatcherPass
)");
matcher_pass.def("apply",
&ov::pass::MatcherPass::apply,
py::arg("node"),
R"(
Execute MatcherPass on given Node.
:return: callback return code.
:rtype: bool
)");
matcher_pass.def("register_new_node",
&ov::pass::MatcherPass::register_new_node_,
py::arg("node"),
R"(
Register node for additional pattern matching.
:param node: openvino.runtime.Node for matching.
:type node: openvino.runtime.Node
:return: registered node instance
:rtype: openvino.runtime.Node
)");
matcher_pass.def("register_matcher",
static_cast<void (ov::pass::MatcherPass::*)(const std::shared_ptr<ov::pass::pattern::Matcher>&,
const ov::graph_rewrite_callback& callback)>(
&PyMatcherPass::py_register_matcher),
py::arg("matcher"),
py::arg("callback"),
R"(
Initialize matcher and callback for further execution.
:param matcher: openvino.runtime.passes.Matcher with registered pattern.
:type matcher: openvino.runtime.passes.Matcher
:param callback: Function that performs transformation on the matched nodes.
:type callback: function
)");
}

View File

@ -0,0 +1,13 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_passes_Matcher(py::module m);
void regclass_passes_MatcherPass(py::module m);

View File

@ -0,0 +1,47 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "pyopenvino/graph/passes/model_pass.hpp"
#include <pybind11/pybind11.h>
#include <openvino/pass/pass.hpp>
#include <string>
namespace py = pybind11;
class PyModelPass : public ov::pass::ModelPass {
public:
/* Inherit the constructors */
using ov::pass::ModelPass::ModelPass;
/* Trampoline (need one for each virtual function) */
bool run_on_model(const std::shared_ptr<ov::Model>& model) override {
PYBIND11_OVERRIDE_PURE(bool, /* Return type */
ov::pass::ModelPass, /* Parent class */
run_on_model, /* Name of function in C++ (must match Python name) */
model /* Argument(s) */
);
}
};
void regclass_passes_ModelPass(py::module m) {
py::class_<ov::pass::ModelPass, std::shared_ptr<ov::pass::ModelPass>, ov::pass::PassBase, PyModelPass> model_pass(
m,
"ModelPass");
model_pass.doc() = "openvino.runtime.passes.ModelPass wraps ov::pass::ModelPass";
model_pass.def(py::init<>());
model_pass.def("run_on_model",
&ov::pass::ModelPass::run_on_model,
py::arg("model"),
R"(
run_on_model must be defined in inherited class. This method is used to work with Model directly.
:param model: openvino.runtime.Model to be transformed.
:type model: openvino.runtime.Model
:return: True in case if Model was changed and False otherwise.
:rtype: bool
)");
}

View File

@ -0,0 +1,11 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_passes_ModelPass(py::module m);

View File

@ -0,0 +1,34 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "pyopenvino/graph/passes/pass_base.hpp"
#include <pybind11/pybind11.h>
#include <memory>
#include <openvino/pass/pass.hpp>
namespace py = pybind11;
void regclass_passes_PassBase(py::module m) {
py::class_<ov::pass::PassBase, std::shared_ptr<ov::pass::PassBase>> pass_base(m, "PassBase");
pass_base.doc() = "openvino.runtime.passes.PassBase wraps ov::pass::PassBase";
pass_base.def("set_name",
&ov::pass::PassBase::set_name,
py::arg("name"),
R"(
Set transformation name.
:param name: Transformation name.
:type name: str
)");
pass_base.def("get_name",
&ov::pass::PassBase::get_name,
R"(
Get transformation name.
:return: Transformation name.
:rtype: str
)");
}

View File

@ -0,0 +1,11 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_passes_PassBase(py::module m);

View File

@ -0,0 +1,504 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "pyopenvino/graph/passes/pattern_ops.hpp"
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <iterator>
#include <sstream>
#include <string>
#include "ngraph/opsets/opset.hpp"
#include "openvino/pass/pattern/op/label.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/pattern.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
ov::NodeTypeInfo get_type(const std::string& type_name) {
// Supported types: opsetX.OpName or opsetX::OpName
std::string opset_type;
auto it = type_name.cbegin();
while (it != type_name.cend() && *it != '.' && *it != ':') {
opset_type += *it;
++it;
}
// Skip delimiter
while (it != type_name.cend() && (*it == '.' || *it == ':')) {
++it;
}
// Get operation type name
std::string operation_type(it, type_name.end());
// TODO: create generic opset factory in Core so it can be reused
const std::unordered_map<std::string, std::function<const ngraph::OpSet&()>> get_opset{
{"opset1", ngraph::get_opset1},
{"opset2", ngraph::get_opset2},
{"opset3", ngraph::get_opset3},
{"opset4", ngraph::get_opset4},
{"opset5", ngraph::get_opset5},
{"opset6", ngraph::get_opset6},
{"opset7", ngraph::get_opset7},
{"opset8", ngraph::get_opset8},
};
if (!get_opset.count(opset_type)) {
throw std::runtime_error("Unsupported opset type: " + opset_type);
}
const ngraph::OpSet& m_opset = get_opset.at(opset_type)();
if (!m_opset.contains_type(operation_type)) {
throw std::runtime_error("Unrecognized operation type: " + operation_type);
}
return m_opset.create(operation_type)->get_type_info();
}
std::vector<ov::NodeTypeInfo> get_types(const std::vector<std::string>& type_names) {
std::vector<ov::NodeTypeInfo> types;
for (const auto& type_name : type_names) {
types.emplace_back(get_type(type_name));
}
return types;
}
using Predicate = const ov::pass::pattern::op::ValuePredicate;
void reg_pattern_wrap_type(py::module m) {
py::class_<ov::pass::pattern::op::WrapType, std::shared_ptr<ov::pass::pattern::op::WrapType>, ov::Node> wrap_type(
m,
"WrapType");
wrap_type.doc() = "openvino.runtime.passes.WrapType wraps ov::pass::pattern::op::WrapType";
wrap_type.def(py::init([](const std::string& type_name) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name));
}),
py::arg("type_name"),
R"(
Create WrapType with given node type.
:param type_name: node type. For example: "opset8.Abs"
:type type_name: str
)");
wrap_type.def(py::init([](const std::string& type_name, const Predicate& pred) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name), pred);
}),
py::arg("type_name"),
py::arg("pred"),
R"(
Create WrapType with given node type and predicate.
:param type_name: node type. For example: "opset8.Abs"
:type type_name: str
:param predicate: Function that performs additional checks for matching.
:type predicate: function
)");
wrap_type.def(py::init([](const std::string& type_name, const ov::Output<ov::Node>& input) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
nullptr,
ov::OutputVector{input});
}),
py::arg("type_name"),
py::arg("input"),
R"(
Create WrapType with given node type and input node.
:param type_name: node type. For example: "opset8.Abs"
:type type_name: str
:param input: Node output.
:type input: openvino.runtime.Output
)");
wrap_type.def(py::init([](const std::string& type_name, const std::shared_ptr<ov::Node>& input) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
nullptr,
ov::OutputVector{input});
}),
py::arg("type_name"),
py::arg("input"),
R"(
Create WrapType with given node type and input node.
:param type_name: node type. For example: opset8.Abs
:type type_name: str
:param input: Input node.
:type input: openvino.runtime.Node
)");
wrap_type.def(py::init([](const std::string& type_name, const ov::Output<ov::Node>& input, const Predicate& pred) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
pred,
ov::OutputVector{input});
}),
py::arg("type_name"),
py::arg("input"),
py::arg("predicate"),
R"(
Create WrapType with given node type, input node and predicate.
:param type_name: node type. For example: "opset8.Abs"
:type type_name: str
:param input: Node output.
:type input: openvino.runtime.Output
:param predicate: Function that performs additional checks for matching.
:type predicate: function
)");
wrap_type.def(
py::init([](const std::string& type_name, const std::shared_ptr<ov::Node>& input, const Predicate& pred) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
pred,
ov::OutputVector{input});
}),
py::arg("type_name"),
py::arg("input"),
py::arg("predicate"),
R"(
Create WrapType with given node type, input node and predicate.
:param type_name: node type. For example: "opset8.Abs"
:type type_name: str
:param input: Input node.
:type input: openvino.runtime.Node
:param predicate: Function that performs additional checks for matching.
:type predicate: function
)");
wrap_type.def(py::init([](const std::string& type_name, const ov::OutputVector& inputs) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name), nullptr, inputs);
}),
py::arg("type_name"),
py::arg("inputs"),
R"(
Create WrapType with given node type and input nodes.
:param type_name: node type. For example: "opset8.Abs"
:type type_name: str
:param inputs: Node outputs.
:type inputs: List[openvino.runtime.Output]
)");
wrap_type.def(py::init([](const std::string& type_name, const ov::NodeVector& inputs) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
nullptr,
ov::as_output_vector(inputs));
}),
py::arg("type_name"),
py::arg("inputs"),
R"(
Create WrapType with given node type and input nodes.
:param type_name: node type. For example: "opset8.Abs"
:type type_name: str
:param inputs: Input nodes.
:type inputs: List[openvino.runtime.Node]
)");
wrap_type.def(py::init([](const std::string& type_name, const ov::OutputVector& inputs, const Predicate& pred) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name), pred, inputs);
}),
py::arg("type_name"),
py::arg("inputs"),
py::arg("predicate"),
R"(
Create WrapType with given node type, input nodes and predicate.
:param type_name: node type. For example: "opset8.Abs"
:type type_name: str
:param inputs: Node outputs.
:type inputs: List[openvino.runtime.Output]
:param predicate: Function that performs additional checks for matching.
:type predicate: function
)");
wrap_type.def(py::init([](const std::string& type_name, const ov::NodeVector& inputs, const Predicate& pred) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
pred,
ov::as_output_vector(inputs));
}),
py::arg("type_name"),
py::arg("inputs"),
py::arg("predicate"),
R"(
Create WrapType with given node type, input nodes and predicate.
:param type_name: node type. For example: "opset8.Abs"
:type type_name: str
:param inputs: Input nodes.
:type inputs: List[openvino.runtime.Node]
:param predicate: Function that performs additional checks for matching.
:type predicate: function
)");
wrap_type.def(py::init([](const std::vector<std::string>& type_names) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names));
}),
py::arg("type_names"),
R"(
Create WrapType with given node types.
:param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"]
:type type_names: List[str]
)");
wrap_type.def(py::init([](const std::vector<std::string>& type_names, const Predicate& pred) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names), pred);
}),
py::arg("type_names"),
py::arg("predicate"),
R"(
Create WrapType with given node types and predicate.
:param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"]
:type type_names: List[str]
:param predicate: Function that performs additional checks for matching.
:type predicate: function
)");
wrap_type.def(py::init([](const std::vector<std::string>& type_names, const ov::Output<ov::Node>& input) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
nullptr,
ov::OutputVector{input});
}),
py::arg("type_names"),
py::arg("input"),
R"(
Create WrapType with given node types and input.
:param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"]
:type type_names: List[str]
:param input: Node output.
:type input: openvino.runtime.Output
)");
wrap_type.def(py::init([](const std::vector<std::string>& type_names, const std::shared_ptr<ov::Node>& input) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
nullptr,
ov::OutputVector{input});
}),
py::arg("type_names"),
py::arg("input"),
R"(
Create WrapType with given node types and input.
:param type_name: node types. For example: ["opset8.Abs", "opset8.Relu"]
:type type_name: List[str]
:param input: Input node.
:type input: openvino.runtime.Node
)");
wrap_type.def(
py::init(
[](const std::vector<std::string>& type_names, const ov::Output<ov::Node>& input, const Predicate& pred) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
pred,
ov::OutputVector{input});
}),
py::arg("type_names"),
py::arg("input"),
py::arg("predicate"),
R"(
Create WrapType with given node types, input and predicate.
:param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"]
:type type_names: List[str]
:param input: Node output.
:type input: openvino.runtime.Output
:param predicate: Function that performs additional checks for matching.
:type predicate: function
)");
wrap_type.def(py::init([](const std::vector<std::string>& type_names,
const std::shared_ptr<ov::Node>& input,
const Predicate& pred) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
pred,
ov::OutputVector{input});
}),
py::arg("type_names"),
py::arg("input"),
py::arg("predicate"),
R"(
Create WrapType with given node types, input and predicate.
:param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"]
:type type_names: List[str]
:param input: Input node.
:type input: openvino.runtime.Node
:param predicate: Function that performs additional checks for matching.
:type predicate: function
)");
wrap_type.def(py::init([](const std::vector<std::string>& type_names, const ov::OutputVector& inputs) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names), nullptr, inputs);
}),
py::arg("type_names"),
py::arg("inputs"),
R"(
Create WrapType with given node types and input.
:param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"]
:type type_names: List[str]
:param inputs: Nodes outputs.
:type inputs: List[openvino.runtime.Output]
)");
wrap_type.def(py::init([](const std::vector<std::string>& type_names, const ov::NodeVector& inputs) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
nullptr,
ov::as_output_vector(inputs));
}),
py::arg("type_names"),
py::arg("inputs"),
R"(
Create WrapType with given node types and inputs.
:param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"]
:type type_names: List[str]
:param inputs: Input nodes.
:type inputs: List[openvino.runtime.Node]
)");
wrap_type.def(
py::init([](const std::vector<std::string>& type_names, const ov::OutputVector& inputs, const Predicate& pred) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names), pred, inputs);
}),
py::arg("type_names"),
py::arg("inputs"),
py::arg("predicate"),
R"(
Create WrapType with given node types, inputs and predicate.
:param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"]
:type type_names: List[str]
:param inputs: Nodes outputs.
:type inputs: List[openvino.runtime.Output]
:param predicate: Function that performs additional checks for matching.
:type predicate: function
)");
wrap_type.def(
py::init([](const std::vector<std::string>& type_names, const ov::NodeVector& inputs, const Predicate& pred) {
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
pred,
ov::as_output_vector(inputs));
}),
py::arg("type_names"),
py::arg("inputs"),
py::arg("predicate"),
R"(
Create WrapType with given node types, inputs and predicate.
:param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"]
:type type_names: List[str]
:param inputs: Input nodes.
:type inputs: List[openvino.runtime.Node]
:param predicate: Function that performs additional checks for matching.
:type predicate: function
)");
}
void reg_pattern_or(py::module m) {
py::class_<ov::pass::pattern::op::Or, std::shared_ptr<ov::pass::pattern::op::Or>, ov::Node> or_type(m, "Or");
or_type.doc() = "openvino.runtime.passes.Or wraps ov::pass::pattern::op::Or";
or_type.def(py::init([](const ov::OutputVector& inputs) {
return std::make_shared<ov::pass::pattern::op::Or>(inputs);
}),
py::arg("inputs"),
R"(
Create pattern Or operation which is used to match any of given inputs.
:param inputs: Operation inputs.
:type inputs: List[openvino.runtime.Output]
)");
or_type.def(py::init([](const ov::NodeVector& inputs) {
return std::make_shared<ov::pass::pattern::op::Or>(ov::as_output_vector(inputs));
}),
py::arg("inputs"),
R"(
Create pattern Or operation which is used to match any of given inputs.
:param inputs: Operation inputs.
:type inputs: List[openvino.runtime.Node]
)");
}
void reg_pattern_any_input(py::module m) {
py::class_<ov::pass::pattern::op::Label, std::shared_ptr<ov::pass::pattern::op::Label>, ov::Node> any_input(
m,
"AnyInput");
any_input.doc() = "openvino.runtime.passes.AnyInput wraps ov::pass::pattern::op::Label";
any_input.def(py::init([]() {
return std::make_shared<ov::pass::pattern::op::Label>();
}),
R"(
Create pattern AnyInput operation which is used to match any type of node.
)");
any_input.def(py::init([](const Predicate& pred) {
return std::make_shared<ov::pass::pattern::op::Label>(ov::element::dynamic,
ov::PartialShape::dynamic(),
pred);
}),
py::arg("predicate"),
R"(
Create pattern AnyInput operation which is used to match any type of node.
:param pred: Function that performs additional checks for matching.
:type pred: function
)");
}
void reg_predicates(py::module m) {
m.def("consumers_count", &ov::pass::pattern::consumers_count);
m.def("has_static_dim", &ov::pass::pattern::has_static_dim);
m.def("has_static_dims", &ov::pass::pattern::has_static_dims);
m.def("has_static_shape", &ov::pass::pattern::has_static_shape);
m.def("has_static_rank", &ov::pass::pattern::has_static_rank);
m.def("rank_equals", &ov::pass::pattern::rank_equals);
m.def("type_matches", &ov::pass::pattern::type_matches);
m.def("type_matches_any", &ov::pass::pattern::type_matches_any);
}
void reg_passes_pattern_ops(py::module m) {
reg_pattern_any_input(m);
reg_pattern_wrap_type(m);
reg_pattern_or(m);
reg_predicates(m);
}

View File

@ -0,0 +1,11 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void reg_passes_pattern_ops(py::module m);

View File

@ -6,9 +6,24 @@
#include <pybind11/pybind11.h>
#include "pyopenvino/graph/passes/graph_rewrite.hpp"
#include "pyopenvino/graph/passes/manager.hpp"
#include "pyopenvino/graph/passes/matcher_pass.hpp"
#include "pyopenvino/graph/passes/model_pass.hpp"
#include "pyopenvino/graph/passes/pass_base.hpp"
#include "pyopenvino/graph/passes/pattern_ops.hpp"
#include "pyopenvino/graph/passes/transformations.hpp"
namespace py = pybind11;
void regmodule_graph_passes(py::module m) {
py::module m_passes = m.def_submodule("passes", "Package openvino.runtime.passes wraps ov::passes");
regclass_graph_passes_Manager(m_passes);
regclass_passes_PassBase(m_passes);
regclass_passes_ModelPass(m_passes);
regclass_passes_GraphRewrite(m_passes);
regclass_passes_Matcher(m_passes);
regclass_passes_MatcherPass(m_passes);
regclass_transformations(m_passes);
regclass_passes_Manager(m_passes);
reg_passes_pattern_ops(m_passes);
}

View File

@ -5,7 +5,6 @@
#pragma once
#include <pybind11/pybind11.h>
#include "pyopenvino/graph/passes/manager.hpp"
namespace py = pybind11;

View File

@ -0,0 +1,111 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "pyopenvino/graph/passes/transformations.hpp"
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <memory>
#include <openvino/pass/constant_folding.hpp>
#include <openvino/pass/convert_fp32_to_fp16.hpp>
#include <openvino/pass/low_latency.hpp>
#include <openvino/pass/make_stateful.hpp>
#include <openvino/pass/pass.hpp>
#include <openvino/pass/serialize.hpp>
#include <openvino/pass/visualize_tree.hpp>
void regclass_transformations(py::module m) {
py::class_<ov::pass::Serialize, std::shared_ptr<ov::pass::Serialize>, ov::pass::ModelPass, ov::pass::PassBase>
serialize(m, "Serialize");
serialize.doc() = "openvino.runtime.passes.Serialize transformation";
serialize.def(py::init([](const std::string& path_to_xml, const std::string& path_to_bin) {
return std::make_shared<ov::pass::Serialize>(path_to_xml, path_to_bin);
}),
py::arg("path_to_xml"),
py::arg("path_to_bin"),
R"(
Create Serialize pass which is used for Model to IR serialization.
:param path_to_xml: Path where *.xml file will be saved.
:type path_to_xml: str
:param path_to_xml: Path where *.bin file will be saved.
:type path_to_xml: str
)");
serialize.def(
py::init(
[](const std::string& path_to_xml, const std::string& path_to_bin, ov::pass::Serialize::Version version) {
return std::make_shared<ov::pass::Serialize>(path_to_xml, path_to_bin, version);
}),
py::arg("path_to_xml"),
py::arg("path_to_bin"),
py::arg("version"),
R"(
Create Serialize pass which is used for Model to IR serialization.
:param path_to_xml: Path where *.xml file will be saved.
:type path_to_xml: str
:param path_to_xml: Path where *.bin file will be saved.
:type path_to_xml: str
:param version: serialized IR version.
:type version: int
)");
py::class_<ov::pass::ConstantFolding,
std::shared_ptr<ov::pass::ConstantFolding>,
ov::pass::ModelPass,
ov::pass::PassBase>
cf(m, "ConstantFolding");
cf.doc() = "openvino.runtime.passes.ConstantFolding transformation";
cf.def(py::init<>());
py::class_<ov::pass::VisualizeTree,
std::shared_ptr<ov::pass::VisualizeTree>,
ov::pass::ModelPass,
ov::pass::PassBase>
visualize(m, "VisualizeTree");
visualize.doc() = "openvino.runtime.passes.VisualizeTree transformation";
visualize.def(py::init<const std::string&, ov::pass::VisualizeTree::node_modifiers_t, bool>(),
py::arg("file_name"),
py::arg("nm") = nullptr,
py::arg("don_only") = false,
R"(
Create VisualizeTree pass which is used for Model to dot serialization.
:param file_name: Path where serialized model will be saved. For example: /tmp/out.svg
:type file_name: str
:param nm: Node modifier function.
:type nm: function
:param don_only: Enable only dot file generation.
:type don_only: bool
)");
py::class_<ov::pass::MakeStateful, std::shared_ptr<ov::pass::MakeStateful>, ov::pass::ModelPass, ov::pass::PassBase>
make_stateful(m, "MakeStateful");
make_stateful.doc() = "openvino.runtime.passes.MakeStateful transformation";
// TODO: update docstrings for c-tors below
make_stateful.def(py::init<const ov::pass::MakeStateful::ParamResPairs&>(), py::arg("pairs_to_replace"));
make_stateful.def(py::init<const std::map<std::string, std::string>&>());
py::class_<ov::pass::LowLatency2, std::shared_ptr<ov::pass::LowLatency2>, ov::pass::ModelPass, ov::pass::PassBase>
low_latency(m, "LowLatency2");
low_latency.doc() = "openvino.runtime.passes.LowLatency2 transformation";
// TODO: update docstrings for c-tor below
low_latency.def(py::init<bool>(), py::arg("use_const_initializer") = true);
py::class_<ov::pass::ConvertFP32ToFP16,
std::shared_ptr<ov::pass::ConvertFP32ToFP16>,
ov::pass::ModelPass,
ov::pass::PassBase>
convert(m, "ConvertFP32ToFP16");
convert.doc() = "openvino.runtime.passes.ConvertFP32ToFP16 transformation";
convert.def(py::init<>());
}

View File

@ -0,0 +1,11 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
void regclass_transformations(py::module m);

View File

@ -6,10 +6,15 @@
#include <pybind11/numpy.h>
#include "openvino/core/graph_util.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/pass/manager.hpp"
namespace py = pybind11;
template <typename... Args>
using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>;
void* numpy_to_c(py::array a) {
py::buffer_info info = a.request();
return info.ptr;
@ -31,4 +36,24 @@ void regmodule_graph_util(py::module m) {
from the resulting bound, otherwise Null.
:rtype: openvino.runtime.op.Constant or openvino.runtime.Node
)");
mod.def("replace_output_update_name", &ov::replace_output_update_name, py::arg("output"), py::arg("target_output"));
mod.def("replace_node",
overload_cast_<const std::shared_ptr<ov::Node>&, const std::shared_ptr<ov::Node>&>()(&ov::replace_node),
py::arg("target"),
py::arg("replacement"));
mod.def("replace_node",
overload_cast_<const std::shared_ptr<ov::Node>&, const ov::OutputVector&>()(&ov::replace_node),
py::arg("target"),
py::arg("replacement"));
mod.def("replace_node",
overload_cast_<const std::shared_ptr<ov::Node>&,
const std::shared_ptr<ov::Node>&,
const std::vector<int64_t>&>()(&ov::replace_node),
py::arg("target"),
py::arg("replacement"),
py::arg("outputs_order"));
}

View File

@ -0,0 +1,73 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.runtime import opset8
from openvino.runtime.passes import Manager, GraphRewrite, MatcherPass, WrapType, Matcher
from utils.utils import count_ops, get_test_function, PatternReplacement
def test_graph_rewrite():
model = get_test_function()
m = Manager()
# check that register pass returns pass instance
anchor = m.register_pass(GraphRewrite())
anchor.add_matcher(PatternReplacement())
m.run_passes(model)
assert count_ops(model, "Relu") == [2]
def test_register_new_node():
class InsertExp(MatcherPass):
def __init__(self):
MatcherPass.__init__(self)
self.model_changed = False
param = WrapType("opset8.Parameter")
def callback(m: Matcher) -> bool:
# Input->...->Result => Input->Exp->...->Result
root = m.get_match_value()
consumers = root.get_target_inputs()
exp = opset8.exp(root)
for consumer in consumers:
consumer.replace_source_output(exp.output(0))
# For testing purpose
self.model_changed = True
# Use new operation for additional matching
self.register_new_node(exp)
# Root node wasn't replaced or changed
return False
self.register_matcher(Matcher(param, "InsertExp"), callback)
class RemoveExp(MatcherPass):
def __init__(self):
MatcherPass.__init__(self)
self.model_changed = False
param = WrapType("opset8.Exp")
def callback(m: Matcher) -> bool:
root = m.get_match_root()
root.output(0).replace(root.input_value(0))
# For testing purpose
self.model_changed = True
return True
self.register_matcher(Matcher(param, "RemoveExp"), callback)
m = Manager()
ins = m.register_pass(InsertExp())
rem = m.register_pass(RemoveExp())
m.run_passes(get_test_function())
assert ins.model_changed
assert rem.model_changed

View File

@ -0,0 +1,44 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.runtime.passes import Manager, GraphRewrite, BackwardGraphRewrite, Serialize
from utils.utils import MyModelPass, PatternReplacement, expect_exception
def test_registration_and_pass_name():
m = Manager()
a = m.register_pass(PatternReplacement())
a.set_name("PatterReplacement")
b = m.register_pass(MyModelPass())
b.set_name("ModelPass")
c = m.register_pass(GraphRewrite())
c.set_name("Anchor")
d = c.add_matcher(PatternReplacement())
d.set_name("PatterReplacement")
e = m.register_pass(BackwardGraphRewrite())
e.set_name("BackAnchor")
f = e.add_matcher(PatternReplacement())
f.set_name("PatterReplacement")
PatternReplacement().set_name("PatternReplacement")
MyModelPass().set_name("MyModelPass")
GraphRewrite().set_name("Anchor")
BackwardGraphRewrite().set_name("BackAnchor")
# Preserve legacy behaviour when registered pass doesn't exist
# and in this case we shouldn't throw an exception.
m.register_pass("NotExistingPass")
def test_negative_pass_registration():
m = Manager()
expect_exception(lambda: m.register_pass(PatternReplacement))
expect_exception(lambda: m.register_pass("PatternReplacement", PatternReplacement()))
expect_exception(lambda: m.register_pass("Serialize", Serialize("out.xml", "out.bin")))
expect_exception(lambda: m.register_pass("Serialize", "out.xml", "out.bin", "out.wrong"))

View File

@ -0,0 +1,56 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.runtime import opset8
from openvino.runtime.passes import Manager, Matcher, MatcherPass, WrapType
from openvino.runtime.utils import replace_node
from utils.utils import count_ops, get_test_function, PatternReplacement
def test_simple_pattern_replacement():
# Simple: for Extensions. Without any classes and inheritance.
def pattern_replacement():
param = WrapType("opset8.Parameter")
relu = WrapType("opset8.Relu", param.output(0))
def callback(m: Matcher) -> bool:
root = m.get_match_root()
# Just to check that capturing works and we can
# link pattern nodes with matched graph nodes.
assert relu in m.get_pattern_value_map()
new_relu = opset8.exp(root.input_value(0)) # ot root.input(0).get_source_output()
replace_node(root, new_relu)
return True
return Matcher(relu, "SimpleReplacement"), callback
model = get_test_function()
m = Manager()
m.register_pass(MatcherPass(*pattern_replacement()))
m.run_passes(model)
assert count_ops(model, ("Relu", "Exp")) == [0, 1]
def test_matcher_pass():
model = get_test_function()
m = Manager()
# check that register pass returns pass instance
p = m.register_pass(PatternReplacement())
m.run_passes(model)
assert p.model_changed
assert count_ops(model, "Relu") == [2]
def test_matcher_pass_apply():
model = get_test_function()
p = PatternReplacement()
p.apply(model.get_result().input_value(0).get_node())
assert count_ops(model, "Relu") == [2]

View File

@ -0,0 +1,13 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.runtime.passes import Manager
from utils.utils import get_test_function, MyModelPass
def test_model_pass():
m = Manager()
p = m.register_pass(MyModelPass())
m.run_passes(get_test_function())
assert p.model_changed

View File

@ -67,26 +67,6 @@ def test_make_stateful_transformations():
assert len(function.get_results()) == 0
def test_serialize_pass():
core = Core()
xml_path = "serialized_function.xml"
bin_path = "serialized_function.bin"
func = get_test_function()
serialize(func, xml_path, bin_path)
assert func is not None
res_func = core.read_model(model=xml_path, weights=bin_path)
assert func.get_parameters() == res_func.get_parameters()
assert func.get_ordered_ops() == res_func.get_ordered_ops()
os.remove(xml_path)
os.remove(bin_path)
def test_serialize_pass_v2():
core = Core()
xml_path = "./serialized_function.xml"

View File

@ -0,0 +1,108 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from openvino.runtime import PartialShape, opset8
from openvino.runtime.passes import Matcher, WrapType, Or, AnyInput
from openvino.runtime.passes import consumers_count, has_static_dim, has_static_dims, \
has_static_shape, has_static_rank, type_matches, type_matches_any
from openvino.runtime.utils.types import get_element_type
from utils.utils import expect_exception
def test_wrap_type_pattern_type():
for i in range(1, 9):
WrapType("opset{}.Parameter".format(i))
WrapType("opset{}::Parameter".format(i))
# Negative check not to forget to update opset map in get_type function
expect_exception(lambda: WrapType("opset9.Parameter"), "Unsupported opset type: opset9")
# Generic negative test cases
expect_exception(lambda: WrapType(""))
expect_exception(lambda: WrapType("opset8"))
expect_exception(lambda: WrapType("Parameter"))
expect_exception(lambda: WrapType("opset.Parameter"))
expect_exception(lambda: WrapType("opset8,Parameter"))
expect_exception(lambda: WrapType("Parameter.opset8"))
def test_wrap_type_ctors():
param = opset8.parameter(PartialShape([1, 3, 22, 22]))
relu = opset8.relu(param.output(0))
slope = opset8.parameter(PartialShape([]))
prelu = opset8.prelu(param.output(0), slope.output(0))
m = Matcher(WrapType(["opset8.Relu", "opset8.PRelu"]), "FindActivation")
assert m.match(relu)
assert m.match(prelu)
m = Matcher(WrapType(["opset8.Relu", "opset8.PRelu"],
WrapType("opset8.Parameter").output(0)), "FindActivation")
assert m.match(relu)
def test_or():
param = opset8.parameter(PartialShape([1, 3, 22, 22]))
relu = opset8.relu(param.output(0))
slope = opset8.parameter(PartialShape([]))
prelu = opset8.prelu(param.output(0), slope.output(0))
m = Matcher(Or([WrapType("opset8.Relu"),
WrapType("opset8.PRelu")]), "FindActivation")
assert m.match(relu)
assert m.match(prelu)
def test_any_input():
param = opset8.parameter(PartialShape([1, 3, 22, 22]))
relu = opset8.relu(param.output(0))
slope = opset8.parameter(PartialShape([]))
prelu = opset8.prelu(param.output(0), slope.output(0))
m = Matcher(WrapType("opset8.PRelu", [AnyInput(), AnyInput()]), "FindActivation")
assert not m.match(relu)
assert m.match(prelu)
def test_any_input_predicate():
param = opset8.parameter(PartialShape([1, 3, 22, 22]))
slope = opset8.parameter(PartialShape([]))
m = Matcher(AnyInput(lambda output: len(output.get_shape()) == 4), "FindActivation")
assert m.match(param)
assert not m.match(slope)
def test_all_predicates():
static_param = opset8.parameter(PartialShape([1, 3, 22, 22]), np.float32)
dynamic_param = opset8.parameter(PartialShape([-1, 6]), np.long)
fully_dynamic_param = opset8.parameter(PartialShape.dynamic())
assert Matcher(WrapType("opset8.Parameter", consumers_count(0)), "Test").match(static_param)
assert not Matcher(WrapType("opset8.Parameter", consumers_count(1)), "Test").match(static_param)
assert Matcher(WrapType("opset8.Parameter", has_static_dim(1)), "Test").match(static_param)
assert not Matcher(WrapType("opset8.Parameter", has_static_dim(0)), "Test").match(dynamic_param)
assert Matcher(WrapType("opset8.Parameter", has_static_dims([0, 3])), "Test").match(static_param)
assert not Matcher(WrapType("opset8.Parameter", has_static_dims([0, 1])), "Test").match(dynamic_param)
assert Matcher(WrapType("opset8.Parameter", has_static_shape()), "Test").match(static_param)
assert not Matcher(WrapType("opset8.Parameter", has_static_shape()), "Test").match(dynamic_param)
assert Matcher(WrapType("opset8.Parameter", has_static_rank()), "Test").match(dynamic_param)
assert not Matcher(WrapType("opset8.Parameter", has_static_rank()), "Test").match(fully_dynamic_param)
assert Matcher(WrapType("opset8.Parameter",
type_matches(get_element_type(np.float32))), "Test").match(static_param)
assert not Matcher(WrapType("opset8.Parameter",
type_matches(get_element_type(np.float32))), "Test").match(dynamic_param)
assert Matcher(WrapType("opset8.Parameter",
type_matches_any([get_element_type(np.float32),
get_element_type(np.long)])), "Test").match(static_param)
assert Matcher(WrapType("opset8.Parameter",
type_matches_any([get_element_type(np.float32),
get_element_type(np.long)])), "Test").match(dynamic_param)

View File

@ -0,0 +1,114 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import os
import numpy as np
from openvino.runtime import Model, PartialShape, Shape, opset8, Core
from openvino.runtime.passes import Manager, ConstantFolding, MakeStateful,\
ConvertFP32ToFP16, LowLatency2, Serialize
from utils.utils import count_ops, get_test_function
def get_model():
param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter")
param.get_output_tensor(0).set_names({"parameter"})
relu = opset8.relu(param)
reshape = opset8.reshape(relu, opset8.shape_of(relu), False)
res = opset8.result(reshape, name="result")
res.get_output_tensor(0).set_names({"result"})
return Model([res], [param], "test")
def test_make_stateful():
model = get_model()
m = Manager()
p = MakeStateful({"parameter": "result"})
m.register_pass(p)
m.run_passes(model)
assert model is not None
assert len(model.get_parameters()) == 0
assert len(model.get_results()) == 0
def test_constant_folding():
model = get_model()
m = Manager()
m.register_pass(ConstantFolding())
m.run_passes(model)
assert model is not None
assert count_ops(model, "ShapeOf") == [0]
def test_convert_precision():
model = get_model()
m = Manager()
m.register_pass(ConvertFP32ToFP16())
m.run_passes(model)
assert model is not None
# TODO: fix bug 82773 with float16 type comparison
# assert model.get_parameters()[0].get_element_type() == np.float16
def test_low_latency2():
X = opset8.parameter(Shape([32, 40, 10]), np.float32, "X")
Y = opset8.parameter(Shape([32, 40, 10]), np.float32, "Y")
M = opset8.parameter(Shape([32, 2, 10]), np.float32, "M")
X_i = opset8.parameter(Shape([32, 2, 10]), np.float32, "X_i")
Y_i = opset8.parameter(Shape([32, 2, 10]), np.float32, "Y_i")
M_body = opset8.parameter(Shape([32, 2, 10]), np.float32, "M_body")
sum = opset8.add(X_i, Y_i)
Zo = opset8.multiply(sum, M_body)
body = Model([Zo], [X_i, Y_i, M_body], "body_function")
ti = opset8.tensor_iterator()
ti.set_body(body)
ti.set_sliced_input(X_i, X.output(0), 0, 2, 2, 39, 1)
ti.set_sliced_input(Y_i, Y.output(0), 0, 2, 2, -1, 1)
ti.set_invariant_input(M_body, M.output(0))
out0 = ti.get_iter_value(Zo.output(0), -1)
out1 = ti.get_concatenated_slices(Zo.output(0), 0, 2, 2, 39, 1)
result0 = opset8.result(out0)
result1 = opset8.result(out1)
model = Model([result0, result1], [X, Y, M])
m = Manager()
m.register_pass(LowLatency2())
m.run_passes(model)
# TODO: create TI which will be transformed by LowLatency2
assert count_ops(model, "TensorIterator") == [1]
def test_serialize_pass():
core = Core()
xml_path = "serialized_function.xml"
bin_path = "serialized_function.bin"
func = get_test_function()
m = Manager()
m.register_pass(Serialize(xml_path, bin_path))
m.run_passes(func)
assert func is not None
res_func = core.read_model(model=xml_path, weights=bin_path)
assert func.get_parameters() == res_func.get_parameters()
assert func.get_ordered_ops() == res_func.get_ordered_ops()
os.remove(xml_path)
os.remove(bin_path)

View File

@ -0,0 +1,59 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.runtime import Model, PartialShape, opset8
from openvino.runtime.utils import replace_node, replace_output_update_name
def get_test_function():
# Parameter->Relu->Result
param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter")
relu = opset8.relu(param.output(0))
res = opset8.result(relu.output(0), name="result")
return Model([res], [param], "test")
def test_output_replace():
param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter")
relu = opset8.relu(param.output(0))
res = opset8.result(relu.output(0), name="result")
exp = opset8.exp(param.output(0))
relu.output(0).replace(exp.output(0))
assert res.input_value(0).get_node() == exp
def test_replace_source_output():
param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter")
relu = opset8.relu(param.output(0))
res = opset8.result(relu.output(0), name="result")
exp = opset8.exp(param.output(0))
res.input(0).replace_source_output(exp.output(0))
assert len(exp.output(0).get_target_inputs()) == 1
assert len(relu.output(0).get_target_inputs()) == 0
assert next(iter(exp.output(0).get_target_inputs())).get_node() == res
def test_replace_node():
param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter")
relu = opset8.relu(param.output(0))
res = opset8.result(relu.output(0), name="result")
exp = opset8.exp(param.output(0))
replace_node(relu, exp)
assert res.input_value(0).get_node() == exp
def test_replace_output_update_name():
param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter")
relu = opset8.relu(param.output(0))
exp = opset8.exp(relu.output(0))
res = opset8.result(exp.output(0), name="result")
replace_output_update_name(exp.output(0), exp.input_value(0))
assert res.input_value(0).get_node() == exp

View File

@ -0,0 +1,75 @@
# Copyright (C) 2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from openvino.runtime import Model, PartialShape, opset8
from openvino.runtime.passes import ModelPass, Matcher, MatcherPass, WrapType
def get_test_function():
# Parameter->Relu->Result
param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter")
relu = opset8.relu(param.output(0))
res = opset8.result(relu.output(0), name="result")
return Model([res], [param], "test")
def count_ops(model, op_types):
if isinstance(op_types, str):
op_types = [op_types]
cnt = [0] * len(op_types)
types = {op_types[index]: index for index in range(len(op_types))}
for op in model.get_ops():
op_type = op.get_type_info().name
if op_type in types:
cnt[types[op_type]] += 1
return cnt
def expect_exception(func, message=""):
def check():
try:
func()
return None
except Exception as e:
return str(e)
res = check()
if res is None:
raise AssertionError("Exception is not thrown!")
assert message in res
class PatternReplacement(MatcherPass):
def __init__(self):
MatcherPass.__init__(self)
self.model_changed = False
relu = WrapType("opset8::Relu")
def callback(m: Matcher) -> bool:
self.applied = True
root = m.get_match_root()
new_relu = opset8.relu(root.input(0).get_source_output())
# For testing purpose
self.model_changed = True
# # Use new operation for additional matching
# self.register_new_node(new_relu)
# Input->Relu->Result => Input->Relu->Relu->Result
root.input(0).replace_source_output(new_relu.output(0))
return True
self.register_matcher(Matcher(relu, "PatternReplacement"), callback)
class MyModelPass(ModelPass):
def __init__(self):
super().__init__()
self.model_changed = False
def run_on_model(self, model):
for op in model.get_ops():
if op.get_type_info().name == "Relu":
self.model_changed = True

View File

@ -61,6 +61,10 @@ public:
set_property(property, true);
}
MatcherPass(const std::shared_ptr<pattern::Matcher>& m, const matcher_pass_callback& callback) : PassBase() {
register_matcher(m, callback);
}
bool apply(std::shared_ptr<ov::Node> node);
template <typename T, class... Args>
@ -76,6 +80,10 @@ public:
return node;
}
std::shared_ptr<ov::Node> register_new_node_(const std::shared_ptr<ov::Node>& node) {
return register_new_node(node);
}
const std::vector<std::shared_ptr<ov::Node>>& get_new_nodes() {
return m_new_nodes;
}
@ -88,8 +96,10 @@ public:
protected:
void register_matcher(const std::shared_ptr<pattern::Matcher>& m,
const graph_rewrite_callback& callback,
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);
const matcher_pass_callback& callback,
const PassPropertyMask& property);
void register_matcher(const std::shared_ptr<pattern::Matcher>& m, const matcher_pass_callback& callback);
private:
handler_callback m_handler;
@ -195,6 +205,13 @@ public:
}
}
std::shared_ptr<MatcherPass> add_matcher(const std::shared_ptr<MatcherPass>& pass) {
auto pass_config = get_pass_config();
pass->set_pass_config(pass_config);
m_matchers.push_back(pass);
return pass;
}
OPENVINO_DEPRECATED("Use MatcherPass instead")
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
const graph_rewrite_callback& callback,

View File

@ -51,6 +51,15 @@ public:
return rc;
}
std::shared_ptr<PassBase> register_pass_instance(std::shared_ptr<PassBase> pass) {
pass->set_pass_config(m_pass_config);
m_pass_list.push_back(pass);
if (m_per_pass_validation) {
push_pass<Validate>();
}
return pass;
}
void run_passes(std::shared_ptr<Model>);
void set_pass_visualization(bool new_state) {

View File

@ -308,6 +308,11 @@ void ov::pass::MatcherPass::register_matcher(const std::shared_ptr<ov::pass::pat
};
}
void ov::pass::MatcherPass::register_matcher(const std::shared_ptr<ov::pass::pattern::Matcher>& m,
const ov::graph_rewrite_callback& callback) {
register_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
}
bool ov::pass::MatcherPass::apply(std::shared_ptr<ov::Node> node) {
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, pass::perf_counters_graph_rewrite()[get_type_info()]);
m_new_nodes.clear();