Fix const node non-deterministic names (part 2) (#1081)

* Fix non-deterministic node names generation in the Model Optimizer (part 2)
This commit is contained in:
Anton Chetverikov 2020-07-07 09:37:48 +03:00 committed by GitHub
parent 5d1c5ee6a9
commit 56916ace61
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 205 additions and 175 deletions

View File

@ -20,7 +20,7 @@ from extensions.back.ReshapeMutation import ReshapeMutation
from extensions.back.ReverseInputChannels import ApplyReverseChannels
from mo.back.replacement import BackReplacementPattern
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_node_with_second_input
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
from mo.graph.graph import Graph
from mo.ops.const import Const
from mo.ops.reshape import Reshape
@ -224,6 +224,7 @@ class DeconvolutionNormalizer(BackReplacementPattern):
def replace_pattern(self, graph: Graph, match: dict):
node = match['node']
node_name = node.soft_get('name', node.id)
if 2 in node.in_ports() and not node.in_port(2).disconnected():
# Third input represents output shape. Cutting its value according to scheme:
@ -233,22 +234,17 @@ class DeconvolutionNormalizer(BackReplacementPattern):
shape_src = node.in_port(2).get_source()
node.in_port(2).disconnect()
begin = Const(graph, {'value': np.array([2], dtype=np.int32)}).create_node()
end = Const(graph, {'value': np.array([in_rank], dtype=np.int32)}).create_node()
stride = Const(graph, {'value': np.array([1], dtype=np.int32)}).create_node()
ss_0 = StridedSlice(graph, {'name': node.name + '/ss_0_port',
ss_0 = create_op_with_const_inputs(graph, StridedSlice, {1: np.array([2], dtype=np.int32),
2: np.array([in_rank], dtype=np.int32),
3: np.array([1], dtype=np.int32)},
{'name': node_name + '/ss_0_port',
'begin_mask': np.array([1], dtype=np.int32),
'end_mask': np.array([0], dtype=np.int32),
'new_axis_mask': np.array([0], dtype=np.int32),
'shrink_axis_mask': np.array([0], dtype=np.int32),
'ellipsis_mask': np.array([0], dtype=np.int32)}).create_node()
'ellipsis_mask': np.array([0], dtype=np.int32)})
shape_src.connect(ss_0.in_port(0))
begin.out_port(0).connect(ss_0.in_port(1))
end.out_port(0).connect(ss_0.in_port(2))
stride.out_port(0).connect(ss_0.in_port(3))
ss_0.out_port(0).connect(node.in_port(2))
# Specification: *padding amount* is deduced from relation of input and output spatial shapes
@ -256,7 +252,8 @@ class DeconvolutionNormalizer(BackReplacementPattern):
elif node.has_valid('original_output_spatial_shape'):
# node had fixed output spatial shape set in original framework, so we restore it here
const = Const(graph, {'value': int64_array(node.original_output_spatial_shape)}).create_node()
const = Const(graph, {'value': int64_array(node.original_output_spatial_shape),
'name': node_name + '/original_spatial_shape'}).create_node()
node.add_input_port(2, skip_if_exist=True)
const.out_port(0).connect(node.in_port(2))

View File

@ -55,43 +55,52 @@ class CropToStridedSlice(BackReplacementPattern):
def replace_pattern(self, graph: Graph, match: [str, Node]):
node = match['crop']
assert node.has_valid('axis')
node.axis = self.list_to_ndarray(node.axis)
node_axis = self.list_to_ndarray(node.axis)
in_shape = node.in_port(0).data.get_shape()
shape_rank = in_shape.size
axis_mask = int64_array([1 if i in node.axis else 0 for i in range(shape_rank)])
axis_mask = int64_array([1 if i in node_axis else 0 for i in range(shape_rank)])
begin_mask = axis_mask.copy()
end_mask = axis_mask.copy()
ss = StridedSlice(graph, {'name': node.soft_get('name', node.id) + '/strided_slice', 'begin_mask': begin_mask,
'end_mask': end_mask, 'new_axis_mask': np.array([0]),
'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()
if len(node.in_nodes()) == 2 and node.has_valid('offset'):
# Crop Type 1
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node.axis, node.offset)}).create_node()
shape = Shape(graph, {'name': node.name + '/shape_of_crop'}).create_node()
end = Add(graph, {'name': node.name + '/end'}).create_node()
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, node.offset),
'name': ss.name + '/begin'}).create_node()
shape = Shape(graph, {'name': ss.name + '/shape_of_crop'}).create_node()
end = Add(graph, {'name': ss.name + '/end'}).create_node()
node.in_port(1).get_connection().get_source().connect(shape.in_port(0))
node.in_port(1).disconnect()
shape.out_port(0).connect(end.in_port(0))
begin.out_port(0).connect(end.in_port(1))
elif node.has_valid('dim') and node.has_valid('offset'):
# Crop Type 2
node.dim = self.list_to_ndarray(node.dim)
node.offset = self.list_to_ndarray(node.offset)
assert node.dim.size == node.offset.size == node.axis.size
node_dim = self.list_to_ndarray(node.dim)
node_offset = self.list_to_ndarray(node.offset)
assert node_dim.size == node_offset.size == node_axis.size
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node.axis, node.offset)}).create_node()
end_values = np.array([node.offset[i] + node.dim[i] for i in range(len(node.dim))])
end = Const(graph, {'value': self.mask_normalizer(shape_rank, node.axis, end_values)}).create_node()
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, node_offset),
'name': ss.name + '/begin'}).create_node()
end_values = np.array([node_offset[i] + node_dim[i] for i in range(len(node_dim))])
end = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, end_values),
'name': ss.name + '/end'}).create_node()
elif node.has_valid('crop_begin') and node.has_valid('crop_end'):
# Crop Type 3
node.crop_begin = self.list_to_ndarray(node.crop_begin)
node.crop_end = self.list_to_ndarray(node.crop_end)
assert len(node.crop_begin) == len(node.crop_end) == len(node.axis)
node_crop_begin = self.list_to_ndarray(node.crop_begin)
node_crop_end = self.list_to_ndarray(node.crop_end)
assert len(node_crop_begin) == len(node_crop_end) == len(node_axis)
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node.axis, node.crop_begin)}).create_node()
shape = Shape(graph, {'name': node.name + '/shape_of_crop'}).create_node()
const = Const(graph,
{'value': -1 * self.mask_normalizer(shape_rank, node.axis, node.crop_end)}).create_node()
end = Add(graph, {'name': node.name + '/end'}).create_node()
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, node_crop_begin),
'name': ss.name + '/begin'}).create_node()
shape = Shape(graph, {'name': ss.name + '/shape'}).create_node()
end = Add(graph, {'name': ss.name + '/end'}).create_node()
const = Const(graph, {'value': -1 * self.mask_normalizer(shape_rank, node_axis, node_crop_end),
'name': ss.name + '/const'}).create_node()
node.in_port(0).get_connection().get_source().connect(shape.in_port(0))
shape.out_port(0).connect(end.in_port(0))
@ -102,9 +111,8 @@ class CropToStridedSlice(BackReplacementPattern):
source = node.in_port(0).get_connection().get_source()
stride = Const(graph, {'value': np.ones(shape_rank, dtype=np.int64)}).create_node()
ss = StridedSlice(graph, {'name': 'Crop_', 'begin_mask': begin_mask, 'end_mask': end_mask, 'new_axis_mask': np.array([0]),
'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()
stride = Const(graph, {'value': np.ones(shape_rank, dtype=np.int64),
'name': ss.name + '/stride'}).create_node()
source.connect(ss.in_port(0))
begin.out_port(0).connect(ss.in_port(1))

View File

@ -44,6 +44,7 @@ class GroupedConvWeightsNormalize(BackReplacementPattern):
weights = match['weights']
input_shape = conv.in_port(0).data.get_shape()
new_weights_shape = int64_array([(weights.value.shape[0] * weights.value.shape[1]) / (input_shape[1] / conv.group), input_shape[1] / conv.group, *weights.value.shape[2:]])
new_weights = Const(graph, {'value': np.reshape(weights.value, new_weights_shape)}).create_node()
new_weights = Const(graph, {'value': np.reshape(weights.value, new_weights_shape),
'name': weights.soft_get('name', weights.id) + '_new'}).create_node()
weights.out_port(0).get_connection().set_source(new_weights.out_port(0))
new_weights.infer(new_weights)

View File

@ -18,7 +18,7 @@ import numpy as np
from extensions.back.ForceStrictPrecision import ForceStrictPrecision
from extensions.ops.prelu import PreluOp
from mo.back.replacement import BackReplacementPattern
from mo.graph.graph import Graph
from mo.graph.graph import Graph, rename_node
from mo.ops.const import Const
@ -39,11 +39,16 @@ class LeakyReLUMutation(BackReplacementPattern):
@staticmethod
def replace_pattern(graph: Graph, match: dict):
relu = match['leakyrelu']
relu_name = relu.soft_get('name', relu.id)
if not relu.has_valid('negative_slope'):
return
rename_node(relu, relu_name + '/to_delete')
# Create PReLU op and reconnect input/output from LeakyReLU to PReLU
prelu = PreluOp(graph, dict(name=relu.name)).create_node()
const = Const(graph, dict(name=relu.name + "/weights", value=np.array([relu.negative_slope]))).create_node()
prelu = PreluOp(graph, dict(name=relu_name)).create_node()
rename_node(prelu, relu_name)
const = Const(graph, dict(name=relu_name + "/weights", value=np.array([relu.negative_slope]))).create_node()
relu.in_port(0).get_connection().set_destination(prelu.in_port(0))
const.out_port(0).connect(prelu.in_port(1))

View File

@ -19,6 +19,7 @@ import numpy as np
from extensions.ops.transpose import Transpose
from mo.back.replacement import BackReplacementPattern
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_node_with_second_input
from mo.graph.graph import Graph
from mo.ops.const import Const
from mo.ops.unsqueeze import Unsqueeze
@ -56,13 +57,12 @@ class MatMulConstTransposesExtraction(BackReplacementPattern):
transpose_order = list(range(port_shape.size))
transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1]
order = Const(graph, {'value': int64_array(transpose_order)}).create_node()
transpose = Transpose(graph, {'name': name + '/{}_port_transpose'.format(in_port_idx)}).create_node()
transpose = create_op_node_with_second_input(graph, Transpose, int64_array(transpose_order),
{'name': name + '/{}_port_transpose'.format(in_port_idx)})
port_source = in_port.get_source()
in_port.get_connection().set_source(transpose.out_port(0))
transpose.in_port(0).connect(port_source)
transpose.in_port(1).connect(order.out_port(0))
transpose['override_output_shape'] = True

View File

@ -21,8 +21,8 @@ from extensions.back.ReshapeMutation import ReshapeMutation
from extensions.back.StridedSliceMasksNormalizer import StridedSliceMasksNormalizer
from mo.back.replacement import BackReplacementPattern
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
from mo.graph.graph import Graph
from mo.ops.const import Const
from mo.ops.reshape import Reshape
from mo.ops.strided_slice import StridedSlice
@ -53,32 +53,27 @@ class ProposalMutation(BackReplacementPattern):
'implementation of the Proposal layer uses only 4 first values (indices 0, 1, 2 and 3). '
'Elements with indices 4 and 5 will be ignored.'.format(node.soft_get('name', node.id)),
extra={'is_warning': True})
begin = Const(graph, {'value': np.array([0, 0], dtype=np.int32)}).create_node()
end = Const(graph, {'value': np.array([1, 3], dtype=np.int32)}).create_node()
stride = Const(graph, {'value': np.array([1, 1], dtype=np.int32)}).create_node()
cropped_im_info = StridedSlice(graph, {'name': 'cropped_im_info',
cropped_im_info = create_op_with_const_inputs(graph, StridedSlice, {1: np.array([0, 0], dtype=np.int32),
2: np.array([1, 3], dtype=np.int32),
3: np.array([1, 1], dtype=np.int32)},
{'name': 'cropped_im_info',
'begin_mask': int64_array([1, 1]),
'end_mask': int64_array([1, 1]),
'new_axis_mask': int64_array([0]),
'shrink_axis_mask': int64_array([0]),
'ellipsis_mask': int64_array([0]),
'override_output_shape': True,
}).create_node()
})
node.in_port(2).get_connection().insert_node(cropped_im_info)
begin.out_port(0).connect(cropped_im_info.in_port(1))
end.out_port(0).connect(cropped_im_info.in_port(2))
stride.out_port(0).connect(cropped_im_info.in_port(3))
# update the im_info_shape so the next 'if' statement become true
im_info_shape = int64_array([1, 3])
if np.array_equal(im_info_shape, [1, 3]) or np.array_equal(im_info_shape, [1, 4]):
reshape = Reshape(graph, dict(name="im_info/Reshape")).create_node()
const = Const(graph, dict(value=[im_info_shape[1]])).create_node()
reshape = create_op_node_with_second_input(graph, Reshape, [im_info_shape[1]], {'name': 'im_info/Reshape'})
node.in_port(2).get_connection().set_destination(reshape.in_port(0))
const.out_port(0).connect(reshape.in_port(1))
reshape.out_port(0).connect(node.in_port(2))
if node.has_port('out', 1) and not node.out_port(1).disconnected():

