Files
openvino/model-optimizer/mo/utils/ir_engine/ir_engine_test.py
2020-02-11 22:48:49 +03:00

154 lines
6.6 KiB
Python

"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging as log
import os
import sys
import unittest
from unittest import mock
from generator import generator, generate
from mo.graph.graph import Node
from mo.utils.ir_engine.ir_engine import IREngine
log.basicConfig(format="[ %(levelname)s ] %(message)s", level=log.DEBUG, stream=sys.stdout)
@generator
class TestFunction(unittest.TestCase):
def setUp(self):
path, _ = os.path.split(os.path.dirname(__file__))
self.xml = os.path.join(path,
"unittest", "test_data", "mxnet_synthetic_gru_bidirectional_FP16_1_v6.xml")
self.xml_negative = os.path.join(path,
"unittest", "test_data",
"mxnet_synthetic_gru_bidirectional_FP16_1_v6_negative.xml")
self.bin = os.path.splitext(self.xml)[0] + '.bin'
self.assertTrue(os.path.exists(self.xml), 'XML file not found: {}'.format(self.xml))
self.assertTrue(os.path.exists(self.bin), 'BIN file not found: {}'.format(self.bin))
self.IR = IREngine(path_to_xml=str(self.xml), path_to_bin=str(self.bin))
self.IR_ref = IREngine(path_to_xml=str(self.xml), path_to_bin=str(self.bin))
self.IR_negative = IREngine(path_to_xml=str(self.xml_negative), path_to_bin=str(self.bin))
@generate(*[(4.4, True), ('aaaa', False)])
def test_is_float(self, test_data, result):
test_data = test_data
self.assertEqual(IREngine._IREngine__isfloat(test_data), result,
"Function __isfloat is not working with value: {}".format(test_data))
log.info('Test for function __is_float passed wit value: {}, expected result: {}'.format(test_data, result))
# TODO add comparison not for type IREngine
def test_compare(self):
flag, msg = self.IR.compare(self.IR_ref)
self.assertTrue(flag, 'Comparing false, test compare function failed')
log.info('Test for function compare passed')
def test_comare_negative(self):
# Reference data for test:
reference_msg = 'Current node "2" with type Const and reference node "2" with type Input have different attr "type" : Const and Input'
# Check function:
flag, msg = self.IR.compare(self.IR_negative)
self.assertFalse(flag, 'Comparing flag failed, test compare function failed')
self.assertEqual('\n'.join(msg), reference_msg, 'Comparing message failes, test compare negative failed')
log.info('Test for function compare passed')
def test_find_input(self):
# Create references for this test:
ref_nodes = [Node(self.IR.graph, '0')]
# Check function:
a = IREngine._IREngine__find_input(self.IR.graph)
self.assertTrue(a == ref_nodes, 'Error')
def test_get_inputs(self):
# Reference data for test:
ref_input_dict = {'data': [1, 10, 16]}
# Check function:
inputs_dict = self.IR.get_inputs()
# is_equal = compare_dictionaries(ref_input_dict, inputs_dict)
self.assertTrue(ref_input_dict == inputs_dict, 'Test on function get_inputs failed')
log.info('Test for function get_inputs passed')
def test_eq_function(self):
self.assertTrue(self.IR == self.IR_ref, 'Comparing false, test eq function failed')
log.info('Test for function eq passed')
@unittest.mock.patch('numpy.savez_compressed')
def test_generate_bin_hashes_file(self, numpy_savez):
# Generate bin_hashes file in default directory
self.IR.generate_bin_hashes_file()
numpy_savez.assert_called_once()
log.info('Test for function generate_bin_hashes_file with default folder passed')
@unittest.mock.patch('numpy.savez_compressed')
def test_generate_bin_hashes_file_custom_directory(self, numpy_savez):
# Generate bin_hashes file in custom directory
directory_for_file = os.path.join(os.path.split(os.path.dirname(__file__))[0], "unittest", "test_data",
"bin_hash")
self.IR.generate_bin_hashes_file(path_for_file=directory_for_file)
numpy_savez.assert_called_once()
log.info('Test for function generate_bin_hashes_file with custom folder passed')
@generate(*[({'order': '1,0,2'}, {'order': [1, 0, 2]}),
({'order': '1'}, {'order': 1})])
def test_normalize_attr(self, test_data, reference):
result_dict = IREngine._IREngine__normalize_attrs(attrs=test_data)
self.assertTrue(reference == result_dict, 'Test on function normalize_attr failed')
log.info('Test for function normalize_attr passed')
def test_load_bin_hashes(self):
path_for_file = self.IR.generate_bin_hashes_file()
IR = IREngine(path_to_xml=str(self.xml), path_to_bin=str(path_for_file))
is_ok = True
# Check for constant nodes
const_nodes = IR.graph.get_op_nodes(type='Const')
for node in const_nodes:
if not node.has_valid('hashes'):
log.error('Constant node {} do not include hashes'.format(node.name))
is_ok = False
# Check for TensorIterator Body
ti_nodes = IR.graph.get_op_nodes(type='TensorIterator')
for ti in ti_nodes:
if not ti.has_valid('body'):
log.error('TensorIterator has not body attrubite for node: {}'.format(ti.name))
else:
const_ti_nodes = ti.body.graph.get_op_nodes(type='Const')
for node in const_ti_nodes:
if not node.has_valid('hashes'):
log.error('Constant node {} do not include hashes'.format(node.name))
is_ok = False
self.assertTrue(is_ok, 'Test for function load_bin_hashes failed')
os.remove(path_for_file)
@generate(*[
("0", True),
("1", True),
("-1", True),
("-", False),
("+1", True),
("+", False),
("1.0", False),
("-1.0", False),
("1.5", False),
("+1.5", False),
("abracadabra", False),
])
def test_isint(self, value, result):
self.assertEqual(IREngine._IREngine__isint(value), result)