[MO] Fix Interpolate-11 in MO (#17002)

* Fix Interpolate-11 in MO

* Add forgotten file

* Fix output type of TopK-11

* Do not force precision on port 1 for mode scales

* Update tools/mo/openvino/tools/mo/ops/interpolate.py

---------

Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
Co-authored-by: Andrei Kochin <andrei.kochin@intel.com>
Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Maxim Vafin 2023-04-20 09:51:38 +02:00 committed by GitHub
parent 5026aa044a
commit 552143c9cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 57 additions and 23 deletions

View File

@ -49,7 +49,8 @@ def infer_for_opsetX(node: Node, opset: str):
if node.shape_calculation_mode == 'sizes': if node.shape_calculation_mode == 'sizes':
dst_shape = node.in_port(1).data.get_value() dst_shape = node.in_port(1).data.get_value()
assert dst_shape is not None assert dst_shape is not None
correct_scales_using_dst_shape(node, dst_shape, src_shape, axes) if node.get_opset() != "opset11":
correct_scales_using_dst_shape(node, dst_shape, src_shape, axes)
for i, axis in enumerate(axes): for i, axis in enumerate(axes):
output_shape[axis] = dst_shape[i] output_shape[axis] = dst_shape[i]
else: else:
@ -151,12 +152,13 @@ class Interpolate(Op):
'pads_end': 0, 'pads_end': 0,
'infer': self.infer, 'infer': self.infer,
'force_precision_in_ports': {1: 'int64'}, 'force_precision_in_ports': {1: 'int64'},
'in_ports_count': 2, 'in_ports_count': 2,
'out_ports_count': 1, 'out_ports_count': 1,
} }
super().__init__(graph, mandatory_props, attrs) super().__init__(graph, mandatory_props, attrs)
if self.attrs['version'] == 'opset11' and self.attrs['shape_calculation_mode'] != 'sizes':
del self.attrs['force_precision_in_ports']
def supported_attrs(self): def supported_attrs(self):
opset = self.get_opset() opset = self.get_opset()

View File

@ -76,7 +76,7 @@ class TopK(Op):
@staticmethod @staticmethod
def type_infer(node): def type_infer(node):
node.out_port(0).set_data_type(node.in_port(0).get_data_type()) node.out_port(0).set_data_type(node.in_port(0).get_data_type())
if node.get_opset() == 'opset3': if node.get_opset() in ['opset3', 'opset11']:
node.out_port(1).set_data_type(node.index_element_type) node.out_port(1).set_data_type(node.index_element_type)
else: else:
node.out_port(1).set_data_type(np.int32) node.out_port(1).set_data_type(np.int32)

View File

@ -159,25 +159,27 @@ def convert_inputs_of_specific_ops(graph: Graph):
} }
for node in graph.get_op_nodes(): for node in graph.get_op_nodes():
if node.soft_get('type') in type_port: if node.soft_get('version') != "opset11":
ports_to_update = type_port[node.soft_get('type')] # opset11 cannot be produced by legacy MO frontends, it can only be read by MO IR Reader
for port_id, precision in ports_to_update.items(): if node.soft_get('type') in type_port:
if port_id in node.in_ports() and not node.in_port(port_id).disconnected(): ports_to_update = type_port[node.soft_get('type')]
log.debug('Converting value for the input port "{}" of op "{}" to "{}".' for port_id, precision in ports_to_update.items():
''.format(port_id, node.soft_get('name', node.id), precision)) if port_id in node.in_ports() and not node.in_port(port_id).disconnected():
in_port = node.in_port(port_id) log.debug('Converting value for the input port "{}" of op "{}" to "{}".'
np_type = data_type_str_to_np(precision) ''.format(port_id, node.soft_get('name', node.id), precision))
if in_port.get_source().node.type == 'Const': in_port = node.in_port(port_id)
const_node = node.in_port(port_id).get_source().node np_type = data_type_str_to_np(precision)
const_type = const_node.out_port(0).get_data_type() if in_port.get_source().node.type == 'Const':
if np.issubdtype(const_type, np.integer) and np.issubdtype(np_type, np.integer): const_node = node.in_port(port_id).get_source().node
# do not convert Constant value if both source and destination types are of integer types const_type = const_node.out_port(0).get_data_type()
# otherwise, it affects compatibility of MO IR Engine and TF FE if np.issubdtype(const_type, np.integer) and np.issubdtype(np_type, np.integer):
# TF FE intents to use original model type for layers if it is possible # do not convert Constant value if both source and destination types are of integer types
continue # otherwise, it affects compatibility of MO IR Engine and TF FE
convert_const_node_value_type(const_node, np_type) # TF FE intents to use original model type for layers if it is possible
else: continue
in_port.get_connection().insert_node(Cast(graph, {'dst_type': np_type}).create_node()) convert_const_node_value_type(const_node, np_type)
else:
in_port.get_connection().insert_node(Cast(graph, {'dst_type': np_type}).create_node())
def set_default_tensor_names_for_parameters_results(graph: Graph): def set_default_tensor_names_for_parameters_results(graph: Graph):

