Bell/revise lstm cell sequence (#7638)
* revise LSTM spec Signed-off-by: fishbell <bell.song@intel.com> * add param check and related test case Signed-off-by: fishbell <bell.song@intel.com> * fix clang-format Signed-off-by: fishbell <bell.song@intel.com> * use static_cast to replace c style force conversion Signed-off-by: fishbell <bell.song@intel.com>
This commit is contained in:
parent
f66d9216ef
commit
9ccc308523
@ -6,21 +6,20 @@
|
||||
|
||||
**Short description**: *LSTMCell* operation represents a single LSTM cell. It computes the output using the formula described in the original paper [Long Short-Term Memory](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.676.4320&rep=rep1&type=pdf).
|
||||
|
||||
**Detailed description**
|
||||
**Detailed description**: *LSTMCell* computes the output *Ht* and *ot* for current time step based on the following formula:
|
||||
|
||||
```
|
||||
Formula:
|
||||
* - matrix mult
|
||||
(.) - eltwise mult
|
||||
* - matrix multiplication
|
||||
(.) - Hadamard product (element-wise)
|
||||
[,] - concatenation
|
||||
sigm - 1/(1 + e^{-x})
|
||||
tanh - (e^{2x} - 1)/(e^{2x} + 1)
|
||||
f = sigm(Wf*[Hi, X] + Bf)
|
||||
i = sigm(Wi*[Hi, X] + Bi)
|
||||
c = tanh(Wc*[Hi, X] + Bc)
|
||||
o = sigm(Wo*[Hi, X] + Bo)
|
||||
Co = f (.) Ci + i (.) c
|
||||
Ho = o (.) tanh(Co)
|
||||
f, g, h - are activation functions.
|
||||
it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
|
||||
ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf)
|
||||
ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
|
||||
Ct = ft (.) Ct-1 + it (.) ct
|
||||
ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo)
|
||||
Ht = ot (.) h(Ct)
|
||||
```
|
||||
|
||||
**Attributes**
|
||||
@ -37,7 +36,7 @@ tanh - (e^{2x} - 1)/(e^{2x} + 1)
|
||||
* **Description**: *activations* specifies activation functions for gates, there are three gates, so three activation functions should be specified as a value for this attributes
|
||||
* **Range of values**: any combination of *relu*, *sigmoid*, *tanh*
|
||||
* **Type**: a list of strings
|
||||
* **Default value**: *sigmoid,tanh,tanh*
|
||||
* **Default value**: *sigmoid* for f, *tanh* for g, *tanh* for h
|
||||
* **Required**: *no*
|
||||
|
||||
* *activations_alpha, activations_beta*
|
||||
@ -68,7 +67,7 @@ tanh - (e^{2x} - 1)/(e^{2x} + 1)
|
||||
|
||||
* **5**: `R` - 2D tensor of type *T* `[4 * hidden_size, hidden_size]`, the recurrence weights for matrix multiplication, gate order: fico. **Required.**
|
||||
|
||||
* **6**: `B` 1D tensor of type *T* `[4 * hidden_size]`, the sum of biases (weights and recurrence weights). **Required.**
|
||||
* **6**: `B` 1D tensor of type *T* `[4 * hidden_size]`, the sum of biases (weights and recurrence weights), if not specified - assumed to be 0. **optional.**
|
||||
|
||||
|
||||
**Outputs**
|
||||
|
@ -538,6 +538,23 @@ void op::v5::LSTMSequence::validate_and_infer_types() {
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, b_pshape[0]),
|
||||
"Parameter num_directions not matched in LSTMSequence.");
|
||||
|
||||
auto check_direction_valid = [](const ov::PartialShape& pshape, size_t index) -> bool {
|
||||
if (pshape[index].is_static())
|
||||
return static_cast<direction>(pshape[index].get_length()) == direction::FORWARD ||
|
||||
static_cast<direction>(pshape[index].get_length()) == direction::REVERSE ||
|
||||
static_cast<direction>(pshape[index].get_length()) == direction::BIDIRECTIONAL;
|
||||
return true;
|
||||
};
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
check_direction_valid(ht_pshape, 1),
|
||||
"Parameter direction must be Forward or Reverse or Bidirectional.");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
m_direction == direction::FORWARD || m_direction == direction::REVERSE ||
|
||||
m_direction == direction::BIDIRECTIONAL,
|
||||
"Parameter direction must be Forward or Reverse or Bidirectional.");
|
||||
|
||||
// Validate hidden_size value for W, R, B inputs
|
||||
if (merged_hidden_size.is_static()) {
|
||||
if (w_pshape[1].is_static()) {
|
||||
|
@ -324,3 +324,22 @@ TEST(type_prop, lstm_sequence_invalid_input_dynamic_rank) {
|
||||
EXPECT_EQ(check_dynamic_lstm(lstm_sequence), true);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, lstm_sequence_invalid_input_direction) {
|
||||
recurrent_sequence_parameters param;
|
||||
|
||||
param.batch_size = 24;
|
||||
param.num_directions = 3;
|
||||
param.seq_length = 12;
|
||||
param.input_size = 8;
|
||||
param.hidden_size = 256;
|
||||
param.et = element::f32;
|
||||
|
||||
auto lstm_sequence = lstm_seq_tensor_initialization(param);
|
||||
try {
|
||||
lstm_sequence->validate_and_infer_types();
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Parameter direction must be Forward or Reverse or Bidirectional"));
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user