[MO] Fix Interpolate-11 in MO (#17002)
* Fix Interpolate-11 in MO * Add forgotten file * Fix output type of TopK-11 * Do not force precision on port 1 for mode scales * Update tools/mo/openvino/tools/mo/ops/interpolate.py --------- Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com> Co-authored-by: Andrei Kochin <andrei.kochin@intel.com> Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
5026aa044a
commit
552143c9cd
@ -49,7 +49,8 @@ def infer_for_opsetX(node: Node, opset: str):
|
|||||||
if node.shape_calculation_mode == 'sizes':
|
if node.shape_calculation_mode == 'sizes':
|
||||||
dst_shape = node.in_port(1).data.get_value()
|
dst_shape = node.in_port(1).data.get_value()
|
||||||
assert dst_shape is not None
|
assert dst_shape is not None
|
||||||
correct_scales_using_dst_shape(node, dst_shape, src_shape, axes)
|
if node.get_opset() != "opset11":
|
||||||
|
correct_scales_using_dst_shape(node, dst_shape, src_shape, axes)
|
||||||
for i, axis in enumerate(axes):
|
for i, axis in enumerate(axes):
|
||||||
output_shape[axis] = dst_shape[i]
|
output_shape[axis] = dst_shape[i]
|
||||||
else:
|
else:
|
||||||
@ -151,12 +152,13 @@ class Interpolate(Op):
|
|||||||
'pads_end': 0,
|
'pads_end': 0,
|
||||||
|
|
||||||
'infer': self.infer,
|
'infer': self.infer,
|
||||||
|
|
||||||
'force_precision_in_ports': {1: 'int64'},
|
'force_precision_in_ports': {1: 'int64'},
|
||||||
'in_ports_count': 2,
|
'in_ports_count': 2,
|
||||||
'out_ports_count': 1,
|
'out_ports_count': 1,
|
||||||
}
|
}
|
||||||
super().__init__(graph, mandatory_props, attrs)
|
super().__init__(graph, mandatory_props, attrs)
|
||||||
|
if self.attrs['version'] == 'opset11' and self.attrs['shape_calculation_mode'] != 'sizes':
|
||||||
|
del self.attrs['force_precision_in_ports']
|
||||||
|
|
||||||
def supported_attrs(self):
|
def supported_attrs(self):
|
||||||
opset = self.get_opset()
|
opset = self.get_opset()
|
||||||
|
@ -76,7 +76,7 @@ class TopK(Op):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def type_infer(node):
|
def type_infer(node):
|
||||||
node.out_port(0).set_data_type(node.in_port(0).get_data_type())
|
node.out_port(0).set_data_type(node.in_port(0).get_data_type())
|
||||||
if node.get_opset() == 'opset3':
|
if node.get_opset() in ['opset3', 'opset11']:
|
||||||
node.out_port(1).set_data_type(node.index_element_type)
|
node.out_port(1).set_data_type(node.index_element_type)
|
||||||
else:
|
else:
|
||||||
node.out_port(1).set_data_type(np.int32)
|
node.out_port(1).set_data_type(np.int32)
|
||||||
|
@ -159,25 +159,27 @@ def convert_inputs_of_specific_ops(graph: Graph):
|
|||||||
}
|
}
|
||||||
|
|
||||||
for node in graph.get_op_nodes():
|
for node in graph.get_op_nodes():
|
||||||
if node.soft_get('type') in type_port:
|
if node.soft_get('version') != "opset11":
|
||||||
ports_to_update = type_port[node.soft_get('type')]
|
# opset11 cannot be produced by legacy MO frontends, it can only be read by MO IR Reader
|
||||||
for port_id, precision in ports_to_update.items():
|
if node.soft_get('type') in type_port:
|
||||||
if port_id in node.in_ports() and not node.in_port(port_id).disconnected():
|
ports_to_update = type_port[node.soft_get('type')]
|
||||||
log.debug('Converting value for the input port "{}" of op "{}" to "{}".'
|
for port_id, precision in ports_to_update.items():
|
||||||
''.format(port_id, node.soft_get('name', node.id), precision))
|
if port_id in node.in_ports() and not node.in_port(port_id).disconnected():
|
||||||
in_port = node.in_port(port_id)
|
log.debug('Converting value for the input port "{}" of op "{}" to "{}".'
|
||||||
np_type = data_type_str_to_np(precision)
|
''.format(port_id, node.soft_get('name', node.id), precision))
|
||||||
if in_port.get_source().node.type == 'Const':
|
in_port = node.in_port(port_id)
|
||||||
const_node = node.in_port(port_id).get_source().node
|
np_type = data_type_str_to_np(precision)
|
||||||
const_type = const_node.out_port(0).get_data_type()
|
if in_port.get_source().node.type == 'Const':
|
||||||
if np.issubdtype(const_type, np.integer) and np.issubdtype(np_type, np.integer):
|
const_node = node.in_port(port_id).get_source().node
|
||||||
# do not convert Constant value if both source and destination types are of integer types
|
const_type = const_node.out_port(0).get_data_type()
|
||||||
# otherwise, it affects compatibility of MO IR Engine and TF FE
|
if np.issubdtype(const_type, np.integer) and np.issubdtype(np_type, np.integer):
|
||||||
# TF FE intents to use original model type for layers if it is possible
|
# do not convert Constant value if both source and destination types are of integer types
|
||||||
continue
|
# otherwise, it affects compatibility of MO IR Engine and TF FE
|
||||||
convert_const_node_value_type(const_node, np_type)
|
# TF FE intents to use original model type for layers if it is possible
|
||||||
else:
|
continue
|
||||||
in_port.get_connection().insert_node(Cast(graph, {'dst_type': np_type}).create_node())
|
convert_const_node_value_type(const_node, np_type)
|
||||||
|
else:
|
||||||
|
in_port.get_connection().insert_node(Cast(graph, {'dst_type': np_type}).create_node())
|
||||||
|
|
||||||
|
|
||||||
def set_default_tensor_names_for_parameters_results(graph: Graph):
|
def set_default_tensor_names_for_parameters_results(graph: Graph):
|
||||||
|
@ -8,7 +8,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import openvino.runtime.opset11 as opset11
|
import openvino.runtime.opset11 as opset11
|
||||||
import openvino.runtime.opset10 as opset10
|
import openvino.runtime.opset10 as opset10
|
||||||
from openvino.runtime import Model, serialize
|
from openvino.runtime import Model, serialize, Core
|
||||||
|
|
||||||
from openvino.tools.mo.utils.ir_reader.restore_graph import restore_graph_from_ir, save_restored_graph
|
from openvino.tools.mo.utils.ir_reader.restore_graph import restore_graph_from_ir, save_restored_graph
|
||||||
from openvino.tools.mo.utils.logger import init_logger
|
from openvino.tools.mo.utils.logger import init_logger
|
||||||
@ -28,6 +28,8 @@ class TestOps(unittest.TestCase):
|
|||||||
save_restored_graph(graph, tmp, {}, name)
|
save_restored_graph(graph, tmp, {}, name)
|
||||||
# restore 2 times to validate that after save graph doesn't lose attributes etc.
|
# restore 2 times to validate that after save graph doesn't lose attributes etc.
|
||||||
graph, _ = restore_graph_from_ir(model_xml, model_bin)
|
graph, _ = restore_graph_from_ir(model_xml, model_bin)
|
||||||
|
# check that re-saved model can be read in runtime
|
||||||
|
Core().read_model(model_xml)
|
||||||
return graph
|
return graph
|
||||||
|
|
||||||
def test_topk_11(self):
|
def test_topk_11(self):
|
||||||
@ -43,6 +45,7 @@ class TestOps(unittest.TestCase):
|
|||||||
topk_node = graph.get_op_nodes(op="TopK")[0]
|
topk_node = graph.get_op_nodes(op="TopK")[0]
|
||||||
self.assertEqual(topk_node["version"], "opset11")
|
self.assertEqual(topk_node["version"], "opset11")
|
||||||
self.assertTrue(topk_node["stable"])
|
self.assertTrue(topk_node["stable"])
|
||||||
|
self.assertEqual(topk_node["index_element_type"], np.int32)
|
||||||
|
|
||||||
def test_interpolate_11(self):
|
def test_interpolate_11(self):
|
||||||
data_shape = [6, 12, 10, 24]
|
data_shape = [6, 12, 10, 24]
|
||||||
@ -54,6 +57,33 @@ class TestOps(unittest.TestCase):
|
|||||||
graph = TestOps.check_graph_can_save(model, 'interpolate_model')
|
graph = TestOps.check_graph_can_save(model, 'interpolate_model')
|
||||||
interpolate_node = graph.get_op_nodes(op="Interpolate")[0]
|
interpolate_node = graph.get_op_nodes(op="Interpolate")[0]
|
||||||
self.assertEqual(interpolate_node["version"], "opset11")
|
self.assertEqual(interpolate_node["version"], "opset11")
|
||||||
|
self.assertTrue("force_precision_in_ports" in interpolate_node)
|
||||||
|
self.assertEqual(interpolate_node["force_precision_in_ports"], {1: 'int64'})
|
||||||
|
|
||||||
|
def test_interpolate_11_scales(self):
|
||||||
|
data_shape = [6, 12, 10, 24]
|
||||||
|
data_parameter = opset11.parameter(
|
||||||
|
data_shape, name="Data", dtype=np.float32)
|
||||||
|
interpolate = opset11.interpolate(data_parameter, np.float32(
|
||||||
|
[2., 2.]), "nearest", "scales", axes=np.int32([2, 3]), name="Interpolate_11")
|
||||||
|
model = Model(interpolate, [data_parameter])
|
||||||
|
graph = TestOps.check_graph_can_save(model, 'interpolate_model')
|
||||||
|
interpolate_node = graph.get_op_nodes(op="Interpolate")[0]
|
||||||
|
self.assertEqual(interpolate_node["version"], "opset11")
|
||||||
|
self.assertTrue("force_precision_in_ports" not in interpolate_node)
|
||||||
|
|
||||||
|
def test_interpolate_11_no_axes(self):
|
||||||
|
data_shape = [6, 12, 10, 24]
|
||||||
|
data_parameter = opset11.parameter(
|
||||||
|
data_shape, name="Data", dtype=np.float32)
|
||||||
|
interpolate = opset11.interpolate(data_parameter, np.int32(
|
||||||
|
[6, 12, 20, 48]), "nearest", "sizes", name="Interpolate_11")
|
||||||
|
model = Model(interpolate, [data_parameter])
|
||||||
|
graph = TestOps.check_graph_can_save(model, 'interpolate_model')
|
||||||
|
interpolate_node = graph.get_op_nodes(op="Interpolate")[0]
|
||||||
|
self.assertEqual(interpolate_node["version"], "opset11")
|
||||||
|
self.assertTrue("force_precision_in_ports" in interpolate_node)
|
||||||
|
self.assertEqual(interpolate_node["force_precision_in_ports"], {1: 'int64'})
|
||||||
|
|
||||||
def test_interpolate_4(self):
|
def test_interpolate_4(self):
|
||||||
data_shape = [6, 12, 10, 24]
|
data_shape = [6, 12, 10, 24]
|
||||||
|
Loading…
Reference in New Issue
Block a user