Apply RIC for dynamic dimension in legacy MO (#10130)
* Apply RIC for dynamic dimension in legacy MO and fail if RIC wasn't applied to any input * Fix moc tests
This commit is contained in:
parent
d951433b12
commit
1970baeb1c
@ -7,7 +7,7 @@ import numpy as np
|
|||||||
|
|
||||||
from openvino.tools.mo.back.replacement import BackReplacementPattern
|
from openvino.tools.mo.back.replacement import BackReplacementPattern
|
||||||
from openvino.tools.mo.front.common.layout import get_dim_from_layout, get_features_dim
|
from openvino.tools.mo.front.common.layout import get_dim_from_layout, get_features_dim
|
||||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array
|
from openvino.tools.mo.front.common.partial_infer.utils import int64_array, compatible_dims
|
||||||
from openvino.tools.mo.front.common.partial_infer.utils import mo_array
|
from openvino.tools.mo.front.common.partial_infer.utils import mo_array
|
||||||
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs
|
from openvino.tools.mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||||
from openvino.tools.mo.graph.graph import Graph
|
from openvino.tools.mo.graph.graph import Graph
|
||||||
@ -16,6 +16,7 @@ from openvino.tools.mo.ops.concat import Concat
|
|||||||
from openvino.tools.mo.ops.gather import Gather
|
from openvino.tools.mo.ops.gather import Gather
|
||||||
from openvino.tools.mo.ops.op import Op, PermuteAttrs
|
from openvino.tools.mo.ops.op import Op, PermuteAttrs
|
||||||
from openvino.tools.mo.ops.split import Split
|
from openvino.tools.mo.ops.split import Split
|
||||||
|
from openvino.tools.mo.utils.error import Error
|
||||||
|
|
||||||
|
|
||||||
class ReverseChannels(Op):
|
class ReverseChannels(Op):
|
||||||
@ -53,7 +54,10 @@ class InsertReverseChannels(BackReplacementPattern):
|
|||||||
enabled = False
|
enabled = False
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_channel_index(node: Node) -> int:
|
def get_suitable_channel_index(node: Node, shape):
|
||||||
|
if len(shape) != 4:
|
||||||
|
return None
|
||||||
|
|
||||||
guessed_layout = 'NCHW'
|
guessed_layout = 'NCHW'
|
||||||
if node.has_valid('rt_info'):
|
if node.has_valid('rt_info'):
|
||||||
rt_info = node.rt_info
|
rt_info = node.rt_info
|
||||||
@ -66,25 +70,30 @@ class InsertReverseChannels(BackReplacementPattern):
|
|||||||
guessed_layout = np.array(list(guessed_layout))[order]
|
guessed_layout = np.array(list(guessed_layout))[order]
|
||||||
guessed_layout = ''.join(guessed_layout)
|
guessed_layout = ''.join(guessed_layout)
|
||||||
idx, has_layout = get_dim_from_layout(node, 'C')
|
idx, has_layout = get_dim_from_layout(node, 'C')
|
||||||
if has_layout:
|
if not has_layout:
|
||||||
|
idx = get_features_dim(guessed_layout, len(node.shape))
|
||||||
|
if compatible_dims(shape[idx], 3):
|
||||||
return idx
|
return idx
|
||||||
else:
|
else:
|
||||||
return get_features_dim(guessed_layout, len(node.shape))
|
return None
|
||||||
|
|
||||||
def find_and_replace_pattern(self, graph: Graph):
|
def find_and_replace_pattern(self, graph: Graph):
|
||||||
all_params = [(p.soft_get('name', p.id), p, list(p.out_port(0).data.get_shape()))
|
all_params = [(p.soft_get('name', p.id), p, list(p.out_port(0).data.get_shape()))
|
||||||
for p in graph.get_op_nodes(type='Parameter')]
|
for p in graph.get_op_nodes(type='Parameter')]
|
||||||
suitable_params = []
|
suitable_params = []
|
||||||
for name, p, shape in all_params:
|
for name, p, shape in all_params:
|
||||||
if len(shape) == 4:
|
idx = self.get_suitable_channel_index(p, shape)
|
||||||
idx = self.get_channel_index(p)
|
if idx is not None:
|
||||||
if idx is not None and shape[idx] == 3:
|
suitable_params.append((name, p, shape, idx))
|
||||||
suitable_params.append((name, p, shape, idx))
|
|
||||||
|
|
||||||
|
|
||||||
log.debug('All network inputs: {}'.format({name: shape for name, _, shape in all_params}))
|
log.debug('All network inputs: {}'.format({name: shape for name, _, shape in all_params}))
|
||||||
log.debug('Will reverse input channels for: {}'.format({name: shape for name, _, shape, _ in suitable_params}))
|
log.debug('Will reverse input channels for: {}'.format({name: shape for name, _, shape, _ in suitable_params}))
|
||||||
if len(suitable_params) < len(all_params):
|
if not len(suitable_params):
|
||||||
|
raise Error('Network has {} inputs overall, but none of them are suitable for input channels reversing.\n'
|
||||||
|
'Suitable for input channel reversing inputs are 4-dimensional with 3 channels (in case of '
|
||||||
|
'dynamic dimensions C channel must be provided in a layout for this input)\n'
|
||||||
|
'All inputs: {}'.format(len(all_params), all_params))
|
||||||
|
elif len(suitable_params) < len(all_params):
|
||||||
log.error('Network has {} inputs overall, but only {} of them are suitable for input channels reversing.\n'
|
log.error('Network has {} inputs overall, but only {} of them are suitable for input channels reversing.\n'
|
||||||
'Suitable for input channel reversing inputs are 4-dimensional with 3 channels\nAll inputs: {}\n'
|
'Suitable for input channel reversing inputs are 4-dimensional with 3 channels\nAll inputs: {}\n'
|
||||||
'Suitable inputs {}'.format(len(all_params), len(suitable_params),
|
'Suitable inputs {}'.format(len(all_params), len(suitable_params),
|
||||||
|
@ -310,9 +310,15 @@ def guess_source_layouts_for_reverse_channels(ov_function: Model, layout_values)
|
|||||||
if layout and check_suitable_for_reverse(Layout(layout), ov_input):
|
if layout and check_suitable_for_reverse(Layout(layout), ov_input):
|
||||||
suitable_params.append(param_info)
|
suitable_params.append(param_info)
|
||||||
|
|
||||||
if len(suitable_params) < len(all_params):
|
if not len(suitable_params):
|
||||||
|
raise Error('Network has {} inputs overall, but none of them are suitable for input channels reversing.\n'
|
||||||
|
'Suitable for input channel reversing inputs are 4-dimensional with 3 channels (in case of dynamic '
|
||||||
|
'dimensions C channel must be provided in a layout for this input)\nAll inputs: {}'.format(
|
||||||
|
len(all_params), all_params))
|
||||||
|
elif len(suitable_params) < len(all_params):
|
||||||
log.error('Network has {} inputs overall, but only {} of them are suitable for input channels reversing.\n'
|
log.error('Network has {} inputs overall, but only {} of them are suitable for input channels reversing.\n'
|
||||||
'Suitable for input channel reversing inputs are 4-dimensional with 3 channels\nAll inputs: {}\n'
|
'Suitable for input channel reversing inputs are 4-dimensional with 3 channels (in case of dynamic '
|
||||||
|
'dimensions C channel must be provided in a layout for this input)\nAll inputs: {}\n'
|
||||||
'Suitable inputs {}'.format(len(all_params), len(suitable_params), all_params, suitable_params),
|
'Suitable inputs {}'.format(len(all_params), len(suitable_params), all_params, suitable_params),
|
||||||
extra={'is_warning': True})
|
extra={'is_warning': True})
|
||||||
return suitable_params
|
return suitable_params
|
||||||
|
@ -434,13 +434,9 @@ class TestPreprocessingMOC(UnitTestWithMockedTelemetry):
|
|||||||
'input2a': { 'source_layout': 'nchw' }
|
'input2a': { 'source_layout': 'nchw' }
|
||||||
})
|
})
|
||||||
function = create_function2(shape1=[1, 224, 224, 4], shape2=[1, 4, 224, 224])
|
function = create_function2(shape1=[1, 224, 224, 4], shape2=[1, 4, 224, 224])
|
||||||
process_function(ov_function=function, argv=argv)
|
# no suitable inputs
|
||||||
# In future, consider using mock PrePostProcessor to verify that 'reverse_channels' was not called
|
with self.assertRaises(Exception):
|
||||||
# Verify that reverse_channels are not applied.
|
process_function(ov_function=function, argv=argv)
|
||||||
op_node0 = list(function.get_parameters()[0].output(0).get_target_inputs())[0].get_node()
|
|
||||||
self.assertTrue(op_node0.get_type_name() == 'Relu')
|
|
||||||
op_node1 = list(function.get_parameters()[1].output(0).get_target_inputs())[0].get_node()
|
|
||||||
self.assertTrue(op_node1.get_type_name() == 'Relu')
|
|
||||||
|
|
||||||
def test_reverse_input_channels_3d(self):
|
def test_reverse_input_channels_3d(self):
|
||||||
argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None,
|
argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None,
|
||||||
@ -457,23 +453,17 @@ class TestPreprocessingMOC(UnitTestWithMockedTelemetry):
|
|||||||
argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None,
|
argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None,
|
||||||
layout_values=None)
|
layout_values=None)
|
||||||
function = create_function2(shape1=[4, 4, 4, 4, 4, 3], shape2=[4, 3, 4, 4, 4, 4])
|
function = create_function2(shape1=[4, 4, 4, 4, 4, 3], shape2=[4, 3, 4, 4, 4, 4])
|
||||||
process_function(ov_function=function, argv=argv)
|
# no suitable inputs
|
||||||
# Verify that reverse_channels are NOT applied.
|
with self.assertRaises(Exception):
|
||||||
op_node0 = list(function.get_parameters()[0].output(0).get_target_inputs())[0].get_node()
|
process_function(ov_function=function, argv=argv)
|
||||||
self.assertTrue(op_node0.get_type_name() == 'Relu')
|
|
||||||
op_node1 = list(function.get_parameters()[1].output(0).get_target_inputs())[0].get_node()
|
|
||||||
self.assertTrue(op_node1.get_type_name() == 'Relu')
|
|
||||||
|
|
||||||
def test_reverse_input_channels_dynamic(self):
|
def test_reverse_input_channels_dynamic(self):
|
||||||
argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None,
|
argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None,
|
||||||
layout_values=None)
|
layout_values=None)
|
||||||
function = create_function2(shape1=[1, -1, 5, 5], shape2=[-1, -1, -1, -1])
|
function = create_function2(shape1=[1, -1, 5, 5], shape2=[-1, -1, -1, -1])
|
||||||
process_function(ov_function=function, argv=argv)
|
# no suitable inputs
|
||||||
# Verify that reverse_channels are NOT applied.
|
with self.assertRaises(Exception):
|
||||||
op_node0 = list(function.get_parameters()[0].output(0).get_target_inputs())[0].get_node()
|
process_function(ov_function=function, argv=argv)
|
||||||
self.assertTrue(op_node0.get_type_name() == 'Relu')
|
|
||||||
op_node1 = list(function.get_parameters()[1].output(0).get_target_inputs())[0].get_node()
|
|
||||||
self.assertTrue(op_node1.get_type_name() == 'Relu')
|
|
||||||
|
|
||||||
def test_reverse_input_channels_dynamic_layout(self):
|
def test_reverse_input_channels_dynamic_layout(self):
|
||||||
argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None,
|
argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None,
|
||||||
@ -582,34 +572,25 @@ class TestPreprocessingMOC(UnitTestWithMockedTelemetry):
|
|||||||
function = create_function2(shape1=[1, 224, 224, 3], shape2=[1, 3, 224, 224])
|
function = create_function2(shape1=[1, 224, 224, 3], shape2=[1, 3, 224, 224])
|
||||||
function.get_parameters()[0].layout = Layout("NHW?")
|
function.get_parameters()[0].layout = Layout("NHW?")
|
||||||
function.get_parameters()[1].layout = Layout("N?HW")
|
function.get_parameters()[1].layout = Layout("N?HW")
|
||||||
process_function(ov_function=function, argv=argv)
|
# no suitable inputs
|
||||||
# Nothing has applied
|
with self.assertRaises(Exception):
|
||||||
op_node0 = list(function.get_parameters()[0].output(0).get_target_inputs())[0].get_node()
|
process_function(ov_function=function, argv=argv)
|
||||||
self.assertTrue(op_node0.get_type_name() == 'Relu')
|
|
||||||
op_node1 = list(function.get_parameters()[1].output(0).get_target_inputs())[0].get_node()
|
|
||||||
self.assertTrue(op_node1.get_type_name() == 'Relu')
|
|
||||||
|
|
||||||
def test_guess_layout_reverse_channels_incorrect_pos(self):
|
def test_guess_layout_reverse_channels_incorrect_pos(self):
|
||||||
argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None)
|
argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None)
|
||||||
function = create_function2(shape1=[1, 4, 224, 224], shape2=[1, 224, 224, 2])
|
function = create_function2(shape1=[1, 4, 224, 224], shape2=[1, 224, 224, 2])
|
||||||
function.get_parameters()[0].layout = Layout("NCHW")
|
function.get_parameters()[0].layout = Layout("NCHW")
|
||||||
function.get_parameters()[1].layout = Layout("NHWC")
|
function.get_parameters()[1].layout = Layout("NHWC")
|
||||||
process_function(ov_function=function, argv=argv)
|
# no suitable inputs
|
||||||
# Nothing has applied
|
with self.assertRaises(Exception):
|
||||||
op_node0 = list(function.get_parameters()[0].output(0).get_target_inputs())[0].get_node()
|
process_function(ov_function=function, argv=argv)
|
||||||
self.assertTrue(op_node0.get_type_name() == 'Relu')
|
|
||||||
op_node1 = list(function.get_parameters()[1].output(0).get_target_inputs())[0].get_node()
|
|
||||||
self.assertTrue(op_node1.get_type_name() == 'Relu')
|
|
||||||
|
|
||||||
def test_no_reverse_channels_even_with_layout(self):
|
def test_no_reverse_channels_even_with_layout(self):
|
||||||
argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None)
|
argv = Namespace(reverse_input_channels=True, mean_scale_values=None, scale=None)
|
||||||
function = create_function2(shape1=[3, 4, 224, 224], shape2=[1, 224, 3, 224])
|
function = create_function2(shape1=[3, 4, 224, 224], shape2=[1, 224, 3, 224])
|
||||||
process_function(ov_function=function, argv=argv)
|
# no suitable inputs
|
||||||
# Nothing has applied
|
with self.assertRaises(Exception):
|
||||||
op_node0 = list(function.get_parameters()[0].output(0).get_target_inputs())[0].get_node()
|
process_function(ov_function=function, argv=argv)
|
||||||
self.assertTrue(op_node0.get_type_name() == 'Relu')
|
|
||||||
op_node1 = list(function.get_parameters()[1].output(0).get_target_inputs())[0].get_node()
|
|
||||||
self.assertTrue(op_node1.get_type_name() == 'Relu')
|
|
||||||
|
|
||||||
def test_reverse_channels_and_mean_scale(self):
|
def test_reverse_channels_and_mean_scale(self):
|
||||||
argv = Namespace(reverse_input_channels=True, mean_scale_values={
|
argv = Namespace(reverse_input_channels=True, mean_scale_values={
|
||||||
|
Loading…
Reference in New Issue
Block a user