pytest:- extractor_test.py (#19487)
This commit is contained in:
parent
e3f1ff7f2a
commit
9250d17e01
@ -4,7 +4,7 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from generator import generator, generate
|
||||
import pytest
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import strict_compare_tensors
|
||||
from openvino.tools.mo.front.extractor import input_user_data_repack, output_user_data_repack, update_ie_fields, add_input_op, \
|
||||
get_node_id_with_ports
|
||||
@ -438,10 +438,9 @@ class TestInputAddition(UnitTestWithMockedTelemetry):
|
||||
self.assertTrue(Node(graph, 'relu_1').in_edge(0)['edge_attr'] == 'edge_value')
|
||||
|
||||
|
||||
@generator
|
||||
class TestOutputCut(unittest.TestCase):
|
||||
class TestOutputCut():
|
||||
# {'embeddings': [{'port': None}]}
|
||||
@generate({'C': [{'port': None}]}, {'C': [{'out': 0}]}, {'C': [{'out': 1}]})
|
||||
@pytest.mark.parametrize("output",[{'C':[{'port': None}]}, {'C': [{'out': 0}]}, {'C': [{'out': 1}]}])
|
||||
def test_output_port_cut(self, output):
|
||||
nodes = {'A': {'type': 'Identity', 'kind': 'op', 'op': 'Identity'},
|
||||
'B': {'type': 'Identity', 'kind': 'op', 'op': 'Identity'},
|
||||
@ -458,10 +457,10 @@ class TestOutputCut(unittest.TestCase):
|
||||
graph = build_graph_with_edge_attrs(nodes, edges)
|
||||
sinks = add_output_ops(graph, output)
|
||||
graph.clean_up()
|
||||
self.assertEqual(len(Node(graph, 'C').out_nodes()), 1)
|
||||
self.assertEqual(len(Node(graph, 'C').in_nodes()), 2)
|
||||
assert len(Node(graph, 'C').out_nodes()) == 1
|
||||
assert len(Node(graph, 'C').in_nodes()) == 2
|
||||
|
||||
@generate({'C': [{'in': 0}]}, {'C': [{'in': 1}]})
|
||||
@pytest.mark.parametrize("output",[{'C': [{'in': 0}]}, {'C': [{'in': 1}]}])
|
||||
def test_output_port_cut(self, output):
|
||||
nodes = {'A': {'op': 'Parameter', 'kind': 'op'},
|
||||
'B': {'op': 'Parameter', 'kind': 'op'},
|
||||
@ -478,7 +477,7 @@ class TestOutputCut(unittest.TestCase):
|
||||
graph = build_graph_with_edge_attrs(nodes, edges)
|
||||
sinks = add_output_ops(graph, output)
|
||||
graph.clean_up()
|
||||
self.assertEqual(len(graph.nodes()), 2)
|
||||
assert len(graph.nodes()) == 2
|
||||
|
||||
|
||||
class TestUserDataRepack(UnitTestWithMockedTelemetry):
|
||||
|
Loading…
Reference in New Issue
Block a user