From dbf855b3205995a7b2348a2cad83270de927e8f2 Mon Sep 17 00:00:00 2001 From: Tomasz Socha Date: Fri, 11 Dec 2020 11:57:50 +0100 Subject: [PATCH] [ONNX][Python][Tests] Update ONNX to onnx 1.8 (#3557) --- .../include/onnx_import/op/split.hpp | 6 + .../include/onnx_import/op/squeeze.hpp | 6 + .../include/onnx_import/op/unsqueeze.hpp | 7 +- ngraph/frontend/onnx_import/src/op/split.cpp | 25 +- .../frontend/onnx_import/src/op/squeeze.cpp | 22 ++ .../frontend/onnx_import/src/op/unsqueeze.cpp | 11 +- .../frontend/onnx_import/src/ops_bridge.cpp | 3 + ngraph/python/requirements_test.txt | 2 +- ngraph/python/tests/__init__.py | 18 +- .../tests/test_ngraph/test_data_movement.py | 2 - .../tests/test_ngraph/test_ops_fused.py | 6 +- .../tests/test_ngraph/test_ops_unary.py | 3 +- .../test_ngraph/test_sequence_processing.py | 5 +- ngraph/python/tests/test_onnx/test_backend.py | 245 +++++++++++++++--- .../python/tests/test_onnx/test_ops_binary.py | 10 +- .../tests/test_onnx/test_ops_reduction.py | 26 +- .../tests/test_onnx/test_ops_reshape.py | 68 +++-- .../python/tests/test_onnx/test_ops_unary.py | 11 +- 18 files changed, 384 insertions(+), 92 deletions(-) diff --git a/ngraph/frontend/onnx_import/include/onnx_import/op/split.hpp b/ngraph/frontend/onnx_import/include/onnx_import/op/split.hpp index c2cdeba117e..bdba88f37a6 100644 --- a/ngraph/frontend/onnx_import/include/onnx_import/op/split.hpp +++ b/ngraph/frontend/onnx_import/include/onnx_import/op/split.hpp @@ -31,6 +31,12 @@ namespace ngraph } // namespace set_1 + namespace set_13 + { + OutputVector split(const Node& node); + + } // namespace set_13 + } // namespace op } // namespace onnx_import diff --git a/ngraph/frontend/onnx_import/include/onnx_import/op/squeeze.hpp b/ngraph/frontend/onnx_import/include/onnx_import/op/squeeze.hpp index ad303d85b2a..0a4d59e1d48 100644 --- a/ngraph/frontend/onnx_import/include/onnx_import/op/squeeze.hpp +++ b/ngraph/frontend/onnx_import/include/onnx_import/op/squeeze.hpp @@ -31,6 +31,12 @@ namespace ngraph } // namespace set_1 + namespace set_13 + { + OutputVector squeeze(const Node& node); + + } // namespace set_13 + } // namespace op } // namespace onnx_import diff --git a/ngraph/frontend/onnx_import/include/onnx_import/op/unsqueeze.hpp b/ngraph/frontend/onnx_import/include/onnx_import/op/unsqueeze.hpp index 4f4fe1cdd74..06ecbba923a 100644 --- a/ngraph/frontend/onnx_import/include/onnx_import/op/unsqueeze.hpp +++ b/ngraph/frontend/onnx_import/include/onnx_import/op/unsqueeze.hpp @@ -31,7 +31,12 @@ namespace ngraph } // namespace set_1 - } // namespace op + namespace set_13 + { + OutputVector unsqueeze(const Node& node); + + } // namespace set_13 + } // namespace op } // namespace onnx_import diff --git a/ngraph/frontend/onnx_import/src/op/split.cpp b/ngraph/frontend/onnx_import/src/op/split.cpp index 8dce83789c6..aac6148de9a 100644 --- a/ngraph/frontend/onnx_import/src/op/split.cpp +++ b/ngraph/frontend/onnx_import/src/op/split.cpp @@ -49,7 +49,30 @@ namespace ngraph } // namespace set_1 - } // namespace op + namespace set_13 + { + OutputVector split(const Node& node) + { + const auto inputs = node.get_ng_inputs(); + const auto axis = node.get_attribute_value("axis", 0); + + if (inputs.size() < 2) + { + const auto outputs_number = node.get_output_names().size(); + return ngraph::builder::opset1::split(inputs.at(0), outputs_number, axis); + } + else + { + const auto axis_node = + default_opset::Constant::create(element::Type_t::i64, Shape{}, {axis}); + return {std::make_shared( + inputs.at(0), axis_node, inputs.at(1)) + ->outputs()}; + } + } + + } // namespace set_13 + } // namespace op } // namespace onnx_import diff --git a/ngraph/frontend/onnx_import/src/op/squeeze.cpp b/ngraph/frontend/onnx_import/src/op/squeeze.cpp index 035f5902957..58d49d573ee 100644 --- a/ngraph/frontend/onnx_import/src/op/squeeze.cpp +++ b/ngraph/frontend/onnx_import/src/op/squeeze.cpp @@ -45,6 +45,28 @@ namespace ngraph } } // namespace set_1 + + namespace set_13 + { + OutputVector squeeze(const Node& node) + { + auto inputs = node.get_ng_inputs(); + if (inputs.size() < 2) + { + std::vector axes{}; + auto axes_node = std::make_shared( + element::Type_t::u64, Shape{}, axes); + + return {std::make_shared(inputs.at(0), axes_node)}; + } + else + { + return { + std::make_shared(inputs.at(0), inputs.at(1))}; + } + } + + } // namespace set_13 } // namespace op } // namespace onnx_import } // namespace ngraph diff --git a/ngraph/frontend/onnx_import/src/op/unsqueeze.cpp b/ngraph/frontend/onnx_import/src/op/unsqueeze.cpp index ba2a64778e8..87d50867cc1 100644 --- a/ngraph/frontend/onnx_import/src/op/unsqueeze.cpp +++ b/ngraph/frontend/onnx_import/src/op/unsqueeze.cpp @@ -41,7 +41,16 @@ namespace ngraph } // namespace set_1 - } // namespace op + namespace set_13 + { + OutputVector unsqueeze(const Node& node) + { + auto inputs = node.get_ng_inputs(); + return {std::make_shared(inputs.at(0), inputs.at(1))}; + } + + } // namespace set_13 + } // namespace op } // namespace onnx_import diff --git a/ngraph/frontend/onnx_import/src/ops_bridge.cpp b/ngraph/frontend/onnx_import/src/ops_bridge.cpp index 7c113f0a0c9..6d1571240f7 100644 --- a/ngraph/frontend/onnx_import/src/ops_bridge.cpp +++ b/ngraph/frontend/onnx_import/src/ops_bridge.cpp @@ -431,8 +431,10 @@ namespace ngraph REGISTER_OPERATOR("Softsign", 1, softsign); REGISTER_OPERATOR("SpaceToDepth", 1, space_to_depth); REGISTER_OPERATOR("Split", 1, split); + REGISTER_OPERATOR("Split", 13, split); REGISTER_OPERATOR("Sqrt", 1, sqrt); REGISTER_OPERATOR("Squeeze", 1, squeeze); + REGISTER_OPERATOR("Squeeze", 13, squeeze); REGISTER_OPERATOR("Sub", 1, sub); REGISTER_OPERATOR("Sub", 7, sub); REGISTER_OPERATOR("Sum", 1, sum); @@ -446,6 +448,7 @@ namespace ngraph REGISTER_OPERATOR("TopK", 11, topk); REGISTER_OPERATOR("Transpose", 1, transpose); REGISTER_OPERATOR("Unsqueeze", 1, unsqueeze); + REGISTER_OPERATOR("Unsqueeze", 13, unsqueeze); REGISTER_OPERATOR("Upsample", 1, upsample); REGISTER_OPERATOR("Upsample", 9, upsample); REGISTER_OPERATOR("Where", 1, where); diff --git a/ngraph/python/requirements_test.txt b/ngraph/python/requirements_test.txt index c6bb0dd98fc..7ebee9a9404 100644 --- a/ngraph/python/requirements_test.txt +++ b/ngraph/python/requirements_test.txt @@ -2,7 +2,7 @@ flake8==3.8.4 flake8-comprehensions==3.3.0 flake8-docstrings==1.5.0 flake8-quotes==3.2.0 -onnx==1.7.0 +onnx==1.8.0 pydocstyle==5.1.1 pytest==6.1.2 retrying==1.3.3 diff --git a/ngraph/python/tests/__init__.py b/ngraph/python/tests/__init__.py index ba015a41523..378bbcf1885 100644 --- a/ngraph/python/tests/__init__.py +++ b/ngraph/python/tests/__init__.py @@ -110,7 +110,6 @@ xfail_issue_38084 = xfail_test(reason="RuntimeError: AssertionFailed: layer->get "with index 0 contains dynamic shapes: {}. Try to use " "CNNNetwork::reshape() method in order to specialize shapes " "before the conversion.") -xfail_issue_38085 = xfail_test(reason="RuntimeError: Interpolate operation should be converted to Interp") xfail_issue_38086 = xfail_test(reason="RuntimeError: Quantize layer input '' doesn't have blobs") xfail_issue_38087 = xfail_test(reason="RuntimeError: Cannot cast to tensor desc. Format is unsupported!") xfail_issue_38091 = xfail_test(reason="AssertionError: Mismatched elements") @@ -170,6 +169,23 @@ xfail_issue_38735 = xfail_test(reason="RuntimeError: nGraph does not support the "ai.onnx.preview.training.Adagrad") xfail_issue_38736 = xfail_test(reason="RuntimeError: nGraph does not support the following ONNX operations:" "NegativeLogLikelihoodLoss") +xfail_issue_43523 = xfail_test(reason="onnx.onnx_cpp2py_export.checker.ValidationError:" + " Unrecognized attribute: axes for operator ReduceSum") +xfail_issue_44839 = xfail_test(reason="Huge computation missmatch") +xfail_issue_44848 = xfail_test(reason="E Unsupported dynamic op: Range") +xfail_issue_44851 = xfail_test(reason="E Unsupported dynamic op: Broadcast") +xfail_issue_44854 = xfail_test(reason="E Unsupported dynamic op: VariadicSplit") +xfail_issue_44858 = xfail_test(reason="E Unsupported dynamic op: Unsqueeze") +xfail_issue_44956 = xfail_test(reason="E Unsupported dynamic op: Loop") +xfail_issue_44957 = xfail_test(reason="E Unsupported dynamic op: NonZero") +xfail_issue_44958 = xfail_test(reason="E Unsupported dynamic op: Interpolate") +xfail_issue_44965 = xfail_test(reason="E RuntimeError: value info has no element") +xfail_issue_44967 = xfail_test(reason="E RuntimeError: unsupported element type: BFLOAT16") +xfail_issue_44968 = xfail_test(reason="E Unsupported dynamic op: Squeeze") +xfail_issue_44970 = xfail_test(reason="Assertion error") +xfail_issue_44976 = xfail_test(reason="E RuntimeError: Quantize layer with name:" + "FakeQuantize_xxx has non const input on 1 port") + # Model ONNX Zoo issues: xfail_issue_39684 = xfail_test(reason="ngraph.exceptions.UserInputError:" diff --git a/ngraph/python/tests/test_ngraph/test_data_movement.py b/ngraph/python/tests/test_ngraph/test_data_movement.py index 7cad0e272dc..f2e85bcd179 100644 --- a/ngraph/python/tests/test_ngraph/test_data_movement.py +++ b/ngraph/python/tests/test_ngraph/test_data_movement.py @@ -14,7 +14,6 @@ # limitations under the License. # ****************************************************************************** import numpy as np -import pytest import ngraph as ng from ngraph.impl import Type @@ -167,7 +166,6 @@ def test_pad_edge(): assert np.allclose(result, expected) -@pytest.mark.xfail(reason="AssertionError") def test_pad_constant(): input_data = np.arange(1, 13).reshape([3, 4]) pads_begin = np.array([0, 1], dtype=np.int32) diff --git a/ngraph/python/tests/test_ngraph/test_ops_fused.py b/ngraph/python/tests/test_ngraph/test_ops_fused.py index 49412d3a54d..f7e37805a1f 100644 --- a/ngraph/python/tests/test_ngraph/test_ops_fused.py +++ b/ngraph/python/tests/test_ngraph/test_ops_fused.py @@ -19,11 +19,11 @@ import pytest import ngraph as ng from tests.runtime import get_runtime from tests import (xfail_issue_40957, - skip_segfault, xfail_issue_34327, xfail_issue_36485, xfail_issue_36486, - xfail_issue_36487) + xfail_issue_36487, + xfail_issue_44976) @xfail_issue_40957 @@ -58,7 +58,7 @@ def test_elu_operator_with_scalar(): assert np.allclose(result, expected) -@skip_segfault +@xfail_issue_44976 def test_fake_quantize(): runtime = get_runtime() diff --git a/ngraph/python/tests/test_ngraph/test_ops_unary.py b/ngraph/python/tests/test_ngraph/test_ops_unary.py index f0327de1be3..61cce9ba8d6 100644 --- a/ngraph/python/tests/test_ngraph/test_ops_unary.py +++ b/ngraph/python/tests/test_ngraph/test_ops_unary.py @@ -19,6 +19,7 @@ import pytest import ngraph as ng from ngraph.impl import Shape, Type from tests.test_ngraph.util import run_op_node +from tests import xfail_issue_44970 @pytest.mark.parametrize( @@ -110,7 +111,7 @@ def test_sigmoid(): assert np.allclose(result, expected) -@pytest.mark.skip(reason="Wrong results are broadcasted along given axis") +@xfail_issue_44970 def test_softmax(): axis = 0 input_tensor = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) diff --git a/ngraph/python/tests/test_ngraph/test_sequence_processing.py b/ngraph/python/tests/test_ngraph/test_sequence_processing.py index 2c9c5d25b63..e9b922b1066 100644 --- a/ngraph/python/tests/test_ngraph/test_sequence_processing.py +++ b/ngraph/python/tests/test_ngraph/test_sequence_processing.py @@ -18,7 +18,8 @@ import numpy as np import ngraph as ng from tests.runtime import get_runtime from tests.test_ngraph.util import run_op_node -from tests import xfail_issue_36478, skip_issue_38084 +from tests import (xfail_issue_36478, + xfail_issue_44848) def test_onehot(): @@ -46,7 +47,7 @@ def test_one_hot(): assert np.allclose(result, excepted) -@skip_issue_38084 +@xfail_issue_44848 def test_range(): start = 5 stop = 35 diff --git a/ngraph/python/tests/test_onnx/test_backend.py b/ngraph/python/tests/test_onnx/test_backend.py index 68ebd44ca3a..8a67427952a 100644 --- a/ngraph/python/tests/test_onnx/test_backend.py +++ b/ngraph/python/tests/test_onnx/test_backend.py @@ -25,7 +25,6 @@ import onnx.backend.test from tests.test_onnx.utils.onnx_backend import OpenVinoTestBackend from tests import (BACKEND_NAME, - skip_issue_38084, xfail_issue_36535, xfail_issue_39656, xfail_issue_39658, @@ -77,7 +76,21 @@ from tests import (BACKEND_NAME, xfail_issue_38735, xfail_issue_40319, xfail_issue_40485, - xfail_issue_41894) + xfail_issue_41894, + xfail_issue_43523, + xfail_issue_43742, + xfail_issue_44839, + xfail_issue_44848, + xfail_issue_44851, + xfail_issue_44854, + xfail_issue_44858, + xfail_issue_44956, + xfail_issue_44957, + xfail_issue_44958, + xfail_issue_44965, + xfail_issue_44967, + xfail_issue_44968, + xfail_issue_44976) def expect_fail(test_case_path, xfail): # type: (str) -> None @@ -124,21 +137,6 @@ OnnxBackendPyTorchConvertedModelTest = None globals().update(backend_test.enable_report().test_cases) tests_expected_to_fail = [ - (skip_issue_38084, - "OnnxBackendNodeModelTest.test_expand_dim_changed_cpu", - "OnnxBackendNodeModelTest.test_expand_dim_unchanged_cpu", - "OnnxBackendSimpleModelTest.test_expand_shape_model1_cpu", - "OnnxBackendSimpleModelTest.test_expand_shape_model2_cpu", - "OnnxBackendSimpleModelTest.test_expand_shape_model3_cpu", - "OnnxBackendSimpleModelTest.test_expand_shape_model4_cpu", - "OnnxBackendNodeModelTest.test_slice_default_axes_cpu", - "OnnxBackendNodeModelTest.test_top_k_cpu", - "OnnxBackendNodeModelTest.test_top_k_negative_axis_cpu", - "OnnxBackendNodeModelTest.test_top_k_smallest_cpu", - "OnnxBackendNodeModelTest.test_nonzero_example_cpu", - "OnnxBackendNodeModelTest.test_range_int32_type_negative_delta_cpu", - "OnnxBackendNodeModelTest.test_range_float_type_positive_delta_cpu", - "OnnxBackendNodeModelTest.test_upsample_nearest_cpu"), (xfail_issue_34314, "OnnxBackendNodeModelTest.test_rnn_seq_length_cpu", "OnnxBackendNodeModelTest.test_simple_rnn_defaults_cpu", @@ -432,22 +430,126 @@ tests_expected_to_fail = [ "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2_expanded_cpu", "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1_weight_expanded_cpu", "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_expanded_cpu", # noqa - "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1_mean_weight_negative_ignore_index_expanded_cpu", # noqa - "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_none_no_weight_expanded_cpu", # noqa - "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_ignore_index_expanded_cpu", # noqa - "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1_expanded_cpu", - "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_mean_weight_expanded_cpu", # noqa - "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_expanded_cpu", # noqa - "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NC_expanded_cpu", - "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_expanded_cpu", # noqa - "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_iinput_shape_is_NCd1_weight_ignore_index_expanded_cpu", # noqa - "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_mean_expanded_cpu", # noqa - "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_expanded_cpu", # noqa - "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_sum_expanded_cpu", # noqa - "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_mean_expanded_cpu", # noqa - "OnnxBackendNodeModelTest.test_gather_elements_0_cpu", - "OnnxBackendNodeModelTest.test_gather_elements_negative_indices_cpu", - "OnnxBackendNodeModelTest.test_gather_elements_1_cpu"), + "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1_mean_weight_negative_ignore_index_expanded_cpu", # noqa + "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_none_no_weight_expanded_cpu", # noqa + "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_ignore_index_expanded_cpu", # noqa + "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1_expanded_cpu", + "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3d4d5_mean_weight_expanded_cpu", # noqa + "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_sum_expanded_cpu", # noqa + "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NC_expanded_cpu", + "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_sum_weight_high_ignore_index_expanded_cpu", # noqa + "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_iinput_shape_is_NCd1_weight_ignore_index_expanded_cpu", # noqa + "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_reduction_mean_expanded_cpu", # noqa + "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2_with_weight_expanded_cpu", # noqa + "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_sum_expanded_cpu", # noqa + "OnnxBackendNodeModelTest.test_negative_log_likelihood_loss_input_shape_is_NCd1d2_reduction_mean_expanded_cpu", # noqa + "OnnxBackendNodeModelTest.test_gather_elements_0_cpu", + "OnnxBackendNodeModelTest.test_gather_elements_negative_indices_cpu", + "OnnxBackendNodeModelTest.test_gather_elements_1_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NC_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NC_expanded_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1_expanded_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1_ii_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1_ii_expanded_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1_mean_weight_negative_ii_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1_mean_weight_negative_ii_expanded_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1_weight_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1_weight_expanded_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1_weight_ii_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1_weight_ii_expanded_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2_expanded_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2_no_weight_reduction_mean_ii_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2_no_weight_reduction_mean_ii_expanded_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2_reduction_mean_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2_reduction_mean_expanded_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2_reduction_sum_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2_reduction_sum_expanded_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_expanded_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_mean_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_mean_expanded_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_sum_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_sum_expanded_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_sum_ii_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2_with_weight_reduction_sum_ii_expanded_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3_none_no_weight_negative_ii_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3_none_no_weight_negative_ii_expanded_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3_sum_weight_high_ii_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3_sum_weight_high_ii_expanded_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3d4d5_mean_weight_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3d4d5_mean_weight_expanded_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3d4d5_none_no_weight_cpu", + "OnnxBackendNodeModelTest.test_nllloss_NCd1d2d3d4d5_none_no_weight_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1_mean_weight_negative_ii_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1_mean_weight_negative_ii_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1_mean_weight_negative_ii_log_prob_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1_mean_weight_negative_ii_log_prob_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1d2d3_none_no_weight_negative_ii_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1d2d3_none_no_weight_negative_ii_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1d2d3_none_no_weight_negative_ii_log_prob_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1d2d3_sum_weight_high_ii_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1d2d3_sum_weight_high_ii_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1d2d3_sum_weight_high_ii_log_prob_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1d2d3_sum_weight_high_ii_log_prob_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_mean_weight_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_mean_weight_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_mean_weight_log_prob_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_mean_weight_log_prob_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_none_no_weight_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_none_no_weight_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_none_no_weight_log_prob_cpu", + "OnnxBackendNodeModelTest.test_sce_NCd1d2d3d4d5_none_no_weight_log_prob_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_3d_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_3d_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_3d_log_prob_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_3d_log_prob_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_log_prob_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_log_prob_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_3d_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_3d_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_3d_log_prob_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_3d_log_prob_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_4d_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_4d_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_4d_log_prob_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_4d_log_prob_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_log_prob_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_no_weight_ii_log_prob_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_weight_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_weight_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_weight_ii_3d_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_weight_ii_3d_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_weight_ii_3d_log_prob_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_weight_ii_3d_log_prob_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_weight_ii_4d_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_weight_ii_4d_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_weight_ii_4d_log_prob_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_weight_ii_4d_log_prob_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_weight_ii_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_weight_ii_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_weight_ii_log_prob_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_weight_ii_log_prob_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_weight_log_prob_cpu", + "OnnxBackendNodeModelTest.test_sce_mean_weight_log_prob_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_none_cpu", + "OnnxBackendNodeModelTest.test_sce_none_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_none_log_prob_cpu", + "OnnxBackendNodeModelTest.test_sce_none_log_prob_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_none_weights_cpu", + "OnnxBackendNodeModelTest.test_sce_none_weights_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_none_weights_log_prob_cpu", + "OnnxBackendNodeModelTest.test_sce_none_weights_log_prob_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_sum_cpu", + "OnnxBackendNodeModelTest.test_sce_sum_expanded_cpu", + "OnnxBackendNodeModelTest.test_sce_sum_log_prob_cpu", + "OnnxBackendNodeModelTest.test_sce_sum_log_prob_expanded_cpu"), (xfail_issue_38712, "OnnxBackendNodeModelTest.test_mod_mixed_sign_int16_cpu", "OnnxBackendNodeModelTest.test_mod_uint8_cpu", @@ -534,7 +636,82 @@ tests_expected_to_fail = [ "OnnxBackendNodeModelTest.test_adagrad_cpu"), (xfail_issue_41894, "OnnxBackendNodeModelTest.test_max_uint16_cpu", - "OnnxBackendNodeModelTest.test_mod_int64_fmod_cpu") + "OnnxBackendNodeModelTest.test_mod_int64_fmod_cpu"), + (xfail_issue_43523, + "OnnxBackendNodeModelTest.test_reduce_sum_do_not_keepdims_example_cpu", + "OnnxBackendNodeModelTest.test_reduce_sum_do_not_keepdims_random_cpu", + "OnnxBackendNodeModelTest.test_reduce_sum_keepdims_example_cpu", + "OnnxBackendNodeModelTest.test_reduce_sum_keepdims_random_cpu", + "OnnxBackendNodeModelTest.test_reduce_sum_negative_axes_keepdims_example_cpu", + "OnnxBackendNodeModelTest.test_reduce_sum_default_axes_keepdims_example_cpu", + "OnnxBackendNodeModelTest.test_reduce_sum_default_axes_keepdims_random_cpu", + "OnnxBackendNodeModelTest.test_reduce_sum_empty_axes_input_noop_example_cpu", + "OnnxBackendNodeModelTest.test_reduce_sum_empty_axes_input_noop_random_cpu", + "OnnxBackendNodeModelTest.test_reduce_sum_negative_axes_keepdims_random_cpu"), + (xfail_issue_43742, + "OnnxBackendNodeModelTest.test_if_cpu", + "OnnxBackendNodeModelTest.test_if_seq_cpu"), + (xfail_issue_44839, + "OnnxBackendNodeModelTest.test_logsoftmax_axis_0_cpu", + "OnnxBackendNodeModelTest.test_logsoftmax_axis_0_expanded_cpu", + "OnnxBackendNodeModelTest.test_logsoftmax_axis_1_cpu", + "OnnxBackendNodeModelTest.test_logsoftmax_axis_1_expanded_cpu", + "OnnxBackendNodeModelTest.test_logsoftmax_axis_2_expanded_cpu", + "OnnxBackendNodeModelTest.test_logsoftmax_default_axis_expanded_cpu", + "OnnxBackendNodeModelTest.test_logsoftmax_large_number_expanded_cpu", + "OnnxBackendNodeModelTest.test_logsoftmax_negative_axis_expanded_cpu", + "OnnxBackendNodeModelTest.test_softmax_axis_0_cpu", + "OnnxBackendNodeModelTest.test_softmax_axis_0_expanded_cpu", + "OnnxBackendNodeModelTest.test_softmax_axis_1_cpu", + "OnnxBackendNodeModelTest.test_softmax_axis_1_expanded_cpu", + "OnnxBackendNodeModelTest.test_softmax_axis_2_expanded_cpu", + "OnnxBackendNodeModelTest.test_softmax_default_axis_cpu", + "OnnxBackendNodeModelTest.test_softmax_default_axis_expanded_cpu", + "OnnxBackendNodeModelTest.test_softmax_large_number_expanded_cpu", + "OnnxBackendNodeModelTest.test_softmax_negative_axis_expanded_cpu", + "OnnxBackendNodeModelTest.test_hardmax_axis_0_cpu", + "OnnxBackendNodeModelTest.test_hardmax_axis_1_cpu", + "OnnxBackendNodeModelTest.test_hardmax_default_axis_cpu",), + (xfail_issue_44848, + "OnnxBackendNodeModelTest.test_range_float_type_positive_delta_cpu", + "OnnxBackendNodeModelTest.test_range_int32_type_negative_delta_cpu",), + (xfail_issue_44851, + "OnnxBackendNodeModelTest.test_expand_dim_changed_cpu", + "OnnxBackendNodeModelTest.test_expand_dim_unchanged_cpu", + "OnnxBackendSimpleModelTest.test_expand_shape_model1_cpu", + "OnnxBackendSimpleModelTest.test_expand_shape_model2_cpu", + "OnnxBackendSimpleModelTest.test_expand_shape_model3_cpu", + "OnnxBackendSimpleModelTest.test_expand_shape_model4_cpu",), + (xfail_issue_44854, + "OnnxBackendNodeModelTest.test_split_variable_parts_1d_cpu", + "OnnxBackendNodeModelTest.test_split_variable_parts_2d_cpu", + "OnnxBackendNodeModelTest.test_split_variable_parts_default_axis_cpu",), + (xfail_issue_44858, + "OnnxBackendNodeModelTest.test_unsqueeze_axis_0_cpu", + "OnnxBackendNodeModelTest.test_unsqueeze_axis_1_cpu", + "OnnxBackendNodeModelTest.test_unsqueeze_axis_2_cpu", + "OnnxBackendNodeModelTest.test_unsqueeze_negative_axes_cpu", + "OnnxBackendNodeModelTest.test_unsqueeze_three_axes_cpu", + "OnnxBackendNodeModelTest.test_unsqueeze_two_axes_cpu", + "OnnxBackendNodeModelTest.test_unsqueeze_unsorted_axes_cpu",), + (xfail_issue_44956, + "OnnxBackendNodeModelTest.test_loop11_cpu"), + (xfail_issue_44957, + "OnnxBackendNodeModelTest.test_nonzero_example_cpu"), + (xfail_issue_44958, + "OnnxBackendNodeModelTest.test_upsample_nearest_cpu"), + (xfail_issue_44965, + "OnnxBackendNodeModelTest.test_loop13_seq_cpu", + "OnnxBackendNodeModelTest.test_sequence_insert_at_back_cpu", + "OnnxBackendNodeModelTest.test_sequence_insert_at_front_cpu",), + (xfail_issue_44967, + "OnnxBackendNodeModelTest.test_cast_BFLOAT16_to_FLOAT_cpu", + "OnnxBackendNodeModelTest.test_cast_FLOAT_to_BFLOAT16_cpu",), + (xfail_issue_44968, + "OnnxBackendNodeModelTest.test_squeeze_cpu", + "OnnxBackendNodeModelTest.test_squeeze_negative_axes_cpu",), + (xfail_issue_44976, + "OnnxBackendNodeModelTest.test_quantizelinear_axis_cpu",) ] for test_group in tests_expected_to_fail: diff --git a/ngraph/python/tests/test_onnx/test_ops_binary.py b/ngraph/python/tests/test_onnx/test_ops_binary.py index e9b23832fee..2a19208aa88 100644 --- a/ngraph/python/tests/test_onnx/test_ops_binary.py +++ b/ngraph/python/tests/test_onnx/test_ops_binary.py @@ -19,7 +19,7 @@ import pytest from onnx.helper import make_graph, make_model, make_tensor_value_info from tests.test_onnx.utils import run_model -from tests import skip_segfault +from tests import xfail_issue_44970 def import_and_compute(op_type, input_data_left, input_data_right, opset=7, **node_attributes): @@ -38,7 +38,7 @@ def import_and_compute(op_type, input_data_left, input_data_right, opset=7, **no return run_model(model, inputs)[0] -@skip_segfault +@xfail_issue_44970 def test_add_opset4(): assert np.array_equal(import_and_compute("Add", 1, 2, opset=4), np.array(3, dtype=np.float32)) @@ -111,7 +111,7 @@ def test_add_opset7(left_shape, right_shape): assert np.array_equal(import_and_compute("Add", left_input, right_input), left_input + right_input) -@skip_segfault +@xfail_issue_44970 def test_sub(): assert np.array_equal(import_and_compute("Sub", 20, 1), np.array(19, dtype=np.float32)) @@ -125,7 +125,7 @@ def test_sub(): ) -@skip_segfault +@xfail_issue_44970 def test_mul(): assert np.array_equal(import_and_compute("Mul", 2, 3), np.array(6, dtype=np.float32)) @@ -139,7 +139,7 @@ def test_mul(): ) -@skip_segfault +@xfail_issue_44970 def test_div(): assert np.array_equal(import_and_compute("Div", 6, 3), np.array(2, dtype=np.float32)) diff --git a/ngraph/python/tests/test_onnx/test_ops_reduction.py b/ngraph/python/tests/test_onnx/test_ops_reduction.py index a2c9e824c44..76e96dd170c 100644 --- a/ngraph/python/tests/test_onnx/test_ops_reduction.py +++ b/ngraph/python/tests/test_onnx/test_ops_reduction.py @@ -18,11 +18,11 @@ import onnx import pytest from tests.test_onnx.utils import run_node -from tests import xfail_issue_35925 +from tests import (xfail_issue_35925, + xfail_issue_43523) reduce_data = np.array([[[5, 1], [20, 2]], [[30, 1], [40, 2]], [[55, 1], [60, 2]]], dtype=np.float32) reduce_axis_parameters = [ - None, (0,), (1,), (2,), @@ -36,7 +36,7 @@ reduce_operation_parameters = [ ("ReduceMax", np.max), ("ReduceMin", np.min), ("ReduceMean", np.mean), - ("ReduceSum", np.sum), + pytest.param("ReduceSum", np.sum, marks=xfail_issue_43523), ("ReduceProd", np.prod) ] @@ -47,15 +47,23 @@ def import_and_compute(op_type, input_data, **node_attrs): return run_node(node, data_inputs).pop() +@pytest.mark.parametrize("operation, ref_operation", [ + ("ReduceMax", np.max), + ("ReduceMin", np.min), + ("ReduceMean", np.mean), + ("ReduceSum", np.sum), + ("ReduceProd", np.prod) +]) +def test_reduce_operation_keepdims_none_axes(operation, ref_operation): + assert np.array_equal(import_and_compute(operation, reduce_data, keepdims=True), + ref_operation(reduce_data, keepdims=True)) + + @pytest.mark.parametrize("operation, ref_operation", reduce_operation_parameters) @pytest.mark.parametrize("axes", reduce_axis_parameters) def test_reduce_operation_keepdims(operation, ref_operation, axes): - if axes: - assert np.array_equal(import_and_compute(operation, reduce_data, axes=axes, keepdims=True), - ref_operation(reduce_data, keepdims=True, axis=axes)) - else: - assert np.array_equal(import_and_compute(operation, reduce_data, keepdims=True), - ref_operation(reduce_data, keepdims=True)) + assert np.array_equal(import_and_compute(operation, reduce_data, axes=axes, keepdims=True), + ref_operation(reduce_data, keepdims=True, axis=axes)) @pytest.mark.parametrize("axes", [ diff --git a/ngraph/python/tests/test_onnx/test_ops_reshape.py b/ngraph/python/tests/test_onnx/test_ops_reshape.py index b428c541dff..2bfceb3407f 100644 --- a/ngraph/python/tests/test_onnx/test_ops_reshape.py +++ b/ngraph/python/tests/test_onnx/test_ops_reshape.py @@ -26,7 +26,10 @@ from tests.test_onnx.utils import ( run_model, run_node, ) -from tests import xfail_issue_35927 +from tests import (xfail_issue_35927, + xfail_issue_44854, + xfail_issue_44858, + xfail_issue_44968) def test_reshape(): @@ -228,36 +231,43 @@ def test_concat(): assert np.array_equal(ng_results, [expected_output]) +@xfail_issue_44968 def test_squeeze(): data = np.arange(6, dtype=np.int32).reshape([1, 2, 3, 1]) expected_output = data.reshape([2, 3]) - node = onnx.helper.make_node("Squeeze", inputs=["x"], outputs=["y"], axes=[0, 3]) - ng_results = run_node(node, [data]) + axes = np.array([0, 3]).astype(np.int64) + node = onnx.helper.make_node("Squeeze", inputs=["x", "axes"], outputs=["y"]) + ng_results = run_node(node, [data, axes]) assert np.array_equal(ng_results, [expected_output]) data = np.random.randn(1, 3, 4, 5).astype(np.float32) expected_output = np.squeeze(data, axis=0) - node = onnx.helper.make_node("Squeeze", inputs=["x"], outputs=["y"], axes=[0]) - ng_results = run_node(node, [data]) + axes = np.array([0]).astype(np.int64) + node = onnx.helper.make_node("Squeeze", inputs=["x", "axes"], outputs=["y"]) + ng_results = run_node(node, [data, axes]) assert np.array_equal(ng_results, [expected_output]) +@xfail_issue_44858 def test_unsqueeze(): data = np.random.randn(3, 4, 5).astype(np.float32) expected_output = np.expand_dims(data, axis=0) - node = onnx.helper.make_node("Unsqueeze", inputs=["x"], outputs=["y"], axes=[0]) - ng_results = run_node(node, [data]) + axes = np.array([0]).astype(np.int64) + node = onnx.helper.make_node("Unsqueeze", inputs=["x", "axes"], outputs=["y"]) + ng_results = run_node(node, [data, axes]) assert np.array_equal(ng_results, [expected_output]) expected_output = np.reshape(data, [1, 3, 4, 5, 1]) - node = onnx.helper.make_node("Unsqueeze", inputs=["x"], outputs=["y"], axes=[0, 4]) - ng_results = run_node(node, [data]) + axes = np.array([0, 4]).astype(np.int64) + node = onnx.helper.make_node("Unsqueeze", inputs=["x", "axes"], outputs=["y"]) + ng_results = run_node(node, [data, axes]) assert np.array_equal(ng_results, [expected_output]) expected_output = np.reshape(data, [1, 3, 1, 4, 5]) - node = onnx.helper.make_node("Unsqueeze", inputs=["x"], outputs=["y"], axes=[0, 2]) - ng_results = run_node(node, [data]) + axes = np.array([0, 2]).astype(np.int64) + node = onnx.helper.make_node("Unsqueeze", inputs=["x", "axes"], outputs=["y"]) + ng_results = run_node(node, [data, axes]) assert np.array_equal(ng_results, [expected_output]) @@ -300,16 +310,6 @@ def test_unsqueeze(): np.array([[3], [7]], dtype=np.int32), ], ), - # Split into 2 unequal parts along axis=1 - ( - onnx.helper.make_node( - "Split", inputs=["x"], outputs=["a", "b"], axis=1, split=(3, 1) - ), - [ - np.array([[0, 1, 2], [4, 5, 6]], dtype=np.int32), - np.array([[3], [7]], dtype=np.int32), - ], - ), ], ) def test_split_2d(node, expected_output): @@ -318,6 +318,22 @@ def test_split_2d(node, expected_output): assert all_arrays_equal(ng_results, expected_output) +@xfail_issue_44854 +def test_split_2d_splits_input(): + data = np.arange(8, dtype=np.int32).reshape(2, 4) + splits = np.array([3, 1]).astype(np.int64) + node = onnx.helper.make_node( + "Split", inputs=["x", "splits"], outputs=["a", "b"], axis=1 + ) + expected_outputs = [ + np.array([[0, 1, 2], [4, 5, 6]], dtype=np.int32), + np.array([[3], [7]], dtype=np.int32), + ] + ng_results = run_node(node, [data, splits]) + assert all_arrays_equal(ng_results, expected_outputs) + + +@xfail_issue_44854 def test_split_1d(): # 1D data = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).astype(np.float32) @@ -330,15 +346,16 @@ def test_split_1d(): ng_results = run_node(node, [data]) assert all_arrays_equal(ng_results, expected_outputs) + splits = np.array([2, 3, 1]).astype(np.int64) node = onnx.helper.make_node( - "Split", inputs=["input"], outputs=["y", "z", "w"], axis=0, split=[2, 3, 1] + "Split", inputs=["input", "splits"], outputs=["y", "z", "w"], axis=0 ) expected_outputs = [ np.array([1.0, 2.0]).astype(np.float32), np.array([3.0, 4.0, 5.0]).astype(np.float32), np.array([6.0]).astype(np.float32), ] - ng_results = run_node(node, [data]) + ng_results = run_node(node, [data, splits]) assert all_arrays_equal(ng_results, expected_outputs) # Default values @@ -353,14 +370,15 @@ def test_split_1d(): ng_results = run_node(node, [data]) assert all_arrays_equal(ng_results, expected_outputs) + splits = np.array([2, 4]).astype(np.int64) node = onnx.helper.make_node( - "Split", inputs=["input"], outputs=["y", "z"], split=[2, 4] + "Split", inputs=["input", "splits"], outputs=["y", "z"], split=[2, 4] ) expected_outputs = [ np.array([1.0, 2.0]).astype(np.float32), np.array([3.0, 4.0, 5.0, 6.0]).astype(np.float32), ] - ng_results = run_node(node, [data]) + ng_results = run_node(node, [data, splits]) assert all_arrays_equal(ng_results, expected_outputs) diff --git a/ngraph/python/tests/test_onnx/test_ops_unary.py b/ngraph/python/tests/test_onnx/test_ops_unary.py index 5b28c636d43..d5ba32bd580 100644 --- a/ngraph/python/tests/test_onnx/test_ops_unary.py +++ b/ngraph/python/tests/test_onnx/test_ops_unary.py @@ -302,13 +302,13 @@ def test_logsoftmax(): ng_results = run_node(node, [data]) assert np.allclose(ng_results, [expected]) - # default axis is 1 - node = onnx.helper.make_node("LogSoftmax", inputs=["x"], outputs=["y"]) + node = onnx.helper.make_node("LogSoftmax", inputs=["x"], outputs=["y"], axis=2) + expected = logsoftmax_2d(data.reshape(12, 5)).reshape(3, 4, 5) ng_results = run_node(node, [data]) assert np.allclose(ng_results, [expected]) - node = onnx.helper.make_node("LogSoftmax", inputs=["x"], outputs=["y"], axis=2) - expected = logsoftmax_2d(data.reshape(12, 5)).reshape(3, 4, 5) + # default axis is -1 + node = onnx.helper.make_node("LogSoftmax", inputs=["x"], outputs=["y"]) ng_results = run_node(node, [data]) assert np.allclose(ng_results, [expected]) @@ -388,8 +388,7 @@ def test_cast_to_bool(val_type, input_data): "val_type, range_start, range_end, in_dtype", [ (np.dtype(np.float32), -8, 8, np.dtype(np.int32)), - pytest.param(np.dtype(np.float64), -16383, 16383, np.dtype(np.int64), - marks=pytest.mark.xfail(reason="RuntimeError: Unsupported type")), + (np.dtype(np.float64), -16383, 16383, np.dtype(np.int64)), ], ) def test_cast_to_float(val_type, range_start, range_end, in_dtype):