Added support of custom domain for OpExtension (C++/Python) and ConversionExtension (Python) (#14375)
* Added support for ONNX OpExtension with custom domain * review remarks * move domain ctor to onnx OpExtension * code refactor + new test * styles applied * [Python API] Support extensions with custom domains * try to fix windows build error * removed unnecessary stores
This commit is contained in:
parent
156905c381
commit
507cbe7045
@ -28,10 +28,12 @@ void regclass_frontend_onnx_ConversionExtension(py::module m) {
|
||||
public:
|
||||
using Ptr = std::shared_ptr<PyConversionExtension>;
|
||||
using PyCreatorFunction = std::function<ov::OutputVector(const ov::frontend::NodeContext*)>;
|
||||
PyConversionExtension(const std::string& op_type, const PyCreatorFunction& f)
|
||||
: ConversionExtension(op_type, [f](const ov::frontend::NodeContext& node) -> ov::OutputVector {
|
||||
PyConversionExtension(const std::string& op_type, const std::string& op_domain, const PyCreatorFunction& f)
|
||||
: ConversionExtension(op_type, op_domain, [f](const ov::frontend::NodeContext& node) -> ov::OutputVector {
|
||||
return f(static_cast<const ov::frontend::NodeContext*>(&node));
|
||||
}) {}
|
||||
PyConversionExtension(const std::string& op_type, const PyCreatorFunction& f)
|
||||
: PyConversionExtension(op_type, "", f) {}
|
||||
};
|
||||
py::class_<PyConversionExtension, PyConversionExtension::Ptr, ConversionExtension> ext(m,
|
||||
"ConversionExtensionONNX",
|
||||
@ -41,6 +43,10 @@ void regclass_frontend_onnx_ConversionExtension(py::module m) {
|
||||
return std::make_shared<PyConversionExtension>(op_type, f);
|
||||
}));
|
||||
|
||||
ext.def(py::init([](const std::string& op_type, const std::string& op_domain, const PyConversionExtension::PyCreatorFunction& f) {
|
||||
return std::make_shared<PyConversionExtension>(op_type, op_domain, f);
|
||||
}));
|
||||
|
||||
ext.def_property_readonly_static("m_converter", &ConversionExtension::get_converter);
|
||||
}
|
||||
|
||||
@ -78,4 +84,22 @@ void regclass_frontend_onnx_OpExtension(py::module m) {
|
||||
py::arg("fw_type_name"),
|
||||
py::arg("attr_names_map") = std::map<std::string, std::string>(),
|
||||
py::arg("attr_values_map") = std::map<std::string, py::object>());
|
||||
|
||||
ext.def(py::init([](const std::string& ov_type_name,
|
||||
const std::string& fw_type_name,
|
||||
const std::string& fw_domain,
|
||||
const std::map<std::string, std::string>& attr_names_map,
|
||||
const std::map<std::string, py::object>& attr_values_map) {
|
||||
|
||||
std::map<std::string, ov::Any> any_map;
|
||||
for (const auto& it : attr_values_map) {
|
||||
any_map[it.first] = py_object_to_any(it.second);
|
||||
}
|
||||
return std::make_shared<OpExtension<void>>(ov_type_name, fw_type_name, fw_domain, attr_names_map, any_map);
|
||||
}),
|
||||
py::arg("ov_type_name"),
|
||||
py::arg("fw_type_name"),
|
||||
py::arg("fw_domain"),
|
||||
py::arg("attr_names_map") = std::map<std::string, std::string>(),
|
||||
py::arg("attr_values_map") = std::map<std::string, py::object>());
|
||||
}
|
||||
|
@ -142,6 +142,24 @@ def create_onnx_model_for_op_extension():
|
||||
return make_model(graph, producer_name="ONNX Frontend")
|
||||
|
||||
|
||||
def create_onnx_model_extension_with_custom_domain():
|
||||
add = onnx.helper.make_node("CustomAdd", inputs=["x", "y"], outputs=["z"], domain="custom_domain")
|
||||
const_tensor = onnx.helper.make_tensor("const_tensor",
|
||||
onnx.TensorProto.FLOAT,
|
||||
(2, 2),
|
||||
[0.5, 1, 1.5, 2.0])
|
||||
const_node = onnx.helper.make_node("Constant", [], outputs=["const_node"],
|
||||
value=const_tensor, name="const_node")
|
||||
mul = onnx.helper.make_node("Mul", inputs=["z", "const_node"], outputs=["out"])
|
||||
input_tensors = [
|
||||
make_tensor_value_info("x", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
make_tensor_value_info("y", onnx.TensorProto.FLOAT, (2, 2)),
|
||||
]
|
||||
output_tensors = [make_tensor_value_info("out", onnx.TensorProto.FLOAT, (2, 2))]
|
||||
graph = make_graph([add, const_node, mul], "graph", input_tensors, output_tensors)
|
||||
return make_model(graph, producer_name="ONNX Frontend")
|
||||
|
||||
|
||||
def run_model(model, *inputs, expected):
|
||||
runtime = get_runtime()
|
||||
computation = runtime.computation(model)
|
||||
@ -159,6 +177,7 @@ onnx_model_2_filename = "model2.onnx"
|
||||
onnx_model_with_custom_attributes_filename = "model_custom_attributes.onnx"
|
||||
onnx_model_with_subgraphs_filename = "model_subgraphs.onnx"
|
||||
onnx_model_for_op_extension_test = "model_op_extension.onnx"
|
||||
onnx_model_extension_with_custom_domain = "model_extension_custom_domain.onnx"
|
||||
ONNX_FRONTEND_NAME = "onnx"
|
||||
|
||||
|
||||
@ -169,6 +188,7 @@ def setup_module():
|
||||
onnx_model_with_custom_attributes_filename)
|
||||
onnx.save_model(create_onnx_model_with_subgraphs(), onnx_model_with_subgraphs_filename)
|
||||
onnx.save_model(create_onnx_model_for_op_extension(), onnx_model_for_op_extension_test)
|
||||
onnx.save_model(create_onnx_model_extension_with_custom_domain(), onnx_model_extension_with_custom_domain)
|
||||
|
||||
|
||||
def teardown_module():
|
||||
@ -177,6 +197,7 @@ def teardown_module():
|
||||
os.remove(onnx_model_with_custom_attributes_filename)
|
||||
os.remove(onnx_model_with_subgraphs_filename)
|
||||
os.remove(onnx_model_for_op_extension_test)
|
||||
os.remove(onnx_model_extension_with_custom_domain)
|
||||
|
||||
|
||||
def skip_if_onnx_frontend_is_disabled():
|
||||
@ -482,6 +503,53 @@ def test_onnx_conversion_extension():
|
||||
assert invoked
|
||||
|
||||
|
||||
def test_onnx_conversion_extension_with_custom_domain():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
|
||||
# use specific (openvino.frontend.onnx) import here
|
||||
from openvino.frontend.onnx import ConversionExtension
|
||||
from openvino.frontend import NodeContext
|
||||
import openvino.runtime.opset8 as ops
|
||||
|
||||
fe = fem.load_by_model(onnx_model_extension_with_custom_domain)
|
||||
assert fe
|
||||
assert fe.get_name() == "onnx"
|
||||
|
||||
invoked = False
|
||||
|
||||
def custom_converter(node: NodeContext):
|
||||
nonlocal invoked
|
||||
invoked = True
|
||||
input_1 = node.get_input(0)
|
||||
input_2 = node.get_input(1)
|
||||
add = ops.add(input_1, input_2)
|
||||
return [add.output(0)]
|
||||
|
||||
fe.add_extension(ConversionExtension("CustomAdd", "custom_domain", custom_converter))
|
||||
input_model = fe.load(onnx_model_extension_with_custom_domain)
|
||||
assert input_model
|
||||
model = fe.convert(input_model)
|
||||
assert model
|
||||
assert invoked
|
||||
|
||||
|
||||
def test_onnx_op_extension_with_custom_domain():
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
|
||||
# use specific (openvino.frontend.onnx) import here
|
||||
from openvino.frontend.onnx import OpExtension
|
||||
|
||||
fe = fem.load_by_model(onnx_model_extension_with_custom_domain)
|
||||
assert fe
|
||||
assert fe.get_name() == "onnx"
|
||||
|
||||
fe.add_extension(OpExtension("opset1.Add", "CustomAdd", "custom_domain", {}, {"auto_broadcast": "numpy"}))
|
||||
input_model = fe.load(onnx_model_extension_with_custom_domain)
|
||||
assert input_model
|
||||
model = fe.convert(input_model)
|
||||
assert model
|
||||
|
||||
|
||||
@pytest.mark.parametrize("opset_prefix", ["opset1.", "opset1::", "opset8.", "opset8::", ""])
|
||||
def test_op_extension_specify_opset(opset_prefix):
|
||||
skip_if_onnx_frontend_is_disabled()
|
||||
|
@ -31,6 +31,72 @@ inline const ov::OpSet& get_opset_by_name(const std::string& opset_name) {
|
||||
}
|
||||
}
|
||||
|
||||
/// \brief The helper function to create an instance of ov::Node class initialized by provided type name.
|
||||
/// Expected formats:
|
||||
/// - opsetN::OpName
|
||||
/// - opsetN.OpName
|
||||
/// - OpName
|
||||
/// \param ov_type_name Type name of created ov::Node.
|
||||
inline std::shared_ptr<ov::Node> create_ov_node_by_name(const std::string& ov_type_name) {
|
||||
auto split = [](const std::string& s, const std::string& delimiter) {
|
||||
size_t pos_start = 0, pos_end, delim_len = delimiter.length();
|
||||
std::string token;
|
||||
std::vector<std::string> res;
|
||||
|
||||
while ((pos_end = s.find(delimiter, pos_start)) != std::string::npos) {
|
||||
token = s.substr(pos_start, pos_end - pos_start);
|
||||
pos_start = pos_end + delim_len;
|
||||
res.push_back(token);
|
||||
}
|
||||
|
||||
res.push_back(s.substr(pos_start));
|
||||
return res;
|
||||
};
|
||||
|
||||
std::string opset_name;
|
||||
std::string op_name;
|
||||
auto cnt_colons = std::count(ov_type_name.begin(), ov_type_name.end(), ':');
|
||||
auto cnt_dots = std::count(ov_type_name.begin(), ov_type_name.end(), '.');
|
||||
if (cnt_colons == 2 && cnt_dots == 0) {
|
||||
auto divided = split(ov_type_name, "::");
|
||||
if (divided.size() != 2) {
|
||||
FRONT_END_GENERAL_CHECK(false,
|
||||
"Invalid OpenVINO operation format, one of the next is expected:"
|
||||
"opsetN::OpName or opsetN.OpName or OpName. Provided operation format: ",
|
||||
ov_type_name);
|
||||
}
|
||||
opset_name = divided[0];
|
||||
op_name = divided[1];
|
||||
} else if (cnt_colons == 0 && cnt_dots == 1) {
|
||||
auto divided = split(ov_type_name, ".");
|
||||
if (divided.size() != 2) {
|
||||
FRONT_END_GENERAL_CHECK(false,
|
||||
"Invalid OpenVINO operation format, one of the next is expected:"
|
||||
"opsetN::OpName or opsetN.OpName or OpName. Provided operation format: ",
|
||||
ov_type_name);
|
||||
}
|
||||
opset_name = divided[0];
|
||||
op_name = divided[1];
|
||||
} else if (cnt_colons == 0 && cnt_dots == 0) {
|
||||
opset_name = "latest";
|
||||
op_name = ov_type_name;
|
||||
} else {
|
||||
FRONT_END_GENERAL_CHECK(false,
|
||||
"Invalid OpenVINO operation format, one of the next is expected: \n"
|
||||
"opsetN::OpName or opsetN.OpName or OpName. Provided operation format: ",
|
||||
ov_type_name);
|
||||
}
|
||||
|
||||
const auto& opset = get_opset_by_name(opset_name);
|
||||
if (!opset.contains_type(op_name)) {
|
||||
FRONT_END_GENERAL_CHECK(false,
|
||||
"OpenVINO opset doesn't contain operation with "
|
||||
"name ",
|
||||
op_name);
|
||||
}
|
||||
return std::shared_ptr<ngraph::Node>(opset.create(op_name));
|
||||
}
|
||||
|
||||
// One-to-one operation mapping for OVOpType != void which means OV type is specified by OVOpType
|
||||
// See a specialization for OVOptype = void
|
||||
template <typename BaseConversionType, typename OVOpType = void>
|
||||
@ -139,72 +205,8 @@ OpExtensionBase<BaseConversionType, void>::OpExtensionBase(const std::string& ov
|
||||
const std::map<std::string, ov::Any>& attr_values_map)
|
||||
: BaseConversionType(fw_type_name,
|
||||
OpConversionFunction(
|
||||
[=]() -> std::shared_ptr<ov::Node> {
|
||||
auto split = [](const std::string& s, const std::string& delimiter) {
|
||||
size_t pos_start = 0, pos_end, delim_len = delimiter.length();
|
||||
std::string token;
|
||||
std::vector<std::string> res;
|
||||
|
||||
while ((pos_end = s.find(delimiter, pos_start)) != std::string::npos) {
|
||||
token = s.substr(pos_start, pos_end - pos_start);
|
||||
pos_start = pos_end + delim_len;
|
||||
res.push_back(token);
|
||||
}
|
||||
|
||||
res.push_back(s.substr(pos_start));
|
||||
return res;
|
||||
};
|
||||
|
||||
// Expected formats:
|
||||
// opsetN::OpName
|
||||
// opsetN.OpName
|
||||
// OpName
|
||||
std::string opset_name;
|
||||
std::string op_name;
|
||||
auto cnt_colons = std::count(ov_type_name.begin(), ov_type_name.end(), ':');
|
||||
auto cnt_dots = std::count(ov_type_name.begin(), ov_type_name.end(), '.');
|
||||
if (cnt_colons == 2 && cnt_dots == 0) {
|
||||
auto divided = split(ov_type_name, "::");
|
||||
if (divided.size() != 2) {
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
false,
|
||||
"Invalid OpenVINO operation format, one of the next is expected:"
|
||||
"opsetN::OpName or opsetN.OpName or OpName. Provided operation format: ",
|
||||
ov_type_name);
|
||||
}
|
||||
opset_name = divided[0];
|
||||
op_name = divided[1];
|
||||
} else if (cnt_colons == 0 && cnt_dots == 1) {
|
||||
auto divided = split(ov_type_name, ".");
|
||||
if (divided.size() != 2) {
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
false,
|
||||
"Invalid OpenVINO operation format, one of the next is expected:"
|
||||
"opsetN::OpName or opsetN.OpName or OpName. Provided operation format: ",
|
||||
ov_type_name);
|
||||
}
|
||||
opset_name = divided[0];
|
||||
op_name = divided[1];
|
||||
} else if (cnt_colons == 0 && cnt_dots == 0) {
|
||||
opset_name = "latest";
|
||||
op_name = ov_type_name;
|
||||
} else {
|
||||
FRONT_END_GENERAL_CHECK(
|
||||
false,
|
||||
"Invalid OpenVINO operation format, one of the next is expected: \n"
|
||||
"opsetN::OpName or opsetN.OpName or OpName. Provided operation format: ",
|
||||
ov_type_name);
|
||||
}
|
||||
|
||||
const auto& opset = get_opset_by_name(opset_name);
|
||||
if (!opset.contains_type(op_name)) {
|
||||
FRONT_END_GENERAL_CHECK(false,
|
||||
"OpenVINO opset doesn't contain operation with "
|
||||
"name ",
|
||||
op_name);
|
||||
}
|
||||
|
||||
return std::shared_ptr<ngraph::Node>(opset.create(op_name));
|
||||
[ov_type_name]() {
|
||||
return create_ov_node_by_name(ov_type_name);
|
||||
},
|
||||
attr_names_map,
|
||||
attr_values_map)) {}
|
||||
|
@ -11,7 +11,61 @@ namespace frontend {
|
||||
namespace onnx {
|
||||
|
||||
template <typename OVOpType = void>
|
||||
using OpExtension = ov::frontend::OpExtensionBase<ConversionExtension, OVOpType>;
|
||||
class OpExtension : public ConversionExtension {
|
||||
public:
|
||||
OpExtension(const std::map<std::string, std::string>& attr_names_map = {},
|
||||
const std::map<std::string, ov::Any>& attr_values_map = {})
|
||||
: OpExtension(OVOpType::get_type_info_static().name, "", attr_names_map, attr_values_map) {}
|
||||
|
||||
OpExtension(const std::string& fw_type_name,
|
||||
const std::map<std::string, std::string>& attr_names_map = {},
|
||||
const std::map<std::string, ov::Any>& attr_values_map = {})
|
||||
: OpExtension(fw_type_name, "", attr_names_map, attr_values_map) {}
|
||||
|
||||
OpExtension(const std::string& fw_type_name,
|
||||
const std::string& fw_domain,
|
||||
const std::map<std::string, std::string>& attr_names_map = {},
|
||||
const std::map<std::string, ov::Any>& attr_values_map = {})
|
||||
: ConversionExtension(fw_type_name,
|
||||
fw_domain,
|
||||
OpConversionFunction(
|
||||
[]() {
|
||||
return std::make_shared<OVOpType>();
|
||||
},
|
||||
attr_names_map,
|
||||
attr_values_map)) {}
|
||||
};
|
||||
|
||||
template <>
|
||||
class OpExtension<void> : public ConversionExtension {
|
||||
public:
|
||||
OpExtension() = delete;
|
||||
|
||||
explicit OpExtension(const std::string& fw_ov_type_name,
|
||||
const std::map<std::string, std::string>& attr_names_map = {},
|
||||
const std::map<std::string, ov::Any>& attr_values_map = {})
|
||||
: OpExtension(fw_ov_type_name, fw_ov_type_name, attr_names_map, attr_values_map) {}
|
||||
|
||||
OpExtension(const std::string& ov_type_name,
|
||||
const std::string& fw_type_name,
|
||||
const std::map<std::string, std::string>& attr_names_map = {},
|
||||
const std::map<std::string, ov::Any>& attr_values_map = {})
|
||||
: OpExtension(ov_type_name, fw_type_name, "", attr_names_map, attr_values_map) {}
|
||||
|
||||
OpExtension(const std::string& ov_type_name,
|
||||
const std::string& fw_type_name,
|
||||
const std::string& fw_domain_name,
|
||||
const std::map<std::string, std::string>& attr_names_map = {},
|
||||
const std::map<std::string, ov::Any>& attr_values_map = {})
|
||||
: ConversionExtension(fw_type_name,
|
||||
fw_domain_name,
|
||||
OpConversionFunction(
|
||||
[ov_type_name]() {
|
||||
return create_ov_node_by_name(ov_type_name);
|
||||
},
|
||||
attr_names_map,
|
||||
attr_values_map)) {}
|
||||
};
|
||||
|
||||
} // namespace onnx
|
||||
} // namespace frontend
|
||||
|
45
src/frontends/onnx/tests/models/relu_custom_domain.prototxt
Normal file
45
src/frontends/onnx/tests/models/relu_custom_domain.prototxt
Normal file
@ -0,0 +1,45 @@
|
||||
ir_version: 4
|
||||
producer_name: "OV ONNX FE"
|
||||
graph {
|
||||
node {
|
||||
input: "A"
|
||||
output: "B"
|
||||
op_type: "CustomRelu"
|
||||
domain: "my_custom_domain"
|
||||
}
|
||||
input {
|
||||
name: "A"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "B"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 1
|
||||
}
|
@ -4,10 +4,12 @@
|
||||
|
||||
#include "op_extension.hpp"
|
||||
|
||||
#include "common_test_utils/file_utils.hpp"
|
||||
#include "onnx_utils.hpp"
|
||||
#include "openvino/frontend/extension/op.hpp"
|
||||
#include "openvino/frontend/onnx/extension/op.hpp"
|
||||
#include "openvino/frontend/onnx/frontend.hpp"
|
||||
#include "openvino/op/relu.hpp"
|
||||
#include "so_extension.hpp"
|
||||
|
||||
using namespace ov::frontend;
|
||||
@ -141,4 +143,30 @@ INSTANTIATE_TEST_SUITE_P(ONNXOpExtensionViaONNXConstructor,
|
||||
INSTANTIATE_TEST_SUITE_P(ONNXOpExtensionViaCommonConstructor,
|
||||
FrontEndOpExtensionTest,
|
||||
::testing::Values(getTestDataOpExtensionViaCommonConstructor()),
|
||||
FrontEndOpExtensionTest::getTestCaseName);
|
||||
FrontEndOpExtensionTest::getTestCaseName);
|
||||
|
||||
TEST(ONNXOpExtensionViaCommonConstructor, onnx_op_extension_via_template_arg_with_custom_domain) {
|
||||
const auto ext = std::make_shared<onnx::OpExtension<ov::op::v0::Relu>>("CustomRelu", "my_custom_domain");
|
||||
|
||||
auto fe = std::make_shared<ov::frontend::onnx::FrontEnd>();
|
||||
fe->add_extension(ext);
|
||||
|
||||
const auto input_model = fe->load(CommonTestUtils::getModelFromTestModelZoo(
|
||||
ov::util::path_join({TEST_ONNX_MODELS_DIRNAME, "relu_custom_domain.onnx"})));
|
||||
|
||||
std::shared_ptr<ov::Model> model;
|
||||
EXPECT_NO_THROW(fe->convert(input_model));
|
||||
}
|
||||
|
||||
TEST(ONNXOpExtensionViaCommonConstructor, onnx_op_extension_via_ov_type_name_with_custom_domain) {
|
||||
const auto ext = std::make_shared<onnx::OpExtension<>>("opset1::Relu", "CustomRelu", "my_custom_domain");
|
||||
|
||||
auto fe = std::make_shared<ov::frontend::onnx::FrontEnd>();
|
||||
fe->add_extension(ext);
|
||||
|
||||
const auto input_model = fe->load(CommonTestUtils::getModelFromTestModelZoo(
|
||||
ov::util::path_join({TEST_ONNX_MODELS_DIRNAME, "relu_custom_domain.onnx"})));
|
||||
|
||||
std::shared_ptr<ov::Model> model;
|
||||
EXPECT_NO_THROW(fe->convert(input_model));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user