[MO] Fix fp16 in shapeof subgraphs (#4524)
* Initial working solution * moved bfs_search_apply_on_shapeof_subgraph_nodes from utils/graph.py to MarkShapeOfSubgraphDataType.py * Reused bfs from MarkSubgraphsWithCorrectLayout.py * fixed e2e precomit issues: specified correct const data_types, fixed BFS search staring point to avoid nodeless shapeof subgraphs * fixed mxnet_rnnt: added converting all Const nodes in ShapeOf subgraph in MarkAndChangeDataTypeInShapeOfSubgraphs.py, revised Const values in transformations that affect ShapeOf subgraph nodes * reverter ReverseV2ToReverseSequence.py and DecomposeBidirectionalRNNSequence.py * in MarkSubgraphsWithCorrectLayout BFS search beauty applied * apply review comments, returned back 'in_shape_subgraph' attribute * graph condition added * MO IR reader fix for mixed FP16 models, added replacer order placement comment * moved to back phase * new solution with marking nodes from bottom to top (WIP) * successfully tested on back phase * corrected unittest * removed check for start nodes size in bfs * fix transformations that insert f64 to f32 in shape subgraph * corrected log.warning -> log.debug * revised list if shape input operations added unittest for Const shape inputs * applied @lazarevevgeny's comments * licence head corrections
This commit is contained in:
parent
e072c96237
commit
6477e8ec01
@ -11,6 +11,7 @@ extensions/back/__init__.py
|
|||||||
extensions/back/AvgPool.py
|
extensions/back/AvgPool.py
|
||||||
extensions/back/blob_normalizer.py
|
extensions/back/blob_normalizer.py
|
||||||
extensions/back/CellNormalizer.py
|
extensions/back/CellNormalizer.py
|
||||||
|
extensions/back/ChangeCastOutputType.py
|
||||||
extensions/back/ClampNormalizer.py
|
extensions/back/ClampNormalizer.py
|
||||||
extensions/back/compress_quantized_weights.py
|
extensions/back/compress_quantized_weights.py
|
||||||
extensions/back/ConvolutionNormalizer.py
|
extensions/back/ConvolutionNormalizer.py
|
||||||
@ -32,6 +33,7 @@ extensions/back/LayoutChangeForGatherND.py
|
|||||||
extensions/back/LeakyReLUMutation.py
|
extensions/back/LeakyReLUMutation.py
|
||||||
extensions/back/LinearToLinearONNXReplacer.py
|
extensions/back/LinearToLinearONNXReplacer.py
|
||||||
extensions/back/LRNToNorm.py
|
extensions/back/LRNToNorm.py
|
||||||
|
extensions/back/MarkNodesWithShapeValues.py
|
||||||
extensions/back/MatMulNormalizer.py
|
extensions/back/MatMulNormalizer.py
|
||||||
extensions/back/MaxPool.py
|
extensions/back/MaxPool.py
|
||||||
extensions/back/NormalizeToNormalizeL2.py
|
extensions/back/NormalizeToNormalizeL2.py
|
||||||
@ -125,7 +127,6 @@ extensions/front/caffe/softmax_ext.py
|
|||||||
extensions/front/caffe/spatial_transformer_ext.py
|
extensions/front/caffe/spatial_transformer_ext.py
|
||||||
extensions/front/caffe/split_to_identity.py
|
extensions/front/caffe/split_to_identity.py
|
||||||
extensions/front/caffe/tanh.py
|
extensions/front/caffe/tanh.py
|
||||||
extensions/front/ChangeCastOutputType.py
|
|
||||||
extensions/front/ChangePlaceholderTypes.py
|
extensions/front/ChangePlaceholderTypes.py
|
||||||
extensions/front/create_tensor_nodes.py
|
extensions/front/create_tensor_nodes.py
|
||||||
extensions/front/disable_weights_quantize_value_propagation.py
|
extensions/front/disable_weights_quantize_value_propagation.py
|
||||||
|
43
model-optimizer/extensions/back/ChangeCastOutputType.py
Normal file
43
model-optimizer/extensions/back/ChangeCastOutputType.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import logging as log
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mo.back.replacement import BackReplacementPattern
|
||||||
|
from mo.graph.graph import Graph
|
||||||
|
from mo.middle.passes.convert_data_type import data_type_str_to_np
|
||||||
|
|
||||||
|
|
||||||
|
class ChangeCastOutputType(BackReplacementPattern):
|
||||||
|
"""
|
||||||
|
Change the Cast dst_type from fp64 to fp32 since not all plugins support fp64 data type.
|
||||||
|
Change the Cast dst_type from fp32 to fp16 when generating IR for fp16.
|
||||||
|
But leave fp32 if node returns shape value even if --data_type=FP16 (look extensions/back/MarkNodesWithShapeValues.py).
|
||||||
|
"""
|
||||||
|
enabled = True
|
||||||
|
force_shape_inference = True
|
||||||
|
|
||||||
|
def run_after(self):
|
||||||
|
from extensions.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues
|
||||||
|
return [MarkNodesWithShapeValues]
|
||||||
|
|
||||||
|
def run_before(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def find_and_replace_pattern(self, graph: Graph):
|
||||||
|
for node in graph.get_op_nodes(op='Cast'):
|
||||||
|
if node.dst_type == np.float64:
|
||||||
|
log.warning('Change data type from {} to {} for node {}'.format(node.dst_type, np.float32, node.name))
|
||||||
|
node.dst_type = np.float32
|
||||||
|
|
||||||
|
ir_data_type = data_type_str_to_np(node.graph.graph['cmd_params'].data_type)
|
||||||
|
if node.dst_type == np.float32 and ir_data_type == np.float16 and not node.has_and_set('returns_shape_value'):
|
||||||
|
log.warning('Change data type from {} to {} for node {}'.format(node.dst_type, ir_data_type, node.name))
|
||||||
|
node.dst_type = ir_data_type
|
||||||
|
elif node.has_and_set('returns_shape_value') and node.dst_type == np.float16:
|
||||||
|
# return back FP32 for all Convert nodes with shape values
|
||||||
|
log.warning('Change data type from {} to {} for node {} in ShapeOf subgraph'.
|
||||||
|
format(node.dst_type, np.float32, node.name))
|
||||||
|
node.dst_type = np.float32
|
78
model-optimizer/extensions/back/MarkNodesWithShapeValues.py
Normal file
78
model-optimizer/extensions/back/MarkNodesWithShapeValues.py
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import logging as log
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from extensions.middle.MarkSubgraphsWithCorrectLayout import MarkSubGraphsWithCorrectLayout
|
||||||
|
from mo.back.replacement import BackReplacementPattern
|
||||||
|
from mo.graph.graph import Graph
|
||||||
|
|
||||||
|
|
||||||
|
class MarkNodesWithShapeValues(BackReplacementPattern):
|
||||||
|
"""
|
||||||
|
This transformation marks op nodes in ShapeOf subgraphs with 'returns_shape_value' bool attribute and
|
||||||
|
data nodes of float32 constants with 'correct_data_type' attribute.
|
||||||
|
So that float Consts and Cast float will be kept in FP32 even if argument --data_type=FP16 is specified.
|
||||||
|
|
||||||
|
This is needed to enable conversion to FP16 even if values in ShapeOf subgraphs exceed max(float16)
|
||||||
|
or because of FP16 lower precession shape inference is incorrect on some nodes (e.g. if Interpolate in scales mode
|
||||||
|
accepts values from ShapeOf subgraph).
|
||||||
|
|
||||||
|
This transformation should be executed after shape inference and after all transformations which insert/modify
|
||||||
|
Cast nodes in ShapeOf subgraphs therefore it's placed at the end of the back phase.
|
||||||
|
"""
|
||||||
|
enabled = True
|
||||||
|
graph_condition = [lambda graph: graph.graph['cmd_params'].data_type == 'FP16']
|
||||||
|
|
||||||
|
def run_after(self):
|
||||||
|
from extensions.back.pass_separator import BackFinish
|
||||||
|
return [BackFinish]
|
||||||
|
|
||||||
|
def run_before(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_operations_with_shape_inputs():
|
||||||
|
return {
|
||||||
|
'Interpolate': [1, 2], # sizes, scales inputs
|
||||||
|
'Reshape': [1], # shape
|
||||||
|
'Broadcast': [1], # target_shape
|
||||||
|
'ConvBackPropData ': [2], # output_shape
|
||||||
|
'GroupConvolutionBackpropData ': [2], # output_shape
|
||||||
|
'BatchToSpace': [1, 2, 3], # block_shape, crops_begin, crops_end
|
||||||
|
'SpaceToBatch': [1, 2, 3], # block_shape, pads_begin, pads_end
|
||||||
|
'StridedSlice': [1, 2, 3], # begin, end, strides
|
||||||
|
'VariadicSplit': [2], # split_lengths
|
||||||
|
'Tile': [1], # repeats input
|
||||||
|
'TopK': [1], # K input
|
||||||
|
'Pad': [1, 2], # pads_begin, pads_end
|
||||||
|
'Range': [0, 1, 2], # start, stop, step inputs
|
||||||
|
'OneHot': [1], # depth input
|
||||||
|
}
|
||||||
|
|
||||||
|
def find_and_replace_pattern(self, graph: Graph):
|
||||||
|
shape_input_ops_map = self.get_operations_with_shape_inputs()
|
||||||
|
|
||||||
|
nodes_with_shape_inputs = []
|
||||||
|
for node in graph.get_op_nodes():
|
||||||
|
if node.soft_get('type') in shape_input_ops_map:
|
||||||
|
nodes_with_shape_inputs.append(node)
|
||||||
|
|
||||||
|
start_nodes = []
|
||||||
|
for node in nodes_with_shape_inputs:
|
||||||
|
start_nodes.extend(
|
||||||
|
[node.in_port(port_idx).get_source().node for port_idx in shape_input_ops_map[node.type] if
|
||||||
|
node.is_in_port_connected(port_idx)])
|
||||||
|
|
||||||
|
condition = lambda node: node.soft_get('type') != 'ShapeOf'
|
||||||
|
nodes_with_shape_values = MarkSubGraphsWithCorrectLayout.bfs(start_nodes, set(), condition, forward=False)
|
||||||
|
for node in nodes_with_shape_values:
|
||||||
|
node['returns_shape_value'] = True
|
||||||
|
if node.soft_get('type') == 'Const':
|
||||||
|
if node.value.dtype == np.float32:
|
||||||
|
node.out_node(0)['correct_data_type'] = True
|
||||||
|
elif node.value.dtype in [np.float16, np.float64]:
|
||||||
|
log.debug('Const nodes {} with shape values have {} type'.format(node.soft_get('name', node.id),
|
||||||
|
node.value.dtype))
|
103
model-optimizer/extensions/back/MarkNodesWithShapeValues_test.py
Normal file
103
model-optimizer/extensions/back/MarkNodesWithShapeValues_test.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
# Copyright (C) 2018-2021 Intel Corporation
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from extensions.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues
|
||||||
|
from mo.front.common.partial_infer.utils import int64_array, float32_array
|
||||||
|
from mo.graph.graph import Node
|
||||||
|
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||||
|
from mo.utils.unittest.graph import build_graph
|
||||||
|
from mo.utils.unittest.graph import result, regular_op_with_empty_data, \
|
||||||
|
shaped_const_with_data, connect, regular_op
|
||||||
|
|
||||||
|
|
||||||
|
class TestMarkDataTypeInShapeOfSubgraphs(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_run_with_shape_subgraph_input(self):
|
||||||
|
inp_shape = (1, 3, 1000, 1000)
|
||||||
|
dst_type = np.float32
|
||||||
|
|
||||||
|
nodes = {
|
||||||
|
**shaped_const_with_data('input', int64_array(inp_shape)),
|
||||||
|
**regular_op_with_empty_data('shape', {'type': 'ShapeOf'}),
|
||||||
|
**regular_op_with_empty_data('cast_to_float', {'type': 'Cast', 'dst_type': dst_type}),
|
||||||
|
**regular_op('mul_const', {'op': 'Const'}),
|
||||||
|
**{'mul_const_d': {'kind': 'data', 'value': float32_array([1., 1., 1., 100.])}},
|
||||||
|
**regular_op_with_empty_data('mul', {'type': 'Mul'}),
|
||||||
|
**regular_op_with_empty_data('cast_to_int', {'type': 'Cast', 'dst_type': np.int64}),
|
||||||
|
**regular_op_with_empty_data('interpolate', {'type': 'Interpolate', 'shape_calculation_model': 'scales'}),
|
||||||
|
**result('res'),
|
||||||
|
}
|
||||||
|
|
||||||
|
nodes_ref = {
|
||||||
|
**shaped_const_with_data('input', int64_array(inp_shape)),
|
||||||
|
**regular_op_with_empty_data('shape', {'type': 'ShapeOf'}),
|
||||||
|
**regular_op_with_empty_data('cast_to_float', {'type': 'Cast', 'dst_type': dst_type,
|
||||||
|
'returns_shape_value': True}),
|
||||||
|
**regular_op_with_empty_data('mul', {'type': 'Mul', 'returns_shape_value': True}),
|
||||||
|
**regular_op('mul_const', {'op': 'Const', 'returns_shape_value': True}),
|
||||||
|
**{'mul_const_d': {'kind': 'data', 'value': float32_array([1., 1., 1., 100.]),
|
||||||
|
'correct_data_type': True}},
|
||||||
|
**regular_op_with_empty_data('cast_to_int', {'type': 'Cast', 'dst_type': np.int64,
|
||||||
|
'returns_shape_value': True}),
|
||||||
|
**regular_op_with_empty_data('interpolate', {'type': 'Interpolate', 'shape_calculation_model': 'scales'}),
|
||||||
|
**result('res'),
|
||||||
|
}
|
||||||
|
|
||||||
|
edges = [
|
||||||
|
*connect('input', '0:interpolate'),
|
||||||
|
*connect('input', '0:shape', skip_data=True),
|
||||||
|
*connect('shape', '0:cast_to_float'),
|
||||||
|
*connect('cast_to_float', '0:mul'),
|
||||||
|
*connect('mul_const', '1:mul'),
|
||||||
|
*connect('mul', '0:cast_to_int'),
|
||||||
|
*connect('cast_to_int', '1:interpolate'),
|
||||||
|
*connect('interpolate', 'res'),
|
||||||
|
]
|
||||||
|
graph = build_graph(nodes, edges)
|
||||||
|
interp_node = Node(graph, 'interpolate')
|
||||||
|
interp_node.add_input_port(2)
|
||||||
|
|
||||||
|
MarkNodesWithShapeValues().find_and_replace_pattern(graph)
|
||||||
|
|
||||||
|
graph_ref = build_graph(nodes_ref, edges)
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_run_with_const_input(self):
|
||||||
|
inp_shape = (1, 3, 1000, 1000)
|
||||||
|
dst_type = np.float32
|
||||||
|
|
||||||
|
nodes = {
|
||||||
|
**shaped_const_with_data('input', int64_array(inp_shape)),
|
||||||
|
**regular_op('sizes_const', {'op': 'Const'}),
|
||||||
|
**{'sizes_const_d': {'kind': 'data', 'value': float32_array([1., 1., 1., 100.])}},
|
||||||
|
**regular_op_with_empty_data('interpolate', {'type': 'Interpolate', 'shape_calculation_model': 'scales'}),
|
||||||
|
**result('res'),
|
||||||
|
}
|
||||||
|
|
||||||
|
nodes_ref = {
|
||||||
|
**shaped_const_with_data('input', int64_array(inp_shape)),
|
||||||
|
**regular_op('sizes_const', {'op': 'Const', 'returns_shape_value': True}),
|
||||||
|
**{'sizes_const_d': {'kind': 'data', 'value': float32_array([1., 1., 1., 100.])}},
|
||||||
|
**regular_op_with_empty_data('interpolate', {'type': 'Interpolate', 'shape_calculation_model': 'scales'}),
|
||||||
|
**result('res'),
|
||||||
|
}
|
||||||
|
|
||||||
|
edges = [
|
||||||
|
*connect('input', '0:interpolate'),
|
||||||
|
*connect('sizes_const', '1:interpolate'),
|
||||||
|
*connect('interpolate', 'res'),
|
||||||
|
]
|
||||||
|
graph = build_graph(nodes, edges)
|
||||||
|
interp_node = Node(graph, 'interpolate')
|
||||||
|
interp_node.add_input_port(2)
|
||||||
|
|
||||||
|
MarkNodesWithShapeValues().find_and_replace_pattern(graph)
|
||||||
|
|
||||||
|
graph_ref = build_graph(nodes_ref, edges)
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
@ -1,38 +0,0 @@
|
|||||||
# Copyright (C) 2018-2021 Intel Corporation
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
|
|
||||||
import logging as log
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from mo.front.common.replacement import FrontReplacementSubgraph
|
|
||||||
from mo.front.subgraph_matcher import SubgraphMatch
|
|
||||||
from mo.graph.graph import Graph
|
|
||||||
from mo.middle.passes.convert_data_type import data_type_str_to_np
|
|
||||||
|
|
||||||
|
|
||||||
class ChangeCastOutputType(FrontReplacementSubgraph):
|
|
||||||
"""
|
|
||||||
Change the Cast to fp64 to fp32 since not all plugins support fp64 data type.
|
|
||||||
Change the Cast to fp32 to fp16 when generating IR for fp16.
|
|
||||||
"""
|
|
||||||
enabled = True
|
|
||||||
|
|
||||||
def pattern(self):
|
|
||||||
return dict(
|
|
||||||
nodes=[
|
|
||||||
('cast', dict(op='Cast'))
|
|
||||||
],
|
|
||||||
edges=[]
|
|
||||||
)
|
|
||||||
|
|
||||||
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
|
|
||||||
node = match['cast']
|
|
||||||
if node.dst_type == np.float64:
|
|
||||||
log.warning('Change data type from {} to {} for node {}'.format(node.dst_type, np.float32, node.name))
|
|
||||||
node.dst_type = np.float32
|
|
||||||
|
|
||||||
ir_data_type = data_type_str_to_np(node.graph.graph['cmd_params'].data_type)
|
|
||||||
if node.dst_type == np.float32 and ir_data_type == np.float16:
|
|
||||||
log.warning('Change data type from {} to {} for node {}'.format(node.dst_type, ir_data_type, node.name))
|
|
||||||
node.dst_type = ir_data_type
|
|
@ -5,11 +5,10 @@ import numpy as np
|
|||||||
|
|
||||||
from extensions.ops.Cast import Cast
|
from extensions.ops.Cast import Cast
|
||||||
from extensions.ops.elementwise import Div
|
from extensions.ops.elementwise import Div
|
||||||
from mo.front.common.partial_infer.utils import int64_array, float_array
|
from mo.front.common.partial_infer.utils import int64_array, float32_array
|
||||||
from mo.front.common.replacement import FrontReplacementPattern
|
from mo.front.common.replacement import FrontReplacementPattern
|
||||||
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
|
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
|
||||||
from mo.graph.graph import Graph
|
from mo.graph.graph import Graph
|
||||||
from mo.middle.passes.convert_data_type import data_type_str_to_np
|
|
||||||
from mo.ops.concat import Concat
|
from mo.ops.concat import Concat
|
||||||
from mo.ops.reshape import Reshape
|
from mo.ops.reshape import Reshape
|
||||||
from mo.ops.shape import Shape
|
from mo.ops.shape import Shape
|
||||||
@ -46,21 +45,23 @@ class ReplaceConvolutionReshape(FrontReplacementPattern):
|
|||||||
node = match['conv']
|
node = match['conv']
|
||||||
node_name = node.soft_get('name', node.id)
|
node_name = node.soft_get('name', node.id)
|
||||||
|
|
||||||
|
dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values
|
||||||
|
|
||||||
# create Reshape before convolution
|
# create Reshape before convolution
|
||||||
# shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
|
# shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
|
||||||
i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
|
i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
|
||||||
shape = Cast(graph, {'name': node_name + '/to_float',
|
shape = Cast(graph, {'name': node_name + '/to_float',
|
||||||
'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
|
'dst_type': dst_dtype}).create_node()
|
||||||
i_shape.in_port(0).connect(node.in_port(0).get_source())
|
i_shape.in_port(0).connect(node.in_port(0).get_source())
|
||||||
shape.in_port(0).connect(i_shape.out_port(0))
|
shape.in_port(0).connect(i_shape.out_port(0))
|
||||||
|
|
||||||
N, H = node_to_get_shape_value_of_indices(shape, [0]), node_to_get_shape_value_of_indices(shape, [1])
|
N, H = node_to_get_shape_value_of_indices(shape, [0]), node_to_get_shape_value_of_indices(shape, [1])
|
||||||
|
|
||||||
div = create_op_with_const_inputs(
|
div = create_op_with_const_inputs(
|
||||||
graph, Div, {1: float_array([node.patch_stride])}, {'name': node_name + '/div_stride_h'})
|
graph, Div, {1: float32_array([node.patch_stride])}, {'name': node_name + '/div_stride_h'})
|
||||||
div.in_port(0).connect(H.out_port(0))
|
div.in_port(0).connect(H.out_port(0))
|
||||||
|
|
||||||
concat = create_op_with_const_inputs(graph, Concat, {2: float_array([1]), 3: float_array([node.patch_stride])},
|
concat = create_op_with_const_inputs(graph, Concat, {2: float32_array([1]), 3: float32_array([node.patch_stride])},
|
||||||
{'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})
|
{'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})
|
||||||
concat.in_port(0).connect(N.out_port(0))
|
concat.in_port(0).connect(N.out_port(0))
|
||||||
concat.in_port(1).connect(div.out_port(0))
|
concat.in_port(1).connect(div.out_port(0))
|
||||||
|
@ -5,11 +5,10 @@ import numpy as np
|
|||||||
|
|
||||||
from extensions.ops.Cast import Cast
|
from extensions.ops.Cast import Cast
|
||||||
from extensions.ops.elementwise import Div
|
from extensions.ops.elementwise import Div
|
||||||
from mo.front.common.partial_infer.utils import int64_array, float_array
|
from mo.front.common.partial_infer.utils import int64_array, float32_array
|
||||||
from mo.front.common.replacement import FrontReplacementPattern
|
from mo.front.common.replacement import FrontReplacementPattern
|
||||||
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
|
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
|
||||||
from mo.graph.graph import Graph
|
from mo.graph.graph import Graph
|
||||||
from mo.middle.passes.convert_data_type import data_type_str_to_np
|
|
||||||
from mo.ops.concat import Concat
|
from mo.ops.concat import Concat
|
||||||
from mo.ops.reshape import Reshape
|
from mo.ops.reshape import Reshape
|
||||||
from mo.ops.shape import Shape
|
from mo.ops.shape import Shape
|
||||||
@ -48,18 +47,20 @@ class ReplacePoolingReshape(FrontReplacementPattern):
|
|||||||
# create Reshape before convolution
|
# create Reshape before convolution
|
||||||
# shape = [in_shape[0], pool_stride, 1, in_shape[1]/pool_stride]
|
# shape = [in_shape[0], pool_stride, 1, in_shape[1]/pool_stride]
|
||||||
i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
|
i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
|
||||||
|
|
||||||
|
dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values
|
||||||
shape = Cast(graph, {'name': node_name + '/to_float',
|
shape = Cast(graph, {'name': node_name + '/to_float',
|
||||||
'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
|
'dst_type': dst_dtype}).create_node()
|
||||||
i_shape.in_port(0).connect(node.in_port(0).get_source())
|
i_shape.in_port(0).connect(node.in_port(0).get_source())
|
||||||
shape.in_port(0).connect(i_shape.out_port(0))
|
shape.in_port(0).connect(i_shape.out_port(0))
|
||||||
|
|
||||||
N, H = node_to_get_shape_value_of_indices(shape, [0]), node_to_get_shape_value_of_indices(shape, [1])
|
N, H = node_to_get_shape_value_of_indices(shape, [0]), node_to_get_shape_value_of_indices(shape, [1])
|
||||||
|
|
||||||
div = create_op_with_const_inputs(
|
div = create_op_with_const_inputs(
|
||||||
graph, Div, {1: float_array([node.pool_stride])}, {'name': node_name + '/div_stride_h'})
|
graph, Div, {1: float32_array([node.pool_stride])}, {'name': node_name + '/div_stride_h'})
|
||||||
div.in_port(0).connect(H.out_port(0))
|
div.in_port(0).connect(H.out_port(0))
|
||||||
|
|
||||||
concat = create_op_with_const_inputs(graph, Concat, {1: float_array([node.pool_stride]), 2: float_array([1])},
|
concat = create_op_with_const_inputs(graph, Concat, {1: float32_array([node.pool_stride]), 2: float32_array([1])},
|
||||||
{'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})
|
{'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})
|
||||||
concat.in_port(0).connect(N.out_port(0))
|
concat.in_port(0).connect(N.out_port(0))
|
||||||
concat.in_port(3).connect(div.out_port(0))
|
concat.in_port(3).connect(div.out_port(0))
|
||||||
|
@ -41,7 +41,7 @@ def unique_id(prefix: str = 'id') -> str:
|
|||||||
unique_id.names = []
|
unique_id.names = []
|
||||||
|
|
||||||
|
|
||||||
def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision = np.float):
|
def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision=np.float32):
|
||||||
# create init_graph connected to ReadValue
|
# create init_graph connected to ReadValue
|
||||||
graph = input_out_port.node.graph
|
graph = input_out_port.node.graph
|
||||||
input_name = input_out_port.node.name
|
input_name = input_out_port.node.name
|
||||||
|
@ -5,8 +5,8 @@ import logging as log
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from extensions.ops.activation_ops import Floor
|
|
||||||
from extensions.ops.Cast import Cast
|
from extensions.ops.Cast import Cast
|
||||||
|
from extensions.ops.activation_ops import Floor
|
||||||
from extensions.ops.elementwise import Add, Mul
|
from extensions.ops.elementwise import Add, Mul
|
||||||
from extensions.ops.interpolate import Interpolate
|
from extensions.ops.interpolate import Interpolate
|
||||||
from extensions.ops.range import Range
|
from extensions.ops.range import Range
|
||||||
@ -15,7 +15,6 @@ from mo.front.common.partial_infer.utils import int64_array, float_array
|
|||||||
from mo.front.common.replacement import FrontReplacementOp
|
from mo.front.common.replacement import FrontReplacementOp
|
||||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||||
from mo.graph.graph import Graph, Node, rename_nodes
|
from mo.graph.graph import Graph, Node, rename_nodes
|
||||||
from mo.middle.passes.convert_data_type import data_type_str_to_np
|
|
||||||
from mo.ops.shape import Shape
|
from mo.ops.shape import Shape
|
||||||
from mo.ops.strided_slice import StridedSlice
|
from mo.ops.strided_slice import StridedSlice
|
||||||
|
|
||||||
@ -79,9 +78,9 @@ def replace_resize(graph: Graph, resize: Node):
|
|||||||
{1: float_array([1.0e-5])},
|
{1: float_array([1.0e-5])},
|
||||||
{'name': resize_name + '/Add'})
|
{'name': resize_name + '/Add'})
|
||||||
|
|
||||||
input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
|
dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values
|
||||||
|
|
||||||
cast_shape_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
|
cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()
|
||||||
|
|
||||||
shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
|
shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
|
||||||
mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node([cast_shape_to_float, add_node])
|
mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node([cast_shape_to_float, add_node])
|
||||||
|
@ -8,7 +8,7 @@ from extensions.ops.DetectionOutput import DetectionOutput
|
|||||||
from extensions.ops.elementwise import Mul, Sub, Pow
|
from extensions.ops.elementwise import Mul, Sub, Pow
|
||||||
from extensions.ops.gather import Gather
|
from extensions.ops.gather import Gather
|
||||||
from extensions.ops.split import VariadicSplit
|
from extensions.ops.split import VariadicSplit
|
||||||
from mo.front.common.partial_infer.utils import int64_array
|
from mo.front.common.partial_infer.utils import int64_array, float32_array
|
||||||
from mo.front.subgraph_matcher import SubgraphMatch
|
from mo.front.subgraph_matcher import SubgraphMatch
|
||||||
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
|
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
|
||||||
from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph
|
from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph
|
||||||
@ -54,9 +54,9 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr
|
|||||||
sp_shape = Shape(graph, {'name': name + '/shape'}).create_node()
|
sp_shape = Shape(graph, {'name': name + '/shape'}).create_node()
|
||||||
priors_scale_node.out_port(0).connect(sp_shape.in_port(0))
|
priors_scale_node.out_port(0).connect(sp_shape.in_port(0))
|
||||||
|
|
||||||
begin = Const(graph, {'value': np.array([-2])}).create_node()
|
begin = Const(graph, {'value': int64_array([-2])}).create_node()
|
||||||
end = Const(graph, {'value': np.array([-1])}).create_node()
|
end = Const(graph, {'value': int64_array([-1])}).create_node()
|
||||||
stride = Const(graph, {'value': np.array([1])}).create_node()
|
stride = Const(graph, {'value': int64_array([1])}).create_node()
|
||||||
shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': np.array([1]),
|
shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': np.array([1]),
|
||||||
'end_mask': np.array([1]), 'new_axis_mask': np.array([0]),
|
'end_mask': np.array([1]), 'new_axis_mask': np.array([0]),
|
||||||
'shrink_axis_mask': np.array([0]),
|
'shrink_axis_mask': np.array([0]),
|
||||||
@ -72,7 +72,7 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr
|
|||||||
'axis': int64_array(0)},
|
'axis': int64_array(0)},
|
||||||
shape_part_for_tiling)
|
shape_part_for_tiling)
|
||||||
|
|
||||||
variance = Const(graph, {'name': name + '/variance', 'value': np.array(variance)}).create_node()
|
variance = Const(graph, {'name': name + '/variance', 'value': float32_array(variance)}).create_node()
|
||||||
tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node()
|
tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node()
|
||||||
variance.out_port(0).connect(tile.in_port(0))
|
variance.out_port(0).connect(tile.in_port(0))
|
||||||
shape_concat.out_port(0).connect(tile.in_port(1))
|
shape_concat.out_port(0).connect(tile.in_port(1))
|
||||||
@ -113,9 +113,9 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr
|
|||||||
shape = Shape(graph, {'name': 'input_image_shape'}).create_node()
|
shape = Shape(graph, {'name': 'input_image_shape'}).create_node()
|
||||||
shape.in_port(0).connect(placeholder.out_port(0))
|
shape.in_port(0).connect(placeholder.out_port(0))
|
||||||
|
|
||||||
begin = Const(graph, {'value': np.array([1])}).create_node()
|
begin = Const(graph, {'value': int64_array([1])}).create_node()
|
||||||
end = Const(graph, {'value': np.array([3])}).create_node()
|
end = Const(graph, {'value': int64_array([3])}).create_node()
|
||||||
stride = Const(graph, {'value': np.array([1])}).create_node()
|
stride = Const(graph, {'value': int64_array([1])}).create_node()
|
||||||
spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': np.array([1]),
|
spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': np.array([1]),
|
||||||
'end_mask': np.array([1]), 'new_axis_mask': np.array([0]),
|
'end_mask': np.array([1]), 'new_axis_mask': np.array([0]),
|
||||||
'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()
|
'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()
|
||||||
@ -125,7 +125,7 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr
|
|||||||
spatial.in_port(2).connect(end.out_port(0))
|
spatial.in_port(2).connect(end.out_port(0))
|
||||||
spatial.in_port(3).connect(stride.out_port(0))
|
spatial.in_port(3).connect(stride.out_port(0))
|
||||||
|
|
||||||
power = Const(graph, {'value': np.array([-1.])}).create_node()
|
power = Const(graph, {'value': float32_array([-1.])}).create_node()
|
||||||
spatial_scale = Pow(graph, {}).create_node()
|
spatial_scale = Pow(graph, {}).create_node()
|
||||||
|
|
||||||
spatial_scale.in_port(0).connect(spatial.out_port(0))
|
spatial_scale.in_port(0).connect(spatial.out_port(0))
|
||||||
|
@ -3,14 +3,12 @@
|
|||||||
|
|
||||||
import logging as log
|
import logging as log
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
from typing import Set
|
from typing import Set
|
||||||
|
|
||||||
from extensions.middle.InsertLayoutPropagationTransposes import InsertLayoutPropagationTranspose, \
|
from extensions.middle.InsertLayoutPropagationTransposes import InsertLayoutPropagationTranspose, \
|
||||||
mark_as_correct_data_layout, mark_output_as_in_correct_layout, mark_input_as_in_correct_layout
|
mark_as_correct_data_layout, mark_output_as_in_correct_layout, mark_input_as_in_correct_layout
|
||||||
from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths
|
from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths
|
||||||
from extensions.middle.pass_separator import PostMiddleStart
|
from extensions.middle.pass_separator import PostMiddleStart
|
||||||
from mo.front.common.partial_infer.utils import int64_array
|
|
||||||
from mo.graph.graph import Graph, Node
|
from mo.graph.graph import Graph, Node
|
||||||
from mo.graph.perm_inputs import PermuteInputs
|
from mo.graph.perm_inputs import PermuteInputs
|
||||||
from mo.graph.port import Port
|
from mo.graph.port import Port
|
||||||
@ -51,7 +49,8 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
|
|||||||
result.append(dest_port.node)
|
result.append(dest_port.node)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def bfs(self, start_nodes: list, visited: set, condition: callable = None, forward: bool = True):
|
@staticmethod
|
||||||
|
def bfs(start_nodes: list, visited: set, condition: callable = None, forward: bool = True):
|
||||||
"""
|
"""
|
||||||
The function performs BFS starting from selected nodes in forward or backward direction adding nodes by an
|
The function performs BFS starting from selected nodes in forward or backward direction adding nodes by an
|
||||||
optional condition
|
optional condition
|
||||||
@ -63,7 +62,7 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
|
|||||||
:return: the list of Nodes visited
|
:return: the list of Nodes visited
|
||||||
"""
|
"""
|
||||||
assert visited is not None, 'The "visited" set must be defined'
|
assert visited is not None, 'The "visited" set must be defined'
|
||||||
assert start_nodes is not None and len(start_nodes) != 0, 'The list of start nodes must be specified'
|
assert start_nodes is not None, 'The list of start nodes must be specified'
|
||||||
|
|
||||||
result = list()
|
result = list()
|
||||||
d = deque(start_nodes)
|
d = deque(start_nodes)
|
||||||
@ -72,9 +71,9 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
|
|||||||
result.append(cur_node)
|
result.append(cur_node)
|
||||||
visited.add(cur_node)
|
visited.add(cur_node)
|
||||||
if forward:
|
if forward:
|
||||||
next_nodes = self.get_output_nodes(cur_node)
|
next_nodes = MarkSubGraphsWithCorrectLayout.get_output_nodes(cur_node)
|
||||||
else:
|
else:
|
||||||
next_nodes = self.get_input_nodes(cur_node)
|
next_nodes = MarkSubGraphsWithCorrectLayout.get_input_nodes(cur_node)
|
||||||
for next_node in next_nodes:
|
for next_node in next_nodes:
|
||||||
if next_node not in visited and (condition is None or condition(next_node)):
|
if next_node not in visited and (condition is None or condition(next_node)):
|
||||||
d.append(next_node)
|
d.append(next_node)
|
||||||
|
@ -2,18 +2,18 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import logging as log
|
import logging as log
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from extensions.ops.activation_ops import Floor
|
|
||||||
from extensions.ops.Cast import Cast
|
from extensions.ops.Cast import Cast
|
||||||
|
from extensions.ops.activation_ops import Floor
|
||||||
from extensions.ops.elementwise import Add, Div, Mul
|
from extensions.ops.elementwise import Add, Div, Mul
|
||||||
from extensions.ops.interpolate import Interpolate
|
from extensions.ops.interpolate import Interpolate
|
||||||
from mo.front.common.layout import get_depth_dim, get_height_dim, get_width_dim
|
from mo.front.common.layout import get_depth_dim, get_height_dim, get_width_dim
|
||||||
from mo.front.common.partial_infer.utils import int64_array, float_array
|
from mo.front.common.partial_infer.utils import int64_array, float_array
|
||||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||||
from mo.middle.passes.convert_data_type import data_type_str_to_np
|
|
||||||
from mo.middle.replacement import MiddleReplacementPattern
|
|
||||||
from mo.graph.graph import Graph, Node, rename_nodes
|
from mo.graph.graph import Graph, Node, rename_nodes
|
||||||
|
from mo.middle.replacement import MiddleReplacementPattern
|
||||||
from mo.ops.const import Const
|
from mo.ops.const import Const
|
||||||
from mo.ops.shape import Shape
|
from mo.ops.shape import Shape
|
||||||
from mo.ops.strided_slice import StridedSlice
|
from mo.ops.strided_slice import StridedSlice
|
||||||
@ -94,10 +94,10 @@ def replace_resize(graph: Graph, resize: Node):
|
|||||||
{1: float_array([1.0e-5])},
|
{1: float_array([1.0e-5])},
|
||||||
{'name': resize_name + '/Add'})
|
{'name': resize_name + '/Add'})
|
||||||
|
|
||||||
input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
|
dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values
|
||||||
|
|
||||||
if num_of_inputs == 3:
|
if num_of_inputs == 3:
|
||||||
cast_shape_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
|
cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()
|
||||||
mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node()
|
mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node()
|
||||||
shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
|
shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
|
||||||
cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))
|
cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))
|
||||||
@ -119,8 +119,8 @@ def replace_resize(graph: Graph, resize: Node):
|
|||||||
connection_of_resize_input.get_source().connect(shape_of.in_port(0))
|
connection_of_resize_input.get_source().connect(shape_of.in_port(0))
|
||||||
connection_of_scales.get_source().connect(mul_node.in_port(1))
|
connection_of_scales.get_source().connect(mul_node.in_port(1))
|
||||||
else:
|
else:
|
||||||
cast_shape_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
|
cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()
|
||||||
cast_sizes_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
|
cast_sizes_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()
|
||||||
div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()
|
div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()
|
||||||
cast_sizes_to_float.out_port(0).connect(div_node.in_port(0))
|
cast_sizes_to_float.out_port(0).connect(div_node.in_port(0))
|
||||||
cast_shape_to_float.out_port(0).connect(div_node.in_port(1))
|
cast_shape_to_float.out_port(0).connect(div_node.in_port(1))
|
||||||
|
@ -2,9 +2,9 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import logging as log
|
import logging as log
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mo.front.common.partial_infer.elemental import copy_shape_infer
|
|
||||||
from mo.graph.graph import Node, Graph
|
from mo.graph.graph import Node, Graph
|
||||||
from mo.middle.passes.convert_data_type import np_data_type_to_precision, convert_blob, \
|
from mo.middle.passes.convert_data_type import np_data_type_to_precision, convert_blob, \
|
||||||
np_data_type_to_destination_type, packed_I4, packed_U4
|
np_data_type_to_destination_type, packed_I4, packed_U4
|
||||||
|
@ -17,6 +17,7 @@ from mo.utils.class_registration import apply_replacements_list
|
|||||||
from mo.utils.ir_engine.ir_engine import IREngine
|
from mo.utils.ir_engine.ir_engine import IREngine
|
||||||
from mo.utils.ir_reader.layer_to_class import copy_graph_with_ops, collect_extenders, collect_ops
|
from mo.utils.ir_reader.layer_to_class import copy_graph_with_ops, collect_extenders, collect_ops
|
||||||
from mo.utils.utils import get_mo_root_dir
|
from mo.utils.utils import get_mo_root_dir
|
||||||
|
from extensions.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues
|
||||||
|
|
||||||
|
|
||||||
def restore_graph_from_ir(path_to_xml: str, path_to_bin: str = None) -> (Graph, dict):
|
def restore_graph_from_ir(path_to_xml: str, path_to_bin: str = None) -> (Graph, dict):
|
||||||
@ -64,6 +65,7 @@ def save_restored_graph(graph: Graph, path: str, meta_data, name=None):
|
|||||||
BlobNormalizer,
|
BlobNormalizer,
|
||||||
ConvolutionNormalizer,
|
ConvolutionNormalizer,
|
||||||
KaldiRemoveMemoryOutputBackReplacementPattern,
|
KaldiRemoveMemoryOutputBackReplacementPattern,
|
||||||
|
MarkNodesWithShapeValues,
|
||||||
]
|
]
|
||||||
|
|
||||||
# We need to run some specific passes from MO back stage.
|
# We need to run some specific passes from MO back stage.
|
||||||
|
Loading…
Reference in New Issue
Block a user