diff --git a/model-optimizer/extensions/front/kaldi/add_permute_after_convolution.py b/model-optimizer/extensions/front/kaldi/add_permute_after_convolution.py index 5f30fe7158c..e3611d474b6 100644 --- a/model-optimizer/extensions/front/kaldi/add_permute_after_convolution.py +++ b/model-optimizer/extensions/front/kaldi/add_permute_after_convolution.py @@ -15,16 +15,15 @@ """ from collections import deque -import numpy as np - from extensions.front.MatMul_normalizer import FullyConnectedDecomposer from extensions.front.kaldi.add_reshape_around_convolution import ReplaceConvolutionReshape from extensions.middle.TensorIteratorMerge import op_type from extensions.ops.activation_ops import activation_ops from extensions.ops.transpose import Transpose +from mo.front.common.partial_infer.utils import int64_array from mo.front.common.replacement import FrontReplacementSubgraph +from mo.front.tf.graph_utils import create_op_with_const_inputs from mo.graph.graph import Node, Graph -from mo.ops.const import Const class ReplaceConvolutionTranspose(FrontReplacementSubgraph): @@ -61,10 +60,9 @@ class ReplaceConvolutionTranspose(FrontReplacementSubgraph): convolution_nodes = [node for node in nodes_with_weights if Node(graph, node).op == 'Convolution'] for convolution_node in convolution_nodes: target_node = self.search_target_node(Node(graph, convolution_node)) - order_const = Const(graph, dict(value=np.array([0, 3, 2, 1]))).create_node() - permute_node = Transpose(graph, dict(name=target_node.name + '/Transpose')).create_node() + permute_node = create_op_with_const_inputs(graph, Transpose, {1: int64_array([0, 3, 2, 1])}, + {'name': target_node.name + '/Transpose'}) target_node.insert_node_after(permute_node, 0) - order_const.out_port(0).connect(permute_node.in_port(1)) def run_after(self): from extensions.front.flatten_to_reshape import FlattenToReshape diff --git a/model-optimizer/extensions/front/kaldi/add_permute_after_convolution_test.py b/model-optimizer/extensions/front/kaldi/add_permute_after_convolution_test.py index d93ac86e0c9..9ffc65eab99 100644 --- a/model-optimizer/extensions/front/kaldi/add_permute_after_convolution_test.py +++ b/model-optimizer/extensions/front/kaldi/add_permute_after_convolution_test.py @@ -39,11 +39,12 @@ class ReplaceConvolutionTransposeTests(unittest.TestCase): ('conv', 'reshape_conv'), ('reshape_conv', 'scale_shift'), ]) + graph.stage = 'front' ReplaceConvolutionTranspose().find_and_replace_pattern(graph) conv_node = Node(graph, graph.nodes['conv']['name']) permute = conv_node.out_node() self.assertEqual(permute.op, 'Transpose') - self.assertTrue(np.array_equal(permute.in_node(1).in_node().value, np.array([0, 3, 2, 1]))) + self.assertTrue(np.array_equal(permute.in_node(1).value, np.array([0, 3, 2, 1]))) def test_conv_pool(self): graph = build_graph(self.nodes_attributes, [ @@ -53,11 +54,12 @@ class ReplaceConvolutionTransposeTests(unittest.TestCase): ('pool', 'reshape_after_pool'), ('reshape_after_pool', 'fc'), ]) + graph.stage = 'front' ReplaceConvolutionTranspose().find_and_replace_pattern(graph) pool_node = Node(graph, graph.nodes['pool']['name']) permute = pool_node.out_node() self.assertEqual(permute.op, 'Transpose') - self.assertTrue(np.array_equal(permute.in_node(1).in_node().value, np.array([0, 3, 2, 1]))) + self.assertTrue(np.array_equal(permute.in_node(1).value, np.array([0, 3, 2, 1]))) def test_conv_act_pool(self): graph = build_graph(self.nodes_attributes, [ @@ -68,8 +70,9 @@ class ReplaceConvolutionTransposeTests(unittest.TestCase): ('pool', 'reshape_after_pool'), ('reshape_after_pool', 'fc'), ]) + graph.stage = 'front' ReplaceConvolutionTranspose().find_and_replace_pattern(graph) pool_node = Node(graph, graph.nodes['pool']['name']) permute = pool_node.out_node() self.assertEqual(permute.op, 'Transpose') - self.assertTrue(np.array_equal(permute.in_node(1).in_node().value, np.array([0, 3, 2, 1]))) + self.assertTrue(np.array_equal(permute.in_node(1).value, np.array([0, 3, 2, 1]))) diff --git a/model-optimizer/extensions/front/kaldi/add_reshape_around_convolution.py b/model-optimizer/extensions/front/kaldi/add_reshape_around_convolution.py index d533770d412..15bf12bc9fc 100644 --- a/model-optimizer/extensions/front/kaldi/add_reshape_around_convolution.py +++ b/model-optimizer/extensions/front/kaldi/add_reshape_around_convolution.py @@ -53,35 +53,38 @@ class ReplaceConvolutionReshape(FrontReplacementPattern): @staticmethod def replace_pattern(graph: Graph, match: dict): node = match['conv'] + node_name = node.soft_get('name', node.id) # create Reshape before convolution # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride] - shape = Shape(graph, {}).create_node() + shape = Shape(graph, {'name': node_name + '/Shape'}).create_node() shape.in_port(0).connect(node.in_port(0).get_source()) - split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([1, -1])}, {'out_ports_count': 2}, shape) - conv_patch_stride = Const(graph, {'value': int64_array([node.patch_stride])}).create_node() - pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1])) + split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([1, -1])}, + {'name': shape.name + '/split_batch', 'out_ports_count': 2}, shape) + + pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1]), {'name': node_name + '/patch_stride/inverse'}) + conv_patch_stride = Const(graph, {'value': int64_array([node.patch_stride]), + 'name': node_name + '/patch_stride/'}).create_node() pow_node.in_port(0).connect(conv_patch_stride.out_port(0)) - mul = Mul(graph, {}).create_node() + mul = Mul(graph, {'name': node_name + '/mul_inverse_stride_h'}).create_node() mul.in_port(0).connect(split.out_port(1)) mul.in_port(1).connect(pow_node.out_port(0)) - const_1 = Const(graph, {'value': int64_array([1])}).create_node() + concat = create_op_with_const_inputs(graph, Concat, {2: int64_array([1])}, + {'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0}) - concat = Concat(graph, {'in_ports_count': 4, 'axis': 0}).create_node() concat.in_port(0).connect(split.out_port(0)) concat.in_port(1).connect(mul.out_port(0)) - concat.in_port(2).connect(const_1.out_port(0)) concat.in_port(3).connect(conv_patch_stride.out_port(0)) - reshape_in = Reshape(graph, {'name': '/Reshape/' + node.name}).create_node() + reshape_in = Reshape(graph, {'name': node_name + '/reshape_in'}).create_node() reshape_in.in_port(1).connect(concat.out_port(0)) # create Reshape after Convolution reshape_out = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]), - {'name': node.name + '/Reshape/'}) + {'name': node_name + '/reshape_out'}) # connect input_reshape_node source = node.in_port(0).get_source() diff --git a/model-optimizer/extensions/front/kaldi/add_reshape_around_pooling.py b/model-optimizer/extensions/front/kaldi/add_reshape_around_pooling.py index 421ddbeb086..197879011b7 100644 --- a/model-optimizer/extensions/front/kaldi/add_reshape_around_pooling.py +++ b/model-optimizer/extensions/front/kaldi/add_reshape_around_pooling.py @@ -49,38 +49,41 @@ class ReplacePoolingReshape(FrontReplacementPattern): @staticmethod def replace_pattern(graph: Graph, match: dict): node = match['pool'] + node_name = node.soft_get('name', node.id) if node.pool_step is None: node.stride = int64_array([1, 1, node.window[-1], node.window[-1]]) # create Reshape before convolution # shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride] - shape = Shape(graph, {}).create_node() + shape = Shape(graph, {'name': node_name + '/Shape'}).create_node() shape.in_port(0).connect(node.in_port(0).get_source()) - split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([1, -1])}, {'out_ports_count': 2}, shape) - node_pool_stride = Const(graph, {'value': int64_array([node.pool_stride])}).create_node() - pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1])) + split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([1, -1])}, + {'name': shape.name + '/split_batch', 'out_ports_count': 2}, shape) + + pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1]), {'name': node_name + '/pool_stride/inverse'}) + node_pool_stride = Const(graph, {'value': int64_array([node.pool_stride]), + 'name': node_name + '/pool_stride/'}).create_node() pow_node.in_port(0).connect(node_pool_stride.out_port(0)) - mul = Mul(graph, {}).create_node() + mul = Mul(graph, {'name': node_name + '/mul_inverse_stride_h'}).create_node() mul.in_port(0).connect(split.out_port(1)) mul.in_port(1).connect(pow_node.out_port(0)) - const_1 = Const(graph, {'value': int64_array([1])}).create_node() + concat = create_op_with_const_inputs(graph, Concat, {2: int64_array([1])}, + {'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0}) - concat = Concat(graph, {'in_ports_count': 4, 'axis': 0}).create_node() concat.in_port(0).connect(split.out_port(0)) concat.in_port(3).connect(mul.out_port(0)) - concat.in_port(2).connect(const_1.out_port(0)) concat.in_port(1).connect(node_pool_stride.out_port(0)) - reshape_in = Reshape(graph, {'name': '/Reshape/' + node.name}).create_node() + reshape_in = Reshape(graph, {'name': node_name + '/reshape_in'}).create_node() reshape_in.in_port(1).connect(concat.out_port(0)) # create Reshape after Convolution reshape_out = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]), - {'name': node.name + '/Reshape/'}) + {'name': node_name + '/reshape_out'}) # connect input_reshape_node source = node.in_port(0).get_source() diff --git a/model-optimizer/extensions/front/kaldi/replace_lstm_nonlinearity.py b/model-optimizer/extensions/front/kaldi/replace_lstm_nonlinearity.py index c583a193d1b..9bd5de07b82 100644 --- a/model-optimizer/extensions/front/kaldi/replace_lstm_nonlinearity.py +++ b/model-optimizer/extensions/front/kaldi/replace_lstm_nonlinearity.py @@ -16,13 +16,13 @@ import numpy as np from extensions.ops.activation_ops import Sigmoid, Tanh +from extensions.ops.elementwise import Add, Mul from extensions.ops.split import Split from mo.front.caffe.extractors.utils import input_as_const from mo.front.common.replacement import FrontReplacementOp +from mo.front.tf.graph_utils import create_op_with_const_inputs from mo.graph.graph import Node, Graph from mo.ops.concat import Concat -from mo.ops.const import Const -from extensions.ops.elementwise import Add, Mul from mo.ops.scale_shift import ScaleShiftOp @@ -40,80 +40,80 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp): def replace_op(self, graph: Graph, node: Node): # split input to (i_part, f_part, c_part, o_part, ct_1) - split_node_axis = Const(graph, {'value': np.int64(1)}).create_node() - split_node = Split(graph, {'name': 'Split_lstm_input_', - 'num_splits': 5}).create_node() + node_name = node.soft_get('name', node.id) + split_node = create_op_with_const_inputs(graph, Split, {1: np.int64(1)}, + {'name': node_name + '/split_lstm_input', + 'num_splits': 5}) node.in_port(0).get_connection().set_destination(split_node.in_port(0)) - split_node.in_port(1).connect(split_node_axis.out_port(0)) # i_t = Sigmoid(i_part + w_ic*ct_1) - i_scale_attrs = {'name': 'i_scaleshift', + i_scale_attrs = {'name': node_name + '/i_scaleshift', 'bias_term': False} i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node() input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights) split_node.out_port(4).connect(i_scale.in_port(0)) - sum_i_c = Add(graph, {'name': 'sum_i_c_'}).create_node() + sum_i_c = Add(graph, {'name': node_name + '/sum_i_c_'}).create_node() split_node.out_port(0).connect(sum_i_c.in_port(0)) i_scale.out_port(0).connect(sum_i_c.in_port(1)) - i_sigmoid = Sigmoid(graph, {'name': 'i_sigmoid'}).create_node() + i_sigmoid = Sigmoid(graph, {'name': node_name + '/i_sigmoid'}).create_node() sum_i_c.out_port(0).connect(i_sigmoid.in_port(0)) # f_t = Sigmoid(f_part + w_fc*ct_1) - f_scale_attrs = {'name': 'f_scaleshift', + f_scale_attrs = {'name': node_name + '/f_scaleshift', 'bias_term': False} f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node() input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights) split_node.out_port(4).connect(f_scale.in_port(0)) - sum_f_c = Add(graph, {'name': 'sum_f_c_'}).create_node() + sum_f_c = Add(graph, {'name': node_name + '/sum_f_c_'}).create_node() split_node.out_port(1).connect(sum_f_c.in_port(0)) f_scale.out_port(0).connect(sum_f_c.in_port(1)) - f_sigmoid = Sigmoid(graph, {'name': 'f_sigmoid'}).create_node() + f_sigmoid = Sigmoid(graph, {'name': node_name + '/f_sigmoid'}).create_node() sum_f_c.out_port(0).connect(f_sigmoid.in_port(0)) # c_t = f_t*ct_1 + i_t * tanh(c_part) - c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node() + c_tanh = Tanh(graph, {'name': node_name + '/c_tanh'}).create_node() split_node.out_port(2).connect(c_tanh.in_port(0)) - prod_i_c_tanh = Mul(graph, {'name': 'prod_i_c_tanh_'}).create_node() + prod_i_c_tanh = Mul(graph, {'name': node_name + '/prod_i_c_tanh_'}).create_node() i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0)) c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1)) - prod_f_ct_1 = Mul(graph, {'name': 'prod_f_ct_1_'}).create_node() + prod_f_ct_1 = Mul(graph, {'name': node_name + '/prod_f_ct_1_'}).create_node() f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0)) split_node.out_port(4).connect(prod_f_ct_1.in_port(1)) - sum_f_i = Add(graph, {'name': 'sum_f_i_'}).create_node() + sum_f_i = Add(graph, {'name': node_name + '/sum_f_i_'}).create_node() prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0)) prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1)) # o_t = Sigmoid(o_part + w_oc*c_t) - o_scale_attrs = {'name': 'o_scaleshift', + o_scale_attrs = {'name': node_name + '/o_scaleshift', 'bias_term': False} o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node() input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights) sum_f_i.out_port(0).connect(o_scale.in_port(0)) - sum_o_c = Add(graph, {'name': 'sum_o_c_'}).create_node() + sum_o_c = Add(graph, {'name': node_name + '/sum_o_c_'}).create_node() split_node.out_port(3).connect(sum_o_c.in_port(0)) o_scale.out_port(0).connect(sum_o_c.in_port(1)) - o_sigmoid = Sigmoid(graph, {'name': 'o_sigmoid'}).create_node() + o_sigmoid = Sigmoid(graph, {'name': node_name + '/o_sigmoid'}).create_node() sum_o_c.out_port(0).connect(o_sigmoid.in_port(0)) # m_t = o_t * Tanh(c_t) - c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node() + c_t_tanh = Tanh(graph, {'name': node_name + '/c_t_tanh'}).create_node() sum_f_i.out_port(0).connect(c_t_tanh.in_port(0)) - prod_o_c_t_tanh = Mul(graph, {'name': 'prod_o_c_t_tanh_'}).create_node() + prod_o_c_t_tanh = Mul(graph, {'name': node_name + '/prod_o_c_t_tanh_'}).create_node() o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0)) c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1)) # add concat to create 1 output - concat = Concat(graph, {'name': 'Concat_c_m'}).create_node() + concat = Concat(graph, {'name': node_name + '/concat_c_m'}).create_node() concat.add_sequence_of_ports('in', range(2)) sum_f_i.out_port(0).connect(concat.in_port(0)) prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1)) diff --git a/model-optimizer/extensions/front/mxnet/gather.py b/model-optimizer/extensions/front/mxnet/gather.py index 9b6a2f0e51a..e00aaefac4d 100644 --- a/model-optimizer/extensions/front/mxnet/gather.py +++ b/model-optimizer/extensions/front/mxnet/gather.py @@ -16,8 +16,8 @@ from extensions.ops.gather import Gather from mo.front.common.partial_infer.utils import int64_array from mo.front.common.replacement import FrontReplacementOp +from mo.front.tf.graph_utils import create_op_with_const_inputs from mo.graph.graph import Graph -from mo.ops.const import Const class GatherFrontReplacer(FrontReplacementOp): @@ -26,10 +26,10 @@ class GatherFrontReplacer(FrontReplacementOp): def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] - gather_node = Gather(graph, dict(name=node.id + '/embedding_', - symbol_dict={'name': node.id + '/embedding_'})).create_node() - axis_const = Const(graph, {'value': int64_array(0)}).create_node() + + gather_node = create_op_with_const_inputs(graph, Gather, {2: int64_array(0)}, + {'name': node.soft_get('name', node.id) + '/embedding_'}) + node.in_port(0).get_connection().set_destination(gather_node.in_port(1)) node.in_port(1).get_connection().set_destination(gather_node.in_port(0)) - axis_const.out_port(0).connect(gather_node.in_port(2)) node.out_port(0).get_connection().set_source(gather_node.out_port(0)) diff --git a/model-optimizer/extensions/front/onnx/flattenONNX_to_reshape.py b/model-optimizer/extensions/front/onnx/flattenONNX_to_reshape.py index d4ebcdb675b..bffd69d9a79 100644 --- a/model-optimizer/extensions/front/onnx/flattenONNX_to_reshape.py +++ b/model-optimizer/extensions/front/onnx/flattenONNX_to_reshape.py @@ -51,10 +51,12 @@ class FlattenONNXToReshape(FrontReplacementSubgraph): assert node.has_valid('axis'), 'Flatten {} should have `axis` attribute extracted, but it\'s not'.format(name) axis = node.axis + reshape_node = Reshape(graph, {'name': node.id + '/Reshape'}).create_node() + if axis == 0: - dim = Const(graph, {'value': int64_array([1, -1])}).create_node() + dim = Const(graph, {'value': int64_array([1, -1]), 'name': reshape_node.name + '/shape'}).create_node() elif axis == 1: - dim = Const(graph, {'value': int64_array([0, -1])}).create_node() + dim = Const(graph, {'value': int64_array([0, -1]), 'name': reshape_node.name + '/shape'}).create_node() else: shape = Shape(graph, {'name': name + '/input_shape'}).create_node() @@ -62,8 +64,8 @@ class FlattenONNXToReshape(FrontReplacementSubgraph): axis_shape_portion = node_to_get_shape_value_of_indices(shape, idxs) first_dims = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]), - {'keep_dims': True}) - second_dims = Const(graph, {'value': int64_array([-1])}).create_node() + {'name': name + '/first_dims', 'keep_dims': True}) + second_dims = Const(graph, {'value': int64_array([-1]), 'name': name + '/second_dims'}).create_node() node.in_port(0).get_source().connect(shape.in_port(0)) axis_shape_portion.out_port(0).connect(first_dims.in_port(0)) @@ -72,7 +74,6 @@ class FlattenONNXToReshape(FrontReplacementSubgraph): dim = new_shape_node_from_shape_nodes(order_of_dims) - reshape_node = Reshape(graph, {'name': node.id + '/Reshape'}).create_node() reshape_node.in_port(1).connect(dim.out_port(0)) node.out_port(0).get_connection().set_source(reshape_node.out_port(0)) diff --git a/model-optimizer/extensions/front/onnx/hard_sigmoid_ext.py b/model-optimizer/extensions/front/onnx/hard_sigmoid_ext.py index 981960684b5..ac5297b14d5 100644 --- a/model-optimizer/extensions/front/onnx/hard_sigmoid_ext.py +++ b/model-optimizer/extensions/front/onnx/hard_sigmoid_ext.py @@ -20,7 +20,7 @@ from extensions.ops.hard_sigmoid import HardSigmoid from mo.front.common.replacement import FrontReplacementOp from mo.front.onnx.extractors.utils import onnx_attr from mo.graph.graph import Node, Graph -from mo.ops.const import Const +from mo.front.tf.graph_utils import create_op_with_const_inputs class HardSigmoidFrontExtractor(FrontReplacementOp): @@ -30,11 +30,9 @@ class HardSigmoidFrontExtractor(FrontReplacementOp): def replace_op(self, graph: Graph, node: Node): alpha = onnx_attr(node, 'alpha', 'f', default=0.2) beta = onnx_attr(node, 'beta', 'f', default=0.5) - alpha_node = Const(graph, {'value': np.array(alpha)}).create_node() - beta_node = Const(graph, {'value': np.array(beta)}).create_node() - hard_sigmoid = HardSigmoid(graph, {'name': node.name + '/HardSigmoid_'}).create_node() + hard_sigmoid = create_op_with_const_inputs(graph, HardSigmoid, {1: np.array(alpha), 2: np.array(beta)}, + {'name': node.name + '/HardSigmoid_'}) + node.in_port(0).get_connection().set_destination(hard_sigmoid.in_port(0)) - alpha_node.out_port(0).connect(hard_sigmoid.in_port(1)) - beta_node.out_port(0).connect(hard_sigmoid.in_port(2)) return [hard_sigmoid.id] diff --git a/model-optimizer/extensions/front/tf/FlattenToReshape.py b/model-optimizer/extensions/front/tf/FlattenToReshape.py index 656a503d83b..00927ba503a 100644 --- a/model-optimizer/extensions/front/tf/FlattenToReshape.py +++ b/model-optimizer/extensions/front/tf/FlattenToReshape.py @@ -88,7 +88,8 @@ class FlattenToReshapeableReshape(FrontReplacementSubgraph): return reshape_node.in_port(1).disconnect() - reshape_const_node = Const(graph, {'value': int64_array([0, -1])}).create_node() + reshape_const_node = Const(graph, {'value': int64_array([0, -1]), + 'name': reshape_node.soft_get('name', reshape_node.id) + '/shape'}).create_node() reshape_node.in_port(1).connect(reshape_const_node.out_port(0)) reshape_node['special_zero'] = True log.debug('The node "{}" is actually a Flatten node'.format(reshape_node.soft_get('name'))) diff --git a/model-optimizer/extensions/front/tf/nearest_neighbor_upsampling.py b/model-optimizer/extensions/front/tf/nearest_neighbor_upsampling.py index 80ec52cd8f4..92177cacb73 100644 --- a/model-optimizer/extensions/front/tf/nearest_neighbor_upsampling.py +++ b/model-optimizer/extensions/front/tf/nearest_neighbor_upsampling.py @@ -72,10 +72,10 @@ class NearestNeighborUpsampling(FrontReplacementSubgraph): axes = int64_array([2, 3]) if graph.graph['layout'] == 'NCHW' else int64_array([1, 2]) - const = Const(graph, - {'value': np.array([input_height * height_scale, input_width * width_scale])}).create_node() resample_op = Interpolate(graph, {'name': 'Resample_', 'antialias': 0, 'mode': 'nearest', 'axes': axes}) resample_node = resample_op.create_node([match['op']]) + const = Const(graph, {'value': np.array([input_height * height_scale, input_width * width_scale]), + 'name': resample_node.name + '/target_shape'}).create_node() match['reshape_2'].replace_node(resample_node) diff --git a/model-optimizer/extensions/front/tf/pad_tf_to_pad.py b/model-optimizer/extensions/front/tf/pad_tf_to_pad.py index 0d884b1983f..087199e5763 100644 --- a/model-optimizer/extensions/front/tf/pad_tf_to_pad.py +++ b/model-optimizer/extensions/front/tf/pad_tf_to_pad.py @@ -47,7 +47,8 @@ class PadTFToPad(FrontReplacementPattern): if not tfpad.in_port(2).disconnected(): tfpad.in_port(2).get_connection().set_destination(new_pad.in_port(3)) else: - new_pad.in_port(3).connect(Const(graph, {'value': 0.0}).create_node().out_port(0)) + new_pad.in_port(3).connect(Const(graph, {'value': 0.0, 'name': new_pad.name + '/value'} + ).create_node().out_port(0)) # convert TF representation of the pads as [N, 2] to MO representation: [N] and [N] transposed_pads = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0])}) diff --git a/model-optimizer/mo/front/kaldi/extractors/copy_ext.py b/model-optimizer/mo/front/kaldi/extractors/copy_ext.py index fb6bda68eb7..9acde3a6a0c 100644 --- a/model-optimizer/mo/front/kaldi/extractors/copy_ext.py +++ b/model-optimizer/mo/front/kaldi/extractors/copy_ext.py @@ -21,6 +21,7 @@ from extensions.ops.transpose import Transpose from mo.front.common.partial_infer.utils import int64_array from mo.front.common.replacement import FrontReplacementOp from mo.front.kaldi.loader.utils import read_binary_integer32_token, read_blob +from mo.front.tf.graph_utils import create_op_with_const_inputs from mo.graph.graph import Node, Graph from mo.ops.const import Const @@ -33,26 +34,26 @@ class CopyFrontExtractor(FrontReplacementOp): pb = node.parameters weights_size = read_binary_integer32_token(pb) weights = read_blob(pb, weights_size, dtype=np.int32) - 1 + + node_name = node.soft_get('name', node.id) const_attrs = { - 'name': 'indexes/{}'.format(node.id), + 'name': node_name + '/indexes', 'value': np.array(weights), 'shape': [weights_size], 'data_type': np.int32 } indexes_node = Const(graph).create_node(attrs=const_attrs) - perm_in_1 = Const(graph, {'value': np.array([1, 0], dtype=np.int64), 'shape': [2], 'data_type': np.int64}).create_node() - axis_const = Const(graph, {'value': int64_array(0)}).create_node() - perm1_node = Transpose(graph, {'name': 'input_permute'}).create_node([node.in_node(0)]) + perm_in_1 = Const(graph, {'value': int64_array([1, 0]), 'name': node_name + '/order'}).create_node() + perm1_node = Transpose(graph, {'name': node_name + '/input_permute'}).create_node([node.in_node(0)]) perm1_node.in_port(0).connect(node.in_port(0).get_source()) perm1_node.in_port(1).connect(perm_in_1.out_port(0)) - gather_node = Gather(graph, {}).create_node() + gather_node = create_op_with_const_inputs(graph, Gather, {2: int64_array(0)}, {'name': node_name + '/gather'}) gather_node.in_port(0).connect(perm1_node.out_port(0)) gather_node.in_port(1).connect(indexes_node.out_port(0)) - gather_node.in_port(2).connect(axis_const.out_port(0)) - perm2_node = Transpose(graph, {'name': 'output_permute'}).create_node() + perm2_node = Transpose(graph, {'name': node_name + '/output_permute'}).create_node() perm2_node.in_port(0).connect(gather_node.out_port(0)) perm2_node.in_port(1).connect(perm_in_1.out_port(0)) diff --git a/model-optimizer/mo/graph/perm_inputs.py b/model-optimizer/mo/graph/perm_inputs.py index 63bf2ed667f..16457993119 100644 --- a/model-optimizer/mo/graph/perm_inputs.py +++ b/model-optimizer/mo/graph/perm_inputs.py @@ -56,9 +56,11 @@ def axis(op_node: Node, port_info: str, input_port: int): data_node = op_node.in_node(input_port) - const = Const(graph, {'value': permutation.inv, 'need_shape_inference': True}).create_node_with_data() - axis_const = Const(graph, {'value': int64_array(0)}).create_node_with_data() - gather = Gather(graph, {'name': op_node.name + '/AxisGather', 'need_shape_inference': True}).create_node_with_data( + gather_name = op_node.soft_get('name', op_node.id) + '/AxisGather' + const = Const(graph, {'value': permutation.inv, 'name': gather_name + '/const', + 'need_shape_inference': True}).create_node_with_data() + axis_const = Const(graph, {'value': int64_array(0), 'name': gather_name + '/axis'}).create_node_with_data() + gather = Gather(graph, {'name': gather_name, 'need_shape_inference': True}).create_node_with_data( [const, data_node, axis_const]) attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy() graph.add_edge(gather.id, op_node.id, **attrs) @@ -103,14 +105,18 @@ def order(op_node: Node, port_info: str, input_port: int): data_node = op_node.in_node(input_port) - const = Const(graph, {'value': permutation.perm, 'need_shape_inference': True}).create_node_with_data() - axis_const = Const(graph, {'value': int64_array(0)}).create_node_with_data() - gather = Gather(graph, {'name': op_node.name + '/OrderGather_1', + gather_name = op_node.soft_get('name', op_node.id) + '/OrderGather_1' + const = Const(graph, {'value': permutation.perm, 'name': gather_name + '/const', + 'need_shape_inference': True}).create_node_with_data() + axis_const = Const(graph, {'value': int64_array(0), 'name': gather_name + '/axis'}).create_node_with_data() + gather = Gather(graph, {'name': gather_name, 'need_shape_inference': True}).create_node_with_data([data_node, const, axis_const]) - const_1 = Const(graph, {'value': permutation.inv, 'need_shape_inference': True}).create_node_with_data() - axis_const_1 = Const(graph, {'value': int64_array(0)}).create_node_with_data() - gather_1 = Gather(graph, {'name': op_node.name + '/OrderGather_2', + gather_1_name = op_node.soft_get('name', op_node.id) + '/OrderGather_2' + const_1 = Const(graph, {'value': permutation.inv, 'name': gather_1_name + '/const', + 'need_shape_inference': True}).create_node_with_data() + axis_const_1 = Const(graph, {'value': int64_array(0), 'name': gather_1_name + '/axis'}).create_node_with_data() + gather_1 = Gather(graph, {'name': gather_1_name, 'need_shape_inference': True}).create_node_with_data([const_1, gather, axis_const_1]) attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy() @@ -131,9 +137,11 @@ def shape(op_node: Node, port_info: str, input_port: int): data_node = op_node.in_node(input_port) - const = Const(graph, {'value': permutation.perm, 'need_shape_inference': True}).create_node_with_data() - axis_const = Const(graph, {'value': int64_array(0)}).create_node_with_data() - gather = Gather(graph, {'name': op_node.name + '/ShapeGather', + gather_name = op_node.soft_get('name', op_node.id) + '/ShapeGather' + const = Const(graph, {'value': permutation.perm, 'name': gather_name + '/const', + 'need_shape_inference': True}).create_node_with_data() + axis_const = Const(graph, {'value': int64_array(0), 'name': gather_name + '/axis'}).create_node_with_data() + gather = Gather(graph, {'name': gather_name, 'need_shape_inference': True}).create_node_with_data([data_node, const, axis_const]) attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy() diff --git a/model-optimizer/mo/middle/passes/fusing/fuse_linear_ops.py b/model-optimizer/mo/middle/passes/fusing/fuse_linear_ops.py index cb833c143dc..70730d45f3b 100644 --- a/model-optimizer/mo/middle/passes/fusing/fuse_linear_ops.py +++ b/model-optimizer/mo/middle/passes/fusing/fuse_linear_ops.py @@ -108,9 +108,10 @@ def _fuse_mul(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True) value = np.reshape(value, shape) # Weights multiplication - mul_const = Const(graph, {'value': value}).create_node() - w_mul = node.copy_node({'in_ports_count': len(node.in_ports()), 'out_ports_count': len(node.out_ports()), - 'can_be_fused': False}) + mul_name = node.name + '_copy' + mul_const = Const(graph, {'value': value, 'name': mul_name + '/const'}).create_node() + w_mul = node.copy_node({'name': mul_name, 'in_ports_count': len(node.in_ports()), + 'out_ports_count': len(node.out_ports()), 'can_be_fused': False}) w_mul.in_port(const_port.idx).connect(mul_const.out_port(0)) w_const = weights_port.get_source() weights_port.get_connection().set_source(w_mul.out_port(0)) diff --git a/model-optimizer/mo/utils/shape.py b/model-optimizer/mo/utils/shape.py index c4ae63babb4..ec725a276ff 100644 --- a/model-optimizer/mo/utils/shape.py +++ b/model-optimizer/mo/utils/shape.py @@ -65,14 +65,14 @@ def get_range_node_of_idxs(rank: Node, begin: int, end: int, end_idx = get_canonical_axis_index_node(rank, end) if not include_begin: - const = Const(graph, {'value': int64_array([1])}).create_node() + const = Const(graph, {'value': int64_array([1]), 'name': name + '/exclude_begin/value'}).create_node() add = Add(graph, {'name': name + '/exclude_begin'}).create_node() start_idx.out_port(0).connect(add.in_port(0)) const.out_port(0).connect(add.in_port(1)) start_idx = add if include_end: - const = Const(graph, {'value': int64_array([1])}).create_node() + const = Const(graph, {'value': int64_array([1]), 'name': name + '/including_end/value'}).create_node() add = Add(graph, {'name': name + '/including_end'}).create_node() end_idx.out_port(0).connect(add.in_port(0)) const.out_port(0).connect(add.in_port(1)) @@ -187,7 +187,10 @@ def new_shape_node_from_shape_nodes(input_shape_nodes: list): :return: the node producing concatenated values of nodes from the "input_shape_nodes" """ assert len(input_shape_nodes) > 0, 'The list of input shape nodes should be non-empty' - new_shape_node = Concat(input_shape_nodes[0].graph, {'axis': 0}).create_node() + new_shape_node = Concat(input_shape_nodes[0].graph, + {'axis': 0, + 'name': input_shape_nodes[0].soft_get('name', input_shape_nodes[0].id) + '/shapes_concat'} + ).create_node() for ind, input_node in enumerate(input_shape_nodes): new_shape_node.add_input_port(ind)