[ONNX] Add support for Softmax v11 in opset 7 models (#4340)

This commit is contained in:
Tomasz Socha 2021-02-16 11:58:39 +01:00 committed by GitHub
parent cc645d50e4
commit 185aaacc07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 58 additions and 50 deletions

View File

@ -78,6 +78,47 @@ namespace ngraph
}
}
return {result};
}
}
namespace set_7
{
OutputVector softmax(const Node& node)
{
const auto data = node.get_ng_inputs().at(0);
const auto data_rank = data.get_partial_shape().rank();
NGRAPH_CHECK(data_rank.is_static(),
"ONNX Softmax data rank needs to be known (static)");
const auto axis = node.get_attribute_value<int64_t>("axis", 1);
std::shared_ptr<ngraph::Node> result;
switch (data_rank.get_length())
{
case 0:
{
result =
default_opset::Constant::create(data.get_element_type(), Shape{}, {1});
break;
}
case 1:
{
// checks if the axis belongs to the allowed values set (-1 and 0 for 1D)
ngraph::normalize_axis(
node.get_description(), axis, data.get_partial_shape().rank());
result = std::make_shared<default_opset::Softmax>(data, 0);
break;
}
default:
{
const auto normalized_axis = ngraph::normalize_axis(
node.get_description(), axis, data.get_partial_shape().rank());
result = std::make_shared<default_opset::Softmax>(data, normalized_axis);
break;
}
}
return {result};
}
}

View File

@ -31,7 +31,12 @@ namespace ngraph
} // namespace set_1
} // namespace op
namespace set_7
{
OutputVector softmax(const Node& node);
} // namespace set_7
} // namespace op
} // namespace onnx_import

View File

@ -439,6 +439,9 @@ namespace ngraph
REGISTER_OPERATOR("Slice", 1, slice);
REGISTER_OPERATOR("Slice", 10, slice);
REGISTER_OPERATOR("Softmax", 1, softmax);
// Softmax v7 should be in the 11th opset but,
// other frameworks(mxnet and onnxruntime) already use for older models.
REGISTER_OPERATOR("Softmax", 7, softmax);
REGISTER_OPERATOR("Softplus", 1, softplus);
REGISTER_OPERATOR("Softsign", 1, softsign);
REGISTER_OPERATOR("SpaceToDepth", 1, space_to_depth);

View File

@ -197,7 +197,6 @@ xfail_issue_39662 = xfail_test(reason="RuntimeError: 'ScatterElementsUpdate' lay
"indices value that points to non-existing output tensor element")
xfail_issue_39704 = xfail_test(reason="ResNet101_DUC_HDC - AssertionError: zoo models results mismatch")
xfail_issue_37973 = xfail_test(reason="TF Inception V2 - AssertionError: zoo models results mismatch")
xfail_issue_47430 = xfail_test(reason="FCN ResNet models - AssertionError: zoo models results mismatch")
xfail_issue_47495 = xfail_test(reason="BertSquad-10 from MSFT - AssertionError: zoo models results mismatch")

View File

@ -37,7 +37,10 @@ def runtime(backend_name: str = "CPU") -> "Runtime":
def get_runtime():
"""Return runtime object."""
return runtime(backend_name=tests.BACKEND_NAME)
if tests.BACKEND_NAME is not None:
return runtime(backend_name=tests.BACKEND_NAME)
else:
return runtime()
def _convert_inputs(cnn_network: IENetwork) -> None:

View File

