Added support of custom domain for OpExtension (C++/Python) and ConversionExtension (Python) (#14375)

* Added support for ONNX OpExtension with custom domain

* review remarks

* move domain ctor to onnx OpExtension

* code refactor + new test

* styles applied

* [Python API] Support extensions with custom domains

* try to fix windows build error

* removed unnecessary stores
This commit is contained in:
Mateusz Bencer 2022-12-08 18:47:45 +01:00 committed by GitHub
parent 156905c381
commit 507cbe7045
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 291 additions and 70 deletions

View File

@ -28,10 +28,12 @@ void regclass_frontend_onnx_ConversionExtension(py::module m) {
public:
using Ptr = std::shared_ptr<PyConversionExtension>;
using PyCreatorFunction = std::function<ov::OutputVector(const ov::frontend::NodeContext*)>;
PyConversionExtension(const std::string& op_type, const PyCreatorFunction& f)
: ConversionExtension(op_type, [f](const ov::frontend::NodeContext& node) -> ov::OutputVector {
PyConversionExtension(const std::string& op_type, const std::string& op_domain, const PyCreatorFunction& f)
: ConversionExtension(op_type, op_domain, [f](const ov::frontend::NodeContext& node) -> ov::OutputVector {
return f(static_cast<const ov::frontend::NodeContext*>(&node));
}) {}
PyConversionExtension(const std::string& op_type, const PyCreatorFunction& f)
: PyConversionExtension(op_type, "", f) {}
};
py::class_<PyConversionExtension, PyConversionExtension::Ptr, ConversionExtension> ext(m,
"ConversionExtensionONNX",
@ -41,6 +43,10 @@ void regclass_frontend_onnx_ConversionExtension(py::module m) {
return std::make_shared<PyConversionExtension>(op_type, f);
}));
ext.def(py::init([](const std::string& op_type, const std::string& op_domain, const PyConversionExtension::PyCreatorFunction& f) {
return std::make_shared<PyConversionExtension>(op_type, op_domain, f);
}));
ext.def_property_readonly_static("m_converter", &ConversionExtension::get_converter);
}
@ -78,4 +84,22 @@ void regclass_frontend_onnx_OpExtension(py::module m) {
py::arg("fw_type_name"),
py::arg("attr_names_map") = std::map<std::string, std::string>(),
py::arg("attr_values_map") = std::map<std::string, py::object>());
ext.def(py::init([](const std::string& ov_type_name,
const std::string& fw_type_name,
const std::string& fw_domain,
const std::map<std::string, std::string>& attr_names_map,
const std::map<std::string, py::object>& attr_values_map) {
std::map<std::string, ov::Any> any_map;
for (const auto& it : attr_values_map) {
any_map[it.first] = py_object_to_any(it.second);
}
return std::make_shared<OpExtension<void>>(ov_type_name, fw_type_name, fw_domain, attr_names_map, any_map);
}),
py::arg("ov_type_name"),
py::arg("fw_type_name"),
py::arg("fw_domain"),
py::arg("attr_names_map") = std::map<std::string, std::string>(),
py::arg("attr_values_map") = std::map<std::string, py::object>());
}

View File

