Kaldi priors (#6258)

* add priors to loader and counts transformation

* fixes in select insertion for case with context gathering - LSTM - context gathering
fix for edge parallel to ReadValue
extend counts option to case of priors inside mdl model file

* fixed tests

* fixed typo

* fixed issue with input names

* fix priors loading + comments

* fix e2e test: error with not found transformation

* print debug info for dependency graph - should be reverted

* should be reverted: debug commit

* Revert "fix e2e test: error with not found transformation"

This reverts commit 8320fa99bf.

* revert debug commits

* fixes after review

* review fixes

* review change

* review changes
This commit is contained in:
Svetlana Dolinina 2021-07-02 13:18:23 +03:00 committed by GitHub
parent fc7f80a34e
commit ccf786438b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 159 additions and 65 deletions

View File

@ -50,7 +50,7 @@ The following list provides the Kaldi\*-specific parameters.
```sh
Kaldi-specific parameters:
--counts COUNTS A file name with full path to the counts file
--counts COUNTS A file name with full path to the counts file or empty string to utilize count values from the model file
--remove_output_softmax
Removes the Softmax that is the output layer
--remove_memory Remove the Memory layer and add new inputs and outputs instead
@ -78,6 +78,8 @@ python3 mo.py --input_model wsj_dnn5b_smbr.nnet --counts wsj_dnn5b_smbr.counts -
\f$|C|\f$ - number of elements in the counts array;
* The normalized counts are subtracted from biases of the last or next to last layer (if last layer is SoftMax).
> **NOTE:** Model Optimizer will show warning if model contains counts values inside model and `--counts` option is not used.
* If you want to remove the last SoftMax layer in the topology, launch the Model Optimizer with the
`--remove_output_softmax` flag.
```sh

View File

@ -51,7 +51,7 @@ def apply_biases_to_last_layer(graph, counts):
outputs_ids = find_outputs(graph)
for output in outputs_ids.copy():
node = Node(graph, output)
if node.op != 'Assign':
if node.op != 'Assign' and node.op != "MemoryOffset":
continue
outputs_ids.remove(output)
@ -82,6 +82,11 @@ def read_counts_file(file_path):
except TypeError:
raise Error('Expect counts file to contain list of floats.' +
refer_to_faq_msg(90))
return counts
def counts_to_priors(counts):
cutoff = 1.00000001e-10
cutoff_idxs = np.where(counts < cutoff)
counts[cutoff_idxs] = cutoff
@ -111,10 +116,21 @@ class ApplyCountsFilePattern(FrontReplacementSubgraph):
]
def find_and_replace_pattern(self, graph: Graph):
try:
counts = read_counts_file(graph.graph['cmd_params'].counts)
except Exception as e:
raise Error('Model Optimizer is not able to read counts file {}'.format(graph.graph['cmd_params'].counts) +
refer_to_faq_msg(92)) from e
# if empty string is in counts, read priors from model itself (on loader stage)
if graph.graph['cmd_params'].counts == "":
assert isinstance(graph.graph['priors'], (list, np.ndarray)) and len(graph.graph['priors']) != 0, \
"Model file does not contain Priors tag with counts values, use separate file instead"
counts = graph.graph['priors'].copy()
else:
# read counts from given file
try:
counts = read_counts_file(graph.graph['cmd_params'].counts)
except Exception as e:
raise Error('Model Optimizer is not able to read counts file {}'.format(graph.graph['cmd_params'].counts) +
refer_to_faq_msg(92)) from e
# calculate normalized counts as follows:
# c_i=log(c_i/sum(c_j))
# set max_float/2 for almost zero c_i (< 1.e-10)
counts = counts_to_priors(counts)
apply_biases_to_last_layer(graph, counts)

View File

