added Kaldi dropoutmask extraction and Kaldi lstmNonlinearity replacer extended for dropout case

This commit is contained in:
sadolini 2021-09-09 00:27:48 +03:00
parent 667ca3c3f6
commit 0dd05f8053
10 changed files with 185 additions and 16 deletions

View File

@ -162,6 +162,7 @@ extensions/front/kaldi/apply_counts.py
extensions/front/kaldi/logsoftmax_component_ext.py
extensions/front/kaldi/memory_offset_adjustment.py
extensions/front/kaldi/memoryoffset_batch_update.py
extensions/front/kaldi/replace_dropoutmask.py
extensions/front/kaldi/replace_eltwise_nin1.py
extensions/front/kaldi/replace_lstm_node_pattern.py
extensions/front/kaldi/replace_lstm_nonlinearity.py
@ -867,6 +868,7 @@ mo/front/kaldi/extractors/convolutional_1d_component_ext.py
mo/front/kaldi/extractors/convolutional_component_ext.py
mo/front/kaldi/extractors/copy_ext.py
mo/front/kaldi/extractors/crop_ext.py
mo/front/kaldi/extractors/dropoutmask_ext.py
mo/front/kaldi/extractors/elementwise_component_ext.py
mo/front/kaldi/extractors/fixed_affine_component_ext.py
mo/front/kaldi/extractors/generaldropout_ext.py
@ -980,6 +982,7 @@ mo/ops/convolution.py
mo/ops/crop.py
mo/ops/deconvolution.py
mo/ops/deformable_convolution.py
mo/ops/dropoutmask.py
mo/ops/eltwise.py
mo/ops/eltwise_n.py
mo/ops/eltwise_ninputs_in_1.py

View File

@ -0,0 +1,27 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from extensions.middle.MakeKaldiConstReshapable import create_const_with_batch_from_input
from mo.front.common.replacement import FrontReplacementPattern
from mo.graph.graph import Graph
class ReplaceDropoutMaskPattern(FrontReplacementPattern):
enabled = True
run_non_recursive = True
def run_after(self):
from extensions.front.restore_ports import RestorePorts
return [RestorePorts]
def run_before(self):
from extensions.front.kaldi.replace_lstm_nonlinearity import ReplaceLstmNonLinearityPattern
return [ReplaceLstmNonLinearityPattern]
def find_and_replace_pattern(self, graph: Graph):
batch_port = graph.get_op_nodes(op="Parameter")[0].out_port(0)
replace_nodes = graph.get_op_nodes(op='dropoutmaskcomponent')
for dropout_node in replace_nodes:
dp_const_node = create_const_with_batch_from_input(batch_port, dropout_node.size,
dropout_node.dropout_proportion)
dropout_node.out_port(0).get_connection().set_source(dp_const_node.out_port(0))
graph.remove_node(dropout_node.id)

View File

