pytest:- extractor_test.py (#19487)

This commit is contained in:
Pratham Ingawale 2023-09-12 23:53:09 +05:30 committed by GitHub
parent e3f1ff7f2a
commit 9250d17e01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,7 +4,7 @@
import unittest
import numpy as np
from generator import generator, generate
import pytest
from openvino.tools.mo.front.common.partial_infer.utils import strict_compare_tensors
from openvino.tools.mo.front.extractor import input_user_data_repack, output_user_data_repack, update_ie_fields, add_input_op, \
get_node_id_with_ports
@ -438,10 +438,9 @@ class TestInputAddition(UnitTestWithMockedTelemetry):
self.assertTrue(Node(graph, 'relu_1').in_edge(0)['edge_attr'] == 'edge_value')
@generator
class TestOutputCut(unittest.TestCase):
class TestOutputCut():
# {'embeddings': [{'port': None}]}
@generate({'C': [{'port': None}]}, {'C': [{'out': 0}]}, {'C': [{'out': 1}]})
@pytest.mark.parametrize("output",[{'C':[{'port': None}]}, {'C': [{'out': 0}]}, {'C': [{'out': 1}]}])
def test_output_port_cut(self, output):
nodes = {'A': {'type': 'Identity', 'kind': 'op', 'op': 'Identity'},
'B': {'type': 'Identity', 'kind': 'op', 'op': 'Identity'},
@ -458,10 +457,10 @@ class TestOutputCut(unittest.TestCase):
graph = build_graph_with_edge_attrs(nodes, edges)
sinks = add_output_ops(graph, output)
graph.clean_up()
self.assertEqual(len(Node(graph, 'C').out_nodes()), 1)
self.assertEqual(len(Node(graph, 'C').in_nodes()), 2)
assert len(Node(graph, 'C').out_nodes()) == 1
assert len(Node(graph, 'C').in_nodes()) == 2
@generate({'C': [{'in': 0}]}, {'C': [{'in': 1}]})
@pytest.mark.parametrize("output",[{'C': [{'in': 0}]}, {'C': [{'in': 1}]}])
def test_output_port_cut(self, output):
nodes = {'A': {'op': 'Parameter', 'kind': 'op'},
'B': {'op': 'Parameter', 'kind': 'op'},
@ -478,7 +477,7 @@ class TestOutputCut(unittest.TestCase):
graph = build_graph_with_edge_attrs(nodes, edges)
sinks = add_output_ops(graph, output)
graph.clean_up()
self.assertEqual(len(graph.nodes()), 2)
assert len(graph.nodes()) == 2
class TestUserDataRepack(UnitTestWithMockedTelemetry):