support parallel nested nnet for Kaldi (#1194)

* supported nested nnet1 for Kaldi
This commit is contained in:
Pavel Esir 2020-07-23 15:37:41 +03:00 committed by GitHub
parent f90f242626
commit e56c8a2bc7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 39 additions and 117 deletions

View File

@ -784,7 +784,6 @@ mo/front/kaldi/extractors/pnorm_component_ext.py
mo/front/kaldi/extractors/rectified_linear_component_ext.py
mo/front/kaldi/extractors/rescale_ext.py
mo/front/kaldi/extractors/scale_component_ext.py
mo/front/kaldi/extractors/slice_ext.py
mo/front/kaldi/extractors/softmax_ext.py
mo/front/kaldi/extractors/splice_component_ext.py
mo/front/kaldi/loader/__init__.py

View File

@ -1,42 +0,0 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from mo.front.common.partial_infer.slice import caffe_slice_infer
from mo.front.extractor import FrontExtractorOp
from mo.front.kaldi.loader.utils import read_binary_integer32_token, read_blob
from mo.ops.slice import Slice
class SliceFrontExtractor(FrontExtractorOp):
op = 'slice'
enabled = True
@classmethod
def extract(cls, node):
pb = node.parameters
num_slice_points = read_binary_integer32_token(pb)
mapping_rule = {
'axis': 1,
'slice_point': read_blob(pb, num_slice_points, np.int32),
'batch_dims': 0,
'spatial_dims': 1,
'out_ports_count': num_slice_points + 1,
'infer': caffe_slice_infer
}
node.parameters.close()
Slice.update_node_stat(node, mapping_rule)
return cls.enabled

View File

@ -1,35 +0,0 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.front.kaldi.extractors.common_ext_test import KaldiFrontExtractorTest
from mo.front.kaldi.extractors.slice_ext import SliceFrontExtractor
from mo.ops.op import Op
from mo.ops.slice import Slice
from mo.utils.unittest.extractors import FakeMultiParam
class SliceFrontExtractorTest(KaldiFrontExtractorTest):
@classmethod
def register_op(cls):
Op.registered_ops['Slice'] = Slice
cls.slice_params = {
'slice_point': [99, 1320],
'axis': 1
}
cls.test_node['pb'] = FakeMultiParam(cls.slice_params)
def test_assertion_no_pb(self):
self.assertRaises(AttributeError, SliceFrontExtractor.extract, None)

View File

@ -13,14 +13,13 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import io
import logging as log
import struct
from io import IOBase
import networkx as nx
import numpy as np
from extensions.ops.split import AttributedVariadicSplit
from mo.front.kaldi.loader.utils import find_next_tag, read_placeholder, find_next_component, get_name_from_path, \
find_end_of_component, end_of_nnet_tag, read_binary_integer32_token, get_parameters, read_token_value, \
collect_until_token, collect_until_token_and_read, create_edge_attrs, get_args_for_specifier
@ -33,7 +32,7 @@ def load_parallel_component(file_descr, graph: Graph, prev_layer_id):
"""
Load ParallelComponent of the Kaldi model.
ParallelComponent contains parallel nested networks.
Slice is inserted before nested networks.
VariadicSplit is inserted before nested networks.
Outputs of nested networks concatenate with layer Concat.
:param file_descr: descriptor of the model file
@ -44,27 +43,23 @@ def load_parallel_component(file_descr, graph: Graph, prev_layer_id):
nnet_count = read_token_value(file_descr, b'<NestedNnetCount>')
log.debug('Model contains parallel component with {} nested networks'.format(nnet_count))
slice_id = graph.unique_id(prefix='Slice')
graph.add_node(slice_id, parameters=None, op='slice', kind='op')
slice_node = Node(graph, slice_id)
Node(graph, prev_layer_id).add_output_port(0)
slice_node.add_input_port(0)
graph.create_edge(Node(graph, prev_layer_id), slice_node, 0, 0)
slices_points = []
split_points = []
outputs = []
inputs = []
for i in range(nnet_count):
read_token_value(file_descr, b'<NestedNnet>')
collect_until_token(file_descr, b'<Nnet>')
g = Graph()
load_kalid_nnet1_model(g, file_descr, 'Nested_net_{}'.format(i))
input_nodes = [n for n in graph.nodes(data=True) if n[1]['op'] == 'Parameter']
shape = input_nodes[0][1]['shape']
if i != nnet_count - 1:
slices_points.append(shape[1])
g.remove_node(input_nodes[0][0])
# input to nnet1 models is of a rank 1 but we also insert batch_size to 0th axis
# 1st axis contains input_size of the nested subnetwork
# we split input from the main network to subnetworks
input_node = Node(g, 'Parameter')
split_points.append(input_node['shape'][1])
g.remove_node(input_node.id)
mapping = {node: graph.unique_id(node) for node in g.nodes(data=False) if node in graph}
g = nx.relabel_nodes(g, mapping)
for val in mapping.values():
@ -72,24 +67,28 @@ def load_parallel_component(file_descr, graph: Graph, prev_layer_id):
graph.add_nodes_from(g.nodes(data=True))
graph.add_edges_from(g.edges(data=True))
sorted_nodes = tuple(nx.topological_sort(g))
edge_attrs = create_edge_attrs(slice_id, sorted_nodes[0])
edge_attrs['out'] = i
Node(graph, slice_id).add_output_port(i)
Node(graph, sorted_nodes[0]).add_input_port(len(Node(graph, sorted_nodes[0]).in_ports()))
graph.create_edge(Node(graph, slice_id), Node(graph, sorted_nodes[0]), i, 0)
outputs.append(sorted_nodes[-1])
packed_sp = struct.pack("B", 4) + struct.pack("I", len(slices_points))
for i in slices_points:
packed_sp += struct.pack("I", i)
slice_node.parameters = io.BytesIO(packed_sp)
outputs.append(Node(graph, sorted_nodes[-1]))
inputs.append(Node(graph, sorted_nodes[0]))
split_id = graph.unique_id(prefix='NestedNets/VariadicSplit')
attrs = {'out_ports_count': nnet_count, 'size_splits': split_points, 'axis': 1, 'name': split_id}
variadic_split_node = AttributedVariadicSplit(graph, attrs).create_node()
prev_layer_node = Node(graph, prev_layer_id)
prev_layer_node.add_output_port(0)
graph.create_edge(prev_layer_node, variadic_split_node, 0, 0)
concat_id = graph.unique_id(prefix='Concat')
graph.add_node(concat_id, parameters=None, op='concat', kind='op')
for i, output in enumerate(outputs):
edge_attrs = create_edge_attrs(output, concat_id)
edge_attrs['in'] = i
Node(graph, output).add_output_port(0)
Node(graph, concat_id).add_input_port(i)
graph.create_edge(Node(graph, output), Node(graph, concat_id), 0, i)
concat_node = Node(graph, concat_id)
# Connect each output of variadic_split_node to each subnetwork's inputs in ParallelComponent
# and each subnetwork's output to concat_node
for i, (input_node, output_node) in enumerate(zip(inputs, outputs)):
output_node.add_output_port(0)
concat_node.add_input_port(i)
graph.create_edge(output_node, concat_node, 0, i)
graph.create_edge(variadic_split_node, input_node, i, 0)
return concat_id
@ -145,6 +144,7 @@ def load_kalid_nnet1_model(graph, file_descr, name):
if component_type == 'parallelcomponent':
prev_layer_id = load_parallel_component(file_descr, graph, prev_layer_id)
find_end_of_component(file_descr, component_type)
continue
start_index = file_descr.tell()
@ -231,7 +231,7 @@ def load_components(file_descr, graph, component_layer_map=None):
file_descr.seek(start_index)
dim = 0
try:
collect_until_token(file_descr, b'<Dim>', size_search_zone=end_index-start_index)
collect_until_token(file_descr, b'<Dim>', size_search_zone=end_index - start_index)
cur_index = file_descr.tell()
if start_index < cur_index < end_index:
dim = read_binary_integer32_token(file_descr)
@ -284,9 +284,9 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
return False
tokens = s.split(b' ')
if tokens[0] == b'input-node':
in_name = s[s.find(b'name=')+len(b'name='):].split(b' ')[0]
in_name = s[s.find(b'name=') + len(b'name='):].split(b' ')[0]
in_name = str(in_name).strip('b').replace('\'', "")
in_shape = np.array([1, s[s.find(b'dim=')+len(b'dim='):].split(b' ')[0]], dtype=np.int)
in_shape = np.array([1, s[s.find(b'dim=') + len(b'dim='):].split(b' ')[0]], dtype=np.int)
if in_name not in layer_node_map:
graph.add_node(in_name, name=in_name, kind='op', op='Parameter', parameters=None, shape=in_shape)
@ -295,7 +295,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
Node(graph, in_name)['op'] = 'Parameter'
Node(graph, in_name)['shape'] = in_shape
elif tokens[0] == b'component-node':
layer_name = s[s.find(b'name=')+len(b'name='):].split(b' ')[0]
layer_name = s[s.find(b'name=') + len(b'name='):].split(b' ')[0]
layer_name = str(layer_name).strip('b').replace('\'', "")
component_name = s[s.find(b'component=') + len(b'component='):].split(b' ')[0]
@ -315,7 +315,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
component_layer_map[component_name] = [node_name]
# parse input
in_node_id = parse_input_for_node(s[s.find(b'input=')+6:], graph, layer_node_map)
in_node_id = parse_input_for_node(s[s.find(b'input=') + 6:], graph, layer_node_map)
out_port = len(Node(graph, in_node_id).out_nodes())
in_port = len(Node(graph, node_name).in_nodes())
@ -331,7 +331,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
parameters=None,
op='Identity',
kind='op')
out_name = graph.unique_id(prefix=node_name+"_out")
out_name = graph.unique_id(prefix=node_name + "_out")
graph.add_node(out_name,
parameters=None,
op='Result',