[MO] Fix Interpolate-11 in MO (#17002)
* Fix Interpolate-11 in MO * Add forgotten file * Fix output type of TopK-11 * Do not force precision on port 1 for mode scales * Update tools/mo/openvino/tools/mo/ops/interpolate.py --------- Co-authored-by: Ilya Lavrenov <ilya.lavrenov@intel.com> Co-authored-by: Andrei Kochin <andrei.kochin@intel.com> Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com>
This commit is contained in:
parent
5026aa044a
commit
552143c9cd
@ -49,7 +49,8 @@ def infer_for_opsetX(node: Node, opset: str):
|
||||
if node.shape_calculation_mode == 'sizes':
|
||||
dst_shape = node.in_port(1).data.get_value()
|
||||
assert dst_shape is not None
|
||||
correct_scales_using_dst_shape(node, dst_shape, src_shape, axes)
|
||||
if node.get_opset() != "opset11":
|
||||
correct_scales_using_dst_shape(node, dst_shape, src_shape, axes)
|
||||
for i, axis in enumerate(axes):
|
||||
output_shape[axis] = dst_shape[i]
|
||||
else:
|
||||
@ -151,12 +152,13 @@ class Interpolate(Op):
|
||||
'pads_end': 0,
|
||||
|
||||
'infer': self.infer,
|
||||
|
||||
'force_precision_in_ports': {1: 'int64'},
|
||||
'in_ports_count': 2,
|
||||
'out_ports_count': 1,
|
||||
}
|
||||
super().__init__(graph, mandatory_props, attrs)
|
||||
if self.attrs['version'] == 'opset11' and self.attrs['shape_calculation_mode'] != 'sizes':
|
||||
del self.attrs['force_precision_in_ports']
|
||||
|
||||
def supported_attrs(self):
|
||||
opset = self.get_opset()
|
||||
|
@ -76,7 +76,7 @@ class TopK(Op):
|
||||
@staticmethod
|
||||
def type_infer(node):
|
||||
node.out_port(0).set_data_type(node.in_port(0).get_data_type())
|
||||
if node.get_opset() == 'opset3':
|
||||
if node.get_opset() in ['opset3', 'opset11']:
|
||||
node.out_port(1).set_data_type(node.index_element_type)
|
||||
else:
|
||||
node.out_port(1).set_data_type(np.int32)
|
||||
|
@ -159,25 +159,27 @@ def convert_inputs_of_specific_ops(graph: Graph):
|
||||
}
|
||||
|
||||
for node in graph.get_op_nodes():
|
||||
if node.soft_get('type') in type_port:
|
||||
ports_to_update = type_port[node.soft_get('type')]
|
||||
for port_id, precision in ports_to_update.items():
|
||||
if port_id in node.in_ports() and not node.in_port(port_id).disconnected():
|
||||
log.debug('Converting value for the input port "{}" of op "{}" to "{}".'
|
||||
''.format(port_id, node.soft_get('name', node.id), precision))
|
||||
in_port = node.in_port(port_id)
|
||||
np_type = data_type_str_to_np(precision)
|
||||
if in_port.get_source().node.type == 'Const':
|
||||
const_node = node.in_port(port_id).get_source().node
|
||||
const_type = const_node.out_port(0).get_data_type()
|
||||
if np.issubdtype(const_type, np.integer) and np.issubdtype(np_type, np.integer):
|
||||
# do not convert Constant value if both source and destination types are of integer types
|
||||
# otherwise, it affects compatibility of MO IR Engine and TF FE
|
||||
# TF FE intents to use original model type for layers if it is possible
|
||||
continue
|
||||
convert_const_node_value_type(const_node, np_type)
|
||||
else:
|
||||
in_port.get_connection().insert_node(Cast(graph, {'dst_type': np_type}).create_node())
|
||||
if node.soft_get('version') != "opset11":
|
||||
# opset11 cannot be produced by legacy MO frontends, it can only be read by MO IR Reader
|
||||
if node.soft_get('type') in type_port:
|
||||
ports_to_update = type_port[node.soft_get('type')]
|
||||
for port_id, precision in ports_to_update.items():
|
||||
if port_id in node.in_ports() and not node.in_port(port_id).disconnected():
|
||||
log.debug('Converting value for the input port "{}" of op "{}" to "{}".'
|
||||
''.format(port_id, node.soft_get('name', node.id), precision))
|
||||
in_port = node.in_port(port_id)
|
||||
np_type = data_type_str_to_np(precision)
|
||||
if in_port.get_source().node.type == 'Const':
|
||||
const_node = node.in_port(port_id).get_source().node
|
||||
const_type = const_node.out_port(0).get_data_type()
|
||||
if np.issubdtype(const_type, np.integer) and np.issubdtype(np_type, np.integer):
|
||||
# do not convert Constant value if both source and destination types are of integer types
|
||||
# otherwise, it affects compatibility of MO IR Engine and TF FE
|
||||
# TF FE intents to use original model type for layers if it is possible
|
||||
continue
|
||||
convert_const_node_value_type(const_node, np_type)
|
||||
else:
|
||||
in_port.get_connection().insert_node(Cast(graph, {'dst_type': np_type}).create_node())
|
||||
|
||||
|
||||
def set_default_tensor_names_for_parameters_results(graph: Graph):
|
||||
|
@ -8,7 +8,7 @@ from pathlib import Path
|
||||
|
||||
import openvino.runtime.opset11 as opset11
|
||||
import openvino.runtime.opset10 as opset10
|
||||
from openvino.runtime import Model, serialize
|
||||
from openvino.runtime import Model, serialize, Core
|
||||
|
||||
from openvino.tools.mo.utils.ir_reader.restore_graph import restore_graph_from_ir, save_restored_graph
|
||||
from openvino.tools.mo.utils.logger import init_logger
|
||||
@ -28,6 +28,8 @@ class TestOps(unittest.TestCase):
|
||||
save_restored_graph(graph, tmp, {}, name)
|
||||
# restore 2 times to validate that after save graph doesn't lose attributes etc.
|
||||
graph, _ = restore_graph_from_ir(model_xml, model_bin)
|
||||
# check that re-saved model can be read in runtime
|
||||
Core().read_model(model_xml)
|
||||
return graph
|
||||
|
||||
def test_topk_11(self):
|
||||
@ -43,6 +45,7 @@ class TestOps(unittest.TestCase):
|
||||
topk_node = graph.get_op_nodes(op="TopK")[0]
|
||||
self.assertEqual(topk_node["version"], "opset11")
|
||||
self.assertTrue(topk_node["stable"])
|
||||
self.assertEqual(topk_node["index_element_type"], np.int32)
|
||||
|
||||
def test_interpolate_11(self):
|
||||
data_shape = [6, 12, 10, 24]
|
||||
@ -54,6 +57,33 @@ class TestOps(unittest.TestCase):
|
||||
graph = TestOps.check_graph_can_save(model, 'interpolate_model')
|
||||
interpolate_node = graph.get_op_nodes(op="Interpolate")[0]
|
||||
self.assertEqual(interpolate_node["version"], "opset11")
|
||||
self.assertTrue("force_precision_in_ports" in interpolate_node)
|
||||
self.assertEqual(interpolate_node["force_precision_in_ports"], {1: 'int64'})
|
||||
|
||||
def test_interpolate_11_scales(self):
|
||||
data_shape = [6, 12, 10, 24]
|
||||
data_parameter = opset11.parameter(
|
||||
data_shape, name="Data", dtype=np.float32)
|
||||
interpolate = opset11.interpolate(data_parameter, np.float32(
|
||||
[2., 2.]), "nearest", "scales", axes=np.int32([2, 3]), name="Interpolate_11")
|
||||
model = Model(interpolate, [data_parameter])
|
||||
graph = TestOps.check_graph_can_save(model, 'interpolate_model')
|
||||
interpolate_node = graph.get_op_nodes(op="Interpolate")[0]
|
||||
self.assertEqual(interpolate_node["version"], "opset11")
|
||||
self.assertTrue("force_precision_in_ports" not in interpolate_node)
|
||||
|
||||
def test_interpolate_11_no_axes(self):
|
||||
data_shape = [6, 12, 10, 24]
|
||||
data_parameter = opset11.parameter(
|
||||
data_shape, name="Data", dtype=np.float32)
|
||||
interpolate = opset11.interpolate(data_parameter, np.int32(
|
||||
[6, 12, 20, 48]), "nearest", "sizes", name="Interpolate_11")
|
||||
model = Model(interpolate, [data_parameter])
|
||||
graph = TestOps.check_graph_can_save(model, 'interpolate_model')
|
||||
interpolate_node = graph.get_op_nodes(op="Interpolate")[0]
|
||||
self.assertEqual(interpolate_node["version"], "opset11")
|
||||
self.assertTrue("force_precision_in_ports" in interpolate_node)
|
||||
self.assertEqual(interpolate_node["force_precision_in_ports"], {1: 'int64'})
|
||||
|
||||
def test_interpolate_4(self):
|
||||
data_shape = [6, 12, 10, 24]
|
||||
|
Loading…
Reference in New Issue
Block a user