fix batch adding to init value of read value (#4187)

* fix batch adding to init value of read value

* fix for batch in Kaldi models

* added broadcast to be able reshape in IE

* test fixes, added batch broadcasting to created constants

* pep fixes

* move all changes to 1 transformation

* added unit test and fix insertSelect transformation

* added comments

* remove unneeded params search

* fix element_size to send correct batch

* fix update batch in element_size

* couple fixes

* update BOM file

* fix review comments

* review fixes

* review fixes

* fix license headers
This commit is contained in:
Svetlana Dolinina 2021-03-31 11:32:36 +03:00 committed by GitHub
parent b58c648d2d
commit ffa467a5ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 486 additions and 227 deletions

View File

@ -155,6 +155,7 @@ extensions/front/kaldi/add_reshape_around_pooling.py
extensions/front/kaldi/apply_counts.py
extensions/front/kaldi/logsoftmax_component_ext.py
extensions/front/kaldi/memory_offset_adjustment.py
extensions/front/kaldi/memoryoffset_batch_update.py
extensions/front/kaldi/replace_eltwise_nin1.py
extensions/front/kaldi/replace_lstm_node_pattern.py
extensions/front/kaldi/replace_lstm_nonlinearity.py
@ -578,6 +579,7 @@ extensions/middle/L2NormFusing.py
extensions/middle/LayoutChangeForConstantShapePaths.py
extensions/middle/LeakyReluPattern.py
extensions/middle/LSTMRNNSequenceToTensorIterator.py
extensions/middle/MakeKaldiConstReshapable.py
extensions/middle/MarkSubgraphsWithCorrectLayout.py
extensions/middle/MoveConstToLoopBody.py
extensions/middle/MulFakeQuantizeFuse.py

View File

@ -53,7 +53,7 @@ def align_frame_time(graph: Graph, node: Node, frame_time_max):
'splitted': False}).create_node()
# add element_size for MemoryOffset after Parameter for infer
if in_node.op == 'Parameter':
memory_align['element_size'] = in_node.shape[1]
memory_align['element_size'] = in_node.shape
in_port.get_connection().set_source(memory_align.out_port(0))
memory_align.in_port(0).connect(in_node_out_port)
memory_align['frame_time'] = memory_align.t

View File

@ -0,0 +1,27 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from mo.front.common.replacement import FrontReplacementPattern
from mo.graph.graph import Graph
class MemoryOffsetBatchUpdate(FrontReplacementPattern):
"""
Update batch for MemoryOffset nodes with set element_size.
element_size is set in loader according to shape saved in model (for example Parameter node have shape in attribute).
But batch can be changed on front stage if user set batch through command line. So, element_size should be updated
accordingly.
"""
enabled = True
run_not_recursively = True
def run_after(self):
from extensions.front.user_data_repack import UserDataRepack
from extensions.front.kaldi.split_recurrent_memoryoffset import SplitRecurrentMemoryOffset
return [UserDataRepack, SplitRecurrentMemoryOffset]
def find_and_replace_pattern(self, graph: Graph):
batch = graph.get_op_nodes(op="Parameter")[0].shape[0]
for memoryoffset_node in graph.get_op_nodes(op='MemoryOffset'):
if memoryoffset_node.has_valid('element_size'):
memoryoffset_node.element_size[0] = batch

View File

@ -3,25 +3,21 @@
import numpy as np
from extensions.middle.MakeKaldiConstReshapable import create_const_with_batch_from_input
from extensions.ops.MatMul import FullyConnected
from extensions.ops.activation_ops import Tanh, Sigmoid
from extensions.ops.elementwise import Add, Mul
from extensions.ops.split import Split
from mo.front.caffe.extractors.utils import input_as_const
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementOp
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Node, Graph, Port
from mo.graph.graph import Node, Graph
from mo.ops.assign import Assign
from mo.ops.broadcast import Broadcast
from mo.ops.clamp import Clamp
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.crop import Crop
from mo.ops.read_value import ReadValue
from mo.ops.result import Result
from mo.ops.scale_shift import ScaleShiftOp
from mo.ops.shape import Shape
def unique_id(prefix: str = 'id') -> str:
@ -41,35 +37,6 @@ def unique_id(prefix: str = 'id') -> str:
unique_id.names = []
def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision=np.float32):
# create init_graph connected to ReadValue
graph = input_out_port.node.graph
input_name = input_out_port.node.name
shape_of_input = Shape(graph, {'name': 'shape/' + input_name}).create_node()
shape_of_input.in_port(0).connect(input_out_port)
dim_for_get_batch = Const(graph, {'name': 'dim/crop_batch/'+shape_of_input.name,
'value': int64_array([1]), 'shape': int64_array([1])}).create_node()
get_batch = Crop(graph, {'name': 'crop_batch/' + shape_of_input.name,
'axis': int64_array([0]), 'offset': int64_array([0])
}).create_node()
get_batch.in_port(0).connect(shape_of_input.out_port(0))
get_batch.in_port(1).connect(dim_for_get_batch.out_port(0))
mem_shape_2nd_dim = Const(graph, {'name': 'gifo_r_weights_shape/'+input_name,
'value': int64_array([second_dim]),
'shape': int64_array([1])}).create_node()
mem_shape = Concat(graph, {'name': 'gather_memory_shape/' + input_name,
'axis': 0, 'in_ports_count': 2}).create_node()
mem_shape.in_port(0).connect(get_batch.out_port(0))
mem_shape.in_port(1).connect(mem_shape_2nd_dim.out_port(0))
fill_value = Const(graph, {'name': 'fill_value/'+input_name,
'value': np.array([0.0], precision), 'shape': int64_array([1])}).create_node()
init_value_prev_lstm_output = Broadcast(graph, {'name': 'init_value/'+input_name,
}).create_node()
init_value_prev_lstm_output.in_port(0).connect(fill_value.out_port(0))
init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0))
return init_value_prev_lstm_output
class ReplaceLSTMNodePattern(FrontReplacementOp):
op = "LSTMCell"
enabled = True
@ -110,8 +77,8 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 1, 'weights', node.gifo_x_weights)
input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 2, 'biases', node.gifo_biases)
init_value_prev_lstm_output = create_zero_value_with_batch_from_input(input_out_port,
node.gifo_r_weights_shape[1])
init_value_prev_lstm_output = create_const_with_batch_from_input(input_out_port,
node.gifo_r_weights_shape[1])
prev_lstm_output = ReadValue(graph, {'name': 'prev_memory_output',
'variable_id': memory_pair_input
}).create_node()
@ -150,14 +117,8 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
split_joined_input.in_port(0).connect(join_input_prev_state_sum.out_port(0))
split_joined_input.in_port(1).connect(split_joined_input_axis.out_port(0))
# prev_lstm_state = Memory(graph, {'name': 'prev_memory_state',
# 'id': memory_pair_output,
# 'index': 1,
# 'size': 2,
# 'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
# }).create_node()
init_value_prev_lstm_state = create_zero_value_with_batch_from_input(split_joined_input.out_port(0),
node.input_gate_weights.shape[0])
init_value_prev_lstm_state = create_const_with_batch_from_input(split_joined_input.out_port(0),
node.input_gate_weights.shape[0])
prev_lstm_state = ReadValue(graph, {'name': 'prev_memory_state',
'variable_id': memory_pair_output}).create_node()
prev_lstm_state.in_port(0).connect(init_value_prev_lstm_state.out_port(0))

