Fix const node non-deterministic names (part 1) (#996)

* Update node names
This commit is contained in:
Anton Chetverikov 2020-06-26 13:41:49 +03:00 committed by GitHub
parent 0cdc549911
commit 5aa9ffbfe3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 117 additions and 96 deletions

View File

@ -15,16 +15,15 @@
"""
from collections import deque
import numpy as np
from extensions.front.MatMul_normalizer import FullyConnectedDecomposer
from extensions.front.kaldi.add_reshape_around_convolution import ReplaceConvolutionReshape
from extensions.middle.TensorIteratorMerge import op_type
from extensions.ops.activation_ops import activation_ops
from extensions.ops.transpose import Transpose
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Node, Graph
from mo.ops.const import Const
class ReplaceConvolutionTranspose(FrontReplacementSubgraph):
@ -61,10 +60,9 @@ class ReplaceConvolutionTranspose(FrontReplacementSubgraph):
convolution_nodes = [node for node in nodes_with_weights if Node(graph, node).op == 'Convolution']
for convolution_node in convolution_nodes:
target_node = self.search_target_node(Node(graph, convolution_node))
order_const = Const(graph, dict(value=np.array([0, 3, 2, 1]))).create_node()
permute_node = Transpose(graph, dict(name=target_node.name + '/Transpose')).create_node()
permute_node = create_op_with_const_inputs(graph, Transpose, {1: int64_array([0, 3, 2, 1])},
{'name': target_node.name + '/Transpose'})
target_node.insert_node_after(permute_node, 0)
order_const.out_port(0).connect(permute_node.in_port(1))
def run_after(self):
from extensions.front.flatten_to_reshape import FlattenToReshape

View File

@ -39,11 +39,12 @@ class ReplaceConvolutionTransposeTests(unittest.TestCase):
('conv', 'reshape_conv'),
('reshape_conv', 'scale_shift'),
])
graph.stage = 'front'
ReplaceConvolutionTranspose().find_and_replace_pattern(graph)
conv_node = Node(graph, graph.nodes['conv']['name'])
permute = conv_node.out_node()
self.assertEqual(permute.op, 'Transpose')
self.assertTrue(np.array_equal(permute.in_node(1).in_node().value, np.array([0, 3, 2, 1])))
self.assertTrue(np.array_equal(permute.in_node(1).value, np.array([0, 3, 2, 1])))
def test_conv_pool(self):
graph = build_graph(self.nodes_attributes, [
@ -53,11 +54,12 @@ class ReplaceConvolutionTransposeTests(unittest.TestCase):
('pool', 'reshape_after_pool'),
('reshape_after_pool', 'fc'),
])
graph.stage = 'front'
ReplaceConvolutionTranspose().find_and_replace_pattern(graph)
pool_node = Node(graph, graph.nodes['pool']['name'])
permute = pool_node.out_node()
self.assertEqual(permute.op, 'Transpose')
self.assertTrue(np.array_equal(permute.in_node(1).in_node().value, np.array([0, 3, 2, 1])))
self.assertTrue(np.array_equal(permute.in_node(1).value, np.array([0, 3, 2, 1])))
def test_conv_act_pool(self):
graph = build_graph(self.nodes_attributes, [
@ -68,8 +70,9 @@ class ReplaceConvolutionTransposeTests(unittest.TestCase):
('pool', 'reshape_after_pool'),
('reshape_after_pool', 'fc'),
])
graph.stage = 'front'
ReplaceConvolutionTranspose().find_and_replace_pattern(graph)
pool_node = Node(graph, graph.nodes['pool']['name'])
permute = pool_node.out_node()
self.assertEqual(permute.op, 'Transpose')
self.assertTrue(np.array_equal(permute.in_node(1).in_node().value, np.array([0, 3, 2, 1])))
self.assertTrue(np.array_equal(permute.in_node(1).value, np.array([0, 3, 2, 1])))

View File

@ -53,35 +53,38 @@ class ReplaceConvolutionReshape(FrontReplacementPattern):
@staticmethod
def replace_pattern(graph: Graph, match: dict):
node = match['conv']
node_name = node.soft_get('name', node.id)
# create Reshape before convolution
# shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
shape = Shape(graph, {}).create_node()
shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
shape.in_port(0).connect(node.in_port(0).get_source())
split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([1, -1])}, {'out_ports_count': 2}, shape)
conv_patch_stride = Const(graph, {'value': int64_array([node.patch_stride])}).create_node()
pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1]))
split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([1, -1])},
{'name': shape.name + '/split_batch', 'out_ports_count': 2}, shape)
pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1]), {'name': node_name + '/patch_stride/inverse'})
conv_patch_stride = Const(graph, {'value': int64_array([node.patch_stride]),
'name': node_name + '/patch_stride/'}).create_node()
pow_node.in_port(0).connect(conv_patch_stride.out_port(0))
mul = Mul(graph, {}).create_node()
mul = Mul(graph, {'name': node_name + '/mul_inverse_stride_h'}).create_node()
mul.in_port(0).connect(split.out_port(1))
mul.in_port(1).connect(pow_node.out_port(0))
const_1 = Const(graph, {'value': int64_array([1])}).create_node()
concat = create_op_with_const_inputs(graph, Concat, {2: int64_array([1])},
{'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})
concat = Concat(graph, {'in_ports_count': 4, 'axis': 0}).create_node()
concat.in_port(0).connect(split.out_port(0))
concat.in_port(1).connect(mul.out_port(0))
concat.in_port(2).connect(const_1.out_port(0))
concat.in_port(3).connect(conv_patch_stride.out_port(0))
reshape_in = Reshape(graph, {'name': '/Reshape/' + node.name}).create_node()
reshape_in = Reshape(graph, {'name': node_name + '/reshape_in'}).create_node()
reshape_in.in_port(1).connect(concat.out_port(0))
# create Reshape after Convolution
reshape_out = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
{'name': node.name + '/Reshape/'})
{'name': node_name + '/reshape_out'})
# connect input_reshape_node
source = node.in_port(0).get_source()

View File

@ -49,38 +49,41 @@ class ReplacePoolingReshape(FrontReplacementPattern):
@staticmethod
def replace_pattern(graph: Graph, match: dict):
node = match['pool']
node_name = node.soft_get('name', node.id)
if node.pool_step is None:
node.stride = int64_array([1, 1, node.window[-1], node.window[-1]])
# create Reshape before convolution
# shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
shape = Shape(graph, {}).create_node()
shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
shape.in_port(0).connect(node.in_port(0).get_source())
split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([1, -1])}, {'out_ports_count': 2}, shape)
node_pool_stride = Const(graph, {'value': int64_array([node.pool_stride])}).create_node()
pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1]))
split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([1, -1])},
{'name': shape.name + '/split_batch', 'out_ports_count': 2}, shape)
pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1]), {'name': node_name + '/pool_stride/inverse'})
node_pool_stride = Const(graph, {'value': int64_array([node.pool_stride]),
'name': node_name + '/pool_stride/'}).create_node()
pow_node.in_port(0).connect(node_pool_stride.out_port(0))
mul = Mul(graph, {}).create_node()
mul = Mul(graph, {'name': node_name + '/mul_inverse_stride_h'}).create_node()
mul.in_port(0).connect(split.out_port(1))
mul.in_port(1).connect(pow_node.out_port(0))
const_1 = Const(graph, {'value': int64_array([1])}).create_node()
concat = create_op_with_const_inputs(graph, Concat, {2: int64_array([1])},
{'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})
concat = Concat(graph, {'in_ports_count': 4, 'axis': 0}).create_node()
concat.in_port(0).connect(split.out_port(0))
concat.in_port(3).connect(mul.out_port(0))
concat.in_port(2).connect(const_1.out_port(0))
concat.in_port(1).connect(node_pool_stride.out_port(0))
reshape_in = Reshape(graph, {'name': '/Reshape/' + node.name}).create_node()
reshape_in = Reshape(graph, {'name': node_name + '/reshape_in'}).create_node()
reshape_in.in_port(1).connect(concat.out_port(0))
# create Reshape after Convolution
reshape_out = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
{'name': node.name + '/Reshape/'})
{'name': node_name + '/reshape_out'})
# connect input_reshape_node
source = node.in_port(0).get_source()

View File

@ -16,13 +16,13 @@
import numpy as np
from extensions.ops.activation_ops import Sigmoid, Tanh
from extensions.ops.elementwise import Add, Mul
from extensions.ops.split import Split
from mo.front.caffe.extractors.utils import input_as_const
from mo.front.common.replacement import FrontReplacementOp
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Node, Graph
from mo.ops.concat import Concat
from mo.ops.const import Const
from extensions.ops.elementwise import Add, Mul
from mo.ops.scale_shift import ScaleShiftOp
@ -40,80 +40,80 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
def replace_op(self, graph: Graph, node: Node):
# split input to (i_part, f_part, c_part, o_part, ct_1)
split_node_axis = Const(graph, {'value': np.int64(1)}).create_node()
split_node = Split(graph, {'name': 'Split_lstm_input_',
'num_splits': 5}).create_node()
node_name = node.soft_get('name', node.id)
split_node = create_op_with_const_inputs(graph, Split, {1: np.int64(1)},
{'name': node_name + '/split_lstm_input',
'num_splits': 5})
node.in_port(0).get_connection().set_destination(split_node.in_port(0))
split_node.in_port(1).connect(split_node_axis.out_port(0))
# i_t = Sigmoid(i_part + w_ic*ct_1)
i_scale_attrs = {'name': 'i_scaleshift',
i_scale_attrs = {'name': node_name + '/i_scaleshift',
'bias_term': False}
i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node()
input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights)
split_node.out_port(4).connect(i_scale.in_port(0))
sum_i_c = Add(graph, {'name': 'sum_i_c_'}).create_node()
sum_i_c = Add(graph, {'name': node_name + '/sum_i_c_'}).create_node()
split_node.out_port(0).connect(sum_i_c.in_port(0))
i_scale.out_port(0).connect(sum_i_c.in_port(1))
i_sigmoid = Sigmoid(graph, {'name': 'i_sigmoid'}).create_node()
i_sigmoid = Sigmoid(graph, {'name': node_name + '/i_sigmoid'}).create_node()
sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))
# f_t = Sigmoid(f_part + w_fc*ct_1)
f_scale_attrs = {'name': 'f_scaleshift',
f_scale_attrs = {'name': node_name + '/f_scaleshift',
'bias_term': False}
f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node()
input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights)
split_node.out_port(4).connect(f_scale.in_port(0))
sum_f_c = Add(graph, {'name': 'sum_f_c_'}).create_node()
sum_f_c = Add(graph, {'name': node_name + '/sum_f_c_'}).create_node()
split_node.out_port(1).connect(sum_f_c.in_port(0))
f_scale.out_port(0).connect(sum_f_c.in_port(1))
f_sigmoid = Sigmoid(graph, {'name': 'f_sigmoid'}).create_node()
f_sigmoid = Sigmoid(graph, {'name': node_name + '/f_sigmoid'}).create_node()
sum_f_c.out_port(0).connect(f_sigmoid.in_port(0))
# c_t = f_t*ct_1 + i_t * tanh(c_part)
c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node()
c_tanh = Tanh(graph, {'name': node_name + '/c_tanh'}).create_node()
split_node.out_port(2).connect(c_tanh.in_port(0))
prod_i_c_tanh = Mul(graph, {'name': 'prod_i_c_tanh_'}).create_node()
prod_i_c_tanh = Mul(graph, {'name': node_name + '/prod_i_c_tanh_'}).create_node()
i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0))
c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1))
prod_f_ct_1 = Mul(graph, {'name': 'prod_f_ct_1_'}).create_node()
prod_f_ct_1 = Mul(graph, {'name': node_name + '/prod_f_ct_1_'}).create_node()
f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0))
split_node.out_port(4).connect(prod_f_ct_1.in_port(1))
sum_f_i = Add(graph, {'name': 'sum_f_i_'}).create_node()
sum_f_i = Add(graph, {'name': node_name + '/sum_f_i_'}).create_node()
prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0))
prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1))
# o_t = Sigmoid(o_part + w_oc*c_t)
o_scale_attrs = {'name': 'o_scaleshift',
o_scale_attrs = {'name': node_name + '/o_scaleshift',
'bias_term': False}
o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node()
input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights)
sum_f_i.out_port(0).connect(o_scale.in_port(0))
sum_o_c = Add(graph, {'name': 'sum_o_c_'}).create_node()
sum_o_c = Add(graph, {'name': node_name + '/sum_o_c_'}).create_node()
split_node.out_port(3).connect(sum_o_c.in_port(0))
o_scale.out_port(0).connect(sum_o_c.in_port(1))
o_sigmoid = Sigmoid(graph, {'name': 'o_sigmoid'}).create_node()
o_sigmoid = Sigmoid(graph, {'name': node_name + '/o_sigmoid'}).create_node()
sum_o_c.out_port(0).connect(o_sigmoid.in_port(0))
# m_t = o_t * Tanh(c_t)
c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node()
c_t_tanh = Tanh(graph, {'name': node_name + '/c_t_tanh'}).create_node()
sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))
prod_o_c_t_tanh = Mul(graph, {'name': 'prod_o_c_t_tanh_'}).create_node()
prod_o_c_t_tanh = Mul(graph, {'name': node_name + '/prod_o_c_t_tanh_'}).create_node()
o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0))
c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1))
# add concat to create 1 output
concat = Concat(graph, {'name': 'Concat_c_m'}).create_node()
concat = Concat(graph, {'name': node_name + '/concat_c_m'}).create_node()
concat.add_sequence_of_ports('in', range(2))
sum_f_i.out_port(0).connect(concat.in_port(0))
prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1))

View File

@ -16,8 +16,8 @@
from extensions.ops.gather import Gather
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementOp
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph
from mo.ops.const import Const
class GatherFrontReplacer(FrontReplacementOp):
@ -26,10 +26,10 @@ class GatherFrontReplacer(FrontReplacementOp):
def replace_sub_graph(self, graph: Graph, match: dict):
node = match['op']
gather_node = Gather(graph, dict(name=node.id + '/embedding_',
symbol_dict={'name': node.id + '/embedding_'})).create_node()
axis_const = Const(graph, {'value': int64_array(0)}).create_node()
gather_node = create_op_with_const_inputs(graph, Gather, {2: int64_array(0)},
{'name': node.soft_get('name', node.id) + '/embedding_'})
node.in_port(0).get_connection().set_destination(gather_node.in_port(1))
node.in_port(1).get_connection().set_destination(gather_node.in_port(0))
axis_const.out_port(0).connect(gather_node.in_port(2))
node.out_port(0).get_connection().set_source(gather_node.out_port(0))

View File

@ -51,10 +51,12 @@ class FlattenONNXToReshape(FrontReplacementSubgraph):
assert node.has_valid('axis'), 'Flatten {} should have `axis` attribute extracted, but it\'s not'.format(name)
axis = node.axis
reshape_node = Reshape(graph, {'name': node.id + '/Reshape'}).create_node()
if axis == 0:
dim = Const(graph, {'value': int64_array([1, -1])}).create_node()
dim = Const(graph, {'value': int64_array([1, -1]), 'name': reshape_node.name + '/shape'}).create_node()
elif axis == 1:
dim = Const(graph, {'value': int64_array([0, -1])}).create_node()
dim = Const(graph, {'value': int64_array([0, -1]), 'name': reshape_node.name + '/shape'}).create_node()
else:
shape = Shape(graph, {'name': name + '/input_shape'}).create_node()
@ -62,8 +64,8 @@ class FlattenONNXToReshape(FrontReplacementSubgraph):
axis_shape_portion = node_to_get_shape_value_of_indices(shape, idxs)
first_dims = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]),
{'keep_dims': True})
second_dims = Const(graph, {'value': int64_array([-1])}).create_node()
{'name': name + '/first_dims', 'keep_dims': True})
second_dims = Const(graph, {'value': int64_array([-1]), 'name': name + '/second_dims'}).create_node()
node.in_port(0).get_source().connect(shape.in_port(0))
axis_shape_portion.out_port(0).connect(first_dims.in_port(0))
@ -72,7 +74,6 @@ class FlattenONNXToReshape(FrontReplacementSubgraph):
dim = new_shape_node_from_shape_nodes(order_of_dims)
reshape_node = Reshape(graph, {'name': node.id + '/Reshape'}).create_node()
reshape_node.in_port(1).connect(dim.out_port(0))
node.out_port(0).get_connection().set_source(reshape_node.out_port(0))

View File

@ -20,7 +20,7 @@ from extensions.ops.hard_sigmoid import HardSigmoid
from mo.front.common.replacement import FrontReplacementOp
from mo.front.onnx.extractors.utils import onnx_attr
from mo.graph.graph import Node, Graph
from mo.ops.const import Const
from mo.front.tf.graph_utils import create_op_with_const_inputs
class HardSigmoidFrontExtractor(FrontReplacementOp):
@ -30,11 +30,9 @@ class HardSigmoidFrontExtractor(FrontReplacementOp):
def replace_op(self, graph: Graph, node: Node):
alpha = onnx_attr(node, 'alpha', 'f', default=0.2)
beta = onnx_attr(node, 'beta', 'f', default=0.5)
alpha_node = Const(graph, {'value': np.array(alpha)}).create_node()
beta_node = Const(graph, {'value': np.array(beta)}).create_node()
hard_sigmoid = HardSigmoid(graph, {'name': node.name + '/HardSigmoid_'}).create_node()
hard_sigmoid = create_op_with_const_inputs(graph, HardSigmoid, {1: np.array(alpha), 2: np.array(beta)},
{'name': node.name + '/HardSigmoid_'})
node.in_port(0).get_connection().set_destination(hard_sigmoid.in_port(0))
alpha_node.out_port(0).connect(hard_sigmoid.in_port(1))
beta_node.out_port(0).connect(hard_sigmoid.in_port(2))
return [hard_sigmoid.id]

View File

@ -88,7 +88,8 @@ class FlattenToReshapeableReshape(FrontReplacementSubgraph):
return
reshape_node.in_port(1).disconnect()
reshape_const_node = Const(graph, {'value': int64_array([0, -1])}).create_node()
reshape_const_node = Const(graph, {'value': int64_array([0, -1]),
'name': reshape_node.soft_get('name', reshape_node.id) + '/shape'}).create_node()
reshape_node.in_port(1).connect(reshape_const_node.out_port(0))
reshape_node['special_zero'] = True
log.debug('The node "{}" is actually a Flatten node'.format(reshape_node.soft_get('name')))

View File

@ -72,10 +72,10 @@ class NearestNeighborUpsampling(FrontReplacementSubgraph):
axes = int64_array([2, 3]) if graph.graph['layout'] == 'NCHW' else int64_array([1, 2])
const = Const(graph,
{'value': np.array([input_height * height_scale, input_width * width_scale])}).create_node()
resample_op = Interpolate(graph, {'name': 'Resample_', 'antialias': 0, 'mode': 'nearest', 'axes': axes})
resample_node = resample_op.create_node([match['op']])
const = Const(graph, {'value': np.array([input_height * height_scale, input_width * width_scale]),
'name': resample_node.name + '/target_shape'}).create_node()
match['reshape_2'].replace_node(resample_node)

View File

@ -47,7 +47,8 @@ class PadTFToPad(FrontReplacementPattern):
if not tfpad.in_port(2).disconnected():
tfpad.in_port(2).get_connection().set_destination(new_pad.in_port(3))
else:
new_pad.in_port(3).connect(Const(graph, {'value': 0.0}).create_node().out_port(0))
new_pad.in_port(3).connect(Const(graph, {'value': 0.0, 'name': new_pad.name + '/value'}
).create_node().out_port(0))
# convert TF representation of the pads as [N, 2] to MO representation: [N] and [N]
transposed_pads = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0])})

View File

@ -21,6 +21,7 @@ from extensions.ops.transpose import Transpose
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementOp
from mo.front.kaldi.loader.utils import read_binary_integer32_token, read_blob
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Node, Graph
from mo.ops.const import Const
@ -33,26 +34,26 @@ class CopyFrontExtractor(FrontReplacementOp):
pb = node.parameters
weights_size = read_binary_integer32_token(pb)
weights = read_blob(pb, weights_size, dtype=np.int32) - 1
node_name = node.soft_get('name', node.id)
const_attrs = {
'name': 'indexes/{}'.format(node.id),
'name': node_name + '/indexes',
'value': np.array(weights),
'shape': [weights_size],
'data_type': np.int32
}
indexes_node = Const(graph).create_node(attrs=const_attrs)
perm_in_1 = Const(graph, {'value': np.array([1, 0], dtype=np.int64), 'shape': [2], 'data_type': np.int64}).create_node()
axis_const = Const(graph, {'value': int64_array(0)}).create_node()
perm1_node = Transpose(graph, {'name': 'input_permute'}).create_node([node.in_node(0)])
perm_in_1 = Const(graph, {'value': int64_array([1, 0]), 'name': node_name + '/order'}).create_node()
perm1_node = Transpose(graph, {'name': node_name + '/input_permute'}).create_node([node.in_node(0)])
perm1_node.in_port(0).connect(node.in_port(0).get_source())
perm1_node.in_port(1).connect(perm_in_1.out_port(0))
gather_node = Gather(graph, {}).create_node()
gather_node = create_op_with_const_inputs(graph, Gather, {2: int64_array(0)}, {'name': node_name + '/gather'})
gather_node.in_port(0).connect(perm1_node.out_port(0))
gather_node.in_port(1).connect(indexes_node.out_port(0))
gather_node.in_port(2).connect(axis_const.out_port(0))
perm2_node = Transpose(graph, {'name': 'output_permute'}).create_node()
perm2_node = Transpose(graph, {'name': node_name + '/output_permute'}).create_node()
perm2_node.in_port(0).connect(gather_node.out_port(0))
perm2_node.in_port(1).connect(perm_in_1.out_port(0))

View File

@ -56,9 +56,11 @@ def axis(op_node: Node, port_info: str, input_port: int):
data_node = op_node.in_node(input_port)
const = Const(graph, {'value': permutation.inv, 'need_shape_inference': True}).create_node_with_data()
axis_const = Const(graph, {'value': int64_array(0)}).create_node_with_data()
gather = Gather(graph, {'name': op_node.name + '/AxisGather', 'need_shape_inference': True}).create_node_with_data(
gather_name = op_node.soft_get('name', op_node.id) + '/AxisGather'
const = Const(graph, {'value': permutation.inv, 'name': gather_name + '/const',
'need_shape_inference': True}).create_node_with_data()
axis_const = Const(graph, {'value': int64_array(0), 'name': gather_name + '/axis'}).create_node_with_data()
gather = Gather(graph, {'name': gather_name, 'need_shape_inference': True}).create_node_with_data(
[const, data_node, axis_const])
attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy()
graph.add_edge(gather.id, op_node.id, **attrs)
@ -103,14 +105,18 @@ def order(op_node: Node, port_info: str, input_port: int):
data_node = op_node.in_node(input_port)
const = Const(graph, {'value': permutation.perm, 'need_shape_inference': True}).create_node_with_data()
axis_const = Const(graph, {'value': int64_array(0)}).create_node_with_data()
gather = Gather(graph, {'name': op_node.name + '/OrderGather_1',
gather_name = op_node.soft_get('name', op_node.id) + '/OrderGather_1'
const = Const(graph, {'value': permutation.perm, 'name': gather_name + '/const',
'need_shape_inference': True}).create_node_with_data()
axis_const = Const(graph, {'value': int64_array(0), 'name': gather_name + '/axis'}).create_node_with_data()
gather = Gather(graph, {'name': gather_name,
'need_shape_inference': True}).create_node_with_data([data_node, const, axis_const])
const_1 = Const(graph, {'value': permutation.inv, 'need_shape_inference': True}).create_node_with_data()
axis_const_1 = Const(graph, {'value': int64_array(0)}).create_node_with_data()
gather_1 = Gather(graph, {'name': op_node.name + '/OrderGather_2',
gather_1_name = op_node.soft_get('name', op_node.id) + '/OrderGather_2'
const_1 = Const(graph, {'value': permutation.inv, 'name': gather_1_name + '/const',
'need_shape_inference': True}).create_node_with_data()
axis_const_1 = Const(graph, {'value': int64_array(0), 'name': gather_1_name + '/axis'}).create_node_with_data()
gather_1 = Gather(graph, {'name': gather_1_name,
'need_shape_inference': True}).create_node_with_data([const_1, gather, axis_const_1])
attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy()
@ -131,9 +137,11 @@ def shape(op_node: Node, port_info: str, input_port: int):
data_node = op_node.in_node(input_port)
const = Const(graph, {'value': permutation.perm, 'need_shape_inference': True}).create_node_with_data()
axis_const = Const(graph, {'value': int64_array(0)}).create_node_with_data()
gather = Gather(graph, {'name': op_node.name + '/ShapeGather',
gather_name = op_node.soft_get('name', op_node.id) + '/ShapeGather'
const = Const(graph, {'value': permutation.perm, 'name': gather_name + '/const',
'need_shape_inference': True}).create_node_with_data()
axis_const = Const(graph, {'value': int64_array(0), 'name': gather_name + '/axis'}).create_node_with_data()
gather = Gather(graph, {'name': gather_name,
'need_shape_inference': True}).create_node_with_data([data_node, const, axis_const])
attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy()

View File

@ -108,9 +108,10 @@ def _fuse_mul(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True)
value = np.reshape(value, shape)
# Weights multiplication
mul_const = Const(graph, {'value': value}).create_node()
w_mul = node.copy_node({'in_ports_count': len(node.in_ports()), 'out_ports_count': len(node.out_ports()),
'can_be_fused': False})
mul_name = node.name + '_copy'
mul_const = Const(graph, {'value': value, 'name': mul_name + '/const'}).create_node()
w_mul = node.copy_node({'name': mul_name, 'in_ports_count': len(node.in_ports()),
'out_ports_count': len(node.out_ports()), 'can_be_fused': False})
w_mul.in_port(const_port.idx).connect(mul_const.out_port(0))
w_const = weights_port.get_source()
weights_port.get_connection().set_source(w_mul.out_port(0))

View File

@ -65,14 +65,14 @@ def get_range_node_of_idxs(rank: Node, begin: int, end: int,
end_idx = get_canonical_axis_index_node(rank, end)
if not include_begin:
const = Const(graph, {'value': int64_array([1])}).create_node()
const = Const(graph, {'value': int64_array([1]), 'name': name + '/exclude_begin/value'}).create_node()
add = Add(graph, {'name': name + '/exclude_begin'}).create_node()
start_idx.out_port(0).connect(add.in_port(0))
const.out_port(0).connect(add.in_port(1))
start_idx = add
if include_end:
const = Const(graph, {'value': int64_array([1])}).create_node()
const = Const(graph, {'value': int64_array([1]), 'name': name + '/including_end/value'}).create_node()
add = Add(graph, {'name': name + '/including_end'}).create_node()
end_idx.out_port(0).connect(add.in_port(0))
const.out_port(0).connect(add.in_port(1))
@ -187,7 +187,10 @@ def new_shape_node_from_shape_nodes(input_shape_nodes: list):
:return: the node producing concatenated values of nodes from the "input_shape_nodes"
"""
assert len(input_shape_nodes) > 0, 'The list of input shape nodes should be non-empty'
new_shape_node = Concat(input_shape_nodes[0].graph, {'axis': 0}).create_node()
new_shape_node = Concat(input_shape_nodes[0].graph,
{'axis': 0,
'name': input_shape_nodes[0].soft_get('name', input_shape_nodes[0].id) + '/shapes_concat'}
).create_node()
for ind, input_node in enumerate(input_shape_nodes):
new_shape_node.add_input_port(ind)