View File

@ -8,7 +8,7 @@ from pathlib import Path
import openvino.runtime.opset11 as opset11 import openvino.runtime.opset11 as opset11
import openvino.runtime.opset10 as opset10 import openvino.runtime.opset10 as opset10
from openvino.runtime import Model, serialize from openvino.runtime import Model, serialize, Core
from openvino.tools.mo.utils.ir_reader.restore_graph import restore_graph_from_ir, save_restored_graph from openvino.tools.mo.utils.ir_reader.restore_graph import restore_graph_from_ir, save_restored_graph
from openvino.tools.mo.utils.logger import init_logger from openvino.tools.mo.utils.logger import init_logger
@ -28,6 +28,8 @@ class TestOps(unittest.TestCase):
save_restored_graph(graph, tmp, {}, name) save_restored_graph(graph, tmp, {}, name)
# restore 2 times to validate that after save graph doesn't lose attributes etc. # restore 2 times to validate that after save graph doesn't lose attributes etc.
graph, _ = restore_graph_from_ir(model_xml, model_bin) graph, _ = restore_graph_from_ir(model_xml, model_bin)
# check that re-saved model can be read in runtime
Core().read_model(model_xml)
return graph return graph
def test_topk_11(self): def test_topk_11(self):
@ -43,6 +45,7 @@ class TestOps(unittest.TestCase):
topk_node = graph.get_op_nodes(op="TopK")[0] topk_node = graph.get_op_nodes(op="TopK")[0]
self.assertEqual(topk_node["version"], "opset11") self.assertEqual(topk_node["version"], "opset11")
self.assertTrue(topk_node["stable"]) self.assertTrue(topk_node["stable"])
self.assertEqual(topk_node["index_element_type"], np.int32)
def test_interpolate_11(self): def test_interpolate_11(self):
data_shape = [6, 12, 10, 24] data_shape = [6, 12, 10, 24]
@ -54,6 +57,33 @@ class TestOps(unittest.TestCase):
graph = TestOps.check_graph_can_save(model, 'interpolate_model') graph = TestOps.check_graph_can_save(model, 'interpolate_model')
interpolate_node = graph.get_op_nodes(op="Interpolate")[0] interpolate_node = graph.get_op_nodes(op="Interpolate")[0]
self.assertEqual(interpolate_node["version"], "opset11") self.assertEqual(interpolate_node["version"], "opset11")
self.assertTrue("force_precision_in_ports" in interpolate_node)
self.assertEqual(interpolate_node["force_precision_in_ports"], {1: 'int64'})
def test_interpolate_11_scales(self):
data_shape = [6, 12, 10, 24]
data_parameter = opset11.parameter(
data_shape, name="Data", dtype=np.float32)
interpolate = opset11.interpolate(data_parameter, np.float32(
[2., 2.]), "nearest", "scales", axes=np.int32([2, 3]), name="Interpolate_11")
model = Model(interpolate, [data_parameter])
graph = TestOps.check_graph_can_save(model, 'interpolate_model')
interpolate_node = graph.get_op_nodes(op="Interpolate")[0]
self.assertEqual(interpolate_node["version"], "opset11")
self.assertTrue("force_precision_in_ports" not in interpolate_node)
def test_interpolate_11_no_axes(self):
data_shape = [6, 12, 10, 24]
data_parameter = opset11.parameter(
data_shape, name="Data", dtype=np.float32)
interpolate = opset11.interpolate(data_parameter, np.int32(
[6, 12, 20, 48]), "nearest", "sizes", name="Interpolate_11")
model = Model(interpolate, [data_parameter])
graph = TestOps.check_graph_can_save(model, 'interpolate_model')
interpolate_node = graph.get_op_nodes(op="Interpolate")[0]
self.assertEqual(interpolate_node["version"], "opset11")
self.assertTrue("force_precision_in_ports" in interpolate_node)
self.assertEqual(interpolate_node["force_precision_in_ports"], {1: 'int64'})
def test_interpolate_4(self): def test_interpolate_4(self):
data_shape = [6, 12, 10, 24] data_shape = [6, 12, 10, 24]