diff --git a/tools/mo/unit_tests/mo/front/extractor_test.py b/tools/mo/unit_tests/mo/front/extractor_test.py index 0211620fcc9..3478131e77f 100644 --- a/tools/mo/unit_tests/mo/front/extractor_test.py +++ b/tools/mo/unit_tests/mo/front/extractor_test.py @@ -4,7 +4,7 @@ import unittest import numpy as np -from generator import generator, generate +import pytest from openvino.tools.mo.front.common.partial_infer.utils import strict_compare_tensors from openvino.tools.mo.front.extractor import input_user_data_repack, output_user_data_repack, update_ie_fields, add_input_op, \ get_node_id_with_ports @@ -438,10 +438,9 @@ class TestInputAddition(UnitTestWithMockedTelemetry): self.assertTrue(Node(graph, 'relu_1').in_edge(0)['edge_attr'] == 'edge_value') -@generator -class TestOutputCut(unittest.TestCase): +class TestOutputCut(): # {'embeddings': [{'port': None}]} - @generate({'C': [{'port': None}]}, {'C': [{'out': 0}]}, {'C': [{'out': 1}]}) + @pytest.mark.parametrize("output",[{'C':[{'port': None}]}, {'C': [{'out': 0}]}, {'C': [{'out': 1}]}]) def test_output_port_cut(self, output): nodes = {'A': {'type': 'Identity', 'kind': 'op', 'op': 'Identity'}, 'B': {'type': 'Identity', 'kind': 'op', 'op': 'Identity'}, @@ -458,10 +457,10 @@ class TestOutputCut(unittest.TestCase): graph = build_graph_with_edge_attrs(nodes, edges) sinks = add_output_ops(graph, output) graph.clean_up() - self.assertEqual(len(Node(graph, 'C').out_nodes()), 1) - self.assertEqual(len(Node(graph, 'C').in_nodes()), 2) + assert len(Node(graph, 'C').out_nodes()) == 1 + assert len(Node(graph, 'C').in_nodes()) == 2 - @generate({'C': [{'in': 0}]}, {'C': [{'in': 1}]}) + @pytest.mark.parametrize("output",[{'C': [{'in': 0}]}, {'C': [{'in': 1}]}]) def test_output_port_cut(self, output): nodes = {'A': {'op': 'Parameter', 'kind': 'op'}, 'B': {'op': 'Parameter', 'kind': 'op'}, @@ -478,7 +477,7 @@ class TestOutputCut(unittest.TestCase): graph = build_graph_with_edge_attrs(nodes, edges) sinks = add_output_ops(graph, output) graph.clean_up() - self.assertEqual(len(graph.nodes()), 2) + assert len(graph.nodes()) == 2 class TestUserDataRepack(UnitTestWithMockedTelemetry):