Files
openvino/model-optimizer/extensions/ops/lstm_cell.py
Ivan Tikhonov 2f5a28d44f LSTMCell/Sequence v1, reference implementations and decompose transformations for LSTM/GRU/RNN Cells (#2000)
* validate_and_infer_types() implementation

* input parameter validation for LSTM, GRU and RNN

* style-check applied

* Add LSTMSequence dynamic shape validation and test props for RNNCell, GRUCell, LSTMCell and LSTMSequence.

* recurrent_sequence.hpp moved to ngraph/core/include/ngraph/op/util/

* style check applied

* removed unused variable from LSTMSequence::validate_and_infer_types

* Add missing newline mark at the end of file.

* Add supression macro for FusedOp deprecation.

* Add element type initialization

* transpose,rnn cell reference implementations

* Apply PR review remarks

* reference implementations for cells op, single layer tests, align lstm cell/sequence according to the spec

* lstm/gru/rnn cell decompostion transformations

* ngraph codestyle

* clean up

* ngraph code style

* change inheritance of Cells, fix build

* fix build

* fix build again

* remove Peepholes from LSTMSeq, fix copy_runtime_info in transformations

* Rewrite tests to use gtest exception assertions.

* resolve tests issues

* ngraph codestyle

* add missed files

* fix typeprop tests

* fix lstm sequence checks

* fix arm build

* fix arm again

* delete unnecessary file

* add convert weghts format function, enable lstm test, resolve review comments

* add ngraph builders

* ngraph codestyle

* fix unit tests

* revert transpose reference implementation

* revert LSTM Cell v0, add LSTMCell v1, update transformation lstm_cell_to_cell_ie

* v1 version of LSTMCell op

* LSTMSequence v1 operation, exclude LSTMSeq from opset4

* fix python api tests

* resolve review comments, tests for decomposition transformations, switch lstm cell to opset4 in mo

Co-authored-by: Szymon Durawa <szymon.durawa@intel.com>
2020-09-04 09:04:36 +03:00

102 lines
3.3 KiB
Python

"""
Copyright (C) 2017-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.front.common.partial_infer.utils import mark_input_bins
from mo.graph.graph import Node, Graph
from mo.ops.op import Op
from mo.utils.error import Error
class LSTMCell(Op):
''' A single LSTM cell (without a loop).
3 inputs:
- [0, required] input data (2D),
- [1, required] initial hidden state (2D),
- [2, required] initial cell state (2D),
2 blobs:
- [3, required] LSTM FC weights
- [4, required] LSTM FC biases
2 outputs:
- [required] output data / resulting hidden state (2D)
- [required] resulting cell state (2D)
'''
op = 'LSTMCell'
def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'type': __class__.op,
'op': __class__.op,
'version': 'opset4',
'infer': __class__.infer,
'in_ports_count': 5,
'out_ports_count': 2,
'wr_input_id': 3,
'gates_count': 4
}
super().__init__(graph, mandatory_props, attrs)
def supported_attrs(self):
return [
'hidden_size', # number of the elements in hidden cell size
'activations',
'activation_alpha',
'activation_beta',
'clip',
]
def backend_attrs(self):
return [
'hidden_size', # number of the elements in hidden cell size
('activations', lambda node: ','.join(node.activations) if node.activations is not None else None),
'activation_alpha',
'activation_beta',
'clip',
]
@staticmethod
def infer(node: Node):
if node.has_and_set('extra_inputs'):
assert len(node.in_nodes()) == 8
else:
assert len(node.in_nodes()) == 5
assert len(node.out_nodes()) in [1, 2]
hidden_shape = node.in_node(1).shape.copy()
cell_shape = node.in_node(2).shape.copy()
mark_input_bins(node, start_port=3)
node.out_node(0).shape = hidden_shape
if len(node.out_nodes()) == 2:
node.out_node(1).shape = cell_shape
hidden_size = hidden_shape[1]
if node.has_valid('hidden_size'):
if node.hidden_size != hidden_size:
raise Error("Input shape {} for hidden size doesn't match pre-defined hidden_size in node {}".format(
node.in_node(1).shape, node.soft_get('name')))
else:
node['hidden_size'] = hidden_size
assert cell_shape[1] == hidden_size
input_shape = node.in_node(0).shape
assert input_shape is not None
assert hidden_shape[0] == cell_shape[0] == input_shape[0], 'States are not broadcastable by batch'