Fix const node non-deterministic names (part 2) (#1081)

* Fix non-deterministic node names generation in the Model Optimizer (part 2)
This commit is contained in:
Anton Chetverikov 2020-07-07 09:37:48 +03:00 committed by GitHub
parent 5d1c5ee6a9
commit 56916ace61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 205 additions and 175 deletions

View File

@ -20,7 +20,7 @@ from extensions.back.ReshapeMutation import ReshapeMutation
from extensions.back.ReverseInputChannels import ApplyReverseChannels from extensions.back.ReverseInputChannels import ApplyReverseChannels
from mo.back.replacement import BackReplacementPattern from mo.back.replacement import BackReplacementPattern
from mo.front.common.partial_infer.utils import int64_array from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_node_with_second_input from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
from mo.graph.graph import Graph from mo.graph.graph import Graph
from mo.ops.const import Const from mo.ops.const import Const
from mo.ops.reshape import Reshape from mo.ops.reshape import Reshape
@ -224,6 +224,7 @@ class DeconvolutionNormalizer(BackReplacementPattern):
def replace_pattern(self, graph: Graph, match: dict): def replace_pattern(self, graph: Graph, match: dict):
node = match['node'] node = match['node']
node_name = node.soft_get('name', node.id)
if 2 in node.in_ports() and not node.in_port(2).disconnected(): if 2 in node.in_ports() and not node.in_port(2).disconnected():
# Third input represents output shape. Cutting its value according to scheme: # Third input represents output shape. Cutting its value according to scheme:
@ -233,22 +234,17 @@ class DeconvolutionNormalizer(BackReplacementPattern):
shape_src = node.in_port(2).get_source() shape_src = node.in_port(2).get_source()
node.in_port(2).disconnect() node.in_port(2).disconnect()
begin = Const(graph, {'value': np.array([2], dtype=np.int32)}).create_node() ss_0 = create_op_with_const_inputs(graph, StridedSlice, {1: np.array([2], dtype=np.int32),
end = Const(graph, {'value': np.array([in_rank], dtype=np.int32)}).create_node() 2: np.array([in_rank], dtype=np.int32),
stride = Const(graph, {'value': np.array([1], dtype=np.int32)}).create_node() 3: np.array([1], dtype=np.int32)},
{'name': node_name + '/ss_0_port',
ss_0 = StridedSlice(graph, {'name': node.name + '/ss_0_port',
'begin_mask': np.array([1], dtype=np.int32), 'begin_mask': np.array([1], dtype=np.int32),
'end_mask': np.array([0], dtype=np.int32), 'end_mask': np.array([0], dtype=np.int32),
'new_axis_mask': np.array([0], dtype=np.int32), 'new_axis_mask': np.array([0], dtype=np.int32),
'shrink_axis_mask': np.array([0], dtype=np.int32), 'shrink_axis_mask': np.array([0], dtype=np.int32),
'ellipsis_mask': np.array([0], dtype=np.int32)}).create_node() 'ellipsis_mask': np.array([0], dtype=np.int32)})
shape_src.connect(ss_0.in_port(0)) shape_src.connect(ss_0.in_port(0))
begin.out_port(0).connect(ss_0.in_port(1))
end.out_port(0).connect(ss_0.in_port(2))
stride.out_port(0).connect(ss_0.in_port(3))
ss_0.out_port(0).connect(node.in_port(2)) ss_0.out_port(0).connect(node.in_port(2))
# Specification: *padding amount* is deduced from relation of input and output spatial shapes # Specification: *padding amount* is deduced from relation of input and output spatial shapes
@ -256,7 +252,8 @@ class DeconvolutionNormalizer(BackReplacementPattern):
elif node.has_valid('original_output_spatial_shape'): elif node.has_valid('original_output_spatial_shape'):
# node had fixed output spatial shape set in original framework, so we restore it here # node had fixed output spatial shape set in original framework, so we restore it here
const = Const(graph, {'value': int64_array(node.original_output_spatial_shape)}).create_node() const = Const(graph, {'value': int64_array(node.original_output_spatial_shape),
'name': node_name + '/original_spatial_shape'}).create_node()
node.add_input_port(2, skip_if_exist=True) node.add_input_port(2, skip_if_exist=True)
const.out_port(0).connect(node.in_port(2)) const.out_port(0).connect(node.in_port(2))

View File

@ -55,43 +55,52 @@ class CropToStridedSlice(BackReplacementPattern):
def replace_pattern(self, graph: Graph, match: [str, Node]): def replace_pattern(self, graph: Graph, match: [str, Node]):
node = match['crop'] node = match['crop']
assert node.has_valid('axis') assert node.has_valid('axis')
node.axis = self.list_to_ndarray(node.axis) node_axis = self.list_to_ndarray(node.axis)
in_shape = node.in_port(0).data.get_shape() in_shape = node.in_port(0).data.get_shape()
shape_rank = in_shape.size shape_rank = in_shape.size
axis_mask = int64_array([1 if i in node.axis else 0 for i in range(shape_rank)]) axis_mask = int64_array([1 if i in node_axis else 0 for i in range(shape_rank)])
begin_mask = axis_mask.copy() begin_mask = axis_mask.copy()
end_mask = axis_mask.copy() end_mask = axis_mask.copy()
ss = StridedSlice(graph, {'name': node.soft_get('name', node.id) + '/strided_slice', 'begin_mask': begin_mask,
'end_mask': end_mask, 'new_axis_mask': np.array([0]),
'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()
if len(node.in_nodes()) == 2 and node.has_valid('offset'): if len(node.in_nodes()) == 2 and node.has_valid('offset'):
# Crop Type 1 # Crop Type 1
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node.axis, node.offset)}).create_node() begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, node.offset),
shape = Shape(graph, {'name': node.name + '/shape_of_crop'}).create_node() 'name': ss.name + '/begin'}).create_node()
end = Add(graph, {'name': node.name + '/end'}).create_node() shape = Shape(graph, {'name': ss.name + '/shape_of_crop'}).create_node()
end = Add(graph, {'name': ss.name + '/end'}).create_node()
node.in_port(1).get_connection().get_source().connect(shape.in_port(0)) node.in_port(1).get_connection().get_source().connect(shape.in_port(0))
node.in_port(1).disconnect() node.in_port(1).disconnect()
shape.out_port(0).connect(end.in_port(0)) shape.out_port(0).connect(end.in_port(0))
begin.out_port(0).connect(end.in_port(1)) begin.out_port(0).connect(end.in_port(1))
elif node.has_valid('dim') and node.has_valid('offset'): elif node.has_valid('dim') and node.has_valid('offset'):
# Crop Type 2 # Crop Type 2
node.dim = self.list_to_ndarray(node.dim) node_dim = self.list_to_ndarray(node.dim)
node.offset = self.list_to_ndarray(node.offset) node_offset = self.list_to_ndarray(node.offset)
assert node.dim.size == node.offset.size == node.axis.size assert node_dim.size == node_offset.size == node_axis.size
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node.axis, node.offset)}).create_node() begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, node_offset),
end_values = np.array([node.offset[i] + node.dim[i] for i in range(len(node.dim))]) 'name': ss.name + '/begin'}).create_node()
end = Const(graph, {'value': self.mask_normalizer(shape_rank, node.axis, end_values)}).create_node() end_values = np.array([node_offset[i] + node_dim[i] for i in range(len(node_dim))])
end = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, end_values),
'name': ss.name + '/end'}).create_node()
elif node.has_valid('crop_begin') and node.has_valid('crop_end'): elif node.has_valid('crop_begin') and node.has_valid('crop_end'):
# Crop Type 3 # Crop Type 3
node.crop_begin = self.list_to_ndarray(node.crop_begin) node_crop_begin = self.list_to_ndarray(node.crop_begin)
node.crop_end = self.list_to_ndarray(node.crop_end) node_crop_end = self.list_to_ndarray(node.crop_end)
assert len(node.crop_begin) == len(node.crop_end) == len(node.axis) assert len(node_crop_begin) == len(node_crop_end) == len(node_axis)
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node.axis, node.crop_begin)}).create_node() begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, node_crop_begin),
shape = Shape(graph, {'name': node.name + '/shape_of_crop'}).create_node() 'name': ss.name + '/begin'}).create_node()
const = Const(graph, shape = Shape(graph, {'name': ss.name + '/shape'}).create_node()
{'value': -1 * self.mask_normalizer(shape_rank, node.axis, node.crop_end)}).create_node()
end = Add(graph, {'name': node.name + '/end'}).create_node() end = Add(graph, {'name': ss.name + '/end'}).create_node()
const = Const(graph, {'value': -1 * self.mask_normalizer(shape_rank, node_axis, node_crop_end),
'name': ss.name + '/const'}).create_node()
node.in_port(0).get_connection().get_source().connect(shape.in_port(0)) node.in_port(0).get_connection().get_source().connect(shape.in_port(0))
shape.out_port(0).connect(end.in_port(0)) shape.out_port(0).connect(end.in_port(0))
@ -102,9 +111,8 @@ class CropToStridedSlice(BackReplacementPattern):
source = node.in_port(0).get_connection().get_source() source = node.in_port(0).get_connection().get_source()
stride = Const(graph, {'value': np.ones(shape_rank, dtype=np.int64)}).create_node() stride = Const(graph, {'value': np.ones(shape_rank, dtype=np.int64),
ss = StridedSlice(graph, {'name': 'Crop_', 'begin_mask': begin_mask, 'end_mask': end_mask, 'new_axis_mask': np.array([0]), 'name': ss.name + '/stride'}).create_node()
'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()
source.connect(ss.in_port(0)) source.connect(ss.in_port(0))
begin.out_port(0).connect(ss.in_port(1)) begin.out_port(0).connect(ss.in_port(1))

View File

@ -44,6 +44,7 @@ class GroupedConvWeightsNormalize(BackReplacementPattern):
weights = match['weights'] weights = match['weights']
input_shape = conv.in_port(0).data.get_shape() input_shape = conv.in_port(0).data.get_shape()
new_weights_shape = int64_array([(weights.value.shape[0] * weights.value.shape[1]) / (input_shape[1] / conv.group), input_shape[1] / conv.group, *weights.value.shape[2:]]) new_weights_shape = int64_array([(weights.value.shape[0] * weights.value.shape[1]) / (input_shape[1] / conv.group), input_shape[1] / conv.group, *weights.value.shape[2:]])
new_weights = Const(graph, {'value': np.reshape(weights.value, new_weights_shape)}).create_node() new_weights = Const(graph, {'value': np.reshape(weights.value, new_weights_shape),
'name': weights.soft_get('name', weights.id) + '_new'}).create_node()
weights.out_port(0).get_connection().set_source(new_weights.out_port(0)) weights.out_port(0).get_connection().set_source(new_weights.out_port(0))
new_weights.infer(new_weights) new_weights.infer(new_weights)

View File

@ -18,7 +18,7 @@ import numpy as np
from extensions.back.ForceStrictPrecision import ForceStrictPrecision from extensions.back.ForceStrictPrecision import ForceStrictPrecision
from extensions.ops.prelu import PreluOp from extensions.ops.prelu import PreluOp
from mo.back.replacement import BackReplacementPattern from mo.back.replacement import BackReplacementPattern
from mo.graph.graph import Graph from mo.graph.graph import Graph, rename_node
from mo.ops.const import Const from mo.ops.const import Const
@ -39,11 +39,16 @@ class LeakyReLUMutation(BackReplacementPattern):
@staticmethod @staticmethod
def replace_pattern(graph: Graph, match: dict): def replace_pattern(graph: Graph, match: dict):
relu = match['leakyrelu'] relu = match['leakyrelu']
relu_name = relu.soft_get('name', relu.id)
if not relu.has_valid('negative_slope'): if not relu.has_valid('negative_slope'):
return return
rename_node(relu, relu_name + '/to_delete')
# Create PReLU op and reconnect input/output from LeakyReLU to PReLU # Create PReLU op and reconnect input/output from LeakyReLU to PReLU
prelu = PreluOp(graph, dict(name=relu.name)).create_node() prelu = PreluOp(graph, dict(name=relu_name)).create_node()
const = Const(graph, dict(name=relu.name + "/weights", value=np.array([relu.negative_slope]))).create_node() rename_node(prelu, relu_name)
const = Const(graph, dict(name=relu_name + "/weights", value=np.array([relu.negative_slope]))).create_node()
relu.in_port(0).get_connection().set_destination(prelu.in_port(0)) relu.in_port(0).get_connection().set_destination(prelu.in_port(0))
const.out_port(0).connect(prelu.in_port(1)) const.out_port(0).connect(prelu.in_port(1))

View File

@ -19,6 +19,7 @@ import numpy as np
from extensions.ops.transpose import Transpose from extensions.ops.transpose import Transpose
from mo.back.replacement import BackReplacementPattern from mo.back.replacement import BackReplacementPattern
from mo.front.common.partial_infer.utils import int64_array from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_node_with_second_input
from mo.graph.graph import Graph from mo.graph.graph import Graph
from mo.ops.const import Const from mo.ops.const import Const
from mo.ops.unsqueeze import Unsqueeze from mo.ops.unsqueeze import Unsqueeze
@ -56,13 +57,12 @@ class MatMulConstTransposesExtraction(BackReplacementPattern):
transpose_order = list(range(port_shape.size)) transpose_order = list(range(port_shape.size))
transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1]
order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = create_op_node_with_second_input(graph, Transpose, int64_array(transpose_order),
transpose = Transpose(graph, {'name': name + '/{}_port_transpose'.format(in_port_idx)}).create_node() {'name': name + '/{}_port_transpose'.format(in_port_idx)})
port_source = in_port.get_source() port_source = in_port.get_source()
in_port.get_connection().set_source(transpose.out_port(0)) in_port.get_connection().set_source(transpose.out_port(0))
transpose.in_port(0).connect(port_source) transpose.in_port(0).connect(port_source)
transpose.in_port(1).connect(order.out_port(0))
transpose['override_output_shape'] = True transpose['override_output_shape'] = True

View File

@ -21,8 +21,8 @@ from extensions.back.ReshapeMutation import ReshapeMutation
from extensions.back.StridedSliceMasksNormalizer import StridedSliceMasksNormalizer from extensions.back.StridedSliceMasksNormalizer import StridedSliceMasksNormalizer
from mo.back.replacement import BackReplacementPattern from mo.back.replacement import BackReplacementPattern
from mo.front.common.partial_infer.utils import int64_array from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
from mo.graph.graph import Graph from mo.graph.graph import Graph
from mo.ops.const import Const
from mo.ops.reshape import Reshape from mo.ops.reshape import Reshape
from mo.ops.strided_slice import StridedSlice from mo.ops.strided_slice import StridedSlice
@ -53,32 +53,27 @@ class ProposalMutation(BackReplacementPattern):
'implementation of the Proposal layer uses only 4 first values (indices 0, 1, 2 and 3). ' 'implementation of the Proposal layer uses only 4 first values (indices 0, 1, 2 and 3). '
'Elements with indices 4 and 5 will be ignored.'.format(node.soft_get('name', node.id)), 'Elements with indices 4 and 5 will be ignored.'.format(node.soft_get('name', node.id)),
extra={'is_warning': True}) extra={'is_warning': True})
begin = Const(graph, {'value': np.array([0, 0], dtype=np.int32)}).create_node()
end = Const(graph, {'value': np.array([1, 3], dtype=np.int32)}).create_node()
stride = Const(graph, {'value': np.array([1, 1], dtype=np.int32)}).create_node()
cropped_im_info = StridedSlice(graph, {'name': 'cropped_im_info', cropped_im_info = create_op_with_const_inputs(graph, StridedSlice, {1: np.array([0, 0], dtype=np.int32),
2: np.array([1, 3], dtype=np.int32),
3: np.array([1, 1], dtype=np.int32)},
{'name': 'cropped_im_info',
'begin_mask': int64_array([1, 1]), 'begin_mask': int64_array([1, 1]),
'end_mask': int64_array([1, 1]), 'end_mask': int64_array([1, 1]),
'new_axis_mask': int64_array([0]), 'new_axis_mask': int64_array([0]),
'shrink_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]),
'ellipsis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0]),
'override_output_shape': True, 'override_output_shape': True,
}).create_node() })
node.in_port(2).get_connection().insert_node(cropped_im_info) node.in_port(2).get_connection().insert_node(cropped_im_info)
begin.out_port(0).connect(cropped_im_info.in_port(1))
end.out_port(0).connect(cropped_im_info.in_port(2))
stride.out_port(0).connect(cropped_im_info.in_port(3))
# update the im_info_shape so the next 'if' statement become true # update the im_info_shape so the next 'if' statement become true
im_info_shape = int64_array([1, 3]) im_info_shape = int64_array([1, 3])
if np.array_equal(im_info_shape, [1, 3]) or np.array_equal(im_info_shape, [1, 4]): if np.array_equal(im_info_shape, [1, 3]) or np.array_equal(im_info_shape, [1, 4]):
reshape = Reshape(graph, dict(name="im_info/Reshape")).create_node() reshape = create_op_node_with_second_input(graph, Reshape, [im_info_shape[1]], {'name': 'im_info/Reshape'})
const = Const(graph, dict(value=[im_info_shape[1]])).create_node()
node.in_port(2).get_connection().set_destination(reshape.in_port(0)) node.in_port(2).get_connection().set_destination(reshape.in_port(0))
const.out_port(0).connect(reshape.in_port(1))
reshape.out_port(0).connect(node.in_port(2)) reshape.out_port(0).connect(node.in_port(2))
if node.has_port('out', 1) and not node.out_port(1).disconnected(): if node.has_port('out', 1) and not node.out_port(1).disconnected():

