Fix const node non-deterministic names (part 2) (#1081)
* Fix non-deterministic node names generation in the Model Optimizer (part 2)
This commit is contained in:
parent
5d1c5ee6a9
commit
56916ace61
@ -20,7 +20,7 @@ from extensions.back.ReshapeMutation import ReshapeMutation
|
|||||||
from extensions.back.ReverseInputChannels import ApplyReverseChannels
|
from extensions.back.ReverseInputChannels import ApplyReverseChannels
|
||||||
from mo.back.replacement import BackReplacementPattern
|
from mo.back.replacement import BackReplacementPattern
|
||||||
from mo.front.common.partial_infer.utils import int64_array
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
|
||||||
from mo.graph.graph import Graph
|
from mo.graph.graph import Graph
|
||||||
from mo.ops.const import Const
|
from mo.ops.const import Const
|
||||||
from mo.ops.reshape import Reshape
|
from mo.ops.reshape import Reshape
|
||||||
@ -224,6 +224,7 @@ class DeconvolutionNormalizer(BackReplacementPattern):
|
|||||||
|
|
||||||
def replace_pattern(self, graph: Graph, match: dict):
|
def replace_pattern(self, graph: Graph, match: dict):
|
||||||
node = match['node']
|
node = match['node']
|
||||||
|
node_name = node.soft_get('name', node.id)
|
||||||
|
|
||||||
if 2 in node.in_ports() and not node.in_port(2).disconnected():
|
if 2 in node.in_ports() and not node.in_port(2).disconnected():
|
||||||
# Third input represents output shape. Cutting its value according to scheme:
|
# Third input represents output shape. Cutting its value according to scheme:
|
||||||
@ -233,22 +234,17 @@ class DeconvolutionNormalizer(BackReplacementPattern):
|
|||||||
shape_src = node.in_port(2).get_source()
|
shape_src = node.in_port(2).get_source()
|
||||||
node.in_port(2).disconnect()
|
node.in_port(2).disconnect()
|
||||||
|
|
||||||
begin = Const(graph, {'value': np.array([2], dtype=np.int32)}).create_node()
|
ss_0 = create_op_with_const_inputs(graph, StridedSlice, {1: np.array([2], dtype=np.int32),
|
||||||
end = Const(graph, {'value': np.array([in_rank], dtype=np.int32)}).create_node()
|
2: np.array([in_rank], dtype=np.int32),
|
||||||
stride = Const(graph, {'value': np.array([1], dtype=np.int32)}).create_node()
|
3: np.array([1], dtype=np.int32)},
|
||||||
|
{'name': node_name + '/ss_0_port',
|
||||||
ss_0 = StridedSlice(graph, {'name': node.name + '/ss_0_port',
|
|
||||||
'begin_mask': np.array([1], dtype=np.int32),
|
'begin_mask': np.array([1], dtype=np.int32),
|
||||||
'end_mask': np.array([0], dtype=np.int32),
|
'end_mask': np.array([0], dtype=np.int32),
|
||||||
'new_axis_mask': np.array([0], dtype=np.int32),
|
'new_axis_mask': np.array([0], dtype=np.int32),
|
||||||
'shrink_axis_mask': np.array([0], dtype=np.int32),
|
'shrink_axis_mask': np.array([0], dtype=np.int32),
|
||||||
'ellipsis_mask': np.array([0], dtype=np.int32)}).create_node()
|
'ellipsis_mask': np.array([0], dtype=np.int32)})
|
||||||
|
|
||||||
shape_src.connect(ss_0.in_port(0))
|
shape_src.connect(ss_0.in_port(0))
|
||||||
begin.out_port(0).connect(ss_0.in_port(1))
|
|
||||||
end.out_port(0).connect(ss_0.in_port(2))
|
|
||||||
stride.out_port(0).connect(ss_0.in_port(3))
|
|
||||||
|
|
||||||
ss_0.out_port(0).connect(node.in_port(2))
|
ss_0.out_port(0).connect(node.in_port(2))
|
||||||
|
|
||||||
# Specification: *padding amount* is deduced from relation of input and output spatial shapes
|
# Specification: *padding amount* is deduced from relation of input and output spatial shapes
|
||||||
@ -256,7 +252,8 @@ class DeconvolutionNormalizer(BackReplacementPattern):
|
|||||||
|
|
||||||
elif node.has_valid('original_output_spatial_shape'):
|
elif node.has_valid('original_output_spatial_shape'):
|
||||||
# node had fixed output spatial shape set in original framework, so we restore it here
|
# node had fixed output spatial shape set in original framework, so we restore it here
|
||||||
const = Const(graph, {'value': int64_array(node.original_output_spatial_shape)}).create_node()
|
const = Const(graph, {'value': int64_array(node.original_output_spatial_shape),
|
||||||
|
'name': node_name + '/original_spatial_shape'}).create_node()
|
||||||
node.add_input_port(2, skip_if_exist=True)
|
node.add_input_port(2, skip_if_exist=True)
|
||||||
const.out_port(0).connect(node.in_port(2))
|
const.out_port(0).connect(node.in_port(2))
|
||||||
|
|
||||||
|
@ -55,43 +55,52 @@ class CropToStridedSlice(BackReplacementPattern):
|
|||||||
def replace_pattern(self, graph: Graph, match: [str, Node]):
|
def replace_pattern(self, graph: Graph, match: [str, Node]):
|
||||||
node = match['crop']
|
node = match['crop']
|
||||||
assert node.has_valid('axis')
|
assert node.has_valid('axis')
|
||||||
node.axis = self.list_to_ndarray(node.axis)
|
node_axis = self.list_to_ndarray(node.axis)
|
||||||
|
|
||||||
in_shape = node.in_port(0).data.get_shape()
|
in_shape = node.in_port(0).data.get_shape()
|
||||||
shape_rank = in_shape.size
|
shape_rank = in_shape.size
|
||||||
axis_mask = int64_array([1 if i in node.axis else 0 for i in range(shape_rank)])
|
axis_mask = int64_array([1 if i in node_axis else 0 for i in range(shape_rank)])
|
||||||
begin_mask = axis_mask.copy()
|
begin_mask = axis_mask.copy()
|
||||||
end_mask = axis_mask.copy()
|
end_mask = axis_mask.copy()
|
||||||
|
|
||||||
|
ss = StridedSlice(graph, {'name': node.soft_get('name', node.id) + '/strided_slice', 'begin_mask': begin_mask,
|
||||||
|
'end_mask': end_mask, 'new_axis_mask': np.array([0]),
|
||||||
|
'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()
|
||||||
|
|
||||||
if len(node.in_nodes()) == 2 and node.has_valid('offset'):
|
if len(node.in_nodes()) == 2 and node.has_valid('offset'):
|
||||||
# Crop Type 1
|
# Crop Type 1
|
||||||
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node.axis, node.offset)}).create_node()
|
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, node.offset),
|
||||||
shape = Shape(graph, {'name': node.name + '/shape_of_crop'}).create_node()
|
'name': ss.name + '/begin'}).create_node()
|
||||||
end = Add(graph, {'name': node.name + '/end'}).create_node()
|
shape = Shape(graph, {'name': ss.name + '/shape_of_crop'}).create_node()
|
||||||
|
end = Add(graph, {'name': ss.name + '/end'}).create_node()
|
||||||
node.in_port(1).get_connection().get_source().connect(shape.in_port(0))
|
node.in_port(1).get_connection().get_source().connect(shape.in_port(0))
|
||||||
node.in_port(1).disconnect()
|
node.in_port(1).disconnect()
|
||||||
shape.out_port(0).connect(end.in_port(0))
|
shape.out_port(0).connect(end.in_port(0))
|
||||||
begin.out_port(0).connect(end.in_port(1))
|
begin.out_port(0).connect(end.in_port(1))
|
||||||
elif node.has_valid('dim') and node.has_valid('offset'):
|
elif node.has_valid('dim') and node.has_valid('offset'):
|
||||||
# Crop Type 2
|
# Crop Type 2
|
||||||
node.dim = self.list_to_ndarray(node.dim)
|
node_dim = self.list_to_ndarray(node.dim)
|
||||||
node.offset = self.list_to_ndarray(node.offset)
|
node_offset = self.list_to_ndarray(node.offset)
|
||||||
assert node.dim.size == node.offset.size == node.axis.size
|
assert node_dim.size == node_offset.size == node_axis.size
|
||||||
|
|
||||||
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node.axis, node.offset)}).create_node()
|
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, node_offset),
|
||||||
end_values = np.array([node.offset[i] + node.dim[i] for i in range(len(node.dim))])
|
'name': ss.name + '/begin'}).create_node()
|
||||||
end = Const(graph, {'value': self.mask_normalizer(shape_rank, node.axis, end_values)}).create_node()
|
end_values = np.array([node_offset[i] + node_dim[i] for i in range(len(node_dim))])
|
||||||
|
end = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, end_values),
|
||||||
|
'name': ss.name + '/end'}).create_node()
|
||||||
elif node.has_valid('crop_begin') and node.has_valid('crop_end'):
|
elif node.has_valid('crop_begin') and node.has_valid('crop_end'):
|
||||||
# Crop Type 3
|
# Crop Type 3
|
||||||
node.crop_begin = self.list_to_ndarray(node.crop_begin)
|
node_crop_begin = self.list_to_ndarray(node.crop_begin)
|
||||||
node.crop_end = self.list_to_ndarray(node.crop_end)
|
node_crop_end = self.list_to_ndarray(node.crop_end)
|
||||||
assert len(node.crop_begin) == len(node.crop_end) == len(node.axis)
|
assert len(node_crop_begin) == len(node_crop_end) == len(node_axis)
|
||||||
|
|
||||||
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node.axis, node.crop_begin)}).create_node()
|
begin = Const(graph, {'value': self.mask_normalizer(shape_rank, node_axis, node_crop_begin),
|
||||||
shape = Shape(graph, {'name': node.name + '/shape_of_crop'}).create_node()
|
'name': ss.name + '/begin'}).create_node()
|
||||||
const = Const(graph,
|
shape = Shape(graph, {'name': ss.name + '/shape'}).create_node()
|
||||||
{'value': -1 * self.mask_normalizer(shape_rank, node.axis, node.crop_end)}).create_node()
|
|
||||||
end = Add(graph, {'name': node.name + '/end'}).create_node()
|
end = Add(graph, {'name': ss.name + '/end'}).create_node()
|
||||||
|
const = Const(graph, {'value': -1 * self.mask_normalizer(shape_rank, node_axis, node_crop_end),
|
||||||
|
'name': ss.name + '/const'}).create_node()
|
||||||
|
|
||||||
node.in_port(0).get_connection().get_source().connect(shape.in_port(0))
|
node.in_port(0).get_connection().get_source().connect(shape.in_port(0))
|
||||||
shape.out_port(0).connect(end.in_port(0))
|
shape.out_port(0).connect(end.in_port(0))
|
||||||
@ -102,9 +111,8 @@ class CropToStridedSlice(BackReplacementPattern):
|
|||||||
|
|
||||||
source = node.in_port(0).get_connection().get_source()
|
source = node.in_port(0).get_connection().get_source()
|
||||||
|
|
||||||
stride = Const(graph, {'value': np.ones(shape_rank, dtype=np.int64)}).create_node()
|
stride = Const(graph, {'value': np.ones(shape_rank, dtype=np.int64),
|
||||||
ss = StridedSlice(graph, {'name': 'Crop_', 'begin_mask': begin_mask, 'end_mask': end_mask, 'new_axis_mask': np.array([0]),
|
'name': ss.name + '/stride'}).create_node()
|
||||||
'shrink_axis_mask': np.array([0]), 'ellipsis_mask': np.array([0])}).create_node()
|
|
||||||
|
|
||||||
source.connect(ss.in_port(0))
|
source.connect(ss.in_port(0))
|
||||||
begin.out_port(0).connect(ss.in_port(1))
|
begin.out_port(0).connect(ss.in_port(1))
|
||||||
|
@ -44,6 +44,7 @@ class GroupedConvWeightsNormalize(BackReplacementPattern):
|
|||||||
weights = match['weights']
|
weights = match['weights']
|
||||||
input_shape = conv.in_port(0).data.get_shape()
|
input_shape = conv.in_port(0).data.get_shape()
|
||||||
new_weights_shape = int64_array([(weights.value.shape[0] * weights.value.shape[1]) / (input_shape[1] / conv.group), input_shape[1] / conv.group, *weights.value.shape[2:]])
|
new_weights_shape = int64_array([(weights.value.shape[0] * weights.value.shape[1]) / (input_shape[1] / conv.group), input_shape[1] / conv.group, *weights.value.shape[2:]])
|
||||||
new_weights = Const(graph, {'value': np.reshape(weights.value, new_weights_shape)}).create_node()
|
new_weights = Const(graph, {'value': np.reshape(weights.value, new_weights_shape),
|
||||||
|
'name': weights.soft_get('name', weights.id) + '_new'}).create_node()
|
||||||
weights.out_port(0).get_connection().set_source(new_weights.out_port(0))
|
weights.out_port(0).get_connection().set_source(new_weights.out_port(0))
|
||||||
new_weights.infer(new_weights)
|
new_weights.infer(new_weights)
|
||||||
|
@ -18,7 +18,7 @@ import numpy as np
|
|||||||
from extensions.back.ForceStrictPrecision import ForceStrictPrecision
|
from extensions.back.ForceStrictPrecision import ForceStrictPrecision
|
||||||
from extensions.ops.prelu import PreluOp
|
from extensions.ops.prelu import PreluOp
|
||||||
from mo.back.replacement import BackReplacementPattern
|
from mo.back.replacement import BackReplacementPattern
|
||||||
from mo.graph.graph import Graph
|
from mo.graph.graph import Graph, rename_node
|
||||||
from mo.ops.const import Const
|
from mo.ops.const import Const
|
||||||
|
|
||||||
|
|
||||||
@ -39,11 +39,16 @@ class LeakyReLUMutation(BackReplacementPattern):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def replace_pattern(graph: Graph, match: dict):
|
def replace_pattern(graph: Graph, match: dict):
|
||||||
relu = match['leakyrelu']
|
relu = match['leakyrelu']
|
||||||
|
relu_name = relu.soft_get('name', relu.id)
|
||||||
if not relu.has_valid('negative_slope'):
|
if not relu.has_valid('negative_slope'):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
rename_node(relu, relu_name + '/to_delete')
|
||||||
# Create PReLU op and reconnect input/output from LeakyReLU to PReLU
|
# Create PReLU op and reconnect input/output from LeakyReLU to PReLU
|
||||||
prelu = PreluOp(graph, dict(name=relu.name)).create_node()
|
prelu = PreluOp(graph, dict(name=relu_name)).create_node()
|
||||||
const = Const(graph, dict(name=relu.name + "/weights", value=np.array([relu.negative_slope]))).create_node()
|
rename_node(prelu, relu_name)
|
||||||
|
|
||||||
|
const = Const(graph, dict(name=relu_name + "/weights", value=np.array([relu.negative_slope]))).create_node()
|
||||||
|
|
||||||
relu.in_port(0).get_connection().set_destination(prelu.in_port(0))
|
relu.in_port(0).get_connection().set_destination(prelu.in_port(0))
|
||||||
const.out_port(0).connect(prelu.in_port(1))
|
const.out_port(0).connect(prelu.in_port(1))
|
||||||
|
@ -19,6 +19,7 @@ import numpy as np
|
|||||||
from extensions.ops.transpose import Transpose
|
from extensions.ops.transpose import Transpose
|
||||||
from mo.back.replacement import BackReplacementPattern
|
from mo.back.replacement import BackReplacementPattern
|
||||||
from mo.front.common.partial_infer.utils import int64_array
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
|
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
||||||
from mo.graph.graph import Graph
|
from mo.graph.graph import Graph
|
||||||
from mo.ops.const import Const
|
from mo.ops.const import Const
|
||||||
from mo.ops.unsqueeze import Unsqueeze
|
from mo.ops.unsqueeze import Unsqueeze
|
||||||
@ -56,13 +57,12 @@ class MatMulConstTransposesExtraction(BackReplacementPattern):
|
|||||||
transpose_order = list(range(port_shape.size))
|
transpose_order = list(range(port_shape.size))
|
||||||
transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1]
|
transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1]
|
||||||
|
|
||||||
order = Const(graph, {'value': int64_array(transpose_order)}).create_node()
|
transpose = create_op_node_with_second_input(graph, Transpose, int64_array(transpose_order),
|
||||||
transpose = Transpose(graph, {'name': name + '/{}_port_transpose'.format(in_port_idx)}).create_node()
|
{'name': name + '/{}_port_transpose'.format(in_port_idx)})
|
||||||
|
|
||||||
port_source = in_port.get_source()
|
port_source = in_port.get_source()
|
||||||
in_port.get_connection().set_source(transpose.out_port(0))
|
in_port.get_connection().set_source(transpose.out_port(0))
|
||||||
transpose.in_port(0).connect(port_source)
|
transpose.in_port(0).connect(port_source)
|
||||||
transpose.in_port(1).connect(order.out_port(0))
|
|
||||||
|
|
||||||
transpose['override_output_shape'] = True
|
transpose['override_output_shape'] = True
|
||||||
|
|
||||||
|
@ -21,8 +21,8 @@ from extensions.back.ReshapeMutation import ReshapeMutation
|
|||||||
from extensions.back.StridedSliceMasksNormalizer import StridedSliceMasksNormalizer
|
from extensions.back.StridedSliceMasksNormalizer import StridedSliceMasksNormalizer
|
||||||
from mo.back.replacement import BackReplacementPattern
|
from mo.back.replacement import BackReplacementPattern
|
||||||
from mo.front.common.partial_infer.utils import int64_array
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
|
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
|
||||||
from mo.graph.graph import Graph
|
from mo.graph.graph import Graph
|
||||||
from mo.ops.const import Const
|
|
||||||
from mo.ops.reshape import Reshape
|
from mo.ops.reshape import Reshape
|
||||||
from mo.ops.strided_slice import StridedSlice
|
from mo.ops.strided_slice import StridedSlice
|
||||||
|
|
||||||
@ -53,32 +53,27 @@ class ProposalMutation(BackReplacementPattern):
|
|||||||
'implementation of the Proposal layer uses only 4 first values (indices 0, 1, 2 and 3). '
|
'implementation of the Proposal layer uses only 4 first values (indices 0, 1, 2 and 3). '
|
||||||
'Elements with indices 4 and 5 will be ignored.'.format(node.soft_get('name', node.id)),
|
'Elements with indices 4 and 5 will be ignored.'.format(node.soft_get('name', node.id)),
|
||||||
extra={'is_warning': True})
|
extra={'is_warning': True})
|
||||||
begin = Const(graph, {'value': np.array([0, 0], dtype=np.int32)}).create_node()
|
|
||||||
end = Const(graph, {'value': np.array([1, 3], dtype=np.int32)}).create_node()
|
|
||||||
stride = Const(graph, {'value': np.array([1, 1], dtype=np.int32)}).create_node()
|
|
||||||
|
|
||||||
cropped_im_info = StridedSlice(graph, {'name': 'cropped_im_info',
|
cropped_im_info = create_op_with_const_inputs(graph, StridedSlice, {1: np.array([0, 0], dtype=np.int32),
|
||||||
|
2: np.array([1, 3], dtype=np.int32),
|
||||||
|
3: np.array([1, 1], dtype=np.int32)},
|
||||||
|
{'name': 'cropped_im_info',
|
||||||
'begin_mask': int64_array([1, 1]),
|
'begin_mask': int64_array([1, 1]),
|
||||||
'end_mask': int64_array([1, 1]),
|
'end_mask': int64_array([1, 1]),
|
||||||
'new_axis_mask': int64_array([0]),
|
'new_axis_mask': int64_array([0]),
|
||||||
'shrink_axis_mask': int64_array([0]),
|
'shrink_axis_mask': int64_array([0]),
|
||||||
'ellipsis_mask': int64_array([0]),
|
'ellipsis_mask': int64_array([0]),
|
||||||
'override_output_shape': True,
|
'override_output_shape': True,
|
||||||
}).create_node()
|
})
|
||||||
|
|
||||||
node.in_port(2).get_connection().insert_node(cropped_im_info)
|
node.in_port(2).get_connection().insert_node(cropped_im_info)
|
||||||
begin.out_port(0).connect(cropped_im_info.in_port(1))
|
|
||||||
end.out_port(0).connect(cropped_im_info.in_port(2))
|
|
||||||
stride.out_port(0).connect(cropped_im_info.in_port(3))
|
|
||||||
|
|
||||||
# update the im_info_shape so the next 'if' statement become true
|
# update the im_info_shape so the next 'if' statement become true
|
||||||
im_info_shape = int64_array([1, 3])
|
im_info_shape = int64_array([1, 3])
|
||||||
|
|
||||||
if np.array_equal(im_info_shape, [1, 3]) or np.array_equal(im_info_shape, [1, 4]):
|
if np.array_equal(im_info_shape, [1, 3]) or np.array_equal(im_info_shape, [1, 4]):
|
||||||
reshape = Reshape(graph, dict(name="im_info/Reshape")).create_node()
|
reshape = create_op_node_with_second_input(graph, Reshape, [im_info_shape[1]], {'name': 'im_info/Reshape'})
|
||||||
const = Const(graph, dict(value=[im_info_shape[1]])).create_node()
|
|
||||||
node.in_port(2).get_connection().set_destination(reshape.in_port(0))
|
node.in_port(2).get_connection().set_destination(reshape.in_port(0))
|
||||||
const.out_port(0).connect(reshape.in_port(1))
|
|
||||||
reshape.out_port(0).connect(node.in_port(2))
|
reshape.out_port(0).connect(node.in_port(2))
|
||||||
|
|
||||||
if node.has_port('out', 1) and not node.out_port(1).disconnected():
|
if node.has_port('out', 1) and not node.out_port(1).disconnected():
|
||||||
|
@ -25,7 +25,6 @@ from mo.front.tf.graph_utils import create_op_with_const_inputs
|
|||||||
from mo.graph.graph import Graph
|
from mo.graph.graph import Graph
|
||||||
from mo.graph.graph import Node
|
from mo.graph.graph import Node
|
||||||
from mo.ops.concat import Concat
|
from mo.ops.concat import Concat
|
||||||
from mo.ops.const import Const
|
|
||||||
from mo.ops.op import Op, PermuteAttrs
|
from mo.ops.op import Op, PermuteAttrs
|
||||||
|
|
||||||
|
|
||||||
@ -358,11 +357,7 @@ class DecomposeReverseChannels(BackReplacementPattern):
|
|||||||
axis = node.axis
|
axis = node.axis
|
||||||
order = node.order
|
order = node.order
|
||||||
|
|
||||||
indices = Const(graph, {'name': name + '/reverse_order', 'value': order}).create_node()
|
gather = create_op_with_const_inputs(graph, Gather, {1: order, 2: int64_array(axis)}, {'name': name})
|
||||||
axis_const = Const(graph, {'value': int64_array(axis)}).create_node()
|
|
||||||
gather = Gather(graph, {'name': name}).create_node()
|
|
||||||
gather.in_port(1).connect(indices.out_port(0))
|
|
||||||
gather.in_port(2).connect(axis_const.out_port(0))
|
|
||||||
|
|
||||||
node.out_port(0).get_connection().set_source(gather.out_port(0))
|
node.out_port(0).get_connection().set_source(gather.out_port(0))
|
||||||
node.in_port(0).get_connection().set_destination(gather.in_port(0))
|
node.in_port(0).get_connection().set_destination(gather.in_port(0))
|
||||||
|
@ -107,8 +107,9 @@ class CompressQuantizeWeights(BackReplacementPattern):
|
|||||||
|
|
||||||
def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
|
def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
|
||||||
initial_fake_quantize = match['quantize']
|
initial_fake_quantize = match['quantize']
|
||||||
|
initial_fake_quantize_name = initial_fake_quantize.soft_get('name', initial_fake_quantize.id)
|
||||||
|
|
||||||
new_fake_quantize = initial_fake_quantize.copy_node(dict(name=initial_fake_quantize.name + '/Copy',
|
new_fake_quantize = initial_fake_quantize.copy_node(dict(name=initial_fake_quantize_name + '/Copy',
|
||||||
stop_value_propagation=False), graph)
|
stop_value_propagation=False), graph)
|
||||||
|
|
||||||
initial_fake_quantize.in_port(1).get_connection().set_destination(new_fake_quantize.in_port(1))
|
initial_fake_quantize.in_port(1).get_connection().set_destination(new_fake_quantize.in_port(1))
|
||||||
@ -121,9 +122,9 @@ class CompressQuantizeWeights(BackReplacementPattern):
|
|||||||
i_min = np.array([0.], dtype=dst_type)
|
i_min = np.array([0.], dtype=dst_type)
|
||||||
i_max = np.array([initial_fake_quantize.levels - 1.], dtype=dst_type)
|
i_max = np.array([initial_fake_quantize.levels - 1.], dtype=dst_type)
|
||||||
|
|
||||||
new_out_low_node = Const(graph, dict(name=initial_fake_quantize.name + '/Copy/out_low',
|
new_out_low_node = Const(graph, dict(name=initial_fake_quantize_name + '/Copy/out_low',
|
||||||
value=i_min)).create_node()
|
value=i_min)).create_node()
|
||||||
new_out_high_node = Const(graph, dict(name=initial_fake_quantize.name + '/Copy/out_high',
|
new_out_high_node = Const(graph, dict(name=initial_fake_quantize_name + '/Copy/out_high',
|
||||||
value=i_max)).create_node()
|
value=i_max)).create_node()
|
||||||
|
|
||||||
new_out_low_node.out_port(0).connect(new_fake_quantize.in_port(3))
|
new_out_low_node.out_port(0).connect(new_fake_quantize.in_port(3))
|
||||||
@ -131,7 +132,7 @@ class CompressQuantizeWeights(BackReplacementPattern):
|
|||||||
new_out_low_node.out_port(0).connect(initial_fake_quantize.in_port(1))
|
new_out_low_node.out_port(0).connect(initial_fake_quantize.in_port(1))
|
||||||
new_out_high_node.out_port(0).connect(initial_fake_quantize.in_port(2))
|
new_out_high_node.out_port(0).connect(initial_fake_quantize.in_port(2))
|
||||||
|
|
||||||
cast_node = Cast(graph, dict(name=initial_fake_quantize.name + "/Convert_to_float", dst_type=dst_type,
|
cast_node = Cast(graph, dict(name=initial_fake_quantize_name + "/Convert_to_float", dst_type=dst_type,
|
||||||
stop_value_propagation=True)).create_node()
|
stop_value_propagation=True)).create_node()
|
||||||
new_fake_quantize.out_port(0).connect(cast_node.in_port(0))
|
new_fake_quantize.out_port(0).connect(cast_node.in_port(0))
|
||||||
initial_fake_quantize.in_port(0).get_connection().set_destination(new_fake_quantize.in_port(0))
|
initial_fake_quantize.in_port(0).get_connection().set_destination(new_fake_quantize.in_port(0))
|
||||||
|
@ -49,12 +49,12 @@ class PriorboxMutation(BackReplacementPattern):
|
|||||||
|
|
||||||
assert len(node.in_ports()) == 2
|
assert len(node.in_ports()) == 2
|
||||||
|
|
||||||
begin = Const(graph, {'value': np.array([2], dtype=np.int32)}).create_node()
|
begin = Const(graph, {'value': np.array([2], dtype=np.int32), 'name': name + '/ss_begin'}).create_node()
|
||||||
end = Const(graph, {'value': np.array([4], dtype=np.int32)}).create_node()
|
end = Const(graph, {'value': np.array([4], dtype=np.int32), 'name': name + '/ss_end'}).create_node()
|
||||||
stride = Const(graph, {'value': np.array([1], dtype=np.int32)}).create_node()
|
stride = Const(graph, {'value': np.array([1], dtype=np.int32), 'name': name + '/ss_stride'}).create_node()
|
||||||
|
|
||||||
shape_0 = Shape(graph, {'name': node.name + '/0_port'}).create_node()
|
shape_0 = Shape(graph, {'name': name + '/0_port'}).create_node()
|
||||||
ss_0 = StridedSlice(graph, {'name': node.name + '/ss_0_port',
|
ss_0 = StridedSlice(graph, {'name': name + '/ss_0_port',
|
||||||
'begin_mask': np.array([1], dtype=np.int32),
|
'begin_mask': np.array([1], dtype=np.int32),
|
||||||
'end_mask': np.array([0], dtype=np.int32),
|
'end_mask': np.array([0], dtype=np.int32),
|
||||||
'new_axis_mask': np.array([0], dtype=np.int32),
|
'new_axis_mask': np.array([0], dtype=np.int32),
|
||||||
@ -71,8 +71,8 @@ class PriorboxMutation(BackReplacementPattern):
|
|||||||
source.connect(shape_0.in_port(0))
|
source.connect(shape_0.in_port(0))
|
||||||
ss_0.out_port(0).connect(node.in_port(0))
|
ss_0.out_port(0).connect(node.in_port(0))
|
||||||
|
|
||||||
shape_1 = Shape(graph, {'name': node.name + '/1_port'}).create_node()
|
shape_1 = Shape(graph, {'name': name + '/1_port'}).create_node()
|
||||||
ss_1 = StridedSlice(graph, {'name': node.name + '/ss_1_port',
|
ss_1 = StridedSlice(graph, {'name': name + '/ss_1_port',
|
||||||
'begin_mask': np.array([1], dtype=np.int32),
|
'begin_mask': np.array([1], dtype=np.int32),
|
||||||
'end_mask': np.array([0], dtype=np.int32),
|
'end_mask': np.array([0], dtype=np.int32),
|
||||||
'new_axis_mask': np.array([0], dtype=np.int32),
|
'new_axis_mask': np.array([0], dtype=np.int32),
|
||||||
|
@ -92,6 +92,8 @@ class BinarizeWeightsM1P1(MiddleReplacementPattern):
|
|||||||
output_low = quantize.in_node(3)
|
output_low = quantize.in_node(3)
|
||||||
output_high = quantize.in_node(4)
|
output_high = quantize.in_node(4)
|
||||||
|
|
||||||
|
quantize_name = quantize.soft_get('name', quantize.id)
|
||||||
|
|
||||||
if not output_low.has_valid('value') and not output_high.has_valid('value'):
|
if not output_low.has_valid('value') and not output_high.has_valid('value'):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -115,8 +117,9 @@ class BinarizeWeightsM1P1(MiddleReplacementPattern):
|
|||||||
|
|
||||||
mult_term = quantize.in_node(3) if np.all(output_high == 0) else quantize.in_node(4)
|
mult_term = quantize.in_node(3) if np.all(output_high == 0) else quantize.in_node(4)
|
||||||
|
|
||||||
new_shape = Const(graph, {'value': int64_array([-1, 1, 1])}).create_node_with_data()
|
new_shape = Const(graph, {'name': quantize_name + '/Reshape/Shape',
|
||||||
reshape = Reshape(graph, {}).create_node_with_data([mult_term, new_shape])
|
'value': int64_array([-1, 1, 1])}).create_node_with_data()
|
||||||
|
reshape = Reshape(graph, {'name': quantize_name + '/Reshape'}).create_node_with_data([mult_term, new_shape])
|
||||||
|
|
||||||
# Patch inflow path (by diving by mult_term)
|
# Patch inflow path (by diving by mult_term)
|
||||||
# Put a new Pow/Mul combination here:
|
# Put a new Pow/Mul combination here:
|
||||||
@ -125,15 +128,16 @@ class BinarizeWeightsM1P1(MiddleReplacementPattern):
|
|||||||
if len(match['quantized'].out_nodes()) > 1:
|
if len(match['quantized'].out_nodes()) > 1:
|
||||||
log.debug('BinarizeWeightsM1P1: len(match[\'quantized\'].out_nodes()) > 1')
|
log.debug('BinarizeWeightsM1P1: len(match[\'quantized\'].out_nodes()) > 1')
|
||||||
return
|
return
|
||||||
power_of_exponent = Const(graph, {'value': np.array(-1.0)}).create_node_with_data()
|
power_of_exponent = Const(graph, {'name': quantize_name + '/DivNormalize/Power',
|
||||||
div_op = Pow(graph, {'name': quantize.name + '/DivNormalize'})
|
'value': np.array(-1.0)}).create_node_with_data()
|
||||||
|
div_op = Pow(graph, {'name': quantize_name + '/DivNormalize'})
|
||||||
div_output = div_op.create_node_with_data([mult_term, power_of_exponent])
|
div_output = div_op.create_node_with_data([mult_term, power_of_exponent])
|
||||||
|
|
||||||
for i in [3, 4]:
|
for i in [3, 4]:
|
||||||
match['quantize'].insert_node_with_data_before(
|
match['quantize'].insert_node_with_data_before(
|
||||||
match['quantize'].in_node(i),
|
match['quantize'].in_node(i),
|
||||||
Mul,
|
Mul,
|
||||||
dict(name=quantize.name + '/MulNormalize'),
|
dict(name=quantize_name + '/MulNormalize'),
|
||||||
additional_inputs=[div_output],
|
additional_inputs=[div_output],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -19,9 +19,9 @@ import logging as log
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from extensions.ops.elementwise import Mul, Add
|
from extensions.ops.elementwise import Mul, Add
|
||||||
|
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
||||||
from mo.graph.graph import Graph
|
from mo.graph.graph import Graph
|
||||||
from mo.middle.replacement import MiddleReplacementPattern
|
from mo.middle.replacement import MiddleReplacementPattern
|
||||||
from mo.ops.const import Const
|
|
||||||
|
|
||||||
|
|
||||||
class ConvToBinaryConv(MiddleReplacementPattern):
|
class ConvToBinaryConv(MiddleReplacementPattern):
|
||||||
@ -91,12 +91,10 @@ class ConvToBinaryConv(MiddleReplacementPattern):
|
|||||||
weights_reduced = np.add.reduce(weights, axis=tuple(reduction_indices))
|
weights_reduced = np.add.reduce(weights, axis=tuple(reduction_indices))
|
||||||
weights_reduced = weights_reduced.reshape([len(weights_reduced), 1, 1]) # FIXME: works for NCHW only
|
weights_reduced = weights_reduced.reshape([len(weights_reduced), 1, 1]) # FIXME: works for NCHW only
|
||||||
|
|
||||||
add_term = Const(graph, {'value': weights_reduced}).create_node()
|
operator_name = operator.soft_get('name', operator.id)
|
||||||
add = Add(graph, {}).create_node()
|
add = create_op_node_with_second_input(graph, Add, weights_reduced, {'name': operator_name + '/Add_'})
|
||||||
add.in_port(1).connect(add_term.out_port(0))
|
mul = create_op_node_with_second_input(graph, Mul, np.array(0.5), {'name': operator_name + '/Mul_'})
|
||||||
mul_term = Const(graph, {'value': np.array(0.5)}).create_node()
|
|
||||||
mul = Mul(graph, {}).create_node()
|
|
||||||
mul.in_port(1).connect(mul_term.out_port(0))
|
|
||||||
add.out_port(0).connect(mul.in_port(0))
|
add.out_port(0).connect(mul.in_port(0))
|
||||||
|
|
||||||
operator.out_port(0).get_connection().set_source(mul.out_port(0))
|
operator.out_port(0).get_connection().set_source(mul.out_port(0))
|
||||||
|
@ -189,9 +189,13 @@ class ConvertGroupedStridedSlice(MiddleReplacementPattern):
|
|||||||
log.debug("Removed: {}".format(node.id))
|
log.debug("Removed: {}".format(node.id))
|
||||||
|
|
||||||
# 2. Create Split layer and reorder outputs
|
# 2. Create Split layer and reorder outputs
|
||||||
axis_const = Const(graph, {'value': int64_array(split_channel_dim)}).create_node_with_data()
|
name = name_for_future_split + "/Split"
|
||||||
size_splits_const = Const(graph, {'value': int64_array(size_splits)}).create_node_with_data()
|
axis_const = Const(graph, {'value': int64_array(split_channel_dim),
|
||||||
split = VariadicSplit(graph, dict(name=name_for_future_split + "/Split", out_ports_count=len(size_splits)))
|
'name': name + '/Axis'}).create_node_with_data()
|
||||||
|
size_splits_const = Const(graph, {'value': int64_array(size_splits),
|
||||||
|
'name': name + '/Sizes'}).create_node_with_data()
|
||||||
|
split = VariadicSplit(graph, dict(name=name, out_ports_count=len(size_splits)))
|
||||||
|
|
||||||
split.create_node_with_data(inputs=[input_data, axis_const, size_splits_const],
|
split.create_node_with_data(inputs=[input_data, axis_const, size_splits_const],
|
||||||
data_nodes=final_data_nodes_list)
|
data_nodes=final_data_nodes_list)
|
||||||
|
|
||||||
|
@ -36,6 +36,7 @@ class ConvertLayoutDependentOperations(MiddleReplacementPattern):
|
|||||||
def find_and_replace_pattern(self, graph: Graph):
|
def find_and_replace_pattern(self, graph: Graph):
|
||||||
for node in list(graph.nodes()):
|
for node in list(graph.nodes()):
|
||||||
node = Node(graph, node)
|
node = Node(graph, node)
|
||||||
|
node_name = node.soft_get('name', node.id)
|
||||||
# Check that node layout mismatch with graph layout
|
# Check that node layout mismatch with graph layout
|
||||||
# For example: NHWC and NCHW or NCDHW and NDHWC
|
# For example: NHWC and NCHW or NCDHW and NDHWC
|
||||||
if node.kind == 'op' and node.has_valid('layout') and node.layout != indices_mapping[len(node.layout)][
|
if node.kind == 'op' and node.has_valid('layout') and node.layout != indices_mapping[len(node.layout)][
|
||||||
@ -63,8 +64,10 @@ class ConvertLayoutDependentOperations(MiddleReplacementPattern):
|
|||||||
edge_attrs = graph.get_edge_data(input.id, node.id)[0]
|
edge_attrs = graph.get_edge_data(input.id, node.id)[0]
|
||||||
graph.remove_edge(input.id, node.id)
|
graph.remove_edge(input.id, node.id)
|
||||||
|
|
||||||
input_order_const = Const(graph, {'value': permutation.perm}).create_node_with_data()
|
input_permute_name = node_name + '/input_transpose'
|
||||||
input_permute_op = Transpose(graph, dict(name=node.name + '/Transpose_'))
|
input_order_const = Const(graph, {'name': input_permute_name + '/order',
|
||||||
|
'value': permutation.perm}).create_node_with_data()
|
||||||
|
input_permute_op = Transpose(graph, {'name': input_permute_name})
|
||||||
input_permute_data_node = input_permute_op.create_node_with_data([input, input_order_const])
|
input_permute_data_node = input_permute_op.create_node_with_data([input, input_order_const])
|
||||||
|
|
||||||
graph.add_edge(input_permute_data_node.id, node.id, **edge_attrs)
|
graph.add_edge(input_permute_data_node.id, node.id, **edge_attrs)
|
||||||
@ -77,8 +80,10 @@ class ConvertLayoutDependentOperations(MiddleReplacementPattern):
|
|||||||
input_data_node = Op.create_data_node(graph, node, {'shape': output.shape[permutation.perm]},
|
input_data_node = Op.create_data_node(graph, node, {'shape': output.shape[permutation.perm]},
|
||||||
edge_attrs)
|
edge_attrs)
|
||||||
|
|
||||||
output_order_const = Const(graph, {'value': permutation.inv}).create_node_with_data()
|
output_permute_name = node_name + '/output_transpose'
|
||||||
output_permute_op = Transpose(graph, dict(name=node.name + '/Transpose_')
|
output_order_const = Const(graph, {'name': output_permute_name + '/order',
|
||||||
|
'value': permutation.inv}).create_node_with_data()
|
||||||
|
output_permute_op = Transpose(graph, {'name': output_permute_name}
|
||||||
).create_node_with_data([input_data_node, output_order_const],
|
).create_node_with_data([input_data_node, output_order_const],
|
||||||
data_nodes=output)
|
data_nodes=output)
|
||||||
|
|
||||||
|
@ -46,9 +46,11 @@ class Deconvolution3rdInputNormalization(MiddleReplacementPattern):
|
|||||||
|
|
||||||
data_node = node.in_node(2)
|
data_node = node.in_node(2)
|
||||||
|
|
||||||
const = Const(graph, {'value': permutation.perm, 'need_shape_inference': True}).create_node_with_data()
|
name = node.soft_get('name', node.id) + '/ShapeGather'
|
||||||
axis_const = Const(graph, {'value': int64_array(0)}).create_node_with_data()
|
const = Const(graph, {'value': permutation.perm, 'name': name + '/Const',
|
||||||
gather = Gather(graph, {'name': node.name + '/ShapeGather',
|
'need_shape_inference': True}).create_node_with_data()
|
||||||
|
axis_const = Const(graph, {'value': int64_array(0), 'name': name + '/Axis'}).create_node_with_data()
|
||||||
|
gather = Gather(graph, {'name': name,
|
||||||
'need_shape_inference': True}).create_node_with_data([data_node, const, axis_const])
|
'need_shape_inference': True}).create_node_with_data([data_node, const, axis_const])
|
||||||
attrs = graph.get_edge_data(data_node.id, node.id, key=0).copy()
|
attrs = graph.get_edge_data(data_node.id, node.id, key=0).copy()
|
||||||
|
|
||||||
|
@ -160,7 +160,8 @@ class DilatedConvolution1DConverter(MiddleReplacementPattern):
|
|||||||
unsqueeze_axis = unsqueeze.in_port(1).data.get_value()
|
unsqueeze_axis = unsqueeze.in_port(1).data.get_value()
|
||||||
for port_id in [1, 2]:
|
for port_id in [1, 2]:
|
||||||
current_value = pad.in_port(port_id).get_connection().data.get_value()
|
current_value = pad.in_port(port_id).get_connection().data.get_value()
|
||||||
new_value_node = Const(pad.graph, {'value': np.insert(current_value, unsqueeze_axis.item(), 0),
|
new_value_node = Const(pad.graph, {'name': pad.soft_get('name', pad.id) + '/value_{}'.format(port_id),
|
||||||
|
'value': np.insert(current_value, unsqueeze_axis.item(), 0),
|
||||||
'override_output_shape': True}).create_node()
|
'override_output_shape': True}).create_node()
|
||||||
pad.in_port(port_id).disconnect()
|
pad.in_port(port_id).disconnect()
|
||||||
pad.in_port(port_id).connect(new_value_node.out_port(0))
|
pad.in_port(port_id).connect(new_value_node.out_port(0))
|
||||||
|
@ -106,8 +106,10 @@ class EltwiseInputReshape(MiddleReplacementPattern):
|
|||||||
# Insert Reshape layer between data node and consumer
|
# Insert Reshape layer between data node and consumer
|
||||||
for shape_key in mapping.keys():
|
for shape_key in mapping.keys():
|
||||||
shape = list(shape_key)
|
shape = list(shape_key)
|
||||||
reshape = Reshape(graph, attrs={'name': 'EltwiseReshapeNormalization'})
|
reshape_name = node.soft_get('name', node.id) + '/EltwiseReshape'
|
||||||
reshape_dim = Const(graph, {'value': shape}).create_node_with_data()
|
reshape = Reshape(graph, attrs={'name': reshape_name})
|
||||||
|
reshape_dim = Const(graph,
|
||||||
|
{'value': shape, 'name': reshape_name + '/Shape'}).create_node_with_data()
|
||||||
reshape_data = reshape.create_node_with_data(inputs=[node, reshape_dim])
|
reshape_data = reshape.create_node_with_data(inputs=[node, reshape_dim])
|
||||||
|
|
||||||
# Iterate over consumers and reconnect them to Reshape layer output
|
# Iterate over consumers and reconnect them to Reshape layer output
|
||||||
|
@ -19,9 +19,9 @@ import numpy as np
|
|||||||
|
|
||||||
from extensions.ops.gather import Gather
|
from extensions.ops.gather import Gather
|
||||||
from mo.front.common.partial_infer.utils import int64_array
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
from mo.graph.graph import Graph
|
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
|
||||||
|
from mo.graph.graph import Graph, rename_node
|
||||||
from mo.middle.replacement import MiddleReplacementPattern
|
from mo.middle.replacement import MiddleReplacementPattern
|
||||||
from mo.ops.const import Const
|
|
||||||
from mo.ops.reshape import Reshape
|
from mo.ops.reshape import Reshape
|
||||||
|
|
||||||
|
|
||||||
@ -68,6 +68,7 @@ class GatherNdNormalize(MiddleReplacementPattern):
|
|||||||
|
|
||||||
def replace_pattern(self, graph: Graph, match: dict):
|
def replace_pattern(self, graph: Graph, match: dict):
|
||||||
gather = match['GatherNd']
|
gather = match['GatherNd']
|
||||||
|
gather_name = gather.soft_get('name', gather.id)
|
||||||
input_shape = gather.in_node(0).shape
|
input_shape = gather.in_node(0).shape
|
||||||
indices = gather.in_node(1).value
|
indices = gather.in_node(1).value
|
||||||
if indices is None:
|
if indices is None:
|
||||||
@ -77,26 +78,26 @@ class GatherNdNormalize(MiddleReplacementPattern):
|
|||||||
# 0. All needed checks that we can replace GatherNd by Gather
|
# 0. All needed checks that we can replace GatherNd by Gather
|
||||||
gather_idx = self.indices_check(indices, input_shape)
|
gather_idx = self.indices_check(indices, input_shape)
|
||||||
if gather_idx is None:
|
if gather_idx is None:
|
||||||
log.warning('Node {} with op=GatherNd can\'t be normalized to op=Gather.'.format(gather.name))
|
log.warning('Node {} with op=GatherNd can\'t be normalized to op=Gather.'.format(gather_name))
|
||||||
return
|
return
|
||||||
|
|
||||||
# 1. Add Reshape and connect
|
# 1. Add Reshape and connect
|
||||||
new_shape = int64_array([-1] + list(input_shape[indices.shape[-1]:]))
|
new_shape = int64_array([-1] + list(input_shape[indices.shape[-1]:]))
|
||||||
reshape = Reshape(graph, {'name': gather.name + '/Reshape_for_GatherNd/'}).create_node()
|
reshape = create_op_node_with_second_input(graph, Reshape, new_shape,
|
||||||
reshape_const_node = Const(graph, {'name': reshape.name + '/Dim', 'value': new_shape}).create_node()
|
{'name': gather_name + '/Reshape_for_GatherNd/'})
|
||||||
gather.in_port(0).get_connection().set_destination(reshape.in_port(0))
|
gather.in_port(0).get_connection().set_destination(reshape.in_port(0))
|
||||||
reshape.in_port(1).connect(reshape_const_node.out_port(0))
|
|
||||||
|
|
||||||
# 2. Change indices from Nd to 1d:
|
# 2. Change indices from Nd to 1d:
|
||||||
new_indices = np.reshape(np.take(indices, indices=[gather_idx], axis=-1), [-1])
|
new_indices = np.reshape(np.take(indices, indices=[gather_idx], axis=-1), [-1])
|
||||||
new_indices_const = Const(graph, dict(value=new_indices)).create_node()
|
|
||||||
axis_const = Const(graph, {'value': int64_array(0)}).create_node()
|
rename_node(gather, gather_name + '/to_delete')
|
||||||
|
|
||||||
# 3. Create new Gather operation and reconnect all inputs/outputs
|
# 3. Create new Gather operation and reconnect all inputs/outputs
|
||||||
new_gather = Gather(graph, {'name': gather.name + '/NewGather/'}).create_node()
|
new_gather = create_op_with_const_inputs(graph, Gather, {1: new_indices, 2: int64_array(0)},
|
||||||
|
{'name': gather_name})
|
||||||
|
rename_node(new_gather, gather_name)
|
||||||
|
|
||||||
reshape.out_port(0).connect(new_gather.in_port(0))
|
reshape.out_port(0).connect(new_gather.in_port(0))
|
||||||
new_indices_const.out_port(0).connect(new_gather.in_port(1))
|
|
||||||
axis_const.out_port(0).connect(new_gather.in_port(2))
|
|
||||||
|
|
||||||
gather.out_port(0).get_connection().set_source(new_gather.out_port(0))
|
gather.out_port(0).get_connection().set_source(new_gather.out_port(0))
|
||||||
|
|
||||||
|
@ -16,9 +16,9 @@
|
|||||||
|
|
||||||
from extensions.middle.pass_separator import PostMiddleStart
|
from extensions.middle.pass_separator import PostMiddleStart
|
||||||
from extensions.ops.transpose import Transpose
|
from extensions.ops.transpose import Transpose
|
||||||
|
|
||||||
from mo.graph.graph import Graph, Node
|
from mo.graph.graph import Graph, Node
|
||||||
from mo.middle.replacement import MiddleReplacementPattern
|
from mo.middle.replacement import MiddleReplacementPattern
|
||||||
from mo.ops.const import Const
|
|
||||||
from mo.ops.op import PermuteAttrs
|
from mo.ops.op import PermuteAttrs
|
||||||
|
|
||||||
|
|
||||||
@ -69,6 +69,10 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern):
|
|||||||
len(node.out_port(0).data.get_shape()) >= 4
|
len(node.out_port(0).data.get_shape()) >= 4
|
||||||
|
|
||||||
def find_and_replace_pattern(self, graph: Graph):
|
def find_and_replace_pattern(self, graph: Graph):
|
||||||
|
|
||||||
|
# we need to import these functions here to avoid circular dependent imports
|
||||||
|
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
||||||
|
|
||||||
if graph.graph['layout'] != 'NHWC':
|
if graph.graph['layout'] != 'NHWC':
|
||||||
# we check it here because this transformation is called explicitly from the pipeline
|
# we check it here because this transformation is called explicitly from the pipeline
|
||||||
return
|
return
|
||||||
@ -80,15 +84,14 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern):
|
|||||||
reinterp_shape_node_id, graph.dump_graph_for_graphviz())
|
reinterp_shape_node_id, graph.dump_graph_for_graphviz())
|
||||||
input_shape = reinterp_shape_node.in_node(0).shape
|
input_shape = reinterp_shape_node.in_node(0).shape
|
||||||
if self.is_nchw_to_nhwc_transpose_needed(reinterp_shape_node):
|
if self.is_nchw_to_nhwc_transpose_needed(reinterp_shape_node):
|
||||||
order_const = Const(graph, {'value': PermuteAttrs().get_nchw_to_nhwc_permutation(len(input_shape)).perm
|
permute_node = create_op_node_with_second_input(
|
||||||
}).create_node()
|
graph, Transpose, PermuteAttrs().get_nchw_to_nhwc_permutation(len(input_shape)).perm,
|
||||||
permute_node = Transpose(graph,
|
{'name': reinterp_shape_node.in_port(0).get_source().node.name + '/Transpose'}
|
||||||
{'name': reinterp_shape_node.in_port(0).get_source().node.name + '/Transpose'
|
)
|
||||||
}).create_node()
|
|
||||||
reinterp_shape_node.in_port(0).get_connection().insert_node(permute_node)
|
reinterp_shape_node.in_port(0).get_connection().insert_node(permute_node)
|
||||||
order_const.out_port(0).connect(permute_node.in_port(1))
|
|
||||||
order_const.infer(order_const)
|
|
||||||
|
|
||||||
|
order_const = permute_node.in_port(1).get_source().node
|
||||||
|
order_const.infer(order_const)
|
||||||
# do not infer the Transpose node because it should have input data node in NCHW layout (but currently
|
# do not infer the Transpose node because it should have input data node in NCHW layout (but currently
|
||||||
# it is NHWC because data node attributes has not been permuted yet) and produce output in NHWC layout
|
# it is NHWC because data node attributes has not been permuted yet) and produce output in NHWC layout
|
||||||
# (which is true at this moment)
|
# (which is true at this moment)
|
||||||
@ -107,11 +110,10 @@ class InsertLayoutPropagationTranspose(MiddleReplacementPattern):
|
|||||||
reinterp_shape_node_id, graph.dump_graph_for_graphviz())
|
reinterp_shape_node_id, graph.dump_graph_for_graphviz())
|
||||||
output_shape = reinterp_shape_node.out_node(0).shape
|
output_shape = reinterp_shape_node.out_node(0).shape
|
||||||
if self.is_nhwc_to_nchw_transpose_needed(reinterp_shape_node):
|
if self.is_nhwc_to_nchw_transpose_needed(reinterp_shape_node):
|
||||||
order_const = Const(graph, {
|
permute_node = create_op_node_with_second_input(
|
||||||
'value': PermuteAttrs().get_nhwc_to_nchw_permutation(len(output_shape)).perm}).create_node()
|
graph, Transpose, PermuteAttrs().get_nhwc_to_nchw_permutation(len(output_shape)).perm,
|
||||||
permute_node = Transpose(graph, {'name': reinterp_shape_node.id + '/Transpose'}).create_node()
|
{'name': reinterp_shape_node.id + '/Transpose'})
|
||||||
reinterp_shape_node.out_port(0).get_connection().insert_node(permute_node)
|
reinterp_shape_node.out_port(0).get_connection().insert_node(permute_node)
|
||||||
order_const.out_port(0).connect(permute_node.in_port(1))
|
|
||||||
|
|
||||||
# the Reshape and Transpose operations should work in original (NHWC layout) so the Transpose
|
# the Reshape and Transpose operations should work in original (NHWC layout) so the Transpose
|
||||||
# will convert it to the NCHW
|
# will convert it to the NCHW
|
||||||
|
@ -20,9 +20,9 @@ import numpy as np
|
|||||||
|
|
||||||
from extensions.ops.normalize import NormalizeOp
|
from extensions.ops.normalize import NormalizeOp
|
||||||
from mo.front.common.partial_infer.utils import int64_array
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
|
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
||||||
from mo.graph.graph import Graph, rename_node
|
from mo.graph.graph import Graph, rename_node
|
||||||
from mo.middle.replacement import MiddleReplacementPattern
|
from mo.middle.replacement import MiddleReplacementPattern
|
||||||
from mo.ops.const import Const
|
|
||||||
|
|
||||||
|
|
||||||
class L2NormToNorm(MiddleReplacementPattern):
|
class L2NormToNorm(MiddleReplacementPattern):
|
||||||
@ -87,13 +87,13 @@ class L2NormToNorm(MiddleReplacementPattern):
|
|||||||
normalizel2_name = output_name + '/normalizel2'
|
normalizel2_name = output_name + '/normalizel2'
|
||||||
rename_node(match['l2_normalize'], normalizel2_name)
|
rename_node(match['l2_normalize'], normalizel2_name)
|
||||||
|
|
||||||
normalize_node = NormalizeOp(graph, {'name': output_name, 'eps': y,
|
normalize_node = create_op_node_with_second_input(graph, NormalizeOp,
|
||||||
'across_spatial': 0, 'channel_shared': 0}).create_node()
|
np.ones(shape=int64_array([match['input'].shape[-1]]),
|
||||||
|
dtype=match['input'].data_type),
|
||||||
|
{'name': output_name, 'eps': y,
|
||||||
|
'across_spatial': 0, 'channel_shared': 0})
|
||||||
rename_node(normalize_node, output_name)
|
rename_node(normalize_node, output_name)
|
||||||
|
|
||||||
weights_node = Const(graph, {'value': np.ones(shape=int64_array([match['input'].shape[-1]]),
|
|
||||||
dtype=match['input'].data_type)}).create_node()
|
|
||||||
|
|
||||||
match['square'].in_port(0).get_source().connect(normalize_node.in_port(0))
|
match['square'].in_port(0).get_source().connect(normalize_node.in_port(0))
|
||||||
|
|
||||||
match['square'].in_port(0).disconnect()
|
match['square'].in_port(0).disconnect()
|
||||||
@ -102,5 +102,4 @@ class L2NormToNorm(MiddleReplacementPattern):
|
|||||||
else:
|
else:
|
||||||
match['l2_normalize'].in_port(0).disconnect()
|
match['l2_normalize'].in_port(0).disconnect()
|
||||||
|
|
||||||
weights_node.out_port(0).get_connection().set_destination(normalize_node.in_port(1))
|
|
||||||
match['l2_normalize'].out_port(0).get_connection().set_source(normalize_node.out_port(0))
|
match['l2_normalize'].out_port(0).get_connection().set_source(normalize_node.out_port(0))
|
||||||
|
@ -192,7 +192,7 @@ class MXNetRNNSequenceNormalize(MiddleReplacementPattern):
|
|||||||
input = match['input']
|
input = match['input']
|
||||||
if not lstm.has_num_directions:
|
if not lstm.has_num_directions:
|
||||||
return
|
return
|
||||||
old_data_node =lstm.out_node(0)
|
old_data_node = lstm.out_node(0)
|
||||||
num_directions = 2 if lstm.direction in ['bidirectional'] else 1
|
num_directions = 2 if lstm.direction in ['bidirectional'] else 1
|
||||||
mxnet_shape = lstm.out_node(0).shape.copy()
|
mxnet_shape = lstm.out_node(0).shape.copy()
|
||||||
|
|
||||||
@ -206,18 +206,21 @@ class MXNetRNNSequenceNormalize(MiddleReplacementPattern):
|
|||||||
if lstm.has_num_directions:
|
if lstm.has_num_directions:
|
||||||
mo_shape = np.insert(mo_shape, 1, np.int64(num_directions))
|
mo_shape = np.insert(mo_shape, 1, np.int64(num_directions))
|
||||||
|
|
||||||
new_data = Op._create_data_node(graph, name=lstm.name + '/Data/Reshape_mxnet/', attrs={'shape': mo_shape})
|
lstm_name = lstm.soft_get('name', lstm.id)
|
||||||
|
|
||||||
|
new_data = Op._create_data_node(graph, name=lstm_name + '/Data/Reshape_mxnet/', attrs={'shape': mo_shape})
|
||||||
graph.remove_edge(lstm.id, old_data_node.id)
|
graph.remove_edge(lstm.id, old_data_node.id)
|
||||||
graph.add_edge(lstm.id, new_data.id, key=0, out=0)
|
graph.add_edge(lstm.id, new_data.id, key=0, out=0)
|
||||||
|
|
||||||
# Add Transpose
|
# Add Transpose
|
||||||
permute_order = Const(graph, dict(value=int64_array([0, 2, 1, 3]))).create_node_with_data()
|
permute_order = Const(graph, {'name': lstm_name + '/Transpose_mxnet_order',
|
||||||
permute_data = Transpose(graph, dict(name=lstm.name + '/Transpose_mxnet/')
|
'value': int64_array([0, 2, 1, 3])}).create_node_with_data()
|
||||||
|
permute_data = Transpose(graph, {'name': lstm_name + '/Transpose_mxnet/'}
|
||||||
).create_node_with_data([new_data, permute_order])
|
).create_node_with_data([new_data, permute_order])
|
||||||
|
|
||||||
# Add Reshape
|
# Add Reshape
|
||||||
reshape = Reshape(graph, dict(name=lstm.name + '/Reshape_mxnet/'))
|
reshape = Reshape(graph, {'name': lstm_name + '/Reshape_mxnet/'})
|
||||||
reshape_dim_data = Const(graph, {'name': lstm.name + '/Reshape_mxnet_dim',
|
reshape_dim_data = Const(graph, {'name': lstm_name + '/Reshape_mxnet_dim',
|
||||||
'value': mxnet_shape}).create_node_with_data()
|
'value': mxnet_shape}).create_node_with_data()
|
||||||
|
|
||||||
reshape.create_node_with_data([permute_data, reshape_dim_data], dict(), data_nodes=[old_data_node])
|
reshape.create_node_with_data([permute_data, reshape_dim_data], dict(), data_nodes=[old_data_node])
|
||||||
|
@ -41,7 +41,8 @@ def resolve_shared_inputs(node: Node, port_ids_to_duplicate: List[int]):
|
|||||||
if value is None:
|
if value is None:
|
||||||
log.debug('Can not duplicate due no data for in_port {} of node {}'.format(port_id, node.name))
|
log.debug('Can not duplicate due no data for in_port {} of node {}'.format(port_id, node.name))
|
||||||
for node, idxs in dst_port_map.items():
|
for node, idxs in dst_port_map.items():
|
||||||
const = Const(graph, {'value': np.array(value)}).create_node()
|
const = Const(graph, {'value': np.array(value),
|
||||||
|
'name': node.soft_get('name', node.id) + '/duplicated_'}).create_node()
|
||||||
for idx in idxs:
|
for idx in idxs:
|
||||||
node.in_port(idx).disconnect()
|
node.in_port(idx).disconnect()
|
||||||
const.out_port(0).connect(node.in_port(idx))
|
const.out_port(0).connect(node.in_port(idx))
|
||||||
|
@ -34,6 +34,6 @@ class ReverseTransposeNormalization(MiddleReplacementPattern):
|
|||||||
node = match['transpose']
|
node = match['transpose']
|
||||||
assert len(node.in_nodes()) == 1
|
assert len(node.in_nodes()) == 1
|
||||||
order = np.arange(len(node.in_port(0).data.get_shape()))[::-1]
|
order = np.arange(len(node.in_port(0).data.get_shape()))[::-1]
|
||||||
const = Const(graph, {'value': order}).create_node()
|
const = Const(graph, {'value': order, 'name': node.soft_get('name', node.id) + '/Order'}).create_node()
|
||||||
node.add_input_port(1, skip_if_exist=True)
|
node.add_input_port(1, skip_if_exist=True)
|
||||||
const.out_port(0).connect(node.in_port(1))
|
const.out_port(0).connect(node.in_port(1))
|
||||||
|
@ -16,9 +16,9 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from extensions.ops.reverse_sequence import ReverseSequence
|
from extensions.ops.reverse_sequence import ReverseSequence
|
||||||
from mo.graph.graph import Graph
|
from mo.front.tf.graph_utils import create_op_node_with_second_input
|
||||||
|
from mo.graph.graph import Graph, rename_node
|
||||||
from mo.middle.replacement import MiddleReplacementPattern
|
from mo.middle.replacement import MiddleReplacementPattern
|
||||||
from mo.ops.const import Const
|
|
||||||
from mo.utils.error import Error
|
from mo.utils.error import Error
|
||||||
|
|
||||||
|
|
||||||
@ -57,14 +57,15 @@ class ReverseToReverseSequence(MiddleReplacementPattern):
|
|||||||
|
|
||||||
# 1. For ReverseSequence 1-port input is seq_lengths => create this input node
|
# 1. For ReverseSequence 1-port input is seq_lengths => create this input node
|
||||||
seq_lengths = np.ones(input_data_shape[batch_axis]) * input_data_shape[seq_axis]
|
seq_lengths = np.ones(input_data_shape[batch_axis]) * input_data_shape[seq_axis]
|
||||||
const = Const(graph, dict(value=seq_lengths)).create_node()
|
|
||||||
|
|
||||||
|
reverse_name = reverse.soft_get('name', reverse.id)
|
||||||
|
rename_node(reverse, reverse_name + '/to_delete')
|
||||||
# 2. Create new ReverseSequence node and reconnect all inputs/outputs to it
|
# 2. Create new ReverseSequence node and reconnect all inputs/outputs to it
|
||||||
reverse_sequence = ReverseSequence(graph, {'name': reverse.name + '/ReverseSequence/',
|
reverse_sequence = create_op_node_with_second_input(graph, ReverseSequence, seq_lengths,
|
||||||
'seq_axis': seq_axis, 'batch_axis': batch_axis}).create_node()
|
{'name': reverse_name, 'seq_axis': seq_axis,
|
||||||
|
'batch_axis': batch_axis})
|
||||||
|
rename_node(reverse_sequence, reverse_name)
|
||||||
reverse.in_port(0).get_connection().set_destination(reverse_sequence.in_port(0))
|
reverse.in_port(0).get_connection().set_destination(reverse_sequence.in_port(0))
|
||||||
const.out_port(0).connect(reverse_sequence.in_port(1))
|
|
||||||
reverse.out_port(0).get_connection().set_source(reverse_sequence.out_port(0))
|
reverse.out_port(0).get_connection().set_source(reverse_sequence.out_port(0))
|
||||||
|
|
||||||
# 3. Delete old Reverse node
|
# 3. Delete old Reverse node
|
||||||
|
@ -36,7 +36,7 @@ class SwapAxisMiddleReplacer(MiddleReplacementPattern):
|
|||||||
order = swapaxis.order
|
order = swapaxis.order
|
||||||
|
|
||||||
swapaxis.add_input_port(1)
|
swapaxis.add_input_port(1)
|
||||||
const = Const(graph, {'value': order}).create_node()
|
const = Const(graph, {'value': order, 'name': swapaxis.soft_get('name', swapaxis.id) + '/Order'}).create_node()
|
||||||
const.out_port(0).connect(swapaxis.in_port(1))
|
const.out_port(0).connect(swapaxis.in_port(1))
|
||||||
|
|
||||||
Transpose.update_node_stat(swapaxis, {'need_shape_inference': True})
|
Transpose.update_node_stat(swapaxis, {'need_shape_inference': True})
|
||||||
|
@ -25,9 +25,9 @@ from extensions.ops.elementwise import Mul
|
|||||||
from extensions.ops.interpolate import Interpolate
|
from extensions.ops.interpolate import Interpolate
|
||||||
from mo.front.common.layout import get_height_dim, get_width_dim, get_depth_dim
|
from mo.front.common.layout import get_height_dim, get_width_dim, get_depth_dim
|
||||||
from mo.front.common.partial_infer.utils import int64_array
|
from mo.front.common.partial_infer.utils import int64_array
|
||||||
|
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
|
||||||
from mo.graph.graph import Graph, Node
|
from mo.graph.graph import Graph, Node
|
||||||
from mo.middle.replacement import MiddleReplacementPattern
|
from mo.middle.replacement import MiddleReplacementPattern
|
||||||
from mo.ops.const import Const
|
|
||||||
from mo.ops.shape import Shape
|
from mo.ops.shape import Shape
|
||||||
from mo.ops.strided_slice import StridedSlice
|
from mo.ops.strided_slice import StridedSlice
|
||||||
|
|
||||||
@ -55,6 +55,7 @@ class UpsampleToResample(MiddleReplacementPattern):
|
|||||||
def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
|
def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
|
||||||
log.debug('UpsampleToResample is triggered')
|
log.debug('UpsampleToResample is triggered')
|
||||||
upsample = match['upsample']
|
upsample = match['upsample']
|
||||||
|
upsample_name = upsample.soft_get('name', upsample.id)
|
||||||
input_shape = upsample.in_port(0).data.get_shape()
|
input_shape = upsample.in_port(0).data.get_shape()
|
||||||
input_shape_rank = len(input_shape)
|
input_shape_rank = len(input_shape)
|
||||||
if input_shape_rank not in [4, 5]:
|
if input_shape_rank not in [4, 5]:
|
||||||
@ -67,8 +68,7 @@ class UpsampleToResample(MiddleReplacementPattern):
|
|||||||
return
|
return
|
||||||
scales = upsample.in_node(1).value
|
scales = upsample.in_node(1).value
|
||||||
assert len(scales) in (4, 5), 'Supported scales rank is 4 or 5, but it is {} for node {}'.format(
|
assert len(scales) in (4, 5), 'Supported scales rank is 4 or 5, but it is {} for node {}'.format(
|
||||||
len(scales), upsample.soft_get('name', upsample.id)
|
len(scales), upsample_name)
|
||||||
)
|
|
||||||
if not (math.isclose(scales[0], 1, rel_tol=1e-5) and math.isclose(scales[1], 1, rel_tol=1e-5)):
|
if not (math.isclose(scales[0], 1, rel_tol=1e-5) and math.isclose(scales[1], 1, rel_tol=1e-5)):
|
||||||
return
|
return
|
||||||
height_scale = scales[2]
|
height_scale = scales[2]
|
||||||
@ -81,45 +81,50 @@ class UpsampleToResample(MiddleReplacementPattern):
|
|||||||
|
|
||||||
if not math.isclose(height_scale, width_scale, rel_tol=1e-5):
|
if not math.isclose(height_scale, width_scale, rel_tol=1e-5):
|
||||||
log.debug('Width and height scales are not equal: {} vs {} for node {}'.format(
|
log.debug('Width and height scales are not equal: {} vs {} for node {}'.format(
|
||||||
width_scale, height_scale, upsample.soft_get('name')))
|
width_scale, height_scale, upsample_name))
|
||||||
return
|
return
|
||||||
if depth_scale is not None and not math.isclose(height_scale, depth_scale, rel_tol=1e-5):
|
if depth_scale is not None and not math.isclose(height_scale, depth_scale, rel_tol=1e-5):
|
||||||
log.debug('Depth and height scales are not equal: {} vs {} for node {}'.format(
|
log.debug('Depth and height scales are not equal: {} vs {} for node {}'.format(
|
||||||
depth_scale, height_scale, upsample.soft_get('name')))
|
depth_scale, height_scale, upsample_name))
|
||||||
return
|
return
|
||||||
|
|
||||||
if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
|
if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
|
||||||
upsample.in_port(1).disconnect()
|
upsample.in_port(1).disconnect()
|
||||||
|
|
||||||
shape = Shape(graph, {'name': upsample.name + '/0_port'}).create_node()
|
shape = Shape(graph, {'name': upsample_name + '/0_port'}).create_node()
|
||||||
|
|
||||||
layout = graph.graph['layout']
|
layout = graph.graph['layout']
|
||||||
if input_shape_rank == 4:
|
|
||||||
begin = Const(graph, {'value': int64_array([get_height_dim(layout, input_shape_rank)])}).create_node()
|
|
||||||
factor = Const(graph, {'value': np.array([height_scale, width_scale])}).create_node()
|
|
||||||
else:
|
|
||||||
begin = Const(graph, {'value': int64_array([get_depth_dim(layout, input_shape_rank)])}).create_node()
|
|
||||||
factor = Const(graph, {'value': np.array([depth_scale, height_scale, width_scale])}).create_node()
|
|
||||||
end = Const(graph, {'value': int64_array([get_width_dim(layout, input_shape_rank) + 1])}).create_node()
|
|
||||||
|
|
||||||
stride = Const(graph, {'value': int64_array([1])}).create_node()
|
if input_shape_rank == 4:
|
||||||
ss = StridedSlice(graph, {'name': upsample.name + '/ss_0_port',
|
begin_value = int64_array([get_height_dim(layout, input_shape_rank)])
|
||||||
|
factor_value = np.array([height_scale, width_scale])
|
||||||
|
else:
|
||||||
|
begin_value = int64_array([get_depth_dim(layout, input_shape_rank)])
|
||||||
|
factor_value = np.array([depth_scale, height_scale, width_scale])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
ss = create_op_with_const_inputs(graph, StridedSlice,
|
||||||
|
{1: begin_value,
|
||||||
|
2: int64_array([get_width_dim(layout, input_shape_rank) + 1]),
|
||||||
|
3: int64_array([1])
|
||||||
|
},
|
||||||
|
{'name': upsample_name + '/ss_0_port',
|
||||||
'begin_mask': int64_array([1]),
|
'begin_mask': int64_array([1]),
|
||||||
'end_mask': int64_array([1]),
|
'end_mask': int64_array([1]),
|
||||||
'new_axis_mask': int64_array([0]),
|
'new_axis_mask': int64_array([0]),
|
||||||
'shrink_axis_mask': int64_array([0]),
|
'shrink_axis_mask': int64_array([0]),
|
||||||
'ellipsis_mask': int64_array([0])}).create_node()
|
'ellipsis_mask': int64_array([0])
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
mul = Mul(graph, {'name': upsample.name + '/factor_mul_'}).create_node()
|
mul = create_op_node_with_second_input(graph, Mul, factor_value, {'name': upsample_name + '/factor_mul_'})
|
||||||
|
|
||||||
source = upsample.in_port(0).get_connection().get_source()
|
source = upsample.in_port(0).get_connection().get_source()
|
||||||
source.connect(shape.in_port(0))
|
source.connect(shape.in_port(0))
|
||||||
shape.out_port(0).connect(ss.in_port(0))
|
shape.out_port(0).connect(ss.in_port(0))
|
||||||
begin.out_port(0).connect(ss.in_port(1))
|
|
||||||
end.out_port(0).connect(ss.in_port(2))
|
|
||||||
stride.out_port(0).connect(ss.in_port(3))
|
|
||||||
ss.out_port(0).connect(mul.in_port(0))
|
ss.out_port(0).connect(mul.in_port(0))
|
||||||
factor.out_port(0).connect(mul.in_port(1))
|
|
||||||
|
|
||||||
# Create Interpolate operation
|
# Create Interpolate operation
|
||||||
if input_shape_rank == 4:
|
if input_shape_rank == 4:
|
||||||
@ -130,7 +135,7 @@ class UpsampleToResample(MiddleReplacementPattern):
|
|||||||
get_height_dim(layout, input_shape_rank),
|
get_height_dim(layout, input_shape_rank),
|
||||||
get_width_dim(layout, input_shape_rank)])
|
get_width_dim(layout, input_shape_rank)])
|
||||||
|
|
||||||
resample_op = Interpolate(graph, dict(name='Interpolate/{}'.format(upsample.name),
|
resample_op = Interpolate(graph, dict(name=upsample_name + '/Interpolate',
|
||||||
axes=axes, mode=upsample.attrs()['mode'],
|
axes=axes, mode=upsample.attrs()['mode'],
|
||||||
antialias=0, convert_to_resample=True)).create_node()
|
antialias=0, convert_to_resample=True)).create_node()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user