@ -1,6 +1,7 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import networkx as nx
import numpy as np
from extensions.middle.MakeKaldiConstReshapable import create_const_with_batch_from_input
@ -15,15 +16,27 @@ from mo.ops.concat import Concat
from mo.ops.crop import Crop
from mo.ops.read_value import ReadValue
from mo.ops.result import Result
from mo.utils.graph import bfs_search
from mo.utils.error import Error
from mo.utils.graph import invert_sub_graph_between_nodes
class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
"""
Add Select before saving state with Memory to avoid garbage saving
Add Select before saving state with Memory to avoid garbage saving.
We need to know delay on each node where Select is adding. For that we traverse the whole graph and set frame time
for each node using the following rules:
* Splice increases frame time by length of its context. If Crop is following Splice - it takes one concrete
moment of time, so frame time increases by its value
Example:
node ---> Splice(-5, -4, ... 0) ---> node
frame time: 0 ---> 5 ---> 5
node ---> Splice(-5, -4, ... 0) ---> Crop(offset = 2, dim = 1) ---> node
frame time: 0 ---> 5 ---> 3 ---> 3
* Nodes with several inputs have frame time= max (frame time of each input)
* Node with one input have the same frame time as its input
"""
enabled = True
graph_condition = [lambda graph: graph.graph['fw'] == 'kaldi']
def run_after(self):
from extensions.middle.ReplaceMemoryOffsetWithSplice import ReplaceMemoryOffsetWithMemoryNodePattern
@ -36,38 +49,63 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
return [ReplaceSpliceNodePattern]
@staticmethod
def pattern():
return dict(
nodes=[('op', dict(op='Assign'))],
edges=[])
def calculate_frame_time(graph: Graph):
# there are either one or two inputs in Kaldi. Only main input can change delay in network.
# Usually ivector input has name 'ivector'.
inputs = graph.get_op_nodes(op='Parameter')
if len(inputs) == 1:
inp_name = inputs[0].name
elif len(inputs) == 2:
if inputs[0].name == 'ivector':
inp_name = inputs[1].name
elif inputs[1].name == 'ivector':
inp_name = inputs[0].name
else:
raise Error("There are 2 inputs for Kaldi model but we can't find out which one is ivector. " +
"Use name \'ivector\' for the corresponding input")
else:
raise Error("There are {} inputs for Kaldi model but we expect only 1 or 2".format(len(inputs)))
# sort nodes to calculate delays
nodes = list(bfs_search(graph, [inp_name]))
nx.set_node_attributes(G=graph, name='frame_time', values=-1)
for n in nodes:
node = Node(graph, n)
# just ignore data nodes
if node.kind != 'op':
continue
# calculate frame_time (delay) that was not calculated
if node.frame_time < 0:
# Splice increases frame delay
if node.op == "Splice":
node.frame_time = node.in_port(0).get_source().node.frame_time + len(node.context) - 1
# crop often used to get concrete time frame, set frame_time correctly for this case
elif node.op == 'Crop':
if node.in_port(0).get_connection().get_source().node.op == 'Splice':
splice_node = node.in_port(0).get_source().node
assert len(node.offset) == 1
assert len(node.dim) == 1
new_delay = splice_node.context[node.offset[0] // node.dim[0]] - splice_node.context[0]
node.frame_time = splice_node.in_port(0).get_source().node.frame_time + new_delay
else:
node.frame_time = node.in_port(0).get_source().node.frame_time
# for node with several inputs frame_time = maximum of delays from branches
else:
# find out maximum of delay and check that we have at least one branch with another delay
node.frame_time = 0
for inp in node.in_ports():
if node.in_port(inp).disconnected():
continue
in_node = node.in_port(inp).get_source().node
if in_node.frame_time > node.frame_time:
node.frame_time = in_node.frame_time
@staticmethod
def replace_pattern(graph: Graph, match: dict):
node = match['op']
if node.name == 'iteration_number_out':
return
# calculate length of context when state of inference becomes meaningful
inputs = []
for n in graph.get_op_nodes(**{'op': 'Parameter'}):
inputs.append(n)
in_nodes = []
for inp in inputs:
for ins in inp.out_port(0).get_destinations():
in_nodes.append(ins.node.name)
context_len = 1
try:
subgraph = invert_sub_graph_between_nodes(graph, [node.in_port(0).get_source().node.name], in_nodes)
except Error:
return
for n in subgraph:
n_node = Node(graph, n)
if n_node.kind == 'op' and n_node.op == 'Splice':
context_len += len(n_node.context) - 1
def insert_select(graph: Graph, node: Node):
context_len = node.frame_time + 1
if context_len == 1:
return
@ -87,7 +125,7 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
('mem_in_data', dict(shape=int64_array([context_len]))),
('crop_mem_in', dict(op='Crop', axis=int64_array([1]),
offset=int64_array([1]),
dim=int64_array([context_len-1]))),
dim=int64_array([context_len - 1]))),
('crop_mem_in_data', dict()),
('concat', dict(op='Concat', axis=1)),
('concat_data', dict()),
@ -115,17 +153,17 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
else:
init_value_mem_out = create_const_with_batch_from_input(in_node_port, context_len, precision=np.int32)
mem_out = ReadValue(graph, {'name': 'iteration_number',
'variable_id': 'iteration_'+node.name}).create_node()
'variable_id': 'iteration_' + node.name}).create_node()
mem_out.in_port(0).connect(init_value_mem_out.out_port(0))
cut_first = Crop(graph, {'name': 'cut_first', 'axis': int64_array([1]),
'offset': int64_array([1]), 'dim': int64_array([context_len-1])}).create_node()
'offset': int64_array([1]), 'dim': int64_array([context_len - 1])}).create_node()
cut_first.in_port(0).connect(mem_out.out_port(0))
ones = create_const_with_batch_from_input(in_node_port, 1, 1, np.int32)
concat = Concat(graph, {'name': 'concat_ones', 'in_ports_count': 2, 'axis': 1}).create_node()
concat.in_port(0).connect(cut_first.out_port(0))
concat.in_port(1).connect(ones.out_port(0))
mem_in = Assign(graph, {'name': 'iteration_number_out',
'variable_id': 'iteration_'+node.name}).create_node()
'variable_id': 'iteration_' + node.name}).create_node()
mem_in.in_port(0).connect(concat.out_port(0))
res = Result(graph, {}).create_node()
mem_in.out_port(0).connect(res.in_port(0))
@ -143,3 +181,19 @@ class AddSelectBeforeMemoryNodePattern(MiddleReplacementPattern):
select_node.in_port(0).connect(cast_in.out_port(0))
select_node.out_port(0).connect(node.in_port(0))
select_node.out_port(0).data.set_shape(in_node_shape)
def find_and_replace_pattern(self, graph: Graph):
if np.all([node.soft_get('name', node.id) == 'iteration_number_out'
for node in graph.get_op_nodes(op='Assign')]):
return
self.calculate_frame_time(graph)
for node in graph.get_op_nodes(op='Assign'):
if node.soft_get('name', node.id) == 'iteration_number_out':
continue
self.insert_select(graph, node)
for node in graph.get_op_nodes():
if 'frame_time' in node:
del node['frame_time']