@ -142,6 +142,24 @@ def create_onnx_model_for_op_extension():
return make_model(graph, producer_name="ONNX Frontend")
def create_onnx_model_extension_with_custom_domain():
add = onnx.helper.make_node("CustomAdd", inputs=["x", "y"], outputs=["z"], domain="custom_domain")
const_tensor = onnx.helper.make_tensor("const_tensor",
onnx.TensorProto.FLOAT,
(2, 2),
[0.5, 1, 1.5, 2.0])
const_node = onnx.helper.make_node("Constant", [], outputs=["const_node"],
value=const_tensor, name="const_node")
mul = onnx.helper.make_node("Mul", inputs=["z", "const_node"], outputs=["out"])
input_tensors = [
make_tensor_value_info("x", onnx.TensorProto.FLOAT, (2, 2)),
make_tensor_value_info("y", onnx.TensorProto.FLOAT, (2, 2)),
]
output_tensors = [make_tensor_value_info("out", onnx.TensorProto.FLOAT, (2, 2))]
graph = make_graph([add, const_node, mul], "graph", input_tensors, output_tensors)
return make_model(graph, producer_name="ONNX Frontend")
def run_model(model, *inputs, expected):
runtime = get_runtime()
computation = runtime.computation(model)
@ -159,6 +177,7 @@ onnx_model_2_filename = "model2.onnx"
onnx_model_with_custom_attributes_filename = "model_custom_attributes.onnx"
onnx_model_with_subgraphs_filename = "model_subgraphs.onnx"
onnx_model_for_op_extension_test = "model_op_extension.onnx"
onnx_model_extension_with_custom_domain = "model_extension_custom_domain.onnx"
ONNX_FRONTEND_NAME = "onnx"
@ -169,6 +188,7 @@ def setup_module():
onnx_model_with_custom_attributes_filename)
onnx.save_model(create_onnx_model_with_subgraphs(), onnx_model_with_subgraphs_filename)
onnx.save_model(create_onnx_model_for_op_extension(), onnx_model_for_op_extension_test)
onnx.save_model(create_onnx_model_extension_with_custom_domain(), onnx_model_extension_with_custom_domain)
def teardown_module():
@ -177,6 +197,7 @@ def teardown_module():
os.remove(onnx_model_with_custom_attributes_filename)
os.remove(onnx_model_with_subgraphs_filename)
os.remove(onnx_model_for_op_extension_test)
os.remove(onnx_model_extension_with_custom_domain)
def skip_if_onnx_frontend_is_disabled():
@ -482,6 +503,53 @@ def test_onnx_conversion_extension():
assert invoked
def test_onnx_conversion_extension_with_custom_domain():
skip_if_onnx_frontend_is_disabled()
# use specific (openvino.frontend.onnx) import here
from openvino.frontend.onnx import ConversionExtension
from openvino.frontend import NodeContext
import openvino.runtime.opset8 as ops
fe = fem.load_by_model(onnx_model_extension_with_custom_domain)
assert fe
assert fe.get_name() == "onnx"
invoked = False
def custom_converter(node: NodeContext):
nonlocal invoked
invoked = True
input_1 = node.get_input(0)
input_2 = node.get_input(1)
add = ops.add(input_1, input_2)
return [add.output(0)]
fe.add_extension(ConversionExtension("CustomAdd", "custom_domain", custom_converter))
input_model = fe.load(onnx_model_extension_with_custom_domain)
assert input_model
model = fe.convert(input_model)
assert model
assert invoked
def test_onnx_op_extension_with_custom_domain():
skip_if_onnx_frontend_is_disabled()
# use specific (openvino.frontend.onnx) import here
from openvino.frontend.onnx import OpExtension
fe = fem.load_by_model(onnx_model_extension_with_custom_domain)
assert fe
assert fe.get_name() == "onnx"
fe.add_extension(OpExtension("opset1.Add", "CustomAdd", "custom_domain", {}, {"auto_broadcast": "numpy"}))
input_model = fe.load(onnx_model_extension_with_custom_domain)
assert input_model
model = fe.convert(input_model)
assert model
@pytest.mark.parametrize("opset_prefix", ["opset1.", "opset1::", "opset8.", "opset8::", ""])
def test_op_extension_specify_opset(opset_prefix):
skip_if_onnx_frontend_is_disabled()

View File