View File

@ -3,6 +3,7 @@
import networkx as nx
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.graph.graph import Graph
from mo.ops.memoryoffset import MemoryOffset
@ -51,7 +52,7 @@ class SplitRecurrentMemoryOffset(FrontReplacementSubgraph):
# check if previous layer contains information about its shape in out-size
# out-size is set in extractor of some nodes like affinecomponent based on weight's size
if offset_node.in_port(0).get_source().node.has_valid('out-size'):
offset_node['element_size'] = offset_node.in_port(0).get_source().node['out-size']
offset_node['element_size'] = int64_array([1, offset_node.in_port(0).get_source().node['out-size']])
else:
raise Error("In a recurrent block 'element_size' for node {} is not set".format(offset_node.id))
SplitRecurrentMemoryOffset.split_offset(offset_node)

View File

@ -3,7 +3,7 @@
import numpy as np
from extensions.front.kaldi.replace_lstm_node_pattern import create_zero_value_with_batch_from_input
from extensions.middle.MakeKaldiConstReshapable import create_const_with_batch_from_input
from extensions.ops.elementwise import Equal
from extensions.ops.select import Select
from mo.front.common.partial_infer.utils import int64_array
@ -12,7 +12,6 @@ from mo.middle.pattern_match import find_pattern_matches, inverse_dict
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.assign import Assign
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.crop import Crop
from mo.ops.read_value import ReadValue
from mo.ops.result import Result
@ -79,7 +78,7 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
# add Select before saving state to avoid saving garbage
select_node = Select(graph, {'name': 'select_' + node.name}).create_node()
zero_else = Const(graph, {'name': 'zero_else', 'value': np.zeros(in_node_shape)}).create_node()
zero_else = create_const_with_batch_from_input(in_node_port, in_node_shape[1])
select_node.in_port(1).connect(in_node_port)
select_node.in_port(2).connect(zero_else.out_port(0))
@ -114,14 +113,14 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
ones = Node(graph, inverse_dict(counter_match)['const_1'])
input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0)
else:
init_value_mem_out = create_zero_value_with_batch_from_input(in_node_port, context_len, np.int32)
init_value_mem_out = create_const_with_batch_from_input(in_node_port, context_len, precision=np.int32)
mem_out = ReadValue(graph, {'name': 'iteration_number',
'variable_id': 'iteration_'+node.name}).create_node()
mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]),
'offset': int64_array([1]), 'dim': int64_array([context_len-1])}).create_node()
cut_first.in_port(0).connect(mem_out.out_port(0))
ones = Const(graph, {'name': 'ones', 'value': np.ones([1, 1], dtype=np.int32)}).create_node()
ones = create_const_with_batch_from_input(in_node_port, 1, 1, np.int32)
concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node()
concat.in_port(0).connect(cut_first.out_port(0))
concat.in_port(1).connect(ones.out_port(0))

View File

