Apply RIC for dynamic dimension in legacy MO (#10130)

* Apply RIC for dynamic dimension in legacy MO and fail if RIC wasn't applied to any input

* Fix moc tests
This commit is contained in:
Maxim Vafin 2022-02-08 22:17:19 +03:00 committed by GitHub
parent d951433b12
commit 1970baeb1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 49 deletions

View File

@ -7,7 +7,7 @@ import numpy as np
from openvino.tools.mo.back.replacement import BackReplacementPattern from openvino.tools.mo.back.replacement import BackReplacementPattern
from openvino.tools.mo.front.common.layout import get_dim_from_layout, get_features_dim from openvino.tools.mo.front.common.layout import get_dim_from_layout, get_features_dim
from openvino.tools.mo.front.common.partial_infer.utils import int64_array from openvino.tools.mo.front.common.partial_infer.utils import int64_array, compatible_dims
from openvino.tools.mo.front.common.partial_infer.utils import mo_array from openvino.tools.mo.front.common.partial_infer.utils import mo_array
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs
from openvino.tools.mo.graph.graph import Graph from openvino.tools.mo.graph.graph import Graph
@ -16,6 +16,7 @@ from openvino.tools.mo.ops.concat import Concat
from openvino.tools.mo.ops.gather import Gather from openvino.tools.mo.ops.gather import Gather
from openvino.tools.mo.ops.op import Op, PermuteAttrs from openvino.tools.mo.ops.op import Op, PermuteAttrs
from openvino.tools.mo.ops.split import Split from openvino.tools.mo.ops.split import Split
from openvino.tools.mo.utils.error import Error
class ReverseChannels(Op): class ReverseChannels(Op):
@ -53,7 +54,10 @@ class InsertReverseChannels(BackReplacementPattern):
enabled = False enabled = False
@staticmethod @staticmethod
def get_channel_index(node: Node) -> int: def get_suitable_channel_index(node: Node, shape):
if len(shape) != 4:
return None
guessed_layout = 'NCHW' guessed_layout = 'NCHW'
if node.has_valid('rt_info'): if node.has_valid('rt_info'):
rt_info = node.rt_info rt_info = node.rt_info
@ -66,25 +70,30 @@ class InsertReverseChannels(BackReplacementPattern):
guessed_layout = np.array(list(guessed_layout))[order] guessed_layout = np.array(list(guessed_layout))[order]
guessed_layout = ''.join(guessed_layout) guessed_layout = ''.join(guessed_layout)
idx, has_layout = get_dim_from_layout(node, 'C') idx, has_layout = get_dim_from_layout(node, 'C')
if has_layout: if not has_layout:
idx = get_features_dim(guessed_layout, len(node.shape))
if compatible_dims(shape[idx], 3):
return idx return idx
else: else:
return get_features_dim(guessed_layout, len(node.shape)) return None
def find_and_replace_pattern(self, graph: Graph): def find_and_replace_pattern(self, graph: Graph):
all_params = [(p.soft_get('name', p.id), p, list(p.out_port(0).data.get_shape())) all_params = [(p.soft_get('name', p.id), p, list(p.out_port(0).data.get_shape()))
for p in graph.get_op_nodes(type='Parameter')] for p in graph.get_op_nodes(type='Parameter')]
suitable_params = [] suitable_params = []
for name, p, shape in all_params: for name, p, shape in all_params:
if len(shape) == 4: idx = self.get_suitable_channel_index(p, shape)
idx = self.get_channel_index(p) if idx is not None:
if idx is not None and shape[idx] == 3: suitable_params.append((name, p, shape, idx))
suitable_params.append((name, p, shape, idx))
log.debug('All network inputs: {}'.format({name: shape for name, _, shape in all_params})) log.debug('All network inputs: {}'.format({name: shape for name, _, shape in all_params}))
log.debug('Will reverse input channels for: {}'.format({name: shape for name, _, shape, _ in suitable_params})) log.debug('Will reverse input channels for: {}'.format({name: shape for name, _, shape, _ in suitable_params}))
if len(suitable_params) < len(all_params): if not len(suitable_params):
raise Error('Network has {} inputs overall, but none of them are suitable for input channels reversing.\n'
'Suitable for input channel reversing inputs are 4-dimensional with 3 channels (in case of '
'dynamic dimensions C channel must be provided in a layout for this input)\n'
'All inputs: {}'.format(len(all_params), all_params))
elif len(suitable_params) < len(all_params):
log.error('Network has {} inputs overall, but only {} of them are suitable for input channels reversing.\n' log.error('Network has {} inputs overall, but only {} of them are suitable for input channels reversing.\n'
'Suitable for input channel reversing inputs are 4-dimensional with 3 channels\nAll inputs: {}\n' 'Suitable for input channel reversing inputs are 4-dimensional with 3 channels\nAll inputs: {}\n'
'Suitable inputs {}'.format(len(all_params), len(suitable_params), 'Suitable inputs {}'.format(len(all_params), len(suitable_params),

View File

@ -310,9 +310,15 @@ def guess_source_layouts_for_reverse_channels(ov_function: Model, layout_values)
if layout and check_suitable_for_reverse(Layout(layout), ov_input): if layout and check_suitable_for_reverse(Layout(layout), ov_input):
suitable_params.append(param_info) suitable_params.append(param_info)
if len(suitable_params) < len(all_params): if not len(suitable_params):
raise Error('Network has {} inputs overall, but none of them are suitable for input channels reversing.\n'
'Suitable for input channel reversing inputs are 4-dimensional with 3 channels (in case of dynamic '
'dimensions C channel must be provided in a layout for this input)\nAll inputs: {}'.format(
len(all_params), all_params))
elif len(suitable_params) < len(all_params):
log.error('Network has {} inputs overall, but only {} of them are suitable for input channels reversing.\n' log.error('Network has {} inputs overall, but only {} of them are suitable for input channels reversing.\n'
'Suitable for input channel reversing inputs are 4-dimensional with 3 channels\nAll inputs: {}\n' 'Suitable for input channel reversing inputs are 4-dimensional with 3 channels (in case of dynamic '
'dimensions C channel must be provided in a layout for this input)\nAll inputs: {}\n'
'Suitable inputs {}'.format(len(all_params), len(suitable_params), all_params, suitable_params), 'Suitable inputs {}'.format(len(all_params), len(suitable_params), all_params, suitable_params),
extra={'is_warning': True}) extra={'is_warning': True})
return suitable_params return suitable_params

View File

@ -434,13 +434,9 @@ class TestPreprocessingMOC(UnitTestWithMockedTelemetry):
'input2a': { 'source_layout': 'nchw' } 'input2a': { 'source_layout': 'nchw' }
}) })
function = create_function2(shape1=[1, 224, 224, 4], shape2=[1, 4, 224, 224]) function = create_function2(shape1=[1, 224, 224, 4], shape2=[1, 4, 224, 224])
process_function(ov_function=function, argv=argv) # no suitable inputs
# In future, consider using mock PrePostProcessor to verify that 'reverse_channels' was not called with self.assertRaises(Exception):
# Verify that reverse_channels are not applied. process_function(ov_function=function, argv=argv)
op_node0 = list(function.get_parameters()[0].output(0).get_target_inputs())[0].get_node()
self.assertTrue(op_node0.get_type_name() == 'Relu')
op_node1 = list(function.get_parameters()[1].output(0).get_target_inputs())[0].get_node()
self.assertTrue(op_node1.get_type_name() == 'Relu')
def test_reverse_input_channels_3d(self): def test_reverse_input_channels_3d(self):
argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None, argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None,
@ -457,23 +453,17 @@ class TestPreprocessingMOC(UnitTestWithMockedTelemetry):
argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None, argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None,
layout_values=None) layout_values=None)
function = create_function2(shape1=[4, 4, 4, 4, 4, 3], shape2=[4, 3, 4, 4, 4, 4]) function = create_function2(shape1=[4, 4, 4, 4, 4, 3], shape2=[4, 3, 4, 4, 4, 4])
process_function(ov_function=function, argv=argv) # no suitable inputs
# Verify that reverse_channels are NOT applied. with self.assertRaises(Exception):
op_node0 = list(function.get_parameters()[0].output(0).get_target_inputs())[0].get_node() process_function(ov_function=function, argv=argv)
self.assertTrue(op_node0.get_type_name() == 'Relu')
op_node1 = list(function.get_parameters()[1].output(0).get_target_inputs())[0].get_node()
self.assertTrue(op_node1.get_type_name() == 'Relu')
def test_reverse_input_channels_dynamic(self): def test_reverse_input_channels_dynamic(self):
argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None, argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None,
layout_values=None) layout_values=None)
function = create_function2(shape1=[1, -1, 5, 5], shape2=[-1, -1, -1, -1]) function = create_function2(shape1=[1, -1, 5, 5], shape2=[-1, -1, -1, -1])
process_function(ov_function=function, argv=argv) # no suitable inputs
# Verify that reverse_channels are NOT applied. with self.assertRaises(Exception):
op_node0 = list(function.get_parameters()[0].output(0).get_target_inputs())[0].get_node() process_function(ov_function=function, argv=argv)
self.assertTrue(op_node0.get_type_name() == 'Relu')
op_node1 = list(function.get_parameters()[1].output(0).get_target_inputs())[0].get_node()
self.assertTrue(op_node1.get_type_name() == 'Relu')
def test_reverse_input_channels_dynamic_layout(self): def test_reverse_input_channels_dynamic_layout(self):
argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None, argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None,
@ -582,34 +572,25 @@ class TestPreprocessingMOC(UnitTestWithMockedTelemetry):
function = create_function2(shape1=[1, 224, 224, 3], shape2=[1, 3, 224, 224]) function = create_function2(shape1=[1, 224, 224, 3], shape2=[1, 3, 224, 224])
function.get_parameters()[0].layout = Layout("NHW?") function.get_parameters()[0].layout = Layout("NHW?")
function.get_parameters()[1].layout = Layout("N?HW") function.get_parameters()[1].layout = Layout("N?HW")
process_function(ov_function=function, argv=argv) # no suitable inputs
# Nothing has applied with self.assertRaises(Exception):
op_node0 = list(function.get_parameters()[0].output(0).get_target_inputs())[0].get_node() process_function(ov_function=function, argv=argv)
self.assertTrue(op_node0.get_type_name() == 'Relu')
op_node1 = list(function.get_parameters()[1].output(0).get_target_inputs())[0].get_node()
self.assertTrue(op_node1.get_type_name() == 'Relu')
def test_guess_layout_reverse_channels_incorrect_pos(self): def test_guess_layout_reverse_channels_incorrect_pos(self):
argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None) argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None)
function = create_function2(shape1=[1, 4, 224, 224], shape2=[1, 224, 224, 2]) function = create_function2(shape1=[1, 4, 224, 224], shape2=[1, 224, 224, 2])
function.get_parameters()[0].layout = Layout("NCHW") function.get_parameters()[0].layout = Layout("NCHW")
function.get_parameters()[1].layout = Layout("NHWC") function.get_parameters()[1].layout = Layout("NHWC")
process_function(ov_function=function, argv=argv) # no suitable inputs
# Nothing has applied with self.assertRaises(Exception):
op_node0 = list(function.get_parameters()[0].output(0).get_target_inputs())[0].get_node() process_function(ov_function=function, argv=argv)
self.assertTrue(op_node0.get_type_name() == 'Relu')
op_node1 = list(function.get_parameters()[1].output(0).get_target_inputs())[0].get_node()
self.assertTrue(op_node1.get_type_name() == 'Relu')
def test_no_reverse_channels_even_with_layout(self): def test_no_reverse_channels_even_with_layout(self):
argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None) argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None)
function = create_function2(shape1=[3, 4, 224, 224], shape2=[1, 224, 3, 224]) function = create_function2(shape1=[3, 4, 224, 224], shape2=[1, 224, 3, 224])
process_function(ov_function=function, argv=argv) # no suitable inputs
# Nothing has applied with self.assertRaises(Exception):
op_node0 = list(function.get_parameters()[0].output(0).get_target_inputs())[0].get_node() process_function(ov_function=function, argv=argv)
self.assertTrue(op_node0.get_type_name() == 'Relu')
op_node1 = list(function.get_parameters()[1].output(0).get_target_inputs())[0].get_node()
self.assertTrue(op_node1.get_type_name() == 'Relu')
def test_reverse_channels_and_mean_scale(self): def test_reverse_channels_and_mean_scale(self):
argv = Namespace(reverse_input_channels=True, mean_scale_values={ argv = Namespace(reverse_input_channels=True, mean_scale_values={