@ -31,6 +31,72 @@ inline const ov::OpSet& get_opset_by_name(const std::string& opset_name) {
}
}
/// \brief The helper function to create an instance of ov::Node class initialized by provided type name.
/// Expected formats:
/// - opsetN::OpName
/// - opsetN.OpName
/// - OpName
/// \param ov_type_name Type name of created ov::Node.
inline std::shared_ptr<ov::Node> create_ov_node_by_name(const std::string& ov_type_name) {
auto split = [](const std::string& s, const std::string& delimiter) {
size_t pos_start = 0, pos_end, delim_len = delimiter.length();
std::string token;
std::vector<std::string> res;
while ((pos_end = s.find(delimiter, pos_start)) != std::string::npos) {
token = s.substr(pos_start, pos_end - pos_start);
pos_start = pos_end + delim_len;
res.push_back(token);
}
res.push_back(s.substr(pos_start));
return res;
};
std::string opset_name;
std::string op_name;
auto cnt_colons = std::count(ov_type_name.begin(), ov_type_name.end(), ':');
auto cnt_dots = std::count(ov_type_name.begin(), ov_type_name.end(), '.');
if (cnt_colons == 2 && cnt_dots == 0) {
auto divided = split(ov_type_name, "::");
if (divided.size() != 2) {
FRONT_END_GENERAL_CHECK(false,
"Invalid OpenVINO operation format, one of the next is expected:"
"opsetN::OpName or opsetN.OpName or OpName. Provided operation format: ",
ov_type_name);
}
opset_name = divided[0];
op_name = divided[1];
} else if (cnt_colons == 0 && cnt_dots == 1) {
auto divided = split(ov_type_name, ".");
if (divided.size() != 2) {
FRONT_END_GENERAL_CHECK(false,
"Invalid OpenVINO operation format, one of the next is expected:"
"opsetN::OpName or opsetN.OpName or OpName. Provided operation format: ",
ov_type_name);
}
opset_name = divided[0];
op_name = divided[1];
} else if (cnt_colons == 0 && cnt_dots == 0) {
opset_name = "latest";
op_name = ov_type_name;
} else {
FRONT_END_GENERAL_CHECK(false,
"Invalid OpenVINO operation format, one of the next is expected: \n"
"opsetN::OpName or opsetN.OpName or OpName. Provided operation format: ",
ov_type_name);
}
const auto& opset = get_opset_by_name(opset_name);
if (!opset.contains_type(op_name)) {
FRONT_END_GENERAL_CHECK(false,
"OpenVINO opset doesn't contain operation with "
"name ",
op_name);
}
return std::shared_ptr<ngraph::Node>(opset.create(op_name));
}
// One-to-one operation mapping for OVOpType != void which means OV type is specified by OVOpType
// See a specialization for OVOptype = void
template <typename BaseConversionType, typename OVOpType = void>
@ -139,72 +205,8 @@ OpExtensionBase<BaseConversionType, void>::OpExtensionBase(const std::string& ov
const std::map<std::string, ov::Any>& attr_values_map)
: BaseConversionType(fw_type_name,
OpConversionFunction(
[=]() -> std::shared_ptr<ov::Node> {
auto split = [](const std::string& s, const std::string& delimiter) {
size_t pos_start = 0, pos_end, delim_len = delimiter.length();
std::string token;
std::vector<std::string> res;
while ((pos_end = s.find(delimiter, pos_start)) != std::string::npos) {
token = s.substr(pos_start, pos_end - pos_start);
pos_start = pos_end + delim_len;
res.push_back(token);
}
res.push_back(s.substr(pos_start));
return res;
};
// Expected formats:
// opsetN::OpName
// opsetN.OpName
// OpName
std::string opset_name;
std::string op_name;
auto cnt_colons = std::count(ov_type_name.begin(), ov_type_name.end(), ':');
auto cnt_dots = std::count(ov_type_name.begin(), ov_type_name.end(), '.');
if (cnt_colons == 2 && cnt_dots == 0) {
auto divided = split(ov_type_name, "::");
if (divided.size() != 2) {
FRONT_END_GENERAL_CHECK(
false,
"Invalid OpenVINO operation format, one of the next is expected:"
"opsetN::OpName or opsetN.OpName or OpName. Provided operation format: ",
ov_type_name);
}
opset_name = divided[0];
op_name = divided[1];
} else if (cnt_colons == 0 && cnt_dots == 1) {
auto divided = split(ov_type_name, ".");
if (divided.size() != 2) {
FRONT_END_GENERAL_CHECK(
false,
"Invalid OpenVINO operation format, one of the next is expected:"
"opsetN::OpName or opsetN.OpName or OpName. Provided operation format: ",
ov_type_name);
}
opset_name = divided[0];
op_name = divided[1];
} else if (cnt_colons == 0 && cnt_dots == 0) {
opset_name = "latest";
op_name = ov_type_name;
} else {
FRONT_END_GENERAL_CHECK(
false,
"Invalid OpenVINO operation format, one of the next is expected: \n"
"opsetN::OpName or opsetN.OpName or OpName. Provided operation format: ",
ov_type_name);
}
const auto& opset = get_opset_by_name(opset_name);
if (!opset.contains_type(op_name)) {
FRONT_END_GENERAL_CHECK(false,
"OpenVINO opset doesn't contain operation with "
"name ",
op_name);
}
return std::shared_ptr<ngraph::Node>(opset.create(op_name));
[ov_type_name]() {
return create_ov_node_by_name(ov_type_name);
},
attr_names_map,
attr_values_map)) {}

View File

@ -11,7 +11,61 @@ namespace frontend {
namespace onnx {
template <typename OVOpType = void>
using OpExtension = ov::frontend::OpExtensionBase<ConversionExtension, OVOpType>;
class OpExtension : public ConversionExtension {
public:
OpExtension(const std::map<std::string, std::string>& attr_names_map = {},
const std::map<std::string, ov::Any>& attr_values_map = {})
: OpExtension(OVOpType::get_type_info_static().name, "", attr_names_map, attr_values_map) {}
OpExtension(const std::string& fw_type_name,
const std::map<std::string, std::string>& attr_names_map = {},
const std::map<std::string, ov::Any>& attr_values_map = {})
: OpExtension(fw_type_name, "", attr_names_map, attr_values_map) {}
OpExtension(const std::string& fw_type_name,
const std::string& fw_domain,
const std::map<std::string, std::string>& attr_names_map = {},
const std::map<std::string, ov::Any>& attr_values_map = {})
: ConversionExtension(fw_type_name,
fw_domain,
OpConversionFunction(
[]() {
return std::make_shared<OVOpType>();
},
attr_names_map,
attr_values_map)) {}
};
template <>
class OpExtension<void> : public ConversionExtension {
public:
OpExtension() = delete;
explicit OpExtension(const std::string& fw_ov_type_name,
const std::map<std::string, std::string>& attr_names_map = {},
const std::map<std::string, ov::Any>& attr_values_map = {})
: OpExtension(fw_ov_type_name, fw_ov_type_name, attr_names_map, attr_values_map) {}
OpExtension(const std::string& ov_type_name,
const std::string& fw_type_name,
const std::map<std::string, std::string>& attr_names_map = {},
const std::map<std::string, ov::Any>& attr_values_map = {})
: OpExtension(ov_type_name, fw_type_name, "", attr_names_map, attr_values_map) {}
OpExtension(const std::string& ov_type_name,
const std::string& fw_type_name,
const std::string& fw_domain_name,
const std::map<std::string, std::string>& attr_names_map = {},
const std::map<std::string, ov::Any>& attr_values_map = {})
: ConversionExtension(fw_type_name,
fw_domain_name,
OpConversionFunction(
[ov_type_name]() {
return create_ov_node_by_name(ov_type_name);
},
attr_names_map,
attr_values_map)) {}
};
} // namespace onnx
} // namespace frontend

View File

@ -0,0 +1,45 @@
ir_version: 4
producer_name: "OV ONNX FE"
graph {
node {
input: "A"
output: "B"
op_type: "CustomRelu"
domain: "my_custom_domain"
}
input {
name: "A"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
output {
name: "B"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 1
}

View File

@ -4,10 +4,12 @@
#include "op_extension.hpp"
#include "common_test_utils/file_utils.hpp"
#include "onnx_utils.hpp"
#include "openvino/frontend/extension/op.hpp"
#include "openvino/frontend/onnx/extension/op.hpp"
#include "openvino/frontend/onnx/frontend.hpp"
#include "openvino/op/relu.hpp"
#include "so_extension.hpp"
using namespace ov::frontend;
@ -141,4 +143,30 @@ INSTANTIATE_TEST_SUITE_P(ONNXOpExtensionViaONNXConstructor,
INSTANTIATE_TEST_SUITE_P(ONNXOpExtensionViaCommonConstructor,
FrontEndOpExtensionTest,
::testing::Values(getTestDataOpExtensionViaCommonConstructor()),
FrontEndOpExtensionTest::getTestCaseName);
FrontEndOpExtensionTest::getTestCaseName);
TEST(ONNXOpExtensionViaCommonConstructor, onnx_op_extension_via_template_arg_with_custom_domain) {
const auto ext = std::make_shared<onnx::OpExtension<ov::op::v0::Relu>>("CustomRelu", "my_custom_domain");
auto fe = std::make_shared<ov::frontend::onnx::FrontEnd>();
fe->add_extension(ext);
const auto input_model = fe->load(CommonTestUtils::getModelFromTestModelZoo(
ov::util::path_join({TEST_ONNX_MODELS_DIRNAME, "relu_custom_domain.onnx"})));
std::shared_ptr<ov::Model> model;
EXPECT_NO_THROW(fe->convert(input_model));
}
TEST(ONNXOpExtensionViaCommonConstructor, onnx_op_extension_via_ov_type_name_with_custom_domain) {
const auto ext = std::make_shared<onnx::OpExtension<>>("opset1::Relu", "CustomRelu", "my_custom_domain");
auto fe = std::make_shared<ov::frontend::onnx::FrontEnd>();
fe->add_extension(ext);
const auto input_model = fe->load(CommonTestUtils::getModelFromTestModelZoo(
ov::util::path_join({TEST_ONNX_MODELS_DIRNAME, "relu_custom_domain.onnx"})));
std::shared_ptr<ov::Model> model;
EXPECT_NO_THROW(fe->convert(input_model));
}