support parallel nested nnet for Kaldi (#1194)
* supported nested nnet1 for Kaldi
This commit is contained in:
parent
f90f242626
commit
e56c8a2bc7
@ -784,7 +784,6 @@ mo/front/kaldi/extractors/pnorm_component_ext.py
|
||||
mo/front/kaldi/extractors/rectified_linear_component_ext.py
|
||||
mo/front/kaldi/extractors/rescale_ext.py
|
||||
mo/front/kaldi/extractors/scale_component_ext.py
|
||||
mo/front/kaldi/extractors/slice_ext.py
|
||||
mo/front/kaldi/extractors/softmax_ext.py
|
||||
mo/front/kaldi/extractors/splice_component_ext.py
|
||||
mo/front/kaldi/loader/__init__.py
|
||||
|
@ -1,42 +0,0 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.partial_infer.slice import caffe_slice_infer
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.kaldi.loader.utils import read_binary_integer32_token, read_blob
|
||||
from mo.ops.slice import Slice
|
||||
|
||||
|
||||
class SliceFrontExtractor(FrontExtractorOp):
|
||||
op = 'slice'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
pb = node.parameters
|
||||
num_slice_points = read_binary_integer32_token(pb)
|
||||
mapping_rule = {
|
||||
'axis': 1,
|
||||
'slice_point': read_blob(pb, num_slice_points, np.int32),
|
||||
'batch_dims': 0,
|
||||
'spatial_dims': 1,
|
||||
'out_ports_count': num_slice_points + 1,
|
||||
'infer': caffe_slice_infer
|
||||
}
|
||||
node.parameters.close()
|
||||
Slice.update_node_stat(node, mapping_rule)
|
||||
return cls.enabled
|
@ -1,35 +0,0 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from mo.front.kaldi.extractors.common_ext_test import KaldiFrontExtractorTest
|
||||
from mo.front.kaldi.extractors.slice_ext import SliceFrontExtractor
|
||||
from mo.ops.op import Op
|
||||
from mo.ops.slice import Slice
|
||||
from mo.utils.unittest.extractors import FakeMultiParam
|
||||
|
||||
|
||||
class SliceFrontExtractorTest(KaldiFrontExtractorTest):
|
||||
@classmethod
|
||||
def register_op(cls):
|
||||
Op.registered_ops['Slice'] = Slice
|
||||
cls.slice_params = {
|
||||
'slice_point': [99, 1320],
|
||||
'axis': 1
|
||||
}
|
||||
cls.test_node['pb'] = FakeMultiParam(cls.slice_params)
|
||||
|
||||
def test_assertion_no_pb(self):
|
||||
self.assertRaises(AttributeError, SliceFrontExtractor.extract, None)
|
@ -13,14 +13,13 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import io
|
||||
import logging as log
|
||||
import struct
|
||||
from io import IOBase
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.split import AttributedVariadicSplit
|
||||
from mo.front.kaldi.loader.utils import find_next_tag, read_placeholder, find_next_component, get_name_from_path, \
|
||||
find_end_of_component, end_of_nnet_tag, read_binary_integer32_token, get_parameters, read_token_value, \
|
||||
collect_until_token, collect_until_token_and_read, create_edge_attrs, get_args_for_specifier
|
||||
@ -33,7 +32,7 @@ def load_parallel_component(file_descr, graph: Graph, prev_layer_id):
|
||||
"""
|
||||
Load ParallelComponent of the Kaldi model.
|
||||
ParallelComponent contains parallel nested networks.
|
||||
Slice is inserted before nested networks.
|
||||
VariadicSplit is inserted before nested networks.
|
||||
Outputs of nested networks concatenate with layer Concat.
|
||||
|
||||
:param file_descr: descriptor of the model file
|
||||
@ -44,27 +43,23 @@ def load_parallel_component(file_descr, graph: Graph, prev_layer_id):
|
||||
nnet_count = read_token_value(file_descr, b'<NestedNnetCount>')
|
||||
log.debug('Model contains parallel component with {} nested networks'.format(nnet_count))
|
||||
|
||||
slice_id = graph.unique_id(prefix='Slice')
|
||||
graph.add_node(slice_id, parameters=None, op='slice', kind='op')
|
||||
|
||||
slice_node = Node(graph, slice_id)
|
||||
Node(graph, prev_layer_id).add_output_port(0)
|
||||
slice_node.add_input_port(0)
|
||||
graph.create_edge(Node(graph, prev_layer_id), slice_node, 0, 0)
|
||||
slices_points = []
|
||||
|
||||
split_points = []
|
||||
outputs = []
|
||||
inputs = []
|
||||
|
||||
for i in range(nnet_count):
|
||||
read_token_value(file_descr, b'<NestedNnet>')
|
||||
collect_until_token(file_descr, b'<Nnet>')
|
||||
g = Graph()
|
||||
load_kalid_nnet1_model(g, file_descr, 'Nested_net_{}'.format(i))
|
||||
input_nodes = [n for n in graph.nodes(data=True) if n[1]['op'] == 'Parameter']
|
||||
shape = input_nodes[0][1]['shape']
|
||||
if i != nnet_count - 1:
|
||||
slices_points.append(shape[1])
|
||||
g.remove_node(input_nodes[0][0])
|
||||
|
||||
# input to nnet1 models is of a rank 1 but we also insert batch_size to 0th axis
|
||||
# 1st axis contains input_size of the nested subnetwork
|
||||
# we split input from the main network to subnetworks
|
||||
input_node = Node(g, 'Parameter')
|
||||
split_points.append(input_node['shape'][1])
|
||||
g.remove_node(input_node.id)
|
||||
|
||||
mapping = {node: graph.unique_id(node) for node in g.nodes(data=False) if node in graph}
|
||||
g = nx.relabel_nodes(g, mapping)
|
||||
for val in mapping.values():
|
||||
@ -72,24 +67,28 @@ def load_parallel_component(file_descr, graph: Graph, prev_layer_id):
|
||||
graph.add_nodes_from(g.nodes(data=True))
|
||||
graph.add_edges_from(g.edges(data=True))
|
||||
sorted_nodes = tuple(nx.topological_sort(g))
|
||||
edge_attrs = create_edge_attrs(slice_id, sorted_nodes[0])
|
||||
edge_attrs['out'] = i
|
||||
Node(graph, slice_id).add_output_port(i)
|
||||
Node(graph, sorted_nodes[0]).add_input_port(len(Node(graph, sorted_nodes[0]).in_ports()))
|
||||
graph.create_edge(Node(graph, slice_id), Node(graph, sorted_nodes[0]), i, 0)
|
||||
outputs.append(sorted_nodes[-1])
|
||||
packed_sp = struct.pack("B", 4) + struct.pack("I", len(slices_points))
|
||||
for i in slices_points:
|
||||
packed_sp += struct.pack("I", i)
|
||||
slice_node.parameters = io.BytesIO(packed_sp)
|
||||
|
||||
outputs.append(Node(graph, sorted_nodes[-1]))
|
||||
inputs.append(Node(graph, sorted_nodes[0]))
|
||||
|
||||
split_id = graph.unique_id(prefix='NestedNets/VariadicSplit')
|
||||
attrs = {'out_ports_count': nnet_count, 'size_splits': split_points, 'axis': 1, 'name': split_id}
|
||||
variadic_split_node = AttributedVariadicSplit(graph, attrs).create_node()
|
||||
prev_layer_node = Node(graph, prev_layer_id)
|
||||
prev_layer_node.add_output_port(0)
|
||||
graph.create_edge(prev_layer_node, variadic_split_node, 0, 0)
|
||||
|
||||
concat_id = graph.unique_id(prefix='Concat')
|
||||
graph.add_node(concat_id, parameters=None, op='concat', kind='op')
|
||||
for i, output in enumerate(outputs):
|
||||
edge_attrs = create_edge_attrs(output, concat_id)
|
||||
edge_attrs['in'] = i
|
||||
Node(graph, output).add_output_port(0)
|
||||
Node(graph, concat_id).add_input_port(i)
|
||||
graph.create_edge(Node(graph, output), Node(graph, concat_id), 0, i)
|
||||
concat_node = Node(graph, concat_id)
|
||||
|
||||
# Connect each output of variadic_split_node to each subnetwork's inputs in ParallelComponent
|
||||
# and each subnetwork's output to concat_node
|
||||
for i, (input_node, output_node) in enumerate(zip(inputs, outputs)):
|
||||
output_node.add_output_port(0)
|
||||
concat_node.add_input_port(i)
|
||||
graph.create_edge(output_node, concat_node, 0, i)
|
||||
graph.create_edge(variadic_split_node, input_node, i, 0)
|
||||
return concat_id
|
||||
|
||||
|
||||
@ -145,6 +144,7 @@ def load_kalid_nnet1_model(graph, file_descr, name):
|
||||
|
||||
if component_type == 'parallelcomponent':
|
||||
prev_layer_id = load_parallel_component(file_descr, graph, prev_layer_id)
|
||||
find_end_of_component(file_descr, component_type)
|
||||
continue
|
||||
|
||||
start_index = file_descr.tell()
|
||||
@ -231,7 +231,7 @@ def load_components(file_descr, graph, component_layer_map=None):
|
||||
file_descr.seek(start_index)
|
||||
dim = 0
|
||||
try:
|
||||
collect_until_token(file_descr, b'<Dim>', size_search_zone=end_index-start_index)
|
||||
collect_until_token(file_descr, b'<Dim>', size_search_zone=end_index - start_index)
|
||||
cur_index = file_descr.tell()
|
||||
if start_index < cur_index < end_index:
|
||||
dim = read_binary_integer32_token(file_descr)
|
||||
@ -284,9 +284,9 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
|
||||
return False
|
||||
tokens = s.split(b' ')
|
||||
if tokens[0] == b'input-node':
|
||||
in_name = s[s.find(b'name=')+len(b'name='):].split(b' ')[0]
|
||||
in_name = s[s.find(b'name=') + len(b'name='):].split(b' ')[0]
|
||||
in_name = str(in_name).strip('b').replace('\'', "")
|
||||
in_shape = np.array([1, s[s.find(b'dim=')+len(b'dim='):].split(b' ')[0]], dtype=np.int)
|
||||
in_shape = np.array([1, s[s.find(b'dim=') + len(b'dim='):].split(b' ')[0]], dtype=np.int)
|
||||
|
||||
if in_name not in layer_node_map:
|
||||
graph.add_node(in_name, name=in_name, kind='op', op='Parameter', parameters=None, shape=in_shape)
|
||||
@ -295,7 +295,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
|
||||
Node(graph, in_name)['op'] = 'Parameter'
|
||||
Node(graph, in_name)['shape'] = in_shape
|
||||
elif tokens[0] == b'component-node':
|
||||
layer_name = s[s.find(b'name=')+len(b'name='):].split(b' ')[0]
|
||||
layer_name = s[s.find(b'name=') + len(b'name='):].split(b' ')[0]
|
||||
layer_name = str(layer_name).strip('b').replace('\'', "")
|
||||
|
||||
component_name = s[s.find(b'component=') + len(b'component='):].split(b' ')[0]
|
||||
@ -315,7 +315,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
|
||||
component_layer_map[component_name] = [node_name]
|
||||
|
||||
# parse input
|
||||
in_node_id = parse_input_for_node(s[s.find(b'input=')+6:], graph, layer_node_map)
|
||||
in_node_id = parse_input_for_node(s[s.find(b'input=') + 6:], graph, layer_node_map)
|
||||
out_port = len(Node(graph, in_node_id).out_nodes())
|
||||
in_port = len(Node(graph, node_name).in_nodes())
|
||||
|
||||
@ -331,7 +331,7 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
|
||||
parameters=None,
|
||||
op='Identity',
|
||||
kind='op')
|
||||
out_name = graph.unique_id(prefix=node_name+"_out")
|
||||
out_name = graph.unique_id(prefix=node_name + "_out")
|
||||
graph.add_node(out_name,
|
||||
parameters=None,
|
||||
op='Result',
|
||||
|
Loading…
Reference in New Issue
Block a user