View File

@ -25,7 +25,6 @@ from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph
from mo.graph.graph import Node
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.op import Op, PermuteAttrs
@ -358,11 +357,7 @@ class DecomposeReverseChannels(BackReplacementPattern):
axis = node.axis
order = node.order
indices = Const(graph, {'name': name + '/reverse_order', 'value': order}).create_node()
axis_const = Const(graph, {'value': int64_array(axis)}).create_node()
gather = Gather(graph, {'name': name}).create_node()
gather.in_port(1).connect(indices.out_port(0))
gather.in_port(2).connect(axis_const.out_port(0))
gather = create_op_with_const_inputs(graph, Gather, {1: order, 2: int64_array(axis)}, {'name': name})
node.out_port(0).get_connection().set_source(gather.out_port(0))
node.in_port(0).get_connection().set_destination(gather.in_port(0))

View File

@ -107,8 +107,9 @@ class CompressQuantizeWeights(BackReplacementPattern):
def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
initial_fake_quantize = match['quantize']
initial_fake_quantize_name = initial_fake_quantize.soft_get('name', initial_fake_quantize.id)
new_fake_quantize = initial_fake_quantize.copy_node(dict(name=initial_fake_quantize.name + '/Copy',
new_fake_quantize = initial_fake_quantize.copy_node(dict(name=initial_fake_quantize_name + '/Copy',
stop_value_propagation=False), graph)
initial_fake_quantize.in_port(1).get_connection().set_destination(new_fake_quantize.in_port(1))
@ -121,9 +122,9 @@ class CompressQuantizeWeights(BackReplacementPattern):
i_min = np.array([0.], dtype=dst_type)
i_max = np.array([initial_fake_quantize.levels - 1.], dtype=dst_type)
new_out_low_node = Const(graph, dict(name=initial_fake_quantize.name + '/Copy/out_low',
new_out_low_node = Const(graph, dict(name=initial_fake_quantize_name + '/Copy/out_low',
value=i_min)).create_node()
new_out_high_node = Const(graph, dict(name=initial_fake_quantize.name + '/Copy/out_high',
new_out_high_node = Const(graph, dict(name=initial_fake_quantize_name + '/Copy/out_high',
value=i_max)).create_node()
new_out_low_node.out_port(0).connect(new_fake_quantize.in_port(3))
@ -131,7 +132,7 @@ class CompressQuantizeWeights(BackReplacementPattern):
new_out_low_node.out_port(0).connect(initial_fake_quantize.in_port(1))
new_out_high_node.out_port(0).connect(initial_fake_quantize.in_port(2))
cast_node = Cast(graph, dict(name=initial_fake_quantize.name + "/Convert_to_float", dst_type=dst_type,
cast_node = Cast(graph, dict(name=initial_fake_quantize_name + "/Convert_to_float", dst_type=dst_type,
stop_value_propagation=True)).create_node()
new_fake_quantize.out_port(0).connect(cast_node.in_port(0))
initial_fake_quantize.in_port(0).get_connection().set_destination(new_fake_quantize.in_port(0))

View File

@ -49,12 +49,12 @@ class PriorboxMutation(BackReplacementPattern):
assert len(node.in_ports()) == 2
begin = Const(graph, {'value': np.array([2], dtype=np.int32)}).create_node()
end = Const(graph, {'value': np.array([4], dtype=np.int32)}).create_node()
stride = Const(graph, {'value': np.array([1], dtype=np.int32)}).create_node()
begin = Const(graph, {'value': np.array([2], dtype=np.int32), 'name': name + '/ss_begin'}).create_node()
end = Const(graph, {'value': np.array([4], dtype=np.int32), 'name': name + '/ss_end'}).create_node()
stride = Const(graph, {'value': np.array([1], dtype=np.int32), 'name': name + '/ss_stride'}).create_node()
shape_0 = Shape(graph, {'name': node.name + '/0_port'}).create_node()
ss_0 = StridedSlice(graph, {'name': node.name + '/ss_0_port',
shape_0 = Shape(graph, {'name': name + '/0_port'}).create_node()
ss_0 = StridedSlice(graph, {'name': name + '/ss_0_port',
'begin_mask': np.array([1], dtype=np.int32),
'end_mask': np.array([0], dtype=np.int32),
'new_axis_mask': np.array([0], dtype=np.int32),
@ -71,8 +71,8 @@ class PriorboxMutation(BackReplacementPattern):
source.connect(shape_0.in_port(0))
ss_0.out_port(0).connect(node.in_port(0))
shape_1 = Shape(graph, {'name': node.name + '/1_port'}).create_node()
ss_1 = StridedSlice(graph, {'name': node.name + '/ss_1_port',
shape_1 = Shape(graph, {'name': name + '/1_port'}).create_node()
ss_1 = StridedSlice(graph, {'name': name + '/ss_1_port',
'begin_mask': np.array([1], dtype=np.int32),
'end_mask': np.array([0], dtype=np.int32),
'new_axis_mask': np.array([0], dtype=np.int32),

View File

@ -92,6 +92,8 @@ class BinarizeWeightsM1P1(MiddleReplacementPattern):
output_low = quantize.in_node(3)
output_high = quantize.in_node(4)
quantize_name = quantize.soft_get('name', quantize.id)
if not output_low.has_valid('value') and not output_high.has_valid('value'):
return
@ -115,8 +117,9 @@ class BinarizeWeightsM1P1(MiddleReplacementPattern):
mult_term = quantize.in_node(3) if np.all(output_high == 0) else quantize.in_node(4)
new_shape = Const(graph, {'value': int64_array([-1, 1, 1])}).create_node_with_data()
reshape = Reshape(graph, {}).create_node_with_data([mult_term, new_shape])
new_shape = Const(graph, {'name': quantize_name + '/Reshape/Shape',
'value': int64_array([-1, 1, 1])}).create_node_with_data()
reshape = Reshape(graph, {'name': quantize_name + '/Reshape'}).create_node_with_data([mult_term, new_shape])
# Patch inflow path (by diving by mult_term)
# Put a new Pow/Mul combination here:
@ -125,15 +128,16 @@ class BinarizeWeightsM1P1(MiddleReplacementPattern):
if len(match['quantized'].out_nodes()) > 1:
log.debug('BinarizeWeightsM1P1: len(match[\'quantized\'].out_nodes()) > 1')
return
power_of_exponent = Const(graph, {'value': np.array(-1.0)}).create_node_with_data()
div_op = Pow(graph, {'name': quantize.name + '/DivNormalize'})
power_of_exponent = Const(graph, {'name': quantize_name + '/DivNormalize/Power',
'value': np.array(-1.0)}).create_node_with_data()
div_op = Pow(graph, {'name': quantize_name + '/DivNormalize'})
div_output = div_op.create_node_with_data([mult_term, power_of_exponent])
for i in [3, 4]:
match['quantize'].insert_node_with_data_before(
match['quantize'].in_node(i),
Mul,
dict(name=quantize.name + '/MulNormalize'),
dict(name=quantize_name + '/MulNormalize'),
additional_inputs=[div_output],
)

View File

@ -19,9 +19,9 @@ import logging as log
import numpy as np
from extensions.ops.elementwise import Mul, Add
from mo.front.tf.graph_utils import create_op_node_with_second_input
from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
class ConvToBinaryConv(MiddleReplacementPattern):
@ -91,12 +91,10 @@ class ConvToBinaryConv(MiddleReplacementPattern):
weights_reduced = np.add.reduce(weights, axis=tuple(reduction_indices))
weights_reduced = weights_reduced.reshape([len(weights_reduced), 1, 1]) # FIXME: works for NCHW only
add_term = Const(graph, {'value': weights_reduced}).create_node()
add = Add(graph, {}).create_node()
add.in_port(1).connect(add_term.out_port(0))
mul_term = Const(graph, {'value': np.array(0.5)}).create_node()
mul = Mul(graph, {}).create_node()
mul.in_port(1).connect(mul_term.out_port(0))
operator_name = operator.soft_get('name', operator.id)
add = create_op_node_with_second_input(graph, Add, weights_reduced, {'name': operator_name + '/Add_'})
mul = create_op_node_with_second_input(graph, Mul, np.array(0.5), {'name': operator_name + '/Mul_'})
add.out_port(0).connect(mul.in_port(0))
operator.out_port(0).get_connection().set_source(mul.out_port(0))

View File

@ -189,9 +189,13 @@ class ConvertGroupedStridedSlice(MiddleReplacementPattern):
log.debug("Removed: {}".format(node.id))
# 2. Create Split layer and reorder outputs
axis_const = Const(graph, {'value': int64_array(split_channel_dim)}).create_node_with_data()
size_splits_const = Const(graph, {'value': int64_array(size_splits)}).create_node_with_data()
split = VariadicSplit(graph, dict(name=name_for_future_split + "/Split", out_ports_count=len(size_splits)))
name = name_for_future_split + "/Split"
axis_const = Const(graph, {'value': int64_array(split_channel_dim),
'name': name + '/Axis'}).create_node_with_data()
size_splits_const = Const(graph, {'value': int64_array(size_splits),
'name': name + '/Sizes'}).create_node_with_data()
split = VariadicSplit(graph, dict(name=name, out_ports_count=len(size_splits)))
split.create_node_with_data(inputs=[input_data, axis_const, size_splits_const],
data_nodes=final_data_nodes_list)

View File

@ -36,6 +36,7 @@ class ConvertLayoutDependentOperations(MiddleReplacementPattern):
def find_and_replace_pattern(self, graph: Graph):
for node in list(graph.nodes()):
node = Node(graph, node)
node_name = node.soft_get('name', node.id)
# Check that node layout mismatch with graph layout
# For example: NHWC and NCHW or NCDHW and NDHWC
if node.kind == 'op' and node.has_valid('layout') and node.layout != indices_mapping[len(node.layout)][
@ -63,8 +64,10 @@ class ConvertLayoutDependentOperations(MiddleReplacementPattern):
edge_attrs = graph.get_edge_data(input.id, node.id)[0]
graph.remove_edge(input.id, node.id)
input_order_const = Const(graph, {'value': permutation.perm}).create_node_with_data()
input_permute_op = Transpose(graph, dict(name=node.name + '/Transpose_'))
input_permute_name = node_name + '/input_transpose'
input_order_const = Const(graph, {'name': input_permute_name + '/order',
'value': permutation.perm}).create_node_with_data()
input_permute_op = Transpose(graph, {'name': input_permute_name})
input_permute_data_node = input_permute_op.create_node_with_data([input, input_order_const])
graph.add_edge(input_permute_data_node.id, node.id, **edge_attrs)
@ -77,8 +80,10 @@ class ConvertLayoutDependentOperations(MiddleReplacementPattern):
input_data_node = Op.create_data_node(graph, node, {'shape': output.shape[permutation.perm]},
edge_attrs)
output_order_const = Const(graph, {'value': permutation.inv}).create_node_with_data()
output_permute_op = Transpose(graph, dict(name=node.name + '/Transpose_')
output_permute_name = node_name + '/output_transpose'
output_order_const = Const(graph, {'name': output_permute_name + '/order',
'value': permutation.inv}).create_node_with_data()
output_permute_op = Transpose(graph, {'name': output_permute_name}
).create_node_with_data([input_data_node, output_order_const],
data_nodes=output)

View File

@ -46,9 +46,11 @@ class Deconvolution3rdInputNormalization(MiddleReplacementPattern):
data_node = node.in_node(2)
const = Const(graph, {'value': permutation.perm, 'need_shape_inference': True}).create_node_with_data()
axis_const = Const(graph, {'value': int64_array(0)}).create_node_with_data()
gather = Gather(graph, {'name': node.name + '/ShapeGather',
name = node.soft_get('name', node.id) + '/ShapeGather'
const = Const(graph, {'value': permutation.perm, 'name': name + '/Const',
'need_shape_inference': True}).create_node_with_data()
axis_const = Const(graph, {'value': int64_array(0), 'name': name + '/Axis'}).create_node_with_data()
gather = Gather(graph, {'name': name,
'need_shape_inference': True}).create_node_with_data([data_node, const, axis_const])
attrs = graph.get_edge_data(data_node.id, node.id, key=0).copy()

View File

@ -160,7 +160,8 @@ class DilatedConvolution1DConverter(MiddleReplacementPattern):
unsqueeze_axis = unsqueeze.in_port(1).data.get_value()
for port_id in [1, 2]:
current_value = pad.in_port(port_id).get_connection().data.get_value()
new_value_node = Const(pad.graph, {'value': np.insert(current_value, unsqueeze_axis.item(), 0),
new_value_node = Const(pad.graph, {'name': pad.soft_get('name', pad.id) + '/value_{}'.format(port_id),
'value': np.insert(current_value, unsqueeze_axis.item(), 0),
'override_output_shape': True}).create_node()
pad.in_port(port_id).disconnect()
pad.in_port(port_id).connect(new_value_node.out_port(0))

View File

@ -106,8 +106,10 @@ class EltwiseInputReshape(MiddleReplacementPattern):
# Insert Reshape layer between data node and consumer
for shape_key in mapping.keys():
shape = list(shape_key)
reshape = Reshape(graph, attrs={'name': 'EltwiseReshapeNormalization'})
reshape_dim = Const(graph, {'value': shape}).create_node_with_data()
reshape_name = node.soft_get('name', node.id) + '/EltwiseReshape'
reshape = Reshape(graph, attrs={'name': reshape_name})
reshape_dim = Const(graph,
{'value': shape, 'name': reshape_name + '/Shape'}).create_node_with_data()
reshape_data = reshape.create_node_with_data(inputs=[node, reshape_dim])
# Iterate over consumers and reconnect them to Reshape layer output

View File

@ -19,9 +19,9 @@ import numpy as np
from extensions.ops.gather import Gather
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
from mo.graph.graph import Graph, rename_node
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
from mo.ops.reshape import Reshape
@ -68,6 +68,7 @@ class GatherNdNormalize(MiddleReplacementPattern):
def replace_pattern(self, graph: Graph, match: dict):
gather = match['GatherNd']
gather_name = gather.soft_get('name', gather.id)
input_shape = gather.in_node(0).shape
indices = gather.in_node(1).value
if indices is None:
@ -77,26 +78,26 @@ class GatherNdNormalize(MiddleReplacementPattern):
# 0. All needed checks that we can replace GatherNd by Gather
gather_idx = self.indices_check(indices, input_shape)
if gather_idx is None:
log.warning('Node {} with op=GatherNd can\'t be normalized to op=Gather.'.format(gather.name))
log.warning('Node {} with op=GatherNd can\'t be normalized to op=Gather.'.format(gather_name))
return
# 1. Add Reshape and connect
new_shape = int64_array([-1] + list(input_shape[indices.shape[-1]:]))
reshape = Reshape(graph, {'name': gather.name + '/Reshape_for_GatherNd/'}).create_node()
reshape_const_node = Const(graph, {'name': reshape.name + '/Dim', 'value': new_shape}).create_node()
reshape = create_op_node_with_second_input(graph, Reshape, new_shape,
{'name': gather_name + '/Reshape_for_GatherNd/'})
gather.in_port(0).get_connection().set_destination(reshape.in_port(0))
reshape.in_port(1).connect(reshape_const_node.out_port(0))
# 2. Change indices from Nd to 1d:
new_indices = np.reshape(np.take(indices, indices=[gather_idx], axis=-1), [-1])
new_indices_const = Const(graph, dict(value=new_indices)).create_node()
axis_const = Const(graph, {'value': int64_array(0)}).create_node()
rename_node(gather, gather_name + '/to_delete')
# 3. Create new Gather operation and reconnect all inputs/outputs
new_gather = Gather(graph, {'name': gather.name + '/NewGather/'}).create_node()
new_gather = create_op_with_const_inputs(graph, Gather, {1: new_indices, 2: int64_array(0)},
{'name': gather_name})
rename_node(new_gather, gather_name)
reshape.out_port(0).connect(new_gather.in_port(0))
new_indices_const.out_port(0).connect(new_gather.in_port(1))
axis_const.out_port(0).connect(new_gather.in_port(2))
gather.out_port(0).get_connection().set_source(new_gather.out_port(0))

View File

@ -16,9 +16,9 @@
from extensions.middle.pass_separator import PostMiddleStart
from extensions.ops.transpose import Transpose
from mo.graph.graph import Graph, Node
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
from mo.ops.op import PermuteAttrs
@ -69,6 +69,10 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern):
len(node.out_port(0).data.get_shape()) >= 4
def find_and_replace_pattern(self, graph: Graph):
# we need to import these functions here to avoid circular dependent imports
from mo.front.tf.graph_utils import create_op_node_with_second_input
if graph.graph['layout'] != 'NHWC':
# we check it here because this transformation is called explicitly from the pipeline
return
@ -80,15 +84,14 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern):
reinterp_shape_node_id, graph.dump_graph_for_graphviz())
input_shape = reinterp_shape_node.in_node(0).shape
if self.is_nchw_to_nhwc_transpose_needed(reinterp_shape_node):
order_const = Const(graph, {'value': PermuteAttrs().get_nchw_to_nhwc_permutation(len(input_shape)).perm
}).create_node()
permute_node = Transpose(graph,
{'name': reinterp_shape_node.in_port(0).get_source().node.name + '/Transpose'
}).create_node()
permute_node = create_op_node_with_second_input(
graph, Transpose, PermuteAttrs().get_nchw_to_nhwc_permutation(len(input_shape)).perm,
{'name': reinterp_shape_node.in_port(0).get_source().node.name + '/Transpose'}
)
reinterp_shape_node.in_port(0).get_connection().insert_node(permute_node)
order_const.out_port(0).connect(permute_node.in_port(1))
order_const.infer(order_const)
order_const = permute_node.in_port(1).get_source().node
order_const.infer(order_const)
# do not infer the Transpose node because it should have input data node in NCHW layout (but currently
# it is NHWC because data node attributes has not been permuted yet) and produce output in NHWC layout
# (which is true at this moment)
@ -107,11 +110,10 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern):
reinterp_shape_node_id, graph.dump_graph_for_graphviz())
output_shape = reinterp_shape_node.out_node(0).shape
if self.is_nhwc_to_nchw_transpose_needed(reinterp_shape_node):
order_const = Const(graph, {
'value': PermuteAttrs().get_nhwc_to_nchw_permutation(len(output_shape)).perm}).create_node()
permute_node = Transpose(graph, {'name': reinterp_shape_node.id + '/Transpose'}).create_node()
permute_node = create_op_node_with_second_input(
graph, Transpose, PermuteAttrs().get_nhwc_to_nchw_permutation(len(output_shape)).perm,
{'name': reinterp_shape_node.id + '/Transpose'})
reinterp_shape_node.out_port(0).get_connection().insert_node(permute_node)
order_const.out_port(0).connect(permute_node.in_port(1))
# the Reshape and Transpose operations should work in original (NHWC layout) so the Transpose
# will convert it to the NCHW

View File

@ -20,9 +20,9 @@ import numpy as np
from extensions.ops.normalize import NormalizeOp
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_node_with_second_input
from mo.graph.graph import Graph, rename_node
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
class L2NormToNorm(MiddleReplacementPattern):
@ -87,13 +87,13 @@ class L2NormToNorm(MiddleReplacementPattern):
normalizel2_name = output_name + '/normalizel2'
rename_node(match['l2_normalize'], normalizel2_name)
normalize_node = NormalizeOp(graph, {'name': output_name, 'eps': y,
'across_spatial': 0, 'channel_shared': 0}).create_node()
normalize_node = create_op_node_with_second_input(graph, NormalizeOp,
np.ones(shape=int64_array([match['input'].shape[-1]]),
dtype=match['input'].data_type),
{'name': output_name, 'eps': y,
'across_spatial': 0, 'channel_shared': 0})
rename_node(normalize_node, output_name)
weights_node = Const(graph, {'value': np.ones(shape=int64_array([match['input'].shape[-1]]),
dtype=match['input'].data_type)}).create_node()
match['square'].in_port(0).get_source().connect(normalize_node.in_port(0))
match['square'].in_port(0).disconnect()
@ -102,5 +102,4 @@ class L2NormToNorm(MiddleReplacementPattern):
else:
match['l2_normalize'].in_port(0).disconnect()
weights_node.out_port(0).get_connection().set_destination(normalize_node.in_port(1))
match['l2_normalize'].out_port(0).get_connection().set_source(normalize_node.out_port(0))

View File

@ -206,18 +206,21 @@ class MXNetRNNSequenceNormalize(MiddleReplacementPattern):
if lstm.has_num_directions:
mo_shape = np.insert(mo_shape, 1, np.int64(num_directions))
new_data = Op._create_data_node(graph, name=lstm.name + '/Data/Reshape_mxnet/', attrs={'shape': mo_shape})
lstm_name = lstm.soft_get('name', lstm.id)
new_data = Op._create_data_node(graph, name=lstm_name + '/Data/Reshape_mxnet/', attrs={'shape': mo_shape})
graph.remove_edge(lstm.id, old_data_node.id)
graph.add_edge(lstm.id, new_data.id, key=0, out=0)
# Add Transpose
permute_order = Const(graph, dict(value=int64_array([0, 2, 1, 3]))).create_node_with_data()
permute_data = Transpose(graph, dict(name=lstm.name + '/Transpose_mxnet/')
permute_order = Const(graph, {'name': lstm_name + '/Transpose_mxnet_order',
'value': int64_array([0, 2, 1, 3])}).create_node_with_data()
permute_data = Transpose(graph, {'name': lstm_name + '/Transpose_mxnet/'}
).create_node_with_data([new_data, permute_order])
# Add Reshape
reshape = Reshape(graph, dict(name=lstm.name + '/Reshape_mxnet/'))
reshape_dim_data = Const(graph, {'name': lstm.name + '/Reshape_mxnet_dim',
reshape = Reshape(graph, {'name': lstm_name + '/Reshape_mxnet/'})
reshape_dim_data = Const(graph, {'name': lstm_name + '/Reshape_mxnet_dim',
'value': mxnet_shape}).create_node_with_data()
reshape.create_node_with_data([permute_data, reshape_dim_data], dict(), data_nodes=[old_data_node])

View File

@ -41,7 +41,8 @@ def resolve_shared_inputs(node: Node, port_ids_to_duplicate: List[int]):
if value is None:
log.debug('Can not duplicate due no data for in_port {} of node {}'.format(port_id, node.name))
for node, idxs in dst_port_map.items():
const = Const(graph, {'value': np.array(value)}).create_node()
const = Const(graph, {'value': np.array(value),
'name': node.soft_get('name', node.id) + '/duplicated_'}).create_node()
for idx in idxs:
node.in_port(idx).disconnect()
const.out_port(0).connect(node.in_port(idx))

View File

@ -34,6 +34,6 @@ class ReverseTransposeNormalization(MiddleReplacementPattern):
node = match['transpose']
assert len(node.in_nodes()) == 1
order = np.arange(len(node.in_port(0).data.get_shape()))[::-1]
const = Const(graph, {'value': order}).create_node()
const = Const(graph, {'value': order, 'name': node.soft_get('name', node.id) + '/Order'}).create_node()
node.add_input_port(1, skip_if_exist=True)
const.out_port(0).connect(node.in_port(1))

View File

@ -16,9 +16,9 @@
import numpy as np
from extensions.ops.reverse_sequence import ReverseSequence
from mo.graph.graph import Graph
from mo.front.tf.graph_utils import create_op_node_with_second_input
from mo.graph.graph import Graph, rename_node
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
from mo.utils.error import Error
@ -57,14 +57,15 @@ class ReverseToReverseSequence(MiddleReplacementPattern):
# 1. For ReverseSequence 1-port input is seq_lengths => create this input node
seq_lengths = np.ones(input_data_shape[batch_axis]) * input_data_shape[seq_axis]
const = Const(graph, dict(value=seq_lengths)).create_node()
reverse_name = reverse.soft_get('name', reverse.id)
rename_node(reverse, reverse_name + '/to_delete')
# 2. Create new ReverseSequence node and reconnect all inputs/outputs to it
reverse_sequence = ReverseSequence(graph, {'name': reverse.name + '/ReverseSequence/',
'seq_axis': seq_axis, 'batch_axis': batch_axis}).create_node()
reverse_sequence = create_op_node_with_second_input(graph, ReverseSequence, seq_lengths,
{'name': reverse_name, 'seq_axis': seq_axis,
'batch_axis': batch_axis})
rename_node(reverse_sequence, reverse_name)
reverse.in_port(0).get_connection().set_destination(reverse_sequence.in_port(0))
const.out_port(0).connect(reverse_sequence.in_port(1))
reverse.out_port(0).get_connection().set_source(reverse_sequence.out_port(0))
# 3. Delete old Reverse node

View File

@ -36,7 +36,7 @@ class SwapAxisMiddleReplacer(MiddleReplacementPattern):
order = swapaxis.order
swapaxis.add_input_port(1)
const = Const(graph, {'value': order}).create_node()
const = Const(graph, {'value': order, 'name': swapaxis.soft_get('name', swapaxis.id) + '/Order'}).create_node()
const.out_port(0).connect(swapaxis.in_port(1))
Transpose.update_node_stat(swapaxis, {'need_shape_inference': True})

View File

@ -25,9 +25,9 @@ from extensions.ops.elementwise import Mul
from extensions.ops.interpolate import Interpolate
from mo.front.common.layout import get_height_dim, get_width_dim, get_depth_dim
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
from mo.graph.graph import Graph, Node
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
from mo.ops.shape import Shape
from mo.ops.strided_slice import StridedSlice
@ -55,6 +55,7 @@ class UpsampleToResample(MiddleReplacementPattern):
def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
log.debug('UpsampleToResample is triggered')
upsample = match['upsample']
upsample_name = upsample.soft_get('name', upsample.id)
input_shape = upsample.in_port(0).data.get_shape()
input_shape_rank = len(input_shape)
if input_shape_rank not in [4, 5]:
@ -67,8 +68,7 @@ class UpsampleToResample(MiddleReplacementPattern):
return
scales = upsample.in_node(1).value
assert len(scales) in (4, 5), 'Supported scales rank is 4 or 5, but it is {} for node {}'.format(
len(scales), upsample.soft_get('name', upsample.id)
)
len(scales), upsample_name)
if not (math.isclose(scales[0], 1, rel_tol=1e-5) and math.isclose(scales[1], 1, rel_tol=1e-5)):
return
height_scale = scales[2]
@ -81,45 +81,50 @@ class UpsampleToResample(MiddleReplacementPattern):
if not math.isclose(height_scale, width_scale, rel_tol=1e-5):
log.debug('Width and height scales are not equal: {} vs {} for node {}'.format(
width_scale, height_scale, upsample.soft_get('name')))
width_scale, height_scale, upsample_name))
return
if depth_scale is not None and not math.isclose(height_scale, depth_scale, rel_tol=1e-5):
log.debug('Depth and height scales are not equal: {} vs {} for node {}'.format(
depth_scale, height_scale, upsample.soft_get('name')))
depth_scale, height_scale, upsample_name))
return
if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
upsample.in_port(1).disconnect()
shape = Shape(graph, {'name': upsample.name + '/0_port'}).create_node()
shape = Shape(graph, {'name': upsample_name + '/0_port'}).create_node()
layout = graph.graph['layout']
if input_shape_rank == 4:
begin = Const(graph, {'value': int64_array([get_height_dim(layout, input_shape_rank)])}).create_node()
factor = Const(graph, {'value': np.array([height_scale, width_scale])}).create_node()
else:
begin = Const(graph, {'value': int64_array([get_depth_dim(layout, input_shape_rank)])}).create_node()
factor = Const(graph, {'value': np.array([depth_scale, height_scale, width_scale])}).create_node()
end = Const(graph, {'value': int64_array([get_width_dim(layout, input_shape_rank) + 1])}).create_node()
stride = Const(graph, {'value': int64_array([1])}).create_node()
ss = StridedSlice(graph, {'name': upsample.name + '/ss_0_port',
if input_shape_rank == 4:
begin_value = int64_array([get_height_dim(layout, input_shape_rank)])
factor_value = np.array([height_scale, width_scale])
else:
begin_value = int64_array([get_depth_dim(layout, input_shape_rank)])
factor_value = np.array([depth_scale, height_scale, width_scale])
ss = create_op_with_const_inputs(graph, StridedSlice,
{1: begin_value,
2: int64_array([get_width_dim(layout, input_shape_rank) + 1]),
3: int64_array([1])
},
{'name': upsample_name + '/ss_0_port',
'begin_mask': int64_array([1]),
'end_mask': int64_array([1]),
'new_axis_mask': int64_array([0]),
'shrink_axis_mask': int64_array([0]),
'ellipsis_mask': int64_array([0])}).create_node()
'ellipsis_mask': int64_array([0])
}
)
mul = Mul(graph, {'name': upsample.name + '/factor_mul_'}).create_node()
mul = create_op_node_with_second_input(graph, Mul, factor_value, {'name': upsample_name + '/factor_mul_'})
source = upsample.in_port(0).get_connection().get_source()
source.connect(shape.in_port(0))
shape.out_port(0).connect(ss.in_port(0))
begin.out_port(0).connect(ss.in_port(1))
end.out_port(0).connect(ss.in_port(2))
stride.out_port(0).connect(ss.in_port(3))
ss.out_port(0).connect(mul.in_port(0))
factor.out_port(0).connect(mul.in_port(1))
# Create Interpolate operation
if input_shape_rank == 4:
@ -130,7 +135,7 @@ class UpsampleToResample(MiddleReplacementPattern):
get_height_dim(layout, input_shape_rank),
get_width_dim(layout, input_shape_rank)])
resample_op = Interpolate(graph, dict(name='Interpolate/{}'.format(upsample.name),
resample_op = Interpolate(graph, dict(name=upsample_name + '/Interpolate',
axes=axes, mode=upsample.attrs()['mode'],
antialias=0, convert_to_resample=True)).create_node()