pytest:- extractor_test.py (#19487)
This commit is contained in:
parent
e3f1ff7f2a
commit
9250d17e01
@ -4,7 +4,7 @@
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from generator import generator, generate
|
import pytest
|
||||||
from openvino.tools.mo.front.common.partial_infer.utils import strict_compare_tensors
|
from openvino.tools.mo.front.common.partial_infer.utils import strict_compare_tensors
|
||||||
from openvino.tools.mo.front.extractor import input_user_data_repack, output_user_data_repack, update_ie_fields, add_input_op, \
|
from openvino.tools.mo.front.extractor import input_user_data_repack, output_user_data_repack, update_ie_fields, add_input_op, \
|
||||||
get_node_id_with_ports
|
get_node_id_with_ports
|
||||||
@ -438,10 +438,9 @@ class TestInputAddition(UnitTestWithMockedTelemetry):
|
|||||||
self.assertTrue(Node(graph, 'relu_1').in_edge(0)['edge_attr'] == 'edge_value')
|
self.assertTrue(Node(graph, 'relu_1').in_edge(0)['edge_attr'] == 'edge_value')
|
||||||
|
|
||||||
|
|
||||||
@generator
|
class TestOutputCut():
|
||||||
class TestOutputCut(unittest.TestCase):
|
|
||||||
# {'embeddings': [{'port': None}]}
|
# {'embeddings': [{'port': None}]}
|
||||||
@generate({'C': [{'port': None}]}, {'C': [{'out': 0}]}, {'C': [{'out': 1}]})
|
@pytest.mark.parametrize("output",[{'C':[{'port': None}]}, {'C': [{'out': 0}]}, {'C': [{'out': 1}]}])
|
||||||
def test_output_port_cut(self, output):
|
def test_output_port_cut(self, output):
|
||||||
nodes = {'A': {'type': 'Identity', 'kind': 'op', 'op': 'Identity'},
|
nodes = {'A': {'type': 'Identity', 'kind': 'op', 'op': 'Identity'},
|
||||||
'B': {'type': 'Identity', 'kind': 'op', 'op': 'Identity'},
|
'B': {'type': 'Identity', 'kind': 'op', 'op': 'Identity'},
|
||||||
@ -458,10 +457,10 @@ class TestOutputCut(unittest.TestCase):
|
|||||||
graph = build_graph_with_edge_attrs(nodes, edges)
|
graph = build_graph_with_edge_attrs(nodes, edges)
|
||||||
sinks = add_output_ops(graph, output)
|
sinks = add_output_ops(graph, output)
|
||||||
graph.clean_up()
|
graph.clean_up()
|
||||||
self.assertEqual(len(Node(graph, 'C').out_nodes()), 1)
|
assert len(Node(graph, 'C').out_nodes()) == 1
|
||||||
self.assertEqual(len(Node(graph, 'C').in_nodes()), 2)
|
assert len(Node(graph, 'C').in_nodes()) == 2
|
||||||
|
|
||||||
@generate({'C': [{'in': 0}]}, {'C': [{'in': 1}]})
|
@pytest.mark.parametrize("output",[{'C': [{'in': 0}]}, {'C': [{'in': 1}]}])
|
||||||
def test_output_port_cut(self, output):
|
def test_output_port_cut(self, output):
|
||||||
nodes = {'A': {'op': 'Parameter', 'kind': 'op'},
|
nodes = {'A': {'op': 'Parameter', 'kind': 'op'},
|
||||||
'B': {'op': 'Parameter', 'kind': 'op'},
|
'B': {'op': 'Parameter', 'kind': 'op'},
|
||||||
@ -478,7 +477,7 @@ class TestOutputCut(unittest.TestCase):
|
|||||||
graph = build_graph_with_edge_attrs(nodes, edges)
|
graph = build_graph_with_edge_attrs(nodes, edges)
|
||||||
sinks = add_output_ops(graph, output)
|
sinks = add_output_ops(graph, output)
|
||||||
graph.clean_up()
|
graph.clean_up()
|
||||||
self.assertEqual(len(graph.nodes()), 2)
|
assert len(graph.nodes()) == 2
|
||||||
|
|
||||||
|
|
||||||
class TestUserDataRepack(UnitTestWithMockedTelemetry):
|
class TestUserDataRepack(UnitTestWithMockedTelemetry):
|
||||||
|
Loading…
Reference in New Issue
Block a user