diff --git a/tools/mo/openvino/tools/mo/pipeline/common.py b/tools/mo/openvino/tools/mo/pipeline/common.py index 8ce1b038eb1..a69c29cc0b8 100644 --- a/tools/mo/openvino/tools/mo/pipeline/common.py +++ b/tools/mo/openvino/tools/mo/pipeline/common.py @@ -169,16 +169,23 @@ def convert_inputs_of_specific_ops(graph: Graph): ''.format(port_id, node.soft_get('name', node.id), precision)) in_port = node.in_port(port_id) np_type = data_type_str_to_np(precision) - if in_port.get_source().node.type == 'Const': - const_node = node.in_port(port_id).get_source().node - const_type = const_node.out_port(0).get_data_type() - if np.issubdtype(const_type, np.integer) and np.issubdtype(np_type, np.integer): + in_node = node.in_port(port_id).get_source().node + in_type = in_node.out_port(0).get_data_type() + + if in_node.type == 'Const': + if np.issubdtype(in_type, np.integer) and np.issubdtype(np_type, np.integer): # do not convert Constant value if both source and destination types are of integer types # otherwise, it affects compatibility of MO IR Engine and TF FE # TF FE intents to use original model type for layers if it is possible continue - convert_const_node_value_type(const_node, np_type) + convert_const_node_value_type(in_node, np_type) else: + allowed_int_types = [np.int32, np.int64, np.uint32, np.uint64] + if in_type in allowed_int_types and np_type in allowed_int_types: + # do not convert if both source and destination types are within the set of + # int32/int64/uint32/uint64. It prevents from getting different IRs from the original + # cpp serializer and from the legacy serialized when restored with ir_reader_utils + continue in_port.get_connection().insert_node(Cast(graph, {'dst_type': np_type}).create_node()) diff --git a/tools/mo/openvino/tools/mo/utils/class_registration.py b/tools/mo/openvino/tools/mo/utils/class_registration.py index 74cf6681572..48777c4a0f8 100644 --- a/tools/mo/openvino/tools/mo/utils/class_registration.py +++ b/tools/mo/openvino/tools/mo/utils/class_registration.py @@ -90,8 +90,8 @@ def _update(cls, registered_list: list, registered_dict: dict, key: str, enabled if hasattr(c, key) and getattr(c, key) is not None: k = getattr(c, key) if k.lower() in new_keys_lower: - log.warning('Attempt to register of custom name {} for the second time as class {}. ' - 'Note that custom names are case-insensitive. ' + refer_to_faq_msg(55), k, c) + # log.warning('Attempt to register of custom name {} for the second time as class {}. ' + # 'Note that custom names are case-insensitive. ' + refer_to_faq_msg(55), k, c) continue else: new_keys_lower[k.lower()] = k diff --git a/tools/mo/unit_tests/mo/utils/ir_reader/restore_graph_test.py b/tools/mo/unit_tests/mo/utils/ir_reader/restore_graph_test.py index b6478dc2e27..f96acc15a32 100644 --- a/tools/mo/unit_tests/mo/utils/ir_reader/restore_graph_test.py +++ b/tools/mo/unit_tests/mo/utils/ir_reader/restore_graph_test.py @@ -4,8 +4,17 @@ import os import tempfile import unittest + +import numpy as np from defusedxml.common import EntitiesForbidden -from openvino.tools.mo.utils.ir_reader.restore_graph import restore_graph_from_ir + +import openvino.tools.mo.utils.ir_reader.extenders.convert_extender +from openvino.tools.mo.middle.passes.convert_data_type import destination_type_to_np_data_type +from openvino.tools.mo.middle.passes.infer import type_infer +from openvino.tools.mo.utils.graph import Node +from openvino.tools.mo.utils.ir_engine.compare_graphs import compare_graphs +from openvino.tools.mo.utils.ir_reader.extender import Extender +from openvino.tools.mo.utils.ir_reader.restore_graph import restore_graph_from_ir, save_restored_graph class TestIRReader(unittest.TestCase): @@ -127,3 +136,320 @@ class TestIRReader(unittest.TestCase): malformed_ir_file.close() self.assertRaises(ValueError, restore_graph_from_ir, malformed_ir_file.name) os.remove(malformed_ir_file.name) + + +class PatchedConvert_extender(Extender): + """ + Original ConvertExtender contains setting 'stop_value_propagation', and because axis value goes to the Gather + through Convert during shape_infer axis turns out to be None and shape_infer fails. + For purposes of this unit-test we patch extender so that it will not add 'stop_value_propagation' attr. + Outside the unit-test Convert_extender is left unchanged because inserting 'stop_value_propagation' + is needed in other cases for CompressQuantizeWeights. + See description of openvino/tools/mo/utils/ir_reader/extenders/convert_extender.py + """ + op = 'Convert' + + @staticmethod + def extend(op: Node): + op['dst_type'] = destination_type_to_np_data_type(op.destination_type) + + +class TestIRSerializeAndRestore(unittest.TestCase): + test_ir_xml = """ + + + + + + + 1 + 128 + + + + + + + + 4 + + + + + + + + 1 + + + + + + + + 1 + 128 + + + 10 + + + 1 + + + + + 1 + 10 + + + + + + + 1 + 10 + + + + + + + + + + + + + """ + + def test_save_and_restore(self): + original_xml_file = tempfile.NamedTemporaryFile(delete=False) + original_xml_file.write(bytes(self.test_ir_xml, 'utf-8')) + original_xml_file.close() + + axis_const_blob = np.array([1], dtype=np.int32) + original_bin_file = tempfile.NamedTemporaryFile(mode='wb', delete=False) + axis_const_blob.tofile(original_bin_file) + original_bin_file.close() + + graph_orig, _ = restore_graph_from_ir(original_xml_file.name, original_bin_file.name) + type_infer(graph_orig) + os.remove(original_xml_file.name) + os.remove(original_bin_file.name) + + restored_ir_dir = tempfile.TemporaryDirectory() + + save_restored_graph(graph_orig.copy(), restored_ir_dir.name, {}) + restored_xml_name = restored_ir_dir.name + '/test_ir.xml' + restored_bin_name = restored_ir_dir.name + '/test_ir.bin' + + # Gather is listed in convert_inputs_of_specific_ops as 'Gather': {2: 'int64'}, but + # no additional converts will be inserted, because input is int32 + graph_restored, _ = restore_graph_from_ir(restored_xml_name, restored_bin_name) + os.remove(restored_xml_name) + os.remove(restored_bin_name) + os.remove(restored_xml_name.replace('xml', 'mapping')) + os.removedirs(restored_ir_dir.name) + + flag, msg = compare_graphs(graph_orig, graph_restored, 'result', 'gather/sink_port_0') + self.assertTrue(flag, msg) + + test_ir_xml_with_i8 = """ + + + + + + + 1 + 128 + + + + + + + + 10 + + + + + + + + 1 + + + + + + + + 1 + 128 + + + 10 + + + 1 + + + + + 1 + 10 + + + + + + + 1 + 10 + + + + + + + + + + + + + """ + + test_ir_xml_with_convert = """ + + + + + + + 1 + 128 + + + + + + + + 10 + + + + + + + + 1 + + + + + + + + 1 + + + + + 1 + + + + + + + + 1 + 128 + + + 10 + + + 1 + + + + + 1 + 10 + + + + + + + 1 + 10 + + + + + + + + + + + + + + """ + + def test_save_and_restore_with_converts(self): + original_xml_file = tempfile.NamedTemporaryFile(delete=False) + original_xml_file.write(bytes(self.test_ir_xml_with_i8, 'utf-8')) + original_xml_file.close() + + gather_axis_blob = np.array([1], dtype=np.int8) + original_bin_file = tempfile.NamedTemporaryFile(mode='wb', delete=False) + gather_axis_blob.tofile(original_bin_file) + original_bin_file.close() + + graph_orig, _ = restore_graph_from_ir(original_xml_file.name, original_bin_file.name) + type_infer(graph_orig) + os.remove(original_xml_file.name) + + restored_ir_dir = tempfile.TemporaryDirectory() + save_restored_graph(graph_orig.copy(), restored_ir_dir.name, {}) + + ir_file_with_convert = tempfile.NamedTemporaryFile(delete=False) + ir_file_with_convert.write(bytes(self.test_ir_xml_with_convert, 'utf-8')) + ir_file_with_convert.close() + + from openvino.tools.mo.utils.ir_reader.extender import Extender + + if 'Convert' in Extender.registered_ops: + Extender.registered_ops['Convert'] = PatchedConvert_extender + + graph_with_convert, _ = restore_graph_from_ir(ir_file_with_convert.name, original_bin_file.name) + type_infer(graph_with_convert) + os.remove(ir_file_with_convert.name) + + if 'Convert' in Extender.registered_ops: + Extender.registered_ops['Convert'] = openvino.tools.mo.utils.ir_reader.extenders.convert_extender.Convert_extender + + restored_xml_file = restored_ir_dir.name + '/test_ir.xml' + restored_bin_file = restored_ir_dir.name + '/test_ir.bin' + + # Gather is listed in convert_inputs_of_specific_ops as 'Gather': {2: 'int64'}, + # converts from int8 to int64 will be inserted + graph_restored, _ = restore_graph_from_ir(restored_xml_file, restored_bin_file) + + os.remove(original_bin_file.name) + os.remove(restored_xml_file) + os.remove(restored_bin_file) + os.remove(restored_xml_file.replace('xml', 'mapping')) + os.removedirs(restored_ir_dir.name) + + flag, msg = compare_graphs(graph_orig, graph_restored, 'result', 'gather/sink_port_0') + self.assertTrue(flag, msg)