fix batch adding to init value of read value (#4187)
* fix batch adding to init value of read value * fix for batch in Kaldi models * added broadcast to be able reshape in IE * test fixes, added batch broadcasting to created constants * pep fixes * move all changes to 1 transformation * added unit test and fix insertSelect transformation * added comments * remove unneeded params search * fix element_size to send correct batch * fix update batch in element_size * couple fixes * update BOM file * fix review comments * review fixes * review fixes * fix license headers
This commit is contained in:
parent
b58c648d2d
commit
ffa467a5ad
@ -155,6 +155,7 @@ extensions/front/kaldi/add_reshape_around_pooling.py
|
||||
extensions/front/kaldi/apply_counts.py
|
||||
extensions/front/kaldi/logsoftmax_component_ext.py
|
||||
extensions/front/kaldi/memory_offset_adjustment.py
|
||||
extensions/front/kaldi/memoryoffset_batch_update.py
|
||||
extensions/front/kaldi/replace_eltwise_nin1.py
|
||||
extensions/front/kaldi/replace_lstm_node_pattern.py
|
||||
extensions/front/kaldi/replace_lstm_nonlinearity.py
|
||||
@ -578,6 +579,7 @@ extensions/middle/L2NormFusing.py
|
||||
extensions/middle/LayoutChangeForConstantShapePaths.py
|
||||
extensions/middle/LeakyReluPattern.py
|
||||
extensions/middle/LSTMRNNSequenceToTensorIterator.py
|
||||
extensions/middle/MakeKaldiConstReshapable.py
|
||||
extensions/middle/MarkSubgraphsWithCorrectLayout.py
|
||||
extensions/middle/MoveConstToLoopBody.py
|
||||
extensions/middle/MulFakeQuantizeFuse.py
|
||||
|
@ -53,7 +53,7 @@ def align_frame_time(graph: Graph, node: Node, frame_time_max):
|
||||
'splitted': False}).create_node()
|
||||
# add element_size for MemoryOffset after Parameter for infer
|
||||
if in_node.op == 'Parameter':
|
||||
memory_align['element_size'] = in_node.shape[1]
|
||||
memory_align['element_size'] = in_node.shape
|
||||
in_port.get_connection().set_source(memory_align.out_port(0))
|
||||
memory_align.in_port(0).connect(in_node_out_port)
|
||||
memory_align['frame_time'] = memory_align.t
|
||||
|
@ -0,0 +1,27 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from mo.front.common.replacement import FrontReplacementPattern
|
||||
from mo.graph.graph import Graph
|
||||
|
||||
|
||||
class MemoryOffsetBatchUpdate(FrontReplacementPattern):
|
||||
"""
|
||||
Update batch for MemoryOffset nodes with set element_size.
|
||||
element_size is set in loader according to shape saved in model (for example Parameter node have shape in attribute).
|
||||
But batch can be changed on front stage if user set batch through command line. So, element_size should be updated
|
||||
accordingly.
|
||||
"""
|
||||
enabled = True
|
||||
run_not_recursively = True
|
||||
|
||||
def run_after(self):
|
||||
from extensions.front.user_data_repack import UserDataRepack
|
||||
from extensions.front.kaldi.split_recurrent_memoryoffset import SplitRecurrentMemoryOffset
|
||||
return [UserDataRepack, SplitRecurrentMemoryOffset]
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
batch = graph.get_op_nodes(op="Parameter")[0].shape[0]
|
||||
for memoryoffset_node in graph.get_op_nodes(op='MemoryOffset'):
|
||||
if memoryoffset_node.has_valid('element_size'):
|
||||
memoryoffset_node.element_size[0] = batch
|
@ -3,25 +3,21 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.middle.MakeKaldiConstReshapable import create_const_with_batch_from_input
|
||||
from extensions.ops.MatMul import FullyConnected
|
||||
from extensions.ops.activation_ops import Tanh, Sigmoid
|
||||
from extensions.ops.elementwise import Add, Mul
|
||||
from extensions.ops.split import Split
|
||||
from mo.front.caffe.extractors.utils import input_as_const
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Node, Graph, Port
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.ops.assign import Assign
|
||||
from mo.ops.broadcast import Broadcast
|
||||
from mo.ops.clamp import Clamp
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.crop import Crop
|
||||
from mo.ops.read_value import ReadValue
|
||||
from mo.ops.result import Result
|
||||
from mo.ops.scale_shift import ScaleShiftOp
|
||||
from mo.ops.shape import Shape
|
||||
|
||||
|
||||
def unique_id(prefix: str = 'id') -> str:
|
||||
@ -41,35 +37,6 @@ def unique_id(prefix: str = 'id') -> str:
|
||||
unique_id.names = []
|
||||
|
||||
|
||||
def create_zero_value_with_batch_from_input(input_out_port: Port, second_dim, precision=np.float32):
|
||||
# create init_graph connected to ReadValue
|
||||
graph = input_out_port.node.graph
|
||||
input_name = input_out_port.node.name
|
||||
shape_of_input = Shape(graph, {'name': 'shape/' + input_name}).create_node()
|
||||
shape_of_input.in_port(0).connect(input_out_port)
|
||||
dim_for_get_batch = Const(graph, {'name': 'dim/crop_batch/'+shape_of_input.name,
|
||||
'value': int64_array([1]), 'shape': int64_array([1])}).create_node()
|
||||
get_batch = Crop(graph, {'name': 'crop_batch/' + shape_of_input.name,
|
||||
'axis': int64_array([0]), 'offset': int64_array([0])
|
||||
}).create_node()
|
||||
get_batch.in_port(0).connect(shape_of_input.out_port(0))
|
||||
get_batch.in_port(1).connect(dim_for_get_batch.out_port(0))
|
||||
mem_shape_2nd_dim = Const(graph, {'name': 'gifo_r_weights_shape/'+input_name,
|
||||
'value': int64_array([second_dim]),
|
||||
'shape': int64_array([1])}).create_node()
|
||||
mem_shape = Concat(graph, {'name': 'gather_memory_shape/' + input_name,
|
||||
'axis': 0, 'in_ports_count': 2}).create_node()
|
||||
mem_shape.in_port(0).connect(get_batch.out_port(0))
|
||||
mem_shape.in_port(1).connect(mem_shape_2nd_dim.out_port(0))
|
||||
fill_value = Const(graph, {'name': 'fill_value/'+input_name,
|
||||
'value': np.array([0.0], precision), 'shape': int64_array([1])}).create_node()
|
||||
init_value_prev_lstm_output = Broadcast(graph, {'name': 'init_value/'+input_name,
|
||||
}).create_node()
|
||||
init_value_prev_lstm_output.in_port(0).connect(fill_value.out_port(0))
|
||||
init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0))
|
||||
return init_value_prev_lstm_output
|
||||
|
||||
|
||||
class ReplaceLSTMNodePattern(FrontReplacementOp):
|
||||
op = "LSTMCell"
|
||||
enabled = True
|
||||
@ -110,8 +77,8 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
|
||||
input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 1, 'weights', node.gifo_x_weights)
|
||||
input_as_const(fc_layer_after_input, fc_layer_after_input_attrs, 2, 'biases', node.gifo_biases)
|
||||
|
||||
init_value_prev_lstm_output = create_zero_value_with_batch_from_input(input_out_port,
|
||||
node.gifo_r_weights_shape[1])
|
||||
init_value_prev_lstm_output = create_const_with_batch_from_input(input_out_port,
|
||||
node.gifo_r_weights_shape[1])
|
||||
prev_lstm_output = ReadValue(graph, {'name': 'prev_memory_output',
|
||||
'variable_id': memory_pair_input
|
||||
}).create_node()
|
||||
@ -150,14 +117,8 @@ class ReplaceLSTMNodePattern(FrontReplacementOp):
|
||||
split_joined_input.in_port(0).connect(join_input_prev_state_sum.out_port(0))
|
||||
split_joined_input.in_port(1).connect(split_joined_input_axis.out_port(0))
|
||||
|
||||
# prev_lstm_state = Memory(graph, {'name': 'prev_memory_state',
|
||||
# 'id': memory_pair_output,
|
||||
# 'index': 1,
|
||||
# 'size': 2,
|
||||
# 'shape': np.array([node.input_gate_weights.shape[0]], dtype=np.int64)
|
||||
# }).create_node()
|
||||
init_value_prev_lstm_state = create_zero_value_with_batch_from_input(split_joined_input.out_port(0),
|
||||
node.input_gate_weights.shape[0])
|
||||
init_value_prev_lstm_state = create_const_with_batch_from_input(split_joined_input.out_port(0),
|
||||
node.input_gate_weights.shape[0])
|
||||
prev_lstm_state = ReadValue(graph, {'name': 'prev_memory_state',
|
||||
'variable_id': memory_pair_output}).create_node()
|
||||
prev_lstm_state.in_port(0).connect(init_value_prev_lstm_state.out_port(0))
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
import networkx as nx
|
||||
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.memoryoffset import MemoryOffset
|
||||
@ -51,7 +52,7 @@ class SplitRecurrentMemoryOffset(FrontReplacementSubgraph):
|
||||
# check if previous layer contains information about its shape in out-size
|
||||
# out-size is set in extractor of some nodes like affinecomponent based on weight's size
|
||||
if offset_node.in_port(0).get_source().node.has_valid('out-size'):
|
||||
offset_node['element_size'] = offset_node.in_port(0).get_source().node['out-size']
|
||||
offset_node['element_size'] = int64_array([1, offset_node.in_port(0).get_source().node['out-size']])
|
||||
else:
|
||||
raise Error("In a recurrent block 'element_size' for node {} is not set".format(offset_node.id))
|
||||
SplitRecurrentMemoryOffset.split_offset(offset_node)
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.front.kaldi.replace_lstm_node_pattern import create_zero_value_with_batch_from_input
|
||||
from extensions.middle.MakeKaldiConstReshapable import create_const_with_batch_from_input
|
||||
from extensions.ops.elementwise import Equal
|
||||
from extensions.ops.select import Select
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
@ -12,7 +12,6 @@ from mo.middle.pattern_match import find_pattern_matches, inverse_dict
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.assign import Assign
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.crop import Crop
|
||||
from mo.ops.read_value import ReadValue
|
||||
from mo.ops.result import Result
|
||||
@ -79,7 +78,7 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
|
||||
# add Select before saving state to avoid saving garbage
|
||||
select_node = Select(graph, {'name': 'select_' + node.name}).create_node()
|
||||
zero_else = Const(graph, {'name': 'zero_else', 'value': np.zeros(in_node_shape)}).create_node()
|
||||
zero_else = create_const_with_batch_from_input(in_node_port, in_node_shape[1])
|
||||
select_node.in_port(1).connect(in_node_port)
|
||||
select_node.in_port(2).connect(zero_else.out_port(0))
|
||||
|
||||
@ -114,14 +113,14 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
||||
ones = Node(graph, inverse_dict(counter_match)['const_1'])
|
||||
input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0)
|
||||
else:
|
||||
init_value_mem_out = create_zero_value_with_batch_from_input(in_node_port, context_len, np.int32)
|
||||
init_value_mem_out = create_const_with_batch_from_input(in_node_port, context_len, precision=np.int32)
|
||||
mem_out = ReadValue(graph, {'name': 'iteration_number',
|
||||
'variable_id': 'iteration_'+node.name}).create_node()
|
||||
mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
|
||||
cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]),
|
||||
'offset': int64_array([1]), 'dim': int64_array([context_len-1])}).create_node()
|
||||
cut_first.in_port(0).connect(mem_out.out_port(0))
|
||||
ones = Const(graph, {'name': 'ones', 'value': np.ones([1, 1], dtype=np.int32)}).create_node()
|
||||
ones = create_const_with_batch_from_input(in_node_port, 1, 1, np.int32)
|
||||
concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node()
|
||||
concat.in_port(0).connect(cut_first.out_port(0))
|
||||
concat.in_port(1).connect(ones.out_port(0))
|
||||
|
@ -15,12 +15,12 @@ class InsertSelectTests(unittest.TestCase):
|
||||
|
||||
# graph have no splices - selects should not be inserted
|
||||
def test_insert_select_0(self):
|
||||
graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
|
||||
'placeholder_1': {'kind': 'op', 'op': None},
|
||||
graph = build_graph({
|
||||
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
[('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'memory')
|
||||
],
|
||||
nodes_with_edges_only=True)
|
||||
@ -32,8 +32,8 @@ class InsertSelectTests(unittest.TestCase):
|
||||
|
||||
# graph contains 1 splice with context length 5, should be inserted select with memory as counter with length 5
|
||||
def test_insert_select_1(self):
|
||||
graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
|
||||
'placeholder_1': {'kind': 'op', 'op': None},
|
||||
graph = build_graph({
|
||||
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
|
||||
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
@ -41,35 +41,53 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
[('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||
('placeholder_data_2', 'memory')
|
||||
],
|
||||
nodes_with_edges_only=True)
|
||||
AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
|
||||
'placeholder_1': {'kind': 'op', 'op': None},
|
||||
ref_graph = build_graph({
|
||||
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
|
||||
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'placeholder_2': {'kind': 'op', 'op': None},
|
||||
|
||||
'second_dim_mem_1': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
|
||||
'second_dim_data_mem_1': {'kind': 'data'},
|
||||
'gather_shape_mem_1': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data_mem_1': {'kind': 'data'},
|
||||
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data': {'kind': 'data'},
|
||||
'broadcast_mem_1': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data_mem_1': {'kind': 'data'},
|
||||
|
||||
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data'},
|
||||
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||
'crop_batch_data': {'kind': 'data'},
|
||||
'crop_batch_dim':{'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||
'crop_batch_dim_data': {'kind': 'data'},
|
||||
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
|
||||
'second_dim_data': {'kind': 'data'},
|
||||
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data': {'kind': 'data'},
|
||||
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data': {'kind': 'data'},
|
||||
'fill_value_ones': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data_ones': {'kind': 'data'},
|
||||
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data': {'kind': 'data'},
|
||||
|
||||
'second_dim_mem_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([26])},
|
||||
'second_dim_data_mem_2': {'kind': 'data'},
|
||||
'gather_shape_mem_2': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data_mem_2': {'kind': 'data'},
|
||||
'fill_value_ones_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data_ones_2': {'kind': 'data'},
|
||||
'broadcast_mem_2': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data_mem_2': {'kind': 'data'},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue', 'shape': int64_array([5])},
|
||||
'memory_in_data': {'kind': 'data'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Assign', 'shape': int64_array([5])},
|
||||
@ -85,18 +103,46 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
|
||||
'const_0': {'kind': 'op', 'op': 'Const'},
|
||||
'const_0_data': {'kind': 'data'},
|
||||
'const_1': {'kind': 'op', 'op': 'Const'},
|
||||
'const_1_data': {'kind': 'data'},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||
'concat_data': {'kind': 'data'},
|
||||
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
[('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||
('placeholder_data_2', 'select', {'in': 1}),
|
||||
('placeholder_data_2', 'select', {'in': 1}),
|
||||
|
||||
('second_dim_mem_1', 'second_dim_data_mem_1'),
|
||||
('second_dim_data_mem_1', 'gather_shape_mem_1', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape_mem_1', {'in': 0}),
|
||||
('gather_shape_mem_1', 'gather_shape_data_mem_1'),
|
||||
('fill_value', 'fill_value_data'),
|
||||
('fill_value_data', 'broadcast_mem_1', {'in': 0}),
|
||||
('gather_shape_data_mem_1', 'broadcast_mem_1', {'in': 1}),
|
||||
('broadcast_mem_1', 'broadcast_data_mem_1'),
|
||||
('broadcast_data_mem_1', 'memory_in'),
|
||||
|
||||
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
|
||||
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
|
||||
|
||||
('second_dim_mem_2', 'second_dim_data_mem_2'),
|
||||
('second_dim_data_mem_2', 'gather_shape_mem_2', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape_mem_2', {'in': 0}),
|
||||
('gather_shape_mem_2', 'gather_shape_data_mem_2'),
|
||||
('fill_value_ones_2', 'fill_value_data_ones_2'),
|
||||
('fill_value_data_ones_2', 'broadcast_mem_2', {'in': 0}),
|
||||
('gather_shape_data_mem_2', 'broadcast_mem_2', {'in': 1}),
|
||||
('broadcast_mem_2', 'broadcast_data_mem_2'),
|
||||
('broadcast_data_mem_2', 'concat', {'in': 1}),
|
||||
|
||||
('concat', 'concat_data'), ('concat_data', 'memory_out'),
|
||||
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
|
||||
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
||||
('crop_out_data', 'equal', {'in': 1}), ('broadcast_data_mem_2', 'equal', {'in': 0}),
|
||||
('equal', 'equal_data'),
|
||||
('equal_data', 'select', {'in': 0}),
|
||||
|
||||
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
|
||||
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||
@ -104,21 +150,11 @@ class InsertSelectTests(unittest.TestCase):
|
||||
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
|
||||
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||
('fill_value_ones', 'fill_value_data_ones'),
|
||||
('fill_value_data_ones', 'broadcast', {'in': 0}),
|
||||
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||
('broadcast_data', 'memory_in'),
|
||||
('broadcast_data', 'select', {'in': 2}),
|
||||
|
||||
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
|
||||
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
|
||||
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
|
||||
('concat', 'concat_data'), ('concat_data', 'memory_out'),
|
||||
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
|
||||
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
||||
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
|
||||
('equal', 'equal_data'),
|
||||
('equal_data', 'select', {'in': 0}),
|
||||
|
||||
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
|
||||
('select', 'select_out_data'),
|
||||
('select_out_data', 'memory')
|
||||
],
|
||||
@ -131,8 +167,8 @@ class InsertSelectTests(unittest.TestCase):
|
||||
# graph contains 1 splice with context length 5 on the path to memory and 1 out of path,
|
||||
# should be inserted select with memory as counter with length 5
|
||||
def test_insert_select_2(self):
|
||||
graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
|
||||
'placeholder_1': {'kind': 'op', 'op': None},
|
||||
graph = build_graph({
|
||||
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
|
||||
'splice_data_1': {'kind': 'data', 'shape': [1, 65]},
|
||||
@ -142,7 +178,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
[('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
('placeholder_data_1', 'splice_2'), ('splice_2', 'splice_data_2'),
|
||||
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||
@ -150,8 +186,8 @@ class InsertSelectTests(unittest.TestCase):
|
||||
],
|
||||
nodes_with_edges_only=True)
|
||||
AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
|
||||
'placeholder_1': {'kind': 'op', 'op': None},
|
||||
ref_graph = build_graph({
|
||||
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
|
||||
'splice_data_1': {'kind': 'data', 'shape': [1, 65]},
|
||||
@ -159,6 +195,15 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
|
||||
'placeholder_2': {'kind': 'op', 'op': None},
|
||||
|
||||
'second_dim_mem_1': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
|
||||
'second_dim_data_mem_1': {'kind': 'data'},
|
||||
'gather_shape_mem_1': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data_mem_1': {'kind': 'data'},
|
||||
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data': {'kind': 'data'},
|
||||
'broadcast_mem_1': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data_mem_1': {'kind': 'data'},
|
||||
|
||||
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data'},
|
||||
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||
@ -169,14 +214,23 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'second_dim_data': {'kind': 'data'},
|
||||
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data': {'kind': 'data'},
|
||||
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data': {'kind': 'data'},
|
||||
'fill_value_ones': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data_ones': {'kind': 'data'},
|
||||
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data': {'kind': 'data'},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
|
||||
'second_dim_mem_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([26])},
|
||||
'second_dim_data_mem_2': {'kind': 'data'},
|
||||
'gather_shape_mem_2': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data_mem_2': {'kind': 'data'},
|
||||
'fill_value_ones_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data_ones_2': {'kind': 'data'},
|
||||
'broadcast_mem_2': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data_mem_2': {'kind': 'data'},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue', 'shape': int64_array([5])},
|
||||
'memory_in_data': {'kind': 'data'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Assign'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Assign', 'shape': int64_array([5])},
|
||||
'memory_out_data': {'kind': 'data'},
|
||||
'result': {'kind': 'op', 'op': 'Result'},
|
||||
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 4},
|
||||
@ -189,55 +243,72 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
|
||||
'const_0': {'kind': 'op', 'op': 'Const'},
|
||||
'const_0_data': {'kind': 'data'},
|
||||
'const_1': {'kind': 'op', 'op': 'Const'},
|
||||
'const_1_data': {'kind': 'data'},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||
'concat_data': {'kind': 'data'},
|
||||
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Assign'},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
[('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
('placeholder_data_1', 'splice_2'), ('splice_2', 'splice_data_2'),
|
||||
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||
('placeholder_data_2', 'select', {'in': 1}),
|
||||
|
||||
('second_dim_mem_1', 'second_dim_data_mem_1'),
|
||||
('second_dim_data_mem_1', 'gather_shape_mem_1', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape_mem_1', {'in': 0}),
|
||||
('gather_shape_mem_1', 'gather_shape_data_mem_1'),
|
||||
('fill_value', 'fill_value_data'),
|
||||
('fill_value_data', 'broadcast_mem_1', {'in': 0}),
|
||||
('gather_shape_data_mem_1', 'broadcast_mem_1', {'in': 1}),
|
||||
('broadcast_mem_1', 'broadcast_data_mem_1'),
|
||||
('broadcast_data_mem_1', 'memory_in'),
|
||||
|
||||
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
|
||||
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
|
||||
|
||||
('second_dim_mem_2', 'second_dim_data_mem_2'),
|
||||
('second_dim_data_mem_2', 'gather_shape_mem_2', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape_mem_2', {'in': 0}),
|
||||
('gather_shape_mem_2', 'gather_shape_data_mem_2'),
|
||||
('fill_value_ones_2', 'fill_value_data_ones_2'),
|
||||
('fill_value_data_ones_2', 'broadcast_mem_2', {'in': 0}),
|
||||
('gather_shape_data_mem_2', 'broadcast_mem_2', {'in': 1}),
|
||||
('broadcast_mem_2', 'broadcast_data_mem_2'),
|
||||
('broadcast_data_mem_2', 'concat', {'in': 1}),
|
||||
|
||||
('concat', 'concat_data'), ('concat_data', 'memory_out'),
|
||||
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
|
||||
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
||||
('crop_out_data', 'equal', {'in': 1}), ('broadcast_data_mem_2', 'equal', {'in': 0}),
|
||||
('equal', 'equal_data'),
|
||||
('equal_data', 'select', {'in': 0}),
|
||||
|
||||
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
|
||||
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
|
||||
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||
('fill_value_ones', 'fill_value_data_ones'),
|
||||
('fill_value_data_ones', 'broadcast', {'in': 0}),
|
||||
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||
('broadcast_data', 'memory_in'),
|
||||
|
||||
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
|
||||
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
|
||||
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
|
||||
('concat', 'concat_data'), ('concat_data', 'memory_out'),
|
||||
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
|
||||
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
||||
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
|
||||
('equal', 'equal_data'),
|
||||
('equal_data', 'select', {'in': 0}),
|
||||
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
|
||||
('broadcast_data', 'select', {'in': 2}),
|
||||
|
||||
('select', 'select_out_data'),
|
||||
('select_out_data', 'memory')
|
||||
],
|
||||
nodes_with_edges_only=True
|
||||
)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'memory')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
# graph contains 2 splices with sum context length 8 on the path to memory,
|
||||
# should be inserted select with memory as counter with length 7
|
||||
def test_insert_select_3(self):
|
||||
graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
|
||||
'placeholder_1': {'kind': 'op', 'op': None},
|
||||
graph = build_graph({
|
||||
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
|
||||
'splice_data_1': {'kind': 'data', 'shape': [1, 65]},
|
||||
@ -247,7 +318,7 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
[('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
('splice_data_1', 'splice_2'), ('splice_2', 'splice_data_2'),
|
||||
('splice_data_2', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||
@ -255,8 +326,8 @@ class InsertSelectTests(unittest.TestCase):
|
||||
],
|
||||
nodes_with_edges_only=True)
|
||||
AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph({'in_node': {'kind': 'data', 'shape': [1, 13]},
|
||||
'placeholder_1': {'kind': 'op', 'op': None},
|
||||
ref_graph = build_graph({
|
||||
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
|
||||
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
|
||||
'splice_data_1': {'kind': 'data', 'shape': [1, 65]},
|
||||
@ -264,27 +335,45 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'splice_data_2': {'kind': 'data', 'shape': [1, 39]},
|
||||
'placeholder_2': {'kind': 'op', 'op': None},
|
||||
|
||||
'second_dim_mem_1': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
|
||||
'second_dim_data_mem_1': {'kind': 'data'},
|
||||
'gather_shape_mem_1': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data_mem_1': {'kind': 'data'},
|
||||
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data': {'kind': 'data'},
|
||||
'broadcast_mem_1': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data_mem_1': {'kind': 'data'},
|
||||
|
||||
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data'},
|
||||
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||
'crop_batch_data': {'kind': 'data'},
|
||||
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||
'crop_batch_dim_data': {'kind': 'data'},
|
||||
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([7])},
|
||||
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([5])},
|
||||
'second_dim_data': {'kind': 'data'},
|
||||
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data': {'kind': 'data'},
|
||||
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data': {'kind': 'data'},
|
||||
'fill_value_ones': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data_ones': {'kind': 'data'},
|
||||
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data': {'kind': 'data'},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
|
||||
'second_dim_mem_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([26])},
|
||||
'second_dim_data_mem_2': {'kind': 'data'},
|
||||
'gather_shape_mem_2': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data_mem_2': {'kind': 'data'},
|
||||
'fill_value_ones_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data_ones_2': {'kind': 'data'},
|
||||
'broadcast_mem_2': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data_mem_2': {'kind': 'data'},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue', 'shape': int64_array([5])},
|
||||
'memory_in_data': {'kind': 'data'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Assign'},
|
||||
'memory_out': {'kind': 'op', 'op': 'Assign', 'shape': int64_array([5])},
|
||||
'memory_out_data': {'kind': 'data'},
|
||||
'result': {'kind': 'op', 'op': 'Result'},
|
||||
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 6},
|
||||
'crop_in': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 4},
|
||||
'crop_in_data': {'kind': 'data'},
|
||||
'crop_out': {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1},
|
||||
'crop_out_data': {'kind': 'data'},
|
||||
@ -294,40 +383,58 @@ class InsertSelectTests(unittest.TestCase):
|
||||
'select_out_data': {'kind': 'data', 'shape': [1, 26]},
|
||||
'const_0': {'kind': 'op', 'op': 'Const'},
|
||||
'const_0_data': {'kind': 'data'},
|
||||
'const_1': {'kind': 'op', 'op': 'Const'},
|
||||
'const_1_data': {'kind': 'data'},
|
||||
'concat': {'kind': 'op', 'op': 'Concat'},
|
||||
'concat_data': {'kind': 'data'},
|
||||
|
||||
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
|
||||
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
|
||||
},
|
||||
[('in_node', 'placeholder_1'), ('placeholder_1', 'placeholder_data_1'),
|
||||
[('placeholder_1', 'placeholder_data_1'),
|
||||
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
|
||||
('splice_data_1', 'splice_2'), ('splice_2', 'splice_data_2'),
|
||||
('splice_data_2', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
|
||||
('placeholder_data_2', 'select', {'in': 1}),
|
||||
|
||||
('second_dim_mem_1', 'second_dim_data_mem_1'),
|
||||
('second_dim_data_mem_1', 'gather_shape_mem_1', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape_mem_1', {'in': 0}),
|
||||
('gather_shape_mem_1', 'gather_shape_data_mem_1'),
|
||||
('fill_value', 'fill_value_data'),
|
||||
('fill_value_data', 'broadcast_mem_1', {'in': 0}),
|
||||
('gather_shape_data_mem_1', 'broadcast_mem_1', {'in': 1}),
|
||||
('broadcast_mem_1', 'broadcast_data_mem_1'),
|
||||
('broadcast_data_mem_1', 'memory_in'),
|
||||
|
||||
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
|
||||
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
|
||||
|
||||
('second_dim_mem_2', 'second_dim_data_mem_2'),
|
||||
('second_dim_data_mem_2', 'gather_shape_mem_2', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape_mem_2', {'in': 0}),
|
||||
('gather_shape_mem_2', 'gather_shape_data_mem_2'),
|
||||
('fill_value_ones_2', 'fill_value_data_ones_2'),
|
||||
('fill_value_data_ones_2', 'broadcast_mem_2', {'in': 0}),
|
||||
('gather_shape_data_mem_2', 'broadcast_mem_2', {'in': 1}),
|
||||
('broadcast_mem_2', 'broadcast_data_mem_2'),
|
||||
('broadcast_data_mem_2', 'concat', {'in': 1}),
|
||||
|
||||
('concat', 'concat_data'), ('concat_data', 'memory_out'),
|
||||
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
|
||||
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
||||
('crop_out_data', 'equal', {'in': 1}), ('broadcast_data_mem_2', 'equal', {'in': 0}),
|
||||
('equal', 'equal_data'),
|
||||
('equal_data', 'select', {'in': 0}),
|
||||
|
||||
('placeholder_data_2', 'shape'), ('shape', 'shape_data'),
|
||||
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape', {'in': 0}), ('gather_shape', 'gather_shape_data'),
|
||||
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||
('fill_value_ones', 'fill_value_data_ones'),
|
||||
('fill_value_data_ones', 'broadcast', {'in': 0}),
|
||||
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||
('broadcast_data', 'memory_in'),
|
||||
|
||||
('memory_in', 'memory_in_data'), ('memory_in_data', 'crop_in'),
|
||||
('crop_in', 'crop_in_data'), ('crop_in_data', 'concat', {'in': 0}),
|
||||
('const_1', 'const_1_data'), ('const_1_data', 'concat', {'in': 1}),
|
||||
('concat', 'concat_data'), ('concat_data', 'memory_out'),
|
||||
('memory_out', 'memory_out_data'), ('memory_out_data', 'result'),
|
||||
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
||||
('crop_out_data', 'equal', {'in': 1}), ('const_1_data', 'equal', {'in': 0}),
|
||||
('equal', 'equal_data'),
|
||||
('equal_data', 'select', {'in': 0}),
|
||||
('const_0', 'const_0_data'), ('const_0_data', 'select', {'in': 2}),
|
||||
('broadcast_data', 'select', {'in': 2}),
|
||||
|
||||
('select', 'select_out_data'),
|
||||
('select_out_data', 'memory')
|
||||
|
118
model-optimizer/extensions/middle/MakeKaldiConstReshapable.py
Normal file
118
model-optimizer/extensions/middle/MakeKaldiConstReshapable.py
Normal file
@ -0,0 +1,118 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_node_with_second_input, create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph, Port
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.broadcast import Broadcast
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.crop import Crop
|
||||
from mo.ops.shape import Shape
|
||||
|
||||
|
||||
def create_const_with_batch_from_input(producer_port: Port, second_dim, value=0, precision=np.float32):
|
||||
"""
|
||||
Create const with batch taken from input_out_port and second dimension equals second_dim
|
||||
:param producer_port: take batch from this port
|
||||
:param second_dim: second dimension for created constant
|
||||
:param value: value to initialize constant
|
||||
:param precision: precision for constant
|
||||
:return created constant node
|
||||
"""
|
||||
graph = producer_port.node.graph
|
||||
input_name = producer_port.node.soft_get('name', producer_port.node.id)
|
||||
|
||||
shape_of_input = None
|
||||
for dest in producer_port.get_destinations():
|
||||
if dest.node.soft_get('op') == "ShapeOf":
|
||||
shape_of_input = dest.node
|
||||
break
|
||||
|
||||
if shape_of_input is None:
|
||||
shape_of_input = Shape(graph, {'name': input_name + '/Shape'}).create_node()
|
||||
shape_of_input.in_port(0).connect(producer_port)
|
||||
|
||||
get_batch = None
|
||||
for dest in shape_of_input.out_port(0).get_destinations():
|
||||
if dest.node.soft_get('op') == "Crop" and \
|
||||
dest.node.in_port(1).get_source().node.soft_get('value', []) == int64_array([1]):
|
||||
get_batch = dest.node
|
||||
break
|
||||
|
||||
if get_batch is None:
|
||||
get_batch = create_op_node_with_second_input(graph, Crop, int64_array([1]),
|
||||
{'name': shape_of_input.name + '/Crop',
|
||||
'axis': int64_array([0]), 'offset': int64_array([0])},
|
||||
shape_of_input)
|
||||
|
||||
mem_shape = None
|
||||
for dest in get_batch.out_port(0).get_destinations():
|
||||
if dest.node.soft_get('op') == "Concat" and \
|
||||
dest.node.in_port(1).get_source().node.soft_get('value', []) == int64_array([second_dim]):
|
||||
mem_shape = dest.node
|
||||
break
|
||||
|
||||
if mem_shape is None:
|
||||
mem_shape = create_op_node_with_second_input(graph, Concat, int64_array([second_dim]),
|
||||
{'name': get_batch.name + '/Concat', 'axis': 0,
|
||||
'in_ports_count': 2}, get_batch)
|
||||
|
||||
init_value_prev_lstm_output = None
|
||||
for dest in mem_shape.out_port(0).get_destinations():
|
||||
if dest.node.soft_get('op') == "Broadcast" and \
|
||||
dest.node.in_port(1).get_source().node.soft_get('value', []) == np.array([value], dtype=precision):
|
||||
init_value_prev_lstm_output = dest.node
|
||||
break
|
||||
|
||||
if init_value_prev_lstm_output is None:
|
||||
init_value_prev_lstm_output = create_op_with_const_inputs(graph, Broadcast,
|
||||
{0: np.array([value], dtype=precision)},
|
||||
{'name': mem_shape.name + '/Broadcast'})
|
||||
init_value_prev_lstm_output.in_port(1).connect(mem_shape.out_port(0))
|
||||
|
||||
return init_value_prev_lstm_output
|
||||
|
||||
|
||||
class MakeKaldiConstReshapable(MiddleReplacementPattern):
|
||||
"""
|
||||
Add broadcasting of constant nodes based on batch from Parameter node. This approach works only for Kaldi,
|
||||
because it has the same batch in whole graph due to framework specific.
|
||||
"""
|
||||
enabled = True
|
||||
graph_condition = [lambda graph: graph.graph['fw'] == "kaldi"]
|
||||
|
||||
def run_after(self):
|
||||
from extensions.middle.InsertSelect import AddSelectBeforeMemoryNodePattern
|
||||
from extensions.middle.ReplaceMemoryOffsetWithSplice import ReplaceMemoryOffsetWithMemoryNodePattern
|
||||
from extensions.middle.ReplaceSpliceNodePattern import ReplaceSpliceNodePattern
|
||||
return [AddSelectBeforeMemoryNodePattern, ReplaceMemoryOffsetWithMemoryNodePattern,
|
||||
ReplaceSpliceNodePattern]
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
params = graph.get_op_nodes(op="Parameter")
|
||||
batch = params[0].shape[0]
|
||||
|
||||
# check that all Parameters have the same batch
|
||||
for p in params:
|
||||
assert(p.shape[0] == batch,
|
||||
"Parameter {} have batch different from the {}".format(p.soft_get('name', p.id),
|
||||
params[0].soft_get('name', params[0].id)))
|
||||
|
||||
# make constants for initialization of ReadValue reshapable
|
||||
for read in graph.get_op_nodes(op='ReadValue'):
|
||||
input_node = read.in_port(0).get_source().node
|
||||
if input_node.soft_get('op') == "Const":
|
||||
const_shape = input_node.out_port(0).data.get_shape()
|
||||
# extra check to be sure that we don't break shapes compatibility in graph
|
||||
# in Kaldi models we have only 2 dimensions
|
||||
# and batch should be set the same as we will get from Parameter
|
||||
# otherwise just skip such node
|
||||
if len(const_shape) != 2 or const_shape[0] != batch:
|
||||
continue
|
||||
new_const = create_const_with_batch_from_input(params[0].out_port(0),
|
||||
const_shape[1],
|
||||
value=input_node.value[0], precision=input_node.data_type)
|
||||
input_node.out_port(0).get_connection().set_source(new_const.out_port(0))
|
@ -0,0 +1,104 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.middle.MakeKaldiConstReshapable import MakeKaldiConstReshapable
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from mo.utils.unittest.graph import build_graph, result, regular_op_with_shaped_data, connect
|
||||
|
||||
nodes = {
|
||||
**regular_op_with_shaped_data('placeholder_1', [1, 13], {'kind': 'op', 'op': 'Parameter', 'shape': [1, 13]}),
|
||||
**regular_op_with_shaped_data('splice_1', [1, 13], {'kind': 'op', 'op': 'Splice',
|
||||
'context': np.array([-2, -1, 0, 1, 2])}),
|
||||
**regular_op_with_shaped_data('placeholder_2', [1, 26], {'kind': 'op', 'op': None}),
|
||||
**regular_op_with_shaped_data('memory_in', [1, 5], {'kind': 'op', 'op': 'ReadValue',
|
||||
'shape': int64_array([1, 5])}),
|
||||
**regular_op_with_shaped_data('memory_out', [1, 5], {'kind': 'op', 'op': 'Assign', 'shape': int64_array([1, 5])}),
|
||||
**result('result'),
|
||||
**regular_op_with_shaped_data('crop_in', [1, 4], {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 1, 'dim': 4}),
|
||||
**regular_op_with_shaped_data('crop_out', [1, 1], {'kind': 'op', 'op': 'Crop', 'axis': 1, 'offset': 0, 'dim': 1}),
|
||||
**regular_op_with_shaped_data('equal', [1, 1], {'kind': 'op', 'op': 'Equal'}),
|
||||
**regular_op_with_shaped_data('select', [1, 26], {'kind': 'op', 'op': 'Select'}),
|
||||
**regular_op_with_shaped_data('const_0', [1, 1], {'kind': 'op', 'op': 'Const', 'shape': [1, 1],
|
||||
'value': [0], 'data_type': np.float32}),
|
||||
**regular_op_with_shaped_data('const_1', [1, 1], {'kind': 'op', 'op': 'Const', 'shape': [1, 1],
|
||||
'value': [0], 'data_type': np.float32}),
|
||||
**regular_op_with_shaped_data('concat', [1, 5], {'kind': 'op', 'op': 'Concat'}),
|
||||
**regular_op_with_shaped_data('memory', [1, 26], {'kind': 'op', 'op': 'Assign'}),
|
||||
|
||||
**regular_op_with_shaped_data('shape', None, {'kind': 'op', 'op': 'ShapeOf'}),
|
||||
**regular_op_with_shaped_data('crop_batch', None, {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])}),
|
||||
**regular_op_with_shaped_data('crop_batch_dim', None, {'kind': 'op', 'op': 'Const', 'shape': [1],
|
||||
'value': [1], 'data_type': np.int64}),
|
||||
**regular_op_with_shaped_data('second_dim', None, {'kind': 'op', 'op': 'Const', 'shape': [1],
|
||||
'value': [5], 'data_type': np.int64}),
|
||||
**regular_op_with_shaped_data('gather_shape', None, {'kind': 'op', 'op': 'Concat'}),
|
||||
**regular_op_with_shaped_data('fill_value', [1, 5], {'kind': 'op', 'op': 'Const', 'shape': [1, 5],
|
||||
'value': np.zeros([1, 5]), 'data_type': np.float32}),
|
||||
**regular_op_with_shaped_data('fill_value_2', None, {'kind': 'op', 'op': 'Const', 'shape': [1],
|
||||
'value': [0], 'data_type': np.float32}),
|
||||
**regular_op_with_shaped_data('broadcast', [1, 5], {'kind': 'op', 'op': 'Broadcast'}),
|
||||
|
||||
**regular_op_with_shaped_data('fill_value_ones', [1, 26], {'kind': 'op', 'op': 'Const', 'shape': [1, 26],
|
||||
'value': np.zeros([1, 26]), 'data_type': np.int64}),
|
||||
**regular_op_with_shaped_data('fill_value_ones_2', [1, 1], {'kind': 'op', 'op': 'Const', 'shape': [1, 1],
|
||||
'value': [1], 'data_type': np.int64}),
|
||||
}
|
||||
|
||||
|
||||
class MakeKaldiConstReshapableTests(unittest.TestCase):
|
||||
|
||||
# graph contains 1 splice with context length 5, should be inserted select with memory as counter with length 5
|
||||
def test_reshapable_const(self):
|
||||
graph = build_graph(nodes,
|
||||
[*connect('placeholder_1', 'splice_1'),
|
||||
*connect('splice_1', 'placeholder_2'),
|
||||
*connect('placeholder_2', '1:select'),
|
||||
*connect('fill_value', 'memory_in'),
|
||||
*connect('memory_in', 'crop_in'),
|
||||
*connect('crop_in', '0:concat'),
|
||||
*connect('fill_value_ones_2:0', '1:concat'),
|
||||
*connect('concat', 'memory_out'),
|
||||
*connect('memory_out', 'result'),
|
||||
*connect('concat', 'crop_out'),
|
||||
*connect('crop_out', '1:equal'),
|
||||
*connect('fill_value_ones_2:0', '0:equal'),
|
||||
*connect('equal', '0:select'),
|
||||
*connect('fill_value_ones', '2:select'),
|
||||
*connect('select', 'memory')
|
||||
],
|
||||
nodes_with_edges_only=True)
|
||||
graph.strict_mode = False
|
||||
MakeKaldiConstReshapable().find_and_replace_pattern(graph)
|
||||
ref_graph = build_graph(nodes,
|
||||
[*connect('placeholder_1:0', 'splice_1'),
|
||||
*connect('splice_1', 'placeholder_2'),
|
||||
*connect('placeholder_2', '1:select'),
|
||||
*connect('placeholder_1:0', 'shape', skip_data=True),
|
||||
*connect('shape', '0:crop_batch'),
|
||||
*connect('crop_batch_dim', '1:crop_batch'),
|
||||
*connect('second_dim', '1:gather_shape'),
|
||||
*connect('crop_batch', '0:gather_shape'),
|
||||
*connect('fill_value_2', '0:broadcast'),
|
||||
*connect('gather_shape', '1:broadcast'),
|
||||
*connect('broadcast', 'memory_in'),
|
||||
*connect('memory_in', 'crop_in'),
|
||||
*connect('crop_in', '0:concat'),
|
||||
*connect('fill_value_ones_2', '1:concat'),
|
||||
*connect('concat', 'memory_out'),
|
||||
*connect('memory_out', 'result'),
|
||||
*connect('concat', 'crop_out'),
|
||||
*connect('crop_out', '1:equal'),
|
||||
*connect('fill_value_ones_2', '0:equal'),
|
||||
*connect('equal', '0:select'),
|
||||
*connect('const_0', '2:select'),
|
||||
*connect('fill_value_ones', '2:select'),
|
||||
*connect('select', 'memory')
|
||||
], nodes_with_edges_only=True)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'memory')
|
||||
self.assertTrue(flag, resp)
|
@ -2,15 +2,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import logging as log
|
||||
|
||||
from extensions.front.kaldi.replace_lstm_node_pattern import create_zero_value_with_batch_from_input
|
||||
from extensions.ops.splice import Splice
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.assign import Assign
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.crop import Crop
|
||||
from mo.ops.read_value import ReadValue
|
||||
from mo.ops.result import Result
|
||||
@ -134,7 +133,9 @@ class ReplaceMemoryOffsetWithMemoryNodePattern(MiddleReplacementPattern):
|
||||
in_shape = input_port.data.get_shape()
|
||||
node_t = abs(node.t)
|
||||
|
||||
init_value_memory_out = create_zero_value_with_batch_from_input(input_port, in_shape[1]*node_t)
|
||||
init_value_memory_out = Const(graph, {'name': 'init_value_' + pair_name,
|
||||
'value': np.zeros(int64_array([in_shape[0], in_shape[1]*node_t])),
|
||||
'shape': int64_array([in_shape[0], in_shape[1]*node_t])}).create_node()
|
||||
memory_out = ReadValue(graph, {'name': pair_name, 'variable_id': node_name+pair_name}).create_node()
|
||||
init_value_memory_out.out_port(0).connect(memory_out.in_port(0))
|
||||
|
||||
@ -163,14 +164,6 @@ class ReplaceMemoryOffsetWithMemoryNodePattern(MiddleReplacementPattern):
|
||||
memory_in.out_port(0).connect(out.in_port(0))
|
||||
out_port.get_connection().set_source(memory_out.out_port(0))
|
||||
|
||||
if not graph.graph['cmd_params'].static_shape:
|
||||
log.error(
|
||||
"Model can not be translated in a reshape-able way.\n"
|
||||
"Model Optimizer key static_shape was turned on to prevent related errors.\n"
|
||||
"There will be no success changing input shapes of the model with the help of "
|
||||
"InferenceEngine reshape method", extra={'is_warning': True})
|
||||
graph.graph['cmd_params'].static_shape = True
|
||||
|
||||
graph.remove_node(op_output_id)
|
||||
graph.remove_node(node.id)
|
||||
graph.remove_node(pair_node.id)
|
||||
|
@ -1,7 +1,9 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from extensions.front.kaldi.replace_lstm_node_pattern import unique_id, create_zero_value_with_batch_from_input
|
||||
import numpy as np
|
||||
|
||||
from extensions.front.kaldi.replace_lstm_node_pattern import unique_id
|
||||
from extensions.ops.split import VariadicSplit
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
@ -9,6 +11,7 @@ from mo.graph.graph import Graph
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.assign import Assign
|
||||
from mo.ops.concat import Concat
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.crop import Crop
|
||||
from mo.ops.read_value import ReadValue
|
||||
from mo.ops.result import Result
|
||||
@ -93,8 +96,11 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
||||
|
||||
# create separate splice construction for const_dim
|
||||
memory_pair_id = unique_id('memory_for_const_dim')
|
||||
init_value_input_memory_const_dim = create_zero_value_with_batch_from_input(split.out_port(1),
|
||||
memory_size_constdim)
|
||||
init_value_input_memory_const_dim = Const(graph, {'name': 'init_value_const_dim_in_memory',
|
||||
'value': np.zeros(int64_array([in_shape[0],
|
||||
memory_size_constdim])),
|
||||
'shape': int64_array([in_shape[0],
|
||||
memory_size_constdim])}).create_node()
|
||||
input_memory_const_dim = ReadValue(graph, {'name': 'const_dim_in_memory',
|
||||
'variable_id': memory_pair_id}).create_node()
|
||||
init_value_input_memory_const_dim.out_port(0).connect(input_memory_const_dim.in_port(0))
|
||||
@ -129,14 +135,16 @@ class ReplaceSpliceNodePattern(MiddleReplacementPattern):
|
||||
concat_const.in_port(1).connect(crop_first.out_port(0))
|
||||
concat_const.in_port(0).connect(concat_node.out_port(0))
|
||||
|
||||
init_value_input_memory = create_zero_value_with_batch_from_input(split.out_port(0),
|
||||
memory_size)
|
||||
init_value_input_memory = Const(graph, {'name': 'init_value_' + node.name,
|
||||
'value': np.zeros(int64_array([in_shape[0], memory_size])),
|
||||
'shape': int64_array([in_shape[0], memory_size])}).create_node()
|
||||
init_value_input_memory.out_port(0).connect(input_memory.in_port(0))
|
||||
node.in_port(0).get_connection().set_destination(split.in_port(0))
|
||||
node.out_port(0).get_connection().set_source(concat_const.out_port(0))
|
||||
else:
|
||||
init_value_input_memory = create_zero_value_with_batch_from_input(node.in_port(0).get_source(),
|
||||
memory_size)
|
||||
init_value_input_memory = Const(graph, {'name': 'init_value_' + node.name,
|
||||
'value': np.zeros(int64_array([in_shape[0], memory_size])),
|
||||
'shape': int64_array([in_shape[0], memory_size])}).create_node()
|
||||
init_value_input_memory.out_port(0).connect(input_memory.in_port(0))
|
||||
node.in_port(0).get_connection().set_destination(concat_node.in_port(1))
|
||||
node.out_port(0).get_connection().set_source(concat_node.out_port(0))
|
||||
|
@ -32,20 +32,8 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
|
||||
ref_graph = build_graph({'in_placeholder': {'kind': 'op', 'op': None},
|
||||
'in_node': {'kind': 'data', 'shape': [1, 13]},
|
||||
|
||||
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data'},
|
||||
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||
'crop_batch_data': {'kind': 'data'},
|
||||
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||
'crop_batch_dim_data': {'kind': 'data'},
|
||||
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([143])},
|
||||
'second_dim_data': {'kind': 'data'},
|
||||
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data': {'kind': 'data'},
|
||||
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data': {'kind': 'data'},
|
||||
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data': {'kind': 'data'},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
|
||||
'memory_in_data': {'kind': 'data'},
|
||||
@ -61,16 +49,7 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
|
||||
[
|
||||
('in_placeholder', 'in_node'),
|
||||
|
||||
('in_node', 'shape'), ('shape', 'shape_data'),
|
||||
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape', {'in': 0}),
|
||||
('gather_shape', 'gather_shape_data'),
|
||||
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||
('broadcast_data', 'memory_in'),
|
||||
('fill_value', 'fill_value_data'), ('fill_value_data', 'memory_in'),
|
||||
|
||||
('memory_in', 'memory_in_data'),
|
||||
('memory_in_data', 'crop_mem'),
|
||||
@ -104,20 +83,8 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
|
||||
'split_data_0': {'kind': 'data'},
|
||||
'split_data_1': {'kind': 'data'},
|
||||
|
||||
'shape': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data'},
|
||||
'crop_batch': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||
'crop_batch_data': {'kind': 'data'},
|
||||
'crop_batch_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||
'crop_batch_dim_data': {'kind': 'data'},
|
||||
'second_dim': {'kind': 'op', 'op': 'Const', 'value': int64_array([33])},
|
||||
'second_dim_data': {'kind': 'data'},
|
||||
'gather_shape': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_data': {'kind': 'data'},
|
||||
'fill_value': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_data': {'kind': 'data'},
|
||||
'broadcast': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_data': {'kind': 'data'},
|
||||
|
||||
'memory_in': {'kind': 'op', 'op': 'ReadValue'},
|
||||
'memory_in_data': {'kind': 'data'},
|
||||
@ -129,21 +96,10 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
|
||||
'memory_out_data': {'kind': 'data'},
|
||||
'result': {'kind': 'op', 'op': 'Result'},
|
||||
|
||||
'shape_2': {'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_2_data': {'kind': 'data'},
|
||||
'crop_batch_2': {'kind': 'op', 'op': 'Crop', 'offset': int64_array([0])},
|
||||
'crop_batch_2_data': {'kind': 'data'},
|
||||
'crop_batch_dim_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([1])},
|
||||
'crop_batch_dim_2_data': {'kind': 'data'},
|
||||
'second_dim_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([33])},
|
||||
'second_dim_2_data': {'kind': 'data'},
|
||||
'gather_shape_2': {'kind': 'op', 'op': 'Concat'},
|
||||
'gather_shape_2_data': {'kind': 'data'},
|
||||
|
||||
'fill_value_2': {'kind': 'op', 'op': 'Const', 'value': int64_array([0])},
|
||||
'fill_value_2_data': {'kind': 'data'},
|
||||
'broadcast_2': {'kind': 'op', 'op': 'Broadcast'},
|
||||
'broadcast_2_data': {'kind': 'data'},
|
||||
|
||||
\
|
||||
'memory_in_constdims': {'kind': 'op', 'op': 'ReadValue'},
|
||||
'memory_in_constdims_data': {'kind': 'data'},
|
||||
'crop_mem_constdims': {'kind': 'op', 'op': 'Crop', 'offset': 10, 'dim': 100},
|
||||
@ -171,16 +127,7 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
|
||||
('split', 'split_data_0', {'out': 0}),
|
||||
('split', 'split_data_1', {'out': 1}),
|
||||
|
||||
('split_data_0', 'shape'), ('shape', 'shape_data'),
|
||||
('shape_data', 'crop_batch'), ('crop_batch', 'crop_batch_data'),
|
||||
('crop_batch_dim', 'crop_batch_dim_data'),
|
||||
('crop_batch_dim_data', 'crop_batch', {'in': 1}),
|
||||
('second_dim', 'second_dim_data'), ('second_dim_data', 'gather_shape', {'in': 1}),
|
||||
('crop_batch_data', 'gather_shape', {'in': 0}),
|
||||
('gather_shape', 'gather_shape_data'),
|
||||
('fill_value', 'fill_value_data'), ('fill_value_data', 'broadcast', {'in': 0}),
|
||||
('gather_shape_data', 'broadcast', {'in': 1}), ('broadcast', 'broadcast_data'),
|
||||
('broadcast_data', 'memory_in'),
|
||||
('fill_value', 'fill_value_data'), ('fill_value_data', 'memory_in'),
|
||||
|
||||
('memory_in', 'memory_in_data'),
|
||||
('memory_in_data', 'crop_mem'),
|
||||
@ -192,16 +139,7 @@ class ReplaceSpliceNodePatternTests(unittest.TestCase):
|
||||
('memory_out', 'memory_out_data'),
|
||||
('memory_out_data', 'result'),
|
||||
|
||||
('split_data_1', 'shape_2'), ('shape_2', 'shape_2_data'),
|
||||
('shape_2_data', 'crop_batch_2'), ('crop_batch_2', 'crop_batch_2_data'),
|
||||
('crop_batch_dim_2', 'crop_batch_dim_2_data'),
|
||||
('crop_batch_dim_2_data', 'crop_batch_2', {'in': 1}),
|
||||
('second_dim_2', 'second_dim_2_data'), ('second_dim_2_data', 'gather_shape_2', {'in': 1}),
|
||||
('crop_batch_2_data', 'gather_shape_2', {'in': 0}),
|
||||
('gather_shape_2', 'gather_shape_2_data'),
|
||||
('fill_value_2', 'fill_value_2_data'), ('fill_value_2_data', 'broadcast_2', {'in': 0}),
|
||||
('gather_shape_2_data', 'broadcast_2', {'in': 1}), ('broadcast_2', 'broadcast_2_data'),
|
||||
('broadcast_2_data', 'memory_in_constdims'),
|
||||
('fill_value_2', 'fill_value_2_data'), ('fill_value_2_data', 'memory_in_constdims'),
|
||||
|
||||
('memory_in_constdims', 'memory_in_constdims_data'),
|
||||
('memory_in_constdims_data', 'crop_mem_constdims'),
|
||||
|
@ -32,4 +32,4 @@ class SplitTdnnMemoryOffset(MiddleReplacementPattern):
|
||||
paired_node['element_size'] = offset_node['element_size']
|
||||
# Copy shape from previous node. Typically (but not always) for TDNN blocks this is the case
|
||||
else:
|
||||
paired_node['element_size'] = offset_node.in_port(0).data.get_shape()[1]
|
||||
paired_node['element_size'] = offset_node.in_port(0).data.get_shape()
|
||||
|
@ -9,7 +9,7 @@ import numpy as np
|
||||
|
||||
from extensions.ops.elementwise import Mul
|
||||
from extensions.ops.split import AttributedVariadicSplit
|
||||
from mo.front.common.partial_infer.utils import float_array
|
||||
from mo.front.common.partial_infer.utils import float_array, int64_array
|
||||
from mo.front.extractor import add_outputs_identity
|
||||
from mo.front.kaldi.loader.utils import find_next_tag, read_placeholder, find_next_component, get_name_from_path, \
|
||||
find_end_of_component, end_of_nnet_tag, read_binary_integer32_token, get_parameters, read_token_value, \
|
||||
@ -214,7 +214,9 @@ def load_kaldi_nnet3_model(graph, file_descr, nnet_name):
|
||||
for o_n_name, params in node.get_outputs():
|
||||
o_n = Node(graph, o_n_name)
|
||||
if o_n['op'] == 'MemoryOffset':
|
||||
o_n['parameters']['element_size'] = node['shape'][1]
|
||||
# don't take batch from Parameter, it will be overwritten
|
||||
# take only second dimension because we have only 2 dimensions
|
||||
o_n['parameters']['element_size'] = int64_array([1, node.shape[1]])
|
||||
|
||||
load_components(file_descr, graph, component_layer_map)
|
||||
|
||||
@ -268,7 +270,7 @@ def load_components(file_descr, graph, component_layer_map=None):
|
||||
for o_n_name, params in node.get_outputs():
|
||||
o_n = Node(graph, o_n_name)
|
||||
if o_n['op'] == 'MemoryOffset' and dim != 0:
|
||||
o_n['parameters']['element_size'] = dim
|
||||
o_n['parameters']['element_size'] = int64_array([1, dim])
|
||||
else:
|
||||
raise Error("Something wrong with layer {}".format(name))
|
||||
else:
|
||||
@ -401,7 +403,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
|
||||
for o_n_name, params in node.get_outputs():
|
||||
o_n = Node(graph, o_n_name)
|
||||
if o_n['op'] == 'MemoryOffset':
|
||||
o_n['parameters']['element_size'] = dim
|
||||
o_n['parameters']['element_size'] = int64_array([1, dim])
|
||||
else:
|
||||
raise Error("Unsupported node specifier {}".format(tokens[0]))
|
||||
return True
|
||||
|
@ -17,17 +17,16 @@ class MemoryOffset(Op):
|
||||
'pair_name': None,
|
||||
'splitted': False,
|
||||
'has_default': False,
|
||||
'infer': __class__.infer,
|
||||
'infer': self.infer,
|
||||
'in_ports_count': 1,
|
||||
'out_ports_count': 1,
|
||||
}, attrs)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def infer(node: Node):
|
||||
if node.has_valid('element_size'):
|
||||
# element_size should be set by Kaldi loader or by MemoryOffsetAdjustment
|
||||
node.out_port(0).data.set_shape([1, node['element_size']])
|
||||
# element_size should be set by Kaldi loader or MemoryOffsetAdjustment or SplitRecurrentMemoryOffset
|
||||
node.out_port(0).data.set_shape(node.element_size)
|
||||
else:
|
||||
# for TDNN blocks
|
||||
copy_shape_infer(node)
|
||||
|
Loading…
Reference in New Issue
Block a user