diff --git a/model-optimizer/automation/package_BOM.txt b/model-optimizer/automation/package_BOM.txt index 7646cc8ffe9..f18563afc08 100644 --- a/model-optimizer/automation/package_BOM.txt +++ b/model-optimizer/automation/package_BOM.txt @@ -11,6 +11,7 @@ extensions/back/__init__.py extensions/back/AvgPool.py extensions/back/blob_normalizer.py extensions/back/CellNormalizer.py +extensions/back/ChangeCastOutputType.py extensions/back/ClampNormalizer.py extensions/back/compress_quantized_weights.py extensions/back/ConvolutionNormalizer.py @@ -32,6 +33,7 @@ extensions/back/LayoutChangeForGatherND.py extensions/back/LeakyReLUMutation.py extensions/back/LinearToLinearONNXReplacer.py extensions/back/LRNToNorm.py +extensions/back/MarkNodesWithShapeValues.py extensions/back/MatMulNormalizer.py extensions/back/MaxPool.py extensions/back/NormalizeToNormalizeL2.py @@ -125,7 +127,6 @@ extensions/front/caffe/softmax_ext.py extensions/front/caffe/spatial_transformer_ext.py extensions/front/caffe/split_to_identity.py extensions/front/caffe/tanh.py -extensions/front/ChangeCastOutputType.py extensions/front/ChangePlaceholderTypes.py extensions/front/create_tensor_nodes.py extensions/front/disable_weights_quantize_value_propagation.py diff --git a/model-optimizer/extensions/back/ChangeCastOutputType.py b/model-optimizer/extensions/back/ChangeCastOutputType.py new file mode 100644 index 00000000000..976b6b50a29 --- /dev/null +++ b/model-optimizer/extensions/back/ChangeCastOutputType.py @@ -0,0 +1,43 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging as log + +import numpy as np + +from mo.back.replacement import BackReplacementPattern +from mo.graph.graph import Graph +from mo.middle.passes.convert_data_type import data_type_str_to_np + + +class ChangeCastOutputType(BackReplacementPattern): + """ + Change the Cast dst_type from fp64 to fp32 since not all plugins support fp64 data type. + Change the Cast dst_type from fp32 to fp16 when generating IR for fp16. + But leave fp32 if node returns shape value even if --data_type=FP16 (look extensions/back/MarkNodesWithShapeValues.py). + """ + enabled = True + force_shape_inference = True + + def run_after(self): + from extensions.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues + return [MarkNodesWithShapeValues] + + def run_before(self): + return [] + + def find_and_replace_pattern(self, graph: Graph): + for node in graph.get_op_nodes(op='Cast'): + if node.dst_type == np.float64: + log.warning('Change data type from {} to {} for node {}'.format(node.dst_type, np.float32, node.name)) + node.dst_type = np.float32 + + ir_data_type = data_type_str_to_np(node.graph.graph['cmd_params'].data_type) + if node.dst_type == np.float32 and ir_data_type == np.float16 and not node.has_and_set('returns_shape_value'): + log.warning('Change data type from {} to {} for node {}'.format(node.dst_type, ir_data_type, node.name)) + node.dst_type = ir_data_type + elif node.has_and_set('returns_shape_value') and node.dst_type == np.float16: + # return back FP32 for all Convert nodes with shape values + log.warning('Change data type from {} to {} for node {} in ShapeOf subgraph'. + format(node.dst_type, np.float32, node.name)) + node.dst_type = np.float32 diff --git a/model-optimizer/extensions/back/MarkNodesWithShapeValues.py b/model-optimizer/extensions/back/MarkNodesWithShapeValues.py new file mode 100644 index 00000000000..0201e31423a --- /dev/null +++ b/model-optimizer/extensions/back/MarkNodesWithShapeValues.py @@ -0,0 +1,78 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging as log + +import numpy as np + +from extensions.middle.MarkSubgraphsWithCorrectLayout import MarkSubGraphsWithCorrectLayout +from mo.back.replacement import BackReplacementPattern +from mo.graph.graph import Graph + + +class MarkNodesWithShapeValues(BackReplacementPattern): + """ + This transformation marks op nodes in ShapeOf subgraphs with 'returns_shape_value' bool attribute and + data nodes of float32 constants with 'correct_data_type' attribute. + So that float Consts and Cast float will be kept in FP32 even if argument --data_type=FP16 is specified. + + This is needed to enable conversion to FP16 even if values in ShapeOf subgraphs exceed max(float16) + or because of FP16 lower precession shape inference is incorrect on some nodes (e.g. if Interpolate in scales mode + accepts values from ShapeOf subgraph). + + This transformation should be executed after shape inference and after all transformations which insert/modify + Cast nodes in ShapeOf subgraphs therefore it's placed at the end of the back phase. + """ + enabled = True + graph_condition = [lambda graph: graph.graph['cmd_params'].data_type == 'FP16'] + + def run_after(self): + from extensions.back.pass_separator import BackFinish + return [BackFinish] + + def run_before(self): + return [] + + @staticmethod + def get_operations_with_shape_inputs(): + return { + 'Interpolate': [1, 2], # sizes, scales inputs + 'Reshape': [1], # shape + 'Broadcast': [1], # target_shape + 'ConvBackPropData ': [2], # output_shape + 'GroupConvolutionBackpropData ': [2], # output_shape + 'BatchToSpace': [1, 2, 3], # block_shape, crops_begin, crops_end + 'SpaceToBatch': [1, 2, 3], # block_shape, pads_begin, pads_end + 'StridedSlice': [1, 2, 3], # begin, end, strides + 'VariadicSplit': [2], # split_lengths + 'Tile': [1], # repeats input + 'TopK': [1], # K input + 'Pad': [1, 2], # pads_begin, pads_end + 'Range': [0, 1, 2], # start, stop, step inputs + 'OneHot': [1], # depth input + } + + def find_and_replace_pattern(self, graph: Graph): + shape_input_ops_map = self.get_operations_with_shape_inputs() + + nodes_with_shape_inputs = [] + for node in graph.get_op_nodes(): + if node.soft_get('type') in shape_input_ops_map: + nodes_with_shape_inputs.append(node) + + start_nodes = [] + for node in nodes_with_shape_inputs: + start_nodes.extend( + [node.in_port(port_idx).get_source().node for port_idx in shape_input_ops_map[node.type] if + node.is_in_port_connected(port_idx)]) + + condition = lambda node: node.soft_get('type') != 'ShapeOf' + nodes_with_shape_values = MarkSubGraphsWithCorrectLayout.bfs(start_nodes, set(), condition, forward=False) + for node in nodes_with_shape_values: + node['returns_shape_value'] = True + if node.soft_get('type') == 'Const': + if node.value.dtype == np.float32: + node.out_node(0)['correct_data_type'] = True + elif node.value.dtype in [np.float16, np.float64]: + log.debug('Const nodes {} with shape values have {} type'.format(node.soft_get('name', node.id), + node.value.dtype)) diff --git a/model-optimizer/extensions/back/MarkNodesWithShapeValues_test.py b/model-optimizer/extensions/back/MarkNodesWithShapeValues_test.py new file mode 100644 index 00000000000..0d33f7b6e21 --- /dev/null +++ b/model-optimizer/extensions/back/MarkNodesWithShapeValues_test.py @@ -0,0 +1,103 @@ +# Copyright (C) 2018-2021 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import unittest + +import numpy as np + +from extensions.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues +from mo.front.common.partial_infer.utils import int64_array, float32_array +from mo.graph.graph import Node +from mo.utils.ir_engine.compare_graphs import compare_graphs +from mo.utils.unittest.graph import build_graph +from mo.utils.unittest.graph import result, regular_op_with_empty_data, \ + shaped_const_with_data, connect, regular_op + + +class TestMarkDataTypeInShapeOfSubgraphs(unittest.TestCase): + + def test_run_with_shape_subgraph_input(self): + inp_shape = (1, 3, 1000, 1000) + dst_type = np.float32 + + nodes = { + **shaped_const_with_data('input', int64_array(inp_shape)), + **regular_op_with_empty_data('shape', {'type': 'ShapeOf'}), + **regular_op_with_empty_data('cast_to_float', {'type': 'Cast', 'dst_type': dst_type}), + **regular_op('mul_const', {'op': 'Const'}), + **{'mul_const_d': {'kind': 'data', 'value': float32_array([1., 1., 1., 100.])}}, + **regular_op_with_empty_data('mul', {'type': 'Mul'}), + **regular_op_with_empty_data('cast_to_int', {'type': 'Cast', 'dst_type': np.int64}), + **regular_op_with_empty_data('interpolate', {'type': 'Interpolate', 'shape_calculation_model': 'scales'}), + **result('res'), + } + + nodes_ref = { + **shaped_const_with_data('input', int64_array(inp_shape)), + **regular_op_with_empty_data('shape', {'type': 'ShapeOf'}), + **regular_op_with_empty_data('cast_to_float', {'type': 'Cast', 'dst_type': dst_type, + 'returns_shape_value': True}), + **regular_op_with_empty_data('mul', {'type': 'Mul', 'returns_shape_value': True}), + **regular_op('mul_const', {'op': 'Const', 'returns_shape_value': True}), + **{'mul_const_d': {'kind': 'data', 'value': float32_array([1., 1., 1., 100.]), + 'correct_data_type': True}}, + **regular_op_with_empty_data('cast_to_int', {'type': 'Cast', 'dst_type': np.int64, + 'returns_shape_value': True}), + **regular_op_with_empty_data('interpolate', {'type': 'Interpolate', 'shape_calculation_model': 'scales'}), + **result('res'), + } + + edges = [ + *connect('input', '0:interpolate'), + *connect('input', '0:shape', skip_data=True), + *connect('shape', '0:cast_to_float'), + *connect('cast_to_float', '0:mul'), + *connect('mul_const', '1:mul'), + *connect('mul', '0:cast_to_int'), + *connect('cast_to_int', '1:interpolate'), + *connect('interpolate', 'res'), + ] + graph = build_graph(nodes, edges) + interp_node = Node(graph, 'interpolate') + interp_node.add_input_port(2) + + MarkNodesWithShapeValues().find_and_replace_pattern(graph) + + graph_ref = build_graph(nodes_ref, edges) + (flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=True) + self.assertTrue(flag, resp) + + def test_run_with_const_input(self): + inp_shape = (1, 3, 1000, 1000) + dst_type = np.float32 + + nodes = { + **shaped_const_with_data('input', int64_array(inp_shape)), + **regular_op('sizes_const', {'op': 'Const'}), + **{'sizes_const_d': {'kind': 'data', 'value': float32_array([1., 1., 1., 100.])}}, + **regular_op_with_empty_data('interpolate', {'type': 'Interpolate', 'shape_calculation_model': 'scales'}), + **result('res'), + } + + nodes_ref = { + **shaped_const_with_data('input', int64_array(inp_shape)), + **regular_op('sizes_const', {'op': 'Const', 'returns_shape_value': True}), + **{'sizes_const_d': {'kind': 'data', 'value': float32_array([1., 1., 1., 100.])}}, + **regular_op_with_empty_data('interpolate', {'type': 'Interpolate', 'shape_calculation_model': 'scales'}), + **result('res'), + } + + edges = [ + *connect('input', '0:interpolate'), + *connect('sizes_const', '1:interpolate'), + *connect('interpolate', 'res'), + ] + graph = build_graph(nodes, edges) + interp_node = Node(graph, 'interpolate') + interp_node.add_input_port(2) + + MarkNodesWithShapeValues().find_and_replace_pattern(graph) + + graph_ref = build_graph(nodes_ref, edges) + (flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=True) + self.assertTrue(flag, resp) diff --git a/model-optimizer/extensions/front/ChangeCastOutputType.py b/model-optimizer/extensions/front/ChangeCastOutputType.py deleted file mode 100644 index 45bd72d2d0f..00000000000 --- a/model-optimizer/extensions/front/ChangeCastOutputType.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (C) 2018-2021 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - -import logging as log - -import numpy as np - -from mo.front.common.replacement import FrontReplacementSubgraph -from mo.front.subgraph_matcher import SubgraphMatch -from mo.graph.graph import Graph -from mo.middle.passes.convert_data_type import data_type_str_to_np - - -class ChangeCastOutputType(FrontReplacementSubgraph): - """ - Change the Cast to fp64 to fp32 since not all plugins support fp64 data type. - Change the Cast to fp32 to fp16 when generating IR for fp16. - """ - enabled = True - - def pattern(self): - return dict( - nodes=[ - ('cast', dict(op='Cast')) - ], - edges=[] - ) - - def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): - node = match['cast'] - if node.dst_type == np.float64: - log.warning('Change data type from {} to {} for node {}'.format(node.dst_type, np.float32, node.name)) - node.dst_type = np.float32 - - ir_data_type = data_type_str_to_np(node.graph.graph['cmd_params'].data_type) - if node.dst_type == np.float32 and ir_data_type == np.float16: - log.warning('Change data type from {} to {} for node {}'.format(node.dst_type, ir_data_type, node.name)) - node.dst_type = ir_data_type diff --git a/model-optimizer/extensions/front/kaldi/add_reshape_around_convolution.py b/model-optimizer/extensions/front/kaldi/add_reshape_around_convolution.py index 6708ca737ae..42ef5d4da06 100644 --- a/model-optimizer/extensions/front/kaldi/add_reshape_around_convolution.py +++ b/model-optimizer/extensions/front/kaldi/add_reshape_around_convolution.py @@ -5,11 +5,10 @@ import numpy as np from extensions.ops.Cast import Cast from extensions.ops.elementwise import Div -from mo.front.common.partial_infer.utils import int64_array, float_array +from mo.front.common.partial_infer.utils import int64_array, float32_array from mo.front.common.replacement import FrontReplacementPattern from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input from mo.graph.graph import Graph -from mo.middle.passes.convert_data_type import data_type_str_to_np from mo.ops.concat import Concat from mo.ops.reshape import Reshape from mo.ops.shape import Shape @@ -46,21 +45,23 @@ class ReplaceConvolutionReshape(FrontReplacementPattern): node = match['conv'] node_name = node.soft_get('name', node.id) + dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values + # create Reshape before convolution # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride] i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node() shape = Cast(graph, {'name': node_name + '/to_float', - 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node() + 'dst_type': dst_dtype}).create_node() i_shape.in_port(0).connect(node.in_port(0).get_source()) shape.in_port(0).connect(i_shape.out_port(0)) N, H = node_to_get_shape_value_of_indices(shape, [0]), node_to_get_shape_value_of_indices(shape, [1]) div = create_op_with_const_inputs( - graph, Div, {1: float_array([node.patch_stride])}, {'name': node_name + '/div_stride_h'}) + graph, Div, {1: float32_array([node.patch_stride])}, {'name': node_name + '/div_stride_h'}) div.in_port(0).connect(H.out_port(0)) - concat = create_op_with_const_inputs(graph, Concat, {2: float_array([1]), 3: float_array([node.patch_stride])}, + concat = create_op_with_const_inputs(graph, Concat, {2: float32_array([1]), 3: float32_array([node.patch_stride])}, {'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0}) concat.in_port(0).connect(N.out_port(0)) concat.in_port(1).connect(div.out_port(0)) diff --git a/model-optimizer/extensions/front/kaldi/add_reshape_around_pooling.py b/model-optimizer/extensions/front/kaldi/add_reshape_around_pooling.py index 0bead3863d7..10c992d6a83 100644 --- a/model-optimizer/extensions/front/kaldi/add_reshape_around_pooling.py +++ b/model-optimizer/extensions/front/kaldi/add_reshape_around_pooling.py @@ -5,11 +5,10 @@ import numpy as np from extensions.ops.Cast import Cast from extensions.ops.elementwise import Div -from mo.front.common.partial_infer.utils import int64_array, float_array +from mo.front.common.partial_infer.utils import int64_array, float32_array from mo.front.common.replacement import FrontReplacementPattern from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs from mo.graph.graph import Graph -from mo.middle.passes.convert_data_type import data_type_str_to_np from mo.ops.concat import Concat from mo.ops.reshape import Reshape from mo.ops.shape import Shape @@ -48,18 +47,20 @@ class ReplacePoolingReshape(FrontReplacementPattern): # create Reshape before convolution # shape = [in_shape[0], pool_stride, 1, in_shape[1]/pool_stride] i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node() + + dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values shape = Cast(graph, {'name': node_name + '/to_float', - 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node() + 'dst_type': dst_dtype}).create_node() i_shape.in_port(0).connect(node.in_port(0).get_source()) shape.in_port(0).connect(i_shape.out_port(0)) N, H = node_to_get_shape_value_of_indices(shape, [0]), node_to_get_shape_value_of_indices(shape, [1]) div = create_op_with_const_inputs( - graph, Div, {1: float_array([node.pool_stride])}, {'name': node_name + '/div_stride_h'}) + graph, Div, {1: float32_array([node.pool_stride])}, {'name': node_name + '/div_stride_h'}) div.in_port(0).connect(H.out_port(0)) - concat = create_op_with_const_inputs(graph, Concat, {1: float_array([node.pool_stride]), 2: float_array([1])}, + concat = create_op_with_const_inputs(graph, Concat, {1: float32_array([node.pool_stride]), 2: float32_array([1])}, {'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0}) concat.in_port(0).connect(N.out_port(0)) concat.in_port(3).connect(div.out_port(0)) diff --git a/model-optimizer/extensions/front/kaldi/replace_lstm_node_pattern.py b/model-optimizer/extensions/front/kaldi/replace_lstm_node_pattern.py index bc74b1500ff..38704b4b286 100644 --- a/model-optimizer/extensions/front/kaldi/replace_lstm_node_pattern.py +++ b/model-optimizer/extensions/front/kaldi/replace_lstm_node_pattern.py @@ -41,7 +41,7 @@ def unique_id(prefix: str = 'id') -> str: unique_id.names = [] -def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision = np.float): +def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision=np.float32): # create init_graph connected to ReadValue graph = input_out_port.node.graph input_name = input_out_port.node.name diff --git a/model-optimizer/extensions/front/onnx/ONNXResize10ToInterpolate.py b/model-optimizer/extensions/front/onnx/ONNXResize10ToInterpolate.py index ca1c2ba71f1..ae2b28ece05 100644 --- a/model-optimizer/extensions/front/onnx/ONNXResize10ToInterpolate.py +++ b/model-optimizer/extensions/front/onnx/ONNXResize10ToInterpolate.py @@ -5,8 +5,8 @@ import logging as log import numpy as np -from extensions.ops.activation_ops import Floor from extensions.ops.Cast import Cast +from extensions.ops.activation_ops import Floor from extensions.ops.elementwise import Add, Mul from extensions.ops.interpolate import Interpolate from extensions.ops.range import Range @@ -15,7 +15,6 @@ from mo.front.common.partial_infer.utils import int64_array, float_array from mo.front.common.replacement import FrontReplacementOp from mo.front.tf.graph_utils import create_op_with_const_inputs from mo.graph.graph import Graph, Node, rename_nodes -from mo.middle.passes.convert_data_type import data_type_str_to_np from mo.ops.shape import Shape from mo.ops.strided_slice import StridedSlice @@ -79,9 +78,9 @@ def replace_resize(graph: Graph, resize: Node): {1: float_array([1.0e-5])}, {'name': resize_name + '/Add'}) - input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type) + dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values - cast_shape_to_float = Cast(graph, {'dst_type': input_data_type}).create_node() + cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node() shape_of.out_port(0).connect(cast_shape_to_float.in_port(0)) mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node([cast_shape_to_float, add_node]) diff --git a/model-optimizer/extensions/front/tf/RetinaNetFilteredDetectionsReplacement.py b/model-optimizer/extensions/front/tf/RetinaNetFilteredDetectionsReplacement.py index 3cc7a8b3120..cdcdfa0a522 100644 --- a/model-optimizer/extensions/front/tf/RetinaNetFilteredDetectionsReplacement.py +++ b/model-optimizer/extensions/front/tf/RetinaNetFilteredDetectionsReplacement.py @@ -8,7 +8,7 @@ from extensions.ops.DetectionOutput import DetectionOutput from extensions.ops.elementwise import Mul, Sub, Pow from extensions.ops.gather import Gather from extensions.ops.split import VariadicSplit -from mo.front.common.partial_infer.utils import int64_array +from mo.front.common.partial_infer.utils import int64_array, float32_array from mo.front.subgraph_matcher import SubgraphMatch from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph @@ -54,9 +54,9 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr sp_shape = Shape(graph, {'name': name + '/shape'}).create_node() priors_scale_node.out_port(0).connect(sp_shape.in_port(0)) - begin = Const(graph, {'value': np.array([-2])}).create_node() - end = Const(graph, {'value': np.array([-1])}).create_node() - stride = Const(graph, {'value': np.array([1])}).create_node() + begin = Const(graph, {'value': int64_array([-2])}).create_node() + end = Const(graph, {'value': int64_array([-1])}).create_node() + stride = Const(graph, {'value': int64_array([1])}).create_node() shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': np.array([1]), 'end_mask': np.array([1]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), @@ -72,7 +72,7 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr 'axis': int64_array(0)}, shape_part_for_tiling) - variance = Const(graph, {'name': name + '/variance', 'value': np.array(variance)}).create_node() + variance = Const(graph, {'name': name + '/variance', 'value': float32_array(variance)}).create_node() tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node() variance.out_port(0).connect(tile.in_port(0)) shape_concat.out_port(0).connect(tile.in_port(1)) @@ -113,9 +113,9 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr shape = Shape(graph, {'name': 'input_image_shape'}).create_node() shape.in_port(0).connect(placeholder.out_port(0)) - begin = Const(graph, {'value': np.array([1])}).create_node() - end = Const(graph, {'value': np.array([3])}).create_node() - stride = Const(graph, {'value': np.array([1])}).create_node() + begin = Const(graph, {'value': int64_array([1])}).create_node() + end = Const(graph, {'value': int64_array([3])}).create_node() + stride = Const(graph, {'value': int64_array([1])}).create_node() spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': np.array([1]), 'end_mask': np.array([1]), 'new_axis_mask': np.array([0]), 'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node() @@ -125,7 +125,7 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr spatial.in_port(2).connect(end.out_port(0)) spatial.in_port(3).connect(stride.out_port(0)) - power = Const(graph, {'value': np.array([-1.])}).create_node() + power = Const(graph, {'value': float32_array([-1.])}).create_node() spatial_scale = Pow(graph, {}).create_node() spatial_scale.in_port(0).connect(spatial.out_port(0)) diff --git a/model-optimizer/extensions/middle/MarkSubgraphsWithCorrectLayout.py b/model-optimizer/extensions/middle/MarkSubgraphsWithCorrectLayout.py index 2ec24fbba03..38058f76fc4 100644 --- a/model-optimizer/extensions/middle/MarkSubgraphsWithCorrectLayout.py +++ b/model-optimizer/extensions/middle/MarkSubgraphsWithCorrectLayout.py @@ -3,14 +3,12 @@ import logging as log from collections import deque - from typing import Set from extensions.middle.InsertLayoutPropagationTransposes import InsertLayoutPropagationTranspose, \ mark_as_correct_data_layout, mark_output_as_in_correct_layout, mark_input_as_in_correct_layout from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths from extensions.middle.pass_separator import PostMiddleStart -from mo.front.common.partial_infer.utils import int64_array from mo.graph.graph import Graph, Node from mo.graph.perm_inputs import PermuteInputs from mo.graph.port import Port @@ -51,7 +49,8 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern): result.append(dest_port.node) return result - def bfs(self, start_nodes: list, visited: set, condition: callable = None, forward: bool = True): + @staticmethod + def bfs(start_nodes: list, visited: set, condition: callable = None, forward: bool = True): """ The function performs BFS starting from selected nodes in forward or backward direction adding nodes by an optional condition @@ -63,7 +62,7 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern): :return: the list of Nodes visited """ assert visited is not None, 'The "visited" set must be defined' - assert start_nodes is not None and len(start_nodes) != 0, 'The list of start nodes must be specified' + assert start_nodes is not None, 'The list of start nodes must be specified' result = list() d = deque(start_nodes) @@ -72,9 +71,9 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern): result.append(cur_node) visited.add(cur_node) if forward: - next_nodes = self.get_output_nodes(cur_node) + next_nodes = MarkSubGraphsWithCorrectLayout.get_output_nodes(cur_node) else: - next_nodes = self.get_input_nodes(cur_node) + next_nodes = MarkSubGraphsWithCorrectLayout.get_input_nodes(cur_node) for next_node in next_nodes: if next_node not in visited and (condition is None or condition(next_node)): d.append(next_node) diff --git a/model-optimizer/extensions/middle/ONNXResize11ToInterpolate.py b/model-optimizer/extensions/middle/ONNXResize11ToInterpolate.py index 956e6d017cd..ffd3c27eebd 100644 --- a/model-optimizer/extensions/middle/ONNXResize11ToInterpolate.py +++ b/model-optimizer/extensions/middle/ONNXResize11ToInterpolate.py @@ -2,18 +2,18 @@ # SPDX-License-Identifier: Apache-2.0 import logging as log + import numpy as np -from extensions.ops.activation_ops import Floor from extensions.ops.Cast import Cast +from extensions.ops.activation_ops import Floor from extensions.ops.elementwise import Add, Div, Mul from extensions.ops.interpolate import Interpolate from mo.front.common.layout import get_depth_dim, get_height_dim, get_width_dim from mo.front.common.partial_infer.utils import int64_array, float_array from mo.front.tf.graph_utils import create_op_with_const_inputs -from mo.middle.passes.convert_data_type import data_type_str_to_np -from mo.middle.replacement import MiddleReplacementPattern from mo.graph.graph import Graph, Node, rename_nodes +from mo.middle.replacement import MiddleReplacementPattern from mo.ops.const import Const from mo.ops.shape import Shape from mo.ops.strided_slice import StridedSlice @@ -94,10 +94,10 @@ def replace_resize(graph: Graph, resize: Node): {1: float_array([1.0e-5])}, {'name': resize_name + '/Add'}) - input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type) + dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values if num_of_inputs == 3: - cast_shape_to_float = Cast(graph, {'dst_type': input_data_type}).create_node() + cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node() mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node() shape_of.out_port(0).connect(cast_shape_to_float.in_port(0)) cast_shape_to_float.out_port(0).connect(mul_node.in_port(0)) @@ -119,8 +119,8 @@ def replace_resize(graph: Graph, resize: Node): connection_of_resize_input.get_source().connect(shape_of.in_port(0)) connection_of_scales.get_source().connect(mul_node.in_port(1)) else: - cast_shape_to_float = Cast(graph, {'dst_type': input_data_type}).create_node() - cast_sizes_to_float = Cast(graph, {'dst_type': input_data_type}).create_node() + cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node() + cast_sizes_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node() div_node = Div(graph, {'name': resize_name + '/Div'}).create_node() cast_sizes_to_float.out_port(0).connect(div_node.in_port(0)) cast_shape_to_float.out_port(0).connect(div_node.in_port(1)) diff --git a/model-optimizer/extensions/ops/Cast.py b/model-optimizer/extensions/ops/Cast.py index 066ef1cf0d6..8f7ac73ffbe 100644 --- a/model-optimizer/extensions/ops/Cast.py +++ b/model-optimizer/extensions/ops/Cast.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import logging as log + import numpy as np -from mo.front.common.partial_infer.elemental import copy_shape_infer from mo.graph.graph import Node, Graph from mo.middle.passes.convert_data_type import np_data_type_to_precision, convert_blob, \ np_data_type_to_destination_type, packed_I4, packed_U4 diff --git a/model-optimizer/mo/utils/ir_reader/restore_graph.py b/model-optimizer/mo/utils/ir_reader/restore_graph.py index 0f92ed61c47..9f1c8e0e9e9 100644 --- a/model-optimizer/mo/utils/ir_reader/restore_graph.py +++ b/model-optimizer/mo/utils/ir_reader/restore_graph.py @@ -17,6 +17,7 @@ from mo.utils.class_registration import apply_replacements_list from mo.utils.ir_engine.ir_engine import IREngine from mo.utils.ir_reader.layer_to_class import copy_graph_with_ops, collect_extenders, collect_ops from mo.utils.utils import get_mo_root_dir +from extensions.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues def restore_graph_from_ir(path_to_xml: str, path_to_bin: str = None) -> (Graph, dict): @@ -64,6 +65,7 @@ def save_restored_graph(graph: Graph, path: str, meta_data, name=None): BlobNormalizer, ConvolutionNormalizer, KaldiRemoveMemoryOutputBackReplacementPattern, + MarkNodesWithShapeValues, ] # We need to run some specific passes from MO back stage.