@ -5,7 +5,7 @@ import numpy as np
from extensions.ops.activation_ops import Sigmoid, Tanh
from extensions.ops.elementwise import Add, Mul
from extensions.ops.split import Split
from extensions.ops.split import Split, AttributedVariadicSplit
from mo.front.caffe.extractors.utils import input_as_const
from mo.front.common.replacement import FrontReplacementOp
from mo.front.tf.graph_utils import create_op_with_const_inputs
@ -27,12 +27,26 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
return [FullyConnectedDecomposer]
def replace_op(self, graph: Graph, node: Node):
# split input to (i_part, f_part, c_part, o_part, ct_1)
node_name = node.soft_get('name', node.id)
# check if we have dropout
input_port = node.in_port(0)
if node['use_dropout']:
split_dropout = AttributedVariadicSplit(graph,
{'name': node_name + '/split_dropout',
'size_splits': np.array([-1, 1, 1, 1],),
'axis': np.int64(1)}).create_node()
input_port.get_connection().set_destination(split_dropout.in_port(0))
input_port = split_dropout.out_port(0)
# split input to (i_part, f_part, c_part, o_part, ct_1)
split_node = create_op_with_const_inputs(graph, Split, {1: np.int64(1)},
{'name': node_name + '/split_lstm_input',
'num_splits': 5})
node.in_port(0).get_connection().set_destination(split_node.in_port(0))
input_port.get_connection().set_destination(split_node.in_port(0))
i_part = split_node.out_port(0)
f_part = split_node.out_port(1)
o_part = split_node.out_port(3)
# i_t = Sigmoid(i_part + w_ic*ct_1)
i_scale_attrs = {'name': node_name + '/i_scaleshift',
@ -42,12 +56,18 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
split_node.out_port(4).connect(i_scale.in_port(0))
sum_i_c = Add(graph, {'name': node_name + '/sum_i_c_'}).create_node()
split_node.out_port(0).connect(sum_i_c.in_port(0))
i_part.connect(sum_i_c.in_port(0))
i_scale.out_port(0).connect(sum_i_c.in_port(1))
i_sigmoid = Sigmoid(graph, {'name': node_name + '/i_sigmoid'}).create_node()
sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))
if node['use_dropout']:
mul_dropout_i = Mul(graph, {'name': split_node.soft_get('name', split_node.id) + '/mul_i'}).create_node()
mul_dropout_i.in_port(0).connect(i_sigmoid.out_port(0))
mul_dropout_i.in_port(1).connect(split_dropout.out_port(1))
i_sigmoid = mul_dropout_i
# f_t = Sigmoid(f_part + w_fc*ct_1)
f_scale_attrs = {'name': node_name + '/f_scaleshift',
'bias_term': False}
@ -56,12 +76,18 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
split_node.out_port(4).connect(f_scale.in_port(0))
sum_f_c = Add(graph, {'name': node_name + '/sum_f_c_'}).create_node()
split_node.out_port(1).connect(sum_f_c.in_port(0))
f_part.connect(sum_f_c.in_port(0))
f_scale.out_port(0).connect(sum_f_c.in_port(1))
f_sigmoid = Sigmoid(graph, {'name': node_name + '/f_sigmoid'}).create_node()
sum_f_c.out_port(0).connect(f_sigmoid.in_port(0))
if node['use_dropout']:
mul_dropout_f = Mul(graph, {'name': split_node.soft_get('name', split_node.id) + '/mul_f'}).create_node()
mul_dropout_f.in_port(0).connect(f_sigmoid.out_port(0))
mul_dropout_f.in_port(1).connect(split_dropout.out_port(2))
f_sigmoid = mul_dropout_f
# c_t = f_t*ct_1 + i_t * tanh(c_part)
c_tanh = Tanh(graph, {'name': node_name + '/c_tanh'}).create_node()
split_node.out_port(2).connect(c_tanh.in_port(0))
@ -86,12 +112,18 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
sum_f_i.out_port(0).connect(o_scale.in_port(0))
sum_o_c = Add(graph, {'name': node_name + '/sum_o_c_'}).create_node()
split_node.out_port(3).connect(sum_o_c.in_port(0))
o_part.connect(sum_o_c.in_port(0))
o_scale.out_port(0).connect(sum_o_c.in_port(1))
o_sigmoid = Sigmoid(graph, {'name': node_name + '/o_sigmoid'}).create_node()
sum_o_c.out_port(0).connect(o_sigmoid.in_port(0))
if node['use_dropout']:
mul_dropout_o = Mul(graph, {'name': split_node.soft_get('name', split_node.id) + '/mul_o'}).create_node()
mul_dropout_o.in_port(0).connect(o_sigmoid.out_port(0))
mul_dropout_o.in_port(1).connect(split_dropout.out_port(3))
o_sigmoid = mul_dropout_o
# m_t = o_t * Tanh(c_t)
c_t_tanh = Tanh(graph, {'name': node_name + '/c_t_tanh'}).create_node()
sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))

View File

@ -0,0 +1,22 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from mo.front.extractor import FrontExtractorOp
from mo.front.kaldi.loader.utils import collect_until_token, collect_until_token_and_read, read_binary_float_token
from mo.ops.dropoutmask import DropoutMask
class DropoutMaskComponentFrontExtractor(FrontExtractorOp):
op = 'dropoutmaskcomponent'
enabled = True
@classmethod
def extract(cls, node):
pb = node.parameters
size = collect_until_token_and_read(pb, b'<OutputDim>')
collect_until_token(pb, b'<DropoutProportion>')
dropout_proportion = read_binary_float_token(pb)
DropoutMask.update_node_stat(node, {'dropout_proportion': 1.0-dropout_proportion,
'size': size})
return cls.enabled

View File

