159 lines
8.3 KiB
Python
159 lines
8.3 KiB
Python
"""
|
|
Copyright (C) 2018-2020 Intel Corporation
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
import numpy as np
|
|
|
|
from extensions.front.kaldi.replace_lstm_node_pattern import create_zero_value_with_batch_from_input
|
|
from extensions.ops.elementwise import Equal
|
|
from extensions.ops.select import Select
|
|
from mo.front.common.partial_infer.utils import int64_array
|
|
from mo.graph.graph import Graph, Node
|
|
from mo.middle.pattern_match import find_pattern_matches, inverse_dict
|
|
from mo.middle.replacement import MiddleReplacementPattern
|
|
from mo.ops.assign import Assign
|
|
from mo.ops.concat import Concat
|
|
from mo.ops.const import Const
|
|
from mo.ops.crop import Crop
|
|
from mo.ops.read_value import ReadValue
|
|
from mo.ops.result import Result
|
|
from mo.utils.error import Error
|
|
from mo.utils.graph import invert_sub_graph_between_nodes
|
|
|
|
|
|
class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
|
|
"""
|
|
Add Select before saving state with Memory to avoid garbage saving
|
|
"""
|
|
enabled = True
|
|
|
|
def run_after(self):
|
|
from extensions.middle.ReplaceMemoryOffsetWithSplice import ReplaceMemoryOffsetWithMemoryNodePattern
|
|
from extensions.middle.RemoveDuplicationMemory import MergeNeighborSplicePattern
|
|
return [ReplaceMemoryOffsetWithMemoryNodePattern,
|
|
MergeNeighborSplicePattern]
|
|
|
|
def run_before(self):
|
|
from extensions.middle.ReplaceSpliceNodePattern import ReplaceSpliceNodePattern
|
|
return [ReplaceSpliceNodePattern]
|
|
|
|
@staticmethod
|
|
def pattern():
|
|
return dict(
|
|
nodes=[('op', dict(op='Assign'))],
|
|
edges=[])
|
|
|
|
@staticmethod
|
|
def replace_pattern(graph: Graph, match: dict):
|
|
node = match['op']
|
|
|
|
if node.name == 'iteration_number_out':
|
|
return
|
|
|
|
# calculate length of context when state of inference becomes meaningful
|
|
inputs = []
|
|
for n in graph.get_op_nodes(**{'op': 'Parameter'}):
|
|
inputs.append(n)
|
|
|
|
in_nodes = []
|
|
for inp in inputs:
|
|
for ins in inp.out_port(0).get_destinations():
|
|
in_nodes.append(ins.node.name)
|
|
|
|
context_len = 1
|
|
try:
|
|
subgraph = invert_sub_graph_between_nodes(graph, [node.in_port(0).get_source().node.name], in_nodes)
|
|
except Error:
|
|
return
|
|
|
|
for n in subgraph:
|
|
n_node = Node(graph, n)
|
|
if n_node.kind == 'op' and n_node.op == 'Splice':
|
|
context_len += len(n_node.context) - 1
|
|
|
|
if context_len == 1:
|
|
return
|
|
|
|
in_node_port = node.in_port(0).get_source()
|
|
in_node_shape = node.in_port(0).data.get_shape()
|
|
node.in_port(0).disconnect()
|
|
|
|
# add Select before saving state to avoid saving garbage
|
|
select_node = Select(graph, {'name': 'select_' + node.name}).create_node()
|
|
zero_else = Const(graph, {'name': 'zero_else', 'value': np.zeros(in_node_shape)}).create_node()
|
|
select_node.in_port(1).connect(in_node_port)
|
|
select_node.in_port(2).connect(zero_else.out_port(0))
|
|
|
|
# check if we have already appropriate iteration counter
|
|
existing_counters = find_pattern_matches(graph, nodes=[('mem_in', dict(op='ReadValue')),
|
|
('mem_in_data', dict(shape=int64_array([context_len]))),
|
|
('crop_mem_in', dict(op='Crop', axis=int64_array([1]),
|
|
offset=int64_array([1]),
|
|
dim=int64_array([context_len-1]))),
|
|
('crop_mem_in_data', dict()),
|
|
('concat', dict(op='Concat', axis=1)),
|
|
('concat_data', dict()),
|
|
('const_1', dict(op='Const')),
|
|
('const_1_data', dict()),
|
|
('mem_out', dict(op='Assign')),
|
|
('crop_out', dict(op='Crop', axis=int64_array([1]),
|
|
offset=int64_array([0]),
|
|
dim=int64_array([1]))),
|
|
('crop_out_data', dict()),
|
|
('select', dict(op='Select'))
|
|
],
|
|
edges=[('mem_in', 'mem_in_data'), ('mem_in_data', 'crop_mem_in'),
|
|
('crop_mem_in', 'crop_mem_in_data'),
|
|
('crop_mem_in_data', 'concat', {'in': 0}),
|
|
('const_1', 'const_1_data'),
|
|
('const_1_data', 'concat', {'in': 1}),
|
|
('concat', 'concat_data'), ('concat_data', 'mem_out'),
|
|
('concat_data', 'crop_out'), ('crop_out', 'crop_out_data'),
|
|
('crop_out_data', 'select')])
|
|
counter_match = next(existing_counters, None)
|
|
if counter_match is not None:
|
|
ones = Node(graph, inverse_dict(counter_match)['const_1'])
|
|
input_port = Node(graph, inverse_dict(counter_match)['crop_out']).out_port(0)
|
|
else:
|
|
init_value_mem_out = create_zero_value_with_batch_from_input(in_node_port, context_len, np.int32)
|
|
mem_out = ReadValue(graph, {'name': 'iteration_number',
|
|
'variable_id': 'iteration_'+node.name}).create_node()
|
|
mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
|
|
cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]),
|
|
'offset': int64_array([1]), 'dim': int64_array([context_len-1])}).create_node()
|
|
cut_first.in_port(0).connect(mem_out.out_port(0))
|
|
ones = Const(graph, {'name': 'ones', 'value': np.ones([1, 1], dtype=np.int32)}).create_node()
|
|
concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node()
|
|
concat.in_port(0).connect(cut_first.out_port(0))
|
|
concat.in_port(1).connect(ones.out_port(0))
|
|
mem_in = Assign(graph, {'name': 'iteration_number_out',
|
|
'variable_id': 'iteration_'+node.name}).create_node()
|
|
mem_in.in_port(0).connect(concat.out_port(0))
|
|
res = Result(graph, {}).create_node()
|
|
mem_in.out_port(0).connect(res.in_port(0))
|
|
cut_last = Crop(graph, {'name': 'cut_last', 'axis': int64_array([1]),
|
|
'offset': int64_array([0]), 'dim': int64_array([1])}).create_node()
|
|
cut_last.in_port(0).connect(concat.out_port(0))
|
|
input_port = cut_last.out_port(0)
|
|
|
|
# Check if data from memory is 1
|
|
# if it is True, we have correct data and should proceed with saving it to memory
|
|
# else we have not gathered context and have garbage here, shouldn't change initial state of memory
|
|
cast_in = Equal(graph, {'name': input_port.node.name + '/cast_to_bool'}).create_node()
|
|
cast_in.in_port(0).connect(ones.out_port(0))
|
|
cast_in.in_port(1).connect(input_port)
|
|
select_node.in_port(0).connect(cast_in.out_port(0))
|
|
select_node.out_port(0).connect(node.in_port(0))
|
|
select_node.out_port(0).data.set_shape(in_node_shape)
|