[MO] Fix Interpolate-11 in MO (#17002)

* Fix Interpolate-11 in MO

* Add forgotten file

* Fix output type of TopK-11

* Do not force precision on port 1 for mode scales

* Update tools/mo/openvino/tools/mo/ops/interpolate.py

---------

Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
Co-authored-by: Andrei Kochin <andrei.kochin@intel.com>
Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Maxim Vafin 2023-04-20 09:51:38 +02:00 committed by GitHub
parent 5026aa044a
commit 552143c9cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 57 additions and 23 deletions

View File

@ -49,7 +49,8 @@ def infer_for_opsetX(node: Node, opset: str):
if node.shape_calculation_mode == 'sizes':
dst_shape = node.in_port(1).data.get_value()
assert dst_shape is not None
correct_scales_using_dst_shape(node, dst_shape, src_shape, axes)
if node.get_opset() != "opset11":
correct_scales_using_dst_shape(node, dst_shape, src_shape, axes)
for i, axis in enumerate(axes):
output_shape[axis] = dst_shape[i]
else:
@ -151,12 +152,13 @@ class Interpolate(Op):
'pads_end': 0,
'infer': self.infer,
'force_precision_in_ports': {1: 'int64'},
'in_ports_count': 2,
'out_ports_count': 1,
}
super().__init__(graph, mandatory_props, attrs)
if self.attrs['version'] == 'opset11' and self.attrs['shape_calculation_mode'] != 'sizes':
del self.attrs['force_precision_in_ports']
def supported_attrs(self):
opset = self.get_opset()

View File

@ -76,7 +76,7 @@ class TopK(Op):
@staticmethod
def type_infer(node):
node.out_port(0).set_data_type(node.in_port(0).get_data_type())
if node.get_opset() == 'opset3':
if node.get_opset() in ['opset3', 'opset11']:
node.out_port(1).set_data_type(node.index_element_type)
else:
node.out_port(1).set_data_type(np.int32)

View File

@ -159,25 +159,27 @@ def convert_inputs_of_specific_ops(graph: Graph):
}
for node in graph.get_op_nodes():
if node.soft_get('type') in type_port:
ports_to_update = type_port[node.soft_get('type')]
for port_id, precision in ports_to_update.items():
if port_id in node.in_ports() and not node.in_port(port_id).disconnected():
log.debug('Converting value for the input port "{}" of op "{}" to "{}".'
''.format(port_id, node.soft_get('name', node.id), precision))
in_port = node.in_port(port_id)
np_type = data_type_str_to_np(precision)
if in_port.get_source().node.type == 'Const':
const_node = node.in_port(port_id).get_source().node
const_type = const_node.out_port(0).get_data_type()
if np.issubdtype(const_type, np.integer) and np.issubdtype(np_type, np.integer):
# do not convert Constant value if both source and destination types are of integer types
# otherwise, it affects compatibility of MO IR Engine and TF FE
# TF FE intents to use original model type for layers if it is possible
continue
convert_const_node_value_type(const_node, np_type)
else:
in_port.get_connection().insert_node(Cast(graph, {'dst_type': np_type}).create_node())
if node.soft_get('version') != "opset11":
# opset11 cannot be produced by legacy MO frontends, it can only be read by MO IR Reader
if node.soft_get('type') in type_port:
ports_to_update = type_port[node.soft_get('type')]
for port_id, precision in ports_to_update.items():
if port_id in node.in_ports() and not node.in_port(port_id).disconnected():
log.debug('Converting value for the input port "{}" of op "{}" to "{}".'
''.format(port_id, node.soft_get('name', node.id), precision))
in_port = node.in_port(port_id)
np_type = data_type_str_to_np(precision)
if in_port.get_source().node.type == 'Const':
const_node = node.in_port(port_id).get_source().node
const_type = const_node.out_port(0).get_data_type()
if np.issubdtype(const_type, np.integer) and np.issubdtype(np_type, np.integer):
# do not convert Constant value if both source and destination types are of integer types
# otherwise, it affects compatibility of MO IR Engine and TF FE
# TF FE intents to use original model type for layers if it is possible
continue
convert_const_node_value_type(const_node, np_type)
else:
in_port.get_connection().insert_node(Cast(graph, {'dst_type': np_type}).create_node())
def set_default_tensor_names_for_parameters_results(graph: Graph):

View File

@ -8,7 +8,7 @@ from pathlib import Path
import openvino.runtime.opset11 as opset11
import openvino.runtime.opset10 as opset10
from openvino.runtime import Model, serialize
from openvino.runtime import Model, serialize, Core
from openvino.tools.mo.utils.ir_reader.restore_graph import restore_graph_from_ir, save_restored_graph
from openvino.tools.mo.utils.logger import init_logger
@ -28,6 +28,8 @@ class TestOps(unittest.TestCase):
save_restored_graph(graph, tmp, {}, name)
# restore 2 times to validate that after save graph doesn't lose attributes etc.
graph, _ = restore_graph_from_ir(model_xml, model_bin)
# check that re-saved model can be read in runtime
Core().read_model(model_xml)
return graph
def test_topk_11(self):
@ -43,6 +45,7 @@ class TestOps(unittest.TestCase):
topk_node = graph.get_op_nodes(op="TopK")[0]
self.assertEqual(topk_node["version"], "opset11")
self.assertTrue(topk_node["stable"])
self.assertEqual(topk_node["index_element_type"], np.int32)
def test_interpolate_11(self):
data_shape = [6, 12, 10, 24]
@ -54,6 +57,33 @@ class TestOps(unittest.TestCase):
graph = TestOps.check_graph_can_save(model, 'interpolate_model')
interpolate_node = graph.get_op_nodes(op="Interpolate")[0]
self.assertEqual(interpolate_node["version"], "opset11")
self.assertTrue("force_precision_in_ports" in interpolate_node)
self.assertEqual(interpolate_node["force_precision_in_ports"], {1: 'int64'})
def test_interpolate_11_scales(self):
data_shape = [6, 12, 10, 24]
data_parameter = opset11.parameter(
data_shape, name="Data", dtype=np.float32)
interpolate = opset11.interpolate(data_parameter, np.float32(
[2., 2.]), "nearest", "scales", axes=np.int32([2, 3]), name="Interpolate_11")
model = Model(interpolate, [data_parameter])
graph = TestOps.check_graph_can_save(model, 'interpolate_model')
interpolate_node = graph.get_op_nodes(op="Interpolate")[0]
self.assertEqual(interpolate_node["version"], "opset11")
self.assertTrue("force_precision_in_ports" not in interpolate_node)
def test_interpolate_11_no_axes(self):
data_shape = [6, 12, 10, 24]
data_parameter = opset11.parameter(
data_shape, name="Data", dtype=np.float32)
interpolate = opset11.interpolate(data_parameter, np.int32(
[6, 12, 20, 48]), "nearest", "sizes", name="Interpolate_11")
model = Model(interpolate, [data_parameter])
graph = TestOps.check_graph_can_save(model, 'interpolate_model')
interpolate_node = graph.get_op_nodes(op="Interpolate")[0]
self.assertEqual(interpolate_node["version"], "opset11")
self.assertTrue("force_precision_in_ports" in interpolate_node)
self.assertEqual(interpolate_node["force_precision_in_ports"], {1: 'int64'})
def test_interpolate_4(self):
data_shape = [6, 12, 10, 24]