@ -15,12 +15,12 @@ class InsertSelectTests(unittest.TestCase):
# graph have no splices - selects should not be inserted
def test_insert_select_0(self):
graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
'placeholder_1': {'kind': 'op', 'op': None},
graph = build_graph({
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
'memory': {'kind': 'op', 'op': 'Assign'},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
[('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'memory')
],
nodes_with_edges_only=True)
@ -32,8 +32,8 @@ class InsertSelectTests(unittest.TestCase):
# graph contains 1 splice with context length 5, should be inserted select with memory as counter with length 5
def test_insert_select_1(self):
graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
'placeholder_1': {'kind': 'op', 'op': None},
graph = build_graph({
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
@ -41,35 +41,53 @@ class InsertSelectTests(unittest.TestCase):
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
[('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
('placeholder_data_2', 'memory')
],
nodes_with_edges_only=True)
AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
ref_graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
'placeholder_1': {'kind': 'op', 'op': None},
ref_graph = build_graph({
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
'placeholder_2': {'kind': 'op', 'op': None},
'second_dim_mem_1': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
'second_dim_data_mem_1': {'kind': 'data'},
'gather_shape_mem_1': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data_mem_1': {'kind': 'data'},
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data': {'kind': 'data'},
'broadcast_mem_1': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data_mem_1': {'kind': 'data'},
'shape': {'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data'},
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
'crop_batch_data': {'kind': 'data'},
'crop_batch_dim':{'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
'crop_batch_dim_data': {'kind': 'data'},
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
'second_dim_data': {'kind': 'data'},
'gather_shape': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data': {'kind': 'data'},
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data': {'kind': 'data'},
'fill_value_ones': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data_ones': {'kind': 'data'},
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data': {'kind': 'data'},
'second_dim_mem_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([26])},
'second_dim_data_mem_2': {'kind': 'data'},
'gather_shape_mem_2': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data_mem_2': {'kind': 'data'},
'fill_value_ones_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data_ones_2': {'kind': 'data'},
'broadcast_mem_2': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data_mem_2': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue', 'shape': int64_array([5])},
'memory_in_data': {'kind': 'data'},
'memory_out': {'kind': 'op', 'op': 'Assign', 'shape': int64_array([5])},
@ -85,18 +103,46 @@ class InsertSelectTests(unittest.TestCase):
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
'const_0': {'kind': 'op', 'op': 'Const'},
'const_0_data': {'kind': 'data'},
'const_1': {'kind': 'op', 'op': 'Const'},
'const_1_data': {'kind': 'data'},
'concat': {'kind': 'op', 'op': 'Concat'},
'concat_data': {'kind': 'data'},
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Assign'},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
[('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
('placeholder_data_2', 'select', {'in': 1}),
('placeholder_data_2', 'select', {'in': 1}),
('second_dim_mem_1', 'second_dim_data_mem_1'),
('second_dim_data_mem_1', 'gather_shape_mem_1', {'in': 1}),
('crop_batch_data', 'gather_shape_mem_1', {'in': 0}),
('gather_shape_mem_1', 'gather_shape_data_mem_1'),
('fill_value', 'fill_value_data'),
('fill_value_data', 'broadcast_mem_1', {'in': 0}),
('gather_shape_data_mem_1', 'broadcast_mem_1', {'in': 1}),
('broadcast_mem_1', 'broadcast_data_mem_1'),
('broadcast_data_mem_1', 'memory_in'),
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
('second_dim_mem_2', 'second_dim_data_mem_2'),
('second_dim_data_mem_2', 'gather_shape_mem_2', {'in': 1}),
('crop_batch_data', 'gather_shape_mem_2', {'in': 0}),
('gather_shape_mem_2', 'gather_shape_data_mem_2'),
('fill_value_ones_2', 'fill_value_data_ones_2'),
('fill_value_data_ones_2', 'broadcast_mem_2', {'in': 0}),
('gather_shape_data_mem_2', 'broadcast_mem_2', {'in': 1}),
('broadcast_mem_2', 'broadcast_data_mem_2'),
('broadcast_data_mem_2', 'concat', {'in': 1}),
('concat', 'concat_data'), ('concat_data', 'memory_out'),
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
('crop_out_data', 'equal', {'in': 1}), ('broadcast_data_mem_2', 'equal', {'in': 0}),
('equal', 'equal_data'),
('equal_data', 'select', {'in': 0}),
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
@ -104,21 +150,11 @@ class InsertSelectTests(unittest.TestCase):
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
('fill_value_ones', 'fill_value_data_ones'),
('fill_value_data_ones', 'broadcast', {'in': 0}),
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
('broadcast_data', 'memory_in'),
('broadcast_data', 'select', {'in': 2}),
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
('concat', 'concat_data'), ('concat_data', 'memory_out'),
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
('equal', 'equal_data'),
('equal_data', 'select', {'in': 0}),
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
('select', 'select_out_data'),
('select_out_data', 'memory')
],
@ -131,8 +167,8 @@ class InsertSelectTests(unittest.TestCase):
# graph contains 1 splice with context length 5 on the path to memory and 1 out of path,
# should be inserted select with memory as counter with length 5
def test_insert_select_2(self):
graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
'placeholder_1': {'kind': 'op', 'op': None},
graph = build_graph({
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
'splice_data_1': {'kind': 'data', 'shape': [1, 65]},
@ -142,7 +178,7 @@ class InsertSelectTests(unittest.TestCase):
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Assign'},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
[('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
('placeholder_data_1', 'splice_2'), ('splice_2', 'splice_data_2'),
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
@ -150,8 +186,8 @@ class InsertSelectTests(unittest.TestCase):
],
nodes_with_edges_only=True)
AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
ref_graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
'placeholder_1': {'kind': 'op', 'op': None},
ref_graph = build_graph({
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
'splice_data_1': {'kind': 'data', 'shape': [1, 65]},
@ -159,6 +195,15 @@ class InsertSelectTests(unittest.TestCase):
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
'placeholder_2': {'kind': 'op', 'op': None},
'second_dim_mem_1': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
'second_dim_data_mem_1': {'kind': 'data'},
'gather_shape_mem_1': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data_mem_1': {'kind': 'data'},
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data': {'kind': 'data'},
'broadcast_mem_1': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data_mem_1': {'kind': 'data'},
'shape': {'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data'},
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
@ -169,14 +214,23 @@ class InsertSelectTests(unittest.TestCase):
'second_dim_data': {'kind': 'data'},
'gather_shape': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data': {'kind': 'data'},
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data': {'kind': 'data'},
'fill_value_ones': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data_ones': {'kind': 'data'},
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
'second_dim_mem_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([26])},
'second_dim_data_mem_2': {'kind': 'data'},
'gather_shape_mem_2': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data_mem_2': {'kind': 'data'},
'fill_value_ones_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data_ones_2': {'kind': 'data'},
'broadcast_mem_2': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data_mem_2': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue', 'shape': int64_array([5])},
'memory_in_data': {'kind': 'data'},
'memory_out': {'kind': 'op', 'op': 'Assign'},
'memory_out': {'kind': 'op', 'op': 'Assign', 'shape': int64_array([5])},
'memory_out_data': {'kind': 'data'},
'result': {'kind': 'op', 'op': 'Result'},
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 4},
@ -189,55 +243,72 @@ class InsertSelectTests(unittest.TestCase):
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
'const_0': {'kind': 'op', 'op': 'Const'},
'const_0_data': {'kind': 'data'},
'const_1': {'kind': 'op', 'op': 'Const'},
'const_1_data': {'kind': 'data'},
'concat': {'kind': 'op', 'op': 'Concat'},
'concat_data': {'kind': 'data'},
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Assign'},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
[('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
('placeholder_data_1', 'splice_2'), ('splice_2', 'splice_data_2'),
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
('placeholder_data_2', 'select', {'in': 1}),
('second_dim_mem_1', 'second_dim_data_mem_1'),
('second_dim_data_mem_1', 'gather_shape_mem_1', {'in': 1}),
('crop_batch_data', 'gather_shape_mem_1', {'in': 0}),
('gather_shape_mem_1', 'gather_shape_data_mem_1'),
('fill_value', 'fill_value_data'),
('fill_value_data', 'broadcast_mem_1', {'in': 0}),
('gather_shape_data_mem_1', 'broadcast_mem_1', {'in': 1}),
('broadcast_mem_1', 'broadcast_data_mem_1'),
('broadcast_data_mem_1', 'memory_in'),
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
('second_dim_mem_2', 'second_dim_data_mem_2'),
('second_dim_data_mem_2', 'gather_shape_mem_2', {'in': 1}),
('crop_batch_data', 'gather_shape_mem_2', {'in': 0}),
('gather_shape_mem_2', 'gather_shape_data_mem_2'),
('fill_value_ones_2', 'fill_value_data_ones_2'),
('fill_value_data_ones_2', 'broadcast_mem_2', {'in': 0}),
('gather_shape_data_mem_2', 'broadcast_mem_2', {'in': 1}),
('broadcast_mem_2', 'broadcast_data_mem_2'),
('broadcast_data_mem_2', 'concat', {'in': 1}),
('concat', 'concat_data'), ('concat_data', 'memory_out'),
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
('crop_out_data', 'equal', {'in': 1}), ('broadcast_data_mem_2', 'equal', {'in': 0}),
('equal', 'equal_data'),
('equal_data', 'select', {'in': 0}),
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
('crop_batch_dim', 'crop_batch_dim_data'),
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
('fill_value_ones', 'fill_value_data_ones'),
('fill_value_data_ones', 'broadcast', {'in': 0}),
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
('broadcast_data', 'memory_in'),
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
('concat', 'concat_data'), ('concat_data', 'memory_out'),
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
('equal', 'equal_data'),
('equal_data', 'select', {'in': 0}),
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
('broadcast_data', 'select', {'in': 2}),
('select', 'select_out_data'),
('select_out_data', 'memory')
],
nodes_with_edges_only=True
)
(flag, resp) = compare_graphs(graph, ref_graph, 'memory')
self.assertTrue(flag, resp)
# graph contains 2 splices with sum context length 8 on the path to memory,
# should be inserted select with memory as counter with length 7
def test_insert_select_3(self):
graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
'placeholder_1': {'kind': 'op', 'op': None},
graph = build_graph({
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
'splice_data_1': {'kind': 'data', 'shape': [1, 65]},
@ -247,7 +318,7 @@ class InsertSelectTests(unittest.TestCase):
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
[('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
('splice_data_1', 'splice_2'), ('splice_2', 'splice_data_2'),
('splice_data_2', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
@ -255,8 +326,8 @@ class InsertSelectTests(unittest.TestCase):
],
nodes_with_edges_only=True)
AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
ref_graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
'placeholder_1': {'kind': 'op', 'op': None},
ref_graph = build_graph({
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
'splice_data_1': {'kind': 'data', 'shape': [1, 65]},
@ -264,27 +335,45 @@ class InsertSelectTests(unittest.TestCase):
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
'placeholder_2': {'kind': 'op', 'op': None},
'second_dim_mem_1': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
'second_dim_data_mem_1': {'kind': 'data'},
'gather_shape_mem_1': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data_mem_1': {'kind': 'data'},
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data': {'kind': 'data'},
'broadcast_mem_1': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data_mem_1': {'kind': 'data'},
'shape': {'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data'},
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
'crop_batch_data': {'kind': 'data'},
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
'crop_batch_dim_data': {'kind': 'data'},
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([7])},
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
'second_dim_data': {'kind': 'data'},
'gather_shape': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data': {'kind': 'data'},
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data': {'kind': 'data'},
'fill_value_ones': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data_ones': {'kind': 'data'},
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
'second_dim_mem_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([26])},
'second_dim_data_mem_2': {'kind': 'data'},
'gather_shape_mem_2': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data_mem_2': {'kind': 'data'},
'fill_value_ones_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data_ones_2': {'kind': 'data'},
'broadcast_mem_2': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data_mem_2': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue', 'shape': int64_array([5])},
'memory_in_data': {'kind': 'data'},
'memory_out': {'kind': 'op', 'op': 'Assign'},
'memory_out': {'kind': 'op', 'op': 'Assign', 'shape': int64_array([5])},
'memory_out_data': {'kind': 'data'},
'result': {'kind': 'op', 'op': 'Result'},
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 6},
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 4},
'crop_in_data': {'kind': 'data'},
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
'crop_out_data': {'kind': 'data'},
@ -294,40 +383,58 @@ class InsertSelectTests(unittest.TestCase):
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
'const_0': {'kind': 'op', 'op': 'Const'},
'const_0_data': {'kind': 'data'},
'const_1': {'kind': 'op', 'op': 'Const'},
'const_1_data': {'kind': 'data'},
'concat': {'kind': 'op', 'op': 'Concat'},
'concat_data': {'kind': 'data'},
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
},
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
[('placeholder_1', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
('splice_data_1', 'splice_2'), ('splice_2', 'splice_data_2'),
('splice_data_2', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
('placeholder_data_2', 'select', {'in': 1}),
('second_dim_mem_1', 'second_dim_data_mem_1'),
('second_dim_data_mem_1', 'gather_shape_mem_1', {'in': 1}),
('crop_batch_data', 'gather_shape_mem_1', {'in': 0}),
('gather_shape_mem_1', 'gather_shape_data_mem_1'),
('fill_value', 'fill_value_data'),
('fill_value_data', 'broadcast_mem_1', {'in': 0}),
('gather_shape_data_mem_1', 'broadcast_mem_1', {'in': 1}),
('broadcast_mem_1', 'broadcast_data_mem_1'),
('broadcast_data_mem_1', 'memory_in'),
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
('second_dim_mem_2', 'second_dim_data_mem_2'),
('second_dim_data_mem_2', 'gather_shape_mem_2', {'in': 1}),
('crop_batch_data', 'gather_shape_mem_2', {'in': 0}),
('gather_shape_mem_2', 'gather_shape_data_mem_2'),
('fill_value_ones_2', 'fill_value_data_ones_2'),
('fill_value_data_ones_2', 'broadcast_mem_2', {'in': 0}),
('gather_shape_data_mem_2', 'broadcast_mem_2', {'in': 1}),
('broadcast_mem_2', 'broadcast_data_mem_2'),
('broadcast_data_mem_2', 'concat', {'in': 1}),
('concat', 'concat_data'), ('concat_data', 'memory_out'),
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
('crop_out_data', 'equal', {'in': 1}), ('broadcast_data_mem_2', 'equal', {'in': 0}),
('equal', 'equal_data'),
('equal_data', 'select', {'in': 0}),
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
('crop_batch_dim', 'crop_batch_dim_data'),
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
('fill_value_ones', 'fill_value_data_ones'),
('fill_value_data_ones', 'broadcast', {'in': 0}),
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
('broadcast_data', 'memory_in'),
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
('concat', 'concat_data'), ('concat_data', 'memory_out'),
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
('equal', 'equal_data'),
('equal_data', 'select', {'in': 0}),
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
('broadcast_data', 'select', {'in': 2}),
('select', 'select_out_data'),
('select_out_data', 'memory')

View File

@ -0,0 +1,118 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
from mo.graph.graph import Graph, Port
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.broadcast import Broadcast
from mo.ops.concat import Concat
from mo.ops.crop import Crop
from mo.ops.shape import Shape
def create_const_with_batch_from_input(producer_port: Port, second_dim, value=0, precision=np.float32):
"""
Create const with batch taken from input_out_port and second dimension equals second_dim
:param producer_port: take batch from this port
:param second_dim: second dimension for created constant
:param value: value to initialize constant
:param precision: precision for constant
:return created constant node
"""
graph = producer_port.node.graph
input_name = producer_port.node.soft_get('name', producer_port.node.id)
shape_of_input = None
for dest in producer_port.get_destinations():
if dest.node.soft_get('op') == "ShapeOf":
shape_of_input = dest.node
break
if shape_of_input is None:
shape_of_input = Shape(graph, {'name': input_name + '/Shape'}).create_node()
shape_of_input.in_port(0).connect(producer_port)
get_batch = None
for dest in shape_of_input.out_port(0).get_destinations():
if dest.node.soft_get('op') == "Crop" and \
dest.node.in_port(1).get_source().node.soft_get('value', []) == int64_array([1]):
get_batch = dest.node
break
if get_batch is None:
get_batch = create_op_node_with_second_input(graph, Crop, int64_array([1]),
{'name': shape_of_input.name + '/Crop',
'axis': int64_array([0]), 'offset': int64_array([0])},
shape_of_input)
mem_shape = None
for dest in get_batch.out_port(0).get_destinations():
if dest.node.soft_get('op') == "Concat" and \
dest.node.in_port(1).get_source().node.soft_get('value', []) == int64_array([second_dim]):
mem_shape = dest.node
break
if mem_shape is None:
mem_shape = create_op_node_with_second_input(graph, Concat, int64_array([second_dim]),
{'name': get_batch.name + '/Concat', 'axis': 0,
'in_ports_count': 2}, get_batch)
init_value_prev_lstm_output = None
for dest in mem_shape.out_port(0).get_destinations():
if dest.node.soft_get('op') == "Broadcast" and \
dest.node.in_port(1).get_source().node.soft_get('value', []) == np.array([value], dtype=precision):
init_value_prev_lstm_output = dest.node
break
if init_value_prev_lstm_output is None:
init_value_prev_lstm_output = create_op_with_const_inputs(graph, Broadcast,
{0: np.array([value], dtype=precision)},
{'name': mem_shape.name + '/Broadcast'})
init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0))
return init_value_prev_lstm_output
class MakeKaldiConstReshapable(MiddleReplacementPattern):
"""
Add broadcasting of constant nodes based on batch from Parameter node. This approach works only for Kaldi,
because it has the same batch in whole graph due to framework specific.
"""
enabled = True
graph_condition = [lambda graph: graph.graph['fw'] == "kaldi"]
def run_after(self):
from extensions.middle.InsertSelect import AddSelectBeforeMemoryNodePattern
from extensions.middle.ReplaceMemoryOffsetWithSplice import ReplaceMemoryOffsetWithMemoryNodePattern
from extensions.middle.ReplaceSpliceNodePattern import ReplaceSpliceNodePattern
return [AddSelectBeforeMemoryNodePattern, ReplaceMemoryOffsetWithMemoryNodePattern,
ReplaceSpliceNodePattern]
def find_and_replace_pattern(self, graph: Graph):
params = graph.get_op_nodes(op="Parameter")
batch = params[0].shape[0]
# check that all Parameters have the same batch
for p in params:
assert(p.shape[0] == batch,
"Parameter {} have batch different from the {}".format(p.soft_get('name', p.id),
params[0].soft_get('name', params[0].id)))
# make constants for initialization of ReadValue reshapable
for read in graph.get_op_nodes(op='ReadValue'):
input_node = read.in_port(0).get_source().node
if input_node.soft_get('op') == "Const":
const_shape = input_node.out_port(0).data.get_shape()
# extra check to be sure that we don't break shapes compatibility in graph
# in Kaldi models we have only 2 dimensions
# and batch should be set the same as we will get from Parameter
# otherwise just skip such node
if len(const_shape) != 2 or const_shape[0] != batch:
continue
new_const = create_const_with_batch_from_input(params[0].out_port(0),
const_shape[1],
value=input_node.value[0], precision=input_node.data_type)
input_node.out_port(0).get_connection().set_source(new_const.out_port(0))

View File

@ -0,0 +1,104 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import unittest
import numpy as np
from extensions.middle.MakeKaldiConstReshapable import MakeKaldiConstReshapable
from mo.front.common.partial_infer.utils import int64_array
from mo.utils.ir_engine.compare_graphs import compare_graphs
from mo.utils.unittest.graph import build_graph, result, regular_op_with_shaped_data, connect
nodes = {
**regular_op_with_shaped_data('placeholder_1', [1, 13], {'kind': 'op', 'op': 'Parameter', 'shape': [1, 13]}),
**regular_op_with_shaped_data('splice_1', [1, 13], {'kind': 'op', 'op': 'Splice',
'context': np.array([-2, -1, 0, 1, 2])}),
**regular_op_with_shaped_data('placeholder_2', [1, 26], {'kind': 'op', 'op': None}),
**regular_op_with_shaped_data('memory_in', [1, 5], {'kind': 'op', 'op': 'ReadValue',
'shape': int64_array([1, 5])}),
**regular_op_with_shaped_data('memory_out', [1, 5], {'kind': 'op', 'op': 'Assign', 'shape': int64_array([1, 5])}),
**result('result'),
**regular_op_with_shaped_data('crop_in', [1, 4], {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 4}),
**regular_op_with_shaped_data('crop_out', [1, 1], {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1}),
**regular_op_with_shaped_data('equal', [1, 1], {'kind': 'op', 'op': 'Equal'}),
**regular_op_with_shaped_data('select', [1, 26], {'kind': 'op', 'op': 'Select'}),
**regular_op_with_shaped_data('const_0', [1, 1], {'kind': 'op', 'op': 'Const', 'shape': [1, 1],
'value': [0], 'data_type': np.float32}),
**regular_op_with_shaped_data('const_1', [1, 1], {'kind': 'op', 'op': 'Const', 'shape': [1, 1],
'value': [0], 'data_type': np.float32}),
**regular_op_with_shaped_data('concat', [1, 5], {'kind': 'op', 'op': 'Concat'}),
**regular_op_with_shaped_data('memory', [1, 26], {'kind': 'op', 'op': 'Assign'}),
**regular_op_with_shaped_data('shape', None, {'kind': 'op', 'op': 'ShapeOf'}),
**regular_op_with_shaped_data('crop_batch', None, {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])}),
**regular_op_with_shaped_data('crop_batch_dim', None, {'kind': 'op', 'op': 'Const', 'shape': [1],
'value': [1], 'data_type': np.int64}),
**regular_op_with_shaped_data('second_dim', None, {'kind': 'op', 'op': 'Const', 'shape': [1],
'value': [5], 'data_type': np.int64}),
**regular_op_with_shaped_data('gather_shape', None, {'kind': 'op', 'op': 'Concat'}),
**regular_op_with_shaped_data('fill_value', [1, 5], {'kind': 'op', 'op': 'Const', 'shape': [1, 5],
'value': np.zeros([1, 5]), 'data_type': np.float32}),
**regular_op_with_shaped_data('fill_value_2', None, {'kind': 'op', 'op': 'Const', 'shape': [1],
'value': [0], 'data_type': np.float32}),
**regular_op_with_shaped_data('broadcast', [1, 5], {'kind': 'op', 'op': 'Broadcast'}),
**regular_op_with_shaped_data('fill_value_ones', [1, 26], {'kind': 'op', 'op': 'Const', 'shape': [1, 26],
'value': np.zeros([1, 26]), 'data_type': np.int64}),
**regular_op_with_shaped_data('fill_value_ones_2', [1, 1], {'kind': 'op', 'op': 'Const', 'shape': [1, 1],
'value': [1], 'data_type': np.int64}),
}
class MakeKaldiConstReshapableTests(unittest.TestCase):
# graph contains 1 splice with context length 5, should be inserted select with memory as counter with length 5
def test_reshapable_const(self):
graph = build_graph(nodes,
[*connect('placeholder_1', 'splice_1'),
*connect('splice_1', 'placeholder_2'),
*connect('placeholder_2', '1:select'),
*connect('fill_value', 'memory_in'),
*connect('memory_in', 'crop_in'),
*connect('crop_in', '0:concat'),
*connect('fill_value_ones_2:0', '1:concat'),
*connect('concat', 'memory_out'),
*connect('memory_out', 'result'),
*connect('concat', 'crop_out'),
*connect('crop_out', '1:equal'),
*connect('fill_value_ones_2:0', '0:equal'),
*connect('equal', '0:select'),
*connect('fill_value_ones', '2:select'),
*connect('select', 'memory')
],
nodes_with_edges_only=True)
graph.strict_mode = False
MakeKaldiConstReshapable().find_and_replace_pattern(graph)
ref_graph = build_graph(nodes,
[*connect('placeholder_1:0', 'splice_1'),
*connect('splice_1', 'placeholder_2'),
*connect('placeholder_2', '1:select'),
*connect('placeholder_1:0', 'shape', skip_data=True),
*connect('shape', '0:crop_batch'),
*connect('crop_batch_dim', '1:crop_batch'),
*connect('second_dim', '1:gather_shape'),
*connect('crop_batch', '0:gather_shape'),
*connect('fill_value_2', '0:broadcast'),
*connect('gather_shape', '1:broadcast'),
*connect('broadcast', 'memory_in'),
*connect('memory_in', 'crop_in'),
*connect('crop_in', '0:concat'),
*connect('fill_value_ones_2', '1:concat'),
*connect('concat', 'memory_out'),
*connect('memory_out', 'result'),
*connect('concat', 'crop_out'),
*connect('crop_out', '1:equal'),
*connect('fill_value_ones_2', '0:equal'),
*connect('equal', '0:select'),
*connect('const_0', '2:select'),
*connect('fill_value_ones', '2:select'),
*connect('select', 'memory')
], nodes_with_edges_only=True)
(flag, resp) = compare_graphs(graph, ref_graph, 'memory')
self.assertTrue(flag, resp)

View File

@ -2,15 +2,14 @@
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import logging as log
from extensions.front.kaldi.replace_lstm_node_pattern import create_zero_value_with_batch_from_input
from extensions.ops.splice import Splice
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph, Node
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.assign import Assign
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.crop import Crop
from mo.ops.read_value import ReadValue
from mo.ops.result import Result
@ -134,7 +133,9 @@ class ReplaceMemoryOffsetWithMemoryNodePattern(MiddleReplacementPattern):
in_shape = input_port.data.get_shape()
node_t = abs(node.t)
init_value_memory_out = create_zero_value_with_batch_from_input(input_port, in_shape[1]*node_t)
init_value_memory_out = Const(graph, {'name': 'init_value_' + pair_name,
'value': np.zeros(int64_array([in_shape[0], in_shape[1]*node_t])),
'shape': int64_array([in_shape[0], in_shape[1]*node_t])}).create_node()
memory_out = ReadValue(graph, {'name': pair_name, 'variable_id': node_name+pair_name}).create_node()
init_value_memory_out.out_port(0).connect(memory_out.in_port(0))
@ -163,14 +164,6 @@ class ReplaceMemoryOffsetWithMemoryNodePattern(MiddleReplacementPattern):
memory_in.out_port(0).connect(out.in_port(0))
out_port.get_connection().set_source(memory_out.out_port(0))
if not graph.graph['cmd_params'].static_shape:
log.error(
"Model can not be translated in a reshape-able way.\n"
"Model Optimizer key static_shape was turned on to prevent related errors.\n"
"There will be no success changing input shapes of the model with the help of "
"InferenceEngine reshape method", extra={'is_warning': True})
graph.graph['cmd_params'].static_shape = True
graph.remove_node(op_output_id)
graph.remove_node(node.id)
graph.remove_node(pair_node.id)

View File

@ -1,7 +1,9 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from extensions.front.kaldi.replace_lstm_node_pattern import unique_id, create_zero_value_with_batch_from_input
import numpy as np
from extensions.front.kaldi.replace_lstm_node_pattern import unique_id
from extensions.ops.split import VariadicSplit
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs
@ -9,6 +11,7 @@ from mo.graph.graph import Graph
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.assign import Assign
from mo.ops.concat import Concat
from mo.ops.const import Const
from mo.ops.crop import Crop
from mo.ops.read_value import ReadValue
from mo.ops.result import Result
@ -93,8 +96,11 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
# create separate splice construction for const_dim
memory_pair_id = unique_id('memory_for_const_dim')
init_value_input_memory_const_dim = create_zero_value_with_batch_from_input(split.out_port(1),
memory_size_constdim)
init_value_input_memory_const_dim = Const(graph, {'name': 'init_value_const_dim_in_memory',
'value': np.zeros(int64_array([in_shape[0],
memory_size_constdim])),
'shape': int64_array([in_shape[0],
memory_size_constdim])}).create_node()
input_memory_const_dim = ReadValue(graph, {'name': 'const_dim_in_memory',
'variable_id': memory_pair_id}).create_node()
init_value_input_memory_const_dim.out_port(0).connect(input_memory_const_dim.in_port(0))
@ -129,14 +135,16 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
concat_const.in_port(1).connect(crop_first.out_port(0))
concat_const.in_port(0).connect(concat_node.out_port(0))
init_value_input_memory = create_zero_value_with_batch_from_input(split.out_port(0),
memory_size)
init_value_input_memory = Const(graph, {'name': 'init_value_' + node.name,
'value': np.zeros(int64_array([in_shape[0], memory_size])),
'shape': int64_array([in_shape[0], memory_size])}).create_node()
init_value_input_memory.out_port(0).connect(input_memory.in_port(0))
node.in_port(0).get_connection().set_destination(split.in_port(0))
node.out_port(0).get_connection().set_source(concat_const.out_port(0))
else:
init_value_input_memory = create_zero_value_with_batch_from_input(node.in_port(0).get_source(),
memory_size)
init_value_input_memory = Const(graph, {'name': 'init_value_' + node.name,
'value': np.zeros(int64_array([in_shape[0], memory_size])),
'shape': int64_array([in_shape[0], memory_size])}).create_node()
init_value_input_memory.out_port(0).connect(input_memory.in_port(0))
node.in_port(0).get_connection().set_destination(concat_node.in_port(1))
node.out_port(0).get_connection().set_source(concat_node.out_port(0))

View File

@ -32,20 +32,8 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
ref_graph = build_graph({'in_placeholder': {'kind': 'op', 'op': None},
'in_node': {'kind': 'data', 'shape': [1, 13]},
'shape': {'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data'},
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
'crop_batch_data': {'kind': 'data'},
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
'crop_batch_dim_data': {'kind': 'data'},
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([143])},
'second_dim_data': {'kind': 'data'},
'gather_shape': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data': {'kind': 'data'},
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data': {'kind': 'data'},
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
'memory_in_data': {'kind': 'data'},
@ -61,16 +49,7 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
[
('in_placeholder', 'in_node'),
('in_node', 'shape'), ('shape', 'shape_data'),
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
('crop_batch_dim', 'crop_batch_dim_data'),
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
('crop_batch_data', 'gather_shape', {'in': 0}),
('gather_shape', 'gather_shape_data'),
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
('broadcast_data', 'memory_in'),
('fill_value', 'fill_value_data'), ('fill_value_data', 'memory_in'),
('memory_in', 'memory_in_data'),
('memory_in_data', 'crop_mem'),
@ -104,20 +83,8 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
'split_data_0': {'kind': 'data'},
'split_data_1': {'kind': 'data'},
'shape': {'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data'},
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
'crop_batch_data': {'kind': 'data'},
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
'crop_batch_dim_data': {'kind': 'data'},
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([33])},
'second_dim_data': {'kind': 'data'},
'gather_shape': {'kind': 'op', 'op': 'Concat'},
'gather_shape_data': {'kind': 'data'},
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_data': {'kind': 'data'},
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_data': {'kind': 'data'},
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
'memory_in_data': {'kind': 'data'},
@ -129,21 +96,10 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
'memory_out_data': {'kind': 'data'},
'result': {'kind': 'op', 'op': 'Result'},
'shape_2': {'kind': 'op', 'op': 'ShapeOf'},
'shape_2_data': {'kind': 'data'},
'crop_batch_2': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
'crop_batch_2_data': {'kind': 'data'},
'crop_batch_dim_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
'crop_batch_dim_2_data': {'kind': 'data'},
'second_dim_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([33])},
'second_dim_2_data': {'kind': 'data'},
'gather_shape_2': {'kind': 'op', 'op': 'Concat'},
'gather_shape_2_data': {'kind': 'data'},
'fill_value_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
'fill_value_2_data': {'kind': 'data'},
'broadcast_2': {'kind': 'op', 'op': 'Broadcast'},
'broadcast_2_data': {'kind': 'data'},
\
'memory_in_constdims': {'kind': 'op', 'op': 'ReadValue'},
'memory_in_constdims_data': {'kind': 'data'},
'crop_mem_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 10, 'dim': 100},
@ -171,16 +127,7 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
('split', 'split_data_0', {'out': 0}),
('split', 'split_data_1', {'out': 1}),
('split_data_0', 'shape'), ('shape', 'shape_data'),
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
('crop_batch_dim', 'crop_batch_dim_data'),
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
('crop_batch_data', 'gather_shape', {'in': 0}),
('gather_shape', 'gather_shape_data'),
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
('broadcast_data', 'memory_in'),
('fill_value', 'fill_value_data'), ('fill_value_data', 'memory_in'),
('memory_in', 'memory_in_data'),
('memory_in_data', 'crop_mem'),
@ -192,16 +139,7 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
('memory_out', 'memory_out_data'),
('memory_out_data', 'result'),
('split_data_1', 'shape_2'), ('shape_2', 'shape_2_data'),
('shape_2_data', 'crop_batch_2'), ('crop_batch_2', 'crop_batch_2_data'),
('crop_batch_dim_2', 'crop_batch_dim_2_data'),
('crop_batch_dim_2_data', 'crop_batch_2', {'in': 1}),
('second_dim_2', 'second_dim_2_data'), ('second_dim_2_data', 'gather_shape_2', {'in': 1}),
('crop_batch_2_data', 'gather_shape_2', {'in': 0}),
('gather_shape_2', 'gather_shape_2_data'),
('fill_value_2', 'fill_value_2_data'), ('fill_value_2_data', 'broadcast_2', {'in': 0}),
('gather_shape_2_data', 'broadcast_2', {'in': 1}), ('broadcast_2', 'broadcast_2_data'),
('broadcast_2_data', 'memory_in_constdims'),
('fill_value_2', 'fill_value_2_data'), ('fill_value_2_data', 'memory_in_constdims'),
('memory_in_constdims', 'memory_in_constdims_data'),
('memory_in_constdims_data', 'crop_mem_constdims'),

View File

@ -32,4 +32,4 @@ class SplitTdnnMemoryOffset(MiddleReplacementPattern):
paired_node['element_size'] = offset_node['element_size']
# Copy shape from previous node. Typically (but not always) for TDNN blocks this is the case
else:
paired_node['element_size'] = offset_node.in_port(0).data.get_shape()[1]
paired_node['element_size'] = offset_node.in_port(0).data.get_shape()

View File

@ -9,7 +9,7 @@ import numpy as np
from extensions.ops.elementwise import Mul
from extensions.ops.split import AttributedVariadicSplit
from mo.front.common.partial_infer.utils import float_array
from mo.front.common.partial_infer.utils import float_array, int64_array
from mo.front.extractor import add_outputs_identity
from mo.front.kaldi.loader.utils import find_next_tag, read_placeholder, find_next_component, get_name_from_path, \
find_end_of_component, end_of_nnet_tag, read_binary_integer32_token, get_parameters, read_token_value, \
@ -214,7 +214,9 @@ def load_kaldi_nnet3_model(graph, file_descr, nnet_name):
for o_n_name, params in node.get_outputs():
o_n = Node(graph, o_n_name)
if o_n['op'] == 'MemoryOffset':
o_n['parameters']['element_size'] = node['shape'][1]
# don't take batch from Parameter, it will be overwritten
# take only second dimension because we have only 2 dimensions
o_n['parameters']['element_size'] = int64_array([1, node.shape[1]])
load_components(file_descr, graph, component_layer_map)
@ -268,7 +270,7 @@ def load_components(file_descr, graph, component_layer_map=None):
for o_n_name, params in node.get_outputs():
o_n = Node(graph, o_n_name)
if o_n['op'] == 'MemoryOffset' and dim != 0:
o_n['parameters']['element_size'] = dim
o_n['parameters']['element_size'] = int64_array([1, dim])
else:
raise Error("Something wrong with layer {}".format(name))
else:
@ -401,7 +403,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
for o_n_name, params in node.get_outputs():
o_n = Node(graph, o_n_name)
if o_n['op'] == 'MemoryOffset':
o_n['parameters']['element_size'] = dim
o_n['parameters']['element_size'] = int64_array([1, dim])
else:
raise Error("Unsupported node specifier {}".format(tokens[0]))
return True

View File

@ -17,17 +17,16 @@ class MemoryOffset(Op):
'pair_name': None,
'splitted': False,
'has_default': False,
'infer': __class__.infer,
'infer': self.infer,
'in_ports_count': 1,
'out_ports_count': 1,
}, attrs)
@staticmethod
def infer(node: Node):
if node.has_valid('element_size'):
# element_size should be set by Kaldi loader or by MemoryOffsetAdjustment
node.out_port(0).data.set_shape([1, node['element_size']])
# element_size should be set by Kaldi loader or MemoryOffsetAdjustment or SplitRecurrentMemoryOffset
node.out_port(0).data.set_shape(node.element_size)
else:
# for TDNN blocks
copy_shape_infer(node)