@ -619,8 +619,6 @@ tests_expected_to_fail = [
(xfail_issue_44839,
"OnnxBackendNodeModelTest.test_logsoftmax_axis_0_cpu",
"OnnxBackendNodeModelTest.test_logsoftmax_axis_1_cpu",
"OnnxBackendNodeModelTest.test_softmax_axis_0_cpu",
"OnnxBackendNodeModelTest.test_softmax_axis_1_cpu",
"OnnxBackendNodeModelTest.test_softmax_default_axis_cpu",
"OnnxBackendNodeModelTest.test_hardmax_axis_0_cpu",
"OnnxBackendNodeModelTest.test_hardmax_axis_1_cpu",

View File

@ -244,45 +244,6 @@ def test_hardsigmoid():
assert np.allclose(ng_results, [expected])
def test_softmax():
def softmax_2d(x):
max_x = np.max(x, axis=1).reshape((-1, 1))
exp_x = np.exp(x - max_x)
return exp_x / np.sum(exp_x, axis=1).reshape((-1, 1))
np.random.seed(133391)
data = np.random.randn(3, 4, 5).astype(np.float32)
node = onnx.helper.make_node("Softmax", inputs=["x"], outputs=["y"], axis=0)
expected = softmax_2d(data.reshape(1, 60)).reshape(3, 4, 5)
ng_results = run_node(node, [data])
assert np.allclose(ng_results, [expected])
node = onnx.helper.make_node("Softmax", inputs=["x"], outputs=["y"], axis=1)
expected = softmax_2d(data.reshape(3, 20)).reshape(3, 4, 5)
ng_results = run_node(node, [data])
assert np.allclose(ng_results, [expected])
# default axis is 1
node = onnx.helper.make_node("Softmax", inputs=["x"], outputs=["y"])
ng_results = run_node(node, [data])
assert np.allclose(ng_results, [expected])
node = onnx.helper.make_node("Softmax", inputs=["x"], outputs=["y"], axis=2)
expected = softmax_2d(data.reshape(12, 5)).reshape(3, 4, 5)
ng_results = run_node(node, [data])
assert np.allclose(ng_results, [expected])
node = onnx.helper.make_node("Softmax", inputs=["x"], outputs=["y"], axis=-1)
expected = softmax_2d(data.reshape(12, 5)).reshape(3, 4, 5)
ng_results = run_node(node, [data])
assert np.allclose(ng_results, [expected])
with pytest.raises(RuntimeError):
node = onnx.helper.make_node("Softmax", inputs=["x"], outputs=["y"], axis=3)
ng_results = run_node(node, [data])
def test_logsoftmax():
def logsoftmax_2d(x):
max_x = np.max(x, axis=1).reshape((-1, 1))

View File

@ -35,7 +35,6 @@ from tests import (
xfail_issue_39669,
xfail_issue_38726,
xfail_issue_40686,
xfail_issue_39704,
xfail_issue_37973,
xfail_issue_47430,
xfail_issue_47495)
@ -176,7 +175,6 @@ if len(zoo_models) > 0:
if tests.MODEL_ZOO_XFAIL:
execution_xfail_list = [
# ONNX Model Zoo
(xfail_issue_39704, "test_onnx_model_zoo_vision_object_detection_segmentation_duc_model_ResNet101_DUC_7_ResNet101_DUC_HDC_ResNet101_DUC_HDC_cpu"),
(xfail_issue_40957, "test_onnx_model_zoo_text_machine_comprehension_roberta_model_roberta_base_11_roberta_base_11_roberta_base_11_cpu"),
(xfail_issue_40957, "test_onnx_model_zoo_text_machine_comprehension_bert_squad_model_bertsquad_8_download_sample_8_bertsquad8_cpu"),
(xfail_issue_39669, "test_onnx_model_zoo_text_machine_comprehension_t5_model_t5_encoder_12_t5_encoder_cpu"),

View File

@ -69,7 +69,7 @@ class OpenVinoOnnxBackend(Backend):
): # type: (...) -> OpenVinoOnnxBackendRep
super().prepare(onnx_model, device, **kwargs)
ng_model_function = import_onnx_model(onnx_model)
return OpenVinoOnnxBackendRep(ng_model_function, cls.backend_name)
return OpenVinoOnnxBackendRep(ng_model_function, device)
@classmethod
def run_model(
@ -79,7 +79,7 @@ class OpenVinoOnnxBackend(Backend):
device="CPU", # type: Text
**kwargs # type: Any
): # type: (...) -> Tuple[Any, ...]
cls.prepare(model, device, **kwargs).run()
return cls.prepare(model, device, **kwargs).run(inputs)
@classmethod
def run_node(

View File

@ -52,5 +52,5 @@ graph {
}
}
opset_import {
version: 7
version: 6
}

View File

@ -52,5 +52,5 @@ graph {
}
}
opset_import {
version: 7
version: 6
}