@ -1,10 +1,11 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
from mo.front.caffe.extractors.utils import embed_input
from mo.front.extractor import FrontExtractorOp
from mo.front.kaldi.loader.utils import collect_until_token
from mo.front.kaldi.utils import read_binary_matrix
from mo.front.kaldi.loader.utils import collect_until_token, collect_until_token_and_read
from mo.front.kaldi.utils import read_binary_matrix, read_token_value
from mo.ops.lstmnonlinearity import LstmNonLinearity
@ -15,10 +16,13 @@ class LSTMNonlinearityFrontExtractor(FrontExtractorOp):
@classmethod
def extract(cls, node):
pb = node.parameters
collect_until_token(pb, b'<Params>')
ifo_x_weights, ifo_x_weights_shape = read_binary_matrix(pb)
mapping_rule = {}
use_dropout = collect_until_token_and_read(pb, b'<UseDropout>', np.bool)
mapping_rule = {'use_dropout': use_dropout}
assert len(ifo_x_weights_shape) == 2, "Unexpected shape of weights in LSTMNonLinearityComponent"
assert ifo_x_weights_shape[0] == 3, "Unexpected shape of weights in LSTMNonLinearityComponent"

View File

@ -353,13 +353,15 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
# parse input
in_node_id = parse_input_for_node(s[s.find(b'input=') + 6:], graph, layer_node_map)
out_port = len(Node(graph, in_node_id).out_nodes())
in_port = len(Node(graph, node_name).in_nodes())
# don't create cyclic edges node to itself to avoid removing later
if in_node_id != node_name:
out_port = len(Node(graph, in_node_id).out_nodes())
in_port = len(Node(graph, node_name).in_nodes())
Node(graph, node_name).add_input_port(in_port)
Node(graph, in_node_id).add_output_port(out_port, skip_if_exist=True)
Node(graph, node_name).add_input_port(in_port)
Node(graph, in_node_id).add_output_port(out_port, skip_if_exist=True)
graph.add_edge(in_node_id, node_name, **create_edge_attrs(in_node_id, node_name, in_node_id, in_port, out_port))
graph.add_edge(in_node_id, node_name, **create_edge_attrs(in_node_id, node_name, in_node_id, in_port, out_port))
elif tokens[0] == b'output-node':
layer_name = s[s.find(b'name=') + len(b'name='):].split(b' ')[0]
layer_name = str(layer_name).strip('b').replace('\'', "")

View File

@ -24,6 +24,7 @@ supported_components = [
'convolutional1dcomponent',
'convolutionalcomponent',
'copy',
'dropoutmaskcomponent',
'elementwiseproductcomponent',
'fixedaffinecomponent',
'fixedscalecomponent',

View File

@ -0,0 +1,22 @@
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from mo.graph.graph import Graph
from mo.ops.op import Op
class DropoutMask(Op):
"""
Operation for dropout proportion, it will be replaced by broadcast constant on front stage
"""
op = 'dropoutmaskcomponent'
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'op': self.op,
'dropout_proportion': None,
'type': None, # type is None because this operation should not appear in IR
'infer': None,
'in_ports_count': 0,
'out_ports_count': 1,
}, attrs)

View File

@ -13,6 +13,7 @@ class LstmNonLinearity(Op):
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'op': __class__.op,
'use_dropout': False,
'type': None, # type is None because this operation should not appear in IR
'infer': None,
'in_ports_count': 1,

View File

