Enable creation of custom ops in Python (NodeFactory.add_extension) (#18858)

* [WIP] Added load_extension on Python side

* Added load_extension to NodeFactory as an alternative way to expose them to user (openvino path only, no changes in ngraph legacy path

* Reverted adding load_extensions in openvino.runtime

* Renamed load_extension to add_extension to be aligned with other part of extension API

* Applied code style rules

* Shorter description of NodeFactory.add_extension

* Fixed accidentally deleted indent

* Explicit error when custom op without intpus is attempted to be created, better help for NodeFactory.add_extension (op version clarification)

* Style fixes

* Test to cover NodeFactory.add_extension

* Minor wording changes

* Code style

* Fix code style

* Limit NodeFactory.add_extension test to specific test configurations
This commit is contained in:
Sergey Lyalin 2023-08-01 11:53:23 +04:00 committed by GitHub
parent 974ef62ce6
commit 5587d59bff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 117 additions and 11 deletions

View File

@ -6,6 +6,7 @@ import logging as log
from functools import partial
from typing import Any, Dict, List, Optional, Union
from pathlib import Path
from openvino._pyopenvino import NodeFactory as _NodeFactory
@ -92,6 +93,30 @@ class NodeFactory(object):
return node
def add_extension(self, lib_path: Union[Path, str]) -> None:
"""Add custom operations from extension library.
Extends operation types available for creation by operations
loaded from prebuilt C++ library. Enables instantiation of custom
operations exposed in that library without direct use of
operation classes. Other types of extensions, e.g. conversion
extensions, if they are exposed in the library, are ignored.
In case if an extension operation type from a library match
one of existing operations registered before (from the standard
OpenVINO opset or from another extension loaded earlier), a new
operation overrides an old operation.
Version of an operation is ignored: an operation with a given type and
a given version/opset will override operation with the same type but
different version/opset in the same NodeFactory instance.
Use separate libraries and NodeFactory instances to differentiate
versions/opsets.
:param lib_path: A path to the library with extension.
"""
self.factory.add_extension(lib_path)
@staticmethod
def _arguments_as_outputs(arguments: List[Union[Node, Output]]) -> List[Output]:
outputs = []

View File

@ -20,11 +20,14 @@
#include "ngraph/check.hpp"
#include "openvino/core/except.hpp"
#include "openvino/core/node.hpp"
#include "openvino/core/op_extension.hpp"
#include "openvino/core/so_extension.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/op/util/variable.hpp"
#include "openvino/opsets/opset.hpp"
#include "openvino/util/log.hpp"
#include "pyopenvino/core/common.hpp"
#include "pyopenvino/utils/utils.hpp"
namespace py = pybind11;
@ -37,26 +40,52 @@ public:
std::shared_ptr<ov::Node> create(const std::string op_type_name,
const ov::OutputVector& arguments,
const py::dict& attributes = py::dict()) {
std::shared_ptr<ov::Node> op_node = std::shared_ptr<ov::Node>(m_opset.create(op_type_name));
// Check for available extensions first, because they may override ops from main opset
auto ext_it = m_opset_so_extensions.find(op_type_name);
if (ext_it != m_opset_so_extensions.end()) {
auto op_extension = std::dynamic_pointer_cast<ov::BaseOpExtension>(ext_it->second->extension());
NGRAPH_CHECK(op_extension); // guaranteed by add_extension method
util::DictAttributeDeserializer visitor(attributes, m_variables);
auto outputs = op_extension->create(arguments, visitor);
NGRAPH_CHECK(op_node != nullptr, "Couldn't create operator: ", op_type_name);
NGRAPH_CHECK(!ov::op::util::is_constant(op_node),
"Currently NodeFactory doesn't support Constant node: ",
op_type_name);
NGRAPH_CHECK(outputs.size() > 0,
"Failed to create extension operation with type: ",
op_type_name,
" because it doesn't contain output ports. Operation should has at least one output port.");
util::DictAttributeDeserializer visitor(attributes, m_variables);
auto node = outputs[0].get_node_shared_ptr();
return node;
} else {
std::shared_ptr<ov::Node> op_node = std::shared_ptr<ov::Node>(m_opset.create(op_type_name));
op_node->set_arguments(arguments);
op_node->visit_attributes(visitor);
op_node->constructor_validate_and_infer_types();
NGRAPH_CHECK(op_node != nullptr, "Couldn't create operation: ", op_type_name);
NGRAPH_CHECK(!ov::op::util::is_constant(op_node),
"Currently NodeFactory doesn't support Constant operation: ",
op_type_name);
return op_node;
util::DictAttributeDeserializer visitor(attributes, m_variables);
op_node->set_arguments(arguments);
op_node->visit_attributes(visitor);
op_node->constructor_validate_and_infer_types();
return op_node;
}
}
std::shared_ptr<ov::Node> create(const std::string op_type_name) {
// Check for available extensions first, because they may override ops from main opset
auto ext_it = m_opset_so_extensions.find(op_type_name);
// No way to instantiate operation without inputs, so if extension operation is found report an error.
NGRAPH_CHECK(ext_it == m_opset_so_extensions.end(),
"Couldn't create operation of type ",
op_type_name,
" from an extension library as no inputs were provided. Currently NodeFactory doesn't support ",
"operations without inputs. Provide at least one input.");
std::shared_ptr<ov::Node> op_node = std::shared_ptr<ov::Node>(m_opset.create(op_type_name));
NGRAPH_CHECK(op_node != nullptr, "Couldn't create operator: ", op_type_name);
NGRAPH_CHECK(op_node != nullptr, "Couldn't create operation: ", op_type_name);
NGRAPH_CHECK(!ov::op::util::is_constant(op_node),
"Currently NodeFactory doesn't support Constant node: ",
op_type_name);
@ -66,6 +95,24 @@ public:
return op_node;
}
void add_extension(const std::string& lib_path) {
// Load extension library, seach for operation extensions (derived from ov::BaseOpExtension) and keep
// them in m_opset_so_extensions for future use in create methods.
// NodeFactory provides a simplified API for node creation withotu involving version of operation.
// It means all operations share the same name space and real operation versions (opsets) from extension
// library are ignored.
auto extensions = ov::detail::load_extensions(lib_path);
for (auto extension : extensions) {
auto so_extension = std::dynamic_pointer_cast<ov::detail::SOExtension>(extension);
ov::Extension::Ptr extension_extracted = so_extension ? so_extension->extension() : extension;
if (auto op_extension = std::dynamic_pointer_cast<ov::BaseOpExtension>(extension_extracted)) {
auto op_type = op_extension->get_type_info().name;
// keep so extension instead of extension_extracted to hold loaded library
m_opset_so_extensions[op_type] = so_extension;
}
}
}
private:
const ov::OpSet& get_opset(std::string opset_ver) {
std::locale loc;
@ -81,6 +128,7 @@ private:
}
const ov::OpSet& m_opset = ov::get_opset12();
std::map<std::string, std::shared_ptr<ov::detail::SOExtension>> m_opset_so_extensions;
std::unordered_map<std::string, std::shared_ptr<ov::op::util::Variable>> m_variables;
};
} // namespace
@ -101,6 +149,10 @@ void regclass_graph_NodeFactory(py::module m) {
return self.create(name, arguments, attributes);
});
node_factory.def("add_extension", [](NodeFactory& self, const py::object& lib_path) {
return self.add_extension(Common::utils::convert_path_to_string(lib_path));
});
node_factory.def("__repr__", [](const NodeFactory& self) {
return Common::get_simple_repr(self);
});

View File

@ -3,6 +3,9 @@
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
from sys import platform
from openvino.runtime import compile_model, Model
import openvino.runtime.opset8 as ov
from openvino.runtime.exceptions import UserInputError
from openvino.runtime.utils.node_factory import NodeFactory
@ -92,3 +95,29 @@ def test_node_factory_validate_missing_arguments():
pass
else:
raise AssertionError("Validation of missing arguments has unexpectedly passed.")
@pytest.mark.template_plugin()
def test_extension_added_from_library():
if platform == "win32":
library_path = "openvino_template_extension.dll"
else:
library_path = "libopenvino_template_extension.so"
factory = NodeFactory()
factory.add_extension(library_path)
data = ov.parameter([1, 2], dtype=np.float32)
identity = factory.create("Identity", data.outputs())
model = Model([identity], [data])
compiled = compile_model(model)
tensor = np.array([[3, 4]], dtype=np.float32)
result = compiled(tensor)
# TODO: There is an issue with life time of objects, free resources explicitly
# otherwise segfault will occur. Workaround: create factory as a global variable.
del compiled
del model
del identity
assert np.array_equal(tensor, result[0])