Fix const node non-deterministic names (part 2) (#1081)
* Fix non-deterministic node names generation in the Model Optimizer (part 2)
This commit is contained in:
parent
5d1c5ee6a9
commit
56916ace61
@ -20,7 +20,7 @@ from extensions.back.ReshapeMutation import ReshapeMutation
|
||||
from extensions.back.ReverseInputChannels import ApplyReverseChannels
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
||||
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.reshape import Reshape
|
||||
@ -224,6 +224,7 @@ class DeconvolutionNormalizer(BackReplacementPattern):
|
||||
|
||||
def replace_pattern(self, graph: Graph, match: dict):
|
||||
node = match['node']
|
||||
node_name = node.soft_get('name', node.id)
|
||||
|
||||
if 2 in node.in_ports() and not node.in_port(2).disconnected():
|
||||
# Third input represents output shape. Cutting its value according to scheme:
|
||||
@ -233,22 +234,17 @@ class DeconvolutionNormalizer(BackReplacementPattern):
|
||||
shape_src = node.in_port(2).get_source()
|
||||
node.in_port(2).disconnect()
|
||||
|
||||
begin = Const(graph, {'value': np.array([2], dtype=np.int32)}).create_node()
|
||||
end = Const(graph, {'value': np.array([in_rank], dtype=np.int32)}).create_node()
|
||||
stride = Const(graph, {'value': np.array([1], dtype=np.int32)}).create_node()
|
||||
|
||||
ss_0 = StridedSlice(graph, {'name': node.name + '/ss_0_port',
|
||||
ss_0 = create_op_with_const_inputs(graph, StridedSlice, {1: np.array([2], dtype=np.int32),
|
||||
2: np.array([in_rank], dtype=np.int32),
|
||||
3: np.array([1], dtype=np.int32)},
|
||||
{'name': node_name + '/ss_0_port',
|
||||
'begin_mask': np.array([1], dtype=np.int32),
|
||||
'end_mask': np.array([0], dtype=np.int32),
|
||||
'new_axis_mask': np.array([0], dtype=np.int32),
|
||||
'shrink_axis_mask': np.array([0], dtype=np.int32),
|
||||
'ellipsis_mask': np.array([0], dtype=np.int32)}).create_node()
|
||||
'ellipsis_mask': np.array([0], dtype=np.int32)})
|
||||
|
||||
shape_src.connect(ss_0.in_port(0))
|
||||
begin.out_port(0).connect(ss_0.in_port(1))
|
||||
end.out_port(0).connect(ss_0.in_port(2))
|
||||
stride.out_port(0).connect(ss_0.in_port(3))
|
||||
|
||||
ss_0.out_port(0).connect(node.in_port(2))
|
||||
|
||||
# Specification: *padding amount* is deduced from relation of input and output spatial shapes
|
||||
@ -256,7 +252,8 @@ class DeconvolutionNormalizer(BackReplacementPattern):
|
||||
|
||||
elif node.has_valid('original_output_spatial_shape'):
|
||||
# node had fixed output spatial shape set in original framework, so we restore it here
|
||||
const = Const(graph, {'value': int64_array(node.original_output_spatial_shape)}).create_node()
|
||||
const = Const(graph, {'value': int64_array(node.original_output_spatial_shape),
|
||||
'name': node_name + '/original_spatial_shape'}).create_node()
|
||||
node.add_input_port(2, skip_if_exist=True)
|
||||
const.out_port(0).connect(node.in_port(2))
|
||||
|
||||
|
@ -55,43 +55,52 @@ class CropToStridedSlice(BackReplacementPattern):
|
||||
def replace_pattern(self, graph: Graph, match: [str, Node]):
|
||||
node = match['crop']
|
||||
assert node.has_valid('axis')
|
||||
node.axis = self.list_to_ndarray(node.axis)
|
||||
node_axis = self.list_to_ndarray(node.axis)
|
||||
|
||||
in_shape = node.in_port(0).data.get_shape()
|
||||
shape_rank = in_shape.size
|
||||
axis_mask = int64_array([1 if i in node.axis else 0 for i in range(shape_rank)])
|
||||
axis_mask = int64_array([1 if i in node_axis else 0 for i in range(shape_rank)])
|
||||
begin_mask = axis_mask.copy()
|
||||
end_mask = axis_mask.copy()
|
||||
|
||||
ss = StridedSlice(graph, {'name': node.soft_get('name', node.id) + '/strided_slice', 'begin_mask': begin_mask,
|
||||
'end_mask': end_mask, 'new_axis_mask': np.array([0]),
|
||||
'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()
|
||||
|
||||
if len(node.in_nodes()) == 2 and node.has_valid('offset'):
|
||||
# Crop Type 1
|
||||
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node.axis, node.offset)}).create_node()
|
||||
shape = Shape(graph, {'name': node.name + '/shape_of_crop'}).create_node()
|
||||
end = Add(graph, {'name': node.name + '/end'}).create_node()
|
||||
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, node.offset),
|
||||
'name': ss.name + '/begin'}).create_node()
|
||||
shape = Shape(graph, {'name': ss.name + '/shape_of_crop'}).create_node()
|
||||
end = Add(graph, {'name': ss.name + '/end'}).create_node()
|
||||
node.in_port(1).get_connection().get_source().connect(shape.in_port(0))
|
||||
node.in_port(1).disconnect()
|
||||
shape.out_port(0).connect(end.in_port(0))
|
||||
begin.out_port(0).connect(end.in_port(1))
|
||||
elif node.has_valid('dim') and node.has_valid('offset'):
|
||||
# Crop Type 2
|
||||
node.dim = self.list_to_ndarray(node.dim)
|
||||
node.offset = self.list_to_ndarray(node.offset)
|
||||
assert node.dim.size == node.offset.size == node.axis.size
|
||||
node_dim = self.list_to_ndarray(node.dim)
|
||||
node_offset = self.list_to_ndarray(node.offset)
|
||||
assert node_dim.size == node_offset.size == node_axis.size
|
||||
|
||||
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node.axis, node.offset)}).create_node()
|
||||
end_values = np.array([node.offset[i] + node.dim[i] for i in range(len(node.dim))])
|
||||
end = Const(graph, {'value': self.mask_normalizer(shape_rank, node.axis, end_values)}).create_node()
|
||||
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, node_offset),
|
||||
'name': ss.name + '/begin'}).create_node()
|
||||
end_values = np.array([node_offset[i] + node_dim[i] for i in range(len(node_dim))])
|
||||
end = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, end_values),
|
||||
'name': ss.name + '/end'}).create_node()
|
||||
elif node.has_valid('crop_begin') and node.has_valid('crop_end'):
|
||||
# Crop Type 3
|
||||
node.crop_begin = self.list_to_ndarray(node.crop_begin)
|
||||
node.crop_end = self.list_to_ndarray(node.crop_end)
|
||||
assert len(node.crop_begin) == len(node.crop_end) == len(node.axis)
|
||||
node_crop_begin = self.list_to_ndarray(node.crop_begin)
|
||||
node_crop_end = self.list_to_ndarray(node.crop_end)
|
||||
assert len(node_crop_begin) == len(node_crop_end) == len(node_axis)
|
||||
|
||||
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node.axis, node.crop_begin)}).create_node()
|
||||
shape = Shape(graph, {'name': node.name + '/shape_of_crop'}).create_node()
|
||||
const = Const(graph,
|
||||
{'value': -1 * self.mask_normalizer(shape_rank, node.axis, node.crop_end)}).create_node()
|
||||
end = Add(graph, {'name': node.name + '/end'}).create_node()
|
||||
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, node_crop_begin),
|
||||
'name': ss.name + '/begin'}).create_node()
|
||||
shape = Shape(graph, {'name': ss.name + '/shape'}).create_node()
|
||||
|
||||
end = Add(graph, {'name': ss.name + '/end'}).create_node()
|
||||
const = Const(graph, {'value': -1 * self.mask_normalizer(shape_rank, node_axis, node_crop_end),
|
||||
'name': ss.name + '/const'}).create_node()
|
||||
|
||||
node.in_port(0).get_connection().get_source().connect(shape.in_port(0))
|
||||
shape.out_port(0).connect(end.in_port(0))
|
||||
@ -102,9 +111,8 @@ class CropToStridedSlice(BackReplacementPattern):
|
||||
|
||||
source = node.in_port(0).get_connection().get_source()
|
||||
|
||||
stride = Const(graph, {'value': np.ones(shape_rank, dtype=np.int64)}).create_node()
|
||||
ss = StridedSlice(graph, {'name': 'Crop_', 'begin_mask': begin_mask, 'end_mask': end_mask, 'new_axis_mask': np.array([0]),
|
||||
'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()
|
||||
stride = Const(graph, {'value': np.ones(shape_rank, dtype=np.int64),
|
||||
'name': ss.name + '/stride'}).create_node()
|
||||
|
||||
source.connect(ss.in_port(0))
|
||||
begin.out_port(0).connect(ss.in_port(1))
|
||||
|
@ -44,6 +44,7 @@ class GroupedConvWeightsNormalize(BackReplacementPattern):
|
||||
weights = match['weights']
|
||||
input_shape = conv.in_port(0).data.get_shape()
|
||||
new_weights_shape = int64_array([(weights.value.shape[0] * weights.value.shape[1]) / (input_shape[1] / conv.group), input_shape[1] / conv.group, *weights.value.shape[2:]])
|
||||
new_weights = Const(graph, {'value': np.reshape(weights.value, new_weights_shape)}).create_node()
|
||||
new_weights = Const(graph, {'value': np.reshape(weights.value, new_weights_shape),
|
||||
'name': weights.soft_get('name', weights.id) + '_new'}).create_node()
|
||||
weights.out_port(0).get_connection().set_source(new_weights.out_port(0))
|
||||
new_weights.infer(new_weights)
|
||||
|
@ -18,7 +18,7 @@ import numpy as np
|
||||
from extensions.back.ForceStrictPrecision import ForceStrictPrecision
|
||||
from extensions.ops.prelu import PreluOp
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.graph.graph import Graph
|
||||
from mo.graph.graph import Graph, rename_node
|
||||
from mo.ops.const import Const
|
||||
|
||||
|
||||
@ -39,11 +39,16 @@ class LeakyReLUMutation(BackReplacementPattern):
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
relu = match['leakyrelu']
|
||||
relu_name = relu.soft_get('name', relu.id)
|
||||
if not relu.has_valid('negative_slope'):
|
||||
return
|
||||
|
||||
rename_node(relu, relu_name + '/to_delete')
|
||||
# Create PReLU op and reconnect input/output from LeakyReLU to PReLU
|
||||
prelu = PreluOp(graph, dict(name=relu.name)).create_node()
|
||||
const = Const(graph, dict(name=relu.name + "/weights", value=np.array([relu.negative_slope]))).create_node()
|
||||
prelu = PreluOp(graph, dict(name=relu_name)).create_node()
|
||||
rename_node(prelu, relu_name)
|
||||
|
||||
const = Const(graph, dict(name=relu_name + "/weights", value=np.array([relu.negative_slope]))).create_node()
|
||||
|
||||
relu.in_port(0).get_connection().set_destination(prelu.in_port(0))
|
||||
const.out_port(0).connect(prelu.in_port(1))
|
||||
|
@ -19,6 +19,7 @@ import numpy as np
|
||||
from extensions.ops.transpose import Transpose
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.unsqueeze import Unsqueeze
|
||||
@ -56,13 +57,12 @@ class MatMulConstTransposesExtraction(BackReplacementPattern):
|
||||
transpose_order = list(range(port_shape.size))
|
||||
transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1]
|
||||
|
||||
order = Const(graph, {'value': int64_array(transpose_order)}).create_node()
|
||||
transpose = Transpose(graph, {'name': name + '/{}_port_transpose'.format(in_port_idx)}).create_node()
|
||||
transpose = create_op_node_with_second_input(graph, Transpose, int64_array(transpose_order),
|
||||
{'name': name + '/{}_port_transpose'.format(in_port_idx)})
|
||||
|
||||
port_source = in_port.get_source()
|
||||
in_port.get_connection().set_source(transpose.out_port(0))
|
||||
transpose.in_port(0).connect(port_source)
|
||||
transpose.in_port(1).connect(order.out_port(0))
|
||||
|
||||
transpose['override_output_shape'] = True
|
||||
|
||||
|
@ -21,8 +21,8 @@ from extensions.back.ReshapeMutation import ReshapeMutation
|
||||
from extensions.back.StridedSliceMasksNormalizer import StridedSliceMasksNormalizer
|
||||
from mo.back.replacement import BackReplacementPattern
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.reshape import Reshape
|
||||
from mo.ops.strided_slice import StridedSlice
|
||||
|
||||
@ -53,32 +53,27 @@ class ProposalMutation(BackReplacementPattern):
|
||||
'implementation of the Proposal layer uses only 4 first values (indices 0, 1, 2 and 3). '
|
||||
'Elements with indices 4 and 5 will be ignored.'.format(node.soft_get('name', node.id)),
|
||||
extra={'is_warning': True})
|
||||
begin = Const(graph, {'value': np.array([0, 0], dtype=np.int32)}).create_node()
|
||||
end = Const(graph, {'value': np.array([1, 3], dtype=np.int32)}).create_node()
|
||||
stride = Const(graph, {'value': np.array([1, 1], dtype=np.int32)}).create_node()
|
||||
|
||||
cropped_im_info = StridedSlice(graph, {'name': 'cropped_im_info',
|
||||
cropped_im_info = create_op_with_const_inputs(graph, StridedSlice, {1: np.array([0, 0], dtype=np.int32),
|
||||
2: np.array([1, 3], dtype=np.int32),
|
||||
3: np.array([1, 1], dtype=np.int32)},
|
||||
{'name': 'cropped_im_info',
|
||||
'begin_mask': int64_array([1, 1]),
|
||||
'end_mask': int64_array([1, 1]),
|
||||
'new_axis_mask': int64_array([0]),
|
||||
'shrink_axis_mask': int64_array([0]),
|
||||
'ellipsis_mask': int64_array([0]),
|
||||
'override_output_shape': True,
|
||||
}).create_node()
|
||||
})
|
||||
|
||||
node.in_port(2).get_connection().insert_node(cropped_im_info)
|
||||
begin.out_port(0).connect(cropped_im_info.in_port(1))
|
||||
end.out_port(0).connect(cropped_im_info.in_port(2))
|
||||
stride.out_port(0).connect(cropped_im_info.in_port(3))
|
||||
|
||||
# update the im_info_shape so the next 'if' statement become true
|
||||
im_info_shape = int64_array([1, 3])
|
||||
|
||||
if np.array_equal(im_info_shape, [1, 3]) or np.array_equal(im_info_shape, [1, 4]):
|
||||
reshape = Reshape(graph, dict(name="im_info/Reshape")).create_node()
|
||||
const = Const(graph, dict(value=[im_info_shape[1]])).create_node()
|
||||
reshape = create_op_node_with_second_input(graph, Reshape, [im_info_shape[1]], {'name': 'im_info/Reshape'})
|
||||
node.in_port(2).get_connection().set_destination(reshape.in_port(0))
|
||||
const.out_port(0).connect(reshape.in_port(1))
|
||||
reshape.out_port(0).connect(node.in_port(2))
|
||||
|
||||
if node.has_port('out', 1) and not node.out_port(1).disconnected():
|
||||
|
@ -25,7 +25,6 @@ from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph
|
||||
from mo.graph.graph import Node
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.op import Op, PermuteAttrs
|
||||
|
||||
|
||||
@ -358,11 +357,7 @@ class DecomposeReverseChannels(BackReplacementPattern):
|
||||
axis = node.axis
|
||||
order = node.order
|
||||
|
||||
indices = Const(graph, {'name': name + '/reverse_order', 'value': order}).create_node()
|
||||
axis_const = Const(graph, {'value': int64_array(axis)}).create_node()
|
||||
gather = Gather(graph, {'name': name}).create_node()
|
||||
gather.in_port(1).connect(indices.out_port(0))
|
||||
gather.in_port(2).connect(axis_const.out_port(0))
|
||||
gather = create_op_with_const_inputs(graph, Gather, {1: order, 2: int64_array(axis)}, {'name': name})
|
||||
|
||||
node.out_port(0).get_connection().set_source(gather.out_port(0))
|
||||
node.in_port(0).get_connection().set_destination(gather.in_port(0))
|
||||
|
@ -107,8 +107,9 @@ class CompressQuantizeWeights(BackReplacementPattern):
|
||||
|
||||
def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
|
||||
initial_fake_quantize = match['quantize']
|
||||
initial_fake_quantize_name = initial_fake_quantize.soft_get('name', initial_fake_quantize.id)
|
||||
|
||||
new_fake_quantize = initial_fake_quantize.copy_node(dict(name=initial_fake_quantize.name + '/Copy',
|
||||
new_fake_quantize = initial_fake_quantize.copy_node(dict(name=initial_fake_quantize_name + '/Copy',
|
||||
stop_value_propagation=False), graph)
|
||||
|
||||
initial_fake_quantize.in_port(1).get_connection().set_destination(new_fake_quantize.in_port(1))
|
||||
@ -121,9 +122,9 @@ class CompressQuantizeWeights(BackReplacementPattern):
|
||||
i_min = np.array([0.], dtype=dst_type)
|
||||
i_max = np.array([initial_fake_quantize.levels - 1.], dtype=dst_type)
|
||||
|
||||
new_out_low_node = Const(graph, dict(name=initial_fake_quantize.name + '/Copy/out_low',
|
||||
new_out_low_node = Const(graph, dict(name=initial_fake_quantize_name + '/Copy/out_low',
|
||||
value=i_min)).create_node()
|
||||
new_out_high_node = Const(graph, dict(name=initial_fake_quantize.name + '/Copy/out_high',
|
||||
new_out_high_node = Const(graph, dict(name=initial_fake_quantize_name + '/Copy/out_high',
|
||||
value=i_max)).create_node()
|
||||
|
||||
new_out_low_node.out_port(0).connect(new_fake_quantize.in_port(3))
|
||||
@ -131,7 +132,7 @@ class CompressQuantizeWeights(BackReplacementPattern):
|
||||
new_out_low_node.out_port(0).connect(initial_fake_quantize.in_port(1))
|
||||
new_out_high_node.out_port(0).connect(initial_fake_quantize.in_port(2))
|
||||
|
||||
cast_node = Cast(graph, dict(name=initial_fake_quantize.name + "/Convert_to_float", dst_type=dst_type,
|
||||
cast_node = Cast(graph, dict(name=initial_fake_quantize_name + "/Convert_to_float", dst_type=dst_type,
|
||||
stop_value_propagation=True)).create_node()
|
||||
new_fake_quantize.out_port(0).connect(cast_node.in_port(0))
|
||||
initial_fake_quantize.in_port(0).get_connection().set_destination(new_fake_quantize.in_port(0))
|
||||
|
@ -49,12 +49,12 @@ class PriorboxMutation(BackReplacementPattern):
|
||||
|
||||
assert len(node.in_ports()) == 2
|
||||
|
||||
begin = Const(graph, {'value': np.array([2], dtype=np.int32)}).create_node()
|
||||
end = Const(graph, {'value': np.array([4], dtype=np.int32)}).create_node()
|
||||
stride = Const(graph, {'value': np.array([1], dtype=np.int32)}).create_node()
|
||||
begin = Const(graph, {'value': np.array([2], dtype=np.int32), 'name': name + '/ss_begin'}).create_node()
|
||||
end = Const(graph, {'value': np.array([4], dtype=np.int32), 'name': name + '/ss_end'}).create_node()
|
||||
stride = Const(graph, {'value': np.array([1], dtype=np.int32), 'name': name + '/ss_stride'}).create_node()
|
||||
|
||||
shape_0 = Shape(graph, {'name': node.name + '/0_port'}).create_node()
|
||||
ss_0 = StridedSlice(graph, {'name': node.name + '/ss_0_port',
|
||||
shape_0 = Shape(graph, {'name': name + '/0_port'}).create_node()
|
||||
ss_0 = StridedSlice(graph, {'name': name + '/ss_0_port',
|
||||
'begin_mask': np.array([1], dtype=np.int32),
|
||||
'end_mask': np.array([0], dtype=np.int32),
|
||||
'new_axis_mask': np.array([0], dtype=np.int32),
|
||||
@ -71,8 +71,8 @@ class PriorboxMutation(BackReplacementPattern):
|
||||
source.connect(shape_0.in_port(0))
|
||||
ss_0.out_port(0).connect(node.in_port(0))
|
||||
|
||||
shape_1 = Shape(graph, {'name': node.name + '/1_port'}).create_node()
|
||||
ss_1 = StridedSlice(graph, {'name': node.name + '/ss_1_port',
|
||||
shape_1 = Shape(graph, {'name': name + '/1_port'}).create_node()
|
||||
ss_1 = StridedSlice(graph, {'name': name + '/ss_1_port',
|
||||
'begin_mask': np.array([1], dtype=np.int32),
|
||||
'end_mask': np.array([0], dtype=np.int32),
|
||||
'new_axis_mask': np.array([0], dtype=np.int32),
|
||||
|
@ -92,6 +92,8 @@ class BinarizeWeightsM1P1(MiddleReplacementPattern):
|
||||
output_low = quantize.in_node(3)
|
||||
output_high = quantize.in_node(4)
|
||||
|
||||
quantize_name = quantize.soft_get('name', quantize.id)
|
||||
|
||||
if not output_low.has_valid('value') and not output_high.has_valid('value'):
|
||||
return
|
||||
|
||||
@ -115,8 +117,9 @@ class BinarizeWeightsM1P1(MiddleReplacementPattern):
|
||||
|
||||
mult_term = quantize.in_node(3) if np.all(output_high == 0) else quantize.in_node(4)
|
||||
|
||||
new_shape = Const(graph, {'value': int64_array([-1, 1, 1])}).create_node_with_data()
|
||||
reshape = Reshape(graph, {}).create_node_with_data([mult_term, new_shape])
|
||||
new_shape = Const(graph, {'name': quantize_name + '/Reshape/Shape',
|
||||
'value': int64_array([-1, 1, 1])}).create_node_with_data()
|
||||
reshape = Reshape(graph, {'name': quantize_name + '/Reshape'}).create_node_with_data([mult_term, new_shape])
|
||||
|
||||
# Patch inflow path (by diving by mult_term)
|
||||
# Put a new Pow/Mul combination here:
|
||||
@ -125,15 +128,16 @@ class BinarizeWeightsM1P1(MiddleReplacementPattern):
|
||||
if len(match['quantized'].out_nodes()) > 1:
|
||||
log.debug('BinarizeWeightsM1P1: len(match[\'quantized\'].out_nodes()) > 1')
|
||||
return
|
||||
power_of_exponent = Const(graph, {'value': np.array(-1.0)}).create_node_with_data()
|
||||
div_op = Pow(graph, {'name': quantize.name + '/DivNormalize'})
|
||||
power_of_exponent = Const(graph, {'name': quantize_name + '/DivNormalize/Power',
|
||||
'value': np.array(-1.0)}).create_node_with_data()
|
||||
div_op = Pow(graph, {'name': quantize_name + '/DivNormalize'})
|
||||
div_output = div_op.create_node_with_data([mult_term, power_of_exponent])
|
||||
|
||||
for i in [3, 4]:
|
||||
match['quantize'].insert_node_with_data_before(
|
||||
match['quantize'].in_node(i),
|
||||
Mul,
|
||||
dict(name=quantize.name + '/MulNormalize'),
|
||||
dict(name=quantize_name + '/MulNormalize'),
|
||||
additional_inputs=[div_output],
|
||||
)
|
||||
|
||||
|
@ -19,9 +19,9 @@ import logging as log
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.elementwise import Mul, Add
|
||||
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
||||
from mo.graph.graph import Graph
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.const import Const
|
||||
|
||||
|
||||
class ConvToBinaryConv(MiddleReplacementPattern):
|
||||
@ -91,12 +91,10 @@ class ConvToBinaryConv(MiddleReplacementPattern):
|
||||
weights_reduced = np.add.reduce(weights, axis=tuple(reduction_indices))
|
||||
weights_reduced = weights_reduced.reshape([len(weights_reduced), 1, 1]) # FIXME: works for NCHW only
|
||||
|
||||
add_term = Const(graph, {'value': weights_reduced}).create_node()
|
||||
add = Add(graph, {}).create_node()
|
||||
add.in_port(1).connect(add_term.out_port(0))
|
||||
mul_term = Const(graph, {'value': np.array(0.5)}).create_node()
|
||||
mul = Mul(graph, {}).create_node()
|
||||
mul.in_port(1).connect(mul_term.out_port(0))
|
||||
operator_name = operator.soft_get('name', operator.id)
|
||||
add = create_op_node_with_second_input(graph, Add, weights_reduced, {'name': operator_name + '/Add_'})
|
||||
mul = create_op_node_with_second_input(graph, Mul, np.array(0.5), {'name': operator_name + '/Mul_'})
|
||||
|
||||
add.out_port(0).connect(mul.in_port(0))
|
||||
|
||||
operator.out_port(0).get_connection().set_source(mul.out_port(0))
|
||||
|
@ -189,9 +189,13 @@ class ConvertGroupedStridedSlice(MiddleReplacementPattern):
|
||||
log.debug("Removed: {}".format(node.id))
|
||||
|
||||
# 2. Create Split layer and reorder outputs
|
||||
axis_const = Const(graph, {'value': int64_array(split_channel_dim)}).create_node_with_data()
|
||||
size_splits_const = Const(graph, {'value': int64_array(size_splits)}).create_node_with_data()
|
||||
split = VariadicSplit(graph, dict(name=name_for_future_split + "/Split", out_ports_count=len(size_splits)))
|
||||
name = name_for_future_split + "/Split"
|
||||
axis_const = Const(graph, {'value': int64_array(split_channel_dim),
|
||||
'name': name + '/Axis'}).create_node_with_data()
|
||||
size_splits_const = Const(graph, {'value': int64_array(size_splits),
|
||||
'name': name + '/Sizes'}).create_node_with_data()
|
||||
split = VariadicSplit(graph, dict(name=name, out_ports_count=len(size_splits)))
|
||||
|
||||
split.create_node_with_data(inputs=[input_data, axis_const, size_splits_const],
|
||||
data_nodes=final_data_nodes_list)
|
||||
|
||||
|
@ -36,6 +36,7 @@ class ConvertLayoutDependentOperations(MiddleReplacementPattern):
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for node in list(graph.nodes()):
|
||||
node = Node(graph, node)
|
||||
node_name = node.soft_get('name', node.id)
|
||||
# Check that node layout mismatch with graph layout
|
||||
# For example: NHWC and NCHW or NCDHW and NDHWC
|
||||
if node.kind == 'op' and node.has_valid('layout') and node.layout != indices_mapping[len(node.layout)][
|
||||
@ -63,8 +64,10 @@ class ConvertLayoutDependentOperations(MiddleReplacementPattern):
|
||||
edge_attrs = graph.get_edge_data(input.id, node.id)[0]
|
||||
graph.remove_edge(input.id, node.id)
|
||||
|
||||
input_order_const = Const(graph, {'value': permutation.perm}).create_node_with_data()
|
||||
input_permute_op = Transpose(graph, dict(name=node.name + '/Transpose_'))
|
||||
input_permute_name = node_name + '/input_transpose'
|
||||
input_order_const = Const(graph, {'name': input_permute_name + '/order',
|
||||
'value': permutation.perm}).create_node_with_data()
|
||||
input_permute_op = Transpose(graph, {'name': input_permute_name})
|
||||
input_permute_data_node = input_permute_op.create_node_with_data([input, input_order_const])
|
||||
|
||||
graph.add_edge(input_permute_data_node.id, node.id, **edge_attrs)
|
||||
@ -77,8 +80,10 @@ class ConvertLayoutDependentOperations(MiddleReplacementPattern):
|
||||
input_data_node = Op.create_data_node(graph, node, {'shape': output.shape[permutation.perm]},
|
||||
edge_attrs)
|
||||
|
||||
output_order_const = Const(graph, {'value': permutation.inv}).create_node_with_data()
|
||||
output_permute_op = Transpose(graph, dict(name=node.name + '/Transpose_')
|
||||
output_permute_name = node_name + '/output_transpose'
|
||||
output_order_const = Const(graph, {'name': output_permute_name + '/order',
|
||||
'value': permutation.inv}).create_node_with_data()
|
||||
output_permute_op = Transpose(graph, {'name': output_permute_name}
|
||||
).create_node_with_data([input_data_node, output_order_const],
|
||||
data_nodes=output)
|
||||
|
||||
|
@ -46,9 +46,11 @@ class Deconvolution3rdInputNormalization(MiddleReplacementPattern):
|
||||
|
||||
data_node = node.in_node(2)
|
||||
|
||||
const = Const(graph, {'value': permutation.perm, 'need_shape_inference': True}).create_node_with_data()
|
||||
axis_const = Const(graph, {'value': int64_array(0)}).create_node_with_data()
|
||||
gather = Gather(graph, {'name': node.name + '/ShapeGather',
|
||||
name = node.soft_get('name', node.id) + '/ShapeGather'
|
||||
const = Const(graph, {'value': permutation.perm, 'name': name + '/Const',
|
||||
'need_shape_inference': True}).create_node_with_data()
|
||||
axis_const = Const(graph, {'value': int64_array(0), 'name': name + '/Axis'}).create_node_with_data()
|
||||
gather = Gather(graph, {'name': name,
|
||||
'need_shape_inference': True}).create_node_with_data([data_node, const, axis_const])
|
||||
attrs = graph.get_edge_data(data_node.id, node.id, key=0).copy()
|
||||
|
||||
|
@ -160,7 +160,8 @@ class DilatedConvolution1DConverter(MiddleReplacementPattern):
|
||||
unsqueeze_axis = unsqueeze.in_port(1).data.get_value()
|
||||
for port_id in [1, 2]:
|
||||
current_value = pad.in_port(port_id).get_connection().data.get_value()
|
||||
new_value_node = Const(pad.graph, {'value': np.insert(current_value, unsqueeze_axis.item(), 0),
|
||||
new_value_node = Const(pad.graph, {'name': pad.soft_get('name', pad.id) + '/value_{}'.format(port_id),
|
||||
'value': np.insert(current_value, unsqueeze_axis.item(), 0),
|
||||
'override_output_shape': True}).create_node()
|
||||
pad.in_port(port_id).disconnect()
|
||||
pad.in_port(port_id).connect(new_value_node.out_port(0))
|
||||
|
@ -106,8 +106,10 @@ class EltwiseInputReshape(MiddleReplacementPattern):
|
||||
# Insert Reshape layer between data node and consumer
|
||||
for shape_key in mapping.keys():
|
||||
shape = list(shape_key)
|
||||
reshape = Reshape(graph, attrs={'name': 'EltwiseReshapeNormalization'})
|
||||
reshape_dim = Const(graph, {'value': shape}).create_node_with_data()
|
||||
reshape_name = node.soft_get('name', node.id) + '/EltwiseReshape'
|
||||
reshape = Reshape(graph, attrs={'name': reshape_name})
|
||||
reshape_dim = Const(graph,
|
||||
{'value': shape, 'name': reshape_name + '/Shape'}).create_node_with_data()
|
||||
reshape_data = reshape.create_node_with_data(inputs=[node, reshape_dim])
|
||||
|
||||
# Iterate over consumers and reconnect them to Reshape layer output
|
||||
|
@ -19,9 +19,9 @@ import numpy as np
|
||||
|
||||
from extensions.ops.gather import Gather
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Graph
|
||||
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph, rename_node
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.reshape import Reshape
|
||||
|
||||
|
||||
@ -68,6 +68,7 @@ class GatherNdNormalize(MiddleReplacementPattern):
|
||||
|
||||
def replace_pattern(self, graph: Graph, match: dict):
|
||||
gather = match['GatherNd']
|
||||
gather_name = gather.soft_get('name', gather.id)
|
||||
input_shape = gather.in_node(0).shape
|
||||
indices = gather.in_node(1).value
|
||||
if indices is None:
|
||||
@ -77,26 +78,26 @@ class GatherNdNormalize(MiddleReplacementPattern):
|
||||
# 0. All needed checks that we can replace GatherNd by Gather
|
||||
gather_idx = self.indices_check(indices, input_shape)
|
||||
if gather_idx is None:
|
||||
log.warning('Node {} with op=GatherNd can\'t be normalized to op=Gather.'.format(gather.name))
|
||||
log.warning('Node {} with op=GatherNd can\'t be normalized to op=Gather.'.format(gather_name))
|
||||
return
|
||||
|
||||
# 1. Add Reshape and connect
|
||||
new_shape = int64_array([-1] + list(input_shape[indices.shape[-1]:]))
|
||||
reshape = Reshape(graph, {'name': gather.name + '/Reshape_for_GatherNd/'}).create_node()
|
||||
reshape_const_node = Const(graph, {'name': reshape.name + '/Dim', 'value': new_shape}).create_node()
|
||||
reshape = create_op_node_with_second_input(graph, Reshape, new_shape,
|
||||
{'name': gather_name + '/Reshape_for_GatherNd/'})
|
||||
gather.in_port(0).get_connection().set_destination(reshape.in_port(0))
|
||||
reshape.in_port(1).connect(reshape_const_node.out_port(0))
|
||||
|
||||
# 2. Change indices from Nd to 1d:
|
||||
new_indices = np.reshape(np.take(indices, indices=[gather_idx], axis=-1), [-1])
|
||||
new_indices_const = Const(graph, dict(value=new_indices)).create_node()
|
||||
axis_const = Const(graph, {'value': int64_array(0)}).create_node()
|
||||
|
||||
rename_node(gather, gather_name + '/to_delete')
|
||||
|
||||
# 3. Create new Gather operation and reconnect all inputs/outputs
|
||||
new_gather = Gather(graph, {'name': gather.name + '/NewGather/'}).create_node()
|
||||
new_gather = create_op_with_const_inputs(graph, Gather, {1: new_indices, 2: int64_array(0)},
|
||||
{'name': gather_name})
|
||||
rename_node(new_gather, gather_name)
|
||||
|
||||
reshape.out_port(0).connect(new_gather.in_port(0))
|
||||
new_indices_const.out_port(0).connect(new_gather.in_port(1))
|
||||
axis_const.out_port(0).connect(new_gather.in_port(2))
|
||||
|
||||
gather.out_port(0).get_connection().set_source(new_gather.out_port(0))
|
||||
|
||||
|
@ -16,9 +16,9 @@
|
||||
|
||||
from extensions.middle.pass_separator import PostMiddleStart
|
||||
from extensions.ops.transpose import Transpose
|
||||
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.op import PermuteAttrs
|
||||
|
||||
|
||||
@ -69,6 +69,10 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern):
|
||||
len(node.out_port(0).data.get_shape()) >= 4
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
|
||||
# we need to import these functions here to avoid circular dependent imports
|
||||
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
||||
|
||||
if graph.graph['layout'] != 'NHWC':
|
||||
# we check it here because this transformation is called explicitly from the pipeline
|
||||
return
|
||||
@ -80,15 +84,14 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern):
|
||||
reinterp_shape_node_id, graph.dump_graph_for_graphviz())
|
||||
input_shape = reinterp_shape_node.in_node(0).shape
|
||||
if self.is_nchw_to_nhwc_transpose_needed(reinterp_shape_node):
|
||||
order_const = Const(graph, {'value': PermuteAttrs().get_nchw_to_nhwc_permutation(len(input_shape)).perm
|
||||
}).create_node()
|
||||
permute_node = Transpose(graph,
|
||||
{'name': reinterp_shape_node.in_port(0).get_source().node.name + '/Transpose'
|
||||
}).create_node()
|
||||
permute_node = create_op_node_with_second_input(
|
||||
graph, Transpose, PermuteAttrs().get_nchw_to_nhwc_permutation(len(input_shape)).perm,
|
||||
{'name': reinterp_shape_node.in_port(0).get_source().node.name + '/Transpose'}
|
||||
)
|
||||
reinterp_shape_node.in_port(0).get_connection().insert_node(permute_node)
|
||||
order_const.out_port(0).connect(permute_node.in_port(1))
|
||||
order_const.infer(order_const)
|
||||
|
||||
order_const = permute_node.in_port(1).get_source().node
|
||||
order_const.infer(order_const)
|
||||
# do not infer the Transpose node because it should have input data node in NCHW layout (but currently
|
||||
# it is NHWC because data node attributes has not been permuted yet) and produce output in NHWC layout
|
||||
# (which is true at this moment)
|
||||
@ -107,11 +110,10 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern):
|
||||
reinterp_shape_node_id, graph.dump_graph_for_graphviz())
|
||||
output_shape = reinterp_shape_node.out_node(0).shape
|
||||
if self.is_nhwc_to_nchw_transpose_needed(reinterp_shape_node):
|
||||
order_const = Const(graph, {
|
||||
'value': PermuteAttrs().get_nhwc_to_nchw_permutation(len(output_shape)).perm}).create_node()
|
||||
permute_node = Transpose(graph, {'name': reinterp_shape_node.id + '/Transpose'}).create_node()
|
||||
permute_node = create_op_node_with_second_input(
|
||||
graph, Transpose, PermuteAttrs().get_nhwc_to_nchw_permutation(len(output_shape)).perm,
|
||||
{'name': reinterp_shape_node.id + '/Transpose'})
|
||||
reinterp_shape_node.out_port(0).get_connection().insert_node(permute_node)
|
||||
order_const.out_port(0).connect(permute_node.in_port(1))
|
||||
|
||||
# the Reshape and Transpose operations should work in original (NHWC layout) so the Transpose
|
||||
# will convert it to the NCHW
|
||||
|
@ -20,9 +20,9 @@ import numpy as np
|
||||
|
||||
from extensions.ops.normalize import NormalizeOp
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
||||
from mo.graph.graph import Graph, rename_node
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.const import Const
|
||||
|
||||
|
||||
class L2NormToNorm(MiddleReplacementPattern):
|
||||
@ -87,13 +87,13 @@ class L2NormToNorm(MiddleReplacementPattern):
|
||||
normalizel2_name = output_name + '/normalizel2'
|
||||
rename_node(match['l2_normalize'], normalizel2_name)
|
||||
|
||||
normalize_node = NormalizeOp(graph, {'name': output_name, 'eps': y,
|
||||
'across_spatial': 0, 'channel_shared': 0}).create_node()
|
||||
normalize_node = create_op_node_with_second_input(graph, NormalizeOp,
|
||||
np.ones(shape=int64_array([match['input'].shape[-1]]),
|
||||
dtype=match['input'].data_type),
|
||||
{'name': output_name, 'eps': y,
|
||||
'across_spatial': 0, 'channel_shared': 0})
|
||||
rename_node(normalize_node, output_name)
|
||||
|
||||
weights_node = Const(graph, {'value': np.ones(shape=int64_array([match['input'].shape[-1]]),
|
||||
dtype=match['input'].data_type)}).create_node()
|
||||
|
||||
match['square'].in_port(0).get_source().connect(normalize_node.in_port(0))
|
||||
|
||||
match['square'].in_port(0).disconnect()
|
||||
@ -102,5 +102,4 @@ class L2NormToNorm(MiddleReplacementPattern):
|
||||
else:
|
||||
match['l2_normalize'].in_port(0).disconnect()
|
||||
|
||||
weights_node.out_port(0).get_connection().set_destination(normalize_node.in_port(1))
|
||||
match['l2_normalize'].out_port(0).get_connection().set_source(normalize_node.out_port(0))
|
||||
|
@ -192,7 +192,7 @@ class MXNetRNNSequenceNormalize(MiddleReplacementPattern):
|
||||
input = match['input']
|
||||
if not lstm.has_num_directions:
|
||||
return
|
||||
old_data_node =lstm.out_node(0)
|
||||
old_data_node = lstm.out_node(0)
|
||||
num_directions = 2 if lstm.direction in ['bidirectional'] else 1
|
||||
mxnet_shape = lstm.out_node(0).shape.copy()
|
||||
|
||||
@ -206,18 +206,21 @@ class MXNetRNNSequenceNormalize(MiddleReplacementPattern):
|
||||
if lstm.has_num_directions:
|
||||
mo_shape = np.insert(mo_shape, 1, np.int64(num_directions))
|
||||
|
||||
new_data = Op._create_data_node(graph, name=lstm.name + '/Data/Reshape_mxnet/', attrs={'shape': mo_shape})
|
||||
lstm_name = lstm.soft_get('name', lstm.id)
|
||||
|
||||
new_data = Op._create_data_node(graph, name=lstm_name + '/Data/Reshape_mxnet/', attrs={'shape': mo_shape})
|
||||
graph.remove_edge(lstm.id, old_data_node.id)
|
||||
graph.add_edge(lstm.id, new_data.id, key=0, out=0)
|
||||
|
||||
# Add Transpose
|
||||
permute_order = Const(graph, dict(value=int64_array([0, 2, 1, 3]))).create_node_with_data()
|
||||
permute_data = Transpose(graph, dict(name=lstm.name + '/Transpose_mxnet/')
|
||||
permute_order = Const(graph, {'name': lstm_name + '/Transpose_mxnet_order',
|
||||
'value': int64_array([0, 2, 1, 3])}).create_node_with_data()
|
||||
permute_data = Transpose(graph, {'name': lstm_name + '/Transpose_mxnet/'}
|
||||
).create_node_with_data([new_data, permute_order])
|
||||
|
||||
# Add Reshape
|
||||
reshape = Reshape(graph, dict(name=lstm.name + '/Reshape_mxnet/'))
|
||||
reshape_dim_data = Const(graph, {'name': lstm.name + '/Reshape_mxnet_dim',
|
||||
reshape = Reshape(graph, {'name': lstm_name + '/Reshape_mxnet/'})
|
||||
reshape_dim_data = Const(graph, {'name': lstm_name + '/Reshape_mxnet_dim',
|
||||
'value': mxnet_shape}).create_node_with_data()
|
||||
|
||||
reshape.create_node_with_data([permute_data, reshape_dim_data], dict(), data_nodes=[old_data_node])
|
||||
|
@ -41,7 +41,8 @@ def resolve_shared_inputs(node: Node, port_ids_to_duplicate: List[int]):
|
||||
if value is None:
|
||||
log.debug('Can not duplicate due no data for in_port {} of node {}'.format(port_id, node.name))
|
||||
for node, idxs in dst_port_map.items():
|
||||
const = Const(graph, {'value': np.array(value)}).create_node()
|
||||
const = Const(graph, {'value': np.array(value),
|
||||
'name': node.soft_get('name', node.id) + '/duplicated_'}).create_node()
|
||||
for idx in idxs:
|
||||
node.in_port(idx).disconnect()
|
||||
const.out_port(0).connect(node.in_port(idx))
|
||||
|
@ -34,6 +34,6 @@ class ReverseTransposeNormalization(MiddleReplacementPattern):
|
||||
node = match['transpose']
|
||||
assert len(node.in_nodes()) == 1
|
||||
order = np.arange(len(node.in_port(0).data.get_shape()))[::-1]
|
||||
const = Const(graph, {'value': order}).create_node()
|
||||
const = Const(graph, {'value': order, 'name': node.soft_get('name', node.id) + '/Order'}).create_node()
|
||||
node.add_input_port(1, skip_if_exist=True)
|
||||
const.out_port(0).connect(node.in_port(1))
|
||||
|
@ -16,9 +16,9 @@
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.reverse_sequence import ReverseSequence
|
||||
from mo.graph.graph import Graph
|
||||
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
||||
from mo.graph.graph import Graph, rename_node
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.const import Const
|
||||
from mo.utils.error import Error
|
||||
|
||||
|
||||
@ -57,14 +57,15 @@ class ReverseToReverseSequence(MiddleReplacementPattern):
|
||||
|
||||
# 1. For ReverseSequence 1-port input is seq_lengths => create this input node
|
||||
seq_lengths = np.ones(input_data_shape[batch_axis]) * input_data_shape[seq_axis]
|
||||
const = Const(graph, dict(value=seq_lengths)).create_node()
|
||||
|
||||
reverse_name = reverse.soft_get('name', reverse.id)
|
||||
rename_node(reverse, reverse_name + '/to_delete')
|
||||
# 2. Create new ReverseSequence node and reconnect all inputs/outputs to it
|
||||
reverse_sequence = ReverseSequence(graph, {'name': reverse.name + '/ReverseSequence/',
|
||||
'seq_axis': seq_axis, 'batch_axis': batch_axis}).create_node()
|
||||
|
||||
reverse_sequence = create_op_node_with_second_input(graph, ReverseSequence, seq_lengths,
|
||||
{'name': reverse_name, 'seq_axis': seq_axis,
|
||||
'batch_axis': batch_axis})
|
||||
rename_node(reverse_sequence, reverse_name)
|
||||
reverse.in_port(0).get_connection().set_destination(reverse_sequence.in_port(0))
|
||||
const.out_port(0).connect(reverse_sequence.in_port(1))
|
||||
reverse.out_port(0).get_connection().set_source(reverse_sequence.out_port(0))
|
||||
|
||||
# 3. Delete old Reverse node
|
||||
|
@ -36,7 +36,7 @@ class SwapAxisMiddleReplacer(MiddleReplacementPattern):
|
||||
order = swapaxis.order
|
||||
|
||||
swapaxis.add_input_port(1)
|
||||
const = Const(graph, {'value': order}).create_node()
|
||||
const = Const(graph, {'value': order, 'name': swapaxis.soft_get('name', swapaxis.id) + '/Order'}).create_node()
|
||||
const.out_port(0).connect(swapaxis.in_port(1))
|
||||
|
||||
Transpose.update_node_stat(swapaxis, {'need_shape_inference': True})
|
||||
|
@ -25,9 +25,9 @@ from extensions.ops.elementwise import Mul
|
||||
from extensions.ops.interpolate import Interpolate
|
||||
from mo.front.common.layout import get_height_dim, get_width_dim, get_depth_dim
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.shape import Shape
|
||||
from mo.ops.strided_slice import StridedSlice
|
||||
|
||||
@ -55,6 +55,7 @@ class UpsampleToResample(MiddleReplacementPattern):
|
||||
def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
|
||||
log.debug('UpsampleToResample is triggered')
|
||||
upsample = match['upsample']
|
||||
upsample_name = upsample.soft_get('name', upsample.id)
|
||||
input_shape = upsample.in_port(0).data.get_shape()
|
||||
input_shape_rank = len(input_shape)
|
||||
if input_shape_rank not in [4, 5]:
|
||||
@ -67,8 +68,7 @@ class UpsampleToResample(MiddleReplacementPattern):
|
||||
return
|
||||
scales = upsample.in_node(1).value
|
||||
assert len(scales) in (4, 5), 'Supported scales rank is 4 or 5, but it is {} for node {}'.format(
|
||||
len(scales), upsample.soft_get('name', upsample.id)
|
||||
)
|
||||
len(scales), upsample_name)
|
||||
if not (math.isclose(scales[0], 1, rel_tol=1e-5) and math.isclose(scales[1], 1, rel_tol=1e-5)):
|
||||
return
|
||||
height_scale = scales[2]
|
||||
@ -81,45 +81,50 @@ class UpsampleToResample(MiddleReplacementPattern):
|
||||
|
||||
if not math.isclose(height_scale, width_scale, rel_tol=1e-5):
|
||||
log.debug('Width and height scales are not equal: {} vs {} for node {}'.format(
|
||||
width_scale, height_scale, upsample.soft_get('name')))
|
||||
width_scale, height_scale, upsample_name))
|
||||
return
|
||||
if depth_scale is not None and not math.isclose(height_scale, depth_scale, rel_tol=1e-5):
|
||||
log.debug('Depth and height scales are not equal: {} vs {} for node {}'.format(
|
||||
depth_scale, height_scale, upsample.soft_get('name')))
|
||||
depth_scale, height_scale, upsample_name))
|
||||
return
|
||||
|
||||
if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
|
||||
upsample.in_port(1).disconnect()
|
||||
|
||||
shape = Shape(graph, {'name': upsample.name + '/0_port'}).create_node()
|
||||
shape = Shape(graph, {'name': upsample_name + '/0_port'}).create_node()
|
||||
|
||||
layout = graph.graph['layout']
|
||||
if input_shape_rank == 4:
|
||||
begin = Const(graph, {'value': int64_array([get_height_dim(layout, input_shape_rank)])}).create_node()
|
||||
factor = Const(graph, {'value': np.array([height_scale, width_scale])}).create_node()
|
||||
else:
|
||||
begin = Const(graph, {'value': int64_array([get_depth_dim(layout, input_shape_rank)])}).create_node()
|
||||
factor = Const(graph, {'value': np.array([depth_scale, height_scale, width_scale])}).create_node()
|
||||
end = Const(graph, {'value': int64_array([get_width_dim(layout, input_shape_rank) + 1])}).create_node()
|
||||
|
||||
stride = Const(graph, {'value': int64_array([1])}).create_node()
|
||||
ss = StridedSlice(graph, {'name': upsample.name + '/ss_0_port',
|
||||
if input_shape_rank == 4:
|
||||
begin_value = int64_array([get_height_dim(layout, input_shape_rank)])
|
||||
factor_value = np.array([height_scale, width_scale])
|
||||
else:
|
||||
begin_value = int64_array([get_depth_dim(layout, input_shape_rank)])
|
||||
factor_value = np.array([depth_scale, height_scale, width_scale])
|
||||
|
||||
|
||||
|
||||
ss = create_op_with_const_inputs(graph, StridedSlice,
|
||||
{1: begin_value,
|
||||
2: int64_array([get_width_dim(layout, input_shape_rank) + 1]),
|
||||
3: int64_array([1])
|
||||
},
|
||||
{'name': upsample_name + '/ss_0_port',
|
||||
'begin_mask': int64_array([1]),
|
||||
'end_mask': int64_array([1]),
|
||||
'new_axis_mask': int64_array([0]),
|
||||
'shrink_axis_mask': int64_array([0]),
|
||||
'ellipsis_mask': int64_array([0])}).create_node()
|
||||
'ellipsis_mask': int64_array([0])
|
||||
}
|
||||
)
|
||||
|
||||
mul = Mul(graph, {'name': upsample.name + '/factor_mul_'}).create_node()
|
||||
mul = create_op_node_with_second_input(graph, Mul, factor_value, {'name': upsample_name + '/factor_mul_'})
|
||||
|
||||
source = upsample.in_port(0).get_connection().get_source()
|
||||
source.connect(shape.in_port(0))
|
||||
shape.out_port(0).connect(ss.in_port(0))
|
||||
begin.out_port(0).connect(ss.in_port(1))
|
||||
end.out_port(0).connect(ss.in_port(2))
|
||||
stride.out_port(0).connect(ss.in_port(3))
|
||||
|
||||
ss.out_port(0).connect(mul.in_port(0))
|
||||
factor.out_port(0).connect(mul.in_port(1))
|
||||
|
||||
# Create Interpolate operation
|
||||
if input_shape_rank == 4:
|
||||
@ -130,7 +135,7 @@ class UpsampleToResample(MiddleReplacementPattern):
|
||||
get_height_dim(layout, input_shape_rank),
|
||||
get_width_dim(layout, input_shape_rank)])
|
||||
|
||||
resample_op = Interpolate(graph, dict(name='Interpolate/{}'.format(upsample.name),
|
||||
resample_op = Interpolate(graph, dict(name=upsample_name + '/Interpolate',
|
||||
axes=axes, mode=upsample.attrs()['mode'],
|
||||
antialias=0, convert_to_resample=True)).create_node()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user