added Kaldi dropoutmask extraction and Kaldi lstmNonlinearity replacer extended for dropout case
This commit is contained in:
parent
667ca3c3f6
commit
0dd05f8053
@ -162,6 +162,7 @@ extensions/front/kaldi/apply_counts.py
|
||||
extensions/front/kaldi/logsoftmax_component_ext.py
|
||||
extensions/front/kaldi/memory_offset_adjustment.py
|
||||
extensions/front/kaldi/memoryoffset_batch_update.py
|
||||
extensions/front/kaldi/replace_dropoutmask.py
|
||||
extensions/front/kaldi/replace_eltwise_nin1.py
|
||||
extensions/front/kaldi/replace_lstm_node_pattern.py
|
||||
extensions/front/kaldi/replace_lstm_nonlinearity.py
|
||||
@ -867,6 +868,7 @@ mo/front/kaldi/extractors/convolutional_1d_component_ext.py
|
||||
mo/front/kaldi/extractors/convolutional_component_ext.py
|
||||
mo/front/kaldi/extractors/copy_ext.py
|
||||
mo/front/kaldi/extractors/crop_ext.py
|
||||
mo/front/kaldi/extractors/dropoutmask_ext.py
|
||||
mo/front/kaldi/extractors/elementwise_component_ext.py
|
||||
mo/front/kaldi/extractors/fixed_affine_component_ext.py
|
||||
mo/front/kaldi/extractors/generaldropout_ext.py
|
||||
@ -980,6 +982,7 @@ mo/ops/convolution.py
|
||||
mo/ops/crop.py
|
||||
mo/ops/deconvolution.py
|
||||
mo/ops/deformable_convolution.py
|
||||
mo/ops/dropoutmask.py
|
||||
mo/ops/eltwise.py
|
||||
mo/ops/eltwise_n.py
|
||||
mo/ops/eltwise_ninputs_in_1.py
|
||||
|
@ -0,0 +1,27 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from extensions.middle.MakeKaldiConstReshapable import create_const_with_batch_from_input
|
||||
from mo.front.common.replacement import FrontReplacementPattern
|
||||
from mo.graph.graph import Graph
|
||||
|
||||
|
||||
class ReplaceDropoutMaskPattern(FrontReplacementPattern):
|
||||
enabled = True
|
||||
run_non_recursive = True
|
||||
|
||||
def run_after(self):
|
||||
from extensions.front.restore_ports import RestorePorts
|
||||
return [RestorePorts]
|
||||
|
||||
def run_before(self):
|
||||
from extensions.front.kaldi.replace_lstm_nonlinearity import ReplaceLstmNonLinearityPattern
|
||||
return [ReplaceLstmNonLinearityPattern]
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
batch_port = graph.get_op_nodes(op="Parameter")[0].out_port(0)
|
||||
replace_nodes = graph.get_op_nodes(op='dropoutmaskcomponent')
|
||||
for dropout_node in replace_nodes:
|
||||
dp_const_node = create_const_with_batch_from_input(batch_port, dropout_node.size,
|
||||
dropout_node.dropout_proportion)
|
||||
dropout_node.out_port(0).get_connection().set_source(dp_const_node.out_port(0))
|
||||
graph.remove_node(dropout_node.id)
|
@ -5,7 +5,7 @@ import numpy as np
|
||||
|
||||
from extensions.ops.activation_ops import Sigmoid, Tanh
|
||||
from extensions.ops.elementwise import Add, Mul
|
||||
from extensions.ops.split import Split
|
||||
from extensions.ops.split import Split, AttributedVariadicSplit
|
||||
from mo.front.caffe.extractors.utils import input_as_const
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
@ -27,12 +27,26 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
|
||||
return [FullyConnectedDecomposer]
|
||||
|
||||
def replace_op(self, graph: Graph, node: Node):
|
||||
# split input to (i_part, f_part, c_part, o_part, ct_1)
|
||||
node_name = node.soft_get('name', node.id)
|
||||
# check if we have dropout
|
||||
input_port = node.in_port(0)
|
||||
if node['use_dropout']:
|
||||
split_dropout = AttributedVariadicSplit(graph,
|
||||
{'name': node_name + '/split_dropout',
|
||||
'size_splits': np.array([-1, 1, 1, 1],),
|
||||
'axis': np.int64(1)}).create_node()
|
||||
input_port.get_connection().set_destination(split_dropout.in_port(0))
|
||||
input_port = split_dropout.out_port(0)
|
||||
|
||||
# split input to (i_part, f_part, c_part, o_part, ct_1)
|
||||
split_node = create_op_with_const_inputs(graph, Split, {1: np.int64(1)},
|
||||
{'name': node_name + '/split_lstm_input',
|
||||
'num_splits': 5})
|
||||
node.in_port(0).get_connection().set_destination(split_node.in_port(0))
|
||||
input_port.get_connection().set_destination(split_node.in_port(0))
|
||||
|
||||
i_part = split_node.out_port(0)
|
||||
f_part = split_node.out_port(1)
|
||||
o_part = split_node.out_port(3)
|
||||
|
||||
# i_t = Sigmoid(i_part + w_ic*ct_1)
|
||||
i_scale_attrs = {'name': node_name + '/i_scaleshift',
|
||||
@ -42,12 +56,18 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
|
||||
split_node.out_port(4).connect(i_scale.in_port(0))
|
||||
|
||||
sum_i_c = Add(graph, {'name': node_name + '/sum_i_c_'}).create_node()
|
||||
split_node.out_port(0).connect(sum_i_c.in_port(0))
|
||||
i_part.connect(sum_i_c.in_port(0))
|
||||
i_scale.out_port(0).connect(sum_i_c.in_port(1))
|
||||
|
||||
i_sigmoid = Sigmoid(graph, {'name': node_name + '/i_sigmoid'}).create_node()
|
||||
sum_i_c.out_port(0).connect(i_sigmoid.in_port(0))
|
||||
|
||||
if node['use_dropout']:
|
||||
mul_dropout_i = Mul(graph, {'name': split_node.soft_get('name', split_node.id) + '/mul_i'}).create_node()
|
||||
mul_dropout_i.in_port(0).connect(i_sigmoid.out_port(0))
|
||||
mul_dropout_i.in_port(1).connect(split_dropout.out_port(1))
|
||||
i_sigmoid = mul_dropout_i
|
||||
|
||||
# f_t = Sigmoid(f_part + w_fc*ct_1)
|
||||
f_scale_attrs = {'name': node_name + '/f_scaleshift',
|
||||
'bias_term': False}
|
||||
@ -56,12 +76,18 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
|
||||
split_node.out_port(4).connect(f_scale.in_port(0))
|
||||
|
||||
sum_f_c = Add(graph, {'name': node_name + '/sum_f_c_'}).create_node()
|
||||
split_node.out_port(1).connect(sum_f_c.in_port(0))
|
||||
f_part.connect(sum_f_c.in_port(0))
|
||||
f_scale.out_port(0).connect(sum_f_c.in_port(1))
|
||||
|
||||
f_sigmoid = Sigmoid(graph, {'name': node_name + '/f_sigmoid'}).create_node()
|
||||
sum_f_c.out_port(0).connect(f_sigmoid.in_port(0))
|
||||
|
||||
if node['use_dropout']:
|
||||
mul_dropout_f = Mul(graph, {'name': split_node.soft_get('name', split_node.id) + '/mul_f'}).create_node()
|
||||
mul_dropout_f.in_port(0).connect(f_sigmoid.out_port(0))
|
||||
mul_dropout_f.in_port(1).connect(split_dropout.out_port(2))
|
||||
f_sigmoid = mul_dropout_f
|
||||
|
||||
# c_t = f_t*ct_1 + i_t * tanh(c_part)
|
||||
c_tanh = Tanh(graph, {'name': node_name + '/c_tanh'}).create_node()
|
||||
split_node.out_port(2).connect(c_tanh.in_port(0))
|
||||
@ -86,12 +112,18 @@ class ReplaceLstmNonLinearityPattern(FrontReplacementOp):
|
||||
sum_f_i.out_port(0).connect(o_scale.in_port(0))
|
||||
|
||||
sum_o_c = Add(graph, {'name': node_name + '/sum_o_c_'}).create_node()
|
||||
split_node.out_port(3).connect(sum_o_c.in_port(0))
|
||||
o_part.connect(sum_o_c.in_port(0))
|
||||
o_scale.out_port(0).connect(sum_o_c.in_port(1))
|
||||
|
||||
o_sigmoid = Sigmoid(graph, {'name': node_name + '/o_sigmoid'}).create_node()
|
||||
sum_o_c.out_port(0).connect(o_sigmoid.in_port(0))
|
||||
|
||||
if node['use_dropout']:
|
||||
mul_dropout_o = Mul(graph, {'name': split_node.soft_get('name', split_node.id) + '/mul_o'}).create_node()
|
||||
mul_dropout_o.in_port(0).connect(o_sigmoid.out_port(0))
|
||||
mul_dropout_o.in_port(1).connect(split_dropout.out_port(3))
|
||||
o_sigmoid = mul_dropout_o
|
||||
|
||||
# m_t = o_t * Tanh(c_t)
|
||||
c_t_tanh = Tanh(graph, {'name': node_name + '/c_t_tanh'}).create_node()
|
||||
sum_f_i.out_port(0).connect(c_t_tanh.in_port(0))
|
||||
|
22
model-optimizer/mo/front/kaldi/extractors/dropoutmask_ext.py
Normal file
22
model-optimizer/mo/front/kaldi/extractors/dropoutmask_ext.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.kaldi.loader.utils import collect_until_token, collect_until_token_and_read, read_binary_float_token
|
||||
from mo.ops.dropoutmask import DropoutMask
|
||||
|
||||
|
||||
class DropoutMaskComponentFrontExtractor(FrontExtractorOp):
|
||||
op = 'dropoutmaskcomponent'
|
||||
enabled = True
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
pb = node.parameters
|
||||
|
||||
size = collect_until_token_and_read(pb, b'<OutputDim>')
|
||||
collect_until_token(pb, b'<DropoutProportion>')
|
||||
dropout_proportion = read_binary_float_token(pb)
|
||||
DropoutMask.update_node_stat(node, {'dropout_proportion': 1.0-dropout_proportion,
|
||||
'size': size})
|
||||
|
||||
return cls.enabled
|
@ -1,10 +1,11 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import numpy as np
|
||||
|
||||
from mo.front.caffe.extractors.utils import embed_input
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.kaldi.loader.utils import collect_until_token
|
||||
from mo.front.kaldi.utils import read_binary_matrix
|
||||
from mo.front.kaldi.loader.utils import collect_until_token, collect_until_token_and_read
|
||||
from mo.front.kaldi.utils import read_binary_matrix, read_token_value
|
||||
from mo.ops.lstmnonlinearity import LstmNonLinearity
|
||||
|
||||
|
||||
@ -15,10 +16,13 @@ class LSTMNonlinearityFrontExtractor(FrontExtractorOp):
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
pb = node.parameters
|
||||
|
||||
collect_until_token(pb, b'<Params>')
|
||||
ifo_x_weights, ifo_x_weights_shape = read_binary_matrix(pb)
|
||||
|
||||
mapping_rule = {}
|
||||
use_dropout = collect_until_token_and_read(pb, b'<UseDropout>', np.bool)
|
||||
|
||||
mapping_rule = {'use_dropout': use_dropout}
|
||||
|
||||
assert len(ifo_x_weights_shape) == 2, "Unexpected shape of weights in LSTMNonLinearityComponent"
|
||||
assert ifo_x_weights_shape[0] == 3, "Unexpected shape of weights in LSTMNonLinearityComponent"
|
||||
|
@ -353,13 +353,15 @@ def read_node(file_descr, graph, component_layer_map, layer_node_map):
|
||||
|
||||
# parse input
|
||||
in_node_id = parse_input_for_node(s[s.find(b'input=') + 6:], graph, layer_node_map)
|
||||
out_port = len(Node(graph, in_node_id).out_nodes())
|
||||
in_port = len(Node(graph, node_name).in_nodes())
|
||||
# don't create cyclic edges node to itself to avoid removing later
|
||||
if in_node_id != node_name:
|
||||
out_port = len(Node(graph, in_node_id).out_nodes())
|
||||
in_port = len(Node(graph, node_name).in_nodes())
|
||||
|
||||
Node(graph, node_name).add_input_port(in_port)
|
||||
Node(graph, in_node_id).add_output_port(out_port, skip_if_exist=True)
|
||||
Node(graph, node_name).add_input_port(in_port)
|
||||
Node(graph, in_node_id).add_output_port(out_port, skip_if_exist=True)
|
||||
|
||||
graph.add_edge(in_node_id, node_name, **create_edge_attrs(in_node_id, node_name, in_node_id, in_port, out_port))
|
||||
graph.add_edge(in_node_id, node_name, **create_edge_attrs(in_node_id, node_name, in_node_id, in_port, out_port))
|
||||
elif tokens[0] == b'output-node':
|
||||
layer_name = s[s.find(b'name=') + len(b'name='):].split(b' ')[0]
|
||||
layer_name = str(layer_name).strip('b').replace('\'', "")
|
||||
|
@ -24,6 +24,7 @@ supported_components = [
|
||||
'convolutional1dcomponent',
|
||||
'convolutionalcomponent',
|
||||
'copy',
|
||||
'dropoutmaskcomponent',
|
||||
'elementwiseproductcomponent',
|
||||
'fixedaffinecomponent',
|
||||
'fixedscalecomponent',
|
||||
|
22
model-optimizer/mo/ops/dropoutmask.py
Normal file
22
model-optimizer/mo/ops/dropoutmask.py
Normal file
@ -0,0 +1,22 @@
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.op import Op
|
||||
|
||||
|
||||
class DropoutMask(Op):
|
||||
"""
|
||||
Operation for dropout proportion, it will be replaced by broadcast constant on front stage
|
||||
"""
|
||||
op = 'dropoutmaskcomponent'
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
super().__init__(graph, {
|
||||
'op': self.op,
|
||||
'dropout_proportion': None,
|
||||
'type': None, # type is None because this operation should not appear in IR
|
||||
'infer': None,
|
||||
'in_ports_count': 0,
|
||||
'out_ports_count': 1,
|
||||
}, attrs)
|
@ -13,6 +13,7 @@ class LstmNonLinearity(Op):
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
super().__init__(graph, {
|
||||
'op': __class__.op,
|
||||
'use_dropout': False,
|
||||
'type': None, # type is None because this operation should not appear in IR
|
||||
'infer': None,
|
||||
'in_ports_count': 1,
|
||||
|
@ -14,9 +14,9 @@ from unit_tests.utils.graph import build_graph
|
||||
class ReplaceLstmNonlinearityTests(unittest.TestCase):
|
||||
# i_t = Sigmoid(i_part + w_ic*ct_1)
|
||||
# f_t = Sigmoid(f_part + w_fc*ct_1)
|
||||
# c_t = f_t*ct_1 + i_t * tanh(c_part)
|
||||
# c_t = f_t * f_scale * ct_1 + i_t * i_scale * tanh(c_part)
|
||||
# o_t = Sigmoid(o_part + w_oc*c_t)
|
||||
# m_t = o_t * Tanh(c_t)
|
||||
# m_t = o_t * o_scale * Tanh(c_t)
|
||||
nodes_attributes = {
|
||||
'in': {'kind': 'op', 'op': 'Parameter'},
|
||||
'i_part': {'kind': 'op', 'op': 'Parameter'},
|
||||
@ -24,6 +24,7 @@ class ReplaceLstmNonlinearityTests(unittest.TestCase):
|
||||
'c_part': {'kind': 'op', 'op': 'Parameter'},
|
||||
'o_part': {'kind': 'op', 'op': 'Parameter'},
|
||||
'split': {'kind': 'op', 'op': 'Split'},
|
||||
'split_dropout': {'kind': 'op', 'op': 'AttributedVariadicSplit', 'size_splits': [-1, 1, 1, 1]},
|
||||
'sigmoid_i': {'kind': 'op', 'op': 'Sigmoid'},
|
||||
'sigmoid_f': {'kind': 'op', 'op': 'Sigmoid'},
|
||||
'sigmoid_o': {'kind': 'op', 'op': 'Sigmoid'},
|
||||
@ -31,6 +32,9 @@ class ReplaceLstmNonlinearityTests(unittest.TestCase):
|
||||
'f_plus_c': {'kind': 'op', 'op': 'Eltwise', 'operation': 'sum'},
|
||||
'fc_plus_itanhc': {'kind': 'op', 'op': 'Eltwise', 'operation': 'sum'},
|
||||
'o_plus_c': {'kind': 'op', 'op': 'Eltwise', 'operation': 'sum'},
|
||||
'scaled_i': {'kind': 'op', 'op': 'Mul'},
|
||||
'scaled_f': {'kind': 'op', 'op': 'Mul'},
|
||||
'scaled_o': {'kind': 'op', 'op': 'Mul'},
|
||||
'scale_i_c': {'kind': 'op', 'op': 'ScaleShift'},
|
||||
'scale_f_c': {'kind': 'op', 'op': 'ScaleShift'},
|
||||
'scale_o_c': {'kind': 'op', 'op': 'ScaleShift'},
|
||||
@ -50,6 +54,7 @@ class ReplaceLstmNonlinearityTests(unittest.TestCase):
|
||||
def test_lstm_nonlinearity(self):
|
||||
graph = build_graph({'in': {'kind': 'op', 'op': 'Parameter'},
|
||||
'lstm': {'kind': 'op', 'op': 'LstmNonLinearity',
|
||||
'use_dropout': False,
|
||||
'i_weights': np.array([]),
|
||||
'f_weights': np.array([]),
|
||||
'o_weights': np.array([]),},
|
||||
@ -88,3 +93,53 @@ class ReplaceLstmNonlinearityTests(unittest.TestCase):
|
||||
ReplaceLstmNonLinearityPattern().replace_op(graph, Node(graph, 'lstm'))
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'out', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_lstm_nonlinearity_dropout(self):
|
||||
graph = build_graph({'in': {'kind': 'op', 'op': 'Parameter'},
|
||||
'lstm': {'kind': 'op', 'op': 'LstmNonLinearity',
|
||||
'use_dropout': True,
|
||||
'i_weights': np.array([]),
|
||||
'f_weights': np.array([]),
|
||||
'o_weights': np.array([]),},
|
||||
'out': {'kind': 'op', 'op': 'Placeholder'}},
|
||||
[('in', 'lstm'), ('lstm', 'out')], nodes_with_edges_only=True)
|
||||
graph.stage = 'front'
|
||||
# split input to (i_part, f_part, c_part, o_part, ct_1)
|
||||
ref_graph = build_graph(self.nodes_attributes, [
|
||||
('in', 'split_dropout'),
|
||||
('split_dropout', 'split', {'out': 0}),
|
||||
('split', 'scale_i_c', {'out': 4}),
|
||||
('scale_i_c', 'i_plus_c'),
|
||||
('split', 'i_plus_c', {'out': 0}),
|
||||
('i_plus_c', 'sigmoid_i'),
|
||||
('sigmoid_i', 'scaled_i', {'in': 0}),
|
||||
('split_dropout', 'scaled_i', {'out': 1, 'in': 1}),
|
||||
('split', 'scale_f_c', {'out': 4}),
|
||||
('scale_f_c', 'f_plus_c'),
|
||||
('split', 'f_plus_c', {'out': 1}),
|
||||
('f_plus_c', 'sigmoid_f'),
|
||||
('sigmoid_f', 'scaled_f', {'in': 0}),
|
||||
('split_dropout', 'scaled_f', {'out': 2, 'in': 1}),
|
||||
('split', 'tanhcp', {'out': 2}),
|
||||
('tanhcp', 'i_mul_tanhc'),
|
||||
('scaled_i', 'i_mul_tanhc'),
|
||||
('scaled_f', 'f_mul_c'),
|
||||
('split', 'f_mul_c', {'out': 4}),
|
||||
('f_mul_c', 'fc_plus_itanhc'),
|
||||
('i_mul_tanhc', 'fc_plus_itanhc'),
|
||||
('split', 'scale_o_c', {'out': 4}),
|
||||
('scale_o_c', 'o_plus_c'),
|
||||
('split', 'o_plus_c', {'out': 3}),
|
||||
('o_plus_c', 'sigmoid_o'),
|
||||
('sigmoid_o', 'scaled_o', {'in': 0}),
|
||||
('split_dropout', 'scaled_o', {'out': 3, 'in': 1}),
|
||||
('fc_plus_itanhc', 'tanhc'),
|
||||
('scaled_o', 'o_mul_tanhc'),
|
||||
('tanhc', 'o_mul_tanhc'),
|
||||
('fc_plus_itanhc', 'concat'),
|
||||
('o_mul_tanhc', 'concat'),
|
||||
('lstm', 'out'),
|
||||
], nodes_with_edges_only=True)
|
||||
ReplaceLstmNonLinearityPattern().replace_op(graph, Node(graph, 'lstm'))
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'out', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
Loading…
Reference in New Issue
Block a user