Fix const node non-deterministic names (part 1) (#996)
* Update node names
This commit is contained in:
parent
0cdc549911
commit
5aa9ffbfe3
@ -15,16 +15,15 @@
|
||||
"""
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.front.MatMul_normalizer import FullyConnectedDecomposer
|
||||
from extensions.front.kaldi.add_reshape_around_convolution import ReplaceConvolutionReshape
|
||||
from extensions.middle.TensorIteratorMerge import op_type
|
||||
from extensions.ops.activation_ops import activation_ops
|
||||
from extensions.ops.transpose import Transpose
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.ops.const import Const
|
||||
|
||||
|
||||
class ReplaceConvolutionTranspose(FrontReplacementSubgraph):
|
||||
@ -61,10 +60,9 @@ class ReplaceConvolutionTranspose(FrontReplacementSubgraph):
|
||||
convolution_nodes = [node for node in nodes_with_weights if Node(graph, node).op == 'Convolution']
|
||||
for convolution_node in convolution_nodes:
|
||||
target_node = self.search_target_node(Node(graph, convolution_node))
|
||||
order_const = Const(graph, dict(value=np.array([0, 3, 2, 1]))).create_node()
|
||||
permute_node = Transpose(graph, dict(name=target_node.name + '/Transpose')).create_node()
|
||||
permute_node = create_op_with_const_inputs(graph, Transpose, {1: int64_array([0, 3, 2, 1])},
|
||||
{'name': target_node.name + '/Transpose'})
|
||||
target_node.insert_node_after(permute_node, 0)
|
||||
order_const.out_port(0).connect(permute_node.in_port(1))
|
||||
|
||||
def run_after(self):
|
||||
from extensions.front.flatten_to_reshape import FlattenToReshape
|
||||
|
@ -39,11 +39,12 @@ class ReplaceConvolutionTransposeTests(unittest.TestCase):
|
||||
('conv', 'reshape_conv'),
|
||||
('reshape_conv', 'scale_shift'),
|
||||
])
|
||||
graph.stage = 'front'
|
||||
ReplaceConvolutionTranspose().find_and_replace_pattern(graph)
|
||||
conv_node = Node(graph, graph.nodes['conv']['name'])
|
||||
permute = conv_node.out_node()
|
||||
self.assertEqual(permute.op, 'Transpose')
|
||||
self.assertTrue(np.array_equal(permute.in_node(1).in_node().value, np.array([0, 3, 2, 1])))
|
||||
self.assertTrue(np.array_equal(permute.in_node(1).value, np.array([0, 3, 2, 1])))
|
||||
|
||||
def test_conv_pool(self):
|
||||
graph = build_graph(self.nodes_attributes, [
|
||||
@ -53,11 +54,12 @@ class ReplaceConvolutionTransposeTests(unittest.TestCase):
|
||||
('pool', 'reshape_after_pool'),
|
||||
('reshape_after_pool', 'fc'),
|
||||
])
|
||||
graph.stage = 'front'
|
||||
ReplaceConvolutionTranspose().find_and_replace_pattern(graph)
|
||||
pool_node = Node(graph, graph.nodes['pool']['name'])
|
||||
permute = pool_node.out_node()
|
||||
self.assertEqual(permute.op, 'Transpose')
|
||||
self.assertTrue(np.array_equal(permute.in_node(1).in_node().value, np.array([0, 3, 2, 1])))
|
||||
self.assertTrue(np.array_equal(permute.in_node(1).value, np.array([0, 3, 2, 1])))
|
||||
|
||||
def test_conv_act_pool(self):
|
||||
graph = build_graph(self.nodes_attributes, [
|
||||
@ -68,8 +70,9 @@ class ReplaceConvolutionTransposeTests(unittest.TestCase):
|
||||
('pool', 'reshape_after_pool'),
|
||||
('reshape_after_pool', 'fc'),
|
||||
])
|
||||
graph.stage = 'front'
|
||||
ReplaceConvolutionTranspose().find_and_replace_pattern(graph)
|
||||
pool_node = Node(graph, graph.nodes['pool']['name'])
|
||||
permute = pool_node.out_node()
|
||||
self.assertEqual(permute.op, 'Transpose')
|
||||
self.assertTrue(np.array_equal(permute.in_node(1).in_node().value, np.array([0, 3, 2, 1])))
|
||||
self.assertTrue(np.array_equal(permute.in_node(1).value, np.array([0, 3, 2, 1])))
|
||||
|
@ -53,35 +53,38 @@ class ReplaceConvolutionReshape(FrontReplacementPattern):
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
node = match['conv']
|
||||
node_name = node.soft_get('name', node.id)
|
||||
|
||||
# create Reshape before convolution
|
||||
# shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
|
||||
shape = Shape(graph, {}).create_node()
|
||||
shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
|
||||
shape.in_port(0).connect(node.in_port(0).get_source())
|
||||
|
||||
split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([1, -1])}, {'out_ports_count': 2}, shape)
|
||||
conv_patch_stride = Const(graph, {'value': int64_array([node.patch_stride])}).create_node()
|
||||
pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1]))
|
||||
split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([1, -1])},
|
||||
{'name': shape.name + '/split_batch', 'out_ports_count': 2}, shape)
|
||||
|
||||
pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1]), {'name': node_name + '/patch_stride/inverse'})
|
||||
conv_patch_stride = Const(graph, {'value': int64_array([node.patch_stride]),
|
||||
'name': node_name + '/patch_stride/'}).create_node()
|
||||
pow_node.in_port(0).connect(conv_patch_stride.out_port(0))
|
||||
|
||||
mul = Mul(graph, {}).create_node()
|
||||
mul = Mul(graph, {'name': node_name + '/mul_inverse_stride_h'}).create_node()
|
||||
mul.in_port(0).connect(split.out_port(1))
|
||||
mul.in_port(1).connect(pow_node.out_port(0))
|
||||
|
||||
const_1 = Const(graph, {'value': int64_array([1])}).create_node()
|
||||
concat = create_op_with_const_inputs(graph, Concat, {2: int64_array([1])},
|
||||
{'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})
|
||||
|
||||
concat = Concat(graph, {'in_ports_count': 4, 'axis': 0}).create_node()
|
||||
concat.in_port(0).connect(split.out_port(0))
|
||||
concat.in_port(1).connect(mul.out_port(0))
|
||||
concat.in_port(2).connect(const_1.out_port(0))
|
||||
concat.in_port(3).connect(conv_patch_stride.out_port(0))
|
||||
|
||||
reshape_in = Reshape(graph, {'name': '/Reshape/' + node.name}).create_node()
|
||||
reshape_in = Reshape(graph, {'name': node_name + '/reshape_in'}).create_node()
|
||||
reshape_in.in_port(1).connect(concat.out_port(0))
|
||||
|
||||
# create Reshape after Convolution
|
||||
reshape_out = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
|
||||
{'name': node.name + '/Reshape/'})
|
||||
{'name': node_name + '/reshape_out'})
|
||||
|
||||
# connect input_reshape_node
|
||||
source = node.in_port(0).get_source()
|
||||
|
@ -49,38 +49,41 @@ class ReplacePoolingReshape(FrontReplacementPattern):
|
||||
@staticmethod
|
||||
def replace_pattern(graph: Graph, match: dict):
|
||||
node = match['pool']
|
||||
node_name = node.soft_get('name', node.id)
|
||||
|
||||
if node.pool_step is None:
|
||||
node.stride = int64_array([1, 1, node.window[-1], node.window[-1]])
|
||||
|
||||
# create Reshape before convolution
|
||||
# shape = [in_shape[0], in_shape[1]/patch_stride, 1, patch_stride]
|
||||
shape = Shape(graph, {}).create_node()
|
||||
shape = Shape(graph, {'name': node_name + '/Shape'}).create_node()
|
||||
shape.in_port(0).connect(node.in_port(0).get_source())
|
||||
|
||||
split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([1, -1])}, {'out_ports_count': 2}, shape)
|
||||
node_pool_stride = Const(graph, {'value': int64_array([node.pool_stride])}).create_node()
|
||||
pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1]))
|
||||
split = create_op_with_const_inputs(graph, VariadicSplit, {1: int64_array(0), 2: int64_array([1, -1])},
|
||||
{'name': shape.name + '/split_batch', 'out_ports_count': 2}, shape)
|
||||
|
||||
pow_node = create_op_node_with_second_input(graph, Pow, int64_array([-1]), {'name': node_name + '/pool_stride/inverse'})
|
||||
node_pool_stride = Const(graph, {'value': int64_array([node.pool_stride]),
|
||||
'name': node_name + '/pool_stride/'}).create_node()
|
||||
pow_node.in_port(0).connect(node_pool_stride.out_port(0))
|
||||
|
||||
mul = Mul(graph, {}).create_node()
|
||||
mul = Mul(graph, {'name': node_name + '/mul_inverse_stride_h'}).create_node()
|
||||
mul.in_port(0).connect(split.out_port(1))
|
||||
mul.in_port(1).connect(pow_node.out_port(0))
|
||||
|
||||
const_1 = Const(graph, {'value': int64_array([1])}).create_node()
|
||||
concat = create_op_with_const_inputs(graph, Concat, {2: int64_array([1])},
|
||||
{'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0})
|
||||
|
||||
concat = Concat(graph, {'in_ports_count': 4, 'axis': 0}).create_node()
|
||||
concat.in_port(0).connect(split.out_port(0))
|
||||
concat.in_port(3).connect(mul.out_port(0))
|
||||
concat.in_port(2).connect(const_1.out_port(0))
|
||||
concat.in_port(1).connect(node_pool_stride.out_port(0))
|
||||
|
||||
reshape_in = Reshape(graph, {'name': '/Reshape/' + node.name}).create_node()
|
||||
reshape_in = Reshape(graph, {'name': node_name + '/reshape_in'}).create_node()
|
||||
reshape_in.in_port(1).connect(concat.out_port(0))
|
||||
|
||||
# create Reshape after Convolution
|
||||
reshape_out = create_op_node_with_second_input(graph, Reshape, int64_array([0, -1]),
|
||||
{'name': node.name + '/Reshape/'})
|
||||
{'name': node_name + '/reshape_out'})
|
||||
|
||||
# connect input_reshape_node
|
||||
source = node.in_port(0).get_source()
|
||||
|
@ -16,13 +16,13 @@
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.activation_ops import Sigmoid, Tanh
|
||||
from extensions.ops.elementwise import Add, Mul
|
||||
from extensions.ops.split import Split
|
||||
from mo.front.caffe.extractors.utils import input_as_const
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.const import Const
|
||||
from extensions.ops.elementwise import Add, Mul
|
||||
from mo.ops.scale_shift import ScaleShiftOp
|
||||
|
||||
|
||||
@ -40,80 +40,80 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
|
||||
|
||||
def replace_op(self, graph: Graph, node: Node):
|
||||
# split input to (i_part, f_part, c_part, o_part, ct_1)
|
||||
split_node_axis = Const(graph, {'value': np.int64(1)}).create_node()
|
||||
split_node = Split(graph, {'name': 'Split_lstm_input_',
|
||||
'num_splits': 5}).create_node()
|
||||
node_name = node.soft_get('name', node.id)
|
||||
split_node = create_op_with_const_inputs(graph, Split, {1: np.int64(1)},
|
||||
{'name': node_name + '/split_lstm_input',
|
||||
'num_splits': 5})
|
||||
node.in_port(0).get_connection().set_destination(split_node.in_port(0))
|
||||
split_node.in_port(1).connect(split_node_axis.out_port(0))
|
||||
|
||||
# i_t = Sigmoid(i_part + w_ic*ct_1)
|
||||
i_scale_attrs = {'name': 'i_scaleshift',
|
||||
i_scale_attrs = {'name': node_name + '/i_scaleshift',
|
||||
'bias_term': False}
|
||||
i_scale = ScaleShiftOp(graph, i_scale_attrs).create_node()
|
||||
input_as_const(i_scale, i_scale_attrs, 1, 'weights', node.i_weights)
|
||||
split_node.out_port(4).connect(i_scale.in_port(0))
|
||||
|
||||
sum_i_c = Add(graph, {'name': 'sum_i_c_'}).create_node()
|
||||
sum_i_c = Add(graph, {'name': node_name + '/sum_i_c_'}).create_node()
|
||||
split_node.out_port(0).connect(sum_i_c.in_port(0))
|
||||
i_scale.out_port(0).connect(sum_i_c.in_port(1))
|
||||
|
||||
i_sigmoid = Sigmoid(graph, {'name': 'i_sigmoid'}).create_node()
|
||||
i_sigmoid = Sigmoid(graph, {'name': node_name + '/i_sigmoid'}).create_node()
|
||||
sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))
|
||||
|
||||
# f_t = Sigmoid(f_part + w_fc*ct_1)
|
||||
f_scale_attrs = {'name': 'f_scaleshift',
|
||||
f_scale_attrs = {'name': node_name + '/f_scaleshift',
|
||||
'bias_term': False}
|
||||
f_scale = ScaleShiftOp(graph, f_scale_attrs).create_node()
|
||||
input_as_const(f_scale, f_scale_attrs, 1, 'weights', node.f_weights)
|
||||
split_node.out_port(4).connect(f_scale.in_port(0))
|
||||
|
||||
sum_f_c = Add(graph, {'name': 'sum_f_c_'}).create_node()
|
||||
sum_f_c = Add(graph, {'name': node_name + '/sum_f_c_'}).create_node()
|
||||
split_node.out_port(1).connect(sum_f_c.in_port(0))
|
||||
f_scale.out_port(0).connect(sum_f_c.in_port(1))
|
||||
|
||||
f_sigmoid = Sigmoid(graph, {'name': 'f_sigmoid'}).create_node()
|
||||
f_sigmoid = Sigmoid(graph, {'name': node_name + '/f_sigmoid'}).create_node()
|
||||
sum_f_c.out_port(0).connect(f_sigmoid.in_port(0))
|
||||
|
||||
# c_t = f_t*ct_1 + i_t * tanh(c_part)
|
||||
c_tanh = Tanh(graph, {'name': 'c_tanh'}).create_node()
|
||||
c_tanh = Tanh(graph, {'name': node_name + '/c_tanh'}).create_node()
|
||||
split_node.out_port(2).connect(c_tanh.in_port(0))
|
||||
|
||||
prod_i_c_tanh = Mul(graph, {'name': 'prod_i_c_tanh_'}).create_node()
|
||||
prod_i_c_tanh = Mul(graph, {'name': node_name + '/prod_i_c_tanh_'}).create_node()
|
||||
i_sigmoid.out_port(0).connect(prod_i_c_tanh.in_port(0))
|
||||
c_tanh.out_port(0).connect(prod_i_c_tanh.in_port(1))
|
||||
|
||||
prod_f_ct_1 = Mul(graph, {'name': 'prod_f_ct_1_'}).create_node()
|
||||
prod_f_ct_1 = Mul(graph, {'name': node_name + '/prod_f_ct_1_'}).create_node()
|
||||
f_sigmoid.out_port(0).connect(prod_f_ct_1.in_port(0))
|
||||
split_node.out_port(4).connect(prod_f_ct_1.in_port(1))
|
||||
|
||||
sum_f_i = Add(graph, {'name': 'sum_f_i_'}).create_node()
|
||||
sum_f_i = Add(graph, {'name': node_name + '/sum_f_i_'}).create_node()
|
||||
prod_f_ct_1.out_port(0).connect(sum_f_i.in_port(0))
|
||||
prod_i_c_tanh.out_port(0).connect(sum_f_i.in_port(1))
|
||||
|
||||
# o_t = Sigmoid(o_part + w_oc*c_t)
|
||||
o_scale_attrs = {'name': 'o_scaleshift',
|
||||
o_scale_attrs = {'name': node_name + '/o_scaleshift',
|
||||
'bias_term': False}
|
||||
o_scale = ScaleShiftOp(graph, o_scale_attrs).create_node()
|
||||
input_as_const(o_scale, o_scale_attrs, 1, 'weights', node.o_weights)
|
||||
sum_f_i.out_port(0).connect(o_scale.in_port(0))
|
||||
|
||||
sum_o_c = Add(graph, {'name': 'sum_o_c_'}).create_node()
|
||||
sum_o_c = Add(graph, {'name': node_name + '/sum_o_c_'}).create_node()
|
||||
split_node.out_port(3).connect(sum_o_c.in_port(0))
|
||||
o_scale.out_port(0).connect(sum_o_c.in_port(1))
|
||||
|
||||
o_sigmoid = Sigmoid(graph, {'name': 'o_sigmoid'}).create_node()
|
||||
o_sigmoid = Sigmoid(graph, {'name': node_name + '/o_sigmoid'}).create_node()
|
||||
sum_o_c.out_port(0).connect(o_sigmoid.in_port(0))
|
||||
|
||||
# m_t = o_t * Tanh(c_t)
|
||||
c_t_tanh = Tanh(graph, {'name': 'c_t_tanh'}).create_node()
|
||||
c_t_tanh = Tanh(graph, {'name': node_name + '/c_t_tanh'}).create_node()
|
||||
sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))
|
||||
|
||||
prod_o_c_t_tanh = Mul(graph, {'name': 'prod_o_c_t_tanh_'}).create_node()
|
||||
prod_o_c_t_tanh = Mul(graph, {'name': node_name + '/prod_o_c_t_tanh_'}).create_node()
|
||||
o_sigmoid.out_port(0).connect(prod_o_c_t_tanh.in_port(0))
|
||||
c_t_tanh.out_port(0).connect(prod_o_c_t_tanh.in_port(1))
|
||||
|
||||
# add concat to create 1 output
|
||||
concat = Concat(graph, {'name': 'Concat_c_m'}).create_node()
|
||||
concat = Concat(graph, {'name': node_name + '/concat_c_m'}).create_node()
|
||||
concat.add_sequence_of_ports('in', range(2))
|
||||
sum_f_i.out_port(0).connect(concat.in_port(0))
|
||||
prod_o_c_t_tanh.out_port(0).connect(concat.in_port(1))
|
||||
|
@ -16,8 +16,8 @@
|
||||
from extensions.ops.gather import Gather
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.const import Const
|
||||
|
||||
|
||||
class GatherFrontReplacer(FrontReplacementOp):
|
||||
@ -26,10 +26,10 @@ class GatherFrontReplacer(FrontReplacementOp):
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, match: dict):
|
||||
node = match['op']
|
||||
gather_node = Gather(graph, dict(name=node.id + '/embedding_',
|
||||
symbol_dict={'name': node.id + '/embedding_'})).create_node()
|
||||
axis_const = Const(graph, {'value': int64_array(0)}).create_node()
|
||||
|
||||
gather_node = create_op_with_const_inputs(graph, Gather, {2: int64_array(0)},
|
||||
{'name': node.soft_get('name', node.id) + '/embedding_'})
|
||||
|
||||
node.in_port(0).get_connection().set_destination(gather_node.in_port(1))
|
||||
node.in_port(1).get_connection().set_destination(gather_node.in_port(0))
|
||||
axis_const.out_port(0).connect(gather_node.in_port(2))
|
||||
node.out_port(0).get_connection().set_source(gather_node.out_port(0))
|
||||
|
@ -51,10 +51,12 @@ class FlattenONNXToReshape(FrontReplacementSubgraph):
|
||||
assert node.has_valid('axis'), 'Flatten {} should have `axis` attribute extracted, but it\'s not'.format(name)
|
||||
axis = node.axis
|
||||
|
||||
reshape_node = Reshape(graph, {'name': node.id + '/Reshape'}).create_node()
|
||||
|
||||
if axis == 0:
|
||||
dim = Const(graph, {'value': int64_array([1, -1])}).create_node()
|
||||
dim = Const(graph, {'value': int64_array([1, -1]), 'name': reshape_node.name + '/shape'}).create_node()
|
||||
elif axis == 1:
|
||||
dim = Const(graph, {'value': int64_array([0, -1])}).create_node()
|
||||
dim = Const(graph, {'value': int64_array([0, -1]), 'name': reshape_node.name + '/shape'}).create_node()
|
||||
else:
|
||||
shape = Shape(graph, {'name': name + '/input_shape'}).create_node()
|
||||
|
||||
@ -62,8 +64,8 @@ class FlattenONNXToReshape(FrontReplacementSubgraph):
|
||||
|
||||
axis_shape_portion = node_to_get_shape_value_of_indices(shape, idxs)
|
||||
first_dims = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]),
|
||||
{'keep_dims': True})
|
||||
second_dims = Const(graph, {'value': int64_array([-1])}).create_node()
|
||||
{'name': name + '/first_dims', 'keep_dims': True})
|
||||
second_dims = Const(graph, {'value': int64_array([-1]), 'name': name + '/second_dims'}).create_node()
|
||||
|
||||
node.in_port(0).get_source().connect(shape.in_port(0))
|
||||
axis_shape_portion.out_port(0).connect(first_dims.in_port(0))
|
||||
@ -72,7 +74,6 @@ class FlattenONNXToReshape(FrontReplacementSubgraph):
|
||||
|
||||
dim = new_shape_node_from_shape_nodes(order_of_dims)
|
||||
|
||||
reshape_node = Reshape(graph, {'name': node.id + '/Reshape'}).create_node()
|
||||
reshape_node.in_port(1).connect(dim.out_port(0))
|
||||
|
||||
node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
|
||||
|
@ -20,7 +20,7 @@ from extensions.ops.hard_sigmoid import HardSigmoid
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.front.onnx.extractors.utils import onnx_attr
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.ops.const import Const
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
|
||||
|
||||
class HardSigmoidFrontExtractor(FrontReplacementOp):
|
||||
@ -30,11 +30,9 @@ class HardSigmoidFrontExtractor(FrontReplacementOp):
|
||||
def replace_op(self, graph: Graph, node: Node):
|
||||
alpha = onnx_attr(node, 'alpha', 'f', default=0.2)
|
||||
beta = onnx_attr(node, 'beta', 'f', default=0.5)
|
||||
alpha_node = Const(graph, {'value': np.array(alpha)}).create_node()
|
||||
beta_node = Const(graph, {'value': np.array(beta)}).create_node()
|
||||
|
||||
hard_sigmoid = HardSigmoid(graph, {'name': node.name + '/HardSigmoid_'}).create_node()
|
||||
hard_sigmoid = create_op_with_const_inputs(graph, HardSigmoid, {1: np.array(alpha), 2: np.array(beta)},
|
||||
{'name': node.name + '/HardSigmoid_'})
|
||||
|
||||
node.in_port(0).get_connection().set_destination(hard_sigmoid.in_port(0))
|
||||
alpha_node.out_port(0).connect(hard_sigmoid.in_port(1))
|
||||
beta_node.out_port(0).connect(hard_sigmoid.in_port(2))
|
||||
return [hard_sigmoid.id]
|
||||
|
@ -88,7 +88,8 @@ class FlattenToReshapeableReshape(FrontReplacementSubgraph):
|
||||
return
|
||||
|
||||
reshape_node.in_port(1).disconnect()
|
||||
reshape_const_node = Const(graph, {'value': int64_array([0, -1])}).create_node()
|
||||
reshape_const_node = Const(graph, {'value': int64_array([0, -1]),
|
||||
'name': reshape_node.soft_get('name', reshape_node.id) + '/shape'}).create_node()
|
||||
reshape_node.in_port(1).connect(reshape_const_node.out_port(0))
|
||||
reshape_node['special_zero'] = True
|
||||
log.debug('The node "{}" is actually a Flatten node'.format(reshape_node.soft_get('name')))
|
||||
|
@ -72,10 +72,10 @@ class NearestNeighborUpsampling(FrontReplacementSubgraph):
|
||||
|
||||
axes = int64_array([2, 3]) if graph.graph['layout'] == 'NCHW' else int64_array([1, 2])
|
||||
|
||||
const = Const(graph,
|
||||
{'value': np.array([input_height * height_scale, input_width * width_scale])}).create_node()
|
||||
resample_op = Interpolate(graph, {'name': 'Resample_', 'antialias': 0, 'mode': 'nearest', 'axes': axes})
|
||||
resample_node = resample_op.create_node([match['op']])
|
||||
const = Const(graph, {'value': np.array([input_height * height_scale, input_width * width_scale]),
|
||||
'name': resample_node.name + '/target_shape'}).create_node()
|
||||
|
||||
match['reshape_2'].replace_node(resample_node)
|
||||
|
||||
|
@ -47,7 +47,8 @@ class PadTFToPad(FrontReplacementPattern):
|
||||
if not tfpad.in_port(2).disconnected():
|
||||
tfpad.in_port(2).get_connection().set_destination(new_pad.in_port(3))
|
||||
else:
|
||||
new_pad.in_port(3).connect(Const(graph, {'value': 0.0}).create_node().out_port(0))
|
||||
new_pad.in_port(3).connect(Const(graph, {'value': 0.0, 'name': new_pad.name + '/value'}
|
||||
).create_node().out_port(0))
|
||||
|
||||
# convert TF representation of the pads as [N, 2] to MO representation: [N] and [N]
|
||||
transposed_pads = create_op_with_const_inputs(graph, Transpose, {1: int64_array([1, 0])})
|
||||
|
@ -21,6 +21,7 @@ from extensions.ops.transpose import Transpose
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.front.kaldi.loader.utils import read_binary_integer32_token, read_blob
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.ops.const import Const
|
||||
|
||||
@ -33,26 +34,26 @@ class CopyFrontExtractor(FrontReplacementOp):
|
||||
pb = node.parameters
|
||||
weights_size = read_binary_integer32_token(pb)
|
||||
weights = read_blob(pb, weights_size, dtype=np.int32) - 1
|
||||
|
||||
node_name = node.soft_get('name', node.id)
|
||||
const_attrs = {
|
||||
'name': 'indexes/{}'.format(node.id),
|
||||
'name': node_name + '/indexes',
|
||||
'value': np.array(weights),
|
||||
'shape': [weights_size],
|
||||
'data_type': np.int32
|
||||
}
|
||||
indexes_node = Const(graph).create_node(attrs=const_attrs)
|
||||
|
||||
perm_in_1 = Const(graph, {'value': np.array([1, 0], dtype=np.int64), 'shape': [2], 'data_type': np.int64}).create_node()
|
||||
axis_const = Const(graph, {'value': int64_array(0)}).create_node()
|
||||
perm1_node = Transpose(graph, {'name': 'input_permute'}).create_node([node.in_node(0)])
|
||||
perm_in_1 = Const(graph, {'value': int64_array([1, 0]), 'name': node_name + '/order'}).create_node()
|
||||
perm1_node = Transpose(graph, {'name': node_name + '/input_permute'}).create_node([node.in_node(0)])
|
||||
perm1_node.in_port(0).connect(node.in_port(0).get_source())
|
||||
perm1_node.in_port(1).connect(perm_in_1.out_port(0))
|
||||
|
||||
gather_node = Gather(graph, {}).create_node()
|
||||
gather_node = create_op_with_const_inputs(graph, Gather, {2: int64_array(0)}, {'name': node_name + '/gather'})
|
||||
gather_node.in_port(0).connect(perm1_node.out_port(0))
|
||||
gather_node.in_port(1).connect(indexes_node.out_port(0))
|
||||
gather_node.in_port(2).connect(axis_const.out_port(0))
|
||||
|
||||
perm2_node = Transpose(graph, {'name': 'output_permute'}).create_node()
|
||||
perm2_node = Transpose(graph, {'name': node_name + '/output_permute'}).create_node()
|
||||
perm2_node.in_port(0).connect(gather_node.out_port(0))
|
||||
perm2_node.in_port(1).connect(perm_in_1.out_port(0))
|
||||
|
||||
|
@ -56,9 +56,11 @@ def axis(op_node: Node, port_info: str, input_port: int):
|
||||
|
||||
data_node = op_node.in_node(input_port)
|
||||
|
||||
const = Const(graph, {'value': permutation.inv, 'need_shape_inference': True}).create_node_with_data()
|
||||
axis_const = Const(graph, {'value': int64_array(0)}).create_node_with_data()
|
||||
gather = Gather(graph, {'name': op_node.name + '/AxisGather', 'need_shape_inference': True}).create_node_with_data(
|
||||
gather_name = op_node.soft_get('name', op_node.id) + '/AxisGather'
|
||||
const = Const(graph, {'value': permutation.inv, 'name': gather_name + '/const',
|
||||
'need_shape_inference': True}).create_node_with_data()
|
||||
axis_const = Const(graph, {'value': int64_array(0), 'name': gather_name + '/axis'}).create_node_with_data()
|
||||
gather = Gather(graph, {'name': gather_name, 'need_shape_inference': True}).create_node_with_data(
|
||||
[const, data_node, axis_const])
|
||||
attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy()
|
||||
graph.add_edge(gather.id, op_node.id, **attrs)
|
||||
@ -103,14 +105,18 @@ def order(op_node: Node, port_info: str, input_port: int):
|
||||
|
||||
data_node = op_node.in_node(input_port)
|
||||
|
||||
const = Const(graph, {'value': permutation.perm, 'need_shape_inference': True}).create_node_with_data()
|
||||
axis_const = Const(graph, {'value': int64_array(0)}).create_node_with_data()
|
||||
gather = Gather(graph, {'name': op_node.name + '/OrderGather_1',
|
||||
gather_name = op_node.soft_get('name', op_node.id) + '/OrderGather_1'
|
||||
const = Const(graph, {'value': permutation.perm, 'name': gather_name + '/const',
|
||||
'need_shape_inference': True}).create_node_with_data()
|
||||
axis_const = Const(graph, {'value': int64_array(0), 'name': gather_name + '/axis'}).create_node_with_data()
|
||||
gather = Gather(graph, {'name': gather_name,
|
||||
'need_shape_inference': True}).create_node_with_data([data_node, const, axis_const])
|
||||
|
||||
const_1 = Const(graph, {'value': permutation.inv, 'need_shape_inference': True}).create_node_with_data()
|
||||
axis_const_1 = Const(graph, {'value': int64_array(0)}).create_node_with_data()
|
||||
gather_1 = Gather(graph, {'name': op_node.name + '/OrderGather_2',
|
||||
gather_1_name = op_node.soft_get('name', op_node.id) + '/OrderGather_2'
|
||||
const_1 = Const(graph, {'value': permutation.inv, 'name': gather_1_name + '/const',
|
||||
'need_shape_inference': True}).create_node_with_data()
|
||||
axis_const_1 = Const(graph, {'value': int64_array(0), 'name': gather_1_name + '/axis'}).create_node_with_data()
|
||||
gather_1 = Gather(graph, {'name': gather_1_name,
|
||||
'need_shape_inference': True}).create_node_with_data([const_1, gather, axis_const_1])
|
||||
|
||||
attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy()
|
||||
@ -131,9 +137,11 @@ def shape(op_node: Node, port_info: str, input_port: int):
|
||||
|
||||
data_node = op_node.in_node(input_port)
|
||||
|
||||
const = Const(graph, {'value': permutation.perm, 'need_shape_inference': True}).create_node_with_data()
|
||||
axis_const = Const(graph, {'value': int64_array(0)}).create_node_with_data()
|
||||
gather = Gather(graph, {'name': op_node.name + '/ShapeGather',
|
||||
gather_name = op_node.soft_get('name', op_node.id) + '/ShapeGather'
|
||||
const = Const(graph, {'value': permutation.perm, 'name': gather_name + '/const',
|
||||
'need_shape_inference': True}).create_node_with_data()
|
||||
axis_const = Const(graph, {'value': int64_array(0), 'name': gather_name + '/axis'}).create_node_with_data()
|
||||
gather = Gather(graph, {'name': gather_name,
|
||||
'need_shape_inference': True}).create_node_with_data([data_node, const, axis_const])
|
||||
attrs = graph.get_edge_data(data_node.id, op_node.id, key=0).copy()
|
||||
|
||||
|
@ -108,9 +108,10 @@ def _fuse_mul(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True)
|
||||
value = np.reshape(value, shape)
|
||||
|
||||
# Weights multiplication
|
||||
mul_const = Const(graph, {'value': value}).create_node()
|
||||
w_mul = node.copy_node({'in_ports_count': len(node.in_ports()), 'out_ports_count': len(node.out_ports()),
|
||||
'can_be_fused': False})
|
||||
mul_name = node.name + '_copy'
|
||||
mul_const = Const(graph, {'value': value, 'name': mul_name + '/const'}).create_node()
|
||||
w_mul = node.copy_node({'name': mul_name, 'in_ports_count': len(node.in_ports()),
|
||||
'out_ports_count': len(node.out_ports()), 'can_be_fused': False})
|
||||
w_mul.in_port(const_port.idx).connect(mul_const.out_port(0))
|
||||
w_const = weights_port.get_source()
|
||||
weights_port.get_connection().set_source(w_mul.out_port(0))
|
||||
|
@ -65,14 +65,14 @@ def get_range_node_of_idxs(rank: Node, begin: int, end: int,
|
||||
end_idx = get_canonical_axis_index_node(rank, end)
|
||||
|
||||
if not include_begin:
|
||||
const = Const(graph, {'value': int64_array([1])}).create_node()
|
||||
const = Const(graph, {'value': int64_array([1]), 'name': name + '/exclude_begin/value'}).create_node()
|
||||
add = Add(graph, {'name': name + '/exclude_begin'}).create_node()
|
||||
start_idx.out_port(0).connect(add.in_port(0))
|
||||
const.out_port(0).connect(add.in_port(1))
|
||||
start_idx = add
|
||||
|
||||
if include_end:
|
||||
const = Const(graph, {'value': int64_array([1])}).create_node()
|
||||
const = Const(graph, {'value': int64_array([1]), 'name': name + '/including_end/value'}).create_node()
|
||||
add = Add(graph, {'name': name + '/including_end'}).create_node()
|
||||
end_idx.out_port(0).connect(add.in_port(0))
|
||||
const.out_port(0).connect(add.in_port(1))
|
||||
@ -187,7 +187,10 @@ def new_shape_node_from_shape_nodes(input_shape_nodes: list):
|
||||
:return: the node producing concatenated values of nodes from the "input_shape_nodes"
|
||||
"""
|
||||
assert len(input_shape_nodes) > 0, 'The list of input shape nodes should be non-empty'
|
||||
new_shape_node = Concat(input_shape_nodes[0].graph, {'axis': 0}).create_node()
|
||||
new_shape_node = Concat(input_shape_nodes[0].graph,
|
||||
{'axis': 0,
|
||||
'name': input_shape_nodes[0].soft_get('name', input_shape_nodes[0].id) + '/shapes_concat'}
|
||||
).create_node()
|
||||
|
||||
for ind, input_node in enumerate(input_shape_nodes):
|
||||
new_shape_node.add_input_port(ind)
|
||||
|
Loading…
Reference in New Issue
Block a user