GRU/RNN/LSTM sequence ops, reference implementations, single layer tests (#1594)
* gru/rnn sequences * update gru/rnn sequences ops, add unit tests * enable sequence transformations for cpu plugin * ngraph codestyle * update tensor iterator to rnn/gru/lstm sequence transformations, add unit tests * ngraph codestyle * add visitors for ngraph ie ops, fix a bug with incorrect axis, fix ngraph to ngraph ie conversion * update GRUSequence/GRUSequenceIE according to plugin format * fix ngraph ie implementations according to plugins restricrictions * fix naming issue * adapt unit tests to accordance to new changes * strict checks, additional unit tests * add descriptions for transformations, fix unit tests * enable ti to sequnece and unroll transformations in plugins for testing * disable tensor iterator ngraph reader tests * delete unnecessary cmake file * fix includes * clean up, resolve review comments * move ti to sequence transformation to ti folder * validate_and_infer_types() implementation * input parameter validation for LSTM, GRU and RNN * style-check applied * Add LSTMSequence dynamic shape validation and test props for RNNCell, GRUCell, LSTMCell and LSTMSequence. * recurrent_sequence.hpp moved to ngraph/core/include/ngraph/op/util/ * style check applied * removed unused variable from LSTMSequence::validate_and_infer_types * Add missing newline mark at the end of file. * Add supression macro for FusedOp deprecation. * Add element type initialization * transpose,rnn cell reference implementations * Apply PR review remarks * reference implementations for cells op, single layer tests, align lstm cell/sequence according to the spec * lstm/gru/rnn cell decompostion transformations * ngraph codestyle * clean up * ngraph code style * change inheritance of Cells, fix build * fix build * fix build again * remove Peepholes from LSTMSeq, fix copy_runtime_info in transformations * Rewrite tests to use gtest exception assertions. * resolve tests issues * ngraph codestyle * add missed files * fix typeprop tests * fix lstm sequence checks * fix arm build * fix arm again * delete unnecessary file * add convert weghts format function, enable lstm test, resolve review comments * add ngraph builders * ngraph codestyle * fix unit tests * revert transpose reference implementation * move ti to sequences transformation to another branch, resolve review comments * resolve review comments * revert fix in ie_layer_validators * revert LSTM Cell v0, add LSTMCell v1, update transformation lstm_cell_to_cell_ie * v1 version of LSTMCell op * LSTMSequence v1 operation, exclude LSTMSeq from opset4 * fix python api tests * resolve review comments, tests for decomposition transformations, switch lstm cell to opset4 in mo * references impl for RNN/GRU/LSTM Sequences, single layer tests, bidirectional transformation * fix unit tests * process dynamic ranks of rnn/gru/lstm ops * remove sequences specifications from opset4 * resolve review comments * fix validate_and_infer_types of GRU/RNN sequences Co-authored-by: Szymon Durawa <szymon.durawa@intel.com>
This commit is contained in:
@@ -62,7 +62,6 @@ declared in `namespace opset4`.
|
||||
* [GroupConvolution](convolution/GroupConvolution_1.md)
|
||||
* [GroupConvolutionBackpropData](convolution/GroupConvolutionBackpropData_1.md)
|
||||
* [GRUCell](sequence/GRUCell_3.md)
|
||||
* [GRUSequence](sequence/GRUSequence_4.md)
|
||||
* [HardSigmoid](activation/HardSigmoid_1.md)
|
||||
* [HSwish](activation/HSwish_4.md)
|
||||
* [Interpolate](image/Interpolate_4.md)
|
||||
@@ -75,7 +74,6 @@ declared in `namespace opset4`.
|
||||
* [LogicalXor](logical/LogicalXor_1.md)
|
||||
* [LRN](normalization/LRN_1.md)
|
||||
* [LSTMCell](sequence/LSTMCell_1.md)
|
||||
* [LSTMSequence](sequence/LSTMSequence_1.md)
|
||||
* [MatMul](matrix/MatMul_1.md)
|
||||
* [MaxPool](pooling/MaxPool_1.md)
|
||||
* [Maximum](arithmetic/Maximum_1.md)
|
||||
@@ -117,7 +115,6 @@ declared in `namespace opset4`.
|
||||
* [Reverse](movement/Reverse_1.md)
|
||||
* [ReverseSequence](movement/ReverseSequence_1.md)
|
||||
* [RNNCell](sequence/RNNCell_3.md)
|
||||
* [RNNSequence](sequence/RNNSequence_4.md)
|
||||
* [ROIAlign](detection/ROIAlign_3.md)
|
||||
* [ROIPooling](detection/ROIPooling_1.md)
|
||||
* [ScatterElementsUpdate](movement/ScatterElementsUpdate_3.md)
|
||||
|
||||
@@ -1,136 +0,0 @@
|
||||
## GRUSequence <a name="GRUSequence"></a> {#openvino_docs_ops_sequence_GRUSequence_4}
|
||||
|
||||
**Versioned name**: *GRUSequence-4*
|
||||
|
||||
**Category**: *Sequence processing*
|
||||
|
||||
**Short description**: *GRUSequence* operation represents a series of GRU cells. Each cell is implemented as <a href="#GRUCell">GRUCell</a> operation.
|
||||
|
||||
**Detailed description**
|
||||
|
||||
A single cell in the sequence is implemented in the same way as in <a href="#GRUCell">GRUCell</a> operation. *GRUSequence* represents a sequence of GRU cells. The sequence can be connected differently depending on `direction` attribute that specifies the direction of traversing of input data along sequence dimension or specifies whether it should be a bidirectional sequence. The most of the attributes are in sync with the specification of ONNX GRU operator defined <a href="https://github.com/onnx/onnx/blob/master/docs/Operators.md#gru">GRUCell</a>.
|
||||
|
||||
|
||||
**Attributes**
|
||||
|
||||
* *hidden_size*
|
||||
|
||||
* **Description**: *hidden_size* specifies hidden state size.
|
||||
* **Range of values**: a positive integer
|
||||
* **Type**: `int`
|
||||
* **Default value**: None
|
||||
* **Required**: *yes*
|
||||
|
||||
* *activations*
|
||||
|
||||
* **Description**: *activations* specifies activation functions for gates, there are two gates, so two activation functions should be specified as a value for this attributes
|
||||
* **Range of values**: any combination of *relu*, *sigmoid*, *tanh*
|
||||
* **Type**: a list of strings
|
||||
* **Default value**: *sigmoid,tanh*
|
||||
* **Required**: *no*
|
||||
|
||||
* *activations_alpha, activations_beta*
|
||||
|
||||
* **Description**: *activations_alpha, activations_beta* attributes of functions; applicability and meaning of these attributes depends on choosen activation functions
|
||||
* **Range of values**: a list of floating-point numbers
|
||||
* **Type**: `float[]`
|
||||
* **Default value**: None
|
||||
* **Required**: *no*
|
||||
|
||||
* *clip*
|
||||
|
||||
* **Description**: *clip* specifies bound values *[-C, C]* for tensor clipping. Clipping is performed before activations.
|
||||
* **Range of values**: a positive floating-point number
|
||||
* **Type**: `float`
|
||||
* **Default value**: *infinity* that means that the clipping is not applied
|
||||
* **Required**: *no*
|
||||
|
||||
* *direction*
|
||||
|
||||
* **Description**: Specify if the RNN is forward, reverse, or bidirectional. If it is one of *forward* or *reverse* then `num_directions = 1`, if it is *bidirectional*, then `num_directions = 2`. This `num_directions` value specifies input/output shape requirements.
|
||||
* **Range of values**: *forward*, *reverse*, *bidirectional*
|
||||
* **Type**: `string`
|
||||
* **Default value**: None
|
||||
* **Required**: *Yes*
|
||||
|
||||
* *linear_before_reset*
|
||||
|
||||
* **Description**: *linear_before_reset* flag denotes if the layer behaves according to the modification of *GRUCell* described in the formula in the [ONNX documentation](https://github.com/onnx/onnx/blob/master/docs/Operators.md#GRU).
|
||||
* **Range of values**: True or False
|
||||
* **Type**: `boolean`
|
||||
* **Default value**: False
|
||||
* **Required**: *no*
|
||||
|
||||
**Inputs**
|
||||
|
||||
* **1**: `X` - 3D tensor of type *T1* `[batch_size, seq_length, input_size]`, input data. It differs from GRUCell 1st input only by additional axis with size `seq_length`. **Required.**
|
||||
|
||||
* **2**: `initial_hidden_state` - 3D tensor of type *T1* `[batch_size, num_directions, hidden_size]`, input hidden state data. **Required.**
|
||||
|
||||
* **3**: `sequence_lengths` - 1D tensor of type *T2* `[batch_size]`, specifies real sequence lengths for each batch element. **Required.**
|
||||
|
||||
* **4**: `W` - 3D tensor of type *T1* `[num_directions, 3 * hidden_size, input_size]`, the weights for matrix multiplication, gate order: zrh. **Required.**
|
||||
|
||||
* **5**: `R` - 3D tensor of type *T1* `[num_directions, 3 * hidden_size, hidden_size]`, the recurrence weights for matrix multiplication, gate order: zrh. **Required.**
|
||||
|
||||
* **6**: `B` - 2D tensor of type *T*. If *linear_before_reset* is set to 1, then the shape is `[num_directions, 4 * hidden_size]` - the sum of biases for z and r gates (weights and recurrence weights), the biases for h gate are placed separately. Otherwise the shape is `[num_directions, 3 * hidden_size]`, the sum of biases (weights and recurrence weights). **Required.**
|
||||
|
||||
**Outputs**
|
||||
|
||||
* **1**: `Y` – 3D tensor of type *T1* `[batch_size, num_directions, seq_len, hidden_size]`, concatenation of all the intermediate output values of the hidden.
|
||||
|
||||
* **2**: `Ho` - 3D tensor of type *T1* `[batch_size, num_directions, hidden_size]`, the last output value of hidden state.
|
||||
|
||||
**Types**
|
||||
|
||||
* *T1*: any supported floating point type.
|
||||
* *T2*: any supported integer type.
|
||||
|
||||
**Example**
|
||||
```xml
|
||||
<layer ... type="GRUSequence" ...>
|
||||
<data hidden_size="128"/>
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>1</dim>
|
||||
<dim>4</dim>
|
||||
<dim>16</dim>
|
||||
</port>
|
||||
<port id="1">
|
||||
<dim>1</dim>
|
||||
<dim>1</dim>
|
||||
<dim>128</dim>
|
||||
</port>
|
||||
<port id="2">
|
||||
<dim>1</dim>
|
||||
</port>
|
||||
<port id="3">
|
||||
<dim>1</dim>
|
||||
<dim>384</dim>
|
||||
<dim>16</dim>
|
||||
</port>
|
||||
<port id="4">
|
||||
<dim>1</dim>
|
||||
<dim>384</dim>
|
||||
<dim>128</dim>
|
||||
</port>
|
||||
<port id="5">
|
||||
<dim>1</dim>
|
||||
<dim>384</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="6">
|
||||
<dim>1</dim>
|
||||
<dim>1</dim>
|
||||
<dim>4</dim>
|
||||
<dim>128</dim>
|
||||
</port>
|
||||
<port id="7">
|
||||
<dim>1</dim>
|
||||
<dim>1</dim>
|
||||
<dim>128</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
```
|
||||
@@ -1,128 +0,0 @@
|
||||
## RNNSequence <a name="RNNSequence"></a> {#openvino_docs_ops_sequence_RNNSequence_4}
|
||||
|
||||
**Versioned name**: *RNNSequence-4*
|
||||
|
||||
**Category**: *Sequence processing*
|
||||
|
||||
**Short description**: *RNNSequence* operation represents a series of RNN cells. Each cell is implemented as <a href="#RNNCell">RNNCell</a> operation.
|
||||
|
||||
**Detailed description**
|
||||
|
||||
A single cell in the sequence is implemented in the same way as in <a href="#RNNCell">RNNCell</a> operation. *RNNSequence* represents a sequence of RNN cells. The sequence can be connected differently depending on `direction` attribute that specifies the direction of traversing of input data along sequence dimension or specifies whether it should be a bidirectional sequence. The most of the attributes are in sync with the specification of ONNX RNN operator defined <a href="https://github.com/onnx/onnx/blob/master/docs/Operators.md#rnn">RNNCell</a>.
|
||||
|
||||
|
||||
**Attributes**
|
||||
|
||||
* *hidden_size*
|
||||
|
||||
* **Description**: *hidden_size* specifies hidden state size.
|
||||
* **Range of values**: a positive integer
|
||||
* **Type**: `int`
|
||||
* **Default value**: None
|
||||
* **Required**: *yes*
|
||||
|
||||
* *activations*
|
||||
|
||||
* **Description**: activation functions for gates
|
||||
* **Range of values**: any combination of *relu*, *sigmoid*, *tanh*
|
||||
* **Type**: a list of strings
|
||||
* **Default value**: *tanh*
|
||||
* **Required**: *no*
|
||||
|
||||
* *activations_alpha, activations_beta*
|
||||
|
||||
* **Description**: *activations_alpha, activations_beta* attributes of functions; applicability and meaning of these attributes depends on choosen activation functions
|
||||
* **Range of values**: a list of floating-point numbers
|
||||
* **Type**: `float[]`
|
||||
* **Default value**: None
|
||||
* **Required**: *no*
|
||||
|
||||
* *clip*
|
||||
|
||||
* **Description**: *clip* specifies bound values *[-C, C]* for tensor clipping. Clipping is performed before activations.
|
||||
* **Range of values**: a positive floating-point number
|
||||
* **Type**: `float`
|
||||
* **Default value**: *infinity* that means that the clipping is not applied
|
||||
* **Required**: *no*
|
||||
|
||||
* *direction*
|
||||
|
||||
* **Description**: Specify if the RNN is forward, reverse, or bidirectional. If it is one of *forward* or *reverse* then `num_directions = 1`, if it is *bidirectional*, then `num_directions = 2`. This `num_directions` value specifies input/output shape requirements.
|
||||
* **Range of values**: *forward*, *reverse*, *bidirectional*
|
||||
* **Type**: `string`
|
||||
* **Default value**: None
|
||||
* **Required**: *Yes*
|
||||
|
||||
**Inputs**
|
||||
|
||||
* **1**: `X` - 3D tensor of type *T1* `[batch_size, seq_length, input_size]`, input data. It differs from RNNCell 1st input only by additional axis with size `seq_length`. **Required.**
|
||||
|
||||
* **2**: `initial_hidden_state` - 3D tensor of type *T1* `[batch_size, num_directions, hidden_size]`, input hidden state data. **Required.**
|
||||
|
||||
* **3**: `sequence_lengths` - 1D tensor of type *T2* `[batch_size]`, specifies real sequence lengths for each batch element. **Required.**
|
||||
|
||||
* **4**: `W` - 3D tensor of type *T1* `[num_directions, hidden_size, input_size]`, the weights for matrix multiplication. **Required.**
|
||||
|
||||
* **5**: `R` - 3D tensor of type *T1* `[num_directions, hidden_size, hidden_size]`, the recurrence weights for matrix multiplication. **Required.**
|
||||
|
||||
* **6**: `B` - 2D tensor of type *T1* `[num_directions, hidden_size]`, the sum of biases (weights and recurrence weights). **Required.**
|
||||
|
||||
**Outputs**
|
||||
|
||||
* **1**: `Y` – 3D tensor of type *T1* `[batch_size, num_directions, seq_len, hidden_size]`, concatenation of all the intermediate output values of the hidden.
|
||||
|
||||
* **2**: `Ho` - 3D tensor of type *T1* `[batch_size, num_directions, hidden_size]`, the last output value of hidden state.
|
||||
|
||||
**Types**
|
||||
|
||||
* *T1*: any supported floating point type.
|
||||
* *T2*: any supported integer type.
|
||||
|
||||
**Example**
|
||||
```xml
|
||||
<layer ... type="RNNSequence" ...>
|
||||
<data hidden_size="128"/>
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>1</dim>
|
||||
<dim>4</dim>
|
||||
<dim>16</dim>
|
||||
</port>
|
||||
<port id="1">
|
||||
<dim>1</dim>
|
||||
<dim>1</dim>
|
||||
<dim>128</dim>
|
||||
</port>
|
||||
<port id="2">
|
||||
<dim>1</dim>
|
||||
</port>
|
||||
<port id="3">
|
||||
<dim>1</dim>
|
||||
<dim>128</dim>
|
||||
<dim>16</dim>
|
||||
</port>
|
||||
<port id="4">
|
||||
<dim>1</dim>
|
||||
<dim>128</dim>
|
||||
<dim>128</dim>
|
||||
</port>
|
||||
<port id="5">
|
||||
<dim>1</dim>
|
||||
<dim>128</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
<port id="6">
|
||||
<dim>1</dim>
|
||||
<dim>1</dim>
|
||||
<dim>4</dim>
|
||||
<dim>128</dim>
|
||||
</port>
|
||||
<port id="7">
|
||||
<dim>1</dim>
|
||||
<dim>1</dim>
|
||||
<dim>128</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
```
|
||||
@@ -34,6 +34,9 @@
|
||||
#include "ngraph_ops/selu_ie.hpp"
|
||||
#include "ngraph_ops/rnn_cell_ie.hpp"
|
||||
#include "ngraph_ops/topk_ie.hpp"
|
||||
#include "ngraph_ops/rnn_sequence_ie.hpp"
|
||||
#include "ngraph_ops/lstm_sequence_ie.hpp"
|
||||
#include "ngraph_ops/gru_sequence_ie.hpp"
|
||||
#include "generic_ie.hpp"
|
||||
#include "exec_graph_info.hpp"
|
||||
|
||||
@@ -539,6 +542,111 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr
|
||||
return res;
|
||||
});
|
||||
|
||||
addSpecificCreator({"GRUSequenceIE"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
||||
const std::map<std::string, std::string>& params) -> CNNLayerPtr {
|
||||
|
||||
LayerParams attrs = {node->get_friendly_name(), "GRUSequence",
|
||||
details::convertPrecision(node->get_output_element_type(0))};
|
||||
auto res = std::make_shared<RNNSequenceLayer>(attrs);
|
||||
res->params = params;
|
||||
|
||||
if (res->params["direction"] == "reverse")
|
||||
res->params["direction"] = "Backward";
|
||||
else if (res->params["direction"] == "forward")
|
||||
res->params["direction"] = "Forward";
|
||||
else
|
||||
res->params["direction"] = "Bidirectional";
|
||||
|
||||
res->cellType = RNNSequenceLayer::CellType::GRU;
|
||||
if (res->params["linear_before_reset"] == "true") {
|
||||
res->cellType = RNNSequenceLayer::CellType::GRU_LBR;
|
||||
}
|
||||
|
||||
Builder::NodeConverter<ngraph::op::Constant> converter;
|
||||
const auto weightsNode = node->input_value(3).get_node_shared_ptr();
|
||||
if (converter.canCreate(weightsNode)) {
|
||||
const auto& weights = converter.createLayer(weightsNode);
|
||||
res->blobs["weights"] = weights->blobs["custom"];
|
||||
res->_weights = weights->blobs["custom"];
|
||||
}
|
||||
|
||||
const auto biasNode = node->input_value(4).get_node_shared_ptr();
|
||||
if (converter.canCreate(biasNode)) {
|
||||
const auto& bias = converter.createLayer(biasNode);
|
||||
res->blobs["biases"] = bias->blobs["custom"];
|
||||
res->_biases = bias->blobs["custom"];
|
||||
}
|
||||
return res;
|
||||
});
|
||||
|
||||
addSpecificCreator({"RNNSequenceIE"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
||||
const std::map<std::string, std::string>& params) -> CNNLayerPtr {
|
||||
|
||||
LayerParams attrs = {node->get_friendly_name(), "RNNSequence",
|
||||
details::convertPrecision(node->get_output_element_type(0))};
|
||||
auto res = std::make_shared<RNNSequenceLayer>(attrs);
|
||||
res->params = params;
|
||||
|
||||
res->cellType = RNNSequenceLayer::CellType::RNN;
|
||||
|
||||
if (res->params["direction"] == "reverse")
|
||||
res->params["direction"] = "Backward";
|
||||
else if (res->params["direction"] == "forward")
|
||||
res->params["direction"] = "Forward";
|
||||
else
|
||||
res->params["direction"] = "Bidirectional";
|
||||
|
||||
Builder::NodeConverter<ngraph::op::Constant> converter;
|
||||
const auto weightsNode = node->input_value(3).get_node_shared_ptr();
|
||||
if (converter.canCreate(weightsNode)) {
|
||||
const auto& weights = converter.createLayer(weightsNode);
|
||||
res->blobs["weights"] = weights->blobs["custom"];
|
||||
res->_weights = weights->blobs["custom"];
|
||||
}
|
||||
|
||||
const auto biasNode = node->input_value(4).get_node_shared_ptr();
|
||||
if (converter.canCreate(biasNode)) {
|
||||
const auto& bias = converter.createLayer(biasNode);
|
||||
res->blobs["biases"] = bias->blobs["custom"];
|
||||
res->_biases = bias->blobs["custom"];
|
||||
}
|
||||
return res;
|
||||
});
|
||||
|
||||
addSpecificCreator({"LSTMSequenceIE"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
||||
const std::map<std::string, std::string>& params) -> CNNLayerPtr {
|
||||
|
||||
LayerParams attrs = {node->get_friendly_name(), "LSTMSequence",
|
||||
details::convertPrecision(node->get_output_element_type(0))};
|
||||
auto res = std::make_shared<RNNSequenceLayer>(attrs);
|
||||
res->params = params;
|
||||
|
||||
res->cellType = RNNSequenceLayer::CellType::LSTM;
|
||||
|
||||
if (res->params["direction"] == "reverse")
|
||||
res->params["direction"] = "Backward";
|
||||
else if (res->params["direction"] == "forward")
|
||||
res->params["direction"] = "Forward";
|
||||
else
|
||||
res->params["direction"] = "Bidirectional";
|
||||
|
||||
Builder::NodeConverter<ngraph::op::Constant> converter;
|
||||
const auto weightsNode = node->input_value(4).get_node_shared_ptr();
|
||||
if (converter.canCreate(weightsNode)) {
|
||||
const auto &weights = converter.createLayer(weightsNode);
|
||||
res->blobs["weights"] = weights->blobs["custom"];
|
||||
res->_weights = weights->blobs["custom"];
|
||||
}
|
||||
|
||||
const auto biasNode = node->input_value(5).get_node_shared_ptr();
|
||||
if (converter.canCreate(biasNode)) {
|
||||
const auto &bias = converter.createLayer(biasNode);
|
||||
res->blobs["biases"] = bias->blobs["custom"];
|
||||
res->_biases = bias->blobs["custom"];
|
||||
}
|
||||
return res;
|
||||
});
|
||||
|
||||
REQUIRED_IE_CONVERSION_CREATOR("Broadcast", "Tile");
|
||||
REQUIRED_IE_CONVERSION_CREATOR("Interpolate", "Interp");
|
||||
REQUIRED_IE_CONVERSION_CREATOR("NormalizeL2", "NormalizeIE");
|
||||
@@ -736,13 +844,24 @@ void convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function
|
||||
::ngraph::as_type_ptr<::ngraph::op::VariadicSplit>(consumerLayer) ||
|
||||
::ngraph::as_type_ptr<::ngraph::op::ScaleShiftIE>(consumerLayer) ||
|
||||
::ngraph::as_type_ptr<::ngraph::op::Transpose>(consumerLayer) ||
|
||||
::ngraph::as_type_ptr<::ngraph::op::LSTMSequenceIE>(consumerLayer) ||
|
||||
::ngraph::as_type_ptr<::ngraph::op::RNNSequenceIE>(consumerLayer) ||
|
||||
::ngraph::as_type_ptr<::ngraph::op::GRUSequenceIE>(consumerLayer) ||
|
||||
::ngraph::as_type_ptr<::ngraph::op::RNNCellIE>(consumerLayer) ||
|
||||
::ngraph::as_type_ptr<::ngraph::op::GRUCellIE>(consumerLayer)) {
|
||||
// Check that all input nodes except zero input are Constants for all ops except DeformableConvolutions
|
||||
// for which the input with index 1 is also dynamic
|
||||
size_t inputID = ::ngraph::as_type_ptr<::ngraph::op::v1::DeformableConvolution>(consumerLayer) ||
|
||||
size_t inputID = 1;
|
||||
if (::ngraph::as_type_ptr<::ngraph::op::v1::DeformableConvolution>(consumerLayer) ||
|
||||
::ngraph::as_type_ptr<::ngraph::op::GRUCellIE>(consumerLayer) ||
|
||||
::ngraph::as_type_ptr<::ngraph::op::RNNCellIE>(consumerLayer)? 2 : 1;
|
||||
::ngraph::as_type_ptr<::ngraph::op::RNNCellIE>(consumerLayer) ||
|
||||
::ngraph::as_type_ptr<::ngraph::op::GRUSequenceIE>(consumerLayer) ||
|
||||
::ngraph::as_type_ptr<::ngraph::op::RNNSequenceIE>(consumerLayer)) {
|
||||
inputID = 2;
|
||||
} else if (::ngraph::as_type_ptr<::ngraph::op::LSTMSequenceIE>(consumerLayer)) {
|
||||
inputID = 3;
|
||||
}
|
||||
|
||||
for (; inputID < consumerLayer->inputs().size(); ++inputID) {
|
||||
auto inputLayer = consumerLayer->input(inputID).get_source_output().get_node_shared_ptr();
|
||||
if (inputLayer == constLayer) {
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
class TRANSFORMATIONS_API GRUSequenceIE : public ngraph::op::util::RNNCellBase {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
GRUSequenceIE(const Output <Node> &X,
|
||||
const Output <Node> &H_t,
|
||||
const Output <Node> &seg_lengths,
|
||||
const Output <Node> &WR,
|
||||
const Output <Node> &B,
|
||||
size_t hidden_size,
|
||||
op::RecurrentSequenceDirection direction,
|
||||
const std::vector<std::string> &activations,
|
||||
const std::vector<float> &activations_alpha,
|
||||
const std::vector<float> &activations_beta,
|
||||
float clip,
|
||||
bool linear_before_reset);
|
||||
|
||||
GRUSequenceIE() = delete;
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector &new_args) const override;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::size_t get_hidden_size() { return m_hidden_size; }
|
||||
|
||||
const std::vector<std::string> &get_activations() { return m_activations; }
|
||||
|
||||
const std::vector<float> &get_activations_alpha() { return m_activations_alpha; }
|
||||
|
||||
const std::vector<float> &get_activations_beta() { return m_activations_beta; }
|
||||
|
||||
float get_clip() { return m_clip; }
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
protected:
|
||||
op::RecurrentSequenceDirection m_direction;
|
||||
bool m_linear_before_reset;
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
class TRANSFORMATIONS_API LSTMSequenceIE : public ngraph::op::util::RNNCellBase {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
LSTMSequenceIE() = delete;
|
||||
|
||||
LSTMSequenceIE(const Output <Node> &X,
|
||||
const Output <Node> &H_t,
|
||||
const Output <Node> &C_t,
|
||||
const Output <Node> &seq_lenghts,
|
||||
const Output <Node> &WR,
|
||||
const Output <Node> &B,
|
||||
size_t hidden_size,
|
||||
ngraph::op::RecurrentSequenceDirection lstm_direction,
|
||||
const std::vector<std::string> &activations,
|
||||
const std::vector<float> &activations_alpha,
|
||||
const std::vector<float> &activations_beta,
|
||||
float clip);
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector &new_args) const override;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
ngraph::op::RecurrentSequenceDirection get_direction() { return m_direction; }
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
protected:
|
||||
ngraph::op::RecurrentSequenceDirection m_direction;
|
||||
};
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
@@ -0,0 +1,56 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
#include "ngraph/op/op.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace op {
|
||||
class TRANSFORMATIONS_API RNNSequenceIE : public ngraph::op::util::RNNCellBase {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
RNNSequenceIE(const Output <Node> &X,
|
||||
const Output <Node> &H_t,
|
||||
const Output <Node> &seq_lengths,
|
||||
const Output <Node> &WR,
|
||||
const Output <Node> &B,
|
||||
size_t hidden_size,
|
||||
op::RecurrentSequenceDirection direction,
|
||||
const std::vector<std::string> &activations,
|
||||
const std::vector<float> &activations_alpha,
|
||||
const std::vector<float> &activations_beta,
|
||||
float clip);
|
||||
|
||||
RNNSequenceIE() = delete;
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector &new_args) const override;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
std::size_t get_hidden_size() { return m_hidden_size; }
|
||||
|
||||
const std::vector<std::string> &get_activations() { return m_activations; }
|
||||
|
||||
const std::vector<float> &get_activations_alpha() { return m_activations_alpha; }
|
||||
|
||||
const std::vector<float> &get_activations_beta() { return m_activations_beta; }
|
||||
|
||||
float get_clip() { return m_clip; }
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
protected:
|
||||
op::RecurrentSequenceDirection m_direction;
|
||||
};
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
@@ -0,0 +1,56 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API BidirectionalLSTMSequenceDecomposition;
|
||||
class TRANSFORMATIONS_API BidirectionalGRUSequenceDecomposition;
|
||||
class TRANSFORMATIONS_API BidirectionalRNNSequenceDecomposition;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Decompose LSTMSequence to forward and reverse LSTMSequence.
|
||||
*
|
||||
*/
|
||||
|
||||
class ngraph::pass::BidirectionalLSTMSequenceDecomposition : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
BidirectionalLSTMSequenceDecomposition();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Decompose GRUSequence to forward and reverse GRUSequence.
|
||||
*
|
||||
*/
|
||||
|
||||
class ngraph::pass::BidirectionalGRUSequenceDecomposition : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
BidirectionalGRUSequenceDecomposition();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Decompose RNNSequence to forward and reverse RNNSequence.
|
||||
*
|
||||
*/
|
||||
|
||||
class ngraph::pass::BidirectionalRNNSequenceDecomposition : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
BidirectionalRNNSequenceDecomposition();
|
||||
};
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API ConvertLSTMSequenceMatcher;
|
||||
class TRANSFORMATIONS_API ConvertGRUSequenceMatcher;
|
||||
class TRANSFORMATIONS_API ConvertRNNSequenceMatcher;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Converts LSTMSequence to legacy LSTMSequenceIE.
|
||||
* SequenceIE op doesn't use seq_length input and num_direction (direction) attribute.
|
||||
* We squeeze num_direction dimension for all corresponding inputs and unsqueeze them after the SequenceIE op.
|
||||
*/
|
||||
|
||||
class ngraph::pass::ConvertLSTMSequenceMatcher : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
ConvertLSTMSequenceMatcher();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Converts GRUSequence to legacy GRUSequenceIE.
|
||||
* SequenceIE op doesn't use seq_length input and num_direction (direction) attribute.
|
||||
* We squeeze num_direction dimension for all corresponding inputs and unsqueeze them after the SequenceIE op.
|
||||
*/
|
||||
|
||||
class ngraph::pass::ConvertGRUSequenceMatcher : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
ConvertGRUSequenceMatcher();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Converts RNNSequence to legacy RNNSequenceIE.
|
||||
* SequenceIE op doesn't use seq_length input and num_direction (direction) attribute.
|
||||
* We squeeze num_direction dimension for all corresponding inputs and unsqueeze them after the SequenceIE op.
|
||||
*/
|
||||
|
||||
class ngraph::pass::ConvertRNNSequenceMatcher : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
ConvertRNNSequenceMatcher();
|
||||
};
|
||||
@@ -0,0 +1,85 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph_ops/gru_sequence_ie.hpp"
|
||||
#include "ngraph/op/util/recurrent_sequence.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::GRUSequenceIE, "GRUSequenceIE", 4);
|
||||
|
||||
op::GRUSequenceIE::GRUSequenceIE(const Output<Node>& X,
|
||||
const Output<Node>& H_t,
|
||||
const Output<Node>& seq_lenghts,
|
||||
const Output<Node>& WR,
|
||||
const Output<Node>& B,
|
||||
std::size_t hidden_size,
|
||||
op::RecurrentSequenceDirection direction,
|
||||
const std::vector<std::string>& activations,
|
||||
const std::vector<float>& activations_alpha,
|
||||
const std::vector<float>& activations_beta,
|
||||
float clip,
|
||||
bool linear_before_reset)
|
||||
: RNNCellBase({X, H_t, seq_lenghts, WR, B}, hidden_size, clip, activations, activations_alpha, activations_beta),
|
||||
m_direction(direction),
|
||||
m_linear_before_reset(linear_before_reset) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::GRUSequenceIE::validate_and_infer_types() {
|
||||
for (const auto& input : inputs()) {
|
||||
if (input.get_partial_shape().rank().is_dynamic()) {
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
|
||||
set_output_type(1, get_input_element_type(0), PartialShape::dynamic());
|
||||
return;
|
||||
}
|
||||
}
|
||||
// rank validation
|
||||
auto x_pshape = get_input_partial_shape(0);
|
||||
auto h_state_pshape = get_input_partial_shape(1);
|
||||
auto seq_lengths_pshape = get_input_partial_shape(2);
|
||||
auto wr_pshape = get_input_partial_shape(3);
|
||||
auto b_pshape = get_input_partial_shape(4);
|
||||
std::vector<ngraph::PartialShape> pshapes = {x_pshape, h_state_pshape, seq_lengths_pshape, wr_pshape, b_pshape};
|
||||
|
||||
std::vector<std::string> in_names = {"X", "H", "seq_lenghts", "WR", "B"};
|
||||
// num_direction dimension should be squeezed, we don't support bidirectional case
|
||||
std::vector<size_t> ranks = {3, 2, 1, 2, 1};
|
||||
for (size_t i = 0; i < pshapes.size(); ++i) {
|
||||
NGRAPH_CHECK((pshapes[i].rank().get_length() == ranks[i]),
|
||||
"GRUSequenceIE ",
|
||||
in_names[i],
|
||||
" input rank is not correct.");
|
||||
}
|
||||
|
||||
element::Type arg_type = get_input_element_type(0);
|
||||
PartialShape output_shape_0{PartialShape::dynamic(3)};
|
||||
PartialShape output_shape_1{PartialShape::dynamic(2)};
|
||||
if (get_input_partial_shape(0).is_static()) {
|
||||
size_t batch_size = get_input_partial_shape(0).get_shape()[0];
|
||||
size_t seq_length = get_input_partial_shape(0).get_shape()[1];
|
||||
output_shape_0 = Shape{batch_size, seq_length, m_hidden_size};
|
||||
output_shape_1 = Shape{batch_size, m_hidden_size};
|
||||
}
|
||||
set_output_type(0, arg_type, output_shape_0);
|
||||
set_output_type(1, arg_type, output_shape_1);
|
||||
}
|
||||
|
||||
bool op::GRUSequenceIE::visit_attributes(AttributeVisitor& visitor) {
|
||||
visitor.on_attribute("direction", m_direction);
|
||||
visitor.on_attribute("linear_before_reset", m_linear_before_reset);
|
||||
return op::util::RNNCellBase::visit_attributes(visitor);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::GRUSequenceIE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<op::GRUSequenceIE>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3),
|
||||
new_args.at(4), m_hidden_size, m_direction, m_activations, m_activations_alpha, m_activations_beta, m_clip,
|
||||
m_linear_before_reset);
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph_ops/lstm_sequence_ie.hpp"
|
||||
#include "ngraph/op/util/recurrent_sequence.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::LSTMSequenceIE, "LSTMSequenceIE", 5);
|
||||
|
||||
op::LSTMSequenceIE::LSTMSequenceIE(const Output<Node> &X,
|
||||
const Output<Node> &H_t,
|
||||
const Output<Node> &C_t,
|
||||
const Output<Node> &seq_lenghts,
|
||||
const Output<Node> &WR,
|
||||
const Output<Node> &B,
|
||||
std::size_t hidden_size,
|
||||
ngraph::op::RecurrentSequenceDirection direction,
|
||||
const std::vector<std::string> &activations,
|
||||
const std::vector<float> &activations_alpha,
|
||||
const std::vector<float> &activations_beta,
|
||||
float clip)
|
||||
: RNNCellBase({X, H_t, C_t, seq_lenghts, WR, B}, hidden_size, clip, activations, activations_alpha, activations_beta),
|
||||
m_direction(direction) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::LSTMSequenceIE::validate_and_infer_types() {
|
||||
for (const auto& input : inputs()) {
|
||||
if (input.get_partial_shape().rank().is_dynamic()) {
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
|
||||
set_output_type(1, get_input_element_type(0), PartialShape::dynamic());
|
||||
set_output_type(2, get_input_element_type(0), PartialShape::dynamic());
|
||||
return;
|
||||
}
|
||||
}
|
||||
// rank validation
|
||||
auto x_pshape = get_input_partial_shape(0);
|
||||
auto h_state_pshape = get_input_partial_shape(1);
|
||||
auto c_state_pshape = get_input_partial_shape(2);
|
||||
auto seq_lengths_pshape = get_input_partial_shape(3);
|
||||
auto wr_pshape = get_input_partial_shape(4);
|
||||
auto b_pshape = get_input_partial_shape(5);
|
||||
|
||||
std::vector<ngraph::PartialShape> pshapes = {x_pshape, h_state_pshape, c_state_pshape,
|
||||
seq_lengths_pshape, wr_pshape, b_pshape};
|
||||
std::vector<std::string> in_names = {"X", "H", "C", "seq_lenghts", "WR", "B"};
|
||||
// num_direction dimension should be squeezed, we don't support bidirectional case
|
||||
std::vector<size_t> ranks = {3, 2, 2, 1, 2, 1};
|
||||
for (size_t i = 0; i < pshapes.size(); ++i) {
|
||||
NGRAPH_CHECK((pshapes[i].rank().get_length() == ranks[i]),
|
||||
"LSTMSequenceIE ",
|
||||
in_names[i],
|
||||
" input rank is not correct.");
|
||||
}
|
||||
|
||||
element::Type arg_type = get_input_element_type(0);
|
||||
PartialShape output_shape_0{PartialShape::dynamic(3)};
|
||||
PartialShape output_shape_1{PartialShape::dynamic(2)};
|
||||
if (get_input_partial_shape(0).is_static()) {
|
||||
size_t batch_size = get_input_partial_shape(0).get_shape()[0];
|
||||
size_t seq_length = get_input_partial_shape(0).get_shape()[1];
|
||||
output_shape_0 = Shape{batch_size, seq_length, m_hidden_size};
|
||||
output_shape_1 = Shape{batch_size, m_hidden_size};
|
||||
}
|
||||
set_output_type(0, arg_type, output_shape_0);
|
||||
set_output_type(1, arg_type, output_shape_1);
|
||||
set_output_type(2, arg_type, output_shape_1);
|
||||
}
|
||||
|
||||
bool ngraph::op::LSTMSequenceIE::visit_attributes(AttributeVisitor& visitor) {
|
||||
visitor.on_attribute("direction", m_direction);
|
||||
return op::util::RNNCellBase::visit_attributes(visitor);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::LSTMSequenceIE::clone_with_new_inputs(const OutputVector &new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<op::LSTMSequenceIE>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3),
|
||||
new_args.at(4), new_args.at(5), m_hidden_size, m_direction, m_activations, m_activations_alpha, m_activations_beta,
|
||||
m_clip);
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph_ops/rnn_sequence_ie.hpp"
|
||||
#include "ngraph/op/util/recurrent_sequence.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::RNNSequenceIE, "RNNSequenceIE", 4);
|
||||
|
||||
op::RNNSequenceIE::RNNSequenceIE(const Output<Node>& X,
|
||||
const Output<Node>& H_t,
|
||||
const Output<Node>& seq_lengths, // actually not supported
|
||||
const Output<Node>& WR,
|
||||
const Output<Node>& B,
|
||||
std::size_t hidden_size,
|
||||
op::RecurrentSequenceDirection direction,
|
||||
const std::vector<std::string>& activations,
|
||||
const std::vector<float>& activations_alpha,
|
||||
const std::vector<float>& activations_beta,
|
||||
float clip)
|
||||
: RNNCellBase({X, H_t, seq_lengths, WR, B}, hidden_size, clip, activations, activations_alpha, activations_beta),
|
||||
m_direction(direction) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::RNNSequenceIE::validate_and_infer_types() {
|
||||
for (const auto& input : inputs()) {
|
||||
if (input.get_partial_shape().rank().is_dynamic()) {
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
|
||||
set_output_type(1, get_input_element_type(0), PartialShape::dynamic());
|
||||
return;
|
||||
}
|
||||
}
|
||||
// rank validation
|
||||
auto x_pshape = get_input_partial_shape(0);
|
||||
auto h_state_pshape = get_input_partial_shape(1);
|
||||
auto seq_lengths_pshape = get_input_partial_shape(2);
|
||||
auto wr_pshape = get_input_partial_shape(3);
|
||||
auto b_pshape = get_input_partial_shape(4);
|
||||
|
||||
std::vector<ngraph::PartialShape> pshapes = {x_pshape, h_state_pshape, seq_lengths_pshape, wr_pshape, b_pshape};
|
||||
std::vector<std::string> in_names = {"X", "H", "seq_lenghts", "WR", "B"};
|
||||
// num_direction dimension should be squeezed, we don't support bidirectional case
|
||||
std::vector<size_t> ranks = {3, 2, 1, 2, 1};
|
||||
for (size_t i = 0; i < pshapes.size(); ++i) {
|
||||
NGRAPH_CHECK((pshapes[i].rank().get_length() == ranks[i]),
|
||||
"RNNSequenceIE ",
|
||||
in_names[i],
|
||||
" input rank is not correct.");
|
||||
}
|
||||
|
||||
element::Type arg_type = get_input_element_type(0);
|
||||
PartialShape output_shape_0{PartialShape::dynamic(3)};
|
||||
PartialShape output_shape_1{PartialShape::dynamic(2)};
|
||||
if (get_input_partial_shape(0).is_static()) {
|
||||
size_t batch_size = get_input_partial_shape(0).get_shape()[0];
|
||||
size_t seq_length = get_input_partial_shape(0).get_shape()[1];
|
||||
output_shape_0 = Shape{batch_size, seq_length, m_hidden_size};
|
||||
output_shape_1 = Shape{batch_size, m_hidden_size};
|
||||
}
|
||||
set_output_type(0, arg_type, output_shape_0);
|
||||
set_output_type(1, arg_type, output_shape_1);
|
||||
}
|
||||
|
||||
bool op::RNNSequenceIE::visit_attributes(AttributeVisitor& visitor) {
|
||||
visitor.on_attribute("direction", m_direction);
|
||||
return op::util::RNNCellBase::visit_attributes(visitor);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::RNNSequenceIE::clone_with_new_inputs(const ngraph::OutputVector &new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<op::RNNSequenceIE>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3),
|
||||
new_args.at(4), m_hidden_size, m_direction, m_activations, m_activations_alpha, m_activations_beta, m_clip);
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/bidirectional_sequences_decomposition.hpp"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
ngraph::pass::BidirectionalLSTMSequenceDecomposition::BidirectionalLSTMSequenceDecomposition() {
|
||||
auto lstm_sequence_ngraph = ngraph::pattern::wrap_type<ngraph::op::v5::LSTMSequence>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
|
||||
auto lstm_sequence = std::dynamic_pointer_cast<ngraph::op::v5::LSTMSequence>(m.get_match_root());
|
||||
if (!lstm_sequence) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto axis_0 = ngraph::opset4::Constant::create(element::i64, Shape{}, {0});
|
||||
auto axis_1 = ngraph::opset4::Constant::create(element::i64, Shape{}, {1});
|
||||
auto H = std::make_shared<opset4::Split>(lstm_sequence->input_value(1), axis_1, 2);
|
||||
auto C = std::make_shared<opset4::Split>(lstm_sequence->input_value(2), axis_1, 2);
|
||||
auto W = std::make_shared<opset4::Split>(lstm_sequence->input_value(4), axis_0, 2);
|
||||
auto R = std::make_shared<opset4::Split>(lstm_sequence->input_value(5), axis_0, 2);
|
||||
auto B = std::make_shared<opset4::Split>(lstm_sequence->input_value(6), axis_0, 2);
|
||||
auto lstm_sequence_forward = std::make_shared<ngraph::op::v5::LSTMSequence>(
|
||||
lstm_sequence->input_value(0),
|
||||
H->output(0),
|
||||
C->output(0),
|
||||
lstm_sequence->input_value(3),
|
||||
W->output(0),
|
||||
R->output(0),
|
||||
B->output(0),
|
||||
lstm_sequence->get_hidden_size(),
|
||||
ngraph::op::RecurrentSequenceDirection::FORWARD,
|
||||
lstm_sequence->get_activations_alpha(),
|
||||
lstm_sequence->get_activations_beta(),
|
||||
lstm_sequence->get_activations(),
|
||||
lstm_sequence->get_clip());
|
||||
|
||||
auto lstm_sequence_reverse = std::make_shared<ngraph::op::v5::LSTMSequence>(
|
||||
lstm_sequence->input_value(0),
|
||||
H->output(1),
|
||||
C->output(1),
|
||||
lstm_sequence->input_value(3),
|
||||
W->output(1),
|
||||
R->output(1),
|
||||
B->output(1),
|
||||
lstm_sequence->get_hidden_size(),
|
||||
ngraph::op::RecurrentSequenceDirection::REVERSE,
|
||||
lstm_sequence->get_activations_alpha(),
|
||||
lstm_sequence->get_activations_beta(),
|
||||
lstm_sequence->get_activations(),
|
||||
lstm_sequence->get_clip());
|
||||
|
||||
auto concat_0 = std::make_shared<opset4::Concat>(OutputVector{lstm_sequence_forward->output(0),
|
||||
lstm_sequence_reverse->output(0)}, 1);
|
||||
auto concat_1 = std::make_shared<opset4::Concat>(OutputVector{lstm_sequence_forward->output(1),
|
||||
lstm_sequence_reverse->output(1)}, 1);
|
||||
auto concat_2 = std::make_shared<opset4::Concat>(OutputVector{lstm_sequence_forward->output(2),
|
||||
lstm_sequence_reverse->output(2)}, 1);
|
||||
ngraph::copy_runtime_info(lstm_sequence, {H, C, W, R, B, lstm_sequence_forward, lstm_sequence_reverse,
|
||||
concat_0, concat_1, concat_2});
|
||||
concat_0->set_friendly_name(lstm_sequence->get_friendly_name()+".0");
|
||||
concat_1->set_friendly_name(lstm_sequence->get_friendly_name()+".1");
|
||||
concat_2->set_friendly_name(lstm_sequence->get_friendly_name()+".2");
|
||||
ngraph::replace_node(lstm_sequence, {concat_0->output(0), concat_1->output(0), concat_2->output(0)});
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(lstm_sequence_ngraph, "BidirectionalLSTMSequenceDecomposition");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ngraph::pass::BidirectionalGRUSequenceDecomposition::BidirectionalGRUSequenceDecomposition() {
|
||||
auto gru_sequence_ngraph = ngraph::pattern::wrap_type<ngraph::op::v5::GRUSequence>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
|
||||
auto gru_sequence = std::dynamic_pointer_cast<ngraph::op::v5::GRUSequence>(m.get_match_root());
|
||||
if (!gru_sequence) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto axis_0 = ngraph::opset4::Constant::create(element::i64, Shape{}, {0});
|
||||
auto axis_1 = ngraph::opset4::Constant::create(element::i64, Shape{}, {1});
|
||||
auto H = std::make_shared<opset4::Split>(gru_sequence->input_value(1), axis_1, 2);
|
||||
auto W = std::make_shared<opset4::Split>(gru_sequence->input_value(3), axis_0, 2);
|
||||
auto R = std::make_shared<opset4::Split>(gru_sequence->input_value(4), axis_0, 2);
|
||||
auto B = std::make_shared<opset4::Split>(gru_sequence->input_value(5), axis_0, 2);
|
||||
auto gru_sequence_forward = std::make_shared<ngraph::op::v5::GRUSequence>(
|
||||
gru_sequence->input_value(0),
|
||||
H->output(0),
|
||||
gru_sequence->input_value(2),
|
||||
W->output(0),
|
||||
R->output(0),
|
||||
B->output(0),
|
||||
gru_sequence->get_hidden_size(),
|
||||
ngraph::op::RecurrentSequenceDirection::FORWARD,
|
||||
gru_sequence->get_activations(),
|
||||
gru_sequence->get_activations_alpha(),
|
||||
gru_sequence->get_activations_beta(),
|
||||
gru_sequence->get_clip(),
|
||||
gru_sequence->get_linear_before_reset());
|
||||
|
||||
auto gru_sequence_reverse = std::make_shared<ngraph::op::v5::GRUSequence>(
|
||||
gru_sequence->input_value(0),
|
||||
H->output(1),
|
||||
gru_sequence->input_value(2),
|
||||
W->output(1),
|
||||
R->output(1),
|
||||
B->output(1),
|
||||
gru_sequence->get_hidden_size(),
|
||||
ngraph::op::RecurrentSequenceDirection::REVERSE,
|
||||
gru_sequence->get_activations(),
|
||||
gru_sequence->get_activations_alpha(),
|
||||
gru_sequence->get_activations_beta(),
|
||||
gru_sequence->get_clip(),
|
||||
gru_sequence->get_linear_before_reset());
|
||||
|
||||
auto concat_0 = std::make_shared<opset4::Concat>(OutputVector{gru_sequence_forward->output(0),
|
||||
gru_sequence_reverse->output(0)}, 1);
|
||||
auto concat_1 = std::make_shared<opset4::Concat>(OutputVector{gru_sequence_forward->output(1),
|
||||
gru_sequence_reverse->output(1)}, 1);
|
||||
ngraph::copy_runtime_info(gru_sequence, {H, W, R, B, gru_sequence_forward, gru_sequence_reverse,
|
||||
concat_0, concat_1});
|
||||
concat_0->set_friendly_name(gru_sequence->get_friendly_name()+".0");
|
||||
concat_1->set_friendly_name(gru_sequence->get_friendly_name()+".1");
|
||||
ngraph::replace_node(gru_sequence, {concat_0->output(0), concat_1->output(0)});
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(gru_sequence_ngraph, "BidirectionalGRUSequenceDecomposition");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ngraph::pass::BidirectionalRNNSequenceDecomposition::BidirectionalRNNSequenceDecomposition() {
|
||||
auto rnn_sequence_ngraph = ngraph::pattern::wrap_type<ngraph::op::v5::RNNSequence>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
|
||||
auto rnn_sequence = std::dynamic_pointer_cast<ngraph::op::v5::RNNSequence>(m.get_match_root());
|
||||
if (!rnn_sequence) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto axis_0 = ngraph::opset4::Constant::create(element::i64, Shape{}, {0});
|
||||
auto axis_1 = ngraph::opset4::Constant::create(element::i64, Shape{}, {1});
|
||||
auto H = std::make_shared<opset4::Split>(rnn_sequence->input_value(1), axis_1, 2);
|
||||
auto W = std::make_shared<opset4::Split>(rnn_sequence->input_value(3), axis_0, 2);
|
||||
auto R = std::make_shared<opset4::Split>(rnn_sequence->input_value(4), axis_0, 2);
|
||||
auto B = std::make_shared<opset4::Split>(rnn_sequence->input_value(5), axis_0, 2);
|
||||
auto rnn_sequence_forward = std::make_shared<ngraph::op::v5::RNNSequence>(
|
||||
rnn_sequence->input_value(0),
|
||||
H->output(0),
|
||||
rnn_sequence->input_value(2),
|
||||
W->output(0),
|
||||
R->output(0),
|
||||
B->output(0),
|
||||
rnn_sequence->get_hidden_size(),
|
||||
ngraph::op::RecurrentSequenceDirection::FORWARD,
|
||||
rnn_sequence->get_activations(),
|
||||
rnn_sequence->get_activations_alpha(),
|
||||
rnn_sequence->get_activations_beta(),
|
||||
rnn_sequence->get_clip());
|
||||
|
||||
auto rnn_sequence_reverse = std::make_shared<ngraph::op::v5::RNNSequence>(
|
||||
rnn_sequence->input_value(0),
|
||||
H->output(1),
|
||||
rnn_sequence->input_value(2),
|
||||
W->output(1),
|
||||
R->output(1),
|
||||
B->output(1),
|
||||
rnn_sequence->get_hidden_size(),
|
||||
ngraph::op::RecurrentSequenceDirection::REVERSE,
|
||||
rnn_sequence->get_activations(),
|
||||
rnn_sequence->get_activations_alpha(),
|
||||
rnn_sequence->get_activations_beta(),
|
||||
rnn_sequence->get_clip());
|
||||
|
||||
auto concat_0 = std::make_shared<opset4::Concat>(OutputVector{rnn_sequence_forward->output(0),
|
||||
rnn_sequence_reverse->output(0)}, 1);
|
||||
auto concat_1 = std::make_shared<opset4::Concat>(OutputVector{rnn_sequence_forward->output(1),
|
||||
rnn_sequence_reverse->output(1)}, 1);
|
||||
ngraph::copy_runtime_info(rnn_sequence, {H, W, R, B, rnn_sequence_forward, rnn_sequence_reverse,
|
||||
concat_0, concat_1});
|
||||
concat_0->set_friendly_name(rnn_sequence->get_friendly_name() + ".0");
|
||||
concat_1->set_friendly_name(rnn_sequence->get_friendly_name() + ".1");
|
||||
ngraph::replace_node(rnn_sequence, {concat_0->output(0), concat_1->output(0)});
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(rnn_sequence_ngraph, "BidirectionalRNNSequenceDecomposition");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
@@ -32,6 +32,7 @@
|
||||
#include <transformations/convert_opset1_to_legacy/convert_strided_slice_to_crop.hpp>
|
||||
#include <transformations/convert_subtract.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_selu_to_selu_ie.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_sequences_to_sequences_ie.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_swish_to_swish_ie.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_tile_to_ie_tile.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_topk_to_topk_ie.hpp>
|
||||
@@ -157,6 +158,9 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
||||
anchor->add_matcher<ngraph::pass::ConvertTopKToTopKIEMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertNMSToNMSIEMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertNMS4ToLegacyMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertGRUSequenceMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertRNNSequenceMatcher>();
|
||||
anchor->add_matcher<ngraph::pass::ConvertLSTMSequenceMatcher>();
|
||||
anchor->set_name("ngraph::pass::ConvertOpSet1ToLegacy");
|
||||
|
||||
// List of final conversion transformations that must to be executed
|
||||
|
||||
@@ -0,0 +1,195 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/convert_opset1_to_legacy/convert_sequences_to_sequences_ie.hpp"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
#include <ngraph_ops/lstm_sequence_ie.hpp>
|
||||
#include <ngraph_ops/gru_sequence_ie.hpp>
|
||||
#include <ngraph_ops/rnn_sequence_ie.hpp>
|
||||
|
||||
ngraph::pass::ConvertLSTMSequenceMatcher::ConvertLSTMSequenceMatcher() {
|
||||
auto lstm_sequence_ngraph = ngraph::pattern::wrap_type<ngraph::op::v5::LSTMSequence>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
|
||||
auto lstm_sequence = std::dynamic_pointer_cast<ngraph::op::v5::LSTMSequence>(m.get_match_root());
|
||||
if (!lstm_sequence) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& W = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
|
||||
lstm_sequence->input_value(4).get_node_shared_ptr());
|
||||
if (!W) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto& R = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
|
||||
lstm_sequence->input_value(5).get_node_shared_ptr());
|
||||
if (!R) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// for forward/reverse cases we can squeeze num_direction dimension
|
||||
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input(1).get_source_output(), axis_1);
|
||||
auto in_2 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input(2).get_source_output(), axis_1);
|
||||
auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::NodeVector({W, R}), 2);
|
||||
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
|
||||
auto in_3 = std::make_shared<ngraph::opset4::Squeeze>(concat->output(0), axis_2);
|
||||
auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(lstm_sequence->input(6).get_source_output(), axis_2);
|
||||
auto lstm_sequence_ie = std::make_shared<ngraph::op::LSTMSequenceIE>(
|
||||
lstm_sequence->input(0).get_source_output(), // X
|
||||
in_1, // initial_hidden_state
|
||||
in_2, // initial_cell_state
|
||||
lstm_sequence->input(3).get_source_output(),
|
||||
in_3, // WR
|
||||
in_4, // B
|
||||
lstm_sequence->get_hidden_size(),
|
||||
lstm_sequence->get_direction(),
|
||||
lstm_sequence->get_activations(),
|
||||
lstm_sequence->get_activations_alpha(),
|
||||
lstm_sequence->get_activations_beta(),
|
||||
lstm_sequence->get_clip());
|
||||
|
||||
auto unsqueeze_axis = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto unsqueeze_1 = std::make_shared<ngraph::opset4::Unsqueeze>(lstm_sequence_ie->output(0), unsqueeze_axis);
|
||||
auto unsqueeze_2 = std::make_shared<ngraph::opset4::Unsqueeze>(lstm_sequence_ie->output(1), unsqueeze_axis);
|
||||
auto unsqueeze_3 = std::make_shared<ngraph::opset4::Unsqueeze>(lstm_sequence_ie->output(2), unsqueeze_axis);
|
||||
|
||||
ngraph::copy_runtime_info(lstm_sequence, {concat, lstm_sequence_ie, in_1, in_2, in_3, in_4, unsqueeze_1,
|
||||
unsqueeze_2, unsqueeze_3});
|
||||
unsqueeze_1->set_friendly_name(lstm_sequence->get_friendly_name()+".0");
|
||||
unsqueeze_2->set_friendly_name(lstm_sequence->get_friendly_name()+".1");
|
||||
unsqueeze_3->set_friendly_name(lstm_sequence->get_friendly_name()+".2");
|
||||
ngraph::replace_node(lstm_sequence, {unsqueeze_1->output(0), unsqueeze_2->output(0), unsqueeze_3->output(0)});
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(lstm_sequence_ngraph, "ConvertLSTMSequenceToLSTMSequenceIE");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ngraph::pass::ConvertGRUSequenceMatcher::ConvertGRUSequenceMatcher() {
|
||||
auto gru_sequence_ngraph = ngraph::pattern::wrap_type<ngraph::op::v5::GRUSequence>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
|
||||
auto gru_sequence = std::dynamic_pointer_cast<ngraph::op::v5::GRUSequence>(m.get_match_root());
|
||||
if (!gru_sequence) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto W = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
|
||||
gru_sequence->input_value(3).get_node_shared_ptr());
|
||||
if (!W) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto R = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
|
||||
gru_sequence->input_value(4).get_node_shared_ptr());
|
||||
if (!R) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// todo: add exception?
|
||||
if (gru_sequence->get_direction() == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
|
||||
return false;
|
||||
|
||||
// for forward/reverse cases we can squeeze num_direction dimension
|
||||
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->input(1).get_source_output(), axis_1);
|
||||
auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::NodeVector({W, R}), 2);
|
||||
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
|
||||
auto in_3 = std::make_shared<ngraph::opset4::Squeeze>(concat->output(0), axis_2);
|
||||
auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(gru_sequence->input(5).get_source_output(), axis_2);
|
||||
|
||||
auto gru_sequence_ie = std::make_shared<ngraph::op::GRUSequenceIE>(
|
||||
gru_sequence->input(0).get_source_output(), // X
|
||||
in_1, // initial_hidden_state
|
||||
gru_sequence->input(2).get_source_output(),
|
||||
in_3, // WR
|
||||
in_4, // B
|
||||
gru_sequence->get_hidden_size(),
|
||||
gru_sequence->get_direction(),
|
||||
gru_sequence->get_activations(),
|
||||
gru_sequence->get_activations_alpha(),
|
||||
gru_sequence->get_activations_beta(),
|
||||
gru_sequence->get_clip(),
|
||||
gru_sequence->get_linear_before_reset());
|
||||
|
||||
auto unsqueeze_axis = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto unsqueeze_1 = std::make_shared<ngraph::opset4::Unsqueeze>(gru_sequence_ie->output(0), unsqueeze_axis);
|
||||
auto unsqueeze_2 = std::make_shared<ngraph::opset4::Unsqueeze>(gru_sequence_ie->output(1), unsqueeze_axis);
|
||||
|
||||
ngraph::copy_runtime_info(gru_sequence, {concat, gru_sequence_ie, unsqueeze_1, unsqueeze_2, in_1, in_3, in_4});
|
||||
unsqueeze_1->set_friendly_name(gru_sequence->get_friendly_name()+".0");
|
||||
unsqueeze_2->set_friendly_name(gru_sequence->get_friendly_name()+".1");
|
||||
ngraph::replace_node(gru_sequence, {unsqueeze_1, unsqueeze_2});
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(gru_sequence_ngraph, "ConvertGRUSequenceToGRUSequenceIE");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
ngraph::pass::ConvertRNNSequenceMatcher::ConvertRNNSequenceMatcher() {
|
||||
auto rnn_sequence_ngraph = ngraph::pattern::wrap_type<ngraph::op::v5::RNNSequence>();
|
||||
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher &m) {
|
||||
auto rnn_sequence = std::dynamic_pointer_cast<ngraph::op::v5::RNNSequence>(m.get_match_root());
|
||||
if (!rnn_sequence) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto W = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
|
||||
rnn_sequence->input_value(3).get_node_shared_ptr());
|
||||
if (!W) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto R = std::dynamic_pointer_cast<ngraph::opset4::Constant>(
|
||||
rnn_sequence->input_value(4).get_node_shared_ptr());
|
||||
if (!R) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// for forward/reverse cases we can squeeze num_direction dimension
|
||||
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->input(1).get_source_output(), axis_1);
|
||||
auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::NodeVector({W, R}), 2);
|
||||
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
|
||||
auto in_3 = std::make_shared<ngraph::opset4::Squeeze>(concat->output(0), axis_2);
|
||||
auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(rnn_sequence->input(5).get_source_output(), axis_2);
|
||||
auto rnn_sequence_ie = std::make_shared<ngraph::op::RNNSequenceIE>(
|
||||
rnn_sequence->input(0).get_source_output(), // X
|
||||
in_1, // initial_hidden_state
|
||||
rnn_sequence->input_value(2),
|
||||
in_3, // WR
|
||||
in_4, // B
|
||||
rnn_sequence->get_hidden_size(),
|
||||
rnn_sequence->get_direction(),
|
||||
rnn_sequence->get_activations(),
|
||||
rnn_sequence->get_activations_alpha(),
|
||||
rnn_sequence->get_activations_beta(),
|
||||
rnn_sequence->get_clip());
|
||||
|
||||
auto unsqueeze_axis = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto unsqueeze_1 = std::make_shared<ngraph::opset4::Unsqueeze>(rnn_sequence_ie->output(0), unsqueeze_axis);
|
||||
auto unsqueeze_2 = std::make_shared<ngraph::opset4::Unsqueeze>(rnn_sequence_ie->output(1), unsqueeze_axis);
|
||||
|
||||
ngraph::copy_runtime_info(rnn_sequence, {concat, rnn_sequence_ie, in_1, in_3, in_4, unsqueeze_1,
|
||||
unsqueeze_2});
|
||||
unsqueeze_1->set_friendly_name(rnn_sequence->get_friendly_name()+".0");
|
||||
unsqueeze_2->set_friendly_name(rnn_sequence->get_friendly_name()+".1");
|
||||
ngraph::replace_node(rnn_sequence, {unsqueeze_1->output(0), unsqueeze_2->output(0)});
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(rnn_sequence_ngraph, "ConvertRNNSequenceToRNNSequenceIE");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
@@ -85,15 +85,15 @@ TEST(ConvertFunctionToCNNNetworkTests, OpsShouldBeConvertedToIERepresentation) {
|
||||
std::make_shared<ngraph::opset4::GroupConvolution>(),
|
||||
std::make_shared<ngraph::opset4::GroupConvolutionBackpropData>(),
|
||||
std::make_shared<ngraph::opset4::GRUCell>(),
|
||||
// std::make_shared<ngraph::opset4::GRUSequence>(), todo: enable after GRUSequence support
|
||||
// std::make_shared<ngraph::op::v5::GRUSequence>(), todo: enable after GRUSequence support
|
||||
std::make_shared<ngraph::opset4::HardSigmoid>(),
|
||||
std::make_shared<ngraph::opset4::LRN>(),
|
||||
std::make_shared<ngraph::opset4::LSTMCell>(),
|
||||
// std::make_shared<ngraph::opset4::LSTMSequence>(), todo: enable after LSTMSequence support
|
||||
// std::make_shared<ngraph::op::v5::LSTMSequence>(), todo: enable after LSTMSequence support
|
||||
std::make_shared<ngraph::opset4::NonMaxSuppression>(),
|
||||
std::make_shared<ngraph::opset4::NormalizeL2>(),
|
||||
std::make_shared<ngraph::opset4::RNNCell>(),
|
||||
// std::make_shared<ngraph::opset4::RNNSequence>(), todo: enable after RNNSequence support
|
||||
// std::make_shared<ngraph::op::v5::RNNSequence>(), todo: enable after RNNSequence support
|
||||
std::make_shared<ngraph::opset4::OneHot>(),
|
||||
std::make_shared<ngraph::opset4::Pad>(),
|
||||
std::make_shared<ngraph::opset4::PriorBoxClustered>(),
|
||||
|
||||
@@ -0,0 +1,271 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_sequences_to_sequences_ie.hpp>
|
||||
|
||||
#include <ngraph/ops.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph_ops/gru_sequence_ie.hpp>
|
||||
#include <ngraph_ops/rnn_sequence_ie.hpp>
|
||||
#include <ngraph_ops/lstm_sequence_ie.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
#include "common_test_utils/test_common.hpp"
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
|
||||
TEST(TransformationTests, GRUSequenceConversionTest) {
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
std::shared_ptr<ngraph::op::v5::GRUSequence> sequence;
|
||||
|
||||
const size_t batch_size = 2;
|
||||
const size_t input_size = 3;
|
||||
const size_t hidden_size = 3;
|
||||
const size_t gates_count = 3;
|
||||
const size_t num_directions = 1;
|
||||
{
|
||||
const auto X = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32,
|
||||
ngraph::Shape{batch_size, 1, input_size});
|
||||
const auto W =
|
||||
std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{num_directions, gates_count * hidden_size, input_size});
|
||||
const auto R =
|
||||
std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{num_directions, gates_count * hidden_size, hidden_size});
|
||||
const auto H_t = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32,
|
||||
ngraph::Shape{batch_size, num_directions, hidden_size});
|
||||
const auto B = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{num_directions, gates_count * hidden_size});
|
||||
|
||||
const auto seq_len = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i32, ngraph::Shape{batch_size});
|
||||
sequence = std::make_shared<ngraph::op::v5::GRUSequence>(X, H_t, seq_len, W, R, B, hidden_size,
|
||||
ngraph::op::RecurrentSequenceDirection::FORWARD);
|
||||
sequence->set_friendly_name("test_sequence");
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{sequence}, ngraph::ParameterVector{X, H_t});
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertGRUSequenceMatcher>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
const auto X = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32,
|
||||
ngraph::Shape{batch_size, 1, input_size});
|
||||
const auto W =
|
||||
std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{num_directions, gates_count * hidden_size, input_size});
|
||||
const auto R =
|
||||
std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{num_directions, gates_count * hidden_size, hidden_size});
|
||||
const auto H_t = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32,
|
||||
ngraph::Shape{batch_size, num_directions, hidden_size});
|
||||
const auto B = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{num_directions, gates_count * hidden_size});
|
||||
|
||||
const auto seq_len = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i32, ngraph::Shape{batch_size}, 1);
|
||||
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(H_t, axis_1);
|
||||
auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::NodeVector({W, R}), 2);
|
||||
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
|
||||
auto in_3 = std::make_shared<ngraph::opset4::Squeeze>(concat->output(0), axis_2);
|
||||
auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(B, axis_2);
|
||||
auto sequence_ie = std::make_shared<ngraph::op::GRUSequenceIE>(X,
|
||||
in_1,
|
||||
seq_len, // this input is not supported
|
||||
in_3,
|
||||
in_4,
|
||||
sequence->get_hidden_size(),
|
||||
sequence->get_direction(),
|
||||
sequence->get_activations(),
|
||||
sequence->get_activations_alpha(),
|
||||
sequence->get_activations_beta(),
|
||||
sequence->get_clip(),
|
||||
sequence->get_linear_before_reset());
|
||||
sequence_ie->set_friendly_name("test_sequence");
|
||||
|
||||
auto unsqueeze_axis = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto unsqueeze_1 = std::make_shared<ngraph::opset4::Unsqueeze>(sequence_ie->output(0), unsqueeze_axis);
|
||||
auto unsqueeze_2 = std::make_shared<ngraph::opset4::Unsqueeze>(sequence_ie->output(1), unsqueeze_axis);
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{unsqueeze_1}, ngraph::ParameterVector{X, H_t});
|
||||
}
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
|
||||
auto result_node_of_converted_f = f->get_output_op(0);
|
||||
auto sequence_node = result_node_of_converted_f->input_value(0).get_node_shared_ptr()
|
||||
->input_value(0).get_node_shared_ptr();
|
||||
}
|
||||
|
||||
TEST(TransformationTests, RNNSequenceConversionTest) {
|
||||
const size_t hidden_size = 3;
|
||||
const size_t num_directions = 1;
|
||||
const size_t batch_size = 2;
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
std::shared_ptr<ngraph::op::v5::RNNSequence> sequence;
|
||||
|
||||
{
|
||||
auto X = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{batch_size, 1, 3});
|
||||
auto H = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{batch_size, num_directions, 3});
|
||||
auto W = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{num_directions, 3, 3});
|
||||
auto R = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{num_directions, 3, 3});
|
||||
auto B = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{num_directions, 3});
|
||||
auto seq_len = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{2});
|
||||
sequence = std::make_shared<ngraph::op::v5::RNNSequence>(X, H, seq_len, W, R, B, hidden_size,
|
||||
ngraph::op::RecurrentSequenceDirection::FORWARD);
|
||||
sequence->set_friendly_name("test_sequence");
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{sequence}, ngraph::ParameterVector{X, H});
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertRNNSequenceMatcher>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
auto X = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{batch_size, 1, 3});
|
||||
auto H = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, ngraph::Shape{batch_size, num_directions, 3});
|
||||
auto W = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{num_directions, 3, 3});
|
||||
auto R = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{num_directions, 3, 3});
|
||||
auto B = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{num_directions, 3});
|
||||
auto seq_len = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, ngraph::Shape{batch_size}, 1);
|
||||
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(H, axis_1);
|
||||
auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::NodeVector({W, R}), 2);
|
||||
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
|
||||
auto in_3 = std::make_shared<ngraph::opset4::Squeeze>(concat->output(0), axis_2);
|
||||
auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(B, axis_2);
|
||||
auto sequence_ie = std::make_shared<ngraph::op::RNNSequenceIE>(X,
|
||||
in_1,
|
||||
seq_len,
|
||||
in_3,
|
||||
in_4,
|
||||
sequence->get_hidden_size(),
|
||||
sequence->get_direction(),
|
||||
sequence->get_activations(),
|
||||
sequence->get_activations_alpha(),
|
||||
sequence->get_activations_beta(),
|
||||
sequence->get_clip());
|
||||
|
||||
auto unsqueeze_axis = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto unsqueeze_1 = std::make_shared<ngraph::opset4::Unsqueeze>(sequence_ie->output(0), unsqueeze_axis);
|
||||
auto unsqueeze_2 = std::make_shared<ngraph::opset4::Unsqueeze>(sequence_ie->output(1), unsqueeze_axis);
|
||||
sequence_ie->set_friendly_name("test_sequence");
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{unsqueeze_1}, ngraph::ParameterVector{X, H});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
|
||||
auto result_node_of_converted_f = f->get_output_op(0);
|
||||
auto sequence_node = result_node_of_converted_f->input_value(0).get_node_shared_ptr()
|
||||
->input_value(0).get_node_shared_ptr();
|
||||
}
|
||||
|
||||
TEST(TransformationTests, LSTMSequenceConversionTest) {
|
||||
const size_t batch_size = 2;
|
||||
const size_t input_size = 3;
|
||||
const size_t hidden_size = 3;
|
||||
const size_t gates_count = 4;
|
||||
const size_t num_directions = 1;
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
std::shared_ptr<ngraph::op::v5::LSTMSequence> sequence;
|
||||
{
|
||||
const auto X = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32,
|
||||
ngraph::Shape{batch_size, 10, input_size});
|
||||
const auto W =
|
||||
std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{num_directions,
|
||||
gates_count * hidden_size, input_size});
|
||||
const auto R =
|
||||
std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{num_directions,
|
||||
gates_count * hidden_size, hidden_size});
|
||||
const auto H_t = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32,
|
||||
ngraph::Shape{batch_size, num_directions,
|
||||
hidden_size});
|
||||
const auto C_t = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32,
|
||||
ngraph::Shape{batch_size, num_directions, hidden_size});
|
||||
const auto B = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{num_directions,
|
||||
gates_count * hidden_size});
|
||||
const auto seq_len = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i32, ngraph::Shape{batch_size});
|
||||
sequence = std::make_shared<ngraph::op::v5::LSTMSequence>(X, H_t, C_t, seq_len, W, R, B, hidden_size,
|
||||
ngraph::op::RecurrentSequenceDirection::FORWARD);
|
||||
sequence->set_friendly_name("test_sequence");
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::OutputVector{sequence->output(0)}, ngraph::ParameterVector{X, H_t, C_t});
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertLSTMSequenceMatcher>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
const auto X = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32,
|
||||
ngraph::Shape{batch_size, 10, input_size});
|
||||
const auto W =
|
||||
std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{num_directions,
|
||||
gates_count * hidden_size, input_size});
|
||||
const auto R =
|
||||
std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{num_directions,
|
||||
gates_count * hidden_size, hidden_size});
|
||||
const auto H_t = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32,
|
||||
ngraph::Shape{batch_size, num_directions, hidden_size});
|
||||
const auto C_t = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32,
|
||||
ngraph::Shape{batch_size, num_directions, hidden_size});
|
||||
const auto seq_lenghts = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{batch_size});
|
||||
const auto B = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{num_directions,
|
||||
gates_count * hidden_size});
|
||||
// const auto seq_len = std::make_shared<ngraph::opset4::Constant>(ngraph::element::i32, ngraph::Shape{1}, 1);
|
||||
auto axis_1 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto in_1 = std::make_shared<ngraph::opset4::Squeeze>(H_t, axis_1);
|
||||
auto in_2 = std::make_shared<ngraph::opset4::Squeeze>(C_t, axis_1);
|
||||
auto concat = std::make_shared<ngraph::opset4::Concat>(ngraph::NodeVector({W, R}), 2);
|
||||
auto axis_2 = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {0});
|
||||
auto in_3 = std::make_shared<ngraph::opset4::Squeeze>(concat->output(0), axis_2);
|
||||
auto in_4 = std::make_shared<ngraph::opset4::Squeeze>(B, axis_2);
|
||||
auto sequence_ie = std::make_shared<ngraph::op::LSTMSequenceIE>(X,
|
||||
in_1,
|
||||
in_2,
|
||||
seq_lenghts,
|
||||
in_3,
|
||||
in_4,
|
||||
sequence->get_hidden_size(),
|
||||
sequence->get_direction(),
|
||||
sequence->get_activations(),
|
||||
sequence->get_activations_alpha(),
|
||||
sequence->get_activations_beta(),
|
||||
sequence->get_clip());
|
||||
sequence_ie->set_friendly_name("test_sequence");
|
||||
auto unsqueeze_axis = ngraph::opset4::Constant::create(ngraph::element::i64, ngraph::Shape{1}, {1});
|
||||
auto unsqueeze_1 = std::make_shared<ngraph::opset4::Unsqueeze>(sequence_ie->output(0), unsqueeze_axis);
|
||||
auto unsqueeze_2 = std::make_shared<ngraph::opset4::Unsqueeze>(sequence_ie->output(1), unsqueeze_axis);
|
||||
auto unsqueeze_3 = std::make_shared<ngraph::opset4::Unsqueeze>(sequence_ie->output(2), unsqueeze_axis);
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{unsqueeze_1},
|
||||
ngraph::ParameterVector{X, H_t, C_t});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
|
||||
auto result_node_of_converted_f = f->get_output_op(0);
|
||||
auto sequence_node = result_node_of_converted_f->input_value(0).get_node_shared_ptr()
|
||||
->input_value(0).get_node_shared_ptr();
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include <ngraph/op/util/attr_types.hpp>
|
||||
#include "single_layer_tests/gru_sequence.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
// without clip values increase rapidly, so use only seq_lenghts = 2
|
||||
std::vector<size_t> seq_lengths_zero_clip{2};
|
||||
std::vector<size_t> seq_lengths_clip_non_zero{20};
|
||||
std::vector<size_t> batch{1, 10};
|
||||
std::vector<size_t> hidden_size{1, 10};
|
||||
std::vector<size_t> input_size{10};
|
||||
std::vector<std::vector<std::string>> activations = {{"relu", "tanh"}, {"tanh", "sigmoid"}, {"sigmoid", "tanh"},
|
||||
{"tanh", "relu"}};
|
||||
std::vector<bool> linear_before_reset = {true, false};
|
||||
std::vector<float> clip{0.f};
|
||||
std::vector<float> clip_non_zeros{0.7f};
|
||||
std::vector<ngraph::op::RecurrentSequenceDirection> direction = {ngraph::op::RecurrentSequenceDirection::FORWARD,
|
||||
ngraph::op::RecurrentSequenceDirection::REVERSE,
|
||||
ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL
|
||||
};
|
||||
std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(GRUSequenceCommonZeroClip, GRUSequenceTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(seq_lengths_zero_clip),
|
||||
::testing::ValuesIn(batch),
|
||||
::testing::ValuesIn(hidden_size),
|
||||
::testing::ValuesIn(input_size),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip),
|
||||
::testing::ValuesIn(linear_before_reset),
|
||||
::testing::ValuesIn(direction),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
GRUSequenceTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(GRUSequenceCommonClip, GRUSequenceTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(seq_lengths_clip_non_zero),
|
||||
::testing::ValuesIn(batch),
|
||||
::testing::ValuesIn(hidden_size),
|
||||
::testing::ValuesIn(input_size),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip_non_zeros),
|
||||
::testing::ValuesIn(linear_before_reset),
|
||||
::testing::ValuesIn(direction),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
GRUSequenceTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
@@ -0,0 +1,57 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include <ngraph/op/util/attr_types.hpp>
|
||||
#include "single_layer_tests/lstm_sequence.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
// without clip values increase rapidly, so use only seq_lenghts = 2
|
||||
std::vector<size_t> seq_lengths_zero_clip{2};
|
||||
std::vector<size_t> seq_lengths_clip_non_zero{20};
|
||||
std::vector<size_t> batch{1, 10};
|
||||
std::vector<size_t> hidden_size{1, 10};
|
||||
std::vector<size_t> input_size{10};
|
||||
std::vector<std::vector<std::string>> activations = {{"relu", "sigmoid", "tanh"}, {"sigmoid", "tanh", "tanh"},
|
||||
{"tanh", "relu", "sigmoid"}, {"sigmoid", "sigmoid", "sigmoid"},
|
||||
{"tanh", "tanh", "tanh"}, {"relu", "relu", "relu"}};
|
||||
std::vector<float> clip{0.f};
|
||||
std::vector<float> clip_non_zeros{0.7f};
|
||||
std::vector<ngraph::op::RecurrentSequenceDirection> direction = {ngraph::op::RecurrentSequenceDirection::FORWARD,
|
||||
ngraph::op::RecurrentSequenceDirection::REVERSE,
|
||||
ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL
|
||||
};
|
||||
std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(LSTMSequenceCommonZeroClip, LSTMSequenceTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(seq_lengths_zero_clip),
|
||||
::testing::ValuesIn(batch),
|
||||
::testing::ValuesIn(hidden_size),
|
||||
::testing::ValuesIn(input_size),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip),
|
||||
::testing::ValuesIn(direction),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
LSTMSequenceTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(LSTMSequenceCommonClip, LSTMSequenceTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(seq_lengths_clip_non_zero),
|
||||
::testing::ValuesIn(batch),
|
||||
::testing::ValuesIn(hidden_size),
|
||||
::testing::ValuesIn(input_size),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip_non_zeros),
|
||||
::testing::ValuesIn(direction),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
LSTMSequenceTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
@@ -0,0 +1,55 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include <ngraph/op/util/attr_types.hpp>
|
||||
#include "single_layer_tests/rnn_sequence.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
// without clip values increase rapidly, so use only seq_lenghts = 2
|
||||
std::vector<size_t> seq_lengths_zero_clip{2};
|
||||
std::vector<size_t> seq_lengths_clip_non_zero{20};
|
||||
std::vector<size_t> batch{1, 10};
|
||||
std::vector<size_t> hidden_size{1, 10};
|
||||
std::vector<size_t> input_size{10};
|
||||
std::vector<std::vector<std::string>> activations = {{"relu"}, {"sigmoid"}, {"tanh"}};
|
||||
std::vector<float> clip{0.f};
|
||||
std::vector<float> clip_non_zeros{0.7f};
|
||||
std::vector<ngraph::op::RecurrentSequenceDirection> direction = {ngraph::op::RecurrentSequenceDirection::FORWARD,
|
||||
ngraph::op::RecurrentSequenceDirection::REVERSE,
|
||||
ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL
|
||||
};
|
||||
std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(RNNSequenceCommonZeroClip, RNNSequenceTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(seq_lengths_zero_clip),
|
||||
::testing::ValuesIn(batch),
|
||||
::testing::ValuesIn(hidden_size),
|
||||
::testing::ValuesIn(input_size),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip),
|
||||
::testing::ValuesIn(direction),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
RNNSequenceTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(RNNSequenceCommonClip, RNNSequenceTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(seq_lengths_clip_non_zero),
|
||||
::testing::ValuesIn(batch),
|
||||
::testing::ValuesIn(hidden_size),
|
||||
::testing::ValuesIn(input_size),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip_non_zeros),
|
||||
::testing::ValuesIn(direction),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
RNNSequenceTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <ngraph/op/util/attr_types.hpp>
|
||||
#include "functional_test_utils/layer_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "ngraph_functions/utils/ngraph_helpers.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
using GRUSequenceParams = typename std::tuple<
|
||||
// bool, // using decompose to sub-ops transformation
|
||||
size_t, // seq_lengths
|
||||
size_t, // batch
|
||||
size_t, // hidden size
|
||||
size_t, // input size
|
||||
std::vector<std::string>, // activations
|
||||
float, // clip
|
||||
bool, // linear_before_reset
|
||||
ngraph::op::RecurrentSequenceDirection, // direction
|
||||
InferenceEngine::Precision, // Network precision
|
||||
std::string>; // Device name
|
||||
|
||||
class GRUSequenceTest : public testing::WithParamInterface<GRUSequenceParams>,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<GRUSequenceParams> &obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
@@ -0,0 +1,39 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <ngraph/op/util/attr_types.hpp>
|
||||
#include "functional_test_utils/layer_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "ngraph_functions/utils/ngraph_helpers.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
using LSTMSequenceParams = typename std::tuple<
|
||||
// bool, // using decompose to sub-ops transformation
|
||||
size_t, // seq_lengths
|
||||
size_t, // batch
|
||||
size_t, // hidden size
|
||||
size_t, // input size
|
||||
std::vector<std::string>, // activations
|
||||
float, // clip
|
||||
ngraph::op::RecurrentSequenceDirection, // direction
|
||||
InferenceEngine::Precision, // Network precision
|
||||
std::string>; // Device name
|
||||
|
||||
class LSTMSequenceTest : public testing::WithParamInterface<LSTMSequenceParams>,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<LSTMSequenceParams> &obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
@@ -0,0 +1,39 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <ngraph/op/util/attr_types.hpp>
|
||||
#include "functional_test_utils/layer_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "ngraph_functions/utils/ngraph_helpers.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
using RNNSequenceParams = typename std::tuple<
|
||||
// bool, // using decompose to sub-ops transformation
|
||||
size_t, // seq_lengths
|
||||
size_t, // batch
|
||||
size_t, // hidden size
|
||||
size_t, // input size
|
||||
std::vector<std::string>, // activations
|
||||
float, // clip
|
||||
ngraph::op::RecurrentSequenceDirection, // direction
|
||||
InferenceEngine::Precision, // Network precision
|
||||
std::string>; // Device name
|
||||
|
||||
class RNNSequenceTest : public testing::WithParamInterface<RNNSequenceParams>,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<RNNSequenceParams> &obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
@@ -72,7 +72,8 @@ void GRUCellTest::SetUp() {
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]});
|
||||
std::vector<ngraph::Shape> WRB = {inputShapes[2], inputShapes[3], inputShapes[4]};
|
||||
auto gru_cell = ngraph::builder::makeGRUCell(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)),
|
||||
auto gru_cell = ngraph::builder::makeGRU(
|
||||
ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)),
|
||||
WRB, hidden_size, activations, {}, {}, clip, linear_before_reset);
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gru_cell->output(0))};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "gru_cell");
|
||||
|
||||
@@ -0,0 +1,98 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
|
||||
#include "ie_core.hpp"
|
||||
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
#include "functional_test_utils/blob_utils.hpp"
|
||||
#include "functional_test_utils/precision_utils.hpp"
|
||||
#include "functional_test_utils/plugin_cache.hpp"
|
||||
#include "functional_test_utils/skip_tests_config.hpp"
|
||||
|
||||
#include "single_layer_tests/gru_sequence.hpp"
|
||||
#include <transformations/bidirectional_sequences_decomposition.hpp>
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
std::string GRUSequenceTest::getTestCaseName(const testing::TestParamInfo<GRUSequenceParams> &obj) {
|
||||
//bool should_decompose;
|
||||
size_t seq_lenghts;
|
||||
size_t batch;
|
||||
size_t hidden_size;
|
||||
size_t input_size;
|
||||
std::vector<std::string> activations;
|
||||
std::vector<float> activations_alpha;
|
||||
std::vector<float> activations_beta;
|
||||
float clip;
|
||||
bool linear_before_reset;
|
||||
ngraph::op::RecurrentSequenceDirection direction;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::string targetDevice;
|
||||
std::tie(seq_lenghts, batch, hidden_size, input_size, activations, clip, linear_before_reset, direction, netPrecision,
|
||||
targetDevice) = obj.param;
|
||||
std::vector<std::vector<size_t>> inputShapes = {
|
||||
{{batch, input_size}, {batch, hidden_size}, {batch, hidden_size}, {3 * hidden_size, input_size},
|
||||
{3 * hidden_size, hidden_size}, {(linear_before_reset ? 4 : 3) * hidden_size}},
|
||||
};
|
||||
std::ostringstream result;
|
||||
result << "seq_lenghts" << seq_lenghts << "_";
|
||||
result << "batch=" << batch << "_";
|
||||
result << "hidden_size=" << hidden_size << "_";
|
||||
result << "input_size=" << input_size << "_";
|
||||
result << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_";
|
||||
result << "activations=" << CommonTestUtils::vec2str(activations) << "_";
|
||||
result << "direction=" << direction << "_";
|
||||
result << "clip=" << clip << "_";
|
||||
result << "netPRC=" << netPrecision.name() << "_";
|
||||
result << "targetDevice=" << targetDevice << "_";
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void GRUSequenceTest::SetUp() {
|
||||
size_t seq_lenghts;
|
||||
// bool should_decompose;
|
||||
size_t batch;
|
||||
size_t hidden_size;
|
||||
size_t input_size;
|
||||
std::vector<std::string> activations;
|
||||
std::vector<float> activations_alpha;
|
||||
std::vector<float> activations_beta;
|
||||
float clip;
|
||||
bool linear_before_reset;
|
||||
ngraph::op::RecurrentSequenceDirection direction;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::tie(seq_lenghts, batch, hidden_size, input_size, activations, clip, linear_before_reset, direction, netPrecision,
|
||||
targetDevice) = this->GetParam();
|
||||
size_t num_directions = direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL ? 2 : 1;
|
||||
std::vector<std::vector<size_t>> inputShapes = {
|
||||
{{batch, seq_lenghts, input_size}, {batch, num_directions, hidden_size}, {batch},
|
||||
{num_directions, 3 * hidden_size, input_size}, {num_directions, 3 * hidden_size, hidden_size},
|
||||
{num_directions, (linear_before_reset ? 4 : 3) * hidden_size}},
|
||||
};
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]});
|
||||
std::vector<ngraph::Shape> WRB = {inputShapes[3], inputShapes[4], inputShapes[5], inputShapes[2]};
|
||||
auto gru_sequence = ngraph::builder::makeGRU(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)),
|
||||
WRB, hidden_size, activations, {}, {}, clip, linear_before_reset, true, direction);
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gru_sequence->output(0)),
|
||||
std::make_shared<ngraph::opset1::Result>(gru_sequence->output(1))};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "gru_sequence");
|
||||
if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) {
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::BidirectionalGRUSequenceDecomposition>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_P(GRUSequenceTest, CompareWithRefs) {
|
||||
Run();
|
||||
};
|
||||
} // namespace LayerTestsDefinitions
|
||||
@@ -70,7 +70,7 @@ void LSTMCellTest::SetUp() {
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1], inputShapes[2]});
|
||||
std::vector<ngraph::Shape> WRB = {inputShapes[3], inputShapes[4], inputShapes[5]};
|
||||
auto lstm_cell = ngraph::builder::makeLSTMCell(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)),
|
||||
auto lstm_cell = ngraph::builder::makeLSTM(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)),
|
||||
WRB, hidden_size, activations, {}, {}, clip);
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(lstm_cell->output(0)),
|
||||
std::make_shared<ngraph::opset1::Result>(lstm_cell->output(1))};
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
|
||||
#include "ie_core.hpp"
|
||||
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
#include "functional_test_utils/blob_utils.hpp"
|
||||
#include "functional_test_utils/precision_utils.hpp"
|
||||
#include "functional_test_utils/plugin_cache.hpp"
|
||||
#include "functional_test_utils/skip_tests_config.hpp"
|
||||
|
||||
#include "single_layer_tests/lstm_sequence.hpp"
|
||||
#include <transformations/bidirectional_sequences_decomposition.hpp>
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
std::string LSTMSequenceTest::getTestCaseName(const testing::TestParamInfo<LSTMSequenceParams> &obj) {
|
||||
//bool should_decompose;
|
||||
size_t seq_lenghts;
|
||||
size_t batch;
|
||||
size_t hidden_size;
|
||||
size_t input_size;
|
||||
std::vector<std::string> activations;
|
||||
std::vector<float> activations_alpha;
|
||||
std::vector<float> activations_beta;
|
||||
float clip;
|
||||
ngraph::op::RecurrentSequenceDirection direction;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::string targetDevice;
|
||||
std::tie(seq_lenghts, batch, hidden_size, input_size, activations, clip, direction, netPrecision,
|
||||
targetDevice) = obj.param;
|
||||
std::vector<std::vector<size_t>> inputShapes = {
|
||||
{{batch, input_size}, {batch, hidden_size}, {batch, hidden_size}, {4 * hidden_size, input_size},
|
||||
{4 * hidden_size, hidden_size}, {4 * hidden_size}},
|
||||
};
|
||||
std::ostringstream result;
|
||||
result << "seq_lenghts" << seq_lenghts << "_";
|
||||
result << "batch=" << batch << "_";
|
||||
result << "hidden_size=" << hidden_size << "_";
|
||||
result << "input_size=" << input_size << "_";
|
||||
result << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_";
|
||||
result << "activations=" << CommonTestUtils::vec2str(activations) << "_";
|
||||
result << "direction=" << direction << "_";
|
||||
result << "clip=" << clip << "_";
|
||||
result << "netPRC=" << netPrecision.name() << "_";
|
||||
result << "targetDevice=" << targetDevice << "_";
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void LSTMSequenceTest::SetUp() {
|
||||
size_t seq_lenghts;
|
||||
// bool should_decompose;
|
||||
size_t batch;
|
||||
size_t hidden_size;
|
||||
size_t input_size;
|
||||
std::vector<std::string> activations;
|
||||
std::vector<float> activations_alpha;
|
||||
std::vector<float> activations_beta;
|
||||
float clip;
|
||||
ngraph::op::RecurrentSequenceDirection direction;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::tie(seq_lenghts, batch, hidden_size, input_size, activations, clip, direction, netPrecision,
|
||||
targetDevice) = this->GetParam();
|
||||
size_t num_directions = direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL ? 2 : 1;
|
||||
std::vector<std::vector<size_t>> inputShapes = {
|
||||
{{batch, seq_lenghts, input_size}, {batch, num_directions, hidden_size}, {batch, num_directions, hidden_size},
|
||||
{batch}, {num_directions, 4 * hidden_size, input_size}, {num_directions, 4 * hidden_size, hidden_size}, {num_directions, 4 * hidden_size}},
|
||||
};
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1], inputShapes[2]});
|
||||
std::vector<ngraph::Shape> WRB = {inputShapes[4], inputShapes[5], inputShapes[6], inputShapes[3]};
|
||||
auto lstm_sequence = ngraph::builder::makeLSTM(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)),
|
||||
WRB, hidden_size, activations, {}, {}, clip, true, direction);
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(0)),
|
||||
std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(1)),
|
||||
std::make_shared<ngraph::opset1::Result>(lstm_sequence->output(2))};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "lstm_sequence");
|
||||
if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) {
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::BidirectionalLSTMSequenceDecomposition>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_P(LSTMSequenceTest, CompareWithRefs) {
|
||||
Run();
|
||||
};
|
||||
} // namespace LayerTestsDefinitions
|
||||
@@ -64,7 +64,8 @@ void RNNCellTest::SetUp() {
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]});
|
||||
std::vector<ngraph::Shape> WRB = {inputShapes[2], inputShapes[3], inputShapes[4]};
|
||||
auto rnn_cell = ngraph::builder::makeRNNCell(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)),
|
||||
auto rnn_cell = ngraph::builder::makeRNN(
|
||||
ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)),
|
||||
WRB, hidden_size, activations, {}, {}, clip);
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(rnn_cell)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "rnn_cell");
|
||||
|
||||
@@ -0,0 +1,96 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
|
||||
#include "ie_core.hpp"
|
||||
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
#include "functional_test_utils/blob_utils.hpp"
|
||||
#include "functional_test_utils/precision_utils.hpp"
|
||||
#include "functional_test_utils/plugin_cache.hpp"
|
||||
#include "functional_test_utils/skip_tests_config.hpp"
|
||||
|
||||
#include "single_layer_tests/rnn_sequence.hpp"
|
||||
#include <transformations/bidirectional_sequences_decomposition.hpp>
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
std::string RNNSequenceTest::getTestCaseName(const testing::TestParamInfo<RNNSequenceParams> &obj) {
|
||||
//bool should_decompose;
|
||||
size_t seq_lenghts;
|
||||
size_t batch;
|
||||
size_t hidden_size;
|
||||
size_t input_size;
|
||||
std::vector<std::string> activations;
|
||||
std::vector<float> activations_alpha;
|
||||
std::vector<float> activations_beta;
|
||||
float clip;
|
||||
ngraph::op::RecurrentSequenceDirection direction;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::string targetDevice;
|
||||
std::tie(seq_lenghts, batch, hidden_size, input_size, activations, clip, direction, netPrecision,
|
||||
targetDevice) = obj.param;
|
||||
std::vector<std::vector<size_t>> inputShapes = {
|
||||
{{batch, input_size}, {batch, hidden_size}, {batch, hidden_size}, {hidden_size, input_size},
|
||||
{hidden_size, hidden_size}, {hidden_size}},
|
||||
};
|
||||
std::ostringstream result;
|
||||
result << "seq_lenghts" << seq_lenghts << "_";
|
||||
result << "batch=" << batch << "_";
|
||||
result << "hidden_size=" << hidden_size << "_";
|
||||
result << "input_size=" << input_size << "_";
|
||||
result << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_";
|
||||
result << "activations=" << CommonTestUtils::vec2str(activations) << "_";
|
||||
result << "direction=" << direction << "_";
|
||||
result << "clip=" << clip << "_";
|
||||
result << "netPRC=" << netPrecision.name() << "_";
|
||||
result << "targetDevice=" << targetDevice << "_";
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void RNNSequenceTest::SetUp() {
|
||||
size_t seq_lenghts;
|
||||
// bool should_decompose;
|
||||
size_t batch;
|
||||
size_t hidden_size;
|
||||
size_t input_size;
|
||||
std::vector<std::string> activations;
|
||||
std::vector<float> activations_alpha;
|
||||
std::vector<float> activations_beta;
|
||||
float clip;
|
||||
ngraph::op::RecurrentSequenceDirection direction;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::tie(seq_lenghts, batch, hidden_size, input_size, activations, clip, direction, netPrecision,
|
||||
targetDevice) = this->GetParam();
|
||||
size_t num_directions = direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL ? 2 : 1;
|
||||
std::vector<std::vector<size_t>> inputShapes = {
|
||||
{{batch, seq_lenghts, input_size}, {batch, num_directions, hidden_size}, {batch},
|
||||
{num_directions, hidden_size, input_size}, {num_directions, hidden_size, hidden_size},
|
||||
{num_directions, hidden_size}},
|
||||
};
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]});
|
||||
std::vector<ngraph::Shape> WRB = {inputShapes[3], inputShapes[4], inputShapes[5], inputShapes[2]};
|
||||
auto rnn_sequence = ngraph::builder::makeRNN(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)),
|
||||
WRB, hidden_size, activations, {}, {}, clip, true, direction);
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(rnn_sequence->output(0)),
|
||||
std::make_shared<ngraph::opset1::Result>(rnn_sequence->output(1))};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "rnn_sequence");
|
||||
if (direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL) {
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::BidirectionalRNNSequenceDecomposition>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_P(RNNSequenceTest, CompareWithRefs) {
|
||||
Run();
|
||||
};
|
||||
} // namespace LayerTestsDefinitions
|
||||
@@ -390,35 +390,40 @@ std::shared_ptr<ngraph::Node> makePad(const ngraph::Output<Node>& data,
|
||||
std::shared_ptr<ngraph::Node> makeBatchNormInference(const ngraph::Output<Node>& data,
|
||||
double epsilon);
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeLSTMCell(const OutputVector& in,
|
||||
const std::vector<ngraph::Shape>& WRB,
|
||||
std::shared_ptr<ngraph::Node> makeLSTM(const OutputVector& in,
|
||||
const std::vector<ngraph::Shape>& constants,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations =
|
||||
std::vector<std::string>{"sigmoid", "tanh", "tanh"},
|
||||
const std::vector<float>& activations_alpha = {},
|
||||
const std::vector<float>& activations_beta = {},
|
||||
float clip = 0.f);
|
||||
float clip = 0.f,
|
||||
bool make_sequence = false,
|
||||
ngraph::op::RecurrentSequenceDirection direction = ngraph::op::RecurrentSequenceDirection::FORWARD);
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeGRUCell(const OutputVector& in,
|
||||
const std::vector<ngraph::Shape>& WRB,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations =
|
||||
std::vector<std::string>{"sigmoid", "tanh"},
|
||||
const std::vector<float>& activations_alpha = {},
|
||||
const std::vector<float>& activations_beta = {},
|
||||
float clip = 0.f,
|
||||
bool linear_before_reset = false);
|
||||
std::shared_ptr<ngraph::Node> makeGRU(const OutputVector& in,
|
||||
const std::vector<ngraph::Shape>& constants,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations =
|
||||
std::vector<std::string>{"sigmoid", "tanh"},
|
||||
const std::vector<float>& activations_alpha = {},
|
||||
const std::vector<float>& activations_beta = {},
|
||||
float clip = 0.f,
|
||||
bool linear_before_reset = false,
|
||||
bool make_sequence = false,
|
||||
ngraph::op::RecurrentSequenceDirection direction = ngraph::op::RecurrentSequenceDirection::FORWARD);
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeRNNCell(const OutputVector& in,
|
||||
const std::vector<ngraph::Shape>& WRB,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations = std::vector<std::string>{"tanh"},
|
||||
const std::vector<float>& activations_alpha = {},
|
||||
const std::vector<float>& activations_beta = {},
|
||||
float clip = 0.f);
|
||||
std::shared_ptr<ngraph::Node> makeRNN(const OutputVector& in,
|
||||
const std::vector<ngraph::Shape>& constants,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations = std::vector<std::string>{"tanh"},
|
||||
const std::vector<float>& activations_alpha = {},
|
||||
const std::vector<float>& activations_beta = {},
|
||||
float clip = 0.f,
|
||||
bool make_sequence = false,
|
||||
ngraph::op::RecurrentSequenceDirection direction = ngraph::op::RecurrentSequenceDirection::FORWARD);
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeTile(const ngraph::Output<Node>& in,
|
||||
const std::vector<size_t>& repeats);
|
||||
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
||||
|
||||
@@ -10,21 +10,30 @@
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeGRUCell(const OutputVector& in,
|
||||
const std::vector<ngraph::Shape>& WRB,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations,
|
||||
const std::vector<float>& activations_alpha,
|
||||
const std::vector<float>& activations_beta,
|
||||
float clip,
|
||||
bool linear_before_reset) {
|
||||
std::shared_ptr<ngraph::Node> makeGRU(const OutputVector& in,
|
||||
const std::vector<ngraph::Shape>& constants,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations,
|
||||
const std::vector<float>& activations_alpha,
|
||||
const std::vector<float>& activations_beta,
|
||||
float clip,
|
||||
bool linear_before_reset,
|
||||
bool make_sequence,
|
||||
ngraph::op::RecurrentSequenceDirection direction) {
|
||||
std::vector<float> empty;
|
||||
auto W = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[0], empty, true);
|
||||
auto R = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[1], empty, true);
|
||||
auto B = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[2], empty, true);
|
||||
return std::make_shared<ngraph::opset4::GRUCell>(in[0], in[1], W, R, B, hidden_size, activations,
|
||||
activations_alpha, activations_beta, clip, linear_before_reset);
|
||||
auto W = ngraph::builder::makeConstant(in[0].get_element_type(), constants[0], empty, true);
|
||||
auto R = ngraph::builder::makeConstant(in[0].get_element_type(), constants[1], empty, true);
|
||||
auto B = ngraph::builder::makeConstant(in[0].get_element_type(), constants[2], empty, true);
|
||||
if (!make_sequence) {
|
||||
return std::make_shared<ngraph::opset4::GRUCell>(in[0], in[1], W, R, B, hidden_size, activations,
|
||||
activations_alpha, activations_beta, clip,
|
||||
linear_before_reset);
|
||||
} else {
|
||||
std::vector<float> lenghts(in[0].get_shape()[0], in[0].get_shape()[1]);
|
||||
auto seq_lenghts = ngraph::builder::makeConstant(in[0].get_element_type(), constants[3], lenghts, false);
|
||||
return std::make_shared<ngraph::op::v5::GRUSequence>(in[0], in[1], seq_lenghts, W, R, B, hidden_size, direction,
|
||||
activations, activations_alpha, activations_beta, clip, linear_before_reset);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
||||
@@ -10,20 +10,28 @@
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeLSTMCell(const std::vector<ngraph::Output<Node>>& in,
|
||||
const std::vector<ngraph::Shape>& WRB,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations,
|
||||
const std::vector<float>& activations_alpha,
|
||||
const std::vector<float>& activations_beta,
|
||||
float clip) {
|
||||
std::shared_ptr<ngraph::Node> makeLSTM(const std::vector<ngraph::Output<Node>>& in,
|
||||
const std::vector<ngraph::Shape>& constants,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations,
|
||||
const std::vector<float>& activations_alpha,
|
||||
const std::vector<float>& activations_beta,
|
||||
float clip,
|
||||
bool make_sequence,
|
||||
ngraph::op::RecurrentSequenceDirection direction) {
|
||||
std::vector<float> empty;
|
||||
auto W = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[0], empty, true);
|
||||
auto R = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[1], empty, true);
|
||||
auto B = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[2], empty, true);
|
||||
return std::make_shared<ngraph::opset4::LSTMCell>(in[0], in[1], in[2], W, R, B, hidden_size, activations,
|
||||
activations_alpha, activations_beta, clip);
|
||||
auto W = ngraph::builder::makeConstant(in[0].get_element_type(), constants[0], empty, true);
|
||||
auto R = ngraph::builder::makeConstant(in[0].get_element_type(), constants[1], empty, true);
|
||||
auto B = ngraph::builder::makeConstant(in[0].get_element_type(), constants[2], empty, true);
|
||||
if (!make_sequence) {
|
||||
return std::make_shared<ngraph::opset4::LSTMCell>(in[0], in[1], in[2], W, R, B, hidden_size, activations,
|
||||
activations_alpha, activations_beta, clip);
|
||||
} else {
|
||||
std::vector<float> lenghts(in[0].get_shape()[0], in[0].get_shape()[1]);
|
||||
auto seq_lenghts = ngraph::builder::makeConstant(in[0].get_element_type(), constants[3], lenghts, false);
|
||||
return std::make_shared<ngraph::op::v5::LSTMSequence>(in[0], in[1], in[2], seq_lenghts, W, R, B, hidden_size, direction,
|
||||
activations_alpha, activations_beta, activations, clip);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
||||
@@ -10,20 +10,28 @@
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeRNNCell(const OutputVector& in,
|
||||
const std::vector<ngraph::Shape>& WRB,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations,
|
||||
const std::vector<float>& activations_alpha,
|
||||
const std::vector<float>& activations_beta,
|
||||
float clip) {
|
||||
std::shared_ptr<ngraph::Node> makeRNN(const OutputVector& in,
|
||||
const std::vector<ngraph::Shape>& constants,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations,
|
||||
const std::vector<float>& activations_alpha,
|
||||
const std::vector<float>& activations_beta,
|
||||
float clip,
|
||||
bool make_sequence,
|
||||
ngraph::op::RecurrentSequenceDirection direction) {
|
||||
std::vector<float> empty;
|
||||
auto W = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[0], empty, true);
|
||||
auto R = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[1], empty, true);
|
||||
auto B = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[2], empty, true);
|
||||
return std::make_shared<ngraph::opset4::RNNCell>(in[0], in[1], W, R, B, hidden_size, activations,
|
||||
activations_alpha, activations_beta, clip);
|
||||
auto W = ngraph::builder::makeConstant(in[0].get_element_type(), constants[0], empty, true);
|
||||
auto R = ngraph::builder::makeConstant(in[0].get_element_type(), constants[1], empty, true);
|
||||
auto B = ngraph::builder::makeConstant(in[0].get_element_type(), constants[2], empty, true);
|
||||
if (!make_sequence) {
|
||||
return std::make_shared<ngraph::opset4::RNNCell>(in[0], in[1], W, R, B, hidden_size, activations,
|
||||
activations_alpha, activations_beta, clip);
|
||||
} else {
|
||||
std::vector<float> lenghts(in[0].get_shape()[0], in[0].get_shape()[1]);
|
||||
auto seq_lenghts = ngraph::builder::makeConstant(in[0].get_element_type(), constants[3], lenghts, false);
|
||||
return std::make_shared<ngraph::op::v5::RNNSequence>(in[0], in[1], seq_lenghts, W, R, B, hidden_size, direction,
|
||||
activations, activations_alpha, activations_beta, clip);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
||||
67
ngraph/core/include/ngraph/op/gru_sequence.hpp
Normal file
67
ngraph/core/include/ngraph/op/gru_sequence.hpp
Normal file
@@ -0,0 +1,67 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/rnn_cell_base.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace v5
|
||||
{
|
||||
class NGRAPH_API GRUSequence : public util::RNNCellBase
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
GRUSequence();
|
||||
|
||||
GRUSequence(const Output<Node>& X,
|
||||
const Output<Node>& H_t,
|
||||
const Output<Node>& sequence_lengths,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
const Output<Node>& B,
|
||||
size_t hidden_size,
|
||||
op::RecurrentSequenceDirection direction,
|
||||
const std::vector<std::string>& activations =
|
||||
std::vector<std::string>{"sigmoid", "tanh"},
|
||||
const std::vector<float>& activations_alpha = {},
|
||||
const std::vector<float>& activations_beta = {},
|
||||
float clip = 0.f,
|
||||
bool linear_before_reset = false);
|
||||
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
bool get_linear_before_reset() const { return m_linear_before_reset; }
|
||||
op::RecurrentSequenceDirection get_direction() const { return m_direction; }
|
||||
protected:
|
||||
op::RecurrentSequenceDirection m_direction;
|
||||
bool m_linear_before_reset;
|
||||
};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
@@ -47,8 +47,7 @@ namespace ngraph
|
||||
class NGRAPH_API LSTMSequence : public util::FusedOp
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"LSTMSequence", 0};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
LSTMSequence() = default;
|
||||
|
||||
using direction = RecurrentSequenceDirection;
|
||||
@@ -102,11 +101,11 @@ namespace ngraph
|
||||
const std::int64_t hidden_size,
|
||||
const direction lstm_direction,
|
||||
LSTMWeightsFormat weights_format = LSTMWeightsFormat::IFCO,
|
||||
const std::vector<float> activations_alpha = {},
|
||||
const std::vector<float> activations_beta = {},
|
||||
const std::vector<std::string> activations = {"sigmoid",
|
||||
"tanh",
|
||||
"tanh"},
|
||||
const std::vector<float>& activations_alpha = {},
|
||||
const std::vector<float>& activations_beta = {},
|
||||
const std::vector<std::string>& activations = {"sigmoid",
|
||||
"tanh",
|
||||
"tanh"},
|
||||
const float clip_threshold = 0,
|
||||
const bool input_forget = false)
|
||||
: LSTMSequence(
|
||||
@@ -186,7 +185,7 @@ namespace ngraph
|
||||
};
|
||||
}
|
||||
|
||||
namespace v1
|
||||
namespace v5
|
||||
{
|
||||
///
|
||||
/// \brief Class for lstm sequence node.
|
||||
@@ -200,8 +199,7 @@ namespace ngraph
|
||||
class NGRAPH_API LSTMSequence : public util::RNNCellBase
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"LSTMSequence", 1};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
LSTMSequence() = default;
|
||||
|
||||
using direction = RecurrentSequenceDirection;
|
||||
@@ -216,11 +214,11 @@ namespace ngraph
|
||||
const Output<Node>& B,
|
||||
const std::int64_t hidden_size,
|
||||
const direction lstm_direction,
|
||||
const std::vector<float> activations_alpha = {},
|
||||
const std::vector<float> activations_beta = {},
|
||||
const std::vector<std::string> activations = {"sigmoid",
|
||||
"tanh",
|
||||
"tanh"},
|
||||
const std::vector<float>& activations_alpha = {},
|
||||
const std::vector<float>& activations_beta = {},
|
||||
const std::vector<std::string>& activations = {"sigmoid",
|
||||
"tanh",
|
||||
"tanh"},
|
||||
const float clip = 0.f)
|
||||
: RNNCellBase(
|
||||
{X, initial_hidden_state, initial_cell_state, sequence_lengths, W, R, B},
|
||||
@@ -237,7 +235,7 @@ namespace ngraph
|
||||
void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
direction get_direction() const { return m_direction; }
|
||||
|
||||
66
ngraph/core/include/ngraph/op/rnn_sequence.hpp
Normal file
66
ngraph/core/include/ngraph/op/rnn_sequence.hpp
Normal file
@@ -0,0 +1,66 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph/op/op.hpp"
|
||||
#include "ngraph/op/util/rnn_cell_base.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
{
|
||||
namespace v5
|
||||
{
|
||||
class NGRAPH_API RNNSequence : public util::RNNCellBase
|
||||
{
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
RNNSequence();
|
||||
|
||||
RNNSequence(
|
||||
const Output<Node>& X,
|
||||
const Output<Node>& H_t,
|
||||
const Output<Node>& sequence_lengths,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
const Output<Node>& B,
|
||||
size_t hidden_size,
|
||||
op::RecurrentSequenceDirection direction,
|
||||
const std::vector<std::string>& activations = std::vector<std::string>{"tanh"},
|
||||
const std::vector<float>& activations_alpha = {},
|
||||
const std::vector<float>& activations_beta = {},
|
||||
float clip = 0.f);
|
||||
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
op::RecurrentSequenceDirection get_direction() const { return m_direction; }
|
||||
protected:
|
||||
op::RecurrentSequenceDirection m_direction;
|
||||
};
|
||||
}
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
@@ -74,6 +74,7 @@
|
||||
#include "ngraph/op/grn.hpp"
|
||||
#include "ngraph/op/group_conv.hpp"
|
||||
#include "ngraph/op/gru_cell.hpp"
|
||||
#include "ngraph/op/gru_sequence.hpp"
|
||||
#include "ngraph/op/hard_sigmoid.hpp"
|
||||
#include "ngraph/op/hswish.hpp"
|
||||
#include "ngraph/op/interpolate.hpp"
|
||||
@@ -131,6 +132,7 @@
|
||||
#include "ngraph/op/reverse.hpp"
|
||||
#include "ngraph/op/reverse_sequence.hpp"
|
||||
#include "ngraph/op/rnn_cell.hpp"
|
||||
#include "ngraph/op/rnn_sequence.hpp"
|
||||
#include "ngraph/op/roi_align.hpp"
|
||||
#include "ngraph/op/roi_pooling.hpp"
|
||||
#include "ngraph/op/round.hpp"
|
||||
|
||||
@@ -156,8 +156,8 @@ NGRAPH_OP(Atanh, ngraph::op::v3)
|
||||
NGRAPH_OP(CTCLoss, ngraph::op::v4)
|
||||
NGRAPH_OP(HSwish, ngraph::op::v4)
|
||||
NGRAPH_OP(Interpolate, ngraph::op::v4)
|
||||
NGRAPH_OP(NonMaxSuppression, ngraph::op::v4)
|
||||
NGRAPH_OP(Mish, ngraph::op::v4)
|
||||
NGRAPH_OP(NonMaxSuppression, ngraph::op::v4)
|
||||
NGRAPH_OP(ReduceL1, ngraph::op::v4)
|
||||
NGRAPH_OP(ReduceL2, ngraph::op::v4)
|
||||
NGRAPH_OP(SoftPlus, ngraph::op::v4)
|
||||
|
||||
@@ -0,0 +1,539 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <ngraph/runtime/reference/concat.hpp>
|
||||
#include <ngraph/runtime/reference/gru_cell.hpp>
|
||||
#include <ngraph/runtime/reference/lstm_cell.hpp>
|
||||
#include <ngraph/runtime/reference/rnn_cell.hpp>
|
||||
#include <ngraph/runtime/reference/split.hpp>
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
enum class CellType
|
||||
{
|
||||
RNN,
|
||||
GRU,
|
||||
LSTM,
|
||||
};
|
||||
|
||||
struct CellArgs
|
||||
{
|
||||
std::string activation_f; // RNN
|
||||
std::string activation_g; // RNN/GRU
|
||||
std::string activation_h; // RNN/GRU/LSTM
|
||||
float clip; // RNN/GRU/LSTM
|
||||
bool linear_before_reset = false; // GRU
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void cell_pass(CellType type,
|
||||
const std::vector<const char*>& inputs,
|
||||
const std::vector<Shape>& shapes,
|
||||
const std::vector<char*>& outputs,
|
||||
const CellArgs& args,
|
||||
bool is_reverse)
|
||||
{
|
||||
auto squeeze_axis = [](const Shape& shape, size_t axis) -> Shape {
|
||||
Shape new_shape(shape.size() - 1);
|
||||
for (size_t i = 0, j = 0; i < shape.size(); ++i)
|
||||
{
|
||||
if (i != axis)
|
||||
{
|
||||
new_shape[j] = shape[i];
|
||||
j++;
|
||||
}
|
||||
}
|
||||
return new_shape;
|
||||
};
|
||||
|
||||
size_t x_shape_size = ngraph::shape_size(shapes[0]);
|
||||
|
||||
// split X
|
||||
size_t num_splits = shapes[0].at(1);
|
||||
std::vector<std::vector<char>> in_seqs(
|
||||
num_splits, std::vector<char>(x_shape_size / num_splits * sizeof(T)));
|
||||
std::vector<char*> pointers(num_splits);
|
||||
for (size_t i = 0; i < num_splits; ++i)
|
||||
pointers[is_reverse ? num_splits - i - 1 : i] = in_seqs[i].data();
|
||||
reference::split(inputs[0], shapes[0], sizeof(T), 1, num_splits, pointers.data());
|
||||
|
||||
Shape part_shape{shapes[0][0], 1, shapes[2][2]};
|
||||
size_t part_shape_size = ngraph::shape_size(part_shape);
|
||||
std::vector<std::vector<char>> h_list(
|
||||
num_splits, std::vector<char>(ngraph::shape_size(part_shape) * sizeof(T)));
|
||||
|
||||
// use outputs as a buffer for temporarily values
|
||||
char* H_i = outputs[1];
|
||||
std::memcpy(H_i, inputs[2], ngraph::shape_size(shapes[2]) * sizeof(T));
|
||||
|
||||
char* C_i = nullptr; // LSTMCell only
|
||||
if (type == CellType::LSTM)
|
||||
{
|
||||
C_i = outputs[2];
|
||||
std::memcpy(C_i, inputs[3], ngraph::shape_size(shapes[3]) * sizeof(T));
|
||||
}
|
||||
|
||||
for (size_t time_step = 0; time_step < num_splits; ++time_step)
|
||||
{
|
||||
if (type == CellType::LSTM)
|
||||
{
|
||||
runtime::reference::lstm_cell<T>(
|
||||
reinterpret_cast<const T*>(in_seqs[time_step].data()),
|
||||
squeeze_axis(shapes[0], 1),
|
||||
reinterpret_cast<const T*>(H_i),
|
||||
squeeze_axis(shapes[2], 1),
|
||||
reinterpret_cast<const T*>(C_i),
|
||||
squeeze_axis(shapes[3], 1),
|
||||
reinterpret_cast<const T*>(inputs[4]),
|
||||
squeeze_axis(shapes[4], 0),
|
||||
reinterpret_cast<const T*>(inputs[5]),
|
||||
squeeze_axis(shapes[5], 0),
|
||||
reinterpret_cast<const T*>(inputs[6]),
|
||||
squeeze_axis(shapes[6], 0),
|
||||
reinterpret_cast<T*>(outputs[1]),
|
||||
reinterpret_cast<T*>(outputs[2]),
|
||||
args.activation_f,
|
||||
args.activation_g,
|
||||
args.activation_h,
|
||||
args.clip);
|
||||
}
|
||||
else if (type == CellType::RNN)
|
||||
{
|
||||
runtime::reference::rnn_cell<T>(
|
||||
reinterpret_cast<const T*>(in_seqs[time_step].data()),
|
||||
squeeze_axis(shapes[0], 1),
|
||||
reinterpret_cast<const T*>(H_i),
|
||||
squeeze_axis(shapes[2], 1),
|
||||
reinterpret_cast<const T*>(inputs[3]),
|
||||
squeeze_axis(shapes[3], 0),
|
||||
reinterpret_cast<const T*>(inputs[4]),
|
||||
squeeze_axis(shapes[4], 0),
|
||||
reinterpret_cast<const T*>(inputs[5]),
|
||||
squeeze_axis(shapes[5], 0),
|
||||
reinterpret_cast<T*>(outputs[1]),
|
||||
args.activation_f,
|
||||
args.clip);
|
||||
}
|
||||
else if (type == CellType::GRU)
|
||||
{
|
||||
runtime::reference::gru_cell<T>(
|
||||
reinterpret_cast<const T*>(in_seqs[time_step].data()),
|
||||
squeeze_axis(shapes[0], 1),
|
||||
reinterpret_cast<const T*>(H_i),
|
||||
squeeze_axis(shapes[2], 1),
|
||||
reinterpret_cast<const T*>(inputs[3]),
|
||||
squeeze_axis(shapes[3], 0),
|
||||
reinterpret_cast<const T*>(inputs[4]),
|
||||
squeeze_axis(shapes[4], 0),
|
||||
reinterpret_cast<const T*>(inputs[5]),
|
||||
squeeze_axis(shapes[5], 0),
|
||||
reinterpret_cast<T*>(outputs[1]),
|
||||
args.activation_f,
|
||||
args.activation_g,
|
||||
args.clip,
|
||||
args.linear_before_reset);
|
||||
}
|
||||
std::memcpy(h_list[time_step].data(), outputs[1], part_shape_size * sizeof(T));
|
||||
}
|
||||
// The tensor that concats all the intermediate output values of the hidden.
|
||||
// It has shape [batch_size, seq_length, hidden_size]
|
||||
std::vector<Shape> in_shapes(num_splits, part_shape);
|
||||
std::vector<const char*> to_concat_pointers(num_splits);
|
||||
for (size_t i = 0; i < num_splits; ++i)
|
||||
to_concat_pointers[is_reverse ? num_splits - i - 1 : i] = h_list[i].data();
|
||||
runtime::reference::concat(to_concat_pointers,
|
||||
outputs[0],
|
||||
in_shapes,
|
||||
{shapes[0][0], shapes[0][1], shapes[2][2]},
|
||||
1,
|
||||
sizeof(T));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void lstm_sequence(const char* X,
|
||||
const Shape& X_shape,
|
||||
const char* H,
|
||||
const Shape& H_shape,
|
||||
const char* C,
|
||||
const Shape& C_shape,
|
||||
const char* seq_lengths,
|
||||
const Shape& seq_lengths_shape,
|
||||
const char* W,
|
||||
const Shape& W_shape,
|
||||
const char* R,
|
||||
const Shape& R_shape,
|
||||
const char* B,
|
||||
const Shape& B_shape,
|
||||
char* Y,
|
||||
char* Ho,
|
||||
char* Co,
|
||||
const std::string& activation_f,
|
||||
const std::string& activation_g,
|
||||
const std::string& activation_h,
|
||||
float clip,
|
||||
op::RecurrentSequenceDirection direction)
|
||||
{
|
||||
OutputVector results;
|
||||
if (direction == op::RecurrentSequenceDirection::FORWARD ||
|
||||
direction == op::RecurrentSequenceDirection::REVERSE)
|
||||
{
|
||||
CellArgs args;
|
||||
args.activation_f = activation_f;
|
||||
args.activation_g = activation_g;
|
||||
args.activation_h = activation_h;
|
||||
args.clip = clip;
|
||||
std::vector<const char*> inputs = {X, seq_lengths, H, C, W, R, B};
|
||||
std::vector<char*> outputs = {Y, Ho, Co};
|
||||
std::vector<Shape> shapes = {
|
||||
X_shape, seq_lengths_shape, H_shape, C_shape, W_shape, R_shape, B_shape};
|
||||
cell_pass<T>(CellType::LSTM,
|
||||
inputs,
|
||||
shapes,
|
||||
outputs,
|
||||
args,
|
||||
direction == op::RecurrentSequenceDirection::REVERSE);
|
||||
}
|
||||
else if (direction == op::RecurrentSequenceDirection::BIDIRECTIONAL)
|
||||
{
|
||||
// Split bidirectional case to forward + reverse passes.
|
||||
// split inputs
|
||||
std::vector<std::vector<char>> H_split(
|
||||
2, std::vector<char>(ngraph::shape_size(H_shape) / 2));
|
||||
std::vector<std::vector<char>> C_split(
|
||||
2, std::vector<char>(ngraph::shape_size(C_shape) / 2));
|
||||
std::vector<std::vector<char>> W_split(
|
||||
2, std::vector<char>(ngraph::shape_size(W_shape) / 2));
|
||||
std::vector<std::vector<char>> R_split(
|
||||
2, std::vector<char>(ngraph::shape_size(R_shape) / 2));
|
||||
std::vector<std::vector<char>> B_split(
|
||||
2, std::vector<char>(ngraph::shape_size(B_shape) / 2));
|
||||
char* h_pointers[2] = {H_split[0].data(), H_split[1].data()};
|
||||
char* c_pointers[2] = {C_split[0].data(), C_split[1].data()};
|
||||
char* w_pointers[2] = {W_split[0].data(), W_split[1].data()};
|
||||
char* r_pointers[2] = {R_split[0].data(), R_split[1].data()};
|
||||
char* b_pointers[2] = {B_split[0].data(), B_split[1].data()};
|
||||
reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers);
|
||||
reference::split(C, C_shape, sizeof(T), 1, 2, c_pointers);
|
||||
reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers);
|
||||
reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers);
|
||||
reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers);
|
||||
std::vector<std::vector<char>> forward_res(
|
||||
3, std::vector<char>(H_shape[0] * H_shape[2]));
|
||||
std::vector<std::vector<char>> reverse_res(
|
||||
3, std::vector<char>(H_shape[0] * H_shape[2]));
|
||||
|
||||
CellArgs args;
|
||||
args.activation_f = activation_f;
|
||||
args.activation_g = activation_g;
|
||||
args.activation_h = activation_h;
|
||||
args.clip = clip;
|
||||
std::vector<Shape> shapes = {
|
||||
X_shape, seq_lengths_shape, H_shape, C_shape, W_shape, R_shape, B_shape};
|
||||
// forward pass
|
||||
cell_pass<T>(
|
||||
CellType::LSTM,
|
||||
{X,
|
||||
seq_lengths,
|
||||
h_pointers[0],
|
||||
c_pointers[0],
|
||||
w_pointers[0],
|
||||
r_pointers[0],
|
||||
b_pointers[0]},
|
||||
shapes,
|
||||
{forward_res[0].data(), forward_res[1].data(), forward_res[2].data()},
|
||||
args,
|
||||
false);
|
||||
// reverse pass
|
||||
cell_pass<T>(
|
||||
CellType::LSTM,
|
||||
{X,
|
||||
seq_lengths,
|
||||
h_pointers[1],
|
||||
c_pointers[1],
|
||||
w_pointers[1],
|
||||
r_pointers[1],
|
||||
b_pointers[1]},
|
||||
shapes,
|
||||
{reverse_res[0].data(), reverse_res[1].data(), reverse_res[2].data()},
|
||||
args,
|
||||
true);
|
||||
|
||||
// Stack together respective outputs from both forward and reverse passes.
|
||||
std::vector<Shape> in_shapes = {{H_shape[0], 1, H_shape[2]},
|
||||
{H_shape[0], 1, H_shape[2]},
|
||||
{H_shape[0], 1, H_shape[2]}};
|
||||
Shape output_shape = {{H_shape[0], 2, H_shape[2]}};
|
||||
|
||||
runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
|
||||
Y,
|
||||
in_shapes,
|
||||
output_shape,
|
||||
1,
|
||||
sizeof(T));
|
||||
runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()},
|
||||
Ho,
|
||||
in_shapes,
|
||||
output_shape,
|
||||
1,
|
||||
sizeof(T));
|
||||
runtime::reference::concat({forward_res[2].data(), reverse_res[2].data()},
|
||||
Co,
|
||||
in_shapes,
|
||||
output_shape,
|
||||
1,
|
||||
sizeof(T));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void gru_sequence(const char* X,
|
||||
const Shape& X_shape,
|
||||
const char* H,
|
||||
const Shape& H_shape,
|
||||
const char* seq_lengths,
|
||||
const Shape& seq_lengths_shape,
|
||||
const char* W,
|
||||
const Shape& W_shape,
|
||||
const char* R,
|
||||
const Shape& R_shape,
|
||||
const char* B,
|
||||
const Shape& B_shape,
|
||||
char* Y,
|
||||
char* Ho,
|
||||
const std::string& activation_f,
|
||||
const std::string& activation_g,
|
||||
const float clip,
|
||||
const op::RecurrentSequenceDirection direction,
|
||||
const bool linear_before_reset)
|
||||
{
|
||||
OutputVector results;
|
||||
if (direction == op::RecurrentSequenceDirection::FORWARD ||
|
||||
direction == op::RecurrentSequenceDirection::REVERSE)
|
||||
{
|
||||
CellArgs args;
|
||||
args.activation_f = activation_f;
|
||||
args.activation_g = activation_g;
|
||||
args.linear_before_reset = linear_before_reset;
|
||||
args.clip = clip;
|
||||
std::vector<const char*> inputs = {X, seq_lengths, H, W, R, B};
|
||||
std::vector<char*> outputs = {Y, Ho};
|
||||
std::vector<Shape> shapes = {
|
||||
X_shape, seq_lengths_shape, H_shape, W_shape, R_shape, B_shape};
|
||||
cell_pass<T>(CellType::GRU,
|
||||
inputs,
|
||||
shapes,
|
||||
outputs,
|
||||
args,
|
||||
direction == op::RecurrentSequenceDirection::REVERSE);
|
||||
}
|
||||
else if (direction == op::RecurrentSequenceDirection::BIDIRECTIONAL)
|
||||
{
|
||||
// Split bidirectional case to forward + reverse passes.
|
||||
// split inputs
|
||||
std::vector<std::vector<char>> H_split(
|
||||
2, std::vector<char>(ngraph::shape_size(H_shape) / 2));
|
||||
std::vector<std::vector<char>> W_split(
|
||||
2, std::vector<char>(ngraph::shape_size(W_shape) / 2));
|
||||
std::vector<std::vector<char>> R_split(
|
||||
2, std::vector<char>(ngraph::shape_size(R_shape) / 2));
|
||||
std::vector<std::vector<char>> B_split(
|
||||
2, std::vector<char>(ngraph::shape_size(B_shape) / 2));
|
||||
char* h_pointers[2] = {H_split[0].data(), H_split[1].data()};
|
||||
char* w_pointers[2] = {W_split[0].data(), W_split[1].data()};
|
||||
char* r_pointers[2] = {R_split[0].data(), R_split[1].data()};
|
||||
char* b_pointers[2] = {B_split[0].data(), B_split[1].data()};
|
||||
reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers);
|
||||
reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers);
|
||||
reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers);
|
||||
reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers);
|
||||
std::vector<std::vector<char>> forward_res(
|
||||
2, std::vector<char>(H_shape[0] * H_shape[2]));
|
||||
std::vector<std::vector<char>> reverse_res(
|
||||
2, std::vector<char>(H_shape[0] * H_shape[2]));
|
||||
|
||||
CellArgs args;
|
||||
args.activation_f = activation_f;
|
||||
args.activation_g = activation_g;
|
||||
args.linear_before_reset = linear_before_reset;
|
||||
args.clip = clip;
|
||||
std::vector<Shape> shapes = {
|
||||
X_shape, seq_lengths_shape, H_shape, W_shape, R_shape, B_shape};
|
||||
// forward pass
|
||||
cell_pass<T>(CellType::GRU,
|
||||
{X,
|
||||
seq_lengths,
|
||||
h_pointers[0],
|
||||
w_pointers[0],
|
||||
r_pointers[0],
|
||||
b_pointers[0]},
|
||||
shapes,
|
||||
{forward_res[0].data(), forward_res[1].data()},
|
||||
args,
|
||||
false);
|
||||
// reverse pass
|
||||
cell_pass<T>(CellType::GRU,
|
||||
{X,
|
||||
seq_lengths,
|
||||
h_pointers[1],
|
||||
w_pointers[1],
|
||||
r_pointers[1],
|
||||
b_pointers[1]},
|
||||
shapes,
|
||||
{reverse_res[0].data(), reverse_res[1].data()},
|
||||
args,
|
||||
true);
|
||||
|
||||
// Stack together respective outputs from both forward and reverse passes.
|
||||
std::vector<Shape> in_shapes = {{H_shape[0], 1, H_shape[2]},
|
||||
{H_shape[0], 1, H_shape[2]}};
|
||||
Shape output_shape = {{H_shape[0], 2, H_shape[2]}};
|
||||
|
||||
runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
|
||||
Y,
|
||||
in_shapes,
|
||||
output_shape,
|
||||
1,
|
||||
sizeof(T));
|
||||
runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()},
|
||||
Ho,
|
||||
in_shapes,
|
||||
output_shape,
|
||||
1,
|
||||
sizeof(T));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void rnn_sequence(const char* X,
|
||||
const Shape& X_shape,
|
||||
const char* H,
|
||||
const Shape& H_shape,
|
||||
const char* seq_lengths,
|
||||
const Shape& seq_lengths_shape,
|
||||
const char* W,
|
||||
const Shape& W_shape,
|
||||
const char* R,
|
||||
const Shape& R_shape,
|
||||
const char* B,
|
||||
const Shape& B_shape,
|
||||
char* Y,
|
||||
char* Ho,
|
||||
const std::string& activation_f,
|
||||
float clip,
|
||||
const op::RecurrentSequenceDirection direction)
|
||||
{
|
||||
OutputVector results;
|
||||
if (direction == op::RecurrentSequenceDirection::FORWARD ||
|
||||
direction == op::RecurrentSequenceDirection::REVERSE)
|
||||
{
|
||||
CellArgs args;
|
||||
args.activation_f = activation_f;
|
||||
args.clip = clip;
|
||||
std::vector<const char*> inputs = {X, seq_lengths, H, W, R, B};
|
||||
std::vector<char*> outputs = {Y, Ho};
|
||||
std::vector<Shape> shapes = {
|
||||
X_shape, seq_lengths_shape, H_shape, W_shape, R_shape, B_shape};
|
||||
cell_pass<T>(CellType::RNN,
|
||||
inputs,
|
||||
shapes,
|
||||
outputs,
|
||||
args,
|
||||
direction == op::RecurrentSequenceDirection::REVERSE);
|
||||
}
|
||||
else if (direction == op::RecurrentSequenceDirection::BIDIRECTIONAL)
|
||||
{
|
||||
// Split bidirectional case to forward + reverse passes.
|
||||
// split inputs
|
||||
std::vector<std::vector<char>> H_split(
|
||||
2, std::vector<char>(ngraph::shape_size(H_shape) / 2));
|
||||
std::vector<std::vector<char>> W_split(
|
||||
2, std::vector<char>(ngraph::shape_size(W_shape) / 2));
|
||||
std::vector<std::vector<char>> R_split(
|
||||
2, std::vector<char>(ngraph::shape_size(R_shape) / 2));
|
||||
std::vector<std::vector<char>> B_split(
|
||||
2, std::vector<char>(ngraph::shape_size(B_shape) / 2));
|
||||
char* h_pointers[2] = {H_split[0].data(), H_split[1].data()};
|
||||
char* w_pointers[2] = {W_split[0].data(), W_split[1].data()};
|
||||
char* r_pointers[2] = {R_split[0].data(), R_split[1].data()};
|
||||
char* b_pointers[2] = {B_split[0].data(), B_split[1].data()};
|
||||
reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers);
|
||||
reference::split(W, W_shape, sizeof(T), 1, 2, w_pointers);
|
||||
reference::split(R, R_shape, sizeof(T), 1, 2, r_pointers);
|
||||
reference::split(B, B_shape, sizeof(T), 1, 2, b_pointers);
|
||||
std::vector<std::vector<char>> forward_res(
|
||||
2, std::vector<char>(H_shape[0] * H_shape[2]));
|
||||
std::vector<std::vector<char>> reverse_res(
|
||||
2, std::vector<char>(H_shape[0] * H_shape[2]));
|
||||
|
||||
CellArgs args;
|
||||
args.activation_f = activation_f;
|
||||
args.clip = clip;
|
||||
std::vector<Shape> shapes = {
|
||||
X_shape, seq_lengths_shape, H_shape, W_shape, R_shape, B_shape};
|
||||
// forward pass
|
||||
cell_pass<T>(CellType::RNN,
|
||||
{X,
|
||||
seq_lengths,
|
||||
h_pointers[0],
|
||||
w_pointers[0],
|
||||
r_pointers[0],
|
||||
b_pointers[0]},
|
||||
shapes,
|
||||
{forward_res[0].data(), forward_res[1].data()},
|
||||
args,
|
||||
false);
|
||||
// reverse pass
|
||||
cell_pass<T>(CellType::RNN,
|
||||
{X,
|
||||
seq_lengths,
|
||||
h_pointers[1],
|
||||
w_pointers[1],
|
||||
r_pointers[1],
|
||||
b_pointers[1]},
|
||||
shapes,
|
||||
{reverse_res[0].data(), reverse_res[1].data()},
|
||||
args,
|
||||
true);
|
||||
|
||||
// Stack together respective outputs from both forward and reverse passes.
|
||||
std::vector<Shape> in_shapes = {{H_shape[0], 1, H_shape[2]},
|
||||
{H_shape[0], 1, H_shape[2]}};
|
||||
Shape output_shape = {{H_shape[0], 2, H_shape[2]}};
|
||||
|
||||
runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
|
||||
Y,
|
||||
in_shapes,
|
||||
output_shape,
|
||||
1,
|
||||
sizeof(T));
|
||||
runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()},
|
||||
Ho,
|
||||
in_shapes,
|
||||
output_shape,
|
||||
1,
|
||||
sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -109,6 +109,14 @@ bool op::v3::GRUCell::visit_attributes(AttributeVisitor& visitor)
|
||||
|
||||
void op::v3::GRUCell::validate_and_infer_types()
|
||||
{
|
||||
for (const auto& input : inputs())
|
||||
{
|
||||
if (input.get_partial_shape().rank().is_dynamic())
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
|
||||
return;
|
||||
}
|
||||
}
|
||||
auto merged_batch_size = Dimension::dynamic();
|
||||
auto merged_hidden_size = Dimension::dynamic();
|
||||
auto result_et = element::dynamic;
|
||||
|
||||
199
ngraph/core/src/op/gru_sequence.cpp
Normal file
199
ngraph/core/src/op/gru_sequence.cpp
Normal file
@@ -0,0 +1,199 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph/op/gru_sequence.hpp"
|
||||
#include "ngraph/op/util/recurrent_sequence.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v5::GRUSequence, "GRUSequence", 5);
|
||||
|
||||
op::v5::GRUSequence::GRUSequence()
|
||||
: m_direction(op::RecurrentSequenceDirection::FORWARD)
|
||||
, m_linear_before_reset(false)
|
||||
{
|
||||
}
|
||||
|
||||
op::v5::GRUSequence::GRUSequence(const Output<Node>& X,
|
||||
const Output<Node>& H_t,
|
||||
const Output<Node>& sequence_lengths,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
const Output<Node>& B,
|
||||
std::size_t hidden_size,
|
||||
op::RecurrentSequenceDirection direction,
|
||||
const std::vector<std::string>& activations,
|
||||
const std::vector<float>& activations_alpha,
|
||||
const std::vector<float>& activations_beta,
|
||||
float clip,
|
||||
bool linear_before_reset)
|
||||
: RNNCellBase({X, H_t, sequence_lengths, W, R, B},
|
||||
hidden_size,
|
||||
clip,
|
||||
activations,
|
||||
activations_alpha,
|
||||
activations_beta)
|
||||
, m_direction(direction)
|
||||
, m_linear_before_reset(linear_before_reset)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::v5::GRUSequence::validate_and_infer_types()
|
||||
{
|
||||
for (const auto& input : inputs())
|
||||
{
|
||||
if (input.get_partial_shape().rank().is_dynamic())
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
|
||||
set_output_type(1, get_input_element_type(0), PartialShape::dynamic());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto gru_seq_gates_count = 3;
|
||||
auto merged_batch_size = Dimension::dynamic();
|
||||
auto merged_hidden_size = Dimension::dynamic();
|
||||
auto merged_num_directions = Dimension::dynamic();
|
||||
auto result_et = element::dynamic;
|
||||
|
||||
auto x_pshape = get_input_partial_shape(0);
|
||||
auto ht_pshape = get_input_partial_shape(1);
|
||||
auto sl_pshape = get_input_partial_shape(2);
|
||||
auto w_pshape = get_input_partial_shape(3);
|
||||
auto r_pshape = get_input_partial_shape(4);
|
||||
auto b_pshape = get_input_partial_shape(5);
|
||||
|
||||
ngraph::op::util::validate_seq_input_rank_dimension(
|
||||
{x_pshape, ht_pshape, sl_pshape, w_pshape, r_pshape, b_pshape});
|
||||
|
||||
// Validate input types and save result for output type
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(0)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(1)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(3)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(4)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(5)),
|
||||
"Element types for X, initial_hidden_state, W, R and B inputs do not "
|
||||
"match.");
|
||||
|
||||
// Merge batch_size dimension across all inputs to evaluate output[0] dimension
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, ht_pshape[0]) &&
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, x_pshape[0]) &&
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, sl_pshape[0]),
|
||||
"Parameter batch_size not matched in RNNSequence.");
|
||||
|
||||
// Merge hidden_size dimension across all inputs to evaluate output dimension
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
Dimension::merge(merged_hidden_size, merged_hidden_size, ht_pshape[2]) &&
|
||||
Dimension::merge(merged_hidden_size, merged_hidden_size, r_pshape[2]),
|
||||
"Parameter hidden_size not matched RNNSequence.");
|
||||
|
||||
// Merge num_directions dimension across all inputs to evaluate output dimension
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, ht_pshape[1]) &&
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, w_pshape[0]) &&
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, r_pshape[0]) &&
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, b_pshape[0]),
|
||||
"Parameter num_directions not matched in RNNSequence.");
|
||||
|
||||
// Validate hidden_size value for W, R, B inputs
|
||||
if (merged_hidden_size.is_static())
|
||||
{
|
||||
if (w_pshape[1].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
w_pshape[1].compatible(merged_hidden_size * gru_seq_gates_count),
|
||||
"Parameter hidden_size mistmatched in W input. Current value is: ",
|
||||
w_pshape[1].get_length(),
|
||||
", expected: ",
|
||||
merged_hidden_size.get_length() * gru_seq_gates_count,
|
||||
".");
|
||||
}
|
||||
|
||||
if (r_pshape[1].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
r_pshape[1].compatible(merged_hidden_size * gru_seq_gates_count),
|
||||
"Parameter hidden_size mistmatched in R input. Current value is: ",
|
||||
r_pshape[1].get_length(),
|
||||
", expected: ",
|
||||
merged_hidden_size.get_length() * gru_seq_gates_count,
|
||||
".");
|
||||
}
|
||||
|
||||
if (b_pshape[1].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
b_pshape[1].compatible(merged_hidden_size * (m_linear_before_reset
|
||||
? (gru_seq_gates_count + 1)
|
||||
: gru_seq_gates_count)),
|
||||
"Parameter hidden_size mistmatched in B input. Current value is: ",
|
||||
b_pshape[1].get_length(),
|
||||
", expected: ",
|
||||
merged_hidden_size.get_length() *
|
||||
(m_linear_before_reset ? (gru_seq_gates_count + 1) : gru_seq_gates_count),
|
||||
".");
|
||||
}
|
||||
}
|
||||
|
||||
// Mark inputs which are relevant to output parameters
|
||||
for (size_t i = 0; i <= 5; ++i)
|
||||
set_input_is_relevant_to_shape(i);
|
||||
|
||||
// Set output size, type and shape
|
||||
set_output_size(2);
|
||||
set_output_type(
|
||||
0, result_et, {merged_batch_size, merged_num_directions, x_pshape[1], merged_hidden_size});
|
||||
set_output_type(1, result_et, {merged_batch_size, merged_num_directions, merged_hidden_size});
|
||||
}
|
||||
|
||||
bool op::v5::GRUSequence::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("direction", m_direction);
|
||||
visitor.on_attribute("linear_before_reset", m_linear_before_reset);
|
||||
return op::util::RNNCellBase::visit_attributes(visitor);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v5::GRUSequence::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<op::v5::GRUSequence>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
new_args.at(3),
|
||||
new_args.at(4),
|
||||
new_args.at(5),
|
||||
m_hidden_size,
|
||||
m_direction,
|
||||
m_activations,
|
||||
m_activations_alpha,
|
||||
m_activations_beta,
|
||||
m_clip,
|
||||
m_linear_before_reset);
|
||||
}
|
||||
@@ -142,6 +142,16 @@ bool ngraph::op::v0::LSTMCell::visit_attributes(AttributeVisitor& visitor)
|
||||
|
||||
void op::v0::LSTMCell::validate_and_infer_types()
|
||||
{
|
||||
for (const auto& input : inputs())
|
||||
{
|
||||
if (input.get_partial_shape().rank().is_dynamic())
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
|
||||
set_output_type(1, get_input_element_type(0), PartialShape::dynamic());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ngraph::PartialShape> input_param{};
|
||||
|
||||
auto merged_batch_size = Dimension::dynamic();
|
||||
@@ -436,6 +446,15 @@ bool ngraph::op::v4::LSTMCell::visit_attributes(AttributeVisitor& visitor)
|
||||
|
||||
void op::v4::LSTMCell::validate_and_infer_types()
|
||||
{
|
||||
for (const auto& input : inputs())
|
||||
{
|
||||
if (input.get_partial_shape().rank().is_dynamic())
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
|
||||
set_output_type(1, get_input_element_type(0), PartialShape::dynamic());
|
||||
return;
|
||||
}
|
||||
}
|
||||
auto merged_batch_size = Dimension::dynamic();
|
||||
auto merged_hidden_size = Dimension::dynamic();
|
||||
auto result_et = element::dynamic;
|
||||
@@ -448,11 +467,6 @@ void op::v4::LSTMCell::validate_and_infer_types()
|
||||
const auto& r_pshape = get_input_partial_shape(4);
|
||||
const auto& b_pshape = get_input_partial_shape(5);
|
||||
|
||||
// Validate rank and dimension for initial_cell_state input
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(ct_pshape.rank().is_static()),
|
||||
"LSTMCell input tensor initial_cell_state shall have static rank.");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(ct_pshape.rank().get_length() == 2),
|
||||
"LSTMCell input tensor initial_cell_state shall have dimension 2D.");
|
||||
|
||||
@@ -29,8 +29,8 @@
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
constexpr NodeTypeInfo op::v1::LSTMSequence::type_info;
|
||||
constexpr NodeTypeInfo op::v0::LSTMSequence::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v0::LSTMSequence, "LSTMSequence", 0);
|
||||
NGRAPH_RTTI_DEFINITION(op::v5::LSTMSequence, "LSTMSequence", 5);
|
||||
|
||||
bool ngraph::op::v0::LSTMSequence::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
@@ -353,7 +353,7 @@ void op::v0::LSTMSequence::validate_and_infer_types()
|
||||
// Validate hidden_size value for W, R, B and P inputs
|
||||
if (merged_hidden_size.is_static())
|
||||
{
|
||||
if (w_pshape[0].is_static())
|
||||
if (w_pshape[1].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
@@ -365,7 +365,7 @@ void op::v0::LSTMSequence::validate_and_infer_types()
|
||||
".");
|
||||
}
|
||||
|
||||
if (r_pshape[0].is_static())
|
||||
if (r_pshape[1].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
@@ -377,7 +377,7 @@ void op::v0::LSTMSequence::validate_and_infer_types()
|
||||
".");
|
||||
}
|
||||
|
||||
if (b_pshape[0].is_static())
|
||||
if (b_pshape[1].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
@@ -389,7 +389,7 @@ void op::v0::LSTMSequence::validate_and_infer_types()
|
||||
".");
|
||||
}
|
||||
|
||||
if (p_pshape[0].is_static())
|
||||
if (p_pshape[1].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
@@ -419,18 +419,18 @@ void op::v0::LSTMSequence::validate_and_infer_types()
|
||||
set_output_type(2, result_et, {merged_batch_size, merged_num_directions, merged_hidden_size});
|
||||
}
|
||||
|
||||
bool ngraph::op::v1::LSTMSequence::visit_attributes(AttributeVisitor& visitor)
|
||||
bool ngraph::op::v5::LSTMSequence::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("direction", m_direction);
|
||||
return op::util::RNNCellBase::visit_attributes(visitor);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v1::LSTMSequence::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
shared_ptr<Node> op::v5::LSTMSequence::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
if (new_args.size() == 7)
|
||||
{
|
||||
return make_shared<op::v1::LSTMSequence>(new_args.at(0), // X
|
||||
return make_shared<op::v5::LSTMSequence>(new_args.at(0), // X
|
||||
new_args.at(1), // initial_hidden_state
|
||||
new_args.at(2), // initial_cell_state
|
||||
new_args.at(3), // sequence_lengths
|
||||
@@ -450,8 +450,18 @@ shared_ptr<Node> op::v1::LSTMSequence::clone_with_new_inputs(const OutputVector&
|
||||
}
|
||||
}
|
||||
|
||||
void op::v1::LSTMSequence::validate_and_infer_types()
|
||||
void op::v5::LSTMSequence::validate_and_infer_types()
|
||||
{
|
||||
for (const auto& input : inputs())
|
||||
{
|
||||
if (input.get_partial_shape().rank().is_dynamic())
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
|
||||
set_output_type(1, get_input_element_type(0), PartialShape::dynamic());
|
||||
set_output_type(2, get_input_element_type(0), PartialShape::dynamic());
|
||||
return;
|
||||
}
|
||||
}
|
||||
std::vector<ngraph::PartialShape> input_param{};
|
||||
|
||||
auto lstm_seq_gates_count = 4;
|
||||
@@ -482,10 +492,6 @@ void op::v1::LSTMSequence::validate_and_infer_types()
|
||||
ngraph::op::util::validate_seq_input_rank_dimension(input_param);
|
||||
|
||||
// Validate rank and dimension for initial_cell_state input
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(ct_pshape.rank().is_static()),
|
||||
"LSTMSequence input tensor initial_cell_state shall have static rank.");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(ct_pshape.rank().get_length() == 3),
|
||||
"LSTMSequence input tensor initial_cell_state shall have dimension 3D.");
|
||||
@@ -532,7 +538,7 @@ void op::v1::LSTMSequence::validate_and_infer_types()
|
||||
// Validate hidden_size value for W, R, B inputs
|
||||
if (merged_hidden_size.is_static())
|
||||
{
|
||||
if (w_pshape[0].is_static())
|
||||
if (w_pshape[1].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
@@ -544,7 +550,7 @@ void op::v1::LSTMSequence::validate_and_infer_types()
|
||||
".");
|
||||
}
|
||||
|
||||
if (r_pshape[0].is_static())
|
||||
if (r_pshape[1].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
@@ -556,7 +562,7 @@ void op::v1::LSTMSequence::validate_and_infer_types()
|
||||
".");
|
||||
}
|
||||
|
||||
if (b_pshape[0].is_static())
|
||||
if (b_pshape[1].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
|
||||
@@ -83,6 +83,14 @@ bool op::v0::RNNCell::visit_attributes(AttributeVisitor& visitor)
|
||||
|
||||
void op::v0::RNNCell::validate_and_infer_types()
|
||||
{
|
||||
for (const auto& input : inputs())
|
||||
{
|
||||
if (input.get_partial_shape().rank().is_dynamic())
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
|
||||
return;
|
||||
}
|
||||
}
|
||||
auto merged_batch_size = Dimension::dynamic();
|
||||
auto merged_hidden_size = Dimension::dynamic();
|
||||
auto result_et = element::dynamic;
|
||||
|
||||
192
ngraph/core/src/op/rnn_sequence.cpp
Normal file
192
ngraph/core/src/op/rnn_sequence.cpp
Normal file
@@ -0,0 +1,192 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/op/rnn_sequence.hpp"
|
||||
#include "ngraph/op/util/recurrent_sequence.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(op::v5::RNNSequence, "RNNSequence", 4);
|
||||
|
||||
op::v5::RNNSequence::RNNSequence()
|
||||
: m_direction(op::RecurrentSequenceDirection::FORWARD)
|
||||
{
|
||||
}
|
||||
|
||||
op::v5::RNNSequence::RNNSequence(const Output<Node>& X,
|
||||
const Output<Node>& H_t,
|
||||
const Output<Node>& sequence_lengths,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
const Output<Node>& B,
|
||||
std::size_t hidden_size,
|
||||
op::RecurrentSequenceDirection direction,
|
||||
const std::vector<std::string>& activations,
|
||||
const std::vector<float>& activations_alpha,
|
||||
const std::vector<float>& activations_beta,
|
||||
float clip)
|
||||
: RNNCellBase({X, H_t, sequence_lengths, W, R, B},
|
||||
hidden_size,
|
||||
clip,
|
||||
activations,
|
||||
activations_alpha,
|
||||
activations_beta)
|
||||
, m_direction(direction)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void op::v5::RNNSequence::validate_and_infer_types()
|
||||
{
|
||||
for (const auto& input : inputs())
|
||||
{
|
||||
if (input.get_partial_shape().rank().is_dynamic())
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
|
||||
set_output_type(1, get_input_element_type(0), PartialShape::dynamic());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
auto rnn_seq_gates_count = 1;
|
||||
auto merged_batch_size = Dimension::dynamic();
|
||||
auto merged_hidden_size = Dimension::dynamic();
|
||||
auto merged_num_directions = Dimension::dynamic();
|
||||
auto result_et = element::dynamic;
|
||||
|
||||
auto x_pshape = get_input_partial_shape(0);
|
||||
auto ht_pshape = get_input_partial_shape(1);
|
||||
auto sl_pshape = get_input_partial_shape(2);
|
||||
auto w_pshape = get_input_partial_shape(3);
|
||||
auto r_pshape = get_input_partial_shape(4);
|
||||
auto b_pshape = get_input_partial_shape(5);
|
||||
|
||||
ngraph::op::util::validate_seq_input_rank_dimension(
|
||||
{x_pshape, ht_pshape, sl_pshape, w_pshape, r_pshape, b_pshape});
|
||||
|
||||
// Validate input types and save result for output type
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(0)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(1)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(3)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(4)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(5)),
|
||||
"Element types for X, initial_hidden_state, W, R and B inputs do not "
|
||||
"match.");
|
||||
|
||||
// Merge batch_size dimension across all inputs to evaluate output[0] dimension
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, ht_pshape[0]) &&
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, x_pshape[0]) &&
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, sl_pshape[0]),
|
||||
"Parameter batch_size not matched in RNNSequence.");
|
||||
|
||||
// Merge hidden_size dimension across all inputs to evaluate output dimension
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
Dimension::merge(merged_hidden_size, merged_hidden_size, ht_pshape[2]) &&
|
||||
Dimension::merge(merged_hidden_size, merged_hidden_size, r_pshape[2]),
|
||||
"Parameter hidden_size not matched RNNSequence.");
|
||||
|
||||
// Merge num_directions dimension across all inputs to evaluate output dimension
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, ht_pshape[1]) &&
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, w_pshape[0]) &&
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, r_pshape[0]) &&
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, b_pshape[0]),
|
||||
"Parameter num_directions not matched in RNNSequence.");
|
||||
|
||||
// Validate hidden_size value for W, R, B inputs
|
||||
if (merged_hidden_size.is_static())
|
||||
{
|
||||
if (w_pshape[1].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
w_pshape[1].compatible(merged_hidden_size * rnn_seq_gates_count),
|
||||
"Parameter hidden_size mistmatched in W input. Current value is: ",
|
||||
w_pshape[1].get_length(),
|
||||
", expected: ",
|
||||
merged_hidden_size.get_length() * rnn_seq_gates_count,
|
||||
".");
|
||||
}
|
||||
|
||||
if (r_pshape[1].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
r_pshape[1].compatible(merged_hidden_size * rnn_seq_gates_count),
|
||||
"Parameter hidden_size mistmatched in R input. Current value is: ",
|
||||
r_pshape[1].get_length(),
|
||||
", expected: ",
|
||||
merged_hidden_size.get_length() * rnn_seq_gates_count,
|
||||
".");
|
||||
}
|
||||
|
||||
if (b_pshape[1].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
b_pshape[1].compatible(merged_hidden_size * rnn_seq_gates_count),
|
||||
"Parameter hidden_size mistmatched in B input. Current value is: ",
|
||||
b_pshape[1].get_length(),
|
||||
", expected: ",
|
||||
merged_hidden_size.get_length() * rnn_seq_gates_count,
|
||||
".");
|
||||
}
|
||||
}
|
||||
|
||||
// Mark inputs which are relevant to output parameters
|
||||
for (size_t i = 0; i <= 5; ++i)
|
||||
set_input_is_relevant_to_shape(i);
|
||||
|
||||
// Set output size, type and shape
|
||||
set_output_size(2);
|
||||
set_output_type(
|
||||
0, result_et, {merged_batch_size, merged_num_directions, x_pshape[1], merged_hidden_size});
|
||||
set_output_type(1, result_et, {merged_batch_size, merged_num_directions, merged_hidden_size});
|
||||
}
|
||||
|
||||
bool op::v5::RNNSequence::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("direction", m_direction);
|
||||
return op::util::RNNCellBase::visit_attributes(visitor);
|
||||
}
|
||||
|
||||
shared_ptr<Node>
|
||||
op::v5::RNNSequence::clone_with_new_inputs(const ngraph::OutputVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<op::v5::RNNSequence>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
new_args.at(3),
|
||||
new_args.at(4),
|
||||
new_args.at(5),
|
||||
m_hidden_size,
|
||||
m_direction,
|
||||
m_activations,
|
||||
m_activations_alpha,
|
||||
m_activations_beta,
|
||||
m_clip);
|
||||
}
|
||||
@@ -129,6 +129,7 @@ set(SRC
|
||||
type_prop/group_convolution.cpp
|
||||
type_prop/group_convolution_backprop_data.cpp
|
||||
type_prop/gru_cell.cpp
|
||||
type_prop/gru_sequence.cpp
|
||||
type_prop/hard_sigmoid.cpp
|
||||
type_prop/hswish.cpp
|
||||
type_prop/interpolate.cpp
|
||||
@@ -160,6 +161,7 @@ set(SRC
|
||||
type_prop/reverse_sequence.cpp
|
||||
type_prop/roi_align.cpp
|
||||
type_prop/rnn_cell.cpp
|
||||
type_prop/rnn_sequence.cpp
|
||||
type_prop/scatter_elements_update.cpp
|
||||
type_prop/scatter_nd_update.cpp
|
||||
type_prop/scatter_update.cpp
|
||||
|
||||
@@ -1099,7 +1099,7 @@ TEST(attributes, lstm_cell_op)
|
||||
|
||||
TEST(attributes, lstm_sequence_op)
|
||||
{
|
||||
FactoryRegistry<Node>::get().register_factory<op::v1::LSTMSequence>();
|
||||
FactoryRegistry<Node>::get().register_factory<op::v5::LSTMSequence>();
|
||||
|
||||
const size_t batch_size = 4;
|
||||
const size_t num_directions = 2;
|
||||
@@ -1126,7 +1126,7 @@ TEST(attributes, lstm_sequence_op)
|
||||
const std::vector<std::string> activations = {"tanh", "sigmoid", "tanh"};
|
||||
const float clip_threshold = 0.5f;
|
||||
|
||||
const auto lstm_sequence = make_shared<op::v1::LSTMSequence>(X,
|
||||
const auto lstm_sequence = make_shared<op::v5::LSTMSequence>(X,
|
||||
initial_hidden_state,
|
||||
initial_cell_state,
|
||||
sequence_lengths,
|
||||
@@ -1140,7 +1140,7 @@ TEST(attributes, lstm_sequence_op)
|
||||
activations,
|
||||
clip_threshold);
|
||||
NodeBuilder builder(lstm_sequence);
|
||||
auto g_lstm_sequence = as_type_ptr<op::v1::LSTMSequence>(builder.create());
|
||||
auto g_lstm_sequence = as_type_ptr<op::v5::LSTMSequence>(builder.create());
|
||||
|
||||
EXPECT_EQ(g_lstm_sequence->get_hidden_size(), lstm_sequence->get_hidden_size());
|
||||
EXPECT_EQ(g_lstm_sequence->get_activations(), lstm_sequence->get_activations());
|
||||
|
||||
@@ -85,6 +85,7 @@
|
||||
#include "ngraph/runtime/reference/round.hpp"
|
||||
#include "ngraph/runtime/reference/scatter_nd_update.hpp"
|
||||
#include "ngraph/runtime/reference/select.hpp"
|
||||
#include "ngraph/runtime/reference/sequences.hpp"
|
||||
#include "ngraph/runtime/reference/sigmoid.hpp"
|
||||
#include "ngraph/runtime/reference/sign.hpp"
|
||||
#include "ngraph/runtime/reference/sin.hpp"
|
||||
@@ -758,6 +759,82 @@ protected:
|
||||
rnn_cell->get_clip());
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::LSTMSequence:
|
||||
case OP_TYPEID::LSTMSequence_v5:
|
||||
{
|
||||
auto lstm_seq = static_cast<const op::v5::LSTMSequence*>(&node);
|
||||
runtime::reference::lstm_sequence<T>(args[0]->get_data_ptr<char>(),
|
||||
args[0]->get_shape(),
|
||||
args[1]->get_data_ptr<char>(),
|
||||
args[1]->get_shape(),
|
||||
args[2]->get_data_ptr<char>(),
|
||||
args[2]->get_shape(),
|
||||
args[3]->get_data_ptr<char>(),
|
||||
args[3]->get_shape(),
|
||||
args[4]->get_data_ptr<char>(),
|
||||
args[4]->get_shape(),
|
||||
args[5]->get_data_ptr<char>(),
|
||||
args[5]->get_shape(),
|
||||
args[6]->get_data_ptr<char>(),
|
||||
args[6]->get_shape(),
|
||||
out[0]->get_data_ptr<char>(),
|
||||
out[1]->get_data_ptr<char>(),
|
||||
out[2]->get_data_ptr<char>(),
|
||||
lstm_seq->get_activations()[0],
|
||||
lstm_seq->get_activations()[1],
|
||||
lstm_seq->get_activations()[2],
|
||||
lstm_seq->get_clip(),
|
||||
lstm_seq->get_direction());
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::GRUSequence_v5:
|
||||
{
|
||||
auto gru_seq = static_cast<const op::v5::GRUSequence*>(&node);
|
||||
runtime::reference::gru_sequence<T>(args[0]->get_data_ptr<char>(),
|
||||
args[0]->get_shape(),
|
||||
args[1]->get_data_ptr<char>(),
|
||||
args[1]->get_shape(),
|
||||
args[2]->get_data_ptr<char>(),
|
||||
args[2]->get_shape(),
|
||||
args[3]->get_data_ptr<char>(),
|
||||
args[3]->get_shape(),
|
||||
args[4]->get_data_ptr<char>(),
|
||||
args[4]->get_shape(),
|
||||
args[5]->get_data_ptr<char>(),
|
||||
args[5]->get_shape(),
|
||||
out[0]->get_data_ptr<char>(),
|
||||
out[1]->get_data_ptr<char>(),
|
||||
gru_seq->get_activations()[0],
|
||||
gru_seq->get_activations()[1],
|
||||
gru_seq->get_clip(),
|
||||
gru_seq->get_direction(),
|
||||
gru_seq->get_linear_before_reset()
|
||||
|
||||
);
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::RNNSequence_v5:
|
||||
{
|
||||
auto rnn_seq = static_cast<const op::v5::RNNSequence*>(&node);
|
||||
runtime::reference::rnn_sequence<T>(args[0]->get_data_ptr<char>(),
|
||||
args[0]->get_shape(),
|
||||
args[1]->get_data_ptr<char>(),
|
||||
args[1]->get_shape(),
|
||||
args[2]->get_data_ptr<char>(),
|
||||
args[2]->get_shape(),
|
||||
args[3]->get_data_ptr<char>(),
|
||||
args[3]->get_shape(),
|
||||
args[4]->get_data_ptr<char>(),
|
||||
args[4]->get_shape(),
|
||||
args[5]->get_data_ptr<char>(),
|
||||
args[5]->get_shape(),
|
||||
out[0]->get_data_ptr<char>(),
|
||||
out[1]->get_data_ptr<char>(),
|
||||
rnn_seq->get_activations()[0],
|
||||
rnn_seq->get_clip(),
|
||||
rnn_seq->get_direction());
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::Log:
|
||||
{
|
||||
size_t element_count = shape_size(node.get_output_shape(0));
|
||||
@@ -1285,7 +1362,6 @@ protected:
|
||||
case OP_TYPEID::GroupConvolutionBackpropData:
|
||||
case OP_TYPEID::HardSigmoid:
|
||||
case OP_TYPEID::Interpolate:
|
||||
case OP_TYPEID::LSTMSequence:
|
||||
case OP_TYPEID::MVN:
|
||||
case OP_TYPEID::NormalizeL2:
|
||||
case OP_TYPEID::PRelu:
|
||||
|
||||
@@ -48,3 +48,9 @@ NGRAPH_OP(ScatterUpdate, op::v3)
|
||||
NGRAPH_OP(CTCLoss, op::v4)
|
||||
NGRAPH_OP(LSTMCell, op::v4)
|
||||
#undef ID_SUFFIX
|
||||
|
||||
#define ID_SUFFIX(NAME) NAME##_v5
|
||||
NGRAPH_OP(LSTMSequence, op::v5)
|
||||
NGRAPH_OP(GRUSequence, op::v5)
|
||||
NGRAPH_OP(RNNSequence, op::v5)
|
||||
#undef ID_SUFFIX
|
||||
|
||||
@@ -231,38 +231,38 @@ TEST(type_prop, gru_cell_invalid_input_dynamic_rank)
|
||||
PartialShape{gates_count * hidden_size, hidden_size});
|
||||
auto H_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
|
||||
auto check_dynamic_gru = [](const shared_ptr<opset4::GRUCell>& gru) -> bool {
|
||||
return gru->output(0).get_partial_shape() == PartialShape::dynamic() &&
|
||||
gru->output(0).get_element_type() == gru->input(0).get_element_type();
|
||||
};
|
||||
|
||||
// Invalid dynamic rank for W tensor.
|
||||
auto W = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "GRUCell node was created with invalid data.";
|
||||
auto gru_w = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_gru(gru_w), true);
|
||||
|
||||
// Invalid dynamic rank for X tensor.
|
||||
W = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, input_size});
|
||||
X = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "GRUCell node was created with invalid data.";
|
||||
auto gru_x = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_gru(gru_x), true);
|
||||
|
||||
// Invalid dynamic rank for H_t tensor.
|
||||
X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
H_t = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "GRUCell node was created with invalid data.";
|
||||
auto gru_h = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_gru(gru_h), true);
|
||||
|
||||
// Invalid dynamic rank for R tensor.
|
||||
H_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
R = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "GRUCell node was created with invalid data.";
|
||||
auto gru_r = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_gru(gru_r), true);
|
||||
|
||||
// Invalid dynamic rank for B tensor.
|
||||
R = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, hidden_size});
|
||||
auto B = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, B, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "GRUCell node was created with invalid data.";
|
||||
auto gru_b = make_shared<opset4::GRUCell>(X, H_t, W, R, B, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_gru(gru_b), true);
|
||||
}
|
||||
|
||||
64
ngraph/test/type_prop/gru_sequence.cpp
Normal file
64
ngraph/test/type_prop/gru_sequence.cpp
Normal file
@@ -0,0 +1,64 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, gru_sequence_forward)
|
||||
{
|
||||
const size_t batch_size = 8;
|
||||
const size_t num_directions = 1;
|
||||
const size_t seq_length = 6;
|
||||
const size_t input_size = 4;
|
||||
const size_t hidden_size = 128;
|
||||
|
||||
const auto X =
|
||||
make_shared<opset4::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
|
||||
const auto initial_hidden_state = make_shared<opset4::Parameter>(
|
||||
element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{batch_size});
|
||||
const auto W = make_shared<opset4::Parameter>(
|
||||
element::f32, Shape{num_directions, 3 * hidden_size, input_size});
|
||||
const auto R = make_shared<opset4::Parameter>(
|
||||
element::f32, Shape{num_directions, 3 * hidden_size, hidden_size});
|
||||
const auto B =
|
||||
make_shared<opset4::Parameter>(element::f32, Shape{num_directions, 3 * hidden_size});
|
||||
|
||||
const auto direction = op::RecurrentSequenceDirection::FORWARD;
|
||||
|
||||
const auto sequence = make_shared<op::v5::GRUSequence>(
|
||||
X, initial_hidden_state, sequence_lengths, W, R, B, hidden_size, direction);
|
||||
|
||||
EXPECT_EQ(sequence->get_hidden_size(), hidden_size);
|
||||
EXPECT_EQ(sequence->get_direction(), op::RecurrentSequenceDirection::FORWARD);
|
||||
EXPECT_TRUE(sequence->get_activations_alpha().empty());
|
||||
EXPECT_TRUE(sequence->get_activations_beta().empty());
|
||||
EXPECT_EQ(sequence->get_activations()[0], "sigmoid");
|
||||
EXPECT_EQ(sequence->get_activations()[1], "tanh");
|
||||
EXPECT_EQ(sequence->get_clip(), 0.f);
|
||||
EXPECT_EQ(sequence->get_linear_before_reset(), false);
|
||||
EXPECT_EQ(sequence->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(sequence->outputs().size(), 2);
|
||||
EXPECT_EQ(sequence->get_output_shape(0),
|
||||
(Shape{batch_size, num_directions, seq_length, hidden_size}));
|
||||
EXPECT_EQ(sequence->get_output_element_type(1), element::f32);
|
||||
EXPECT_EQ(sequence->get_output_shape(1), (Shape{batch_size, num_directions, hidden_size}));
|
||||
}
|
||||
@@ -290,46 +290,46 @@ TEST(type_prop, lstm_cell_invalid_input_dynamic_rank)
|
||||
auto H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
auto C_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
|
||||
auto check_dynamic_lstm = [](const shared_ptr<opset4::LSTMCell>& lstm) -> bool {
|
||||
return lstm->output(0).get_partial_shape() == PartialShape::dynamic() &&
|
||||
lstm->output(1).get_partial_shape() == PartialShape::dynamic() &&
|
||||
lstm->output(0).get_element_type() == lstm->input(0).get_element_type();
|
||||
};
|
||||
|
||||
// Invalid dynamic rank for W tensor.
|
||||
W = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
auto lstm = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_lstm(lstm), true);
|
||||
|
||||
// Invalid dynamic rank for X tensor.
|
||||
W = make_shared<opset4::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, input_size});
|
||||
X = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
lstm = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_lstm(lstm), true);
|
||||
|
||||
// Invalid dynamic rank for H_t tensor.
|
||||
X = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
H_t = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
lstm = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_lstm(lstm), true);
|
||||
|
||||
// Invalid dynamic rank for C_t tensor.
|
||||
H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
C_t = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
lstm = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_lstm(lstm), true);
|
||||
|
||||
// Invalid dynamic rank for R tensor.
|
||||
C_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
R = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
lstm = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_lstm(lstm), true);
|
||||
|
||||
// Invalid dynamic rank for B tensor.
|
||||
R = make_shared<opset4::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, hidden_size});
|
||||
auto B = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, B, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
lstm = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, B, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_lstm(lstm), true);
|
||||
}
|
||||
|
||||
@@ -41,7 +41,7 @@ struct recurrent_sequence_parameters
|
||||
//
|
||||
// Create and initialize default input test tensors.
|
||||
//
|
||||
shared_ptr<op::v1::LSTMSequence>
|
||||
shared_ptr<op::v5::LSTMSequence>
|
||||
lstm_seq_tensor_initialization(const recurrent_sequence_parameters& param)
|
||||
{
|
||||
auto batch_size = param.batch_size;
|
||||
@@ -65,7 +65,7 @@ shared_ptr<op::v1::LSTMSequence>
|
||||
const auto B =
|
||||
make_shared<opset4::Parameter>(et, PartialShape{num_directions, hidden_size * 4});
|
||||
|
||||
const auto lstm_sequence = make_shared<op::v1::LSTMSequence>();
|
||||
const auto lstm_sequence = make_shared<op::v5::LSTMSequence>();
|
||||
|
||||
lstm_sequence->set_argument(0, X);
|
||||
lstm_sequence->set_argument(1, initial_hidden_state);
|
||||
@@ -102,7 +102,7 @@ TEST(type_prop, lstm_sequence_forward)
|
||||
|
||||
const auto lstm_direction = op::RecurrentSequenceDirection::FORWARD;
|
||||
|
||||
const auto lstm_sequence = make_shared<op::v1::LSTMSequence>(X,
|
||||
const auto lstm_sequence = make_shared<op::v5::LSTMSequence>(X,
|
||||
initial_hidden_state,
|
||||
initial_cell_state,
|
||||
sequence_lengths,
|
||||
@@ -121,6 +121,7 @@ TEST(type_prop, lstm_sequence_forward)
|
||||
EXPECT_EQ(lstm_sequence->get_activations()[2], "tanh");
|
||||
EXPECT_EQ(lstm_sequence->get_clip(), 0.f);
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(lstm_sequence->outputs().size(), 3);
|
||||
EXPECT_EQ(lstm_sequence->get_output_shape(0),
|
||||
(Shape{batch_size, num_directions, seq_length, hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(1), element::f32);
|
||||
@@ -151,12 +152,12 @@ TEST(type_prop, lstm_sequence_bidirectional)
|
||||
const auto B =
|
||||
make_shared<opset4::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
|
||||
|
||||
const auto lstm_direction = op::v1::LSTMSequence::direction::BIDIRECTIONAL;
|
||||
const auto lstm_direction = op::v5::LSTMSequence::direction::BIDIRECTIONAL;
|
||||
const std::vector<float> activations_alpha = {2.7, 7.0, 32.367};
|
||||
const std::vector<float> activations_beta = {0.0, 5.49, 6.0};
|
||||
const std::vector<std::string> activations = {"tanh", "sigmoid", "sigmoid"};
|
||||
|
||||
const auto lstm_sequence = make_shared<op::v1::LSTMSequence>(X,
|
||||
const auto lstm_sequence = make_shared<op::v5::LSTMSequence>(X,
|
||||
initial_hidden_state,
|
||||
initial_cell_state,
|
||||
sequence_lengths,
|
||||
@@ -169,7 +170,7 @@ TEST(type_prop, lstm_sequence_bidirectional)
|
||||
activations_beta,
|
||||
activations);
|
||||
EXPECT_EQ(lstm_sequence->get_hidden_size(), hidden_size);
|
||||
EXPECT_EQ(lstm_sequence->get_direction(), op::v1::LSTMSequence::direction::BIDIRECTIONAL);
|
||||
EXPECT_EQ(lstm_sequence->get_direction(), op::v5::LSTMSequence::direction::BIDIRECTIONAL);
|
||||
EXPECT_EQ(lstm_sequence->get_activations_alpha(), activations_alpha);
|
||||
EXPECT_EQ(lstm_sequence->get_activations_beta(), activations_beta);
|
||||
EXPECT_EQ(lstm_sequence->get_activations()[0], "tanh");
|
||||
@@ -351,6 +352,13 @@ TEST(type_prop, lstm_sequence_invalid_input_dynamic_rank)
|
||||
param.hidden_size = 256;
|
||||
param.et = element::f32;
|
||||
|
||||
auto check_dynamic_lstm = [](const shared_ptr<op::v5::LSTMSequence>& lstm) -> bool {
|
||||
return lstm->output(0).get_partial_shape() == PartialShape::dynamic() &&
|
||||
lstm->output(1).get_partial_shape() == PartialShape::dynamic() &&
|
||||
lstm->output(2).get_partial_shape() == PartialShape::dynamic() &&
|
||||
lstm->output(0).get_element_type() == lstm->input(0).get_element_type();
|
||||
};
|
||||
|
||||
auto lstm_sequence = lstm_seq_tensor_initialization(param);
|
||||
auto invalid_dynamic_tensor =
|
||||
make_shared<opset4::Parameter>(param.et, PartialShape::dynamic(Rank::dynamic()));
|
||||
@@ -361,7 +369,7 @@ TEST(type_prop, lstm_sequence_invalid_input_dynamic_rank)
|
||||
{
|
||||
lstm_sequence = lstm_seq_tensor_initialization(param);
|
||||
lstm_sequence->set_argument(i, invalid_dynamic_tensor);
|
||||
ASSERT_THROW(lstm_sequence->validate_and_infer_types(), ngraph::CheckFailure)
|
||||
<< "LSTMSequence node was created with invalid data.";
|
||||
lstm_sequence->validate_and_infer_types();
|
||||
EXPECT_EQ(check_dynamic_lstm(lstm_sequence), true);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -223,37 +223,36 @@ TEST(type_prop, rnn_cell_invalid_input_dynamic_rank)
|
||||
auto R = make_shared<opset4::Parameter>(element::f32, Shape{hidden_size, hidden_size});
|
||||
auto H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
|
||||
auto check_dynamic_rnn = [](const shared_ptr<opset4::RNNCell>& rnn) -> bool {
|
||||
return rnn->output(0).get_partial_shape() == PartialShape::dynamic() &&
|
||||
rnn->output(0).get_element_type() == rnn->input(0).get_element_type();
|
||||
};
|
||||
// Invalid dynamic rank for W tensor.
|
||||
auto W = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "RNNCell node was created with invalid data.";
|
||||
auto rnn_w = make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_rnn(rnn_w), true);
|
||||
|
||||
// Invalid dynamic rank for X tensor.
|
||||
W = make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, input_size});
|
||||
X = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "RNNCell node was created with invalid data.";
|
||||
auto rnn_x = make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_rnn(rnn_x), true);
|
||||
|
||||
// Invalid dynamic rank for H_t tensor.
|
||||
X = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, input_size});
|
||||
H_t = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "RNNCell node was created with invalid data.";
|
||||
auto rnn_h = make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_rnn(rnn_h), true);
|
||||
|
||||
// Invalid dynamic rank for R tensor.
|
||||
H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
R = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "RNNCell node was created with invalid data.";
|
||||
auto rnn_r = make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_rnn(rnn_r), true);
|
||||
|
||||
// Invalid dynamic rank for B tensor.
|
||||
R = make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
|
||||
auto B = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, B, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "RNNCell node was created with invalid data.";
|
||||
auto rnn_b = make_shared<opset4::RNNCell>(X, H_t, W, R, B, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_rnn(rnn_b), true);
|
||||
}
|
||||
|
||||
62
ngraph/test/type_prop/rnn_sequence.cpp
Normal file
62
ngraph/test/type_prop/rnn_sequence.cpp
Normal file
@@ -0,0 +1,62 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
TEST(type_prop, rnn_sequence_forward)
|
||||
{
|
||||
const size_t batch_size = 8;
|
||||
const size_t num_directions = 1;
|
||||
const size_t seq_length = 6;
|
||||
const size_t input_size = 4;
|
||||
const size_t hidden_size = 128;
|
||||
|
||||
const auto X =
|
||||
make_shared<opset4::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
|
||||
const auto initial_hidden_state = make_shared<opset4::Parameter>(
|
||||
element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{batch_size});
|
||||
|
||||
const auto W = make_shared<opset4::Parameter>(element::f32,
|
||||
Shape{num_directions, hidden_size, input_size});
|
||||
const auto R = make_shared<opset4::Parameter>(element::f32,
|
||||
Shape{num_directions, hidden_size, hidden_size});
|
||||
const auto B = make_shared<opset4::Parameter>(element::f32, Shape{num_directions, hidden_size});
|
||||
|
||||
const auto direction = op::RecurrentSequenceDirection::FORWARD;
|
||||
|
||||
const auto sequence = make_shared<op::v5::RNNSequence>(
|
||||
X, initial_hidden_state, sequence_lengths, W, R, B, hidden_size, direction);
|
||||
|
||||
EXPECT_EQ(sequence->get_hidden_size(), hidden_size);
|
||||
EXPECT_EQ(sequence->get_direction(), op::RecurrentSequenceDirection::FORWARD);
|
||||
EXPECT_TRUE(sequence->get_activations_alpha().empty());
|
||||
EXPECT_TRUE(sequence->get_activations_beta().empty());
|
||||
EXPECT_EQ(sequence->get_activations()[0], "tanh");
|
||||
EXPECT_EQ(sequence->get_clip(), 0.f);
|
||||
EXPECT_EQ(sequence->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(sequence->outputs().size(), 2);
|
||||
EXPECT_EQ(sequence->get_output_shape(0),
|
||||
(Shape{batch_size, num_directions, seq_length, hidden_size}));
|
||||
EXPECT_EQ(sequence->get_output_element_type(1), element::f32);
|
||||
EXPECT_EQ(sequence->get_output_shape(1), (Shape{batch_size, num_directions, hidden_size}));
|
||||
}
|
||||
Reference in New Issue
Block a user