@ -14,9 +14,9 @@ from unit_tests.utils.graph import build_graph
class ReplaceLstmNonlinearityTests(unittest.TestCase):
# i_t = Sigmoid(i_part + w_ic*ct_1)
# f_t = Sigmoid(f_part + w_fc*ct_1)
# c_t = f_t*ct_1 + i_t * tanh(c_part)
# c_t = f_t * f_scale * ct_1 + i_t * i_scale * tanh(c_part)
# o_t = Sigmoid(o_part + w_oc*c_t)
# m_t = o_t * Tanh(c_t)
# m_t = o_t * o_scale * Tanh(c_t)
nodes_attributes = {
'in': {'kind': 'op', 'op': 'Parameter'},
'i_part': {'kind': 'op', 'op': 'Parameter'},
@ -24,6 +24,7 @@ class ReplaceLstmNonlinearityTests(unittest.TestCase):
'c_part': {'kind': 'op', 'op': 'Parameter'},
'o_part': {'kind': 'op', 'op': 'Parameter'},
'split': {'kind': 'op', 'op': 'Split'},
'split_dropout': {'kind': 'op', 'op': 'AttributedVariadicSplit', 'size_splits': [-1, 1, 1, 1]},
'sigmoid_i': {'kind': 'op', 'op': 'Sigmoid'},
'sigmoid_f': {'kind': 'op', 'op': 'Sigmoid'},
'sigmoid_o': {'kind': 'op', 'op': 'Sigmoid'},
@ -31,6 +32,9 @@ class ReplaceLstmNonlinearityTests(unittest.TestCase):
'f_plus_c': {'kind': 'op', 'op': 'Eltwise', 'operation': 'sum'},
'fc_plus_itanhc': {'kind': 'op', 'op': 'Eltwise', 'operation': 'sum'},
'o_plus_c': {'kind': 'op', 'op': 'Eltwise', 'operation': 'sum'},
'scaled_i': {'kind': 'op', 'op': 'Mul'},
'scaled_f': {'kind': 'op', 'op': 'Mul'},
'scaled_o': {'kind': 'op', 'op': 'Mul'},
'scale_i_c': {'kind': 'op', 'op': 'ScaleShift'},
'scale_f_c': {'kind': 'op', 'op': 'ScaleShift'},
'scale_o_c': {'kind': 'op', 'op': 'ScaleShift'},
@ -50,6 +54,7 @@ class ReplaceLstmNonlinearityTests(unittest.TestCase):
def test_lstm_nonlinearity(self):
graph = build_graph({'in': {'kind': 'op', 'op': 'Parameter'},
'lstm': {'kind': 'op', 'op': 'LstmNonLinearity',
'use_dropout': False,
'i_weights': np.array([]),
'f_weights': np.array([]),
'o_weights': np.array([]),},
@ -88,3 +93,53 @@ class ReplaceLstmNonlinearityTests(unittest.TestCase):
ReplaceLstmNonLinearityPattern().replace_op(graph, Node(graph, 'lstm'))
(flag, resp) = compare_graphs(graph, ref_graph, 'out', check_op_attrs=True)
self.assertTrue(flag, resp)
def test_lstm_nonlinearity_dropout(self):
graph = build_graph({'in': {'kind': 'op', 'op': 'Parameter'},
'lstm': {'kind': 'op', 'op': 'LstmNonLinearity',
'use_dropout': True,
'i_weights': np.array([]),
'f_weights': np.array([]),
'o_weights': np.array([]),},
'out': {'kind': 'op', 'op': 'Placeholder'}},
[('in', 'lstm'), ('lstm', 'out')], nodes_with_edges_only=True)
graph.stage = 'front'
# split input to (i_part, f_part, c_part, o_part, ct_1)
ref_graph = build_graph(self.nodes_attributes, [
('in', 'split_dropout'),
('split_dropout', 'split', {'out': 0}),
('split', 'scale_i_c', {'out': 4}),
('scale_i_c', 'i_plus_c'),
('split', 'i_plus_c', {'out': 0}),
('i_plus_c', 'sigmoid_i'),
('sigmoid_i', 'scaled_i', {'in': 0}),
('split_dropout', 'scaled_i', {'out': 1, 'in': 1}),
('split', 'scale_f_c', {'out': 4}),
('scale_f_c', 'f_plus_c'),
('split', 'f_plus_c', {'out': 1}),
('f_plus_c', 'sigmoid_f'),
('sigmoid_f', 'scaled_f', {'in': 0}),
('split_dropout', 'scaled_f', {'out': 2, 'in': 1}),
('split', 'tanhcp', {'out': 2}),
('tanhcp', 'i_mul_tanhc'),
('scaled_i', 'i_mul_tanhc'),
('scaled_f', 'f_mul_c'),
('split', 'f_mul_c', {'out': 4}),
('f_mul_c', 'fc_plus_itanhc'),
('i_mul_tanhc', 'fc_plus_itanhc'),
('split', 'scale_o_c', {'out': 4}),
('scale_o_c', 'o_plus_c'),
('split', 'o_plus_c', {'out': 3}),
('o_plus_c', 'sigmoid_o'),
('sigmoid_o', 'scaled_o', {'in': 0}),
('split_dropout', 'scaled_o', {'out': 3, 'in': 1}),
('fc_plus_itanhc', 'tanhc'),
('scaled_o', 'o_mul_tanhc'),
('tanhc', 'o_mul_tanhc'),
('fc_plus_itanhc', 'concat'),
('o_mul_tanhc', 'concat'),
('lstm', 'out'),
], nodes_with_edges_only=True)
ReplaceLstmNonLinearityPattern().replace_op(graph, Node(graph, 'lstm'))
(flag, resp) = compare_graphs(graph, ref_graph, 'out', check_op_attrs=True)
self.assertTrue(flag, resp)