diff --git a/src/bindings/python/src/openvino/runtime/passes/__init__.py b/src/bindings/python/src/openvino/runtime/passes/__init__.py index 1ed557a3a37..5ae337cb3c7 100644 --- a/src/bindings/python/src/openvino/runtime/passes/__init__.py +++ b/src/bindings/python/src/openvino/runtime/passes/__init__.py @@ -1,6 +1,11 @@ # Copyright (C) 2018-2022 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - +# type: ignore # flake8: noqa -from openvino.pyopenvino.passes import Manager +from openvino.pyopenvino.passes import ModelPass, Matcher, MatcherPass, PassBase, WrapType, Or, AnyInput +from openvino.pyopenvino.passes import consumers_count, has_static_dim, has_static_dims, has_static_shape,\ + has_static_rank, rank_equals, type_matches, type_matches_any +from openvino.pyopenvino.passes import Serialize, ConstantFolding, VisualizeTree, MakeStateful, LowLatency2, ConvertFP32ToFP16 +from openvino.runtime.passes.manager import Manager +from openvino.runtime.passes.graph_rewrite import GraphRewrite, BackwardGraphRewrite diff --git a/src/bindings/python/src/openvino/runtime/passes/graph_rewrite.py b/src/bindings/python/src/openvino/runtime/passes/graph_rewrite.py new file mode 100644 index 00000000000..6dc0f4e8594 --- /dev/null +++ b/src/bindings/python/src/openvino/runtime/passes/graph_rewrite.py @@ -0,0 +1,30 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# type: ignore +from openvino.pyopenvino.passes import MatcherPass +from openvino.pyopenvino.passes import GraphRewrite as GraphRewriteBase +from openvino.pyopenvino.passes import BackwardGraphRewrite as BackwardGraphRewriteBase + + +class GraphRewrite(GraphRewriteBase): + """GraphRewrite that additionally holds python transformations objects.""" + def __init__(self) -> None: + super().__init__() + self.passes_list = [] # need to keep python instances alive + + def add_matcher(self, transformation: MatcherPass) -> MatcherPass: + """Append MatcherPass instance to the end of execution list.""" + self.passes_list.append(transformation) + return super().add_matcher(transformation) + + +class BackwardGraphRewrite(BackwardGraphRewriteBase): + """BackwardGraphRewriteBase that additionally holds python transformations objects.""" + def __init__(self) -> None: + super().__init__() + self.passes_list = [] # need to keep python instances alive + + def add_matcher(self, transformation: MatcherPass) -> MatcherPass: + """Append MatcherPass instance to the end of execution list.""" + self.passes_list.append(transformation) + return super().add_matcher(transformation) diff --git a/src/bindings/python/src/openvino/runtime/passes/manager.py b/src/bindings/python/src/openvino/runtime/passes/manager.py new file mode 100644 index 00000000000..5d99ca65ada --- /dev/null +++ b/src/bindings/python/src/openvino/runtime/passes/manager.py @@ -0,0 +1,24 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +# type: ignore +from openvino.pyopenvino.passes import Manager as ManagerBase +from openvino.pyopenvino.passes import PassBase + + +class Manager(ManagerBase): + """Manager that additionally holds transformations objects.""" + def __init__(self) -> None: + super().__init__() + self.passes_list = [] # need to keep python instances alive + + def register_pass(self, *args, **kwargs) -> PassBase: + """Register transformation for further execution.""" + for arg in args: + if isinstance(arg, PassBase): + self.passes_list.append(arg) + + for arg in kwargs.values(): + if isinstance(arg, PassBase): + self.passes_list.append(arg) + + return super().register_pass(*args, **kwargs) diff --git a/src/bindings/python/src/openvino/runtime/utils/__init__.py b/src/bindings/python/src/openvino/runtime/utils/__init__.py index 2f77dab83ca..90356e3bd98 100644 --- a/src/bindings/python/src/openvino/runtime/utils/__init__.py +++ b/src/bindings/python/src/openvino/runtime/utils/__init__.py @@ -4,4 +4,4 @@ """Generic utilities. Factor related functions out to separate files.""" from openvino.pyopenvino.util import numpy_to_c -from openvino.pyopenvino.util import get_constant_from_source +from openvino.pyopenvino.util import get_constant_from_source, replace_node, replace_output_update_name diff --git a/src/bindings/python/src/pyopenvino/CMakeLists.txt b/src/bindings/python/src/pyopenvino/CMakeLists.txt index c4c14077828..12c6b5a9da9 100644 --- a/src/bindings/python/src/pyopenvino/CMakeLists.txt +++ b/src/bindings/python/src/pyopenvino/CMakeLists.txt @@ -62,7 +62,7 @@ endif() # create target -file(GLOB_RECURSE SOURCES core/*.cpp graph/*.cpp frontend/*.cpp pyopenvino.cpp) +file(GLOB_RECURSE SOURCES core/*.cpp graph/*.cpp frontend/*.cpp utils/*cpp pyopenvino.cpp) list(FILTER SOURCES EXCLUDE REGEX frontend/onnx|tensorflow|paddle/* ) pybind11_add_module(${PROJECT_NAME} MODULE ${SOURCES}) diff --git a/src/bindings/python/src/pyopenvino/graph/passes/graph_rewrite.cpp b/src/bindings/python/src/pyopenvino/graph/passes/graph_rewrite.cpp new file mode 100644 index 00000000000..0ae831a4af5 --- /dev/null +++ b/src/bindings/python/src/pyopenvino/graph/passes/graph_rewrite.cpp @@ -0,0 +1,75 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "pyopenvino/graph/passes/graph_rewrite.hpp" + +#include + +#include +#include + +namespace py = pybind11; + +void regclass_passes_GraphRewrite(py::module m) { + py::class_, ov::pass::ModelPass, ov::pass::PassBase> + graph_rewrite(m, "GraphRewrite"); + graph_rewrite.doc() = + "openvino.runtime.passes.GraphRewrite executes sequence of MatcherPass transformations in topological order"; + + graph_rewrite.def(py::init<>()); + graph_rewrite.def(py::init([](const std::shared_ptr& pass) { + return std::make_shared(pass); + }), + py::arg("pass"), + R"( + Register single MatcherPass pass inside GraphRewrite. + + :param pass: openvino.runtime.passes.MatcherPass instance + :type pass: openvino.runtime.passes.MatcherPass + )"); + + graph_rewrite.def("add_matcher", + static_cast (ov::pass::GraphRewrite::*)( + const std::shared_ptr&)>(&ov::pass::GraphRewrite::add_matcher), + py::arg("pass"), + R"( + Register single MatcherPass pass inside GraphRewrite. + + :param pass: openvino.runtime.passes.MatcherPass instance + :type pass: openvino.runtime.passes.MatcherPass + )"); + + py::class_, + ov::pass::GraphRewrite, + ov::pass::ModelPass, + ov::pass::PassBase> + back_graph_rewrite(m, "BackwardGraphRewrite"); + back_graph_rewrite.doc() = "openvino.runtime.passes.BackwardGraphRewrite executes sequence of MatcherPass " + "transformations in reversed topological order"; + + back_graph_rewrite.def(py::init<>()); + back_graph_rewrite.def(py::init([](const std::shared_ptr& pass) { + return std::make_shared(pass); + }), + py::arg("pass"), + R"( + Register single MatcherPass pass inside BackwardGraphRewrite. + + :param pass: openvino.runtime.passes.MatcherPass instance + :type pass: openvino.runtime.passes.MatcherPass + )"); + + back_graph_rewrite.def( + "add_matcher", + static_cast (ov::pass::BackwardGraphRewrite::*)( + const std::shared_ptr&)>(&ov::pass::BackwardGraphRewrite::add_matcher), + py::arg("pass"), + R"( + Register single MatcherPass pass inside BackwardGraphRewrite. + + :param pass: openvino.runtime.passes.MatcherPass instance + :type pass: openvino.runtime.passes.MatcherPass + )"); +} diff --git a/src/bindings/python/src/pyopenvino/graph/passes/graph_rewrite.hpp b/src/bindings/python/src/pyopenvino/graph/passes/graph_rewrite.hpp new file mode 100644 index 00000000000..430b0bc319a --- /dev/null +++ b/src/bindings/python/src/pyopenvino/graph/passes/graph_rewrite.hpp @@ -0,0 +1,11 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace py = pybind11; + +void regclass_passes_GraphRewrite(py::module m); diff --git a/src/bindings/python/src/pyopenvino/graph/passes/manager.cpp b/src/bindings/python/src/pyopenvino/graph/passes/manager.cpp index 34211baeca7..f4f167e4054 100644 --- a/src/bindings/python/src/pyopenvino/graph/passes/manager.cpp +++ b/src/bindings/python/src/pyopenvino/graph/passes/manager.cpp @@ -31,63 +31,84 @@ inline Version convert_to_version(const std::string& version) { "'! The supported versions are: 'UNSPECIFIED'(default), 'IR_V10', 'IR_V11'."); } -namespace { -class ManagerWrapper : public ov::pass::Manager { -public: - ManagerWrapper() {} - ~ManagerWrapper() {} - - void register_pass(const std::string& pass_name) { - if (pass_name == "ConstantFolding") - push_pass(); - - if (m_per_pass_validation) - push_pass(); - return; - } - - void register_pass(const std::string& pass_name, const FilePaths& file_paths, const std::string& version) { - if (pass_name == "Serialize") { - push_pass(file_paths.first, file_paths.second, convert_to_version(version)); - } - return; - } - void register_pass(const std::string& pass_name, - const std::string& xml_path, - const std::string& bin_path, - const std::string& version) { - if (pass_name == "Serialize") - push_pass(xml_path, bin_path, convert_to_version(version)); - return; - } -}; -} // namespace - -void regclass_graph_passes_Manager(py::module m) { - py::class_ manager(m, "Manager"); - manager.doc() = "openvino.runtime.passes.Manager wraps ov::pass::Manager using ManagerWrapper"; +void regclass_passes_Manager(py::module m) { + py::class_ manager(m, "Manager"); + manager.doc() = "openvino.runtime.passes.Manager executes sequence of transformation on a given Model"; manager.def(py::init<>()); - - manager.def("set_per_pass_validation", &ManagerWrapper::set_per_pass_validation); - manager.def("run_passes", &ManagerWrapper::run_passes); - manager.def("register_pass", - (void (ManagerWrapper::*)(const std::string&)) & ManagerWrapper::register_pass, - py::arg("pass_name"), + manager.def("set_per_pass_validation", + &ov::pass::Manager::set_per_pass_validation, + py::arg("new_state"), R"( - Set the type of register pass for pass manager. + Enables or disables Model validation after each pass execution. - :param pass_name : String to set the type of a pass. - :type pass_name: str - // )"); + :param new_state: flag which enables or disables model validation. + :type new_state: bool + )"); + + manager.def("run_passes", + &ov::pass::Manager::run_passes, + py::arg("model"), + R"( + Executes sequence of transformations on given Model. + + :param model: openvino.runtime.Model to be transformed. + :type model: openvino.runtime.Model + )"); manager.def("register_pass", - (void (ManagerWrapper::*)(const std::string&, const FilePaths&, const std::string&)) & - ManagerWrapper::register_pass, - py::arg("pass_name"), - py::arg("output_files"), - py::arg("version") = "UNSPECIFIED", + &ov::pass::Manager::register_pass_instance, + py::arg("transformation"), R"( + Register pass instance for execution. Execution order matches the registration order. + + :param transformation: transformation instance. + :type transformation: openvino.runtime.passes.PassBase + )"); + + manager.def( + "register_pass", + [](ov::pass::Manager& self, const std::string& pass_name) -> void { + PyErr_WarnEx(PyExc_DeprecationWarning, + "register_pass with this arguments is deprecated! " + "Please use register_pass(ConstantFolding()) instead.", + 1); + if (pass_name == "ConstantFolding") { + self.register_pass(); + } + }, + py::arg("pass_name"), + R"( + This method is deprecated. Please use m.register_pass(ConstantFolding()) instead. + + Register pass by name from the list of predefined passes. + + :param pass_name: String to set the type of a pass. + :type pass_name: str + )"); + + manager.def( + "register_pass", + [](ov::pass::Manager& self, + const std::string& pass_name, + const FilePaths& file_paths, + const std::string& version) -> void { + PyErr_WarnEx(PyExc_DeprecationWarning, + "register_pass with this arguments is deprecated! " + "Please use register_pass(Serialize(xml, bin, version)) instead.", + 1); + if (pass_name == "Serialize") { + self.register_pass(file_paths.first, + file_paths.second, + convert_to_version(version)); + } + }, + py::arg("pass_name"), + py::arg("output_files"), + py::arg("version") = "UNSPECIFIED", + R"( + This method is deprecated. Please use m.register_pass(Serialize(...)) instead. + Set the type of register pass for pass manager. :param pass_name: String to set the type of a pass. @@ -100,6 +121,7 @@ void regclass_graph_passes_Manager(py::module m) { - "IR_V10" : v10 IR - "IR_V11" : v11 IR :type version: str + Examples ---------- 1. Default Version @@ -108,16 +130,30 @@ void regclass_graph_passes_Manager(py::module m) { 2. IR version 11 pass_manager = Manager() pass_manager.register_pass("Serialize", output_files=("example.xml", "example.bin"), version="IR_V11") - // )"); + )"); + manager.def( "register_pass", - (void (ManagerWrapper::*)(const std::string&, const std::string&, const std::string&, const std::string&)) & - ManagerWrapper::register_pass, + [](ov::pass::Manager& self, + const std::string& pass_name, + const std::string& xml_path, + const std::string& bin_path, + const std::string& version) -> void { + PyErr_WarnEx(PyExc_DeprecationWarning, + "register_pass with this arguments is deprecated! " + "Please use register_pass(Serialize(xml, bin, version)) instead.", + 1); + if (pass_name == "Serialize") { + self.register_pass(xml_path, bin_path, convert_to_version(version)); + } + }, py::arg("pass_name"), py::arg("xml_path"), py::arg("bin_path"), py::arg("version") = "UNSPECIFIED", R"( + This method is deprecated. Please use m.register_pass(Serialize(...)) instead. + Set the type of register pass for pass manager. :param pass_name: String to set the type of a pass. @@ -132,6 +168,7 @@ void regclass_graph_passes_Manager(py::module m) { - "IR_V10" : v10 IR - "IR_V11" : v11 IR :type version: str + Examples ---------- 1. Default Version @@ -140,5 +177,5 @@ void regclass_graph_passes_Manager(py::module m) { 2. IR version 11 pass_manager = Manager() pass_manager.register_pass("Serialize", xml_path="example.xml", bin_path="example.bin", version="IR_V11") - // )"); + )"); } diff --git a/src/bindings/python/src/pyopenvino/graph/passes/manager.hpp b/src/bindings/python/src/pyopenvino/graph/passes/manager.hpp index 5e8bbe23525..beb84dc2442 100644 --- a/src/bindings/python/src/pyopenvino/graph/passes/manager.hpp +++ b/src/bindings/python/src/pyopenvino/graph/passes/manager.hpp @@ -8,4 +8,4 @@ namespace py = pybind11; -void regclass_graph_passes_Manager(py::module m); +void regclass_passes_Manager(py::module m); diff --git a/src/bindings/python/src/pyopenvino/graph/passes/matcher_pass.cpp b/src/bindings/python/src/pyopenvino/graph/passes/matcher_pass.cpp new file mode 100644 index 00000000000..3b0fababed3 --- /dev/null +++ b/src/bindings/python/src/pyopenvino/graph/passes/matcher_pass.cpp @@ -0,0 +1,201 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "pyopenvino/graph/passes/matcher_pass.hpp" + +#include +#include +#include + +#include + +#include "openvino/pass/graph_rewrite.hpp" +#include "openvino/pass/pattern/matcher.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" + +namespace py = pybind11; + +void regclass_passes_Matcher(py::module m) { + py::class_> matcher(m, "Matcher"); + matcher.doc() = "openvino.runtime.passes.Matcher wraps ov::pass::pattern::Matcher"; + matcher.def(py::init([](const std::shared_ptr& node, const std::string& name) { + return std::make_shared(node, name); + }), + py::arg("node"), + py::arg("name"), + R"( + Creates Matcher object with given pattern root node and matcher name. + Matcher object is used for pattern matching on Model. + + :param node: pattern root node. + :type node: openvino.runtime.Node + + :param name: pattern name. Usually matches the MatcherPass class name. + :type name: str + )"); + + matcher.def(py::init([](ov::Output& output, const std::string& name) { + return std::make_shared(output, name); + }), + py::arg("output"), + py::arg("name"), + R"( + Creates Matcher object with given pattern root node output and matcher name. + Matcher object is used for pattern matching on Model. + + :param node: pattern root node output. + :type node: openvino.runtime.Output + + :param name: pattern name. Usually matches the MatcherPass class name. + :type name: str + )"); + + matcher.def("get_name", + &ov::pass::pattern::Matcher::get_name, + R"( + Get Matcher name. + + :return: openvino.runtime.passes.Matcher name. + :rtype: str + )"); + + matcher.def("get_match_root", + &ov::pass::pattern::Matcher::get_match_root, + R"( + Get matched root node inside Model. Should be used after match() method is called. + + :return: matched node. + :rtype: openvino.runtime.Node + )"); + + matcher.def("get_match_value", + &ov::pass::pattern::Matcher::get_match_value, + R"( + Get matched node output inside Model. Should be used after match() method is called. + + :return: matched node output. + :rtype: openvino.runtime.Output + )"); + + matcher.def("get_match_nodes", + &ov::pass::pattern::Matcher::get_matched_nodes, + R"( + Get NodeVector of matched nodes. Should be used after match() method is called. + + :return: matched nodes vector. + :rtype: List[openvino.runtime.Node] + )"); + + matcher.def("get_match_values", + static_cast( + &ov::pass::pattern::Matcher::get_matched_values), + R"( + Get OutputVector of matched outputs. Should be used after match() method is called. + + :return: matched outputs vector. + :rtype: List[openvino.runtime.Output] + )"); + + matcher.def("get_pattern_value_map", + &ov::pass::pattern::Matcher::get_pattern_value_map, + R"( + Get map which can be used to access matched nodes using nodes from pattern. + Should be used after match() method is called. + + :return: mapping of pattern nodes to matched nodes. + :rtype: dict + )"); + + matcher.def("match", + static_cast&)>( + &ov::pass::pattern::Matcher::match), + R"( + Matches registered pattern starting from given output. + + :return: status of matching. + :rtype: bool + )"); + + matcher.def("match", + static_cast)>( + &ov::pass::pattern::Matcher::match), + R"( + Matches registered pattern starting from given Node. + + :return: status of matching. + :rtype: bool + )"); +} + +class PyMatcherPass : public ov::pass::MatcherPass { +public: + void py_register_matcher(const std::shared_ptr& matcher, + const ov::matcher_pass_callback& callback) { + register_matcher(matcher, callback); + } +}; + +void regclass_passes_MatcherPass(py::module m) { + py::class_, ov::pass::PassBase, PyMatcherPass> + matcher_pass(m, "MatcherPass"); + matcher_pass.doc() = "openvino.runtime.passes.MatcherPass wraps ov::pass::MatcherPass"; + matcher_pass.def(py::init<>()); + matcher_pass.def( + py::init([](const std::shared_ptr& m, ov::matcher_pass_callback callback) { + return std::make_shared(m, callback); + }), + py::arg("matcher"), + py::arg("callback"), + R"( + Create MatcherPass from existing Matcher and callback objects. + + :param matcher: openvino.runtime.passes.Matcher with registered pattern. + :type matcher: openvino.runtime.passes.Matcher + + :param callback: Function that performs transformation on the matched nodes. + :type callback: function + + :return: created openvino.runtime.passes.MatcherPass instance. + :rtype: openvino.runtime.passes.MatcherPass + )"); + + matcher_pass.def("apply", + &ov::pass::MatcherPass::apply, + py::arg("node"), + R"( + Execute MatcherPass on given Node. + + :return: callback return code. + :rtype: bool + )"); + + matcher_pass.def("register_new_node", + &ov::pass::MatcherPass::register_new_node_, + py::arg("node"), + R"( + Register node for additional pattern matching. + + :param node: openvino.runtime.Node for matching. + :type node: openvino.runtime.Node + + :return: registered node instance + :rtype: openvino.runtime.Node + )"); + + matcher_pass.def("register_matcher", + static_cast&, + const ov::graph_rewrite_callback& callback)>( + &PyMatcherPass::py_register_matcher), + py::arg("matcher"), + py::arg("callback"), + R"( + Initialize matcher and callback for further execution. + + :param matcher: openvino.runtime.passes.Matcher with registered pattern. + :type matcher: openvino.runtime.passes.Matcher + + :param callback: Function that performs transformation on the matched nodes. + :type callback: function + )"); +} diff --git a/src/bindings/python/src/pyopenvino/graph/passes/matcher_pass.hpp b/src/bindings/python/src/pyopenvino/graph/passes/matcher_pass.hpp new file mode 100644 index 00000000000..bc33f8e1752 --- /dev/null +++ b/src/bindings/python/src/pyopenvino/graph/passes/matcher_pass.hpp @@ -0,0 +1,13 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace py = pybind11; + +void regclass_passes_Matcher(py::module m); + +void regclass_passes_MatcherPass(py::module m); diff --git a/src/bindings/python/src/pyopenvino/graph/passes/model_pass.cpp b/src/bindings/python/src/pyopenvino/graph/passes/model_pass.cpp new file mode 100644 index 00000000000..c3e8e7ca6bf --- /dev/null +++ b/src/bindings/python/src/pyopenvino/graph/passes/model_pass.cpp @@ -0,0 +1,47 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "pyopenvino/graph/passes/model_pass.hpp" + +#include + +#include +#include + +namespace py = pybind11; + +class PyModelPass : public ov::pass::ModelPass { +public: + /* Inherit the constructors */ + using ov::pass::ModelPass::ModelPass; + + /* Trampoline (need one for each virtual function) */ + bool run_on_model(const std::shared_ptr& model) override { + PYBIND11_OVERRIDE_PURE(bool, /* Return type */ + ov::pass::ModelPass, /* Parent class */ + run_on_model, /* Name of function in C++ (must match Python name) */ + model /* Argument(s) */ + ); + } +}; + +void regclass_passes_ModelPass(py::module m) { + py::class_, ov::pass::PassBase, PyModelPass> model_pass( + m, + "ModelPass"); + model_pass.doc() = "openvino.runtime.passes.ModelPass wraps ov::pass::ModelPass"; + model_pass.def(py::init<>()); + model_pass.def("run_on_model", + &ov::pass::ModelPass::run_on_model, + py::arg("model"), + R"( + run_on_model must be defined in inherited class. This method is used to work with Model directly. + + :param model: openvino.runtime.Model to be transformed. + :type model: openvino.runtime.Model + + :return: True in case if Model was changed and False otherwise. + :rtype: bool + )"); +} diff --git a/src/bindings/python/src/pyopenvino/graph/passes/model_pass.hpp b/src/bindings/python/src/pyopenvino/graph/passes/model_pass.hpp new file mode 100644 index 00000000000..d738cfdd7dd --- /dev/null +++ b/src/bindings/python/src/pyopenvino/graph/passes/model_pass.hpp @@ -0,0 +1,11 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace py = pybind11; + +void regclass_passes_ModelPass(py::module m); diff --git a/src/bindings/python/src/pyopenvino/graph/passes/pass_base.cpp b/src/bindings/python/src/pyopenvino/graph/passes/pass_base.cpp new file mode 100644 index 00000000000..8f6bbfac1a3 --- /dev/null +++ b/src/bindings/python/src/pyopenvino/graph/passes/pass_base.cpp @@ -0,0 +1,34 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "pyopenvino/graph/passes/pass_base.hpp" + +#include + +#include +#include + +namespace py = pybind11; + +void regclass_passes_PassBase(py::module m) { + py::class_> pass_base(m, "PassBase"); + pass_base.doc() = "openvino.runtime.passes.PassBase wraps ov::pass::PassBase"; + pass_base.def("set_name", + &ov::pass::PassBase::set_name, + py::arg("name"), + R"( + Set transformation name. + + :param name: Transformation name. + :type name: str + )"); + pass_base.def("get_name", + &ov::pass::PassBase::get_name, + R"( + Get transformation name. + + :return: Transformation name. + :rtype: str + )"); +} diff --git a/src/bindings/python/src/pyopenvino/graph/passes/pass_base.hpp b/src/bindings/python/src/pyopenvino/graph/passes/pass_base.hpp new file mode 100644 index 00000000000..878951236ad --- /dev/null +++ b/src/bindings/python/src/pyopenvino/graph/passes/pass_base.hpp @@ -0,0 +1,11 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace py = pybind11; + +void regclass_passes_PassBase(py::module m); diff --git a/src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp b/src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp new file mode 100644 index 00000000000..0bd0510778d --- /dev/null +++ b/src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp @@ -0,0 +1,504 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "pyopenvino/graph/passes/pattern_ops.hpp" + +#include +#include +#include + +#include +#include +#include + +#include "ngraph/opsets/opset.hpp" +#include "openvino/pass/pattern/op/label.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/pattern.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" + +ov::NodeTypeInfo get_type(const std::string& type_name) { + // Supported types: opsetX.OpName or opsetX::OpName + std::string opset_type; + auto it = type_name.cbegin(); + while (it != type_name.cend() && *it != '.' && *it != ':') { + opset_type += *it; + ++it; + } + + // Skip delimiter + while (it != type_name.cend() && (*it == '.' || *it == ':')) { + ++it; + } + + // Get operation type name + std::string operation_type(it, type_name.end()); + + // TODO: create generic opset factory in Core so it can be reused + const std::unordered_map> get_opset{ + {"opset1", ngraph::get_opset1}, + {"opset2", ngraph::get_opset2}, + {"opset3", ngraph::get_opset3}, + {"opset4", ngraph::get_opset4}, + {"opset5", ngraph::get_opset5}, + {"opset6", ngraph::get_opset6}, + {"opset7", ngraph::get_opset7}, + {"opset8", ngraph::get_opset8}, + }; + + if (!get_opset.count(opset_type)) { + throw std::runtime_error("Unsupported opset type: " + opset_type); + } + + const ngraph::OpSet& m_opset = get_opset.at(opset_type)(); + if (!m_opset.contains_type(operation_type)) { + throw std::runtime_error("Unrecognized operation type: " + operation_type); + } + + return m_opset.create(operation_type)->get_type_info(); +} + +std::vector get_types(const std::vector& type_names) { + std::vector types; + for (const auto& type_name : type_names) { + types.emplace_back(get_type(type_name)); + } + return types; +} + +using Predicate = const ov::pass::pattern::op::ValuePredicate; + +void reg_pattern_wrap_type(py::module m) { + py::class_, ov::Node> wrap_type( + m, + "WrapType"); + wrap_type.doc() = "openvino.runtime.passes.WrapType wraps ov::pass::pattern::op::WrapType"; + + wrap_type.def(py::init([](const std::string& type_name) { + return std::make_shared(get_type(type_name)); + }), + py::arg("type_name"), + R"( + Create WrapType with given node type. + + :param type_name: node type. For example: "opset8.Abs" + :type type_name: str + )"); + + wrap_type.def(py::init([](const std::string& type_name, const Predicate& pred) { + return std::make_shared(get_type(type_name), pred); + }), + py::arg("type_name"), + py::arg("pred"), + R"( + Create WrapType with given node type and predicate. + + :param type_name: node type. For example: "opset8.Abs" + :type type_name: str + + :param predicate: Function that performs additional checks for matching. + :type predicate: function + )"); + + wrap_type.def(py::init([](const std::string& type_name, const ov::Output& input) { + return std::make_shared(get_type(type_name), + nullptr, + ov::OutputVector{input}); + }), + py::arg("type_name"), + py::arg("input"), + R"( + Create WrapType with given node type and input node. + + :param type_name: node type. For example: "opset8.Abs" + :type type_name: str + + :param input: Node output. + :type input: openvino.runtime.Output + )"); + + wrap_type.def(py::init([](const std::string& type_name, const std::shared_ptr& input) { + return std::make_shared(get_type(type_name), + nullptr, + ov::OutputVector{input}); + }), + py::arg("type_name"), + py::arg("input"), + R"( + Create WrapType with given node type and input node. + + :param type_name: node type. For example: opset8.Abs + :type type_name: str + + :param input: Input node. + :type input: openvino.runtime.Node + )"); + + wrap_type.def(py::init([](const std::string& type_name, const ov::Output& input, const Predicate& pred) { + return std::make_shared(get_type(type_name), + pred, + ov::OutputVector{input}); + }), + py::arg("type_name"), + py::arg("input"), + py::arg("predicate"), + R"( + Create WrapType with given node type, input node and predicate. + + :param type_name: node type. For example: "opset8.Abs" + :type type_name: str + + :param input: Node output. + :type input: openvino.runtime.Output + + :param predicate: Function that performs additional checks for matching. + :type predicate: function + )"); + + wrap_type.def( + py::init([](const std::string& type_name, const std::shared_ptr& input, const Predicate& pred) { + return std::make_shared(get_type(type_name), + pred, + ov::OutputVector{input}); + }), + py::arg("type_name"), + py::arg("input"), + py::arg("predicate"), + R"( + Create WrapType with given node type, input node and predicate. + + :param type_name: node type. For example: "opset8.Abs" + :type type_name: str + + :param input: Input node. + :type input: openvino.runtime.Node + + :param predicate: Function that performs additional checks for matching. + :type predicate: function + )"); + + wrap_type.def(py::init([](const std::string& type_name, const ov::OutputVector& inputs) { + return std::make_shared(get_type(type_name), nullptr, inputs); + }), + py::arg("type_name"), + py::arg("inputs"), + R"( + Create WrapType with given node type and input nodes. + + :param type_name: node type. For example: "opset8.Abs" + :type type_name: str + + :param inputs: Node outputs. + :type inputs: List[openvino.runtime.Output] + )"); + + wrap_type.def(py::init([](const std::string& type_name, const ov::NodeVector& inputs) { + return std::make_shared(get_type(type_name), + nullptr, + ov::as_output_vector(inputs)); + }), + py::arg("type_name"), + py::arg("inputs"), + R"( + Create WrapType with given node type and input nodes. + + :param type_name: node type. For example: "opset8.Abs" + :type type_name: str + + :param inputs: Input nodes. + :type inputs: List[openvino.runtime.Node] + )"); + + wrap_type.def(py::init([](const std::string& type_name, const ov::OutputVector& inputs, const Predicate& pred) { + return std::make_shared(get_type(type_name), pred, inputs); + }), + py::arg("type_name"), + py::arg("inputs"), + py::arg("predicate"), + R"( + Create WrapType with given node type, input nodes and predicate. + + :param type_name: node type. For example: "opset8.Abs" + :type type_name: str + + :param inputs: Node outputs. + :type inputs: List[openvino.runtime.Output] + + :param predicate: Function that performs additional checks for matching. + :type predicate: function + )"); + + wrap_type.def(py::init([](const std::string& type_name, const ov::NodeVector& inputs, const Predicate& pred) { + return std::make_shared(get_type(type_name), + pred, + ov::as_output_vector(inputs)); + }), + py::arg("type_name"), + py::arg("inputs"), + py::arg("predicate"), + R"( + Create WrapType with given node type, input nodes and predicate. + + :param type_name: node type. For example: "opset8.Abs" + :type type_name: str + + :param inputs: Input nodes. + :type inputs: List[openvino.runtime.Node] + + :param predicate: Function that performs additional checks for matching. + :type predicate: function + )"); + + wrap_type.def(py::init([](const std::vector& type_names) { + return std::make_shared(get_types(type_names)); + }), + py::arg("type_names"), + R"( + Create WrapType with given node types. + + :param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"] + :type type_names: List[str] + )"); + + wrap_type.def(py::init([](const std::vector& type_names, const Predicate& pred) { + return std::make_shared(get_types(type_names), pred); + }), + py::arg("type_names"), + py::arg("predicate"), + R"( + Create WrapType with given node types and predicate. + + :param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"] + :type type_names: List[str] + + :param predicate: Function that performs additional checks for matching. + :type predicate: function + )"); + + wrap_type.def(py::init([](const std::vector& type_names, const ov::Output& input) { + return std::make_shared(get_types(type_names), + nullptr, + ov::OutputVector{input}); + }), + py::arg("type_names"), + py::arg("input"), + R"( + Create WrapType with given node types and input. + + :param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"] + :type type_names: List[str] + + :param input: Node output. + :type input: openvino.runtime.Output + )"); + + wrap_type.def(py::init([](const std::vector& type_names, const std::shared_ptr& input) { + return std::make_shared(get_types(type_names), + nullptr, + ov::OutputVector{input}); + }), + py::arg("type_names"), + py::arg("input"), + R"( + Create WrapType with given node types and input. + + :param type_name: node types. For example: ["opset8.Abs", "opset8.Relu"] + :type type_name: List[str] + + :param input: Input node. + :type input: openvino.runtime.Node + )"); + + wrap_type.def( + py::init( + [](const std::vector& type_names, const ov::Output& input, const Predicate& pred) { + return std::make_shared(get_types(type_names), + pred, + ov::OutputVector{input}); + }), + py::arg("type_names"), + py::arg("input"), + py::arg("predicate"), + R"( + Create WrapType with given node types, input and predicate. + + :param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"] + :type type_names: List[str] + + :param input: Node output. + :type input: openvino.runtime.Output + + :param predicate: Function that performs additional checks for matching. + :type predicate: function + )"); + + wrap_type.def(py::init([](const std::vector& type_names, + const std::shared_ptr& input, + const Predicate& pred) { + return std::make_shared(get_types(type_names), + pred, + ov::OutputVector{input}); + }), + py::arg("type_names"), + py::arg("input"), + py::arg("predicate"), + R"( + Create WrapType with given node types, input and predicate. + + :param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"] + :type type_names: List[str] + + :param input: Input node. + :type input: openvino.runtime.Node + + :param predicate: Function that performs additional checks for matching. + :type predicate: function + )"); + + wrap_type.def(py::init([](const std::vector& type_names, const ov::OutputVector& inputs) { + return std::make_shared(get_types(type_names), nullptr, inputs); + }), + py::arg("type_names"), + py::arg("inputs"), + R"( + Create WrapType with given node types and input. + + :param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"] + :type type_names: List[str] + + :param inputs: Nodes outputs. + :type inputs: List[openvino.runtime.Output] + )"); + + wrap_type.def(py::init([](const std::vector& type_names, const ov::NodeVector& inputs) { + return std::make_shared(get_types(type_names), + nullptr, + ov::as_output_vector(inputs)); + }), + py::arg("type_names"), + py::arg("inputs"), + R"( + Create WrapType with given node types and inputs. + + :param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"] + :type type_names: List[str] + + :param inputs: Input nodes. + :type inputs: List[openvino.runtime.Node] + )"); + + wrap_type.def( + py::init([](const std::vector& type_names, const ov::OutputVector& inputs, const Predicate& pred) { + return std::make_shared(get_types(type_names), pred, inputs); + }), + py::arg("type_names"), + py::arg("inputs"), + py::arg("predicate"), + R"( + Create WrapType with given node types, inputs and predicate. + + :param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"] + :type type_names: List[str] + + :param inputs: Nodes outputs. + :type inputs: List[openvino.runtime.Output] + + :param predicate: Function that performs additional checks for matching. + :type predicate: function + )"); + + wrap_type.def( + py::init([](const std::vector& type_names, const ov::NodeVector& inputs, const Predicate& pred) { + return std::make_shared(get_types(type_names), + pred, + ov::as_output_vector(inputs)); + }), + py::arg("type_names"), + py::arg("inputs"), + py::arg("predicate"), + R"( + Create WrapType with given node types, inputs and predicate. + + :param type_names: node types. For example: ["opset8.Abs", "opset8.Relu"] + :type type_names: List[str] + + :param inputs: Input nodes. + :type inputs: List[openvino.runtime.Node] + + :param predicate: Function that performs additional checks for matching. + :type predicate: function + )"); +} + +void reg_pattern_or(py::module m) { + py::class_, ov::Node> or_type(m, "Or"); + or_type.doc() = "openvino.runtime.passes.Or wraps ov::pass::pattern::op::Or"; + + or_type.def(py::init([](const ov::OutputVector& inputs) { + return std::make_shared(inputs); + }), + py::arg("inputs"), + R"( + Create pattern Or operation which is used to match any of given inputs. + + :param inputs: Operation inputs. + :type inputs: List[openvino.runtime.Output] + )"); + + or_type.def(py::init([](const ov::NodeVector& inputs) { + return std::make_shared(ov::as_output_vector(inputs)); + }), + py::arg("inputs"), + R"( + Create pattern Or operation which is used to match any of given inputs. + + :param inputs: Operation inputs. + :type inputs: List[openvino.runtime.Node] + )"); +} + +void reg_pattern_any_input(py::module m) { + py::class_, ov::Node> any_input( + m, + "AnyInput"); + any_input.doc() = "openvino.runtime.passes.AnyInput wraps ov::pass::pattern::op::Label"; + + any_input.def(py::init([]() { + return std::make_shared(); + }), + R"( + Create pattern AnyInput operation which is used to match any type of node. + )"); + + any_input.def(py::init([](const Predicate& pred) { + return std::make_shared(ov::element::dynamic, + ov::PartialShape::dynamic(), + pred); + }), + py::arg("predicate"), + R"( + Create pattern AnyInput operation which is used to match any type of node. + + :param pred: Function that performs additional checks for matching. + :type pred: function + )"); +} + +void reg_predicates(py::module m) { + m.def("consumers_count", &ov::pass::pattern::consumers_count); + m.def("has_static_dim", &ov::pass::pattern::has_static_dim); + m.def("has_static_dims", &ov::pass::pattern::has_static_dims); + m.def("has_static_shape", &ov::pass::pattern::has_static_shape); + m.def("has_static_rank", &ov::pass::pattern::has_static_rank); + m.def("rank_equals", &ov::pass::pattern::rank_equals); + m.def("type_matches", &ov::pass::pattern::type_matches); + m.def("type_matches_any", &ov::pass::pattern::type_matches_any); +} + +void reg_passes_pattern_ops(py::module m) { + reg_pattern_any_input(m); + reg_pattern_wrap_type(m); + reg_pattern_or(m); + reg_predicates(m); +} diff --git a/src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.hpp b/src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.hpp new file mode 100644 index 00000000000..beef6f7dfdc --- /dev/null +++ b/src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.hpp @@ -0,0 +1,11 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace py = pybind11; + +void reg_passes_pattern_ops(py::module m); diff --git a/src/bindings/python/src/pyopenvino/graph/passes/regmodule_graph_passes.cpp b/src/bindings/python/src/pyopenvino/graph/passes/regmodule_graph_passes.cpp index 8e35eb0c53f..ef38abf24c9 100644 --- a/src/bindings/python/src/pyopenvino/graph/passes/regmodule_graph_passes.cpp +++ b/src/bindings/python/src/pyopenvino/graph/passes/regmodule_graph_passes.cpp @@ -6,9 +6,24 @@ #include +#include "pyopenvino/graph/passes/graph_rewrite.hpp" +#include "pyopenvino/graph/passes/manager.hpp" +#include "pyopenvino/graph/passes/matcher_pass.hpp" +#include "pyopenvino/graph/passes/model_pass.hpp" +#include "pyopenvino/graph/passes/pass_base.hpp" +#include "pyopenvino/graph/passes/pattern_ops.hpp" +#include "pyopenvino/graph/passes/transformations.hpp" + namespace py = pybind11; void regmodule_graph_passes(py::module m) { py::module m_passes = m.def_submodule("passes", "Package openvino.runtime.passes wraps ov::passes"); - regclass_graph_passes_Manager(m_passes); + regclass_passes_PassBase(m_passes); + regclass_passes_ModelPass(m_passes); + regclass_passes_GraphRewrite(m_passes); + regclass_passes_Matcher(m_passes); + regclass_passes_MatcherPass(m_passes); + regclass_transformations(m_passes); + regclass_passes_Manager(m_passes); + reg_passes_pattern_ops(m_passes); } diff --git a/src/bindings/python/src/pyopenvino/graph/passes/regmodule_graph_passes.hpp b/src/bindings/python/src/pyopenvino/graph/passes/regmodule_graph_passes.hpp index dd31861241f..3c8bbdbc4af 100644 --- a/src/bindings/python/src/pyopenvino/graph/passes/regmodule_graph_passes.hpp +++ b/src/bindings/python/src/pyopenvino/graph/passes/regmodule_graph_passes.hpp @@ -5,7 +5,6 @@ #pragma once #include -#include "pyopenvino/graph/passes/manager.hpp" namespace py = pybind11; diff --git a/src/bindings/python/src/pyopenvino/graph/passes/transformations.cpp b/src/bindings/python/src/pyopenvino/graph/passes/transformations.cpp new file mode 100644 index 00000000000..41b5f3ca601 --- /dev/null +++ b/src/bindings/python/src/pyopenvino/graph/passes/transformations.cpp @@ -0,0 +1,111 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "pyopenvino/graph/passes/transformations.hpp" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +void regclass_transformations(py::module m) { + py::class_, ov::pass::ModelPass, ov::pass::PassBase> + serialize(m, "Serialize"); + serialize.doc() = "openvino.runtime.passes.Serialize transformation"; + serialize.def(py::init([](const std::string& path_to_xml, const std::string& path_to_bin) { + return std::make_shared(path_to_xml, path_to_bin); + }), + py::arg("path_to_xml"), + py::arg("path_to_bin"), + R"( + Create Serialize pass which is used for Model to IR serialization. + + :param path_to_xml: Path where *.xml file will be saved. + :type path_to_xml: str + + :param path_to_xml: Path where *.bin file will be saved. + :type path_to_xml: str + )"); + + serialize.def( + py::init( + [](const std::string& path_to_xml, const std::string& path_to_bin, ov::pass::Serialize::Version version) { + return std::make_shared(path_to_xml, path_to_bin, version); + }), + py::arg("path_to_xml"), + py::arg("path_to_bin"), + py::arg("version"), + R"( + Create Serialize pass which is used for Model to IR serialization. + + :param path_to_xml: Path where *.xml file will be saved. + :type path_to_xml: str + + :param path_to_xml: Path where *.bin file will be saved. + :type path_to_xml: str + + :param version: serialized IR version. + :type version: int + )"); + + py::class_, + ov::pass::ModelPass, + ov::pass::PassBase> + cf(m, "ConstantFolding"); + cf.doc() = "openvino.runtime.passes.ConstantFolding transformation"; + cf.def(py::init<>()); + + py::class_, + ov::pass::ModelPass, + ov::pass::PassBase> + visualize(m, "VisualizeTree"); + visualize.doc() = "openvino.runtime.passes.VisualizeTree transformation"; + visualize.def(py::init(), + py::arg("file_name"), + py::arg("nm") = nullptr, + py::arg("don_only") = false, + R"( + Create VisualizeTree pass which is used for Model to dot serialization. + + :param file_name: Path where serialized model will be saved. For example: /tmp/out.svg + :type file_name: str + + :param nm: Node modifier function. + :type nm: function + + :param don_only: Enable only dot file generation. + :type don_only: bool + )"); + + py::class_, ov::pass::ModelPass, ov::pass::PassBase> + make_stateful(m, "MakeStateful"); + make_stateful.doc() = "openvino.runtime.passes.MakeStateful transformation"; + // TODO: update docstrings for c-tors below + make_stateful.def(py::init(), py::arg("pairs_to_replace")); + make_stateful.def(py::init&>()); + + py::class_, ov::pass::ModelPass, ov::pass::PassBase> + low_latency(m, "LowLatency2"); + low_latency.doc() = "openvino.runtime.passes.LowLatency2 transformation"; + // TODO: update docstrings for c-tor below + low_latency.def(py::init(), py::arg("use_const_initializer") = true); + + py::class_, + ov::pass::ModelPass, + ov::pass::PassBase> + convert(m, "ConvertFP32ToFP16"); + convert.doc() = "openvino.runtime.passes.ConvertFP32ToFP16 transformation"; + convert.def(py::init<>()); +} diff --git a/src/bindings/python/src/pyopenvino/graph/passes/transformations.hpp b/src/bindings/python/src/pyopenvino/graph/passes/transformations.hpp new file mode 100644 index 00000000000..b285a9bd4ea --- /dev/null +++ b/src/bindings/python/src/pyopenvino/graph/passes/transformations.hpp @@ -0,0 +1,11 @@ +// Copyright (C) 2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace py = pybind11; + +void regclass_transformations(py::module m); \ No newline at end of file diff --git a/src/bindings/python/src/pyopenvino/graph/util.cpp b/src/bindings/python/src/pyopenvino/graph/util.cpp index b35bc1c5cbd..e9d01b32760 100644 --- a/src/bindings/python/src/pyopenvino/graph/util.cpp +++ b/src/bindings/python/src/pyopenvino/graph/util.cpp @@ -6,10 +6,15 @@ #include +#include "openvino/core/graph_util.hpp" #include "openvino/core/validation_util.hpp" +#include "openvino/pass/manager.hpp" namespace py = pybind11; +template +using overload_cast_ = pybind11::detail::overload_cast_impl; + void* numpy_to_c(py::array a) { py::buffer_info info = a.request(); return info.ptr; @@ -31,4 +36,24 @@ void regmodule_graph_util(py::module m) { from the resulting bound, otherwise Null. :rtype: openvino.runtime.op.Constant or openvino.runtime.Node )"); + + mod.def("replace_output_update_name", &ov::replace_output_update_name, py::arg("output"), py::arg("target_output")); + + mod.def("replace_node", + overload_cast_&, const std::shared_ptr&>()(&ov::replace_node), + py::arg("target"), + py::arg("replacement")); + + mod.def("replace_node", + overload_cast_&, const ov::OutputVector&>()(&ov::replace_node), + py::arg("target"), + py::arg("replacement")); + + mod.def("replace_node", + overload_cast_&, + const std::shared_ptr&, + const std::vector&>()(&ov::replace_node), + py::arg("target"), + py::arg("replacement"), + py::arg("outputs_order")); } diff --git a/src/bindings/python/tests/test_transformations/test_graph_rewrite.py b/src/bindings/python/tests/test_transformations/test_graph_rewrite.py new file mode 100644 index 00000000000..02131be5e71 --- /dev/null +++ b/src/bindings/python/tests/test_transformations/test_graph_rewrite.py @@ -0,0 +1,73 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +from openvino.runtime import opset8 +from openvino.runtime.passes import Manager, GraphRewrite, MatcherPass, WrapType, Matcher + +from utils.utils import count_ops, get_test_function, PatternReplacement + + +def test_graph_rewrite(): + model = get_test_function() + + m = Manager() + # check that register pass returns pass instance + anchor = m.register_pass(GraphRewrite()) + anchor.add_matcher(PatternReplacement()) + m.run_passes(model) + + assert count_ops(model, "Relu") == [2] + + +def test_register_new_node(): + class InsertExp(MatcherPass): + def __init__(self): + MatcherPass.__init__(self) + self.model_changed = False + + param = WrapType("opset8.Parameter") + + def callback(m: Matcher) -> bool: + # Input->...->Result => Input->Exp->...->Result + root = m.get_match_value() + consumers = root.get_target_inputs() + + exp = opset8.exp(root) + for consumer in consumers: + consumer.replace_source_output(exp.output(0)) + + # For testing purpose + self.model_changed = True + + # Use new operation for additional matching + self.register_new_node(exp) + + # Root node wasn't replaced or changed + return False + + self.register_matcher(Matcher(param, "InsertExp"), callback) + + class RemoveExp(MatcherPass): + def __init__(self): + MatcherPass.__init__(self) + self.model_changed = False + + param = WrapType("opset8.Exp") + + def callback(m: Matcher) -> bool: + root = m.get_match_root() + root.output(0).replace(root.input_value(0)) + + # For testing purpose + self.model_changed = True + + return True + + self.register_matcher(Matcher(param, "RemoveExp"), callback) + + m = Manager() + ins = m.register_pass(InsertExp()) + rem = m.register_pass(RemoveExp()) + m.run_passes(get_test_function()) + + assert ins.model_changed + assert rem.model_changed diff --git a/src/bindings/python/tests/test_transformations/test_manager.py b/src/bindings/python/tests/test_transformations/test_manager.py new file mode 100644 index 00000000000..88b05c164d3 --- /dev/null +++ b/src/bindings/python/tests/test_transformations/test_manager.py @@ -0,0 +1,44 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +from openvino.runtime.passes import Manager, GraphRewrite, BackwardGraphRewrite, Serialize + +from utils.utils import MyModelPass, PatternReplacement, expect_exception + + +def test_registration_and_pass_name(): + m = Manager() + + a = m.register_pass(PatternReplacement()) + a.set_name("PatterReplacement") + + b = m.register_pass(MyModelPass()) + b.set_name("ModelPass") + + c = m.register_pass(GraphRewrite()) + c.set_name("Anchor") + + d = c.add_matcher(PatternReplacement()) + d.set_name("PatterReplacement") + + e = m.register_pass(BackwardGraphRewrite()) + e.set_name("BackAnchor") + + f = e.add_matcher(PatternReplacement()) + f.set_name("PatterReplacement") + + PatternReplacement().set_name("PatternReplacement") + MyModelPass().set_name("MyModelPass") + GraphRewrite().set_name("Anchor") + BackwardGraphRewrite().set_name("BackAnchor") + + # Preserve legacy behaviour when registered pass doesn't exist + # and in this case we shouldn't throw an exception. + m.register_pass("NotExistingPass") + + +def test_negative_pass_registration(): + m = Manager() + expect_exception(lambda: m.register_pass(PatternReplacement)) + expect_exception(lambda: m.register_pass("PatternReplacement", PatternReplacement())) + expect_exception(lambda: m.register_pass("Serialize", Serialize("out.xml", "out.bin"))) + expect_exception(lambda: m.register_pass("Serialize", "out.xml", "out.bin", "out.wrong")) diff --git a/src/bindings/python/tests/test_transformations/test_matcher_pass.py b/src/bindings/python/tests/test_transformations/test_matcher_pass.py new file mode 100644 index 00000000000..8c6a5ebc6d9 --- /dev/null +++ b/src/bindings/python/tests/test_transformations/test_matcher_pass.py @@ -0,0 +1,56 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +from openvino.runtime import opset8 +from openvino.runtime.passes import Manager, Matcher, MatcherPass, WrapType +from openvino.runtime.utils import replace_node + +from utils.utils import count_ops, get_test_function, PatternReplacement + + +def test_simple_pattern_replacement(): + # Simple: for Extensions. Without any classes and inheritance. + def pattern_replacement(): + param = WrapType("opset8.Parameter") + relu = WrapType("opset8.Relu", param.output(0)) + + def callback(m: Matcher) -> bool: + root = m.get_match_root() + + # Just to check that capturing works and we can + # link pattern nodes with matched graph nodes. + assert relu in m.get_pattern_value_map() + + new_relu = opset8.exp(root.input_value(0)) # ot root.input(0).get_source_output() + replace_node(root, new_relu) + return True + + return Matcher(relu, "SimpleReplacement"), callback + + model = get_test_function() + + m = Manager() + m.register_pass(MatcherPass(*pattern_replacement())) + m.run_passes(model) + + assert count_ops(model, ("Relu", "Exp")) == [0, 1] + + +def test_matcher_pass(): + model = get_test_function() + + m = Manager() + # check that register pass returns pass instance + p = m.register_pass(PatternReplacement()) + m.run_passes(model) + + assert p.model_changed + assert count_ops(model, "Relu") == [2] + + +def test_matcher_pass_apply(): + model = get_test_function() + + p = PatternReplacement() + p.apply(model.get_result().input_value(0).get_node()) + + assert count_ops(model, "Relu") == [2] diff --git a/src/bindings/python/tests/test_transformations/test_model_pass.py b/src/bindings/python/tests/test_transformations/test_model_pass.py new file mode 100644 index 00000000000..04fb1908bd6 --- /dev/null +++ b/src/bindings/python/tests/test_transformations/test_model_pass.py @@ -0,0 +1,13 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +from openvino.runtime.passes import Manager + +from utils.utils import get_test_function, MyModelPass + + +def test_model_pass(): + m = Manager() + p = m.register_pass(MyModelPass()) + m.run_passes(get_test_function()) + + assert p.model_changed diff --git a/src/bindings/python/tests/test_transformations/test_offline_api.py b/src/bindings/python/tests/test_transformations/test_offline_api.py index 195d8c597bf..807961771a3 100644 --- a/src/bindings/python/tests/test_transformations/test_offline_api.py +++ b/src/bindings/python/tests/test_transformations/test_offline_api.py @@ -67,26 +67,6 @@ def test_make_stateful_transformations(): assert len(function.get_results()) == 0 -def test_serialize_pass(): - core = Core() - xml_path = "serialized_function.xml" - bin_path = "serialized_function.bin" - - func = get_test_function() - - serialize(func, xml_path, bin_path) - - assert func is not None - - res_func = core.read_model(model=xml_path, weights=bin_path) - - assert func.get_parameters() == res_func.get_parameters() - assert func.get_ordered_ops() == res_func.get_ordered_ops() - - os.remove(xml_path) - os.remove(bin_path) - - def test_serialize_pass_v2(): core = Core() xml_path = "./serialized_function.xml" diff --git a/src/bindings/python/tests/test_transformations/test_pattern_ops.py b/src/bindings/python/tests/test_transformations/test_pattern_ops.py new file mode 100644 index 00000000000..a670ae96d43 --- /dev/null +++ b/src/bindings/python/tests/test_transformations/test_pattern_ops.py @@ -0,0 +1,108 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +import numpy as np + +from openvino.runtime import PartialShape, opset8 +from openvino.runtime.passes import Matcher, WrapType, Or, AnyInput +from openvino.runtime.passes import consumers_count, has_static_dim, has_static_dims, \ + has_static_shape, has_static_rank, type_matches, type_matches_any +from openvino.runtime.utils.types import get_element_type + +from utils.utils import expect_exception + + +def test_wrap_type_pattern_type(): + for i in range(1, 9): + WrapType("opset{}.Parameter".format(i)) + WrapType("opset{}::Parameter".format(i)) + + # Negative check not to forget to update opset map in get_type function + expect_exception(lambda: WrapType("opset9.Parameter"), "Unsupported opset type: opset9") + + # Generic negative test cases + expect_exception(lambda: WrapType("")) + expect_exception(lambda: WrapType("opset8")) + expect_exception(lambda: WrapType("Parameter")) + expect_exception(lambda: WrapType("opset.Parameter")) + expect_exception(lambda: WrapType("opset8,Parameter")) + expect_exception(lambda: WrapType("Parameter.opset8")) + + +def test_wrap_type_ctors(): + param = opset8.parameter(PartialShape([1, 3, 22, 22])) + relu = opset8.relu(param.output(0)) + slope = opset8.parameter(PartialShape([])) + prelu = opset8.prelu(param.output(0), slope.output(0)) + + m = Matcher(WrapType(["opset8.Relu", "opset8.PRelu"]), "FindActivation") + assert m.match(relu) + assert m.match(prelu) + + m = Matcher(WrapType(["opset8.Relu", "opset8.PRelu"], + WrapType("opset8.Parameter").output(0)), "FindActivation") + assert m.match(relu) + + +def test_or(): + param = opset8.parameter(PartialShape([1, 3, 22, 22])) + relu = opset8.relu(param.output(0)) + slope = opset8.parameter(PartialShape([])) + prelu = opset8.prelu(param.output(0), slope.output(0)) + + m = Matcher(Or([WrapType("opset8.Relu"), + WrapType("opset8.PRelu")]), "FindActivation") + assert m.match(relu) + assert m.match(prelu) + + +def test_any_input(): + param = opset8.parameter(PartialShape([1, 3, 22, 22])) + relu = opset8.relu(param.output(0)) + slope = opset8.parameter(PartialShape([])) + prelu = opset8.prelu(param.output(0), slope.output(0)) + + m = Matcher(WrapType("opset8.PRelu", [AnyInput(), AnyInput()]), "FindActivation") + assert not m.match(relu) + assert m.match(prelu) + + +def test_any_input_predicate(): + param = opset8.parameter(PartialShape([1, 3, 22, 22])) + slope = opset8.parameter(PartialShape([])) + + m = Matcher(AnyInput(lambda output: len(output.get_shape()) == 4), "FindActivation") + assert m.match(param) + assert not m.match(slope) + + +def test_all_predicates(): + static_param = opset8.parameter(PartialShape([1, 3, 22, 22]), np.float32) + dynamic_param = opset8.parameter(PartialShape([-1, 6]), np.long) + fully_dynamic_param = opset8.parameter(PartialShape.dynamic()) + + assert Matcher(WrapType("opset8.Parameter", consumers_count(0)), "Test").match(static_param) + assert not Matcher(WrapType("opset8.Parameter", consumers_count(1)), "Test").match(static_param) + + assert Matcher(WrapType("opset8.Parameter", has_static_dim(1)), "Test").match(static_param) + assert not Matcher(WrapType("opset8.Parameter", has_static_dim(0)), "Test").match(dynamic_param) + + assert Matcher(WrapType("opset8.Parameter", has_static_dims([0, 3])), "Test").match(static_param) + assert not Matcher(WrapType("opset8.Parameter", has_static_dims([0, 1])), "Test").match(dynamic_param) + + assert Matcher(WrapType("opset8.Parameter", has_static_shape()), "Test").match(static_param) + assert not Matcher(WrapType("opset8.Parameter", has_static_shape()), "Test").match(dynamic_param) + + assert Matcher(WrapType("opset8.Parameter", has_static_rank()), "Test").match(dynamic_param) + assert not Matcher(WrapType("opset8.Parameter", has_static_rank()), "Test").match(fully_dynamic_param) + + assert Matcher(WrapType("opset8.Parameter", + type_matches(get_element_type(np.float32))), "Test").match(static_param) + assert not Matcher(WrapType("opset8.Parameter", + type_matches(get_element_type(np.float32))), "Test").match(dynamic_param) + + assert Matcher(WrapType("opset8.Parameter", + type_matches_any([get_element_type(np.float32), + get_element_type(np.long)])), "Test").match(static_param) + assert Matcher(WrapType("opset8.Parameter", + type_matches_any([get_element_type(np.float32), + get_element_type(np.long)])), "Test").match(dynamic_param) diff --git a/src/bindings/python/tests/test_transformations/test_public_transformations.py b/src/bindings/python/tests/test_transformations/test_public_transformations.py new file mode 100644 index 00000000000..2318c9bfcd9 --- /dev/null +++ b/src/bindings/python/tests/test_transformations/test_public_transformations.py @@ -0,0 +1,114 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +import os +import numpy as np + +from openvino.runtime import Model, PartialShape, Shape, opset8, Core +from openvino.runtime.passes import Manager, ConstantFolding, MakeStateful,\ + ConvertFP32ToFP16, LowLatency2, Serialize + +from utils.utils import count_ops, get_test_function + + +def get_model(): + param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter") + param.get_output_tensor(0).set_names({"parameter"}) + relu = opset8.relu(param) + reshape = opset8.reshape(relu, opset8.shape_of(relu), False) + res = opset8.result(reshape, name="result") + res.get_output_tensor(0).set_names({"result"}) + return Model([res], [param], "test") + + +def test_make_stateful(): + model = get_model() + + m = Manager() + p = MakeStateful({"parameter": "result"}) + m.register_pass(p) + m.run_passes(model) + + assert model is not None + assert len(model.get_parameters()) == 0 + assert len(model.get_results()) == 0 + + +def test_constant_folding(): + model = get_model() + + m = Manager() + m.register_pass(ConstantFolding()) + m.run_passes(model) + + assert model is not None + assert count_ops(model, "ShapeOf") == [0] + + +def test_convert_precision(): + model = get_model() + + m = Manager() + m.register_pass(ConvertFP32ToFP16()) + m.run_passes(model) + + assert model is not None + # TODO: fix bug 82773 with float16 type comparison + # assert model.get_parameters()[0].get_element_type() == np.float16 + + +def test_low_latency2(): + X = opset8.parameter(Shape([32, 40, 10]), np.float32, "X") + Y = opset8.parameter(Shape([32, 40, 10]), np.float32, "Y") + M = opset8.parameter(Shape([32, 2, 10]), np.float32, "M") + + X_i = opset8.parameter(Shape([32, 2, 10]), np.float32, "X_i") + Y_i = opset8.parameter(Shape([32, 2, 10]), np.float32, "Y_i") + M_body = opset8.parameter(Shape([32, 2, 10]), np.float32, "M_body") + + sum = opset8.add(X_i, Y_i) + Zo = opset8.multiply(sum, M_body) + + body = Model([Zo], [X_i, Y_i, M_body], "body_function") + + ti = opset8.tensor_iterator() + ti.set_body(body) + ti.set_sliced_input(X_i, X.output(0), 0, 2, 2, 39, 1) + ti.set_sliced_input(Y_i, Y.output(0), 0, 2, 2, -1, 1) + ti.set_invariant_input(M_body, M.output(0)) + + out0 = ti.get_iter_value(Zo.output(0), -1) + out1 = ti.get_concatenated_slices(Zo.output(0), 0, 2, 2, 39, 1) + + result0 = opset8.result(out0) + result1 = opset8.result(out1) + + model = Model([result0, result1], [X, Y, M]) + + m = Manager() + m.register_pass(LowLatency2()) + m.run_passes(model) + + # TODO: create TI which will be transformed by LowLatency2 + assert count_ops(model, "TensorIterator") == [1] + + +def test_serialize_pass(): + core = Core() + xml_path = "serialized_function.xml" + bin_path = "serialized_function.bin" + + func = get_test_function() + + m = Manager() + m.register_pass(Serialize(xml_path, bin_path)) + m.run_passes(func) + + assert func is not None + + res_func = core.read_model(model=xml_path, weights=bin_path) + + assert func.get_parameters() == res_func.get_parameters() + assert func.get_ordered_ops() == res_func.get_ordered_ops() + + os.remove(xml_path) + os.remove(bin_path) diff --git a/src/bindings/python/tests/test_transformations/test_replacement_api.py b/src/bindings/python/tests/test_transformations/test_replacement_api.py new file mode 100644 index 00000000000..6e85738ebe7 --- /dev/null +++ b/src/bindings/python/tests/test_transformations/test_replacement_api.py @@ -0,0 +1,59 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from openvino.runtime import Model, PartialShape, opset8 +from openvino.runtime.utils import replace_node, replace_output_update_name + + +def get_test_function(): + # Parameter->Relu->Result + param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter") + relu = opset8.relu(param.output(0)) + res = opset8.result(relu.output(0), name="result") + return Model([res], [param], "test") + + +def test_output_replace(): + param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter") + relu = opset8.relu(param.output(0)) + res = opset8.result(relu.output(0), name="result") + + exp = opset8.exp(param.output(0)) + relu.output(0).replace(exp.output(0)) + + assert res.input_value(0).get_node() == exp + + +def test_replace_source_output(): + param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter") + relu = opset8.relu(param.output(0)) + res = opset8.result(relu.output(0), name="result") + + exp = opset8.exp(param.output(0)) + res.input(0).replace_source_output(exp.output(0)) + + assert len(exp.output(0).get_target_inputs()) == 1 + assert len(relu.output(0).get_target_inputs()) == 0 + assert next(iter(exp.output(0).get_target_inputs())).get_node() == res + + +def test_replace_node(): + param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter") + relu = opset8.relu(param.output(0)) + res = opset8.result(relu.output(0), name="result") + + exp = opset8.exp(param.output(0)) + replace_node(relu, exp) + + assert res.input_value(0).get_node() == exp + + +def test_replace_output_update_name(): + param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter") + relu = opset8.relu(param.output(0)) + exp = opset8.exp(relu.output(0)) + res = opset8.result(exp.output(0), name="result") + + replace_output_update_name(exp.output(0), exp.input_value(0)) + + assert res.input_value(0).get_node() == exp diff --git a/src/bindings/python/tests/test_transformations/utils/utils.py b/src/bindings/python/tests/test_transformations/utils/utils.py new file mode 100644 index 00000000000..a6835edb543 --- /dev/null +++ b/src/bindings/python/tests/test_transformations/utils/utils.py @@ -0,0 +1,75 @@ +# Copyright (C) 2022 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from openvino.runtime import Model, PartialShape, opset8 +from openvino.runtime.passes import ModelPass, Matcher, MatcherPass, WrapType + + +def get_test_function(): + # Parameter->Relu->Result + param = opset8.parameter(PartialShape([1, 3, 22, 22]), name="parameter") + relu = opset8.relu(param.output(0)) + res = opset8.result(relu.output(0), name="result") + return Model([res], [param], "test") + + +def count_ops(model, op_types): + if isinstance(op_types, str): + op_types = [op_types] + + cnt = [0] * len(op_types) + types = {op_types[index]: index for index in range(len(op_types))} + for op in model.get_ops(): + op_type = op.get_type_info().name + if op_type in types: + cnt[types[op_type]] += 1 + return cnt + + +def expect_exception(func, message=""): + def check(): + try: + func() + return None + except Exception as e: + return str(e) + res = check() + if res is None: + raise AssertionError("Exception is not thrown!") + assert message in res + + +class PatternReplacement(MatcherPass): + def __init__(self): + MatcherPass.__init__(self) + self.model_changed = False + + relu = WrapType("opset8::Relu") + + def callback(m: Matcher) -> bool: + self.applied = True + root = m.get_match_root() + new_relu = opset8.relu(root.input(0).get_source_output()) + + # For testing purpose + self.model_changed = True + + # # Use new operation for additional matching + # self.register_new_node(new_relu) + + # Input->Relu->Result => Input->Relu->Relu->Result + root.input(0).replace_source_output(new_relu.output(0)) + return True + + self.register_matcher(Matcher(relu, "PatternReplacement"), callback) + + +class MyModelPass(ModelPass): + def __init__(self): + super().__init__() + self.model_changed = False + + def run_on_model(self, model): + for op in model.get_ops(): + if op.get_type_info().name == "Relu": + self.model_changed = True diff --git a/src/core/include/openvino/pass/graph_rewrite.hpp b/src/core/include/openvino/pass/graph_rewrite.hpp index 13210657d15..a309a559823 100644 --- a/src/core/include/openvino/pass/graph_rewrite.hpp +++ b/src/core/include/openvino/pass/graph_rewrite.hpp @@ -61,6 +61,10 @@ public: set_property(property, true); } + MatcherPass(const std::shared_ptr& m, const matcher_pass_callback& callback) : PassBase() { + register_matcher(m, callback); + } + bool apply(std::shared_ptr node); template @@ -76,6 +80,10 @@ public: return node; } + std::shared_ptr register_new_node_(const std::shared_ptr& node) { + return register_new_node(node); + } + const std::vector>& get_new_nodes() { return m_new_nodes; } @@ -88,8 +96,10 @@ public: protected: void register_matcher(const std::shared_ptr& m, - const graph_rewrite_callback& callback, - const PassPropertyMask& property = PassProperty::CHANGE_DYNAMIC_STATE); + const matcher_pass_callback& callback, + const PassPropertyMask& property); + + void register_matcher(const std::shared_ptr& m, const matcher_pass_callback& callback); private: handler_callback m_handler; @@ -195,6 +205,13 @@ public: } } + std::shared_ptr add_matcher(const std::shared_ptr& pass) { + auto pass_config = get_pass_config(); + pass->set_pass_config(pass_config); + m_matchers.push_back(pass); + return pass; + } + OPENVINO_DEPRECATED("Use MatcherPass instead") void add_matcher(const std::shared_ptr& m, const graph_rewrite_callback& callback, diff --git a/src/core/include/openvino/pass/manager.hpp b/src/core/include/openvino/pass/manager.hpp index be55be4ed41..2cd18982778 100644 --- a/src/core/include/openvino/pass/manager.hpp +++ b/src/core/include/openvino/pass/manager.hpp @@ -51,6 +51,15 @@ public: return rc; } + std::shared_ptr register_pass_instance(std::shared_ptr pass) { + pass->set_pass_config(m_pass_config); + m_pass_list.push_back(pass); + if (m_per_pass_validation) { + push_pass(); + } + return pass; + } + void run_passes(std::shared_ptr); void set_pass_visualization(bool new_state) { diff --git a/src/core/src/pass/graph_rewrite.cpp b/src/core/src/pass/graph_rewrite.cpp index 8898f8e3da6..000b7222c4f 100644 --- a/src/core/src/pass/graph_rewrite.cpp +++ b/src/core/src/pass/graph_rewrite.cpp @@ -308,6 +308,11 @@ void ov::pass::MatcherPass::register_matcher(const std::shared_ptr& m, + const ov::graph_rewrite_callback& callback) { + register_matcher(m, callback, PassProperty::CHANGE_DYNAMIC_STATE); +} + bool ov::pass::MatcherPass::apply(std::shared_ptr node) { OV_ITT_SCOPED_TASK(ov::itt::domains::nGraph, pass::perf_counters_graph_rewrite()[get_type_info()]); m_new_nodes.clear();