[MO] Fix Interpolate-11 in MO (#17002)

* Fix Interpolate-11 in MO

* Add forgotten file

* Fix output type of TopK-11

* Do not force precision on port 1 for mode scales

* Update tools/mo/openvino/tools/mo/ops/interpolate.py

---------

Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com>
Co-authored-by: Andrei Kochin <andrei.kochin@intel.com>
Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
Maxim Vafin 2023-04-20 09:51:38 +02:00 committed by GitHub
parent 5026aa044a
commit 552143c9cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 57 additions and 23 deletions

View File

@ -49,6 +49,7 @@ def infer_for_opsetX(node: Node, opset: str):
if node.shape_calculation_mode == 'sizes': if node.shape_calculation_mode == 'sizes':
dst_shape = node.in_port(1).data.get_value() dst_shape = node.in_port(1).data.get_value()
assert dst_shape is not None assert dst_shape is not None
if node.get_opset() != "opset11":
correct_scales_using_dst_shape(node, dst_shape, src_shape, axes) correct_scales_using_dst_shape(node, dst_shape, src_shape, axes)
for i, axis in enumerate(axes): for i, axis in enumerate(axes):
output_shape[axis] = dst_shape[i] output_shape[axis] = dst_shape[i]
@ -151,12 +152,13 @@ class Interpolate(Op):
'pads_end': 0, 'pads_end': 0,
'infer': self.infer, 'infer': self.infer,
'force_precision_in_ports': {1: 'int64'}, 'force_precision_in_ports': {1: 'int64'},
'in_ports_count': 2, 'in_ports_count': 2,
'out_ports_count': 1, 'out_ports_count': 1,
} }
super().__init__(graph, mandatory_props, attrs) super().__init__(graph, mandatory_props, attrs)
if self.attrs['version'] == 'opset11' and self.attrs['shape_calculation_mode'] != 'sizes':
del self.attrs['force_precision_in_ports']
def supported_attrs(self): def supported_attrs(self):
opset = self.get_opset() opset = self.get_opset()

View File

@ -76,7 +76,7 @@ class TopK(Op):
@staticmethod @staticmethod
def type_infer(node): def type_infer(node):
node.out_port(0).set_data_type(node.in_port(0).get_data_type()) node.out_port(0).set_data_type(node.in_port(0).get_data_type())
if node.get_opset() == 'opset3': if node.get_opset() in ['opset3', 'opset11']:
node.out_port(1).set_data_type(node.index_element_type) node.out_port(1).set_data_type(node.index_element_type)
else: else:
node.out_port(1).set_data_type(np.int32) node.out_port(1).set_data_type(np.int32)

View File

@ -159,6 +159,8 @@ def convert_inputs_of_specific_ops(graph: Graph):
} }
for node in graph.get_op_nodes(): for node in graph.get_op_nodes():
if node.soft_get('version') != "opset11":
# opset11 cannot be produced by legacy MO frontends, it can only be read by MO IR Reader
if node.soft_get('type') in type_port: if node.soft_get('type') in type_port:
ports_to_update = type_port[node.soft_get('type')] ports_to_update = type_port[node.soft_get('type')]
for port_id, precision in ports_to_update.items(): for port_id, precision in ports_to_update.items():

View File

@ -8,7 +8,7 @@ from pathlib import Path
import openvino.runtime.opset11 as opset11 import openvino.runtime.opset11 as opset11
import openvino.runtime.opset10 as opset10 import openvino.runtime.opset10 as opset10
from openvino.runtime import Model, serialize from openvino.runtime import Model, serialize, Core
from openvino.tools.mo.utils.ir_reader.restore_graph import restore_graph_from_ir, save_restored_graph from openvino.tools.mo.utils.ir_reader.restore_graph import restore_graph_from_ir, save_restored_graph
from openvino.tools.mo.utils.logger import init_logger from openvino.tools.mo.utils.logger import init_logger
@ -28,6 +28,8 @@ class TestOps(unittest.TestCase):
save_restored_graph(graph, tmp, {}, name) save_restored_graph(graph, tmp, {}, name)
# restore 2 times to validate that after save graph doesn't lose attributes etc. # restore 2 times to validate that after save graph doesn't lose attributes etc.
graph, _ = restore_graph_from_ir(model_xml, model_bin) graph, _ = restore_graph_from_ir(model_xml, model_bin)
# check that re-saved model can be read in runtime
Core().read_model(model_xml)
return graph return graph
def test_topk_11(self): def test_topk_11(self):
@ -43,6 +45,7 @@ class TestOps(unittest.TestCase):
topk_node = graph.get_op_nodes(op="TopK")[0] topk_node = graph.get_op_nodes(op="TopK")[0]
self.assertEqual(topk_node["version"], "opset11") self.assertEqual(topk_node["version"], "opset11")
self.assertTrue(topk_node["stable"]) self.assertTrue(topk_node["stable"])
self.assertEqual(topk_node["index_element_type"], np.int32)
def test_interpolate_11(self): def test_interpolate_11(self):
data_shape = [6, 12, 10, 24] data_shape = [6, 12, 10, 24]
@ -54,6 +57,33 @@ class TestOps(unittest.TestCase):
graph = TestOps.check_graph_can_save(model, 'interpolate_model') graph = TestOps.check_graph_can_save(model, 'interpolate_model')
interpolate_node = graph.get_op_nodes(op="Interpolate")[0] interpolate_node = graph.get_op_nodes(op="Interpolate")[0]
self.assertEqual(interpolate_node["version"], "opset11") self.assertEqual(interpolate_node["version"], "opset11")
self.assertTrue("force_precision_in_ports" in interpolate_node)
self.assertEqual(interpolate_node["force_precision_in_ports"], {1: 'int64'})
def test_interpolate_11_scales(self):
data_shape = [6, 12, 10, 24]
data_parameter = opset11.parameter(
data_shape, name="Data", dtype=np.float32)
interpolate = opset11.interpolate(data_parameter, np.float32(
[2., 2.]), "nearest", "scales", axes=np.int32([2, 3]), name="Interpolate_11")
model = Model(interpolate, [data_parameter])
graph = TestOps.check_graph_can_save(model, 'interpolate_model')
interpolate_node = graph.get_op_nodes(op="Interpolate")[0]
self.assertEqual(interpolate_node["version"], "opset11")
self.assertTrue("force_precision_in_ports" not in interpolate_node)
def test_interpolate_11_no_axes(self):
data_shape = [6, 12, 10, 24]
data_parameter = opset11.parameter(
data_shape, name="Data", dtype=np.float32)
interpolate = opset11.interpolate(data_parameter, np.int32(
[6, 12, 20, 48]), "nearest", "sizes", name="Interpolate_11")
model = Model(interpolate, [data_parameter])
graph = TestOps.check_graph_can_save(model, 'interpolate_model')
interpolate_node = graph.get_op_nodes(op="Interpolate")[0]
self.assertEqual(interpolate_node["version"], "opset11")
self.assertTrue("force_precision_in_ports" in interpolate_node)
self.assertEqual(interpolate_node["force_precision_in_ports"], {1: 'int64'})
def test_interpolate_4(self): def test_interpolate_4(self):
data_shape = [6, 12, 10, 24] data_shape = [6, 12, 10, 24]