Files
openvino/model-optimizer/unit_tests/extensions/load/tf/loader_test.py
Eugeny Volosenkov 38022c4cd6 Mo implementation for If with tf extractor (#6662)
* Add tf2.x impl for If

* Fix ir_engine

* Fix opset

* Fix BOM file

* Added new test

* Fix comments

* Add subgraph_utils

* Fix comments

* Fix transform

* code refactoring

* Fix description

* rewrite support for empty tensor in if

* added onnx extractor

* delete onnx_if

* fix bug with fake_outputs

* Fix test

* Fix control_flow and fix commentaries

* create method results_mapping_and_finding_fake_outputs(output_nodes_in_subgraph,
2021-08-19 10:13:21 +03:00

68 lines
2.3 KiB
Python

# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
import numpy as np
from extensions.load.tf.loader import graph_or_sub_graph_has_nhwc_ops
from unit_tests.utils.graph import build_graph, result, regular_op, const, connect_front
class TFLoaderTest(unittest.TestCase):
@staticmethod
def build_conv_graph():
nodes = {
**const('weights', np.random.randn(1, 1, 1, 1)),
**regular_op('input', {'op': 'Parameter'}),
**regular_op('conv', {'op': 'Conv2D', 'layout': 'NHWC'}),
**result('result'),
}
edges = [*connect_front('input', '0:conv'),
*connect_front('weights', '1:conv'),
*connect_front('conv:0', 'result'),
]
graph = build_graph(nodes, edges)
graph.stage = 'front'
return graph
@staticmethod
def build_parameter_result_graph():
nodes = {
**regular_op('input', {'op': 'Parameter'}),
**result('result'),
}
edges = [*connect_front('input', '0:result'),
]
graph = build_graph(nodes, edges)
graph.stage = 'front'
return graph
@staticmethod
def build_loop_graph(body_graph):
# create fake Loop operation
nodes = {
**regular_op('input', {'op': 'Parameter'}),
**regular_op('loop', {'op': 'Loop', 'body': body_graph, 'sub_graphs': ['body']}),
**result('result'),
}
edges = [*connect_front('input', '0:loop'),
*connect_front('loop:0', 'result'),
]
graph = build_graph(nodes, edges)
graph.stage = 'front'
return graph
def test_convolution_main_graph(self):
self.assertTrue(graph_or_sub_graph_has_nhwc_ops(self.build_conv_graph()))
def test_convolution_loop_body_graph(self):
self.assertTrue(graph_or_sub_graph_has_nhwc_ops(self.build_loop_graph(self.build_conv_graph())))
def test_no_convolution_main_graph(self):
self.assertFalse(graph_or_sub_graph_has_nhwc_ops(self.build_parameter_result_graph()))
def test_no_convolution_main_and_sub_graph(self):
self.assertFalse(graph_or_sub_graph_has_nhwc_ops(self.build_loop_graph(self.build_parameter_result_graph())))