View File

@ -77,7 +77,7 @@ class ReplaceMemoryOffsetNodePattern(MiddleReplacementPattern):
outs = input_node_out_port.get_destinations()
for in_port in outs:
out_ = in_port.node
if out_['op'] != 'MemoryOffset' and out_['op'] != 'Splice':
if out_.op == 'Concat' and out_ == out_node_in_ports[0].node:
crop_input = Crop(graph, {'name': 'Splice_Crop',
'axis': int64_array([1]),
'offset': int64_array([-min(0, in_shape[1] * node_t)]),

View File

@ -14,6 +14,7 @@ from mo.front.extractor import add_outputs_identity
from mo.front.kaldi.loader.utils import find_next_tag, read_placeholder, find_next_component, get_name_from_path, \
find_end_of_component, end_of_nnet_tag, read_binary_integer32_token, get_parameters, read_token_value, \
collect_until_token, collect_until_token_and_read, create_edge_attrs, get_args_for_specifier
from mo.front.kaldi.utils import read_binary_vector
from mo.graph.graph import Node, Graph
from mo.ops.const import Const
from mo.utils.error import Error
@ -222,6 +223,20 @@ def load_kaldi_nnet3_model(graph, file_descr, nnet_name):
o_n['parameters']['element_size'] = int64_array([1, node.shape[1]])
load_components(file_descr, graph, component_layer_map)
load_priors(file_descr, graph)
def load_priors(file_descr, graph):
try:
collect_until_token(file_descr, b'<Priors>')
except Error:
# just ignore if priors were not found
return
if graph.graph['cmd_params'].counts is not None:
graph.graph['priors'] = read_binary_vector(file_descr)
else:
log.error("Model contains Prior values, if you want to embed them into the generated IR add option --counts=\"\" to command line",
extra={'is_warning': True})
def load_components(file_descr, graph, component_layer_map=None):

View File

@ -75,6 +75,17 @@ class CanonicalizePathCheckExistenceAction(CanonicalizePathAction):
' but "{}" does not exist.'.format(self.dest, name))
class CanonicalizePathCheckExistenceIfNeededAction(CanonicalizePathCheckExistenceAction):
def __call__(self, parser, namespace, values, option_string=None):
if values is not None:
if isinstance(values, str):
if values != "":
super().__call__(parser, namespace, values, option_string)
else:
setattr(namespace, self.dest, values)
class DeprecatedCanonicalizePathCheckExistenceAction(CanonicalizePathCheckExistenceAction):
def __call__(self, parser, namespace, values, option_string=None):
super().__call__(parser, namespace, values, option_string)
@ -415,7 +426,7 @@ def get_mxnet_cli_options():
def get_kaldi_cli_options():
d = {
'counts': '- A file name with full path to the counts file',
'counts': '- A file name with full path to the counts file or empty string if you want to use counts from model',
'remove_output_softmax': '- Removes the SoftMax layer that is the output layer',
'remove_memory': '- Removes the Memory layer and use additional inputs and outputs instead'
}
@ -609,7 +620,7 @@ def get_kaldi_cli_parser(parser: argparse.ArgumentParser = None):
kaldi_group.add_argument("--counts",
help="Path to the counts file",
default=None,
action=CanonicalizePathCheckExistenceAction)
action=CanonicalizePathCheckExistenceIfNeededAction)
kaldi_group.add_argument("--remove_output_softmax",
help="Removes the SoftMax layer that is the output layer",

View File

@ -16,11 +16,11 @@ class InsertSelectTests(unittest.TestCase):
# graph have no splices - selects should not be inserted
def test_insert_select_0(self):
graph = build_graph({
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
'input': {'kind': 'op', 'op': 'Parameter'},
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
'memory': {'kind': 'op', 'op': 'Assign'},
},
[('placeholder_1', 'placeholder_data_1'),
[('input', 'placeholder_data_1'),
('placeholder_data_1', 'memory')
],
nodes_with_edges_only=True)
@ -33,7 +33,7 @@ class InsertSelectTests(unittest.TestCase):
# graph contains 1 splice with context length 5, should be inserted select with memory as counter with length 5
def test_insert_select_1(self):
graph = build_graph({
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
'input': {'kind': 'op', 'op': 'Parameter'},
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
@ -41,7 +41,7 @@ class InsertSelectTests(unittest.TestCase):
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
},
[('placeholder_1', 'placeholder_data_1'),
[('input', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
('placeholder_data_2', 'memory')
@ -49,7 +49,7 @@ class InsertSelectTests(unittest.TestCase):
nodes_with_edges_only=True)
AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
ref_graph = build_graph({
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
'input': {'kind': 'op', 'op': 'Parameter'},
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
'splice_data_1': {'kind': 'data', 'shape': [1, 13]},
@ -109,7 +109,7 @@ class InsertSelectTests(unittest.TestCase):
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Assign'},
},
[('placeholder_1', 'placeholder_data_1'),
[('input', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
('placeholder_data_2', 'select', {'in': 1}),
@ -168,7 +168,7 @@ class InsertSelectTests(unittest.TestCase):
# should be inserted select with memory as counter with length 5
def test_insert_select_2(self):
graph = build_graph({
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
'input': {'kind': 'op', 'op': 'Parameter'},
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
'splice_data_1': {'kind': 'data', 'shape': [1, 65]},
@ -178,7 +178,7 @@ class InsertSelectTests(unittest.TestCase):
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Assign'},
},
[('placeholder_1', 'placeholder_data_1'),
[('input', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
('placeholder_data_1', 'splice_2'), ('splice_2', 'splice_data_2'),
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
@ -187,7 +187,7 @@ class InsertSelectTests(unittest.TestCase):
nodes_with_edges_only=True)
AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
ref_graph = build_graph({
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
'input': {'kind': 'op', 'op': 'Parameter'},
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
'splice_data_1': {'kind': 'data', 'shape': [1, 65]},
@ -249,7 +249,7 @@ class InsertSelectTests(unittest.TestCase):
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Assign'},
},
[('placeholder_1', 'placeholder_data_1'),
[('input', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
('placeholder_data_1', 'splice_2'), ('splice_2', 'splice_data_2'),
('splice_data_1', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
@ -308,7 +308,7 @@ class InsertSelectTests(unittest.TestCase):
# should be inserted select with memory as counter with length 7
def test_insert_select_3(self):
graph = build_graph({
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
'input': {'kind': 'op', 'op': 'Parameter'},
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
'splice_data_1': {'kind': 'data', 'shape': [1, 65]},
@ -318,7 +318,7 @@ class InsertSelectTests(unittest.TestCase):
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
},
[('placeholder_1', 'placeholder_data_1'),
[('input', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
('splice_data_1', 'splice_2'), ('splice_2', 'splice_data_2'),
('splice_data_2', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),
@ -327,7 +327,7 @@ class InsertSelectTests(unittest.TestCase):
nodes_with_edges_only=True)
AddSelectBeforeMemoryNodePattern().find_and_replace_pattern(graph)
ref_graph = build_graph({
'placeholder_1': {'kind': 'op', 'op': 'Parameter'},
'input': {'kind': 'op', 'op': 'Parameter'},
'placeholder_data_1': {'kind': 'data', 'shape': [1, 13]},
'splice_1': {'kind': 'op', 'op': 'Splice', 'context': np.array([-2, -1, 0, 1, 2])},
'splice_data_1': {'kind': 'data', 'shape': [1, 65]},
@ -389,7 +389,7 @@ class InsertSelectTests(unittest.TestCase):
'placeholder_data_2': {'kind': 'data', 'shape': [1, 26]},
'memory': {'kind': 'op', 'op': 'Assign', 'index': 0},
},
[('placeholder_1', 'placeholder_data_1'),
[('input', 'placeholder_data_1'),
('placeholder_data_1', 'splice_1'), ('splice_1', 'splice_data_1'),
('splice_data_1', 'splice_2'), ('splice_2', 'splice_data_2'),
('splice_data_2', 'placeholder_2'), ('placeholder_2', 'placeholder_data_2'),

View File

@ -107,8 +107,6 @@ class ReplaceMemoryOffsetNodePatternTests(unittest.TestCase):
'splice': {'kind': 'op', 'op': 'Splice', 'context': range(-5, 1)},
'splice_data': {'kind': 'data', 'shape': [1, 78]},
'crop': {'kind': 'op', 'op': 'Crop', 'offset': 0, 'dim': 13},
'crop_input': {'kind': 'op', 'op': 'Crop', 'offset': 65, 'dim': 13},
'crop_input_data': {'kind': 'data', 'shape': [1, 13]},
'memoryoffset_2_data': {'kind': 'data', 'shape': [1, 13]},
'out_placeholder': {'kind': 'op', 'op': 'placeholder'},
},
@ -118,10 +116,8 @@ class ReplaceMemoryOffsetNodePatternTests(unittest.TestCase):
('splice', 'splice_data'),
('splice_data', 'crop'),
('crop', 'memoryoffset_2_data'),
('splice_data', 'crop_input'),
('crop_input', 'crop_input_data'),
('memoryoffset_2_data', 'out_placeholder'),
('crop_input_data', 'out_placeholder')
('in_node', 'out_placeholder')
]
)