Bell/revise lstm cell sequence (#7638)

* revise LSTM spec

Signed-off-by: fishbell <bell.song@intel.com>

* add param check and related test case

Signed-off-by: fishbell <bell.song@intel.com>

* fix clang-format

Signed-off-by: fishbell <bell.song@intel.com>

* use static_cast to replace c style force conversion

Signed-off-by: fishbell <bell.song@intel.com>
This commit is contained in:
song, bell 2021-11-15 22:13:26 -05:00 committed by GitHub
parent f66d9216ef
commit 9ccc308523
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 13 deletions

View File

@ -6,21 +6,20 @@
**Short description**: *LSTMCell* operation represents a single LSTM cell. It computes the output using the formula described in the original paper [Long Short-Term Memory](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.676.4320&rep=rep1&type=pdf).
**Detailed description**
**Detailed description**: *LSTMCell* computes the output *Ht* and *ot* for current time step based on the following formula:
```
Formula:
* - matrix mult
(.) - eltwise mult
* - matrix multiplication
(.) - Hadamard product (element-wise)
[,] - concatenation
sigm - 1/(1 + e^{-x})
tanh - (e^{2x} - 1)/(e^{2x} + 1)
f = sigm(Wf*[Hi, X] + Bf)
i = sigm(Wi*[Hi, X] + Bi)
c = tanh(Wc*[Hi, X] + Bc)
o = sigm(Wo*[Hi, X] + Bo)
Co = f (.) Ci + i (.) c
Ho = o (.) tanh(Co)
f, g, h - are activation functions.
it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf)
ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
Ct = ft (.) Ct-1 + it (.) ct
ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo)
Ht = ot (.) h(Ct)
```
**Attributes**
@ -37,7 +36,7 @@ tanh - (e^{2x} - 1)/(e^{2x} + 1)
* **Description**: *activations* specifies activation functions for gates, there are three gates, so three activation functions should be specified as a value for this attributes
* **Range of values**: any combination of *relu*, *sigmoid*, *tanh*
* **Type**: a list of strings
* **Default value**: *sigmoid,tanh,tanh*
* **Default value**: *sigmoid* for f, *tanh* for g, *tanh* for h
* **Required**: *no*
* *activations_alpha, activations_beta*
@ -68,7 +67,7 @@ tanh - (e^{2x} - 1)/(e^{2x} + 1)
* **5**: `R` - 2D tensor of type *T* `[4 * hidden_size, hidden_size]`, the recurrence weights for matrix multiplication, gate order: fico. **Required.**
* **6**: `B` 1D tensor of type *T* `[4 * hidden_size]`, the sum of biases (weights and recurrence weights). **Required.**
* **6**: `B` 1D tensor of type *T* `[4 * hidden_size]`, the sum of biases (weights and recurrence weights), if not specified - assumed to be 0. **optional.**
**Outputs**

View File

@ -538,6 +538,23 @@ void op::v5::LSTMSequence::validate_and_infer_types() {
Dimension::merge(merged_num_directions, merged_num_directions, b_pshape[0]),
"Parameter num_directions not matched in LSTMSequence.");
auto check_direction_valid = [](const ov::PartialShape& pshape, size_t index) -> bool {
if (pshape[index].is_static())
return static_cast<direction>(pshape[index].get_length()) == direction::FORWARD ||
static_cast<direction>(pshape[index].get_length()) == direction::REVERSE ||
static_cast<direction>(pshape[index].get_length()) == direction::BIDIRECTIONAL;
return true;
};
NODE_VALIDATION_CHECK(this,
check_direction_valid(ht_pshape, 1),
"Parameter direction must be Forward or Reverse or Bidirectional.");
NODE_VALIDATION_CHECK(this,
m_direction == direction::FORWARD || m_direction == direction::REVERSE ||
m_direction == direction::BIDIRECTIONAL,
"Parameter direction must be Forward or Reverse or Bidirectional.");
// Validate hidden_size value for W, R, B inputs
if (merged_hidden_size.is_static()) {
if (w_pshape[1].is_static()) {

View File

@ -324,3 +324,22 @@ TEST(type_prop, lstm_sequence_invalid_input_dynamic_rank) {
EXPECT_EQ(check_dynamic_lstm(lstm_sequence), true);
}
}
TEST(type_prop, lstm_sequence_invalid_input_direction) {
recurrent_sequence_parameters param;
param.batch_size = 24;
param.num_directions = 3;
param.seq_length = 12;
param.input_size = 8;
param.hidden_size = 256;
param.et = element::f32;
auto lstm_sequence = lstm_seq_tensor_initialization(param);
try {
lstm_sequence->validate_and_infer_types();
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Parameter direction must be Forward or Reverse or Bidirectional"));
}
}