[MO] Fix fp16 in shapeof subgraphs (#4524)
* Initial working solution * moved bfs_search_apply_on_shapeof_subgraph_nodes from utils/graph.py to MarkShapeOfSubgraphDataType.py * Reused bfs from MarkSubgraphsWithCorrectLayout.py * fixed e2e precomit issues: specified correct const data_types, fixed BFS search staring point to avoid nodeless shapeof subgraphs * fixed mxnet_rnnt: added converting all Const nodes in ShapeOf subgraph in MarkAndChangeDataTypeInShapeOfSubgraphs.py, revised Const values in transformations that affect ShapeOf subgraph nodes * reverter ReverseV2ToReverseSequence.py and DecomposeBidirectionalRNNSequence.py * in MarkSubgraphsWithCorrectLayout BFS search beauty applied * apply review comments, returned back 'in_shape_subgraph' attribute * graph condition added * MO IR reader fix for mixed FP16 models, added replacer order placement comment * moved to back phase * new solution with marking nodes from bottom to top (WIP) * successfully tested on back phase * corrected unittest * removed check for start nodes size in bfs * fix transformations that insert f64 to f32 in shape subgraph * corrected log.warning -> log.debug * revised list if shape input operations added unittest for Const shape inputs * applied @lazarevevgeny's comments * licence head corrections
This commit is contained in:
parent
e072c96237
commit
6477e8ec01
@ -11,6 +11,7 @@ extensions/back/__init__.py
|
||||
extensions/back/AvgPool.py
|
||||
extensions/back/blob_normalizer.py
|
||||
extensions/back/CellNormalizer.py
|
||||
extensions/back/ChangeCastOutputType.py
|
||||
extensions/back/ClampNormalizer.py
|
||||
extensions/back/compress_quantized_weights.py
|
||||
extensions/back/ConvolutionNormalizer.py
|
||||
@ -32,6 +33,7 @@ extensions/back/LayoutChangeForGatherND.py
|
||||
extensions/back/LeakyReLUMutation.py
|
||||
extensions/back/LinearToLinearONNXReplacer.py
|
||||
extensions/back/LRNToNorm.py
|
||||
extensions/back/MarkNodesWithShapeValues.py
|
||||
extensions/back/MatMulNormalizer.py
|
||||
extensions/back/MaxPool.py
|
||||
extensions/back/NormalizeToNormalizeL2.py
|
||||
@ -125,7 +127,6 @@ extensions/front/caffe/softmax_ext.py
|
||||
extensions/front/caffe/spatial_transformer_ext.py
|
||||
extensions/front/caffe/split_to_identity.py
|
||||
extensions/front/caffe/tanh.py
|
||||
extensions/front/ChangeCastOutputType.py
|
||||
extensions/front/ChangePlaceholderTypes.py
|
||||
extensions/front/create_tensor_nodes.py
|
||||
extensions/front/disable_weights_quantize_value_propagation.py
|
||||
|
43
model-optimizer/extensions/back/ChangeCastOutputType.py
Normal file
43
model-optimizer/extensions/back/ChangeCastOutputType.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging as log
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.graph.graph import Graph
|
||||
from mo.middle.passes.convert_data_type import data_type_str_to_np
|
||||
|
||||
|
||||
class ChangeCastOutputType(BackReplacementPattern):
|
||||
"""
|
||||
Change the Cast dst_type from fp64 to fp32 since not all plugins support fp64 data type.
|
||||
Change the Cast dst_type from fp32 to fp16 when generating IR for fp16.
|
||||
But leave fp32 if node returns shape value even if --data_type=FP16 (look extensions/back/MarkNodesWithShapeValues.py).
|
||||
"""
|
||||
enabled = True
|
||||
force_shape_inference = True
|
||||
|
||||
def run_after(self):
|
||||
from extensions.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues
|
||||
return [MarkNodesWithShapeValues]
|
||||
|
||||
def run_before(self):
|
||||
return []
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for node in graph.get_op_nodes(op='Cast'):
|
||||
if node.dst_type == np.float64:
|
||||
log.warning('Change data type from {} to {} for node {}'.format(node.dst_type, np.float32, node.name))
|
||||
node.dst_type = np.float32
|
||||
|
||||
ir_data_type = data_type_str_to_np(node.graph.graph['cmd_params'].data_type)
|
||||
if node.dst_type == np.float32 and ir_data_type == np.float16 and not node.has_and_set('returns_shape_value'):
|
||||
log.warning('Change data type from {} to {} for node {}'.format(node.dst_type, ir_data_type, node.name))
|
||||
node.dst_type = ir_data_type
|
||||
elif node.has_and_set('returns_shape_value') and node.dst_type == np.float16:
|
||||
# return back FP32 for all Convert nodes with shape values
|
||||
log.warning('Change data type from {} to {} for node {} in ShapeOf subgraph'.
|
||||
format(node.dst_type, np.float32, node.name))
|
||||
node.dst_type = np.float32
|
78
model-optimizer/extensions/back/MarkNodesWithShapeValues.py
Normal file
78
model-optimizer/extensions/back/MarkNodesWithShapeValues.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging as log
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.middle.MarkSubgraphsWithCorrectLayout import MarkSubGraphsWithCorrectLayout
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.graph.graph import Graph
|
||||
|
||||
|
||||
class MarkNodesWithShapeValues(BackReplacementPattern):
|
||||
"""
|
||||
This transformation marks op nodes in ShapeOf subgraphs with 'returns_shape_value' bool attribute and
|
||||
data nodes of float32 constants with 'correct_data_type' attribute.
|
||||
So that float Consts and Cast float will be kept in FP32 even if argument --data_type=FP16 is specified.
|
||||
|
||||
This is needed to enable conversion to FP16 even if values in ShapeOf subgraphs exceed max(float16)
|
||||
or because of FP16 lower precession shape inference is incorrect on some nodes (e.g. if Interpolate in scales mode
|
||||
accepts values from ShapeOf subgraph).
|
||||
|
||||
This transformation should be executed after shape inference and after all transformations which insert/modify
|
||||
Cast nodes in ShapeOf subgraphs therefore it's placed at the end of the back phase.
|
||||
"""
|
||||
enabled = True
|
||||
graph_condition = [lambda graph: graph.graph['cmd_params'].data_type == 'FP16']
|
||||
|
||||
def run_after(self):
|
||||
from extensions.back.pass_separator import BackFinish
|
||||
return [BackFinish]
|
||||
|
||||
def run_before(self):
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_operations_with_shape_inputs():
|
||||
return {
|
||||
'Interpolate': [1, 2], # sizes, scales inputs
|
||||
'Reshape': [1], # shape
|
||||
'Broadcast': [1], # target_shape
|
||||
'ConvBackPropData ': [2], # output_shape
|
||||
'GroupConvolutionBackpropData ': [2], # output_shape
|
||||
'BatchToSpace': [1, 2, 3], # block_shape, crops_begin, crops_end
|
||||
'SpaceToBatch': [1, 2, 3], # block_shape, pads_begin, pads_end
|
||||
'StridedSlice': [1, 2, 3], # begin, end, strides
|
||||
'VariadicSplit': [2], # split_lengths
|
||||
'Tile': [1], # repeats input
|
||||
'TopK': [1], # K input
|
||||
'Pad': [1, 2], # pads_begin, pads_end
|
||||
'Range': [0, 1, 2], # start, stop, step inputs
|
||||
'OneHot': [1], # depth input
|
||||
}
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
shape_input_ops_map = self.get_operations_with_shape_inputs()
|
||||
|
||||
nodes_with_shape_inputs = []
|
||||
for node in graph.get_op_nodes():
|
||||
if node.soft_get('type') in shape_input_ops_map:
|
||||
nodes_with_shape_inputs.append(node)
|
||||
|
||||
start_nodes = []
|
||||
for node in nodes_with_shape_inputs:
|
||||
start_nodes.extend(
|
||||
[node.in_port(port_idx).get_source().node for port_idx in shape_input_ops_map[node.type] if
|
||||
node.is_in_port_connected(port_idx)])
|
||||
|
||||
condition = lambda node: node.soft_get('type') != 'ShapeOf'
|
||||
nodes_with_shape_values = MarkSubGraphsWithCorrectLayout.bfs(start_nodes, set(), condition, forward=False)
|
||||
for node in nodes_with_shape_values:
|
||||
node['returns_shape_value'] = True
|
||||
if node.soft_get('type') == 'Const':
|
||||
if node.value.dtype == np.float32:
|
||||
node.out_node(0)['correct_data_type'] = True
|
||||
elif node.value.dtype in [np.float16, np.float64]:
|
||||
log.debug('Const nodes {} with shape values have {} type'.format(node.soft_get('name', node.id),
|
||||
node.value.dtype))
|
103
model-optimizer/extensions/back/MarkNodesWithShapeValues_test.py
Normal file
103
model-optimizer/extensions/back/MarkNodesWithShapeValues_test.py
Normal file
@ -0,0 +1,103 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues
|
||||
from mo.front.common.partial_infer.utils import int64_array, float32_array
|
||||
from mo.graph.graph import Node
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph
|
||||
from mo.utils.unittest.graph import result, regular_op_with_empty_data, \
|
||||
shaped_const_with_data, connect, regular_op
|
||||
|
||||
|
||||
class TestMarkDataTypeInShapeOfSubgraphs(unittest.TestCase):
|
||||
|
||||
def test_run_with_shape_subgraph_input(self):
|
||||
inp_shape = (1, 3, 1000, 1000)
|
||||
dst_type = np.float32
|
||||
|
||||
nodes = {
|
||||
**shaped_const_with_data('input', int64_array(inp_shape)),
|
||||
**regular_op_with_empty_data('shape', {'type': 'ShapeOf'}),
|
||||
**regular_op_with_empty_data('cast_to_float', {'type': 'Cast', 'dst_type': dst_type}),
|
||||
**regular_op('mul_const', {'op': 'Const'}),
|
||||
**{'mul_const_d': {'kind': 'data', 'value': float32_array([1., 1., 1., 100.])}},
|
||||
**regular_op_with_empty_data('mul', {'type': 'Mul'}),
|
||||
**regular_op_with_empty_data('cast_to_int', {'type': 'Cast', 'dst_type': np.int64}),
|
||||
**regular_op_with_empty_data('interpolate', {'type': 'Interpolate', 'shape_calculation_model': 'scales'}),
|
||||
**result('res'),
|
||||
}
|
||||
|
||||
nodes_ref = {
|
||||
**shaped_const_with_data('input', int64_array(inp_shape)),
|
||||
**regular_op_with_empty_data('shape', {'type': 'ShapeOf'}),
|
||||
**regular_op_with_empty_data('cast_to_float', {'type': 'Cast', 'dst_type': dst_type,
|
||||
'returns_shape_value': True}),
|
||||
**regular_op_with_empty_data('mul', {'type': 'Mul', 'returns_shape_value': True}),
|
||||
**regular_op('mul_const', {'op': 'Const', 'returns_shape_value': True}),
|
||||
**{'mul_const_d': {'kind': 'data', 'value': float32_array([1., 1., 1., 100.]),
|
||||
'correct_data_type': True}},
|
||||
**regular_op_with_empty_data('cast_to_int', {'type': 'Cast', 'dst_type': np.int64,
|
||||
'returns_shape_value': True}),
|
||||
**regular_op_with_empty_data('interpolate', {'type': 'Interpolate', 'shape_calculation_model': 'scales'}),
|
||||
**result('res'),
|
||||
}
|
||||
|
||||
edges = [
|
||||
*connect('input', '0:interpolate'),
|
||||
*connect('input', '0:shape', skip_data=True),
|
||||
*connect('shape', '0:cast_to_float'),
|
||||
*connect('cast_to_float', '0:mul'),
|
||||
*connect('mul_const', '1:mul'),
|
||||
*connect('mul', '0:cast_to_int'),
|
||||
*connect('cast_to_int', '1:interpolate'),
|
||||
*connect('interpolate', 'res'),
|
||||
]
|
||||
graph = build_graph(nodes, edges)
|
||||
interp_node = Node(graph, 'interpolate')
|
||||
interp_node.add_input_port(2)
|
||||
|
||||
MarkNodesWithShapeValues().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph(nodes_ref, edges)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_run_with_const_input(self):
|
||||
inp_shape = (1, 3, 1000, 1000)
|
||||
dst_type = np.float32
|
||||
|
||||
nodes = {
|
||||
**shaped_const_with_data('input', int64_array(inp_shape)),
|
||||
**regular_op('sizes_const', {'op': 'Const'}),
|
||||
**{'sizes_const_d': {'kind': 'data', 'value': float32_array([1., 1., 1., 100.])}},
|
||||
**regular_op_with_empty_data('interpolate', {'type': 'Interpolate', 'shape_calculation_model': 'scales'}),
|
||||
**result('res'),
|
||||
}
|
||||
|
||||
nodes_ref = {
|
||||
**shaped_const_with_data('input', int64_array(inp_shape)),
|
||||
**regular_op('sizes_const', {'op': 'Const', 'returns_shape_value': True}),
|
||||
**{'sizes_const_d': {'kind': 'data', 'value': float32_array([1., 1., 1., 100.])}},
|
||||
**regular_op_with_empty_data('interpolate', {'type': 'Interpolate', 'shape_calculation_model': 'scales'}),
|
||||
**result('res'),
|
||||
}
|
||||
|
||||
edges = [
|
||||
*connect('input', '0:interpolate'),
|
||||
*connect('sizes_const', '1:interpolate'),
|
||||
*connect('interpolate', 'res'),
|
||||
]
|
||||
graph = build_graph(nodes, edges)
|
||||
interp_node = Node(graph, 'interpolate')
|
||||
interp_node.add_input_port(2)
|
||||
|
||||
MarkNodesWithShapeValues().find_and_replace_pattern(graph)
|
||||
|
||||
graph_ref = build_graph(nodes_ref, edges)
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
@ -1,38 +0,0 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging as log
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from mo.front.subgraph_matcher import SubgraphMatch
|
||||
from mo.graph.graph import Graph
|
||||
from mo.middle.passes.convert_data_type import data_type_str_to_np
|
||||
|
||||
|
||||
class ChangeCastOutputType(FrontReplacementSubgraph):
|
||||
"""
|
||||
Change the Cast to fp64 to fp32 since not all plugins support fp64 data type.
|
||||
Change the Cast to fp32 to fp16 when generating IR for fp16.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
('cast', dict(op='Cast'))
|
||||
],
|
||||
edges=[]
|
||||
)
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
|
||||
node = match['cast']
|
||||
if node.dst_type == np.float64:
|
||||
log.warning('Change data type from {} to {} for node {}'.format(node.dst_type, np.float32, node.name))
|
||||
node.dst_type = np.float32
|
||||
|
||||
ir_data_type = data_type_str_to_np(node.graph.graph['cmd_params'].data_type)
|
||||
if node.dst_type == np.float32 and ir_data_type == np.float16:
|
||||
log.warning('Change data type from {} to {} for node {}'.format(node.dst_type, ir_data_type, node.name))
|
||||
node.dst_type = ir_data_type
|
@ -5,11 +5,10 @@ import numpy as np
|
||||
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.elementwise import Div
|
||||
from mo.front.common.partial_infer.utils import int64_array, float_array
|
||||
from mo.front.common.partial_infer.utils import int64_array, float32_array
|
||||
from mo.front.common.replacement import FrontReplacementPattern
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
|
||||
from mo.graph.graph import Graph
|
||||
from mo.middle.passes.convert_data_type import data_type_str_to_np
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.reshape import Reshape
|
||||
from mo.ops.shape import Shape
|
||||
@ -46,21 +45,23 @@ class ReplaceConvolutionReshape(FrontReplacementPattern):
|
||||
node = match['conv']
|
||||
node_name = node.soft_get('name', node.id)
|
||||
|
||||
dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values
|
||||
|
||||
# create Reshape before convolution
|
||||
# shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
|
||||
i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
|
||||
shape = Cast(graph, {'name': node_name + '/to_float',
|
||||
'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
|
||||
'dst_type': dst_dtype}).create_node()
|
||||
i_shape.in_port(0).connect(node.in_port(0).get_source())
|
||||
shape.in_port(0).connect(i_shape.out_port(0))
|
||||
|
||||
N, H = node_to_get_shape_value_of_indices(shape, [0]), node_to_get_shape_value_of_indices(shape, [1])
|
||||
|
||||
div = create_op_with_const_inputs(
|
||||
graph, Div, {1: float_array([node.patch_stride])}, {'name': node_name + '/div_stride_h'})
|
||||
graph, Div, {1: float32_array([node.patch_stride])}, {'name': node_name + '/div_stride_h'})
|
||||
div.in_port(0).connect(H.out_port(0))
|
||||
|
||||
concat = create_op_with_const_inputs(graph, Concat, {2: float_array([1]), 3: float_array([node.patch_stride])},
|
||||
concat = create_op_with_const_inputs(graph, Concat, {2: float32_array([1]), 3: float32_array([node.patch_stride])},
|
||||
{'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})
|
||||
concat.in_port(0).connect(N.out_port(0))
|
||||
concat.in_port(1).connect(div.out_port(0))
|
||||
|
@ -5,11 +5,10 @@ import numpy as np
|
||||
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.elementwise import Div
|
||||
from mo.front.common.partial_infer.utils import int64_array, float_array
|
||||
from mo.front.common.partial_infer.utils import int64_array, float32_array
|
||||
from mo.front.common.replacement import FrontReplacementPattern
|
||||
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph
|
||||
from mo.middle.passes.convert_data_type import data_type_str_to_np
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.reshape import Reshape
|
||||
from mo.ops.shape import Shape
|
||||
@ -48,18 +47,20 @@ class ReplacePoolingReshape(FrontReplacementPattern):
|
||||
# create Reshape before convolution
|
||||
# shape = [in_shape[0], pool_stride, 1, in_shape[1]/pool_stride]
|
||||
i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
|
||||
|
||||
dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values
|
||||
shape = Cast(graph, {'name': node_name + '/to_float',
|
||||
'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
|
||||
'dst_type': dst_dtype}).create_node()
|
||||
i_shape.in_port(0).connect(node.in_port(0).get_source())
|
||||
shape.in_port(0).connect(i_shape.out_port(0))
|
||||
|
||||
N, H = node_to_get_shape_value_of_indices(shape, [0]), node_to_get_shape_value_of_indices(shape, [1])
|
||||
|
||||
div = create_op_with_const_inputs(
|
||||
graph, Div, {1: float_array([node.pool_stride])}, {'name': node_name + '/div_stride_h'})
|
||||
graph, Div, {1: float32_array([node.pool_stride])}, {'name': node_name + '/div_stride_h'})
|
||||
div.in_port(0).connect(H.out_port(0))
|
||||
|
||||
concat = create_op_with_const_inputs(graph, Concat, {1: float_array([node.pool_stride]), 2: float_array([1])},
|
||||
concat = create_op_with_const_inputs(graph, Concat, {1: float32_array([node.pool_stride]), 2: float32_array([1])},
|
||||
{'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})
|
||||
concat.in_port(0).connect(N.out_port(0))
|
||||
concat.in_port(3).connect(div.out_port(0))
|
||||
|
@ -41,7 +41,7 @@ def unique_id(prefix: str = 'id') -> str:
|
||||
unique_id.names = []
|
||||
|
||||
|
||||
def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision = np.float):
|
||||
def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision=np.float32):
|
||||
# create init_graph connected to ReadValue
|
||||
graph = input_out_port.node.graph
|
||||
input_name = input_out_port.node.name
|
||||
|
@ -5,8 +5,8 @@ import logging as log
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.activation_ops import Floor
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.activation_ops import Floor
|
||||
from extensions.ops.elementwise import Add, Mul
|
||||
from extensions.ops.interpolate import Interpolate
|
||||
from extensions.ops.range import Range
|
||||
@ -15,7 +15,6 @@ from mo.front.common.partial_infer.utils import int64_array, float_array
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph, Node, rename_nodes
|
||||
from mo.middle.passes.convert_data_type import data_type_str_to_np
|
||||
from mo.ops.shape import Shape
|
||||
from mo.ops.strided_slice import StridedSlice
|
||||
|
||||
@ -79,9 +78,9 @@ def replace_resize(graph: Graph, resize: Node):
|
||||
{1: float_array([1.0e-5])},
|
||||
{'name': resize_name + '/Add'})
|
||||
|
||||
input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
|
||||
dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values
|
||||
|
||||
cast_shape_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
|
||||
cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()
|
||||
|
||||
shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
|
||||
mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node([cast_shape_to_float, add_node])
|
||||
|
@ -8,7 +8,7 @@ from extensions.ops.DetectionOutput import DetectionOutput
|
||||
from extensions.ops.elementwise import Mul, Sub, Pow
|
||||
from extensions.ops.gather import Gather
|
||||
from extensions.ops.split import VariadicSplit
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.partial_infer.utils import int64_array, float32_array
|
||||
from mo.front.subgraph_matcher import SubgraphMatch
|
||||
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
|
||||
from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph
|
||||
@ -54,9 +54,9 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr
|
||||
sp_shape = Shape(graph, {'name': name + '/shape'}).create_node()
|
||||
priors_scale_node.out_port(0).connect(sp_shape.in_port(0))
|
||||
|
||||
begin = Const(graph, {'value': np.array([-2])}).create_node()
|
||||
end = Const(graph, {'value': np.array([-1])}).create_node()
|
||||
stride = Const(graph, {'value': np.array([1])}).create_node()
|
||||
begin = Const(graph, {'value': int64_array([-2])}).create_node()
|
||||
end = Const(graph, {'value': int64_array([-1])}).create_node()
|
||||
stride = Const(graph, {'value': int64_array([1])}).create_node()
|
||||
shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': np.array([1]),
|
||||
'end_mask': np.array([1]), 'new_axis_mask': np.array([0]),
|
||||
'shrink_axis_mask': np.array([0]),
|
||||
@ -72,7 +72,7 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr
|
||||
'axis': int64_array(0)},
|
||||
shape_part_for_tiling)
|
||||
|
||||
variance = Const(graph, {'name': name + '/variance', 'value': np.array(variance)}).create_node()
|
||||
variance = Const(graph, {'name': name + '/variance', 'value': float32_array(variance)}).create_node()
|
||||
tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node()
|
||||
variance.out_port(0).connect(tile.in_port(0))
|
||||
shape_concat.out_port(0).connect(tile.in_port(1))
|
||||
@ -113,9 +113,9 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr
|
||||
shape = Shape(graph, {'name': 'input_image_shape'}).create_node()
|
||||
shape.in_port(0).connect(placeholder.out_port(0))
|
||||
|
||||
begin = Const(graph, {'value': np.array([1])}).create_node()
|
||||
end = Const(graph, {'value': np.array([3])}).create_node()
|
||||
stride = Const(graph, {'value': np.array([1])}).create_node()
|
||||
begin = Const(graph, {'value': int64_array([1])}).create_node()
|
||||
end = Const(graph, {'value': int64_array([3])}).create_node()
|
||||
stride = Const(graph, {'value': int64_array([1])}).create_node()
|
||||
spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': np.array([1]),
|
||||
'end_mask': np.array([1]), 'new_axis_mask': np.array([0]),
|
||||
'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()
|
||||
@ -125,7 +125,7 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr
|
||||
spatial.in_port(2).connect(end.out_port(0))
|
||||
spatial.in_port(3).connect(stride.out_port(0))
|
||||
|
||||
power = Const(graph, {'value': np.array([-1.])}).create_node()
|
||||
power = Const(graph, {'value': float32_array([-1.])}).create_node()
|
||||
spatial_scale = Pow(graph, {}).create_node()
|
||||
|
||||
spatial_scale.in_port(0).connect(spatial.out_port(0))
|
||||
|
@ -3,14 +3,12 @@
|
||||
|
||||
import logging as log
|
||||
from collections import deque
|
||||
|
||||
from typing import Set
|
||||
|
||||
from extensions.middle.InsertLayoutPropagationTransposes import InsertLayoutPropagationTranspose, \
|
||||
mark_as_correct_data_layout, mark_output_as_in_correct_layout, mark_input_as_in_correct_layout
|
||||
from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths
|
||||
from extensions.middle.pass_separator import PostMiddleStart
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.graph.perm_inputs import PermuteInputs
|
||||
from mo.graph.port import Port
|
||||
@ -51,7 +49,8 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
|
||||
result.append(dest_port.node)
|
||||
return result
|
||||
|
||||
def bfs(self, start_nodes: list, visited: set, condition: callable = None, forward: bool = True):
|
||||
@staticmethod
|
||||
def bfs(start_nodes: list, visited: set, condition: callable = None, forward: bool = True):
|
||||
"""
|
||||
The function performs BFS starting from selected nodes in forward or backward direction adding nodes by an
|
||||
optional condition
|
||||
@ -63,7 +62,7 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
|
||||
:return: the list of Nodes visited
|
||||
"""
|
||||
assert visited is not None, 'The "visited" set must be defined'
|
||||
assert start_nodes is not None and len(start_nodes) != 0, 'The list of start nodes must be specified'
|
||||
assert start_nodes is not None, 'The list of start nodes must be specified'
|
||||
|
||||
result = list()
|
||||
d = deque(start_nodes)
|
||||
@ -72,9 +71,9 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
|
||||
result.append(cur_node)
|
||||
visited.add(cur_node)
|
||||
if forward:
|
||||
next_nodes = self.get_output_nodes(cur_node)
|
||||
next_nodes = MarkSubGraphsWithCorrectLayout.get_output_nodes(cur_node)
|
||||
else:
|
||||
next_nodes = self.get_input_nodes(cur_node)
|
||||
next_nodes = MarkSubGraphsWithCorrectLayout.get_input_nodes(cur_node)
|
||||
for next_node in next_nodes:
|
||||
if next_node not in visited and (condition is None or condition(next_node)):
|
||||
d.append(next_node)
|
||||
|
@ -2,18 +2,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging as log
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.activation_ops import Floor
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.activation_ops import Floor
|
||||
from extensions.ops.elementwise import Add, Div, Mul
|
||||
from extensions.ops.interpolate import Interpolate
|
||||
from mo.front.common.layout import get_depth_dim, get_height_dim, get_width_dim
|
||||
from mo.front.common.partial_infer.utils import int64_array, float_array
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.middle.passes.convert_data_type import data_type_str_to_np
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.graph.graph import Graph, Node, rename_nodes
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.shape import Shape
|
||||
from mo.ops.strided_slice import StridedSlice
|
||||
@ -94,10 +94,10 @@ def replace_resize(graph: Graph, resize: Node):
|
||||
{1: float_array([1.0e-5])},
|
||||
{'name': resize_name + '/Add'})
|
||||
|
||||
input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
|
||||
dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values
|
||||
|
||||
if num_of_inputs == 3:
|
||||
cast_shape_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
|
||||
cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()
|
||||
mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node()
|
||||
shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
|
||||
cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))
|
||||
@ -119,8 +119,8 @@ def replace_resize(graph: Graph, resize: Node):
|
||||
connection_of_resize_input.get_source().connect(shape_of.in_port(0))
|
||||
connection_of_scales.get_source().connect(mul_node.in_port(1))
|
||||
else:
|
||||
cast_shape_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
|
||||
cast_sizes_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
|
||||
cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()
|
||||
cast_sizes_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()
|
||||
div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()
|
||||
cast_sizes_to_float.out_port(0).connect(div_node.in_port(0))
|
||||
cast_shape_to_float.out_port(0).connect(div_node.in_port(1))
|
||||
|
@ -2,9 +2,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging as log
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.partial_infer.elemental import copy_shape_infer
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.middle.passes.convert_data_type import np_data_type_to_precision, convert_blob, \
|
||||
np_data_type_to_destination_type, packed_I4, packed_U4
|
||||
|
@ -17,6 +17,7 @@ from mo.utils.class_registration import apply_replacements_list
|
||||
from mo.utils.ir_engine.ir_engine import IREngine
|
||||
from mo.utils.ir_reader.layer_to_class import copy_graph_with_ops, collect_extenders, collect_ops
|
||||
from mo.utils.utils import get_mo_root_dir
|
||||
from extensions.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues
|
||||
|
||||
|
||||
def restore_graph_from_ir(path_to_xml: str, path_to_bin: str = None) -> (Graph, dict):
|
||||
@ -64,6 +65,7 @@ def save_restored_graph(graph: Graph, path: str, meta_data, name=None):
|
||||
BlobNormalizer,
|
||||
ConvolutionNormalizer,
|
||||
KaldiRemoveMemoryOutputBackReplacementPattern,
|
||||
MarkNodesWithShapeValues,
|
||||
]
|
||||
|
||||
# We need to run some specific passes from MO back stage.
|
||||
|
Loading…
Reference in New Issue
Block a user