[MO] Fix fp16 in shapeof subgraphs (#4524)

* Initial working solution

* moved bfs_search_apply_on_shapeof_subgraph_nodes from utils/graph.py to MarkShapeOfSubgraphDataType.py

* Reused bfs from MarkSubgraphsWithCorrectLayout.py

* fixed e2e precomit issues: specified correct const data_types, fixed BFS search staring point to avoid nodeless shapeof subgraphs

* fixed mxnet_rnnt: added converting all Const nodes in ShapeOf subgraph in MarkAndChangeDataTypeInShapeOfSubgraphs.py, revised Const values in transformations that affect ShapeOf subgraph nodes

* reverter ReverseV2ToReverseSequence.py and DecomposeBidirectionalRNNSequence.py

* in MarkSubgraphsWithCorrectLayout BFS search beauty applied

* apply review comments, returned back 'in_shape_subgraph' attribute

* graph condition added

* MO IR reader fix for mixed FP16 models, added replacer order placement comment

* moved to back phase

* new solution with marking nodes from bottom to top (WIP)

* successfully tested on back phase

* corrected unittest

* removed check for start nodes size in bfs

* fix transformations that insert f64 to f32 in shape subgraph

* corrected log.warning -> log.debug

* revised list if shape input operations added unittest for Const shape inputs

* applied @lazarevevgeny's comments

* licence head corrections
This commit is contained in:
Pavel Esir 2021-03-30 23:16:32 +03:00 committed by GitHub
parent e072c96237
commit 6477e8ec01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 266 additions and 77 deletions

View File

@ -11,6 +11,7 @@ extensions/back/__init__.py
extensions/back/AvgPool.py
extensions/back/blob_normalizer.py
extensions/back/CellNormalizer.py
extensions/back/ChangeCastOutputType.py
extensions/back/ClampNormalizer.py
extensions/back/compress_quantized_weights.py
extensions/back/ConvolutionNormalizer.py
@ -32,6 +33,7 @@ extensions/back/LayoutChangeForGatherND.py
extensions/back/LeakyReLUMutation.py
extensions/back/LinearToLinearONNXReplacer.py
extensions/back/LRNToNorm.py
extensions/back/MarkNodesWithShapeValues.py
extensions/back/MatMulNormalizer.py
extensions/back/MaxPool.py
extensions/back/NormalizeToNormalizeL2.py
@ -125,7 +127,6 @@ extensions/front/caffe/softmax_ext.py
extensions/front/caffe/spatial_transformer_ext.py
extensions/front/caffe/split_to_identity.py
extensions/front/caffe/tanh.py
extensions/front/ChangeCastOutputType.py
extensions/front/ChangePlaceholderTypes.py
extensions/front/create_tensor_nodes.py
extensions/front/disable_weights_quantize_value_propagation.py

View File

@ -0,0 +1,43 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import logging as log
import numpy as np
from mo.back.replacement import BackReplacementPattern
from mo.graph.graph import Graph
from mo.middle.passes.convert_data_type import data_type_str_to_np
class ChangeCastOutputType(BackReplacementPattern):
"""
Change the Cast dst_type from fp64 to fp32 since not all plugins support fp64 data type.
Change the Cast dst_type from fp32 to fp16 when generating IR for fp16.
But leave fp32 if node returns shape value even if --data_type=FP16 (look extensions/back/MarkNodesWithShapeValues.py).
"""
enabled = True
force_shape_inference = True
def run_after(self):
from extensions.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues
return [MarkNodesWithShapeValues]
def run_before(self):
return []
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(op='Cast'):
if node.dst_type == np.float64:
log.warning('Change data type from {} to {} for node {}'.format(node.dst_type, np.float32, node.name))
node.dst_type = np.float32
ir_data_type = data_type_str_to_np(node.graph.graph['cmd_params'].data_type)
if node.dst_type == np.float32 and ir_data_type == np.float16 and not node.has_and_set('returns_shape_value'):
log.warning('Change data type from {} to {} for node {}'.format(node.dst_type, ir_data_type, node.name))
node.dst_type = ir_data_type
elif node.has_and_set('returns_shape_value') and node.dst_type == np.float16:
# return back FP32 for all Convert nodes with shape values
log.warning('Change data type from {} to {} for node {} in ShapeOf subgraph'.
format(node.dst_type, np.float32, node.name))
node.dst_type = np.float32

View File

@ -0,0 +1,78 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import logging as log
import numpy as np
from extensions.middle.MarkSubgraphsWithCorrectLayout import MarkSubGraphsWithCorrectLayout
from mo.back.replacement import BackReplacementPattern
from mo.graph.graph import Graph
class MarkNodesWithShapeValues(BackReplacementPattern):
"""
This transformation marks op nodes in ShapeOf subgraphs with 'returns_shape_value' bool attribute and
data nodes of float32 constants with 'correct_data_type' attribute.
So that float Consts and Cast float will be kept in FP32 even if argument --data_type=FP16 is specified.
This is needed to enable conversion to FP16 even if values in ShapeOf subgraphs exceed max(float16)
or because of FP16 lower precession shape inference is incorrect on some nodes (e.g. if Interpolate in scales mode
accepts values from ShapeOf subgraph).
This transformation should be executed after shape inference and after all transformations which insert/modify
Cast nodes in ShapeOf subgraphs therefore it's placed at the end of the back phase.
"""
enabled = True
graph_condition = [lambda graph: graph.graph['cmd_params'].data_type == 'FP16']
def run_after(self):
from extensions.back.pass_separator import BackFinish
return [BackFinish]
def run_before(self):
return []
@staticmethod
def get_operations_with_shape_inputs():
return {
'Interpolate': [1, 2], # sizes, scales inputs
'Reshape': [1], # shape
'Broadcast': [1], # target_shape
'ConvBackPropData ': [2], # output_shape
'GroupConvolutionBackpropData ': [2], # output_shape
'BatchToSpace': [1, 2, 3], # block_shape, crops_begin, crops_end
'SpaceToBatch': [1, 2, 3], # block_shape, pads_begin, pads_end
'StridedSlice': [1, 2, 3], # begin, end, strides
'VariadicSplit': [2], # split_lengths
'Tile': [1], # repeats input
'TopK': [1], # K input
'Pad': [1, 2], # pads_begin, pads_end
'Range': [0, 1, 2], # start, stop, step inputs
'OneHot': [1], # depth input
}
def find_and_replace_pattern(self, graph: Graph):
shape_input_ops_map = self.get_operations_with_shape_inputs()
nodes_with_shape_inputs = []
for node in graph.get_op_nodes():
if node.soft_get('type') in shape_input_ops_map:
nodes_with_shape_inputs.append(node)
start_nodes = []
for node in nodes_with_shape_inputs:
start_nodes.extend(
[node.in_port(port_idx).get_source().node for port_idx in shape_input_ops_map[node.type] if
node.is_in_port_connected(port_idx)])
condition = lambda node: node.soft_get('type') != 'ShapeOf'
nodes_with_shape_values = MarkSubGraphsWithCorrectLayout.bfs(start_nodes, set(), condition, forward=False)
for node in nodes_with_shape_values:
node['returns_shape_value'] = True
if node.soft_get('type') == 'Const':
if node.value.dtype == np.float32:
node.out_node(0)['correct_data_type'] = True
elif node.value.dtype in [np.float16, np.float64]:
log.debug('Const nodes {} with shape values have {} type'.format(node.soft_get('name', node.id),
node.value.dtype))

View File

@ -0,0 +1,103 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
import numpy as np
from extensions.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues
from mo.front.common.partial_infer.utils import int64_array, float32_array
from mo.graph.graph import Node
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph
from mo.utils.unittest.graph import result, regular_op_with_empty_data, \
shaped_const_with_data, connect, regular_op
class TestMarkDataTypeInShapeOfSubgraphs(unittest.TestCase):
def test_run_with_shape_subgraph_input(self):
inp_shape = (1, 3, 1000, 1000)
dst_type = np.float32
nodes = {
**shaped_const_with_data('input', int64_array(inp_shape)),
**regular_op_with_empty_data('shape', {'type': 'ShapeOf'}),
**regular_op_with_empty_data('cast_to_float', {'type': 'Cast', 'dst_type': dst_type}),
**regular_op('mul_const', {'op': 'Const'}),
**{'mul_const_d': {'kind': 'data', 'value': float32_array([1., 1., 1., 100.])}},
**regular_op_with_empty_data('mul', {'type': 'Mul'}),
**regular_op_with_empty_data('cast_to_int', {'type': 'Cast', 'dst_type': np.int64}),
**regular_op_with_empty_data('interpolate', {'type': 'Interpolate', 'shape_calculation_model': 'scales'}),
**result('res'),
}
nodes_ref = {
**shaped_const_with_data('input', int64_array(inp_shape)),
**regular_op_with_empty_data('shape', {'type': 'ShapeOf'}),
**regular_op_with_empty_data('cast_to_float', {'type': 'Cast', 'dst_type': dst_type,
'returns_shape_value': True}),
**regular_op_with_empty_data('mul', {'type': 'Mul', 'returns_shape_value': True}),
**regular_op('mul_const', {'op': 'Const', 'returns_shape_value': True}),
**{'mul_const_d': {'kind': 'data', 'value': float32_array([1., 1., 1., 100.]),
'correct_data_type': True}},
**regular_op_with_empty_data('cast_to_int', {'type': 'Cast', 'dst_type': np.int64,
'returns_shape_value': True}),
**regular_op_with_empty_data('interpolate', {'type': 'Interpolate', 'shape_calculation_model': 'scales'}),
**result('res'),
}
edges = [
*connect('input', '0:interpolate'),
*connect('input', '0:shape', skip_data=True),
*connect('shape', '0:cast_to_float'),
*connect('cast_to_float', '0:mul'),
*connect('mul_const', '1:mul'),
*connect('mul', '0:cast_to_int'),
*connect('cast_to_int', '1:interpolate'),
*connect('interpolate', 'res'),
]
graph = build_graph(nodes, edges)
interp_node = Node(graph, 'interpolate')
interp_node.add_input_port(2)
MarkNodesWithShapeValues().find_and_replace_pattern(graph)
graph_ref = build_graph(nodes_ref, edges)
(flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_run_with_const_input(self):
inp_shape = (1, 3, 1000, 1000)
dst_type = np.float32
nodes = {
**shaped_const_with_data('input', int64_array(inp_shape)),
**regular_op('sizes_const', {'op': 'Const'}),
**{'sizes_const_d': {'kind': 'data', 'value': float32_array([1., 1., 1., 100.])}},
**regular_op_with_empty_data('interpolate', {'type': 'Interpolate', 'shape_calculation_model': 'scales'}),
**result('res'),
}
nodes_ref = {
**shaped_const_with_data('input', int64_array(inp_shape)),
**regular_op('sizes_const', {'op': 'Const', 'returns_shape_value': True}),
**{'sizes_const_d': {'kind': 'data', 'value': float32_array([1., 1., 1., 100.])}},
**regular_op_with_empty_data('interpolate', {'type': 'Interpolate', 'shape_calculation_model': 'scales'}),
**result('res'),
}
edges = [
*connect('input', '0:interpolate'),
*connect('sizes_const', '1:interpolate'),
*connect('interpolate', 'res'),
]
graph = build_graph(nodes, edges)
interp_node = Node(graph, 'interpolate')
interp_node.add_input_port(2)
MarkNodesWithShapeValues().find_and_replace_pattern(graph)
graph_ref = build_graph(nodes_ref, edges)
(flag, resp) = compare_graphs(graph, graph_ref, 'res', check_op_attrs=True)
self.assertTrue(flag, resp)

View File

@ -1,38 +0,0 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import logging as log
import numpy as np
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.subgraph_matcher import SubgraphMatch
from mo.graph.graph import Graph
from mo.middle.passes.convert_data_type import data_type_str_to_np
class ChangeCastOutputType(FrontReplacementSubgraph):
"""
Change the Cast to fp64 to fp32 since not all plugins support fp64 data type.
Change the Cast to fp32 to fp16 when generating IR for fp16.
"""
enabled = True
def pattern(self):
return dict(
nodes=[
('cast', dict(op='Cast'))
],
edges=[]
)
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
node = match['cast']
if node.dst_type == np.float64:
log.warning('Change data type from {} to {} for node {}'.format(node.dst_type, np.float32, node.name))
node.dst_type = np.float32
ir_data_type = data_type_str_to_np(node.graph.graph['cmd_params'].data_type)
if node.dst_type == np.float32 and ir_data_type == np.float16:
log.warning('Change data type from {} to {} for node {}'.format(node.dst_type, ir_data_type, node.name))
node.dst_type = ir_data_type

View File

@ -5,11 +5,10 @@ import numpy as np
from extensions.ops.Cast import Cast
from extensions.ops.elementwise import Div
from mo.front.common.partial_infer.utils import int64_array, float_array
from mo.front.common.partial_infer.utils import int64_array, float32_array
from mo.front.common.replacement import FrontReplacementPattern
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
from mo.graph.graph import Graph
from mo.middle.passes.convert_data_type import data_type_str_to_np
from mo.ops.concat import Concat
from mo.ops.reshape import Reshape
from mo.ops.shape import Shape
@ -46,21 +45,23 @@ class ReplaceConvolutionReshape(FrontReplacementPattern):
node = match['conv']
node_name = node.soft_get('name', node.id)
dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values
# create Reshape before convolution
# shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
shape = Cast(graph, {'name': node_name + '/to_float',
'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
'dst_type': dst_dtype}).create_node()
i_shape.in_port(0).connect(node.in_port(0).get_source())
shape.in_port(0).connect(i_shape.out_port(0))
N, H = node_to_get_shape_value_of_indices(shape, [0]), node_to_get_shape_value_of_indices(shape, [1])
div = create_op_with_const_inputs(
graph, Div, {1: float_array([node.patch_stride])}, {'name': node_name + '/div_stride_h'})
graph, Div, {1: float32_array([node.patch_stride])}, {'name': node_name + '/div_stride_h'})
div.in_port(0).connect(H.out_port(0))
concat = create_op_with_const_inputs(graph, Concat, {2: float_array([1]), 3: float_array([node.patch_stride])},
concat = create_op_with_const_inputs(graph, Concat, {2: float32_array([1]), 3: float32_array([node.patch_stride])},
{'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})
concat.in_port(0).connect(N.out_port(0))
concat.in_port(1).connect(div.out_port(0))

View File

@ -5,11 +5,10 @@ import numpy as np
from extensions.ops.Cast import Cast
from extensions.ops.elementwise import Div
from mo.front.common.partial_infer.utils import int64_array, float_array
from mo.front.common.partial_infer.utils import int64_array, float32_array
from mo.front.common.replacement import FrontReplacementPattern
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
from mo.graph.graph import Graph
from mo.middle.passes.convert_data_type import data_type_str_to_np
from mo.ops.concat import Concat
from mo.ops.reshape import Reshape
from mo.ops.shape import Shape
@ -48,18 +47,20 @@ class ReplacePoolingReshape(FrontReplacementPattern):
# create Reshape before convolution
# shape = [in_shape[0], pool_stride, 1, in_shape[1]/pool_stride]
i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values
shape = Cast(graph, {'name': node_name + '/to_float',
'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node()
'dst_type': dst_dtype}).create_node()
i_shape.in_port(0).connect(node.in_port(0).get_source())
shape.in_port(0).connect(i_shape.out_port(0))
N, H = node_to_get_shape_value_of_indices(shape, [0]), node_to_get_shape_value_of_indices(shape, [1])
div = create_op_with_const_inputs(
graph, Div, {1: float_array([node.pool_stride])}, {'name': node_name + '/div_stride_h'})
graph, Div, {1: float32_array([node.pool_stride])}, {'name': node_name + '/div_stride_h'})
div.in_port(0).connect(H.out_port(0))
concat = create_op_with_const_inputs(graph, Concat, {1: float_array([node.pool_stride]), 2: float_array([1])},
concat = create_op_with_const_inputs(graph, Concat, {1: float32_array([node.pool_stride]), 2: float32_array([1])},
{'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})
concat.in_port(0).connect(N.out_port(0))
concat.in_port(3).connect(div.out_port(0))

View File

@ -41,7 +41,7 @@ def unique_id(prefix: str = 'id') -> str:
unique_id.names = []
def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision = np.float):
def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision=np.float32):
# create init_graph connected to ReadValue
graph = input_out_port.node.graph
input_name = input_out_port.node.name

View File

@ -5,8 +5,8 @@ import logging as log
import numpy as np
from extensions.ops.activation_ops import Floor
from extensions.ops.Cast import Cast
from extensions.ops.activation_ops import Floor
from extensions.ops.elementwise import Add, Mul
from extensions.ops.interpolate import Interpolate
from extensions.ops.range import Range
@ -15,7 +15,6 @@ from mo.front.common.partial_infer.utils import int64_array, float_array
from mo.front.common.replacement import FrontReplacementOp
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, Node, rename_nodes
from mo.middle.passes.convert_data_type import data_type_str_to_np
from mo.ops.shape import Shape
from mo.ops.strided_slice import StridedSlice
@ -79,9 +78,9 @@ def replace_resize(graph: Graph, resize: Node):
{1: float_array([1.0e-5])},
{'name': resize_name + '/Add'})
input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values
cast_shape_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()
shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node([cast_shape_to_float, add_node])

View File

@ -8,7 +8,7 @@ from extensions.ops.DetectionOutput import DetectionOutput
from extensions.ops.elementwise import Mul, Sub, Pow
from extensions.ops.gather import Gather
from extensions.ops.split import VariadicSplit
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.partial_infer.utils import int64_array, float32_array
from mo.front.subgraph_matcher import SubgraphMatch
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
from mo.front.tf.replacement import FrontReplacementFromConfigFileSubGraph
@ -54,9 +54,9 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr
sp_shape = Shape(graph, {'name': name + '/shape'}).create_node()
priors_scale_node.out_port(0).connect(sp_shape.in_port(0))
begin = Const(graph, {'value': np.array([-2])}).create_node()
end = Const(graph, {'value': np.array([-1])}).create_node()
stride = Const(graph, {'value': np.array([1])}).create_node()
begin = Const(graph, {'value': int64_array([-2])}).create_node()
end = Const(graph, {'value': int64_array([-1])}).create_node()
stride = Const(graph, {'value': int64_array([1])}).create_node()
shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': np.array([1]),
'end_mask': np.array([1]), 'new_axis_mask': np.array([0]),
'shrink_axis_mask': np.array([0]),
@ -72,7 +72,7 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr
'axis': int64_array(0)},
shape_part_for_tiling)
variance = Const(graph, {'name': name + '/variance', 'value': np.array(variance)}).create_node()
variance = Const(graph, {'name': name + '/variance', 'value': float32_array(variance)}).create_node()
tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node()
variance.out_port(0).connect(tile.in_port(0))
shape_concat.out_port(0).connect(tile.in_port(1))
@ -113,9 +113,9 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr
shape = Shape(graph, {'name': 'input_image_shape'}).create_node()
shape.in_port(0).connect(placeholder.out_port(0))
begin = Const(graph, {'value': np.array([1])}).create_node()
end = Const(graph, {'value': np.array([3])}).create_node()
stride = Const(graph, {'value': np.array([1])}).create_node()
begin = Const(graph, {'value': int64_array([1])}).create_node()
end = Const(graph, {'value': int64_array([3])}).create_node()
stride = Const(graph, {'value': int64_array([1])}).create_node()
spatial = StridedSlice(graph, {'name': name + '/get_h_w', 'begin_mask': np.array([1]),
'end_mask': np.array([1]), 'new_axis_mask': np.array([0]),
'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()
@ -125,7 +125,7 @@ class RetinaNetFilteredDetectionsReplacement(FrontReplacementFromConfigFileSubGr
spatial.in_port(2).connect(end.out_port(0))
spatial.in_port(3).connect(stride.out_port(0))
power = Const(graph, {'value': np.array([-1.])}).create_node()
power = Const(graph, {'value': float32_array([-1.])}).create_node()
spatial_scale = Pow(graph, {}).create_node()
spatial_scale.in_port(0).connect(spatial.out_port(0))

View File

@ -3,14 +3,12 @@
import logging as log
from collections import deque
from typing import Set
from extensions.middle.InsertLayoutPropagationTransposes import InsertLayoutPropagationTranspose, \
mark_as_correct_data_layout, mark_output_as_in_correct_layout, mark_input_as_in_correct_layout
from extensions.middle.LayoutChangeForConstantShapePaths import LayoutChangeForConstantShapePaths
from extensions.middle.pass_separator import PostMiddleStart
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
from mo.graph.perm_inputs import PermuteInputs
from mo.graph.port import Port
@ -51,7 +49,8 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
result.append(dest_port.node)
return result
def bfs(self, start_nodes: list, visited: set, condition: callable = None, forward: bool = True):
@staticmethod
def bfs(start_nodes: list, visited: set, condition: callable = None, forward: bool = True):
"""
The function performs BFS starting from selected nodes in forward or backward direction adding nodes by an
optional condition
@ -63,7 +62,7 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
:return: the list of Nodes visited
"""
assert visited is not None, 'The "visited" set must be defined'
assert start_nodes is not None and len(start_nodes) != 0, 'The list of start nodes must be specified'
assert start_nodes is not None, 'The list of start nodes must be specified'
result = list()
d = deque(start_nodes)
@ -72,9 +71,9 @@ class MarkSubGraphsWithCorrectLayout(MiddleReplacementPattern):
result.append(cur_node)
visited.add(cur_node)
if forward:
next_nodes = self.get_output_nodes(cur_node)
next_nodes = MarkSubGraphsWithCorrectLayout.get_output_nodes(cur_node)
else:
next_nodes = self.get_input_nodes(cur_node)
next_nodes = MarkSubGraphsWithCorrectLayout.get_input_nodes(cur_node)
for next_node in next_nodes:
if next_node not in visited and (condition is None or condition(next_node)):
d.append(next_node)

View File

@ -2,18 +2,18 @@
# SPDX-License-Identifier: Apache-2.0
import logging as log
import numpy as np
from extensions.ops.activation_ops import Floor
from extensions.ops.Cast import Cast
from extensions.ops.activation_ops import Floor
from extensions.ops.elementwise import Add, Div, Mul
from extensions.ops.interpolate import Interpolate
from mo.front.common.layout import get_depth_dim, get_height_dim, get_width_dim
from mo.front.common.partial_infer.utils import int64_array, float_array
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.middle.passes.convert_data_type import data_type_str_to_np
from mo.middle.replacement import MiddleReplacementPattern
from mo.graph.graph import Graph, Node, rename_nodes
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
from mo.ops.shape import Shape
from mo.ops.strided_slice import StridedSlice
@ -94,10 +94,10 @@ def replace_resize(graph: Graph, resize: Node):
{1: float_array([1.0e-5])},
{'name': resize_name + '/Add'})
input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values
if num_of_inputs == 3:
cast_shape_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()
mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node()
shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))
@ -119,8 +119,8 @@ def replace_resize(graph: Graph, resize: Node):
connection_of_resize_input.get_source().connect(shape_of.in_port(0))
connection_of_scales.get_source().connect(mul_node.in_port(1))
else:
cast_shape_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
cast_sizes_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
cast_shape_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()
cast_sizes_to_float = Cast(graph, {'dst_type': dst_dtype}).create_node()
div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()
cast_sizes_to_float.out_port(0).connect(div_node.in_port(0))
cast_shape_to_float.out_port(0).connect(div_node.in_port(1))

View File

@ -2,9 +2,9 @@
# SPDX-License-Identifier: Apache-2.0
import logging as log
import numpy as np
from mo.front.common.partial_infer.elemental import copy_shape_infer
from mo.graph.graph import Node, Graph
from mo.middle.passes.convert_data_type import np_data_type_to_precision, convert_blob, \
np_data_type_to_destination_type, packed_I4, packed_U4

View File

@ -17,6 +17,7 @@ from mo.utils.class_registration import apply_replacements_list
from mo.utils.ir_engine.ir_engine import IREngine
from mo.utils.ir_reader.layer_to_class import copy_graph_with_ops, collect_extenders, collect_ops
from mo.utils.utils import get_mo_root_dir
from extensions.back.MarkNodesWithShapeValues import MarkNodesWithShapeValues
def restore_graph_from_ir(path_to_xml: str, path_to_bin: str = None) -> (Graph, dict):
@ -64,6 +65,7 @@ def save_restored_graph(graph: Graph, path: str, meta_data, name=None):
BlobNormalizer,
ConvolutionNormalizer,
KaldiRemoveMemoryOutputBackReplacementPattern,
MarkNodesWithShapeValues,
]
# We need to run some specific passes from MO back stage.