Transformations Python API (#10971)
* Keep changes * Update tests * Keep changes * Cleanup * Add predicates support; new pattern ops; new tests * support for public passes; added tests * Fix compilation warning * Fix code style * Added docstrings; code cleanup * Update python API tests * Fix build on Windows * Revert back pass registration logic * Fix flake8 errors * Update docstrings; fix utils.hpp * Cleanup * Cleanup * Fix flake errors * Fix mypy * Skip mypy for passes
This commit is contained in:
parent
02c60c76ab
commit
866f006a83
@ -1,6 +1,11 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# type: ignore
|
||||
# flake8: noqa
|
||||
|
||||
from openvino.pyopenvino.passes import Manager
|
||||
from openvino.pyopenvino.passes import ModelPass, Matcher, MatcherPass, PassBase, WrapType, Or, AnyInput
|
||||
from openvino.pyopenvino.passes import consumers_count, has_static_dim, has_static_dims, has_static_shape,\
|
||||
has_static_rank, rank_equals, type_matches, type_matches_any
|
||||
from openvino.pyopenvino.passes import Serialize, ConstantFolding, VisualizeTree, MakeStateful, LowLatency2, ConvertFP32ToFP16
|
||||
from openvino.runtime.passes.manager import Manager
|
||||
from openvino.runtime.passes.graph_rewrite import GraphRewrite, BackwardGraphRewrite
|
||||
|
@ -0,0 +1,30 @@
|
||||
# Copyright (C) 2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# type: ignore
|
||||
from openvino.pyopenvino.passes import MatcherPass
|
||||
from openvino.pyopenvino.passes import GraphRewrite as GraphRewriteBase
|
||||
from openvino.pyopenvino.passes import BackwardGraphRewrite as BackwardGraphRewriteBase
|
||||
|
||||
|
||||
class GraphRewrite(GraphRewriteBase):
|
||||
"""GraphRewrite that additionally holds python transformations objects."""
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.passes_list = [] # need to keep python instances alive
|
||||
|
||||
def add_matcher(self, transformation: MatcherPass) -> MatcherPass:
|
||||
"""Append MatcherPass instance to the end of execution list."""
|
||||
self.passes_list.append(transformation)
|
||||
return super().add_matcher(transformation)
|
||||
|
||||
|
||||
class BackwardGraphRewrite(BackwardGraphRewriteBase):
|
||||
"""BackwardGraphRewriteBase that additionally holds python transformations objects."""
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.passes_list = [] # need to keep python instances alive
|
||||
|
||||
def add_matcher(self, transformation: MatcherPass) -> MatcherPass:
|
||||
"""Append MatcherPass instance to the end of execution list."""
|
||||
self.passes_list.append(transformation)
|
||||
return super().add_matcher(transformation)
|
24
src/bindings/python/src/openvino/runtime/passes/manager.py
Normal file
24
src/bindings/python/src/openvino/runtime/passes/manager.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Copyright (C) 2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# type: ignore
|
||||
from openvino.pyopenvino.passes import Manager as ManagerBase
|
||||
from openvino.pyopenvino.passes import PassBase
|
||||
|
||||
|
||||
class Manager(ManagerBase):
|
||||
"""Manager that additionally holds transformations objects."""
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.passes_list = [] # need to keep python instances alive
|
||||
|
||||
def register_pass(self, *args, **kwargs) -> PassBase:
|
||||
"""Register transformation for further execution."""
|
||||
for arg in args:
|
||||
if isinstance(arg, PassBase):
|
||||
self.passes_list.append(arg)
|
||||
|
||||
for arg in kwargs.values():
|
||||
if isinstance(arg, PassBase):
|
||||
self.passes_list.append(arg)
|
||||
|
||||
return super().register_pass(*args, **kwargs)
|
@ -4,4 +4,4 @@
|
||||
"""Generic utilities. Factor related functions out to separate files."""
|
||||
|
||||
from openvino.pyopenvino.util import numpy_to_c
|
||||
from openvino.pyopenvino.util import get_constant_from_source
|
||||
from openvino.pyopenvino.util import get_constant_from_source, replace_node, replace_output_update_name
|
||||
|
@ -62,7 +62,7 @@ endif()
|
||||
|
||||
# create target
|
||||
|
||||
file(GLOB_RECURSE SOURCES core/*.cpp graph/*.cpp frontend/*.cpp pyopenvino.cpp)
|
||||
file(GLOB_RECURSE SOURCES core/*.cpp graph/*.cpp frontend/*.cpp utils/*cpp pyopenvino.cpp)
|
||||
list(FILTER SOURCES EXCLUDE REGEX frontend/onnx|tensorflow|paddle/* )
|
||||
|
||||
pybind11_add_module(${PROJECT_NAME} MODULE ${SOURCES})
|
||||
|
@ -0,0 +1,75 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "pyopenvino/graph/passes/graph_rewrite.hpp"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include <openvino/pass/graph_rewrite.hpp>
|
||||
#include <openvino/pass/pass.hpp>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void regclass_passes_GraphRewrite(py::module m) {
|
||||
py::class_<ov::pass::GraphRewrite, std::shared_ptr<ov::pass::GraphRewrite>, ov::pass::ModelPass, ov::pass::PassBase>
|
||||
graph_rewrite(m, "GraphRewrite");
|
||||
graph_rewrite.doc() =
|
||||
"openvino.runtime.passes.GraphRewrite executes sequence of MatcherPass transformations in topological order";
|
||||
|
||||
graph_rewrite.def(py::init<>());
|
||||
graph_rewrite.def(py::init([](const std::shared_ptr<ov::pass::MatcherPass>& pass) {
|
||||
return std::make_shared<ov::pass::GraphRewrite>(pass);
|
||||
}),
|
||||
py::arg("pass"),
|
||||
R"(
|
||||
Register single MatcherPass pass inside GraphRewrite.
|
||||
|
||||
:param pass: openvino.runtime.passes.MatcherPass instance
|
||||
:type pass: openvino.runtime.passes.MatcherPass
|
||||
)");
|
||||
|
||||
graph_rewrite.def("add_matcher",
|
||||
static_cast<std::shared_ptr<ov::pass::MatcherPass> (ov::pass::GraphRewrite::*)(
|
||||
const std::shared_ptr<ov::pass::MatcherPass>&)>(&ov::pass::GraphRewrite::add_matcher),
|
||||
py::arg("pass"),
|
||||
R"(
|
||||
Register single MatcherPass pass inside GraphRewrite.
|
||||
|
||||
:param pass: openvino.runtime.passes.MatcherPass instance
|
||||
:type pass: openvino.runtime.passes.MatcherPass
|
||||
)");
|
||||
|
||||
py::class_<ov::pass::BackwardGraphRewrite,
|
||||
std::shared_ptr<ov::pass::BackwardGraphRewrite>,
|
||||
ov::pass::GraphRewrite,
|
||||
ov::pass::ModelPass,
|
||||
ov::pass::PassBase>
|
||||
back_graph_rewrite(m, "BackwardGraphRewrite");
|
||||
back_graph_rewrite.doc() = "openvino.runtime.passes.BackwardGraphRewrite executes sequence of MatcherPass "
|
||||
"transformations in reversed topological order";
|
||||
|
||||
back_graph_rewrite.def(py::init<>());
|
||||
back_graph_rewrite.def(py::init([](const std::shared_ptr<ov::pass::MatcherPass>& pass) {
|
||||
return std::make_shared<ov::pass::BackwardGraphRewrite>(pass);
|
||||
}),
|
||||
py::arg("pass"),
|
||||
R"(
|
||||
Register single MatcherPass pass inside BackwardGraphRewrite.
|
||||
|
||||
:param pass: openvino.runtime.passes.MatcherPass instance
|
||||
:type pass: openvino.runtime.passes.MatcherPass
|
||||
)");
|
||||
|
||||
back_graph_rewrite.def(
|
||||
"add_matcher",
|
||||
static_cast<std::shared_ptr<ov::pass::MatcherPass> (ov::pass::BackwardGraphRewrite::*)(
|
||||
const std::shared_ptr<ov::pass::MatcherPass>&)>(&ov::pass::BackwardGraphRewrite::add_matcher),
|
||||
py::arg("pass"),
|
||||
R"(
|
||||
Register single MatcherPass pass inside BackwardGraphRewrite.
|
||||
|
||||
:param pass: openvino.runtime.passes.MatcherPass instance
|
||||
:type pass: openvino.runtime.passes.MatcherPass
|
||||
)");
|
||||
}
|
@ -0,0 +1,11 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void regclass_passes_GraphRewrite(py::module m);
|
@ -31,63 +31,84 @@ inline Version convert_to_version(const std::string& version) {
|
||||
"'! The supported versions are: 'UNSPECIFIED'(default), 'IR_V10', 'IR_V11'.");
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ManagerWrapper : public ov::pass::Manager {
|
||||
public:
|
||||
ManagerWrapper() {}
|
||||
~ManagerWrapper() {}
|
||||
|
||||
void register_pass(const std::string& pass_name) {
|
||||
if (pass_name == "ConstantFolding")
|
||||
push_pass<ov::pass::ConstantFolding>();
|
||||
|
||||
if (m_per_pass_validation)
|
||||
push_pass<ov::pass::Validate>();
|
||||
return;
|
||||
}
|
||||
|
||||
void register_pass(const std::string& pass_name, const FilePaths& file_paths, const std::string& version) {
|
||||
if (pass_name == "Serialize") {
|
||||
push_pass<ov::pass::Serialize>(file_paths.first, file_paths.second, convert_to_version(version));
|
||||
}
|
||||
return;
|
||||
}
|
||||
void register_pass(const std::string& pass_name,
|
||||
const std::string& xml_path,
|
||||
const std::string& bin_path,
|
||||
const std::string& version) {
|
||||
if (pass_name == "Serialize")
|
||||
push_pass<ov::pass::Serialize>(xml_path, bin_path, convert_to_version(version));
|
||||
return;
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void regclass_graph_passes_Manager(py::module m) {
|
||||
py::class_<ManagerWrapper> manager(m, "Manager");
|
||||
manager.doc() = "openvino.runtime.passes.Manager wraps ov::pass::Manager using ManagerWrapper";
|
||||
void regclass_passes_Manager(py::module m) {
|
||||
py::class_<ov::pass::Manager> manager(m, "Manager");
|
||||
manager.doc() = "openvino.runtime.passes.Manager executes sequence of transformation on a given Model";
|
||||
|
||||
manager.def(py::init<>());
|
||||
|
||||
manager.def("set_per_pass_validation", &ManagerWrapper::set_per_pass_validation);
|
||||
manager.def("run_passes", &ManagerWrapper::run_passes);
|
||||
manager.def("register_pass",
|
||||
(void (ManagerWrapper::*)(const std::string&)) & ManagerWrapper::register_pass,
|
||||
py::arg("pass_name"),
|
||||
manager.def("set_per_pass_validation",
|
||||
&ov::pass::Manager::set_per_pass_validation,
|
||||
py::arg("new_state"),
|
||||
R"(
|
||||
Set the type of register pass for pass manager.
|
||||
Enables or disables Model validation after each pass execution.
|
||||
|
||||
:param pass_name : String to set the type of a pass.
|
||||
:type pass_name: str
|
||||
// )");
|
||||
:param new_state: flag which enables or disables model validation.
|
||||
:type new_state: bool
|
||||
)");
|
||||
|
||||
manager.def("run_passes",
|
||||
&ov::pass::Manager::run_passes,
|
||||
py::arg("model"),
|
||||
R"(
|
||||
Executes sequence of transformations on given Model.
|
||||
|
||||
:param model: openvino.runtime.Model to be transformed.
|
||||
:type model: openvino.runtime.Model
|
||||
)");
|
||||
|
||||
manager.def("register_pass",
|
||||
(void (ManagerWrapper::*)(const std::string&, const FilePaths&, const std::string&)) &
|
||||
ManagerWrapper::register_pass,
|
||||
py::arg("pass_name"),
|
||||
py::arg("output_files"),
|
||||
py::arg("version") = "UNSPECIFIED",
|
||||
&ov::pass::Manager::register_pass_instance,
|
||||
py::arg("transformation"),
|
||||
R"(
|
||||
Register pass instance for execution. Execution order matches the registration order.
|
||||
|
||||
:param transformation: transformation instance.
|
||||
:type transformation: openvino.runtime.passes.PassBase
|
||||
)");
|
||||
|
||||
manager.def(
|
||||
"register_pass",
|
||||
[](ov::pass::Manager& self, const std::string& pass_name) -> void {
|
||||
PyErr_WarnEx(PyExc_DeprecationWarning,
|
||||
"register_pass with this arguments is deprecated! "
|
||||
"Please use register_pass(ConstantFolding()) instead.",
|
||||
1);
|
||||
if (pass_name == "ConstantFolding") {
|
||||
self.register_pass<ov::pass::ConstantFolding>();
|
||||
}
|
||||
},
|
||||
py::arg("pass_name"),
|
||||
R"(
|
||||
This method is deprecated. Please use m.register_pass(ConstantFolding()) instead.
|
||||
|
||||
Register pass by name from the list of predefined passes.
|
||||
|
||||
:param pass_name: String to set the type of a pass.
|
||||
:type pass_name: str
|
||||
)");
|
||||
|
||||
manager.def(
|
||||
"register_pass",
|
||||
[](ov::pass::Manager& self,
|
||||
const std::string& pass_name,
|
||||
const FilePaths& file_paths,
|
||||
const std::string& version) -> void {
|
||||
PyErr_WarnEx(PyExc_DeprecationWarning,
|
||||
"register_pass with this arguments is deprecated! "
|
||||
"Please use register_pass(Serialize(xml, bin, version)) instead.",
|
||||
1);
|
||||
if (pass_name == "Serialize") {
|
||||
self.register_pass<ov::pass::Serialize>(file_paths.first,
|
||||
file_paths.second,
|
||||
convert_to_version(version));
|
||||
}
|
||||
},
|
||||
py::arg("pass_name"),
|
||||
py::arg("output_files"),
|
||||
py::arg("version") = "UNSPECIFIED",
|
||||
R"(
|
||||
This method is deprecated. Please use m.register_pass(Serialize(...)) instead.
|
||||
|
||||
Set the type of register pass for pass manager.
|
||||
|
||||
:param pass_name: String to set the type of a pass.
|
||||
@ -100,6 +121,7 @@ void regclass_graph_passes_Manager(py::module m) {
|
||||
- "IR_V10" : v10 IR
|
||||
- "IR_V11" : v11 IR
|
||||
:type version: str
|
||||
|
||||
Examples
|
||||
----------
|
||||
1. Default Version
|
||||
@ -108,16 +130,30 @@ void regclass_graph_passes_Manager(py::module m) {
|
||||
2. IR version 11
|
||||
pass_manager = Manager()
|
||||
pass_manager.register_pass("Serialize", output_files=("example.xml", "example.bin"), version="IR_V11")
|
||||
// )");
|
||||
)");
|
||||
|
||||
manager.def(
|
||||
"register_pass",
|
||||
(void (ManagerWrapper::*)(const std::string&, const std::string&, const std::string&, const std::string&)) &
|
||||
ManagerWrapper::register_pass,
|
||||
[](ov::pass::Manager& self,
|
||||
const std::string& pass_name,
|
||||
const std::string& xml_path,
|
||||
const std::string& bin_path,
|
||||
const std::string& version) -> void {
|
||||
PyErr_WarnEx(PyExc_DeprecationWarning,
|
||||
"register_pass with this arguments is deprecated! "
|
||||
"Please use register_pass(Serialize(xml, bin, version)) instead.",
|
||||
1);
|
||||
if (pass_name == "Serialize") {
|
||||
self.register_pass<ov::pass::Serialize>(xml_path, bin_path, convert_to_version(version));
|
||||
}
|
||||
},
|
||||
py::arg("pass_name"),
|
||||
py::arg("xml_path"),
|
||||
py::arg("bin_path"),
|
||||
py::arg("version") = "UNSPECIFIED",
|
||||
R"(
|
||||
This method is deprecated. Please use m.register_pass(Serialize(...)) instead.
|
||||
|
||||
Set the type of register pass for pass manager.
|
||||
|
||||
:param pass_name: String to set the type of a pass.
|
||||
@ -132,6 +168,7 @@ void regclass_graph_passes_Manager(py::module m) {
|
||||
- "IR_V10" : v10 IR
|
||||
- "IR_V11" : v11 IR
|
||||
:type version: str
|
||||
|
||||
Examples
|
||||
----------
|
||||
1. Default Version
|
||||
@ -140,5 +177,5 @@ void regclass_graph_passes_Manager(py::module m) {
|
||||
2. IR version 11
|
||||
pass_manager = Manager()
|
||||
pass_manager.register_pass("Serialize", xml_path="example.xml", bin_path="example.bin", version="IR_V11")
|
||||
// )");
|
||||
)");
|
||||
}
|
||||
|
@ -8,4 +8,4 @@
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void regclass_graph_passes_Manager(py::module m);
|
||||
void regclass_passes_Manager(py::module m);
|
||||
|
201
src/bindings/python/src/pyopenvino/graph/passes/matcher_pass.cpp
Normal file
201
src/bindings/python/src/pyopenvino/graph/passes/matcher_pass.cpp
Normal file
@ -0,0 +1,201 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "pyopenvino/graph/passes/matcher_pass.hpp"
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void regclass_passes_Matcher(py::module m) {
|
||||
py::class_<ov::pass::pattern::Matcher, std::shared_ptr<ov::pass::pattern::Matcher>> matcher(m, "Matcher");
|
||||
matcher.doc() = "openvino.runtime.passes.Matcher wraps ov::pass::pattern::Matcher";
|
||||
matcher.def(py::init([](const std::shared_ptr<ov::Node>& node, const std::string& name) {
|
||||
return std::make_shared<ov::pass::pattern::Matcher>(node, name);
|
||||
}),
|
||||
py::arg("node"),
|
||||
py::arg("name"),
|
||||
R"(
|
||||
Creates Matcher object with given pattern root node and matcher name.
|
||||
Matcher object is used for pattern matching on Model.
|
||||
|
||||
:param node: pattern root node.
|
||||
:type node: openvino.runtime.Node
|
||||
|
||||
:param name: pattern name. Usually matches the MatcherPass class name.
|
||||
:type name: str
|
||||
)");
|
||||
|
||||
matcher.def(py::init([](ov::Output<ov::Node>& output, const std::string& name) {
|
||||
return std::make_shared<ov::pass::pattern::Matcher>(output, name);
|
||||
}),
|
||||
py::arg("output"),
|
||||
py::arg("name"),
|
||||
R"(
|
||||
Creates Matcher object with given pattern root node output and matcher name.
|
||||
Matcher object is used for pattern matching on Model.
|
||||
|
||||
:param node: pattern root node output.
|
||||
:type node: openvino.runtime.Output
|
||||
|
||||
:param name: pattern name. Usually matches the MatcherPass class name.
|
||||
:type name: str
|
||||
)");
|
||||
|
||||
matcher.def("get_name",
|
||||
&ov::pass::pattern::Matcher::get_name,
|
||||
R"(
|
||||
Get Matcher name.
|
||||
|
||||
:return: openvino.runtime.passes.Matcher name.
|
||||
:rtype: str
|
||||
)");
|
||||
|
||||
matcher.def("get_match_root",
|
||||
&ov::pass::pattern::Matcher::get_match_root,
|
||||
R"(
|
||||
Get matched root node inside Model. Should be used after match() method is called.
|
||||
|
||||
:return: matched node.
|
||||
:rtype: openvino.runtime.Node
|
||||
)");
|
||||
|
||||
matcher.def("get_match_value",
|
||||
&ov::pass::pattern::Matcher::get_match_value,
|
||||
R"(
|
||||
Get matched node output inside Model. Should be used after match() method is called.
|
||||
|
||||
:return: matched node output.
|
||||
:rtype: openvino.runtime.Output
|
||||
)");
|
||||
|
||||
matcher.def("get_match_nodes",
|
||||
&ov::pass::pattern::Matcher::get_matched_nodes,
|
||||
R"(
|
||||
Get NodeVector of matched nodes. Should be used after match() method is called.
|
||||
|
||||
:return: matched nodes vector.
|
||||
:rtype: List[openvino.runtime.Node]
|
||||
)");
|
||||
|
||||
matcher.def("get_match_values",
|
||||
static_cast<const ov::OutputVector& (ov::pass::pattern::Matcher::*)() const>(
|
||||
&ov::pass::pattern::Matcher::get_matched_values),
|
||||
R"(
|
||||
Get OutputVector of matched outputs. Should be used after match() method is called.
|
||||
|
||||
:return: matched outputs vector.
|
||||
:rtype: List[openvino.runtime.Output]
|
||||
)");
|
||||
|
||||
matcher.def("get_pattern_value_map",
|
||||
&ov::pass::pattern::Matcher::get_pattern_value_map,
|
||||
R"(
|
||||
Get map which can be used to access matched nodes using nodes from pattern.
|
||||
Should be used after match() method is called.
|
||||
|
||||
:return: mapping of pattern nodes to matched nodes.
|
||||
:rtype: dict
|
||||
)");
|
||||
|
||||
matcher.def("match",
|
||||
static_cast<bool (ov::pass::pattern::Matcher::*)(const ov::Output<ov::Node>&)>(
|
||||
&ov::pass::pattern::Matcher::match),
|
||||
R"(
|
||||
Matches registered pattern starting from given output.
|
||||
|
||||
:return: status of matching.
|
||||
:rtype: bool
|
||||
)");
|
||||
|
||||
matcher.def("match",
|
||||
static_cast<bool (ov::pass::pattern::Matcher::*)(std::shared_ptr<ov::Node>)>(
|
||||
&ov::pass::pattern::Matcher::match),
|
||||
R"(
|
||||
Matches registered pattern starting from given Node.
|
||||
|
||||
:return: status of matching.
|
||||
:rtype: bool
|
||||
)");
|
||||
}
|
||||
|
||||
class PyMatcherPass : public ov::pass::MatcherPass {
|
||||
public:
|
||||
void py_register_matcher(const std::shared_ptr<ov::pass::pattern::Matcher>& matcher,
|
||||
const ov::matcher_pass_callback& callback) {
|
||||
register_matcher(matcher, callback);
|
||||
}
|
||||
};
|
||||
|
||||
void regclass_passes_MatcherPass(py::module m) {
|
||||
py::class_<ov::pass::MatcherPass, std::shared_ptr<ov::pass::MatcherPass>, ov::pass::PassBase, PyMatcherPass>
|
||||
matcher_pass(m, "MatcherPass");
|
||||
matcher_pass.doc() = "openvino.runtime.passes.MatcherPass wraps ov::pass::MatcherPass";
|
||||
matcher_pass.def(py::init<>());
|
||||
matcher_pass.def(
|
||||
py::init([](const std::shared_ptr<ov::pass::pattern::Matcher>& m, ov::matcher_pass_callback callback) {
|
||||
return std::make_shared<ov::pass::MatcherPass>(m, callback);
|
||||
}),
|
||||
py::arg("matcher"),
|
||||
py::arg("callback"),
|
||||
R"(
|
||||
Create MatcherPass from existing Matcher and callback objects.
|
||||
|
||||
:param matcher: openvino.runtime.passes.Matcher with registered pattern.
|
||||
:type matcher: openvino.runtime.passes.Matcher
|
||||
|
||||
:param callback: Function that performs transformation on the matched nodes.
|
||||
:type callback: function
|
||||
|
||||
:return: created openvino.runtime.passes.MatcherPass instance.
|
||||
:rtype: openvino.runtime.passes.MatcherPass
|
||||
)");
|
||||
|
||||
matcher_pass.def("apply",
|
||||
&ov::pass::MatcherPass::apply,
|
||||
py::arg("node"),
|
||||
R"(
|
||||
Execute MatcherPass on given Node.
|
||||
|
||||
:return: callback return code.
|
||||
:rtype: bool
|
||||
)");
|
||||
|
||||
matcher_pass.def("register_new_node",
|
||||
&ov::pass::MatcherPass::register_new_node_,
|
||||
py::arg("node"),
|
||||
R"(
|
||||
Register node for additional pattern matching.
|
||||
|
||||
:param node: openvino.runtime.Node for matching.
|
||||
:type node: openvino.runtime.Node
|
||||
|
||||
:return: registered node instance
|
||||
:rtype: openvino.runtime.Node
|
||||
)");
|
||||
|
||||
matcher_pass.def("register_matcher",
|
||||
static_cast<void (ov::pass::MatcherPass::*)(const std::shared_ptr<ov::pass::pattern::Matcher>&,
|
||||
const ov::graph_rewrite_callback& callback)>(
|
||||
&PyMatcherPass::py_register_matcher),
|
||||
py::arg("matcher"),
|
||||
py::arg("callback"),
|
||||
R"(
|
||||
Initialize matcher and callback for further execution.
|
||||
|
||||
:param matcher: openvino.runtime.passes.Matcher with registered pattern.
|
||||
:type matcher: openvino.runtime.passes.Matcher
|
||||
|
||||
:param callback: Function that performs transformation on the matched nodes.
|
||||
:type callback: function
|
||||
)");
|
||||
}
|
@ -0,0 +1,13 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void regclass_passes_Matcher(py::module m);
|
||||
|
||||
void regclass_passes_MatcherPass(py::module m);
|
@ -0,0 +1,47 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "pyopenvino/graph/passes/model_pass.hpp"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include <openvino/pass/pass.hpp>
|
||||
#include <string>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
class PyModelPass : public ov::pass::ModelPass {
|
||||
public:
|
||||
/* Inherit the constructors */
|
||||
using ov::pass::ModelPass::ModelPass;
|
||||
|
||||
/* Trampoline (need one for each virtual function) */
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& model) override {
|
||||
PYBIND11_OVERRIDE_PURE(bool, /* Return type */
|
||||
ov::pass::ModelPass, /* Parent class */
|
||||
run_on_model, /* Name of function in C++ (must match Python name) */
|
||||
model /* Argument(s) */
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
void regclass_passes_ModelPass(py::module m) {
|
||||
py::class_<ov::pass::ModelPass, std::shared_ptr<ov::pass::ModelPass>, ov::pass::PassBase, PyModelPass> model_pass(
|
||||
m,
|
||||
"ModelPass");
|
||||
model_pass.doc() = "openvino.runtime.passes.ModelPass wraps ov::pass::ModelPass";
|
||||
model_pass.def(py::init<>());
|
||||
model_pass.def("run_on_model",
|
||||
&ov::pass::ModelPass::run_on_model,
|
||||
py::arg("model"),
|
||||
R"(
|
||||
run_on_model must be defined in inherited class. This method is used to work with Model directly.
|
||||
|
||||
:param model: openvino.runtime.Model to be transformed.
|
||||
:type model: openvino.runtime.Model
|
||||
|
||||
:return: True in case if Model was changed and False otherwise.
|
||||
:rtype: bool
|
||||
)");
|
||||
}
|
@ -0,0 +1,11 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void regclass_passes_ModelPass(py::module m);
|
@ -0,0 +1,34 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "pyopenvino/graph/passes/pass_base.hpp"
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include <memory>
|
||||
#include <openvino/pass/pass.hpp>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void regclass_passes_PassBase(py::module m) {
|
||||
py::class_<ov::pass::PassBase, std::shared_ptr<ov::pass::PassBase>> pass_base(m, "PassBase");
|
||||
pass_base.doc() = "openvino.runtime.passes.PassBase wraps ov::pass::PassBase";
|
||||
pass_base.def("set_name",
|
||||
&ov::pass::PassBase::set_name,
|
||||
py::arg("name"),
|
||||
R"(
|
||||
Set transformation name.
|
||||
|
||||
:param name: Transformation name.
|
||||
:type name: str
|
||||
)");
|
||||
pass_base.def("get_name",
|
||||
&ov::pass::PassBase::get_name,
|
||||
R"(
|
||||
Get transformation name.
|
||||
|
||||
:return: Transformation name.
|
||||
:rtype: str
|
||||
)");
|
||||
}
|
@ -0,0 +1,11 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void regclass_passes_PassBase(py::module m);
|
504
src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp
Normal file
504
src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp
Normal file
@ -0,0 +1,504 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "pyopenvino/graph/passes/pattern_ops.hpp"
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <iterator>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
#include "ngraph/opsets/opset.hpp"
|
||||
#include "openvino/pass/pattern/op/label.hpp"
|
||||
#include "openvino/pass/pattern/op/or.hpp"
|
||||
#include "openvino/pass/pattern/op/pattern.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
|
||||
ov::NodeTypeInfo get_type(const std::string& type_name) {
|
||||
// Supported types: opsetX.OpName or opsetX::OpName
|
||||
std::string opset_type;
|
||||
auto it = type_name.cbegin();
|
||||
while (it != type_name.cend() && *it != '.' && *it != ':') {
|
||||
opset_type += *it;
|
||||
++it;
|
||||
}
|
||||
|
||||
// Skip delimiter
|
||||
while (it != type_name.cend() && (*it == '.' || *it == ':')) {
|
||||
++it;
|
||||
}
|
||||
|
||||
// Get operation type name
|
||||
std::string operation_type(it, type_name.end());
|
||||
|
||||
// TODO: create generic opset factory in Core so it can be reused
|
||||
const std::unordered_map<std::string, std::function<const ngraph::OpSet&()>> get_opset{
|
||||
{"opset1", ngraph::get_opset1},
|
||||
{"opset2", ngraph::get_opset2},
|
||||
{"opset3", ngraph::get_opset3},
|
||||
{"opset4", ngraph::get_opset4},
|
||||
{"opset5", ngraph::get_opset5},
|
||||
{"opset6", ngraph::get_opset6},
|
||||
{"opset7", ngraph::get_opset7},
|
||||
{"opset8", ngraph::get_opset8},
|
||||
};
|
||||
|
||||
if (!get_opset.count(opset_type)) {
|
||||
throw std::runtime_error("Unsupported opset type: " + opset_type);
|
||||
}
|
||||
|
||||
const ngraph::OpSet& m_opset = get_opset.at(opset_type)();
|
||||
if (!m_opset.contains_type(operation_type)) {
|
||||
throw std::runtime_error("Unrecognized operation type: " + operation_type);
|
||||
}
|
||||
|
||||
return m_opset.create(operation_type)->get_type_info();
|
||||
}
|
||||
|
||||
std::vector<ov::NodeTypeInfo> get_types(const std::vector<std::string>& type_names) {
|
||||
std::vector<ov::NodeTypeInfo> types;
|
||||
for (const auto& type_name : type_names) {
|
||||
types.emplace_back(get_type(type_name));
|
||||
}
|
||||
return types;
|
||||
}
|
||||
|
||||
using Predicate = const ov::pass::pattern::op::ValuePredicate;
|
||||
|
||||
void reg_pattern_wrap_type(py::module m) {
|
||||
py::class_<ov::pass::pattern::op::WrapType, std::shared_ptr<ov::pass::pattern::op::WrapType>, ov::Node> wrap_type(
|
||||
m,
|
||||
"WrapType");
|
||||
wrap_type.doc() = "openvino.runtime.passes.WrapType wraps ov::pass::pattern::op::WrapType";
|
||||
|
||||
wrap_type.def(py::init([](const std::string& type_name) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name));
|
||||
}),
|
||||
py::arg("type_name"),
|
||||
R"(
|
||||
Create WrapType with given node type.
|
||||
|
||||
:param type_name: node type. For example: "opset8.Abs"
|
||||
:type type_name: str
|
||||
)");
|
||||
|
||||
wrap_type.def(py::init([](const std::string& type_name, const Predicate& pred) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name), pred);
|
||||
}),
|
||||
py::arg("type_name"),
|
||||
py::arg("pred"),
|
||||
R"(
|
||||
Create WrapType with given node type and predicate.
|
||||
|
||||
:param type_name: node type. For example: "opset8.Abs"
|
||||
:type type_name: str
|
||||
|
||||
:param predicate: Function that performs additional checks for matching.
|
||||
:type predicate: function
|
||||
)");
|
||||
|
||||
wrap_type.def(py::init([](const std::string& type_name, const ov::Output<ov::Node>& input) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
|
||||
nullptr,
|
||||
ov::OutputVector{input});
|
||||
}),
|
||||
py::arg("type_name"),
|
||||
py::arg("input"),
|
||||
R"(
|
||||
Create WrapType with given node type and input node.
|
||||
|
||||
:param type_name: node type. For example: "opset8.Abs"
|
||||
:type type_name: str
|
||||
|
||||
:param input: Node output.
|
||||
:type input: openvino.runtime.Output
|
||||
)");
|
||||
|
||||
wrap_type.def(py::init([](const std::string& type_name, const std::shared_ptr<ov::Node>& input) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
|
||||
nullptr,
|
||||
ov::OutputVector{input});
|
||||
}),
|
||||
py::arg("type_name"),
|
||||
py::arg("input"),
|
||||
R"(
|
||||
Create WrapType with given node type and input node.
|
||||
|
||||
:param type_name: node type. For example: opset8.Abs
|
||||
:type type_name: str
|
||||
|
||||
:param input: Input node.
|
||||
:type input: openvino.runtime.Node
|
||||
)");
|
||||
|
||||
wrap_type.def(py::init([](const std::string& type_name, const ov::Output<ov::Node>& input, const Predicate& pred) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
|
||||
pred,
|
||||
ov::OutputVector{input});
|
||||
}),
|
||||
py::arg("type_name"),
|
||||
py::arg("input"),
|
||||
py::arg("predicate"),
|
||||
R"(
|
||||
Create WrapType with given node type, input node and predicate.
|
||||
|
||||
:param type_name: node type. For example: "opset8.Abs"
|
||||
:type type_name: str
|
||||
|
||||
:param input: Node output.
|
||||
:type input: openvino.runtime.Output
|
||||
|
||||
:param predicate: Function that performs additional checks for matching.
|
||||
:type predicate: function
|
||||
)");
|
||||
|
||||
wrap_type.def(
|
||||
py::init([](const std::string& type_name, const std::shared_ptr<ov::Node>& input, const Predicate& pred) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
|
||||
pred,
|
||||
ov::OutputVector{input});
|
||||
}),
|
||||
py::arg("type_name"),
|
||||
py::arg("input"),
|
||||
py::arg("predicate"),
|
||||
R"(
|
||||
Create WrapType with given node type, input node and predicate.
|
||||
|
||||
:param type_name: node type. For example: "opset8.Abs"
|
||||
:type type_name: str
|
||||
|
||||
:param input: Input node.
|
||||
:type input: openvino.runtime.Node
|
||||
|
||||
:param predicate: Function that performs additional checks for matching.
|
||||
:type predicate: function
|
||||
)");
|
||||
|
||||
wrap_type.def(py::init([](const std::string& type_name, const ov::OutputVector& inputs) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name), nullptr, inputs);
|
||||
}),
|
||||
py::arg("type_name"),
|
||||
py::arg("inputs"),
|
||||
R"(
|
||||
Create WrapType with given node type and input nodes.
|
||||
|
||||
:param type_name: node type. For example: "opset8.Abs"
|
||||
:type type_name: str
|
||||
|
||||
:param inputs: Node outputs.
|
||||
:type inputs: List[openvino.runtime.Output]
|
||||
)");
|
||||
|
||||
wrap_type.def(py::init([](const std::string& type_name, const ov::NodeVector& inputs) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
|
||||
nullptr,
|
||||
ov::as_output_vector(inputs));
|
||||
}),
|
||||
py::arg("type_name"),
|
||||
py::arg("inputs"),
|
||||
R"(
|
||||
Create WrapType with given node type and input nodes.
|
||||
|
||||
:param type_name: node type. For example: "opset8.Abs"
|
||||
:type type_name: str
|
||||
|
||||
:param inputs: Input nodes.
|
||||
:type inputs: List[openvino.runtime.Node]
|
||||
)");
|
||||
|
||||
wrap_type.def(py::init([](const std::string& type_name, const ov::OutputVector& inputs, const Predicate& pred) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name), pred, inputs);
|
||||
}),
|
||||
py::arg("type_name"),
|
||||
py::arg("inputs"),
|
||||
py::arg("predicate"),
|
||||
R"(
|
||||
Create WrapType with given node type, input nodes and predicate.
|
||||
|
||||
:param type_name: node type. For example: "opset8.Abs"
|
||||
:type type_name: str
|
||||
|
||||
:param inputs: Node outputs.
|
||||
:type inputs: List[openvino.runtime.Output]
|
||||
|
||||
:param predicate: Function that performs additional checks for matching.
|
||||
:type predicate: function
|
||||
)");
|
||||
|
||||
wrap_type.def(py::init([](const std::string& type_name, const ov::NodeVector& inputs, const Predicate& pred) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_type(type_name),
|
||||
pred,
|
||||
ov::as_output_vector(inputs));
|
||||
}),
|
||||
py::arg("type_name"),
|
||||
py::arg("inputs"),
|
||||
py::arg("predicate"),
|
||||
R"(
|
||||
Create WrapType with given node type, input nodes and predicate.
|
||||
|
||||
:param type_name: node type. For example: "opset8.Abs"
|
||||
:type type_name: str
|
||||
|
||||
:param inputs: Input nodes.
|
||||
:type inputs: List[openvino.runtime.Node]
|
||||
|
||||
:param predicate: Function that performs additional checks for matching.
|
||||
:type predicate: function
|
||||
)");
|
||||
|
||||
wrap_type.def(py::init([](const std::vector<std::string>& type_names) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names));
|
||||
}),
|
||||
py::arg("type_names"),
|
||||
R"(
|
||||
Create WrapType with given node types.
|
||||
|
||||
:param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"]
|
||||
:type type_names: List[str]
|
||||
)");
|
||||
|
||||
wrap_type.def(py::init([](const std::vector<std::string>& type_names, const Predicate& pred) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names), pred);
|
||||
}),
|
||||
py::arg("type_names"),
|
||||
py::arg("predicate"),
|
||||
R"(
|
||||
Create WrapType with given node types and predicate.
|
||||
|
||||
:param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"]
|
||||
:type type_names: List[str]
|
||||
|
||||
:param predicate: Function that performs additional checks for matching.
|
||||
:type predicate: function
|
||||
)");
|
||||
|
||||
wrap_type.def(py::init([](const std::vector<std::string>& type_names, const ov::Output<ov::Node>& input) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
|
||||
nullptr,
|
||||
ov::OutputVector{input});
|
||||
}),
|
||||
py::arg("type_names"),
|
||||
py::arg("input"),
|
||||
R"(
|
||||
Create WrapType with given node types and input.
|
||||
|
||||
:param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"]
|
||||
:type type_names: List[str]
|
||||
|
||||
:param input: Node output.
|
||||
:type input: openvino.runtime.Output
|
||||
)");
|
||||
|
||||
wrap_type.def(py::init([](const std::vector<std::string>& type_names, const std::shared_ptr<ov::Node>& input) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
|
||||
nullptr,
|
||||
ov::OutputVector{input});
|
||||
}),
|
||||
py::arg("type_names"),
|
||||
py::arg("input"),
|
||||
R"(
|
||||
Create WrapType with given node types and input.
|
||||
|
||||
:param type_name: node types. For example: ["opset8.Abs", "opset8.Relu"]
|
||||
:type type_name: List[str]
|
||||
|
||||
:param input: Input node.
|
||||
:type input: openvino.runtime.Node
|
||||
)");
|
||||
|
||||
wrap_type.def(
|
||||
py::init(
|
||||
[](const std::vector<std::string>& type_names, const ov::Output<ov::Node>& input, const Predicate& pred) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
|
||||
pred,
|
||||
ov::OutputVector{input});
|
||||
}),
|
||||
py::arg("type_names"),
|
||||
py::arg("input"),
|
||||
py::arg("predicate"),
|
||||
R"(
|
||||
Create WrapType with given node types, input and predicate.
|
||||
|
||||
:param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"]
|
||||
:type type_names: List[str]
|
||||
|
||||
:param input: Node output.
|
||||
:type input: openvino.runtime.Output
|
||||
|
||||
:param predicate: Function that performs additional checks for matching.
|
||||
:type predicate: function
|
||||
)");
|
||||
|
||||
wrap_type.def(py::init([](const std::vector<std::string>& type_names,
|
||||
const std::shared_ptr<ov::Node>& input,
|
||||
const Predicate& pred) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
|
||||
pred,
|
||||
ov::OutputVector{input});
|
||||
}),
|
||||
py::arg("type_names"),
|
||||
py::arg("input"),
|
||||
py::arg("predicate"),
|
||||
R"(
|
||||
Create WrapType with given node types, input and predicate.
|
||||
|
||||
:param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"]
|
||||
:type type_names: List[str]
|
||||
|
||||
:param input: Input node.
|
||||
:type input: openvino.runtime.Node
|
||||
|
||||
:param predicate: Function that performs additional checks for matching.
|
||||
:type predicate: function
|
||||
)");
|
||||
|
||||
wrap_type.def(py::init([](const std::vector<std::string>& type_names, const ov::OutputVector& inputs) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names), nullptr, inputs);
|
||||
}),
|
||||
py::arg("type_names"),
|
||||
py::arg("inputs"),
|
||||
R"(
|
||||
Create WrapType with given node types and input.
|
||||
|
||||
:param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"]
|
||||
:type type_names: List[str]
|
||||
|
||||
:param inputs: Nodes outputs.
|
||||
:type inputs: List[openvino.runtime.Output]
|
||||
)");
|
||||
|
||||
wrap_type.def(py::init([](const std::vector<std::string>& type_names, const ov::NodeVector& inputs) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
|
||||
nullptr,
|
||||
ov::as_output_vector(inputs));
|
||||
}),
|
||||
py::arg("type_names"),
|
||||
py::arg("inputs"),
|
||||
R"(
|
||||
Create WrapType with given node types and inputs.
|
||||
|
||||
:param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"]
|
||||
:type type_names: List[str]
|
||||
|
||||
:param inputs: Input nodes.
|
||||
:type inputs: List[openvino.runtime.Node]
|
||||
)");
|
||||
|
||||
wrap_type.def(
|
||||
py::init([](const std::vector<std::string>& type_names, const ov::OutputVector& inputs, const Predicate& pred) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names), pred, inputs);
|
||||
}),
|
||||
py::arg("type_names"),
|
||||
py::arg("inputs"),
|
||||
py::arg("predicate"),
|
||||
R"(
|
||||
Create WrapType with given node types, inputs and predicate.
|
||||
|
||||
:param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"]
|
||||
:type type_names: List[str]
|
||||
|
||||
:param inputs: Nodes outputs.
|
||||
:type inputs: List[openvino.runtime.Output]
|
||||
|
||||
:param predicate: Function that performs additional checks for matching.
|
||||
:type predicate: function
|
||||
)");
|
||||
|
||||
wrap_type.def(
|
||||
py::init([](const std::vector<std::string>& type_names, const ov::NodeVector& inputs, const Predicate& pred) {
|
||||
return std::make_shared<ov::pass::pattern::op::WrapType>(get_types(type_names),
|
||||
pred,
|
||||
ov::as_output_vector(inputs));
|
||||
}),
|
||||
py::arg("type_names"),
|
||||
py::arg("inputs"),
|
||||
py::arg("predicate"),
|
||||
R"(
|
||||
Create WrapType with given node types, inputs and predicate.
|
||||
|
||||
:param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"]
|
||||
:type type_names: List[str]
|
||||
|
||||
:param inputs: Input nodes.
|
||||
:type inputs: List[openvino.runtime.Node]
|
||||
|
||||
:param predicate: Function that performs additional checks for matching.
|
||||
:type predicate: function
|
||||
)");
|
||||
}
|
||||
|
||||
void reg_pattern_or(py::module m) {
|
||||
py::class_<ov::pass::pattern::op::Or, std::shared_ptr<ov::pass::pattern::op::Or>, ov::Node> or_type(m, "Or");
|
||||
or_type.doc() = "openvino.runtime.passes.Or wraps ov::pass::pattern::op::Or";
|
||||
|
||||
or_type.def(py::init([](const ov::OutputVector& inputs) {
|
||||
return std::make_shared<ov::pass::pattern::op::Or>(inputs);
|
||||
}),
|
||||
py::arg("inputs"),
|
||||
R"(
|
||||
Create pattern Or operation which is used to match any of given inputs.
|
||||
|
||||
:param inputs: Operation inputs.
|
||||
:type inputs: List[openvino.runtime.Output]
|
||||
)");
|
||||
|
||||
or_type.def(py::init([](const ov::NodeVector& inputs) {
|
||||
return std::make_shared<ov::pass::pattern::op::Or>(ov::as_output_vector(inputs));
|
||||
}),
|
||||
py::arg("inputs"),
|
||||
R"(
|
||||
Create pattern Or operation which is used to match any of given inputs.
|
||||
|
||||
:param inputs: Operation inputs.
|
||||
:type inputs: List[openvino.runtime.Node]
|
||||
)");
|
||||
}
|
||||
|
||||
void reg_pattern_any_input(py::module m) {
|
||||
py::class_<ov::pass::pattern::op::Label, std::shared_ptr<ov::pass::pattern::op::Label>, ov::Node> any_input(
|
||||
m,
|
||||
"AnyInput");
|
||||
any_input.doc() = "openvino.runtime.passes.AnyInput wraps ov::pass::pattern::op::Label";
|
||||
|
||||
any_input.def(py::init([]() {
|
||||
return std::make_shared<ov::pass::pattern::op::Label>();
|
||||
}),
|
||||
R"(
|
||||
Create pattern AnyInput operation which is used to match any type of node.
|
||||
)");
|
||||
|
||||
any_input.def(py::init([](const Predicate& pred) {
|
||||
return std::make_shared<ov::pass::pattern::op::Label>(ov::element::dynamic,
|
||||
ov::PartialShape::dynamic(),
|
||||
pred);
|
||||
}),
|
||||
py::arg("predicate"),
|
||||
R"(
|
||||
Create pattern AnyInput operation which is used to match any type of node.
|
||||
|
||||
:param pred: Function that performs additional checks for matching.
|
||||
:type pred: function
|
||||
)");
|
||||
}
|
||||
|
||||
void reg_predicates(py::module m) {
|
||||
m.def("consumers_count", &ov::pass::pattern::consumers_count);
|
||||
m.def("has_static_dim", &ov::pass::pattern::has_static_dim);
|
||||
m.def("has_static_dims", &ov::pass::pattern::has_static_dims);
|
||||
m.def("has_static_shape", &ov::pass::pattern::has_static_shape);
|
||||
m.def("has_static_rank", &ov::pass::pattern::has_static_rank);
|
||||
m.def("rank_equals", &ov::pass::pattern::rank_equals);
|
||||
m.def("type_matches", &ov::pass::pattern::type_matches);
|
||||
m.def("type_matches_any", &ov::pass::pattern::type_matches_any);
|
||||
}
|
||||
|
||||
void reg_passes_pattern_ops(py::module m) {
|
||||
reg_pattern_any_input(m);
|
||||
reg_pattern_wrap_type(m);
|
||||
reg_pattern_or(m);
|
||||
reg_predicates(m);
|
||||
}
|
@ -0,0 +1,11 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void reg_passes_pattern_ops(py::module m);
|
@ -6,9 +6,24 @@
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "pyopenvino/graph/passes/graph_rewrite.hpp"
|
||||
#include "pyopenvino/graph/passes/manager.hpp"
|
||||
#include "pyopenvino/graph/passes/matcher_pass.hpp"
|
||||
#include "pyopenvino/graph/passes/model_pass.hpp"
|
||||
#include "pyopenvino/graph/passes/pass_base.hpp"
|
||||
#include "pyopenvino/graph/passes/pattern_ops.hpp"
|
||||
#include "pyopenvino/graph/passes/transformations.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void regmodule_graph_passes(py::module m) {
|
||||
py::module m_passes = m.def_submodule("passes", "Package openvino.runtime.passes wraps ov::passes");
|
||||
regclass_graph_passes_Manager(m_passes);
|
||||
regclass_passes_PassBase(m_passes);
|
||||
regclass_passes_ModelPass(m_passes);
|
||||
regclass_passes_GraphRewrite(m_passes);
|
||||
regclass_passes_Matcher(m_passes);
|
||||
regclass_passes_MatcherPass(m_passes);
|
||||
regclass_transformations(m_passes);
|
||||
regclass_passes_Manager(m_passes);
|
||||
reg_passes_pattern_ops(m_passes);
|
||||
}
|
||||
|
@ -5,7 +5,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
#include "pyopenvino/graph/passes/manager.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
@ -0,0 +1,111 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "pyopenvino/graph/passes/transformations.hpp"
|
||||
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include <memory>
|
||||
#include <openvino/pass/constant_folding.hpp>
|
||||
#include <openvino/pass/convert_fp32_to_fp16.hpp>
|
||||
#include <openvino/pass/low_latency.hpp>
|
||||
#include <openvino/pass/make_stateful.hpp>
|
||||
#include <openvino/pass/pass.hpp>
|
||||
#include <openvino/pass/serialize.hpp>
|
||||
#include <openvino/pass/visualize_tree.hpp>
|
||||
|
||||
void regclass_transformations(py::module m) {
|
||||
py::class_<ov::pass::Serialize, std::shared_ptr<ov::pass::Serialize>, ov::pass::ModelPass, ov::pass::PassBase>
|
||||
serialize(m, "Serialize");
|
||||
serialize.doc() = "openvino.runtime.passes.Serialize transformation";
|
||||
serialize.def(py::init([](const std::string& path_to_xml, const std::string& path_to_bin) {
|
||||
return std::make_shared<ov::pass::Serialize>(path_to_xml, path_to_bin);
|
||||
}),
|
||||
py::arg("path_to_xml"),
|
||||
py::arg("path_to_bin"),
|
||||
R"(
|
||||
Create Serialize pass which is used for Model to IR serialization.
|
||||
|
||||
:param path_to_xml: Path where *.xml file will be saved.
|
||||
:type path_to_xml: str
|
||||
|
||||
:param path_to_xml: Path where *.bin file will be saved.
|
||||
:type path_to_xml: str
|
||||
)");
|
||||
|
||||
serialize.def(
|
||||
py::init(
|
||||
[](const std::string& path_to_xml, const std::string& path_to_bin, ov::pass::Serialize::Version version) {
|
||||
return std::make_shared<ov::pass::Serialize>(path_to_xml, path_to_bin, version);
|
||||
}),
|
||||
py::arg("path_to_xml"),
|
||||
py::arg("path_to_bin"),
|
||||
py::arg("version"),
|
||||
R"(
|
||||
Create Serialize pass which is used for Model to IR serialization.
|
||||
|
||||
:param path_to_xml: Path where *.xml file will be saved.
|
||||
:type path_to_xml: str
|
||||
|
||||
:param path_to_xml: Path where *.bin file will be saved.
|
||||
:type path_to_xml: str
|
||||
|
||||
:param version: serialized IR version.
|
||||
:type version: int
|
||||
)");
|
||||
|
||||
py::class_<ov::pass::ConstantFolding,
|
||||
std::shared_ptr<ov::pass::ConstantFolding>,
|
||||
ov::pass::ModelPass,
|
||||
ov::pass::PassBase>
|
||||
cf(m, "ConstantFolding");
|
||||
cf.doc() = "openvino.runtime.passes.ConstantFolding transformation";
|
||||
cf.def(py::init<>());
|
||||
|
||||
py::class_<ov::pass::VisualizeTree,
|
||||
std::shared_ptr<ov::pass::VisualizeTree>,
|
||||
ov::pass::ModelPass,
|
||||
ov::pass::PassBase>
|
||||
visualize(m, "VisualizeTree");
|
||||
visualize.doc() = "openvino.runtime.passes.VisualizeTree transformation";
|
||||
visualize.def(py::init<const std::string&, ov::pass::VisualizeTree::node_modifiers_t, bool>(),
|
||||
py::arg("file_name"),
|
||||
py::arg("nm") = nullptr,
|
||||
py::arg("don_only") = false,
|
||||
R"(
|
||||
Create VisualizeTree pass which is used for Model to dot serialization.
|
||||
|
||||
:param file_name: Path where serialized model will be saved. For example: /tmp/out.svg
|
||||
:type file_name: str
|
||||
|
||||
:param nm: Node modifier function.
|
||||
:type nm: function
|
||||
|
||||
:param don_only: Enable only dot file generation.
|
||||
:type don_only: bool
|
||||
)");
|
||||
|
||||
py::class_<ov::pass::MakeStateful, std::shared_ptr<ov::pass::MakeStateful>, ov::pass::ModelPass, ov::pass::PassBase>
|
||||
make_stateful(m, "MakeStateful");
|
||||
make_stateful.doc() = "openvino.runtime.passes.MakeStateful transformation";
|
||||
// TODO: update docstrings for c-tors below
|
||||
make_stateful.def(py::init<const ov::pass::MakeStateful::ParamResPairs&>(), py::arg("pairs_to_replace"));
|
||||
make_stateful.def(py::init<const std::map<std::string, std::string>&>());
|
||||
|
||||
py::class_<ov::pass::LowLatency2, std::shared_ptr<ov::pass::LowLatency2>, ov::pass::ModelPass, ov::pass::PassBase>
|
||||
low_latency(m, "LowLatency2");
|
||||
low_latency.doc() = "openvino.runtime.passes.LowLatency2 transformation";
|
||||
// TODO: update docstrings for c-tor below
|
||||
low_latency.def(py::init<bool>(), py::arg("use_const_initializer") = true);
|
||||
|
||||
py::class_<ov::pass::ConvertFP32ToFP16,
|
||||
std::shared_ptr<ov::pass::ConvertFP32ToFP16>,
|
||||
ov::pass::ModelPass,
|
||||
ov::pass::PassBase>
|
||||
convert(m, "ConvertFP32ToFP16");
|
||||
convert.doc() = "openvino.runtime.passes.ConvertFP32ToFP16 transformation";
|
||||
convert.def(py::init<>());
|
||||
}
|
@ -0,0 +1,11 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
void regclass_transformations(py::module m);
|
@ -6,10 +6,15 @@
|
||||
|
||||
#include <pybind11/numpy.h>
|
||||
|
||||
#include "openvino/core/graph_util.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
template <typename... Args>
|
||||
using overload_cast_ = pybind11::detail::overload_cast_impl<Args...>;
|
||||
|
||||
void* numpy_to_c(py::array a) {
|
||||
py::buffer_info info = a.request();
|
||||
return info.ptr;
|
||||
@ -31,4 +36,24 @@ void regmodule_graph_util(py::module m) {
|
||||
from the resulting bound, otherwise Null.
|
||||
:rtype: openvino.runtime.op.Constant or openvino.runtime.Node
|
||||
)");
|
||||
|
||||
mod.def("replace_output_update_name", &ov::replace_output_update_name, py::arg("output"), py::arg("target_output"));
|
||||
|
||||
mod.def("replace_node",
|
||||
overload_cast_<const std::shared_ptr<ov::Node>&, const std::shared_ptr<ov::Node>&>()(&ov::replace_node),
|
||||
py::arg("target"),
|
||||
py::arg("replacement"));
|
||||
|
||||
mod.def("replace_node",
|
||||
overload_cast_<const std::shared_ptr<ov::Node>&, const ov::OutputVector&>()(&ov::replace_node),
|
||||
py::arg("target"),
|
||||
py::arg("replacement"));
|
||||
|
||||
mod.def("replace_node",
|
||||
overload_cast_<const std::shared_ptr<ov::Node>&,
|
||||
const std::shared_ptr<ov::Node>&,
|
||||
const std::vector<int64_t>&>()(&ov::replace_node),
|
||||
py::arg("target"),
|
||||
py::arg("replacement"),
|
||||
py::arg("outputs_order"));
|
||||
}
|
||||
|
@ -0,0 +1,73 @@
|
||||
# Copyright (C) 2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from openvino.runtime import opset8
|
||||
from openvino.runtime.passes import Manager, GraphRewrite, MatcherPass, WrapType, Matcher
|
||||
|
||||
from utils.utils import count_ops, get_test_function, PatternReplacement
|
||||
|
||||
|
||||
def test_graph_rewrite():
|
||||
model = get_test_function()
|
||||
|
||||
m = Manager()
|
||||
# check that register pass returns pass instance
|
||||
anchor = m.register_pass(GraphRewrite())
|
||||
anchor.add_matcher(PatternReplacement())
|
||||
m.run_passes(model)
|
||||
|
||||
assert count_ops(model, "Relu") == [2]
|
||||
|
||||
|
||||
def test_register_new_node():
|
||||
class InsertExp(MatcherPass):
|
||||
def __init__(self):
|
||||
MatcherPass.__init__(self)
|
||||
self.model_changed = False
|
||||
|
||||
param = WrapType("opset8.Parameter")
|
||||
|
||||
def callback(m: Matcher) -> bool:
|
||||
# Input->...->Result => Input->Exp->...->Result
|
||||
root = m.get_match_value()
|
||||
consumers = root.get_target_inputs()
|
||||
|
||||
exp = opset8.exp(root)
|
||||
for consumer in consumers:
|
||||
consumer.replace_source_output(exp.output(0))
|
||||
|
||||
# For testing purpose
|
||||
self.model_changed = True
|
||||
|
||||
# Use new operation for additional matching
|
||||
self.register_new_node(exp)
|
||||
|
||||
# Root node wasn't replaced or changed
|
||||
return False
|
||||
|
||||
self.register_matcher(Matcher(param, "InsertExp"), callback)
|
||||
|
||||
class RemoveExp(MatcherPass):
|
||||
def __init__(self):
|
||||
MatcherPass.__init__(self)
|
||||
self.model_changed = False
|
||||
|
||||
param = WrapType("opset8.Exp")
|
||||
|
||||
def callback(m: Matcher) -> bool:
|
||||
root = m.get_match_root()
|
||||
root.output(0).replace(root.input_value(0))
|
||||
|
||||
# For testing purpose
|
||||
self.model_changed = True
|
||||
|
||||
return True
|
||||
|
||||
self.register_matcher(Matcher(param, "RemoveExp"), callback)
|
||||
|
||||
m = Manager()
|
||||
ins = m.register_pass(InsertExp())
|
||||
rem = m.register_pass(RemoveExp())
|
||||
m.run_passes(get_test_function())
|
||||
|
||||
assert ins.model_changed
|
||||
assert rem.model_changed
|
@ -0,0 +1,44 @@
|
||||
# Copyright (C) 2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from openvino.runtime.passes import Manager, GraphRewrite, BackwardGraphRewrite, Serialize
|
||||
|
||||
from utils.utils import MyModelPass, PatternReplacement, expect_exception
|
||||
|
||||
|
||||
def test_registration_and_pass_name():
|
||||
m = Manager()
|
||||
|
||||
a = m.register_pass(PatternReplacement())
|
||||
a.set_name("PatterReplacement")
|
||||
|
||||
b = m.register_pass(MyModelPass())
|
||||
b.set_name("ModelPass")
|
||||
|
||||
c = m.register_pass(GraphRewrite())
|
||||
c.set_name("Anchor")
|
||||
|
||||
d = c.add_matcher(PatternReplacement())
|
||||
d.set_name("PatterReplacement")
|
||||
|
||||
e = m.register_pass(BackwardGraphRewrite())
|
||||
e.set_name("BackAnchor")
|
||||
|
||||
f = e.add_matcher(PatternReplacement())
|
||||
f.set_name("PatterReplacement")
|
||||
|
||||
PatternReplacement().set_name("PatternReplacement")
|
||||
MyModelPass().set_name("MyModelPass")
|
||||
GraphRewrite().set_name("Anchor")
|
||||
BackwardGraphRewrite().set_name("BackAnchor")
|
||||
|
||||
# Preserve legacy behaviour when registered pass doesn't exist
|
||||
# and in this case we shouldn't throw an exception.
|
||||
m.register_pass("NotExistingPass")
|
||||
|
||||
|
||||
def test_negative_pass_registration():
|
||||
m = Manager()
|
||||
expect_exception(lambda: m.register_pass(PatternReplacement))
|
||||
expect_exception(lambda: m.register_pass("PatternReplacement", PatternReplacement()))
|
||||
expect_exception(lambda: m.register_pass("Serialize", Serialize("out.xml", "out.bin")))
|
||||
expect_exception(lambda: m.register_pass("Serialize", "out.xml", "out.bin", "out.wrong"))
|
@ -0,0 +1,56 @@
|
||||
# Copyright (C) 2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from openvino.runtime import opset8
|
||||
from openvino.runtime.passes import Manager, Matcher, MatcherPass, WrapType
|
||||
from openvino.runtime.utils import replace_node
|
||||
|
||||
from utils.utils import count_ops, get_test_function, PatternReplacement
|
||||
|
||||
|
||||
def test_simple_pattern_replacement():
|
||||
# Simple: for Extensions. Without any classes and inheritance.
|
||||
def pattern_replacement():
|
||||
param = WrapType("opset8.Parameter")
|
||||
relu = WrapType("opset8.Relu", param.output(0))
|
||||
|
||||
def callback(m: Matcher) -> bool:
|
||||
root = m.get_match_root()
|
||||
|
||||
# Just to check that capturing works and we can
|
||||
# link pattern nodes with matched graph nodes.
|
||||
assert relu in m.get_pattern_value_map()
|
||||
|
||||
new_relu = opset8.exp(root.input_value(0)) # ot root.input(0).get_source_output()
|
||||
replace_node(root, new_relu)
|
||||
return True
|
||||
|
||||
return Matcher(relu, "SimpleReplacement"), callback
|
||||
|
||||
model = get_test_function()
|
||||
|
||||
m = Manager()
|
||||
m.register_pass(MatcherPass(*pattern_replacement()))
|
||||
m.run_passes(model)
|
||||
|
||||
assert count_ops(model, ("Relu", "Exp")) == [0, 1]
|
||||
|
||||
|
||||
def test_matcher_pass():
|
||||
model = get_test_function()
|
||||
|
||||
m = Manager()
|
||||
# check that register pass returns pass instance
|
||||
p = m.register_pass(PatternReplacement())
|
||||
m.run_passes(model)
|
||||
|
||||
assert p.model_changed
|
||||
assert count_ops(model, "Relu") == [2]
|
||||
|
||||
|
||||
def test_matcher_pass_apply():
|
||||
model = get_test_function()
|
||||
|
||||
p = PatternReplacement()
|
||||
p.apply(model.get_result().input_value(0).get_node())
|
||||
|
||||
assert count_ops(model, "Relu") == [2]
|
@ -0,0 +1,13 @@
|
||||
# Copyright (C) 2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from openvino.runtime.passes import Manager
|
||||
|
||||
from utils.utils import get_test_function, MyModelPass
|
||||
|
||||
|
||||
def test_model_pass():
|
||||
m = Manager()
|
||||
p = m.register_pass(MyModelPass())
|
||||
m.run_passes(get_test_function())
|
||||
|
||||
assert p.model_changed
|
@ -67,26 +67,6 @@ def test_make_stateful_transformations():
|
||||
assert len(function.get_results()) == 0
|
||||
|
||||
|
||||
def test_serialize_pass():
|
||||
core = Core()
|
||||
xml_path = "serialized_function.xml"
|
||||
bin_path = "serialized_function.bin"
|
||||
|
||||
func = get_test_function()
|
||||
|
||||
serialize(func, xml_path, bin_path)
|
||||
|
||||
assert func is not None
|
||||
|
||||
res_func = core.read_model(model=xml_path, weights=bin_path)
|
||||
|
||||
assert func.get_parameters() == res_func.get_parameters()
|
||||
assert func.get_ordered_ops() == res_func.get_ordered_ops()
|
||||
|
||||
os.remove(xml_path)
|
||||
os.remove(bin_path)
|
||||
|
||||
|
||||
def test_serialize_pass_v2():
|
||||
core = Core()
|
||||
xml_path = "./serialized_function.xml"
|
||||
|
@ -0,0 +1,108 @@
|
||||
# Copyright (C) 2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import numpy as np
|
||||
|
||||
from openvino.runtime import PartialShape, opset8
|
||||
from openvino.runtime.passes import Matcher, WrapType, Or, AnyInput
|
||||
from openvino.runtime.passes import consumers_count, has_static_dim, has_static_dims, \
|
||||
has_static_shape, has_static_rank, type_matches, type_matches_any
|
||||
from openvino.runtime.utils.types import get_element_type
|
||||
|
||||
from utils.utils import expect_exception
|
||||
|
||||
|
||||
def test_wrap_type_pattern_type():
|
||||
for i in range(1, 9):
|
||||
WrapType("opset{}.Parameter".format(i))
|
||||
WrapType("opset{}::Parameter".format(i))
|
||||
|
||||
# Negative check not to forget to update opset map in get_type function
|
||||
expect_exception(lambda: WrapType("opset9.Parameter"), "Unsupported opset type: opset9")
|
||||
|
||||
# Generic negative test cases
|
||||
expect_exception(lambda: WrapType(""))
|
||||
expect_exception(lambda: WrapType("opset8"))
|
||||
expect_exception(lambda: WrapType("Parameter"))
|
||||
expect_exception(lambda: WrapType("opset.Parameter"))
|
||||
expect_exception(lambda: WrapType("opset8,Parameter"))
|
||||
expect_exception(lambda: WrapType("Parameter.opset8"))
|
||||
|
||||
|
||||
def test_wrap_type_ctors():
|
||||
param = opset8.parameter(PartialShape([1, 3, 22, 22]))
|
||||
relu = opset8.relu(param.output(0))
|
||||
slope = opset8.parameter(PartialShape([]))
|
||||
prelu = opset8.prelu(param.output(0), slope.output(0))
|
||||
|
||||
m = Matcher(WrapType(["opset8.Relu", "opset8.PRelu"]), "FindActivation")
|
||||
assert m.match(relu)
|
||||
assert m.match(prelu)
|
||||
|
||||
m = Matcher(WrapType(["opset8.Relu", "opset8.PRelu"],
|
||||
WrapType("opset8.Parameter").output(0)), "FindActivation")
|
||||
assert m.match(relu)
|
||||
|
||||
|
||||
def test_or():
|
||||
param = opset8.parameter(PartialShape([1, 3, 22, 22]))
|
||||
relu = opset8.relu(param.output(0))
|
||||
slope = opset8.parameter(PartialShape([]))
|
||||
prelu = opset8.prelu(param.output(0), slope.output(0))
|
||||
|
||||
m = Matcher(Or([WrapType("opset8.Relu"),
|
||||
WrapType("opset8.PRelu")]), "FindActivation")
|
||||
assert m.match(relu)
|
||||
assert m.match(prelu)
|
||||
|
||||
|
||||
def test_any_input():
|
||||
param = opset8.parameter(PartialShape([1, 3, 22, 22]))
|
||||
relu = opset8.relu(param.output(0))
|
||||
slope = opset8.parameter(PartialShape([]))
|
||||
prelu = opset8.prelu(param.output(0), slope.output(0))
|
||||
|
||||
m = Matcher(WrapType("opset8.PRelu", [AnyInput(), AnyInput()]), "FindActivation")
|
||||
assert not m.match(relu)
|
||||
assert m.match(prelu)
|
||||
|
||||
|
||||
def test_any_input_predicate():
|
||||
param = opset8.parameter(PartialShape([1, 3, 22, 22]))
|
||||
slope = opset8.parameter(PartialShape([]))
|
||||
|
||||
m = Matcher(AnyInput(lambda output: len(output.get_shape()) == 4), "FindActivation")
|
||||
assert m.match(param)
|
||||
assert not m.match(slope)
|
||||
|
||||
|
||||
def test_all_predicates():
|
||||
static_param = opset8.parameter(PartialShape([1, 3, 22, 22]), np.float32)
|
||||
dynamic_param = opset8.parameter(PartialShape([-1, 6]), np.long)
|
||||
fully_dynamic_param = opset8.parameter(PartialShape.dynamic())
|
||||
|
||||
assert Matcher(WrapType("opset8.Parameter", consumers_count(0)), "Test").match(static_param)
|
||||
assert not Matcher(WrapType("opset8.Parameter", consumers_count(1)), "Test").match(static_param)
|
||||
|
||||
assert Matcher(WrapType("opset8.Parameter", has_static_dim(1)), "Test").match(static_param)
|
||||
assert not Matcher(WrapType("opset8.Parameter", has_static_dim(0)), "Test").match(dynamic_param)
|
||||
|
||||
assert Matcher(WrapType("opset8.Parameter", has_static_dims([0, 3])), "Test").match(static_param)
|
||||
assert not Matcher(WrapType("opset8.Parameter", has_static_dims([0, 1])), "Test").match(dynamic_param)
|
||||
|
||||
assert Matcher(WrapType("opset8.Parameter", has_static_shape()), "Test").match(static_param)
|
||||
assert not Matcher(WrapType("opset8.Parameter", has_static_shape()), "Test").match(dynamic_param)
|
||||
|
||||
assert Matcher(WrapType("opset8.Parameter", has_static_rank()), "Test").match(dynamic_param)
|
||||
assert not Matcher(WrapType("opset8.Parameter", has_static_rank()), "Test").match(fully_dynamic_param)
|
||||
|
||||
assert Matcher(WrapType("opset8.Parameter",
|
||||
type_matches(get_element_type(np.float32))), "Test").match(static_param)
|
||||
assert not Matcher(WrapType("opset8.Parameter",
|
||||
type_matches(get_element_type(np.float32))), "Test").match(dynamic_param)
|
||||
|
||||
assert Matcher(WrapType("opset8.Parameter",
|
||||
type_matches_any([get_element_type(np.float32),
|
||||
get_element_type(np.long)])), "Test").match(static_param)
|
||||
assert Matcher(WrapType("opset8.Parameter",
|
||||
type_matches_any([get_element_type(np.float32),
|
||||
get_element_type(np.long)])), "Test").match(dynamic_param)
|
@ -0,0 +1,114 @@
|
||||
# Copyright (C) 2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from openvino.runtime import Model, PartialShape, Shape, opset8, Core
|
||||
from openvino.runtime.passes import Manager, ConstantFolding, MakeStateful,\
|
||||
ConvertFP32ToFP16, LowLatency2, Serialize
|
||||
|
||||
from utils.utils import count_ops, get_test_function
|
||||
|
||||
|
||||
def get_model():
|
||||
param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter")
|
||||
param.get_output_tensor(0).set_names({"parameter"})
|
||||
relu = opset8.relu(param)
|
||||
reshape = opset8.reshape(relu, opset8.shape_of(relu), False)
|
||||
res = opset8.result(reshape, name="result")
|
||||
res.get_output_tensor(0).set_names({"result"})
|
||||
return Model([res], [param], "test")
|
||||
|
||||
|
||||
def test_make_stateful():
|
||||
model = get_model()
|
||||
|
||||
m = Manager()
|
||||
p = MakeStateful({"parameter": "result"})
|
||||
m.register_pass(p)
|
||||
m.run_passes(model)
|
||||
|
||||
assert model is not None
|
||||
assert len(model.get_parameters()) == 0
|
||||
assert len(model.get_results()) == 0
|
||||
|
||||
|
||||
def test_constant_folding():
|
||||
model = get_model()
|
||||
|
||||
m = Manager()
|
||||
m.register_pass(ConstantFolding())
|
||||
m.run_passes(model)
|
||||
|
||||
assert model is not None
|
||||
assert count_ops(model, "ShapeOf") == [0]
|
||||
|
||||
|
||||
def test_convert_precision():
|
||||
model = get_model()
|
||||
|
||||
m = Manager()
|
||||
m.register_pass(ConvertFP32ToFP16())
|
||||
m.run_passes(model)
|
||||
|
||||
assert model is not None
|
||||
# TODO: fix bug 82773 with float16 type comparison
|
||||
# assert model.get_parameters()[0].get_element_type() == np.float16
|
||||
|
||||
|
||||
def test_low_latency2():
|
||||
X = opset8.parameter(Shape([32, 40, 10]), np.float32, "X")
|
||||
Y = opset8.parameter(Shape([32, 40, 10]), np.float32, "Y")
|
||||
M = opset8.parameter(Shape([32, 2, 10]), np.float32, "M")
|
||||
|
||||
X_i = opset8.parameter(Shape([32, 2, 10]), np.float32, "X_i")
|
||||
Y_i = opset8.parameter(Shape([32, 2, 10]), np.float32, "Y_i")
|
||||
M_body = opset8.parameter(Shape([32, 2, 10]), np.float32, "M_body")
|
||||
|
||||
sum = opset8.add(X_i, Y_i)
|
||||
Zo = opset8.multiply(sum, M_body)
|
||||
|
||||
body = Model([Zo], [X_i, Y_i, M_body], "body_function")
|
||||
|
||||
ti = opset8.tensor_iterator()
|
||||
ti.set_body(body)
|
||||
ti.set_sliced_input(X_i, X.output(0), 0, 2, 2, 39, 1)
|
||||
ti.set_sliced_input(Y_i, Y.output(0), 0, 2, 2, -1, 1)
|
||||
ti.set_invariant_input(M_body, M.output(0))
|
||||
|
||||
out0 = ti.get_iter_value(Zo.output(0), -1)
|
||||
out1 = ti.get_concatenated_slices(Zo.output(0), 0, 2, 2, 39, 1)
|
||||
|
||||
result0 = opset8.result(out0)
|
||||
result1 = opset8.result(out1)
|
||||
|
||||
model = Model([result0, result1], [X, Y, M])
|
||||
|
||||
m = Manager()
|
||||
m.register_pass(LowLatency2())
|
||||
m.run_passes(model)
|
||||
|
||||
# TODO: create TI which will be transformed by LowLatency2
|
||||
assert count_ops(model, "TensorIterator") == [1]
|
||||
|
||||
|
||||
def test_serialize_pass():
|
||||
core = Core()
|
||||
xml_path = "serialized_function.xml"
|
||||
bin_path = "serialized_function.bin"
|
||||
|
||||
func = get_test_function()
|
||||
|
||||
m = Manager()
|
||||
m.register_pass(Serialize(xml_path, bin_path))
|
||||
m.run_passes(func)
|
||||
|
||||
assert func is not None
|
||||
|
||||
res_func = core.read_model(model=xml_path, weights=bin_path)
|
||||
|
||||
assert func.get_parameters() == res_func.get_parameters()
|
||||
assert func.get_ordered_ops() == res_func.get_ordered_ops()
|
||||
|
||||
os.remove(xml_path)
|
||||
os.remove(bin_path)
|
@ -0,0 +1,59 @@
|
||||
# Copyright (C) 2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from openvino.runtime import Model, PartialShape, opset8
|
||||
from openvino.runtime.utils import replace_node, replace_output_update_name
|
||||
|
||||
|
||||
def get_test_function():
|
||||
# Parameter->Relu->Result
|
||||
param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter")
|
||||
relu = opset8.relu(param.output(0))
|
||||
res = opset8.result(relu.output(0), name="result")
|
||||
return Model([res], [param], "test")
|
||||
|
||||
|
||||
def test_output_replace():
|
||||
param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter")
|
||||
relu = opset8.relu(param.output(0))
|
||||
res = opset8.result(relu.output(0), name="result")
|
||||
|
||||
exp = opset8.exp(param.output(0))
|
||||
relu.output(0).replace(exp.output(0))
|
||||
|
||||
assert res.input_value(0).get_node() == exp
|
||||
|
||||
|
||||
def test_replace_source_output():
|
||||
param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter")
|
||||
relu = opset8.relu(param.output(0))
|
||||
res = opset8.result(relu.output(0), name="result")
|
||||
|
||||
exp = opset8.exp(param.output(0))
|
||||
res.input(0).replace_source_output(exp.output(0))
|
||||
|
||||
assert len(exp.output(0).get_target_inputs()) == 1
|
||||
assert len(relu.output(0).get_target_inputs()) == 0
|
||||
assert next(iter(exp.output(0).get_target_inputs())).get_node() == res
|
||||
|
||||
|
||||
def test_replace_node():
|
||||
param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter")
|
||||
relu = opset8.relu(param.output(0))
|
||||
res = opset8.result(relu.output(0), name="result")
|
||||
|
||||
exp = opset8.exp(param.output(0))
|
||||
replace_node(relu, exp)
|
||||
|
||||
assert res.input_value(0).get_node() == exp
|
||||
|
||||
|
||||
def test_replace_output_update_name():
|
||||
param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter")
|
||||
relu = opset8.relu(param.output(0))
|
||||
exp = opset8.exp(relu.output(0))
|
||||
res = opset8.result(exp.output(0), name="result")
|
||||
|
||||
replace_output_update_name(exp.output(0), exp.input_value(0))
|
||||
|
||||
assert res.input_value(0).get_node() == exp
|
@ -0,0 +1,75 @@
|
||||
# Copyright (C) 2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from openvino.runtime import Model, PartialShape, opset8
|
||||
from openvino.runtime.passes import ModelPass, Matcher, MatcherPass, WrapType
|
||||
|
||||
|
||||
def get_test_function():
|
||||
# Parameter->Relu->Result
|
||||
param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter")
|
||||
relu = opset8.relu(param.output(0))
|
||||
res = opset8.result(relu.output(0), name="result")
|
||||
return Model([res], [param], "test")
|
||||
|
||||
|
||||
def count_ops(model, op_types):
|
||||
if isinstance(op_types, str):
|
||||
op_types = [op_types]
|
||||
|
||||
cnt = [0] * len(op_types)
|
||||
types = {op_types[index]: index for index in range(len(op_types))}
|
||||
for op in model.get_ops():
|
||||
op_type = op.get_type_info().name
|
||||
if op_type in types:
|
||||
cnt[types[op_type]] += 1
|
||||
return cnt
|
||||
|
||||
|
||||
def expect_exception(func, message=""):
|
||||
def check():
|
||||
try:
|
||||
func()
|
||||
return None
|
||||
except Exception as e:
|
||||
return str(e)
|
||||
res = check()
|
||||
if res is None:
|
||||
raise AssertionError("Exception is not thrown!")
|
||||
assert message in res
|
||||
|
||||
|
||||
class PatternReplacement(MatcherPass):
|
||||
def __init__(self):
|
||||
MatcherPass.__init__(self)
|
||||
self.model_changed = False
|
||||
|
||||
relu = WrapType("opset8::Relu")
|
||||
|
||||
def callback(m: Matcher) -> bool:
|
||||
self.applied = True
|
||||
root = m.get_match_root()
|
||||
new_relu = opset8.relu(root.input(0).get_source_output())
|
||||
|
||||
# For testing purpose
|
||||
self.model_changed = True
|
||||
|
||||
# # Use new operation for additional matching
|
||||
# self.register_new_node(new_relu)
|
||||
|
||||
# Input->Relu->Result => Input->Relu->Relu->Result
|
||||
root.input(0).replace_source_output(new_relu.output(0))
|
||||
return True
|
||||
|
||||
self.register_matcher(Matcher(relu, "PatternReplacement"), callback)
|
||||
|
||||
|
||||
class MyModelPass(ModelPass):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model_changed = False
|
||||
|
||||
def run_on_model(self, model):
|
||||
for op in model.get_ops():
|
||||
if op.get_type_info().name == "Relu":
|
||||
self.model_changed = True
|
@ -61,6 +61,10 @@ public:
|
||||
set_property(property, true);
|
||||
}
|
||||
|
||||
MatcherPass(const std::shared_ptr<pattern::Matcher>& m, const matcher_pass_callback& callback) : PassBase() {
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
bool apply(std::shared_ptr<ov::Node> node);
|
||||
|
||||
template <typename T, class... Args>
|
||||
@ -76,6 +80,10 @@ public:
|
||||
return node;
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Node> register_new_node_(const std::shared_ptr<ov::Node>& node) {
|
||||
return register_new_node(node);
|
||||
}
|
||||
|
||||
const std::vector<std::shared_ptr<ov::Node>>& get_new_nodes() {
|
||||
return m_new_nodes;
|
||||
}
|
||||
@ -88,8 +96,10 @@ public:
|
||||
|
||||
protected:
|
||||
void register_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const graph_rewrite_callback& callback,
|
||||
const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
const matcher_pass_callback& callback,
|
||||
const PassPropertyMask& property);
|
||||
|
||||
void register_matcher(const std::shared_ptr<pattern::Matcher>& m, const matcher_pass_callback& callback);
|
||||
|
||||
private:
|
||||
handler_callback m_handler;
|
||||
@ -195,6 +205,13 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<MatcherPass> add_matcher(const std::shared_ptr<MatcherPass>& pass) {
|
||||
auto pass_config = get_pass_config();
|
||||
pass->set_pass_config(pass_config);
|
||||
m_matchers.push_back(pass);
|
||||
return pass;
|
||||
}
|
||||
|
||||
OPENVINO_DEPRECATED("Use MatcherPass instead")
|
||||
void add_matcher(const std::shared_ptr<pattern::Matcher>& m,
|
||||
const graph_rewrite_callback& callback,
|
||||
|
@ -51,6 +51,15 @@ public:
|
||||
return rc;
|
||||
}
|
||||
|
||||
std::shared_ptr<PassBase> register_pass_instance(std::shared_ptr<PassBase> pass) {
|
||||
pass->set_pass_config(m_pass_config);
|
||||
m_pass_list.push_back(pass);
|
||||
if (m_per_pass_validation) {
|
||||
push_pass<Validate>();
|
||||
}
|
||||
return pass;
|
||||
}
|
||||
|
||||
void run_passes(std::shared_ptr<Model>);
|
||||
|
||||
void set_pass_visualization(bool new_state) {
|
||||
|
@ -308,6 +308,11 @@ void ov::pass::MatcherPass::register_matcher(const std::shared_ptr<ov::pass::pat
|
||||
};
|
||||
}
|
||||
|
||||
void ov::pass::MatcherPass::register_matcher(const std::shared_ptr<ov::pass::pattern::Matcher>& m,
|
||||
const ov::graph_rewrite_callback& callback) {
|
||||
register_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE);
|
||||
}
|
||||
|
||||
bool ov::pass::MatcherPass::apply(std::shared_ptr<ov::Node> node) {
|
||||
OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, pass::perf_counters_graph_rewrite()[get_type_info()]);
|
||||
m_new_nodes.clear();
|
||||
|
Loading…
Reference in New Issue
Block a user