Enable creation of custom ops in Python (NodeFactory.add_extension) (#18858)
* [WIP] Added load_extension on Python side * Added load_extension to NodeFactory as an alternative way to expose them to user (openvino path only, no changes in ngraph legacy path * Reverted adding load_extensions in openvino.runtime * Renamed load_extension to add_extension to be aligned with other part of extension API * Applied code style rules * Shorter description of NodeFactory.add_extension * Fixed accidentally deleted indent * Explicit error when custom op without intpus is attempted to be created, better help for NodeFactory.add_extension (op version clarification) * Style fixes * Test to cover NodeFactory.add_extension * Minor wording changes * Code style * Fix code style * Limit NodeFactory.add_extension test to specific test configurations
This commit is contained in:
parent
974ef62ce6
commit
5587d59bff
@ -6,6 +6,7 @@ import logging as log
|
||||
|
||||
from functools import partial
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from pathlib import Path
|
||||
|
||||
from openvino._pyopenvino import NodeFactory as _NodeFactory
|
||||
|
||||
@ -92,6 +93,30 @@ class NodeFactory(object):
|
||||
|
||||
return node
|
||||
|
||||
def add_extension(self, lib_path: Union[Path, str]) -> None:
|
||||
"""Add custom operations from extension library.
|
||||
|
||||
Extends operation types available for creation by operations
|
||||
loaded from prebuilt C++ library. Enables instantiation of custom
|
||||
operations exposed in that library without direct use of
|
||||
operation classes. Other types of extensions, e.g. conversion
|
||||
extensions, if they are exposed in the library, are ignored.
|
||||
|
||||
In case if an extension operation type from a library match
|
||||
one of existing operations registered before (from the standard
|
||||
OpenVINO opset or from another extension loaded earlier), a new
|
||||
operation overrides an old operation.
|
||||
|
||||
Version of an operation is ignored: an operation with a given type and
|
||||
a given version/opset will override operation with the same type but
|
||||
different version/opset in the same NodeFactory instance.
|
||||
Use separate libraries and NodeFactory instances to differentiate
|
||||
versions/opsets.
|
||||
|
||||
:param lib_path: A path to the library with extension.
|
||||
"""
|
||||
self.factory.add_extension(lib_path)
|
||||
|
||||
@staticmethod
|
||||
def _arguments_as_outputs(arguments: List[Union[Node, Output]]) -> List[Output]:
|
||||
outputs = []
|
||||
|
@ -20,11 +20,14 @@
|
||||
#include "ngraph/check.hpp"
|
||||
#include "openvino/core/except.hpp"
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/core/op_extension.hpp"
|
||||
#include "openvino/core/so_extension.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/op/util/variable.hpp"
|
||||
#include "openvino/opsets/opset.hpp"
|
||||
#include "openvino/util/log.hpp"
|
||||
#include "pyopenvino/core/common.hpp"
|
||||
#include "pyopenvino/utils/utils.hpp"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
@ -37,26 +40,52 @@ public:
|
||||
std::shared_ptr<ov::Node> create(const std::string op_type_name,
|
||||
const ov::OutputVector& arguments,
|
||||
const py::dict& attributes = py::dict()) {
|
||||
std::shared_ptr<ov::Node> op_node = std::shared_ptr<ov::Node>(m_opset.create(op_type_name));
|
||||
// Check for available extensions first, because they may override ops from main opset
|
||||
auto ext_it = m_opset_so_extensions.find(op_type_name);
|
||||
if (ext_it != m_opset_so_extensions.end()) {
|
||||
auto op_extension = std::dynamic_pointer_cast<ov::BaseOpExtension>(ext_it->second->extension());
|
||||
NGRAPH_CHECK(op_extension); // guaranteed by add_extension method
|
||||
util::DictAttributeDeserializer visitor(attributes, m_variables);
|
||||
auto outputs = op_extension->create(arguments, visitor);
|
||||
|
||||
NGRAPH_CHECK(op_node != nullptr, "Couldn't create operator: ", op_type_name);
|
||||
NGRAPH_CHECK(!ov::op::util::is_constant(op_node),
|
||||
"Currently NodeFactory doesn't support Constant node: ",
|
||||
op_type_name);
|
||||
NGRAPH_CHECK(outputs.size() > 0,
|
||||
"Failed to create extension operation with type: ",
|
||||
op_type_name,
|
||||
" because it doesn't contain output ports. Operation should has at least one output port.");
|
||||
|
||||
util::DictAttributeDeserializer visitor(attributes, m_variables);
|
||||
auto node = outputs[0].get_node_shared_ptr();
|
||||
return node;
|
||||
} else {
|
||||
std::shared_ptr<ov::Node> op_node = std::shared_ptr<ov::Node>(m_opset.create(op_type_name));
|
||||
|
||||
op_node->set_arguments(arguments);
|
||||
op_node->visit_attributes(visitor);
|
||||
op_node->constructor_validate_and_infer_types();
|
||||
NGRAPH_CHECK(op_node != nullptr, "Couldn't create operation: ", op_type_name);
|
||||
NGRAPH_CHECK(!ov::op::util::is_constant(op_node),
|
||||
"Currently NodeFactory doesn't support Constant operation: ",
|
||||
op_type_name);
|
||||
|
||||
return op_node;
|
||||
util::DictAttributeDeserializer visitor(attributes, m_variables);
|
||||
|
||||
op_node->set_arguments(arguments);
|
||||
op_node->visit_attributes(visitor);
|
||||
op_node->constructor_validate_and_infer_types();
|
||||
|
||||
return op_node;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Node> create(const std::string op_type_name) {
|
||||
// Check for available extensions first, because they may override ops from main opset
|
||||
auto ext_it = m_opset_so_extensions.find(op_type_name);
|
||||
// No way to instantiate operation without inputs, so if extension operation is found report an error.
|
||||
NGRAPH_CHECK(ext_it == m_opset_so_extensions.end(),
|
||||
"Couldn't create operation of type ",
|
||||
op_type_name,
|
||||
" from an extension library as no inputs were provided. Currently NodeFactory doesn't support ",
|
||||
"operations without inputs. Provide at least one input.");
|
||||
|
||||
std::shared_ptr<ov::Node> op_node = std::shared_ptr<ov::Node>(m_opset.create(op_type_name));
|
||||
|
||||
NGRAPH_CHECK(op_node != nullptr, "Couldn't create operator: ", op_type_name);
|
||||
NGRAPH_CHECK(op_node != nullptr, "Couldn't create operation: ", op_type_name);
|
||||
NGRAPH_CHECK(!ov::op::util::is_constant(op_node),
|
||||
"Currently NodeFactory doesn't support Constant node: ",
|
||||
op_type_name);
|
||||
@ -66,6 +95,24 @@ public:
|
||||
return op_node;
|
||||
}
|
||||
|
||||
void add_extension(const std::string& lib_path) {
|
||||
// Load extension library, seach for operation extensions (derived from ov::BaseOpExtension) and keep
|
||||
// them in m_opset_so_extensions for future use in create methods.
|
||||
// NodeFactory provides a simplified API for node creation withotu involving version of operation.
|
||||
// It means all operations share the same name space and real operation versions (opsets) from extension
|
||||
// library are ignored.
|
||||
auto extensions = ov::detail::load_extensions(lib_path);
|
||||
for (auto extension : extensions) {
|
||||
auto so_extension = std::dynamic_pointer_cast<ov::detail::SOExtension>(extension);
|
||||
ov::Extension::Ptr extension_extracted = so_extension ? so_extension->extension() : extension;
|
||||
if (auto op_extension = std::dynamic_pointer_cast<ov::BaseOpExtension>(extension_extracted)) {
|
||||
auto op_type = op_extension->get_type_info().name;
|
||||
// keep so extension instead of extension_extracted to hold loaded library
|
||||
m_opset_so_extensions[op_type] = so_extension;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const ov::OpSet& get_opset(std::string opset_ver) {
|
||||
std::locale loc;
|
||||
@ -81,6 +128,7 @@ private:
|
||||
}
|
||||
|
||||
const ov::OpSet& m_opset = ov::get_opset12();
|
||||
std::map<std::string, std::shared_ptr<ov::detail::SOExtension>> m_opset_so_extensions;
|
||||
std::unordered_map<std::string, std::shared_ptr<ov::op::util::Variable>> m_variables;
|
||||
};
|
||||
} // namespace
|
||||
@ -101,6 +149,10 @@ void regclass_graph_NodeFactory(py::module m) {
|
||||
return self.create(name, arguments, attributes);
|
||||
});
|
||||
|
||||
node_factory.def("add_extension", [](NodeFactory& self, const py::object& lib_path) {
|
||||
return self.add_extension(Common::utils::convert_path_to_string(lib_path));
|
||||
});
|
||||
|
||||
node_factory.def("__repr__", [](const NodeFactory& self) {
|
||||
return Common::get_simple_repr(self);
|
||||
});
|
||||
|
@ -3,6 +3,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from sys import platform
|
||||
from openvino.runtime import compile_model, Model
|
||||
import openvino.runtime.opset8 as ov
|
||||
from openvino.runtime.exceptions import UserInputError
|
||||
from openvino.runtime.utils.node_factory import NodeFactory
|
||||
@ -92,3 +95,29 @@ def test_node_factory_validate_missing_arguments():
|
||||
pass
|
||||
else:
|
||||
raise AssertionError("Validation of missing arguments has unexpectedly passed.")
|
||||
|
||||
|
||||
@pytest.mark.template_plugin()
|
||||
def test_extension_added_from_library():
|
||||
if platform == "win32":
|
||||
library_path = "openvino_template_extension.dll"
|
||||
else:
|
||||
library_path = "libopenvino_template_extension.so"
|
||||
|
||||
factory = NodeFactory()
|
||||
factory.add_extension(library_path)
|
||||
|
||||
data = ov.parameter([1, 2], dtype=np.float32)
|
||||
identity = factory.create("Identity", data.outputs())
|
||||
model = Model([identity], [data])
|
||||
compiled = compile_model(model)
|
||||
tensor = np.array([[3, 4]], dtype=np.float32)
|
||||
result = compiled(tensor)
|
||||
|
||||
# TODO: There is an issue with life time of objects, free resources explicitly
|
||||
# otherwise segfault will occur. Workaround: create factory as a global variable.
|
||||
del compiled
|
||||
del model
|
||||
del identity
|
||||
|
||||
assert np.array_equal(tensor, result[0])
|
||||
|
Loading…
Reference in New Issue
Block a user