View File

@ -25,7 +25,6 @@ from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph from mo.graph.graph import Graph
from mo.graph.graph import Node from mo.graph.graph import Node
from mo.ops.concat import Concat from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.op import Op, PermuteAttrs from mo.ops.op import Op, PermuteAttrs
@ -358,11 +357,7 @@ class DecomposeReverseChannels(BackReplacementPattern):
axis = node.axis axis = node.axis
order = node.order order = node.order
indices = Const(graph, {'name': name + '/reverse_order', 'value': order}).create_node() gather = create_op_with_const_inputs(graph, Gather, {1: order, 2: int64_array(axis)}, {'name': name})
axis_const = Const(graph, {'value': int64_array(axis)}).create_node()
gather = Gather(graph, {'name': name}).create_node()
gather.in_port(1).connect(indices.out_port(0))
gather.in_port(2).connect(axis_const.out_port(0))
node.out_port(0).get_connection().set_source(gather.out_port(0)) node.out_port(0).get_connection().set_source(gather.out_port(0))
node.in_port(0).get_connection().set_destination(gather.in_port(0)) node.in_port(0).get_connection().set_destination(gather.in_port(0))

View File

@ -107,8 +107,9 @@ class CompressQuantizeWeights(BackReplacementPattern):
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
initial_fake_quantize = match['quantize'] initial_fake_quantize = match['quantize']
initial_fake_quantize_name = initial_fake_quantize.soft_get('name', initial_fake_quantize.id)
new_fake_quantize = initial_fake_quantize.copy_node(dict(name=initial_fake_quantize.name + '/Copy', new_fake_quantize = initial_fake_quantize.copy_node(dict(name=initial_fake_quantize_name + '/Copy',
stop_value_propagation=False), graph) stop_value_propagation=False), graph)
initial_fake_quantize.in_port(1).get_connection().set_destination(new_fake_quantize.in_port(1)) initial_fake_quantize.in_port(1).get_connection().set_destination(new_fake_quantize.in_port(1))
@ -121,9 +122,9 @@ class CompressQuantizeWeights(BackReplacementPattern):
i_min = np.array([0.], dtype=dst_type) i_min = np.array([0.], dtype=dst_type)
i_max = np.array([initial_fake_quantize.levels - 1.], dtype=dst_type) i_max = np.array([initial_fake_quantize.levels - 1.], dtype=dst_type)
new_out_low_node = Const(graph, dict(name=initial_fake_quantize.name + '/Copy/out_low', new_out_low_node = Const(graph, dict(name=initial_fake_quantize_name + '/Copy/out_low',
value=i_min)).create_node() value=i_min)).create_node()
new_out_high_node = Const(graph, dict(name=initial_fake_quantize.name + '/Copy/out_high', new_out_high_node = Const(graph, dict(name=initial_fake_quantize_name + '/Copy/out_high',
value=i_max)).create_node() value=i_max)).create_node()
new_out_low_node.out_port(0).connect(new_fake_quantize.in_port(3)) new_out_low_node.out_port(0).connect(new_fake_quantize.in_port(3))
@ -131,7 +132,7 @@ class CompressQuantizeWeights(BackReplacementPattern):
new_out_low_node.out_port(0).connect(initial_fake_quantize.in_port(1)) new_out_low_node.out_port(0).connect(initial_fake_quantize.in_port(1))
new_out_high_node.out_port(0).connect(initial_fake_quantize.in_port(2)) new_out_high_node.out_port(0).connect(initial_fake_quantize.in_port(2))
cast_node = Cast(graph, dict(name=initial_fake_quantize.name + "/Convert_to_float", dst_type=dst_type, cast_node = Cast(graph, dict(name=initial_fake_quantize_name + "/Convert_to_float", dst_type=dst_type,
stop_value_propagation=True)).create_node() stop_value_propagation=True)).create_node()
new_fake_quantize.out_port(0).connect(cast_node.in_port(0)) new_fake_quantize.out_port(0).connect(cast_node.in_port(0))
initial_fake_quantize.in_port(0).get_connection().set_destination(new_fake_quantize.in_port(0)) initial_fake_quantize.in_port(0).get_connection().set_destination(new_fake_quantize.in_port(0))

View File

@ -49,12 +49,12 @@ class PriorboxMutation(BackReplacementPattern):
assert len(node.in_ports()) == 2 assert len(node.in_ports()) == 2
begin = Const(graph, {'value': np.array([2], dtype=np.int32)}).create_node() begin = Const(graph, {'value': np.array([2], dtype=np.int32), 'name': name + '/ss_begin'}).create_node()
end = Const(graph, {'value': np.array([4], dtype=np.int32)}).create_node() end = Const(graph, {'value': np.array([4], dtype=np.int32), 'name': name + '/ss_end'}).create_node()
stride = Const(graph, {'value': np.array([1], dtype=np.int32)}).create_node() stride = Const(graph, {'value': np.array([1], dtype=np.int32), 'name': name + '/ss_stride'}).create_node()
shape_0 = Shape(graph, {'name': node.name + '/0_port'}).create_node() shape_0 = Shape(graph, {'name': name + '/0_port'}).create_node()
ss_0 = StridedSlice(graph, {'name': node.name + '/ss_0_port', ss_0 = StridedSlice(graph, {'name': name + '/ss_0_port',
'begin_mask': np.array([1], dtype=np.int32), 'begin_mask': np.array([1], dtype=np.int32),
'end_mask': np.array([0], dtype=np.int32), 'end_mask': np.array([0], dtype=np.int32),
'new_axis_mask': np.array([0], dtype=np.int32), 'new_axis_mask': np.array([0], dtype=np.int32),
@ -71,8 +71,8 @@ class PriorboxMutation(BackReplacementPattern):
source.connect(shape_0.in_port(0)) source.connect(shape_0.in_port(0))
ss_0.out_port(0).connect(node.in_port(0)) ss_0.out_port(0).connect(node.in_port(0))
shape_1 = Shape(graph, {'name': node.name + '/1_port'}).create_node() shape_1 = Shape(graph, {'name': name + '/1_port'}).create_node()
ss_1 = StridedSlice(graph, {'name': node.name + '/ss_1_port', ss_1 = StridedSlice(graph, {'name': name + '/ss_1_port',
'begin_mask': np.array([1], dtype=np.int32), 'begin_mask': np.array([1], dtype=np.int32),
'end_mask': np.array([0], dtype=np.int32), 'end_mask': np.array([0], dtype=np.int32),
'new_axis_mask': np.array([0], dtype=np.int32), 'new_axis_mask': np.array([0], dtype=np.int32),

View File

@ -92,6 +92,8 @@ class BinarizeWeightsM1P1(MiddleReplacementPattern):
output_low = quantize.in_node(3) output_low = quantize.in_node(3)
output_high = quantize.in_node(4) output_high = quantize.in_node(4)
quantize_name = quantize.soft_get('name', quantize.id)
if not output_low.has_valid('value') and not output_high.has_valid('value'): if not output_low.has_valid('value') and not output_high.has_valid('value'):
return return
@ -115,8 +117,9 @@ class BinarizeWeightsM1P1(MiddleReplacementPattern):
mult_term = quantize.in_node(3) if np.all(output_high == 0) else quantize.in_node(4) mult_term = quantize.in_node(3) if np.all(output_high == 0) else quantize.in_node(4)
new_shape = Const(graph, {'value': int64_array([-1, 1, 1])}).create_node_with_data() new_shape = Const(graph, {'name': quantize_name + '/Reshape/Shape',
reshape = Reshape(graph, {}).create_node_with_data([mult_term, new_shape]) 'value': int64_array([-1, 1, 1])}).create_node_with_data()
reshape = Reshape(graph, {'name': quantize_name + '/Reshape'}).create_node_with_data([mult_term, new_shape])
# Patch inflow path (by diving by mult_term) # Patch inflow path (by diving by mult_term)
# Put a new Pow/Mul combination here: # Put a new Pow/Mul combination here:
@ -125,15 +128,16 @@ class BinarizeWeightsM1P1(MiddleReplacementPattern):
if len(match['quantized'].out_nodes()) > 1: if len(match['quantized'].out_nodes()) > 1:
log.debug('BinarizeWeightsM1P1: len(match[\'quantized\'].out_nodes()) > 1') log.debug('BinarizeWeightsM1P1: len(match[\'quantized\'].out_nodes()) > 1')
return return
power_of_exponent = Const(graph, {'value': np.array(-1.0)}).create_node_with_data() power_of_exponent = Const(graph, {'name': quantize_name + '/DivNormalize/Power',
div_op = Pow(graph, {'name': quantize.name + '/DivNormalize'}) 'value': np.array(-1.0)}).create_node_with_data()
div_op = Pow(graph, {'name': quantize_name + '/DivNormalize'})
div_output = div_op.create_node_with_data([mult_term, power_of_exponent]) div_output = div_op.create_node_with_data([mult_term, power_of_exponent])
for i in [3, 4]: for i in [3, 4]:
match['quantize'].insert_node_with_data_before( match['quantize'].insert_node_with_data_before(
match['quantize'].in_node(i), match['quantize'].in_node(i),
Mul, Mul,
dict(name=quantize.name + '/MulNormalize'), dict(name=quantize_name + '/MulNormalize'),
additional_inputs=[div_output], additional_inputs=[div_output],
) )

View File

@ -19,9 +19,9 @@ import logging as log
import numpy as np import numpy as np
from extensions.ops.elementwise import Mul, Add from extensions.ops.elementwise import Mul, Add
from mo.front.tf.graph_utils import create_op_node_with_second_input
from mo.graph.graph import Graph from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
class ConvToBinaryConv(MiddleReplacementPattern): class ConvToBinaryConv(MiddleReplacementPattern):
@ -91,12 +91,10 @@ class ConvToBinaryConv(MiddleReplacementPattern):
weights_reduced = np.add.reduce(weights, axis=tuple(reduction_indices)) weights_reduced = np.add.reduce(weights, axis=tuple(reduction_indices))
weights_reduced = weights_reduced.reshape([len(weights_reduced), 1, 1]) # FIXME: works for NCHW only weights_reduced = weights_reduced.reshape([len(weights_reduced), 1, 1]) # FIXME: works for NCHW only
add_term = Const(graph, {'value': weights_reduced}).create_node() operator_name = operator.soft_get('name', operator.id)
add = Add(graph, {}).create_node() add = create_op_node_with_second_input(graph, Add, weights_reduced, {'name': operator_name + '/Add_'})
add.in_port(1).connect(add_term.out_port(0)) mul = create_op_node_with_second_input(graph, Mul, np.array(0.5), {'name': operator_name + '/Mul_'})
mul_term = Const(graph, {'value': np.array(0.5)}).create_node()
mul = Mul(graph, {}).create_node()
mul.in_port(1).connect(mul_term.out_port(0))
add.out_port(0).connect(mul.in_port(0)) add.out_port(0).connect(mul.in_port(0))
operator.out_port(0).get_connection().set_source(mul.out_port(0)) operator.out_port(0).get_connection().set_source(mul.out_port(0))

View File

@ -189,9 +189,13 @@ class ConvertGroupedStridedSlice(MiddleReplacementPattern):
log.debug("Removed: {}".format(node.id)) log.debug("Removed: {}".format(node.id))
# 2. Create Split layer and reorder outputs # 2. Create Split layer and reorder outputs
axis_const = Const(graph, {'value': int64_array(split_channel_dim)}).create_node_with_data() name = name_for_future_split + "/Split"
size_splits_const = Const(graph, {'value': int64_array(size_splits)}).create_node_with_data() axis_const = Const(graph, {'value': int64_array(split_channel_dim),
split = VariadicSplit(graph, dict(name=name_for_future_split + "/Split", out_ports_count=len(size_splits))) 'name': name + '/Axis'}).create_node_with_data()
size_splits_const = Const(graph, {'value': int64_array(size_splits),
'name': name + '/Sizes'}).create_node_with_data()
split = VariadicSplit(graph, dict(name=name, out_ports_count=len(size_splits)))
split.create_node_with_data(inputs=[input_data, axis_const, size_splits_const], split.create_node_with_data(inputs=[input_data, axis_const, size_splits_const],
data_nodes=final_data_nodes_list) data_nodes=final_data_nodes_list)

View File

@ -36,6 +36,7 @@ class ConvertLayoutDependentOperations(MiddleReplacementPattern):
def find_and_replace_pattern(self, graph: Graph): def find_and_replace_pattern(self, graph: Graph):
for node in list(graph.nodes()): for node in list(graph.nodes()):
node = Node(graph, node) node = Node(graph, node)
node_name = node.soft_get('name', node.id)
# Check that node layout mismatch with graph layout # Check that node layout mismatch with graph layout
# For example: NHWC and NCHW or NCDHW and NDHWC # For example: NHWC and NCHW or NCDHW and NDHWC
if node.kind == 'op' and node.has_valid('layout') and node.layout != indices_mapping[len(node.layout)][ if node.kind == 'op' and node.has_valid('layout') and node.layout != indices_mapping[len(node.layout)][
@ -63,8 +64,10 @@ class ConvertLayoutDependentOperations(MiddleReplacementPattern):
edge_attrs = graph.get_edge_data(input.id, node.id)[0] edge_attrs = graph.get_edge_data(input.id, node.id)[0]
graph.remove_edge(input.id, node.id) graph.remove_edge(input.id, node.id)
input_order_const = Const(graph, {'value': permutation.perm}).create_node_with_data() input_permute_name = node_name + '/input_transpose'
input_permute_op = Transpose(graph, dict(name=node.name + '/Transpose_')) input_order_const = Const(graph, {'name': input_permute_name + '/order',
'value': permutation.perm}).create_node_with_data()
input_permute_op = Transpose(graph, {'name': input_permute_name})
input_permute_data_node = input_permute_op.create_node_with_data([input, input_order_const]) input_permute_data_node = input_permute_op.create_node_with_data([input, input_order_const])
graph.add_edge(input_permute_data_node.id, node.id, **edge_attrs) graph.add_edge(input_permute_data_node.id, node.id, **edge_attrs)
@ -77,8 +80,10 @@ class ConvertLayoutDependentOperations(MiddleReplacementPattern):
input_data_node = Op.create_data_node(graph, node, {'shape': output.shape[permutation.perm]}, input_data_node = Op.create_data_node(graph, node, {'shape': output.shape[permutation.perm]},
edge_attrs) edge_attrs)
output_order_const = Const(graph, {'value': permutation.inv}).create_node_with_data() output_permute_name = node_name + '/output_transpose'
output_permute_op = Transpose(graph, dict(name=node.name + '/Transpose_') output_order_const = Const(graph, {'name': output_permute_name + '/order',
'value': permutation.inv}).create_node_with_data()
output_permute_op = Transpose(graph, {'name': output_permute_name}
).create_node_with_data([input_data_node, output_order_const], ).create_node_with_data([input_data_node, output_order_const],
data_nodes=output) data_nodes=output)

View File

@ -46,9 +46,11 @@ class Deconvolution3rdInputNormalization(MiddleReplacementPattern):
data_node = node.in_node(2) data_node = node.in_node(2)
const = Const(graph, {'value': permutation.perm, 'need_shape_inference': True}).create_node_with_data() name = node.soft_get('name', node.id) + '/ShapeGather'
axis_const = Const(graph, {'value': int64_array(0)}).create_node_with_data() const = Const(graph, {'value': permutation.perm, 'name': name + '/Const',
gather = Gather(graph, {'name': node.name + '/ShapeGather', 'need_shape_inference': True}).create_node_with_data()
axis_const = Const(graph, {'value': int64_array(0), 'name': name + '/Axis'}).create_node_with_data()
gather = Gather(graph, {'name': name,
'need_shape_inference': True}).create_node_with_data([data_node, const, axis_const]) 'need_shape_inference': True}).create_node_with_data([data_node, const, axis_const])
attrs = graph.get_edge_data(data_node.id, node.id, key=0).copy() attrs = graph.get_edge_data(data_node.id, node.id, key=0).copy()

View File

@ -160,7 +160,8 @@ class DilatedConvolution1DConverter(MiddleReplacementPattern):
unsqueeze_axis = unsqueeze.in_port(1).data.get_value() unsqueeze_axis = unsqueeze.in_port(1).data.get_value()
for port_id in [1, 2]: for port_id in [1, 2]:
current_value = pad.in_port(port_id).get_connection().data.get_value() current_value = pad.in_port(port_id).get_connection().data.get_value()
new_value_node = Const(pad.graph, {'value': np.insert(current_value, unsqueeze_axis.item(), 0), new_value_node = Const(pad.graph, {'name': pad.soft_get('name', pad.id) + '/value_{}'.format(port_id),
'value': np.insert(current_value, unsqueeze_axis.item(), 0),
'override_output_shape': True}).create_node() 'override_output_shape': True}).create_node()
pad.in_port(port_id).disconnect() pad.in_port(port_id).disconnect()
pad.in_port(port_id).connect(new_value_node.out_port(0)) pad.in_port(port_id).connect(new_value_node.out_port(0))

View File

@ -106,8 +106,10 @@ class EltwiseInputReshape(MiddleReplacementPattern):
# Insert Reshape layer between data node and consumer # Insert Reshape layer between data node and consumer
for shape_key in mapping.keys(): for shape_key in mapping.keys():
shape = list(shape_key) shape = list(shape_key)
reshape = Reshape(graph, attrs={'name': 'EltwiseReshapeNormalization'}) reshape_name = node.soft_get('name', node.id) + '/EltwiseReshape'
reshape_dim = Const(graph, {'value': shape}).create_node_with_data() reshape = Reshape(graph, attrs={'name': reshape_name})
reshape_dim = Const(graph,
{'value': shape, 'name': reshape_name + '/Shape'}).create_node_with_data()
reshape_data = reshape.create_node_with_data(inputs=[node, reshape_dim]) reshape_data = reshape.create_node_with_data(inputs=[node, reshape_dim])
# Iterate over consumers and reconnect them to Reshape layer output # Iterate over consumers and reconnect them to Reshape layer output

View File

@ -19,9 +19,9 @@ import numpy as np
from extensions.ops.gather import Gather from extensions.ops.gather import Gather
from mo.front.common.partial_infer.utils import int64_array from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
from mo.graph.graph import Graph, rename_node
from mo.middle.replacement import MiddleReplacementPattern from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
from mo.ops.reshape import Reshape from mo.ops.reshape import Reshape
@ -68,6 +68,7 @@ class GatherNdNormalize(MiddleReplacementPattern):
def replace_pattern(self, graph: Graph, match: dict): def replace_pattern(self, graph: Graph, match: dict):
gather = match['GatherNd'] gather = match['GatherNd']
gather_name = gather.soft_get('name', gather.id)
input_shape = gather.in_node(0).shape input_shape = gather.in_node(0).shape
indices = gather.in_node(1).value indices = gather.in_node(1).value
if indices is None: if indices is None:
@ -77,26 +78,26 @@ class GatherNdNormalize(MiddleReplacementPattern):
# 0. All needed checks that we can replace GatherNd by Gather # 0. All needed checks that we can replace GatherNd by Gather
gather_idx = self.indices_check(indices, input_shape) gather_idx = self.indices_check(indices, input_shape)
if gather_idx is None: if gather_idx is None:
log.warning('Node {} with op=GatherNd can\'t be normalized to op=Gather.'.format(gather.name)) log.warning('Node {} with op=GatherNd can\'t be normalized to op=Gather.'.format(gather_name))
return return
# 1. Add Reshape and connect # 1. Add Reshape and connect
new_shape = int64_array([-1] + list(input_shape[indices.shape[-1]:])) new_shape = int64_array([-1] + list(input_shape[indices.shape[-1]:]))
reshape = Reshape(graph, {'name': gather.name + '/Reshape_for_GatherNd/'}).create_node() reshape = create_op_node_with_second_input(graph, Reshape, new_shape,
reshape_const_node = Const(graph, {'name': reshape.name + '/Dim', 'value': new_shape}).create_node() {'name': gather_name + '/Reshape_for_GatherNd/'})
gather.in_port(0).get_connection().set_destination(reshape.in_port(0)) gather.in_port(0).get_connection().set_destination(reshape.in_port(0))
reshape.in_port(1).connect(reshape_const_node.out_port(0))
# 2. Change indices from Nd to 1d: # 2. Change indices from Nd to 1d:
new_indices = np.reshape(np.take(indices, indices=[gather_idx], axis=-1), [-1]) new_indices = np.reshape(np.take(indices, indices=[gather_idx], axis=-1), [-1])
new_indices_const = Const(graph, dict(value=new_indices)).create_node()
axis_const = Const(graph, {'value': int64_array(0)}).create_node() rename_node(gather, gather_name + '/to_delete')
# 3. Create new Gather operation and reconnect all inputs/outputs # 3. Create new Gather operation and reconnect all inputs/outputs
new_gather = Gather(graph, {'name': gather.name + '/NewGather/'}).create_node() new_gather = create_op_with_const_inputs(graph, Gather, {1: new_indices, 2: int64_array(0)},
{'name': gather_name})
rename_node(new_gather, gather_name)
reshape.out_port(0).connect(new_gather.in_port(0)) reshape.out_port(0).connect(new_gather.in_port(0))
new_indices_const.out_port(0).connect(new_gather.in_port(1))
axis_const.out_port(0).connect(new_gather.in_port(2))
gather.out_port(0).get_connection().set_source(new_gather.out_port(0)) gather.out_port(0).get_connection().set_source(new_gather.out_port(0))

View File

@ -16,9 +16,9 @@
from extensions.middle.pass_separator import PostMiddleStart from extensions.middle.pass_separator import PostMiddleStart
from extensions.ops.transpose import Transpose from extensions.ops.transpose import Transpose
from mo.graph.graph import Graph, Node from mo.graph.graph import Graph, Node
from mo.middle.replacement import MiddleReplacementPattern from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
from mo.ops.op import PermuteAttrs from mo.ops.op import PermuteAttrs
@ -69,6 +69,10 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern):
len(node.out_port(0).data.get_shape()) >= 4 len(node.out_port(0).data.get_shape()) >= 4
def find_and_replace_pattern(self, graph: Graph): def find_and_replace_pattern(self, graph: Graph):
# we need to import these functions here to avoid circular dependent imports
from mo.front.tf.graph_utils import create_op_node_with_second_input
if graph.graph['layout'] != 'NHWC': if graph.graph['layout'] != 'NHWC':
# we check it here because this transformation is called explicitly from the pipeline # we check it here because this transformation is called explicitly from the pipeline
return return
@ -80,15 +84,14 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern):
reinterp_shape_node_id, graph.dump_graph_for_graphviz()) reinterp_shape_node_id, graph.dump_graph_for_graphviz())
input_shape = reinterp_shape_node.in_node(0).shape input_shape = reinterp_shape_node.in_node(0).shape
if self.is_nchw_to_nhwc_transpose_needed(reinterp_shape_node): if self.is_nchw_to_nhwc_transpose_needed(reinterp_shape_node):
order_const = Const(graph, {'value': PermuteAttrs().get_nchw_to_nhwc_permutation(len(input_shape)).perm permute_node = create_op_node_with_second_input(
}).create_node() graph, Transpose, PermuteAttrs().get_nchw_to_nhwc_permutation(len(input_shape)).perm,
permute_node = Transpose(graph, {'name': reinterp_shape_node.in_port(0).get_source().node.name + '/Transpose'}
{'name': reinterp_shape_node.in_port(0).get_source().node.name + '/Transpose' )
}).create_node()
reinterp_shape_node.in_port(0).get_connection().insert_node(permute_node) reinterp_shape_node.in_port(0).get_connection().insert_node(permute_node)
order_const.out_port(0).connect(permute_node.in_port(1))
order_const.infer(order_const)
order_const = permute_node.in_port(1).get_source().node
order_const.infer(order_const)
# do not infer the Transpose node because it should have input data node in NCHW layout (but currently # do not infer the Transpose node because it should have input data node in NCHW layout (but currently
# it is NHWC because data node attributes has not been permuted yet) and produce output in NHWC layout # it is NHWC because data node attributes has not been permuted yet) and produce output in NHWC layout
# (which is true at this moment) # (which is true at this moment)
@ -107,11 +110,10 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern):
reinterp_shape_node_id, graph.dump_graph_for_graphviz()) reinterp_shape_node_id, graph.dump_graph_for_graphviz())
output_shape = reinterp_shape_node.out_node(0).shape output_shape = reinterp_shape_node.out_node(0).shape
if self.is_nhwc_to_nchw_transpose_needed(reinterp_shape_node): if self.is_nhwc_to_nchw_transpose_needed(reinterp_shape_node):
order_const = Const(graph, { permute_node = create_op_node_with_second_input(
'value': PermuteAttrs().get_nhwc_to_nchw_permutation(len(output_shape)).perm}).create_node() graph, Transpose, PermuteAttrs().get_nhwc_to_nchw_permutation(len(output_shape)).perm,
permute_node = Transpose(graph, {'name': reinterp_shape_node.id + '/Transpose'}).create_node() {'name': reinterp_shape_node.id + '/Transpose'})
reinterp_shape_node.out_port(0).get_connection().insert_node(permute_node) reinterp_shape_node.out_port(0).get_connection().insert_node(permute_node)
order_const.out_port(0).connect(permute_node.in_port(1))
# the Reshape and Transpose operations should work in original (NHWC layout) so the Transpose # the Reshape and Transpose operations should work in original (NHWC layout) so the Transpose
# will convert it to the NCHW # will convert it to the NCHW

View File

@ -20,9 +20,9 @@ import numpy as np
from extensions.ops.normalize import NormalizeOp from extensions.ops.normalize import NormalizeOp
from mo.front.common.partial_infer.utils import int64_array from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_node_with_second_input
from mo.graph.graph import Graph, rename_node from mo.graph.graph import Graph, rename_node
from mo.middle.replacement import MiddleReplacementPattern from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
class L2NormToNorm(MiddleReplacementPattern): class L2NormToNorm(MiddleReplacementPattern):
@ -87,13 +87,13 @@ class L2NormToNorm(MiddleReplacementPattern):
normalizel2_name = output_name + '/normalizel2' normalizel2_name = output_name + '/normalizel2'
rename_node(match['l2_normalize'], normalizel2_name) rename_node(match['l2_normalize'], normalizel2_name)
normalize_node = NormalizeOp(graph, {'name': output_name, 'eps': y, normalize_node = create_op_node_with_second_input(graph, NormalizeOp,
'across_spatial': 0, 'channel_shared': 0}).create_node() np.ones(shape=int64_array([match['input'].shape[-1]]),
dtype=match['input'].data_type),
{'name': output_name, 'eps': y,
'across_spatial': 0, 'channel_shared': 0})
rename_node(normalize_node, output_name) rename_node(normalize_node, output_name)
weights_node = Const(graph, {'value': np.ones(shape=int64_array([match['input'].shape[-1]]),
dtype=match['input'].data_type)}).create_node()
match['square'].in_port(0).get_source().connect(normalize_node.in_port(0)) match['square'].in_port(0).get_source().connect(normalize_node.in_port(0))
match['square'].in_port(0).disconnect() match['square'].in_port(0).disconnect()
@ -102,5 +102,4 @@ class L2NormToNorm(MiddleReplacementPattern):
else: else:
match['l2_normalize'].in_port(0).disconnect() match['l2_normalize'].in_port(0).disconnect()
weights_node.out_port(0).get_connection().set_destination(normalize_node.in_port(1))
match['l2_normalize'].out_port(0).get_connection().set_source(normalize_node.out_port(0)) match['l2_normalize'].out_port(0).get_connection().set_source(normalize_node.out_port(0))

View File

@ -192,7 +192,7 @@ class MXNetRNNSequenceNormalize(MiddleReplacementPattern):
input = match['input'] input = match['input']
if not lstm.has_num_directions: if not lstm.has_num_directions:
return return
old_data_node =lstm.out_node(0) old_data_node = lstm.out_node(0)
num_directions = 2 if lstm.direction in ['bidirectional'] else 1 num_directions = 2 if lstm.direction in ['bidirectional'] else 1
mxnet_shape = lstm.out_node(0).shape.copy() mxnet_shape = lstm.out_node(0).shape.copy()
@ -206,18 +206,21 @@ class MXNetRNNSequenceNormalize(MiddleReplacementPattern):
if lstm.has_num_directions: if lstm.has_num_directions:
mo_shape = np.insert(mo_shape, 1, np.int64(num_directions)) mo_shape = np.insert(mo_shape, 1, np.int64(num_directions))
new_data = Op._create_data_node(graph, name=lstm.name + '/Data/Reshape_mxnet/', attrs={'shape': mo_shape}) lstm_name = lstm.soft_get('name', lstm.id)
new_data = Op._create_data_node(graph, name=lstm_name + '/Data/Reshape_mxnet/', attrs={'shape': mo_shape})
graph.remove_edge(lstm.id, old_data_node.id) graph.remove_edge(lstm.id, old_data_node.id)
graph.add_edge(lstm.id, new_data.id, key=0, out=0) graph.add_edge(lstm.id, new_data.id, key=0, out=0)
# Add Transpose # Add Transpose
permute_order = Const(graph, dict(value=int64_array([0, 2, 1, 3]))).create_node_with_data() permute_order = Const(graph, {'name': lstm_name + '/Transpose_mxnet_order',
permute_data = Transpose(graph, dict(name=lstm.name + '/Transpose_mxnet/') 'value': int64_array([0, 2, 1, 3])}).create_node_with_data()
permute_data = Transpose(graph, {'name': lstm_name + '/Transpose_mxnet/'}
).create_node_with_data([new_data, permute_order]) ).create_node_with_data([new_data, permute_order])
# Add Reshape # Add Reshape
reshape = Reshape(graph, dict(name=lstm.name + '/Reshape_mxnet/')) reshape = Reshape(graph, {'name': lstm_name + '/Reshape_mxnet/'})
reshape_dim_data = Const(graph, {'name': lstm.name + '/Reshape_mxnet_dim', reshape_dim_data = Const(graph, {'name': lstm_name + '/Reshape_mxnet_dim',
'value': mxnet_shape}).create_node_with_data() 'value': mxnet_shape}).create_node_with_data()
reshape.create_node_with_data([permute_data, reshape_dim_data], dict(), data_nodes=[old_data_node]) reshape.create_node_with_data([permute_data, reshape_dim_data], dict(), data_nodes=[old_data_node])

View File

@ -41,7 +41,8 @@ def resolve_shared_inputs(node: Node, port_ids_to_duplicate: List[int]):
if value is None: if value is None:
log.debug('Can not duplicate due no data for in_port {} of node {}'.format(port_id, node.name)) log.debug('Can not duplicate due no data for in_port {} of node {}'.format(port_id, node.name))
for node, idxs in dst_port_map.items(): for node, idxs in dst_port_map.items():
const = Const(graph, {'value': np.array(value)}).create_node() const = Const(graph, {'value': np.array(value),
'name': node.soft_get('name', node.id) + '/duplicated_'}).create_node()
for idx in idxs: for idx in idxs:
node.in_port(idx).disconnect() node.in_port(idx).disconnect()
const.out_port(0).connect(node.in_port(idx)) const.out_port(0).connect(node.in_port(idx))

View File

@ -34,6 +34,6 @@ class ReverseTransposeNormalization(MiddleReplacementPattern):
node = match['transpose'] node = match['transpose']
assert len(node.in_nodes()) == 1 assert len(node.in_nodes()) == 1
order = np.arange(len(node.in_port(0).data.get_shape()))[::-1] order = np.arange(len(node.in_port(0).data.get_shape()))[::-1]
const = Const(graph, {'value': order}).create_node() const = Const(graph, {'value': order, 'name': node.soft_get('name', node.id) + '/Order'}).create_node()
node.add_input_port(1, skip_if_exist=True) node.add_input_port(1, skip_if_exist=True)
const.out_port(0).connect(node.in_port(1)) const.out_port(0).connect(node.in_port(1))

View File

@ -16,9 +16,9 @@
import numpy as np import numpy as np
from extensions.ops.reverse_sequence import ReverseSequence from extensions.ops.reverse_sequence import ReverseSequence
from mo.graph.graph import Graph from mo.front.tf.graph_utils import create_op_node_with_second_input
from mo.graph.graph import Graph, rename_node
from mo.middle.replacement import MiddleReplacementPattern from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
from mo.utils.error import Error from mo.utils.error import Error
@ -57,14 +57,15 @@ class ReverseToReverseSequence(MiddleReplacementPattern):
# 1. For ReverseSequence 1-port input is seq_lengths => create this input node # 1. For ReverseSequence 1-port input is seq_lengths => create this input node
seq_lengths = np.ones(input_data_shape[batch_axis]) * input_data_shape[seq_axis] seq_lengths = np.ones(input_data_shape[batch_axis]) * input_data_shape[seq_axis]
const = Const(graph, dict(value=seq_lengths)).create_node()
reverse_name = reverse.soft_get('name', reverse.id)
rename_node(reverse, reverse_name + '/to_delete')
# 2. Create new ReverseSequence node and reconnect all inputs/outputs to it # 2. Create new ReverseSequence node and reconnect all inputs/outputs to it
reverse_sequence = ReverseSequence(graph, {'name': reverse.name + '/ReverseSequence/', reverse_sequence = create_op_node_with_second_input(graph, ReverseSequence, seq_lengths,
'seq_axis': seq_axis, 'batch_axis': batch_axis}).create_node() {'name': reverse_name, 'seq_axis': seq_axis,
'batch_axis': batch_axis})
rename_node(reverse_sequence, reverse_name)
reverse.in_port(0).get_connection().set_destination(reverse_sequence.in_port(0)) reverse.in_port(0).get_connection().set_destination(reverse_sequence.in_port(0))
const.out_port(0).connect(reverse_sequence.in_port(1))
reverse.out_port(0).get_connection().set_source(reverse_sequence.out_port(0)) reverse.out_port(0).get_connection().set_source(reverse_sequence.out_port(0))
# 3. Delete old Reverse node # 3. Delete old Reverse node

View File

@ -36,7 +36,7 @@ class SwapAxisMiddleReplacer(MiddleReplacementPattern):
order = swapaxis.order order = swapaxis.order
swapaxis.add_input_port(1) swapaxis.add_input_port(1)
const = Const(graph, {'value': order}).create_node() const = Const(graph, {'value': order, 'name': swapaxis.soft_get('name', swapaxis.id) + '/Order'}).create_node()
const.out_port(0).connect(swapaxis.in_port(1)) const.out_port(0).connect(swapaxis.in_port(1))
Transpose.update_node_stat(swapaxis, {'need_shape_inference': True}) Transpose.update_node_stat(swapaxis, {'need_shape_inference': True})

View File

@ -25,9 +25,9 @@ from extensions.ops.elementwise import Mul
from extensions.ops.interpolate import Interpolate from extensions.ops.interpolate import Interpolate
from mo.front.common.layout import get_height_dim, get_width_dim, get_depth_dim from mo.front.common.layout import get_height_dim, get_width_dim, get_depth_dim
from mo.front.common.partial_infer.utils import int64_array from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
from mo.graph.graph import Graph, Node from mo.graph.graph import Graph, Node
from mo.middle.replacement import MiddleReplacementPattern from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
from mo.ops.shape import Shape from mo.ops.shape import Shape
from mo.ops.strided_slice import StridedSlice from mo.ops.strided_slice import StridedSlice
@ -55,6 +55,7 @@ class UpsampleToResample(MiddleReplacementPattern):
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
log.debug('UpsampleToResample is triggered') log.debug('UpsampleToResample is triggered')
upsample = match['upsample'] upsample = match['upsample']
upsample_name = upsample.soft_get('name', upsample.id)
input_shape = upsample.in_port(0).data.get_shape() input_shape = upsample.in_port(0).data.get_shape()
input_shape_rank = len(input_shape) input_shape_rank = len(input_shape)
if input_shape_rank not in [4, 5]: if input_shape_rank not in [4, 5]:
@ -67,8 +68,7 @@ class UpsampleToResample(MiddleReplacementPattern):
return return
scales = upsample.in_node(1).value scales = upsample.in_node(1).value
assert len(scales) in (4, 5), 'Supported scales rank is 4 or 5, but it is {} for node {}'.format( assert len(scales) in (4, 5), 'Supported scales rank is 4 or 5, but it is {} for node {}'.format(
len(scales), upsample.soft_get('name', upsample.id) len(scales), upsample_name)
)
if not (math.isclose(scales[0], 1, rel_tol=1e-5) and math.isclose(scales[1], 1, rel_tol=1e-5)): if not (math.isclose(scales[0], 1, rel_tol=1e-5) and math.isclose(scales[1], 1, rel_tol=1e-5)):
return return
height_scale = scales[2] height_scale = scales[2]
@ -81,45 +81,50 @@ class UpsampleToResample(MiddleReplacementPattern):
if not math.isclose(height_scale, width_scale, rel_tol=1e-5): if not math.isclose(height_scale, width_scale, rel_tol=1e-5):
log.debug('Width and height scales are not equal: {} vs {} for node {}'.format( log.debug('Width and height scales are not equal: {} vs {} for node {}'.format(
width_scale, height_scale, upsample.soft_get('name'))) width_scale, height_scale, upsample_name))
return return
if depth_scale is not None and not math.isclose(height_scale, depth_scale, rel_tol=1e-5): if depth_scale is not None and not math.isclose(height_scale, depth_scale, rel_tol=1e-5):
log.debug('Depth and height scales are not equal: {} vs {} for node {}'.format( log.debug('Depth and height scales are not equal: {} vs {} for node {}'.format(
depth_scale, height_scale, upsample.soft_get('name'))) depth_scale, height_scale, upsample_name))
return return
if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected(): if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
upsample.in_port(1).disconnect() upsample.in_port(1).disconnect()
shape = Shape(graph, {'name': upsample.name + '/0_port'}).create_node() shape = Shape(graph, {'name': upsample_name + '/0_port'}).create_node()
layout = graph.graph['layout'] layout = graph.graph['layout']
if input_shape_rank == 4:
begin = Const(graph, {'value': int64_array([get_height_dim(layout, input_shape_rank)])}).create_node()
factor = Const(graph, {'value': np.array([height_scale, width_scale])}).create_node()
else:
begin = Const(graph, {'value': int64_array([get_depth_dim(layout, input_shape_rank)])}).create_node()
factor = Const(graph, {'value': np.array([depth_scale, height_scale, width_scale])}).create_node()
end = Const(graph, {'value': int64_array([get_width_dim(layout, input_shape_rank) + 1])}).create_node()
stride = Const(graph, {'value': int64_array([1])}).create_node() if input_shape_rank == 4:
ss = StridedSlice(graph, {'name': upsample.name + '/ss_0_port', begin_value = int64_array([get_height_dim(layout, input_shape_rank)])
factor_value = np.array([height_scale, width_scale])
else:
begin_value = int64_array([get_depth_dim(layout, input_shape_rank)])
factor_value = np.array([depth_scale, height_scale, width_scale])
ss = create_op_with_const_inputs(graph, StridedSlice,
{1: begin_value,
2: int64_array([get_width_dim(layout, input_shape_rank) + 1]),
3: int64_array([1])
},
{'name': upsample_name + '/ss_0_port',
'begin_mask': int64_array([1]), 'begin_mask': int64_array([1]),
'end_mask': int64_array([1]), 'end_mask': int64_array([1]),
'new_axis_mask': int64_array([0]), 'new_axis_mask': int64_array([0]),
'shrink_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]),
'ellipsis_mask': int64_array([0])}).create_node() 'ellipsis_mask': int64_array([0])
}
)
mul = Mul(graph, {'name': upsample.name + '/factor_mul_'}).create_node() mul = create_op_node_with_second_input(graph, Mul, factor_value, {'name': upsample_name + '/factor_mul_'})
source = upsample.in_port(0).get_connection().get_source() source = upsample.in_port(0).get_connection().get_source()
source.connect(shape.in_port(0)) source.connect(shape.in_port(0))
shape.out_port(0).connect(ss.in_port(0)) shape.out_port(0).connect(ss.in_port(0))
begin.out_port(0).connect(ss.in_port(1))
end.out_port(0).connect(ss.in_port(2))
stride.out_port(0).connect(ss.in_port(3))
ss.out_port(0).connect(mul.in_port(0)) ss.out_port(0).connect(mul.in_port(0))
factor.out_port(0).connect(mul.in_port(1))
# Create Interpolate operation # Create Interpolate operation
if input_shape_rank == 4: if input_shape_rank == 4:
@ -130,7 +135,7 @@ class UpsampleToResample(MiddleReplacementPattern):
get_height_dim(layout, input_shape_rank), get_height_dim(layout, input_shape_rank),
get_width_dim(layout, input_shape_rank)]) get_width_dim(layout, input_shape_rank)])
resample_op = Interpolate(graph, dict(name='Interpolate/{}'.format(upsample.name), resample_op = Interpolate(graph, dict(name=upsample_name + '/Interpolate',
axes=axes, mode=upsample.attrs()['mode'], axes=axes, mode=upsample.attrs()['mode'],
antialias=0, convert_to_resample=True)).create_node() antialias=0, convert_to_resample=True)).create_node()