Add slt in template plugin/lstm sequence (#9305)

* Remove fp16 of Convert layer test from skip_tests.config.cpp as it works now

* update repo

* initial code commit

* add runtime reference

* apply ov::Model

* initial lstmcell-1 definition

* initial change

* apply Peepholes

* apply input_forget option

* apply initial test case of lstmsequence-1

* fix clang-format error

* fix clang-format error 2

* add lstms_sequence test cases by runtime reference and onnx test cases

* fix clang-format error

* fix clang-format error

* fix onnx test failure of LSTM IE_CPU

* fix clang-format issue

* fix clang-format issue 2

* add type_prop and visitor api test of lstm_sequence_v1

* fix clang-format error

* replace input/refOut data to hard coded and remove unnecessary enum definition

* update namespace of Tensor()

* remove supported test cases in disabling list
This commit is contained in:
Wilson Seok 2022-01-27 20:49:32 -08:00 committed by GitHub
parent 658d9c3633
commit 1e0470f4e7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 4323 additions and 56 deletions

View File

@ -2480,7 +2480,7 @@ bool evaluate(const shared_ptr<op::v0::RNNCell>& op, const HostTensorVector& out
template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v0::LSTMCell>& op, const HostTensorVector& outputs, const HostTensorVector& inputs) {
using T = typename element_type_traits<ET>::value_type;
runtime::reference::lstm_cell<T>(inputs[0]->get_data_ptr<ET>(),
runtime::reference::lstm_cell_v1<T>(inputs[0]->get_data_ptr<ET>(),
inputs[0]->get_shape(),
inputs[1]->get_data_ptr<ET>(),
inputs[1]->get_shape(),
@ -2492,12 +2492,16 @@ bool evaluate(const shared_ptr<op::v0::LSTMCell>& op, const HostTensorVector& ou
inputs[4]->get_shape(),
inputs[5]->get_data_ptr<ET>(),
inputs[5]->get_shape(),
inputs[6]->get_data_ptr<ET>(),
inputs[6]->get_shape(),
outputs[0]->get_data_ptr<ET>(),
outputs[1]->get_data_ptr<ET>(),
op->get_activations()[0],
op->get_activations()[1],
op->get_activations()[2],
op->get_clip());
op->get_clip(),
op->get_weights_format(),
op->get_input_forget());
return true;
}
@ -2592,6 +2596,42 @@ bool evaluate(const shared_ptr<op::v5::RNNSequence>& op,
return true;
}
namespace lstm_seq_v1 {
template <element::Type_t t1, element::Type_t t2>
inline void evaluate(const shared_ptr<op::v0::LSTMSequence>& op,
const HostTensorVector& outputs,
const HostTensorVector& inputs) {
using T1 = typename element_type_traits<t1>::value_type;
using T2 = typename element_type_traits<t2>::value_type;
runtime::reference::lstm_sequence_v1<T1, T2>(inputs[0]->get_data_ptr<char>(),
inputs[0]->get_shape(),
inputs[1]->get_data_ptr<char>(),
inputs[1]->get_shape(),
inputs[2]->get_data_ptr<char>(),
inputs[2]->get_shape(),
inputs[3]->get_data_ptr<char>(),
inputs[3]->get_shape(),
inputs[4]->get_data_ptr<char>(),
inputs[4]->get_shape(),
inputs[5]->get_data_ptr<char>(),
inputs[5]->get_shape(),
inputs[6]->get_data_ptr<char>(),
inputs[6]->get_shape(),
inputs[7]->get_data_ptr<char>(),
inputs[7]->get_shape(),
outputs[0]->get_data_ptr<char>(),
outputs[1]->get_data_ptr<char>(),
outputs[2]->get_data_ptr<char>(),
op->get_activations()[0],
op->get_activations()[1],
op->get_activations()[2],
op->get_clip_threshold(),
op->get_weights_format(),
op->get_input_forget(),
op->get_direction());
}
} // namespace lstm_seq_v1
namespace lstm_seq_v5 {
template <element::Type_t t1, element::Type_t t2>
inline void evaluate(const shared_ptr<op::v5::LSTMSequence>& op,
@ -2624,6 +2664,25 @@ inline void evaluate(const shared_ptr<op::v5::LSTMSequence>& op,
}
} // namespace lstm_seq_v5
template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v0::LSTMSequence>& op,
const HostTensorVector& outputs,
const HostTensorVector& inputs) {
switch (inputs[3]->get_element_type()) {
case element::Type_t::i64:
case element::Type_t::u64:
lstm_seq_v1::evaluate<ET, element::Type_t::i64>(op, outputs, inputs);
break;
case element::Type_t::i32:
case element::Type_t::u32:
lstm_seq_v1::evaluate<ET, element::Type_t::i32>(op, outputs, inputs);
break;
default:
return false;
}
return true;
}
template <element::Type_t ET>
bool evaluate(const shared_ptr<op::v5::LSTMSequence>& op,
const HostTensorVector& outputs,

View File

@ -22,6 +22,7 @@ NGRAPH_OP(HardSigmoid, op::v0)
NGRAPH_OP(Interpolate, op::v0)
NGRAPH_OP(LRN, ngraph::op::v0)
NGRAPH_OP(LSTMCell, op::v0)
NGRAPH_OP(LSTMSequence, op::v0)
NGRAPH_OP(MVN, ngraph::op::v0)
NGRAPH_OP(NormalizeL2, op::v0)
NGRAPH_OP(PriorBox, ngraph::op::v0)

View File

@ -23,6 +23,7 @@ struct LSTMCellParams {
reference_tests::Tensor H_t;
reference_tests::Tensor C_t;
reference_tests::Tensor B;
reference_tests::Tensor P;
reference_tests::Tensor Ho;
reference_tests::Tensor Co;
std::string testcaseName;
@ -39,6 +40,7 @@ struct Builder : ParamsBuilder<LSTMCellParams> {
REFERENCE_TESTS_ADD_SET_PARAM(Builder, H_t);
REFERENCE_TESTS_ADD_SET_PARAM(Builder, C_t);
REFERENCE_TESTS_ADD_SET_PARAM(Builder, B);
REFERENCE_TESTS_ADD_SET_PARAM(Builder, P);
REFERENCE_TESTS_ADD_SET_PARAM(Builder, Ho);
REFERENCE_TESTS_ADD_SET_PARAM(Builder, Co);
REFERENCE_TESTS_ADD_SET_PARAM(Builder, testcaseName);
@ -236,6 +238,15 @@ private:
};
class ReferenceLSTMCellV1TestBiasClip : public ReferenceLSTMCellTestBiasClip {
public:
void SetUp() override {
threshold = 1e-1f;
auto params = GetParam();
function = CreateFunction(params);
inputData = {params.X.data, params.H_t.data, params.C_t.data, params.W.data, params.R.data, params.B.data, params.P.data};
refOutData = {params.Ho.data, params.Co.data};
}
private:
static std::shared_ptr<Model> CreateFunction(const LSTMCellParams& params) {
const float clip_threshold = 3.5f;
@ -246,6 +257,7 @@ private:
const auto H_t = std::make_shared<opset1::Parameter>(params.H_t.type, params.H_t.shape);
const auto C_t = std::make_shared<opset1::Parameter>(params.C_t.type, params.C_t.shape);
const auto B = std::make_shared<opset1::Parameter>(params.B.type, params.B.shape);
const auto P = std::make_shared<opset1::Parameter>(params.P.type, params.P.shape);
const auto lstm_cell =
std::make_shared<opset1::LSTMCell>(X,
@ -254,14 +266,16 @@ private:
W,
R,
B,
P,
params.hiddenSize,
op::LSTMWeightsFormat::IFCO,
op::LSTMWeightsFormat::FICO,
std::vector<std::string>{"sigmoid", "tanh", "tanh"},
std::vector<float>{},
std::vector<float>{},
clip_threshold);
clip_threshold,
false);
auto function = std::make_shared<Model>(lstm_cell->outputs(), ParameterVector{X, H_t, C_t, W, R, B});
auto function = std::make_shared<Model>(lstm_cell->outputs(), ParameterVector{X, H_t, C_t, W, R, B, P});
return function;
}
};
@ -308,6 +322,7 @@ std::vector<LSTMCellParams> generateParams() {
.C_t(reference_tests::Tensor(ET, {2, 3}, std::vector<T>{
0.8488452f, 0.18851636f, 0.5020695f, 0.29716516f, 0.06740791f, 0.45384037f}))
.B(reference_tests::Tensor(ET, {4 * 3}, std::vector<T>(4 * 3, 0.f)))
.P(reference_tests::Tensor(ET, {3 * 3}, std::vector<T>(3 * 3, 0.f)))
.Ho(reference_tests::Tensor(ET, {2, 3}, std::vector<T>{0.81457126f, 0.61109227f, 0.769522f, 0.52239674f, 0.4324641f, 0.63183f}))
.Co(reference_tests::Tensor(ET, {2, 3}, std::vector<T>{1.4444952f, 0.9635685f, 1.2875274f, 0.8053419f, 0.7184521f, 0.95803297f}))
.testcaseName("lstm_cell_zero_bias_default_attrs")
@ -371,6 +386,7 @@ std::vector<LSTMCellParams> generateParamsBiasDefaultAttrs() {
0.51022074f,
1.11389844f,
0.74174305f}))
.P(reference_tests::Tensor(ET, {3 * 3}, std::vector<T>(3 * 3, 0.f)))
.Ho(reference_tests::Tensor(ET, {2, 3}, std::vector<T>{0.81014400720596313,
0.76665538549423218,
0.82509011030197144,
@ -444,6 +460,7 @@ std::vector<LSTMCellParams> generateParamsBiasClip() {
0.51022074f,
1.11389844f,
0.74174305f}))
.P(reference_tests::Tensor(ET, {3 * 3}, std::vector<T>(3 * 3, 0.f)))
.Ho(reference_tests::Tensor(ET, {2, 3}, std::vector<T>{0.81014400720596313,
0.76665538549423218,
0.82387429475784302,
@ -515,6 +532,7 @@ std::vector<LSTMCellParams> generateParamsV1() {
.C_t(reference_tests::Tensor(ET, {2, 3}, std::vector<T>{
0.8488452f, 0.18851636f, 0.5020695f, 0.29716516f, 0.06740791f, 0.45384037f}))
.B(reference_tests::Tensor(ET, {4 * 3}, std::vector<T>(4 * 3, 0.f)))
.P(reference_tests::Tensor(ET, {3 * 3}, std::vector<T>(3 * 3, 0.f)))
.Ho(reference_tests::Tensor(ET, {2, 3}, std::vector<T>{0.81457126f, 0.61109227f, 0.769522f, 0.52239674f, 0.4324641f, 0.63183f}))
.Co(reference_tests::Tensor(ET, {2, 3}, std::vector<T>{1.4444952f, 0.9635685f, 1.2875274f, 0.8053419f, 0.7184521f, 0.95803297f}))
.testcaseName("lstm_cell_v1_zero_bias_default_attrs")
@ -578,6 +596,7 @@ std::vector<LSTMCellParams> generateParamsBiasDefaultAttrsV1() {
0.51022074f,
1.11389844f,
0.74174305f}))
.P(reference_tests::Tensor(ET, {3 * 3}, std::vector<T>(3 * 3, 0.f)))
.Ho(reference_tests::Tensor(ET, {2, 3}, std::vector<T>{0.81014400720596313,
0.76665538549423218,
0.82509011030197144,
@ -651,6 +670,7 @@ std::vector<LSTMCellParams> generateParamsBiasClipV1() {
0.51022074f,
1.11389844f,
0.74174305f}))
.P(reference_tests::Tensor(ET, {3 * 3}, std::vector<T>(3 * 3, 0.f)))
.Ho(reference_tests::Tensor(ET, {2, 3}, std::vector<T>{0.81014400720596313,
0.76665538549423218,
0.82387429475784302,

File diff suppressed because it is too large Load Diff

View File

@ -24,6 +24,12 @@ enum class LSTMWeightsFormat {
IOFC, // ONNX
};
enum class LSTMPeepholesFormat {
FIO, // IE
IOF, // ONNX, PyTorch
IFO, // CAFe, DNNL, TF, MxNet
};
///
/// \brief Change data format of provided node.
///
@ -43,6 +49,12 @@ std::shared_ptr<Node> OPENVINO_API convert_lstm_node_format(const Output<Node>&
LSTMWeightsFormat to_format = LSTMWeightsFormat::FICO,
int64_t axis = 0);
std::shared_ptr<Node> OPENVINO_API
convert_lstm_peepholes_format(const Output<Node>& node,
LSTMPeepholesFormat from_format,
LSTMPeepholesFormat to_format = LSTMPeepholesFormat::FIO,
int64_t axis = 0);
/// \brief Base class for all recurrent network cells.
///
/// \note It holds all common attributes.

View File

@ -155,6 +155,214 @@ void lstm_cell(const T* X,
// Ht = ot (.) h(Ct)
reference::multiply(X_W_fico[3].data(), Ct.data(), out_Ht, gate_shape, gate_shape, op::AutoBroadcastType::NUMPY);
}
template <typename T>
void lstm_cell_v1(const T* X,
const Shape& X_shape,
const T* H,
const Shape& H_shape,
const T* C,
const Shape& C_shape,
const T* W,
const Shape& W_shape,
const T* R,
const Shape& R_shape,
const T* B,
const Shape& B_shape,
const T* P,
const Shape& P_shape,
T* out_Ht,
T* out_Ct,
const std::string& activation_f,
const std::string& activation_g,
const std::string& activation_h,
float clip,
const ov::op::LSTMWeightsFormat weight_format,
bool input_forget) {
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ ACRONYMS ------
// i - input gate
// o - output gate
// f - forget gate
// c - cell gate
// t - time step (t-1 means previous time step)
// Wb - W bias vectors for input, output, forget, and cell gates.
// Rb - R bias vectors for input, output, forget, and cell gates.
// P - The peephole weights for input, output and forget gates.
// ------ VARIABLE NAMES ------
// X - The input data tensor. Shape: [batch_size, input_size].
// W - The weight matrix for input, forget, cell and output gates
// Shape: [4*hidden_size, input_size]
// R - The recurrence weight matrix for input, forget, cell and output gates.
// Shape: [4*hidden_size, hidden_size].
// H_t - The hidden state tensor at current time step. Shape: [batch_size,
// hidden_size].
// C_t - The cell state tensor at current time step. Shape: [batch_size,
// hidden_size].
// bias - The sum of biases (weight and recurrence) for input, forget, cell and
// output gates.
// Shape: [4 * hidden_size]
// p_[iof] - The peephole weight vector for respectively: input, output, and forget
// gates.
// Each peephole has shape [hidden_size].
//
// (.) - Denotes element-wise multiplication.
// * - Denotes dot product.
//
// ---- Equations ----
// f, g, h - are activation functions.
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
// Ct = ft (.) Ct-1 + it (.) ct
// Ht = ot (.) h(Ct)
// --------------------
Shape gate_shape{X_shape[0], H_shape[1]};
Shape all_gates_shape{X_shape[0], 4 * H_shape[1]};
Shape P_gate_shape{H_shape[1]};
auto P_gate_size = H_shape[1];
auto gate_shape_size = X_shape[0] * H_shape[1];
auto all_gates_shape_size = gate_shape_size * 4;
if (weight_format != ov::op::LSTMWeightsFormat::FICO) {
throw ngraph_error("Only LSTMWeightFormat = FICO is supported.");
}
// Xt*(W^T)
std::vector<T> Xt_W(all_gates_shape_size);
reference::matmul(X, W, Xt_W.data(), X_shape, W_shape, all_gates_shape, false, true);
// Ht-1*(R^T)
std::vector<T> Ht_R(all_gates_shape_size);
reference::matmul(H, R, Ht_R.data(), H_shape, R_shape, all_gates_shape, false, true);
// Ht-1*(R^T) + Wb + Rb
std::vector<T> Ht_R_B(all_gates_shape_size);
reference::add(Ht_R.data(), B, Ht_R_B.data(), all_gates_shape, B_shape, op::AutoBroadcastType::NUMPY);
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb
std::vector<T> XHB(all_gates_shape_size);
reference::add(Xt_W.data(),
Ht_R_B.data(),
XHB.data(),
all_gates_shape,
all_gates_shape,
op::AutoBroadcastType::NUMPY);
std::vector<std::vector<T>> X_W_fico(4, std::vector<T>(all_gates_shape_size / 4));
std::vector<char*> pointers = {reinterpret_cast<char*>(X_W_fico[0].data()),
reinterpret_cast<char*>(X_W_fico[1].data()),
reinterpret_cast<char*>(X_W_fico[2].data()),
reinterpret_cast<char*>(X_W_fico[3].data())};
// split on gates
reference::split(reinterpret_cast<char*>(XHB.data()), all_gates_shape, sizeof(T), 1, 4, pointers.data());
auto clip_activation = [&clip](std::vector<T>& gate, const std::string& activation, bool enable_clip = true) {
if (clip > 0.f && enable_clip) {
reference::clamp(gate.data(), gate.data(), static_cast<T>(-clip), static_cast<T>(clip), gate.size());
}
if (activation == "relu") {
reference::relu(gate.data(), gate.data(), gate.size());
} else if (activation == "sigmoid") {
reference::sigmoid(gate.data(), gate.data(), gate.size());
} else if (activation == "tanh") {
reference::tanh(gate.data(), gate.data(), gate.size());
} else {
throw ngraph_error("Activation function " + activation + " is not supported.");
}
};
// Split P on gates f, i, o
std::vector<std::vector<T>> P_fio(3, std::vector<T>(P_gate_size));
std::vector<char*> P_pointers = {reinterpret_cast<char*>(P_fio[0].data()),
reinterpret_cast<char*>(P_fio[1].data()),
reinterpret_cast<char*>(P_fio[2].data())};
reference::split(reinterpret_cast<const char*>(P), P_shape, sizeof(T), 0, 3, P_pointers.data());
// Pf (.) Ct-1
std::vector<T> PfCt_1(gate_shape_size);
reference::multiply(P_fio[0].data(), C, PfCt_1.data(), P_gate_shape, C_shape, op::AutoBroadcastType::NUMPY);
// Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf + Pf (.) Ct-1
std::vector<T> XHBPf(gate_shape_size);
reference::add(X_W_fico[0].data(), PfCt_1.data(), XHBPf.data(), gate_shape, C_shape, op::AutoBroadcastType::NUMPY);
// Pi (.) Ct-1
std::vector<T> PiCt_1(gate_shape_size);
reference::multiply(P_fio[1].data(), C, PiCt_1.data(), P_gate_shape, C_shape, op::AutoBroadcastType::NUMPY);
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
clip_activation(XHBPf, activation_f);
// it calculation per input_forget condition
std::vector<T> XHBPi(gate_shape_size);
if (input_forget) {
// it = (1 - ft)
std::vector<T> ones(gate_shape_size, 1.f);
reference::subtract(ones.data(),
XHBPf.data(),
XHBPi.data(),
gate_shape,
gate_shape,
op::AutoBroadcastType::NUMPY);
} else {
// Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi + Pi (.) Ct-1
reference::add(X_W_fico[1].data(),
PiCt_1.data(),
XHBPi.data(),
gate_shape,
C_shape,
op::AutoBroadcastType::NUMPY);
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
clip_activation(XHBPi, activation_f);
}
// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
clip_activation(X_W_fico[2], activation_g);
std::vector<T> mul1(gate_shape_size);
std::vector<T> mul2(gate_shape_size);
std::vector<T> Ct(gate_shape_size);
// ft (.) Ct-1
reference::multiply(XHBPf.data(), C, mul1.data(), gate_shape, C_shape, op::AutoBroadcastType::NUMPY);
// it (.) ct
reference::multiply(XHBPi.data(),
X_W_fico[2].data(),
mul2.data(),
gate_shape,
gate_shape,
op::AutoBroadcastType::NUMPY);
// input_forget=true: Ct = ft (.) Ct-1 + (1 - ft)(.) ct
// input_forget=false: Ct = ft (.) Ct-1 + it (.) ct
reference::add(mul1.data(), mul2.data(), Ct.data(), gate_shape, gate_shape, op::AutoBroadcastType::NUMPY);
std::memcpy(out_Ct, Ct.data(), Ct.size() * sizeof(T));
// Po (.) Ct
std::vector<T> PoCt(gate_shape_size);
reference::multiply(P_fio[2].data(),
Ct.data(),
PoCt.data(),
P_gate_shape,
gate_shape,
op::AutoBroadcastType::NUMPY);
// Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo + Po (.) Ct
std::vector<T> XHBPo(gate_shape_size);
reference::add(X_W_fico[3].data(), PoCt.data(), XHBPo.data(), gate_shape, gate_shape, op::AutoBroadcastType::NUMPY);
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
clip_activation(XHBPo, activation_f);
clip_activation(Ct, activation_h, false);
// Ht = ot (.) h(Ct)
reference::multiply(XHBPo.data(), Ct.data(), out_Ht, gate_shape, gate_shape, op::AutoBroadcastType::NUMPY);
}
} // namespace reference
} // namespace runtime
} // namespace ngraph

View File

@ -20,6 +20,7 @@ enum class CellType {
RNN,
GRU,
LSTM,
LSTM_v1,
};
struct CellArgs {
@ -28,6 +29,8 @@ struct CellArgs {
std::string activation_h; // RNN/GRU/LSTM
float clip; // RNN/GRU/LSTM
bool linear_before_reset = false; // GRU
ov::op::LSTMWeightsFormat weight_format = ov::op::LSTMWeightsFormat::FICO; // LSTM_v1
bool input_forget = false; // LSTM_v1
};
template <typename T, typename U>
@ -94,7 +97,7 @@ void cell_pass(CellType type,
std::memcpy(H_i, inputs[2], ngraph::shape_size(shapes[2]) * sizeof(T));
char* C_i = nullptr; // LSTMCell only
if (type == CellType::LSTM) {
if ((type == CellType::LSTM) || (type == CellType::LSTM_v1)) {
C_i = outputs[2];
std::memcpy(C_i, inputs[3], ngraph::shape_size(shapes[3]) * sizeof(T));
}
@ -119,6 +122,29 @@ void cell_pass(CellType type,
args.activation_g,
args.activation_h,
args.clip);
} else if (type == CellType::LSTM_v1) {
runtime::reference::lstm_cell_v1<T>(reinterpret_cast<const T*>(in_seqs.data() + time_step * part_size),
squeeze_axis(shapes[0], 1),
reinterpret_cast<const T*>(H_i),
squeeze_axis(shapes[2], 1),
reinterpret_cast<const T*>(C_i),
squeeze_axis(shapes[3], 1),
reinterpret_cast<const T*>(inputs[4]),
squeeze_axis(shapes[4], 0),
reinterpret_cast<const T*>(inputs[5]),
squeeze_axis(shapes[5], 0),
reinterpret_cast<const T*>(inputs[6]),
squeeze_axis(shapes[6], 0),
reinterpret_cast<const T*>(inputs[7]),
squeeze_axis(shapes[7], 0),
reinterpret_cast<T*>(outputs[1]),
reinterpret_cast<T*>(outputs[2]),
args.activation_f,
args.activation_g,
args.activation_h,
args.clip,
args.weight_format,
args.input_forget);
} else if (type == CellType::RNN) {
runtime::reference::rnn_cell<T>(reinterpret_cast<const T*>(in_seqs.data() + time_step * part_size),
squeeze_axis(shapes[0], 1),
@ -159,13 +185,13 @@ void cell_pass(CellType type,
continue;
}
std::memcpy(h_list[time_step].data() + shift, outputs[1] + shift, part_size_single_batch);
if (type == CellType::LSTM) {
if ((type == CellType::LSTM) || (type == CellType::LSTM_v1)) {
std::memcpy(c_list[time_step].data() + shift, outputs[2] + shift, part_size_single_batch);
}
}
if ((num_splits - time_step) > 1) {
std::memcpy(outputs[1], h_list[time_step].data(), part_shape_size * sizeof(T));
if (type == CellType::LSTM) {
if ((type == CellType::LSTM) || (type == CellType::LSTM_v1)) {
std::memcpy(outputs[2], c_list[time_step].data(), part_shape_size * sizeof(T));
}
} else {
@ -174,12 +200,12 @@ void cell_pass(CellType type,
auto shift = i * part_size_single_batch;
if (idx >= 0 && idx < h_list.size()) {
std::memcpy(outputs[1] + shift, h_list[idx].data() + shift, part_size_single_batch);
if (type == CellType::LSTM) {
if ((type == CellType::LSTM) || (type == CellType::LSTM_v1)) {
std::memcpy(outputs[2] + shift, c_list[idx].data() + shift, part_size_single_batch);
}
} else {
std::memset(outputs[1] + shift, 0, part_size_single_batch);
if (type == CellType::LSTM) {
if ((type == CellType::LSTM) || (type == CellType::LSTM_v1)) {
std::memset(outputs[2] + shift, 0, part_size_single_batch);
}
}
@ -330,6 +356,136 @@ void lstm_sequence(const char* X,
}
}
template <typename T, typename U>
void lstm_sequence_v1(const char* X,
const Shape& X_shape,
const char* H,
const Shape& H_shape,
const char* C,
const Shape& C_shape,
const char* seq_lengths,
const Shape& seq_lengths_shape,
const char* W,
const Shape& W_shape,
const char* R,
const Shape& R_shape,
const char* B,
const Shape& B_shape,
const char* P,
const Shape& P_shape,
char* Y,
char* Ho,
char* Co,
const std::string& activation_f,
const std::string& activation_g,
const std::string& activation_h,
float clip,
const ov::op::LSTMWeightsFormat weight_format,
bool input_forget,
op::RecurrentSequenceDirection direction) {
OutputVector results;
if (direction == op::RecurrentSequenceDirection::FORWARD || direction == op::RecurrentSequenceDirection::REVERSE) {
CellArgs args;
args.activation_f = activation_f;
args.activation_g = activation_g;
args.activation_h = activation_h;
args.clip = clip;
args.weight_format = weight_format;
args.input_forget = input_forget;
std::vector<const char*> inputs = {X, seq_lengths, H, C, W, R, B, P};
std::vector<char*> outputs = {Y, Ho, Co};
std::vector<Shape> shapes = {X_shape, seq_lengths_shape, H_shape, C_shape, W_shape, R_shape, B_shape, P_shape};
cell_pass<T, U>(CellType::LSTM_v1,
inputs,
shapes,
outputs,
args,
direction == op::RecurrentSequenceDirection::REVERSE);
} else if (direction == op::RecurrentSequenceDirection::BIDIRECTIONAL) {
// Split bidirectional case to forward + reverse passes.
// split inputs
std::vector<std::vector<char>> H_split(2, std::vector<char>(sizeof(T) * ngraph::shape_size(H_shape) / 2));
std::vector<std::vector<char>> C_split(2, std::vector<char>(sizeof(T) * ngraph::shape_size(C_shape) / 2));
std::vector<std::vector<char>> W_split(2, std::vector<char>(sizeof(T) * ngraph::shape_size(W_shape) / 2));
std::vector<std::vector<char>> R_split(2, std::vector<char>(sizeof(T) * ngraph::shape_size(R_shape) / 2));
std::vector<std::vector<char>> B_split(2, std::vector<char>(sizeof(T) * ngraph::shape_size(B_shape) / 2));
std::vector<std::vector<char>> P_split(2, std::vector<char>(sizeof(T) * ngraph::shape_size(P_shape) / 2));
char* h_pointers[2] = {H_split[0].data(), H_split[1].data()};
char* c_pointers[2] = {C_split[0].data(), C_split[1].data()};
char* w_pointers[2] = {W_split[0].data(), W_split[1].data()};
char* r_pointers[2] = {R_split[0].data(), R_split[1].data()};
char* b_pointers[2] = {B_split[0].data(), B_split[1].data()};
char* p_pointers[2] = {P_split[0].data(), P_split[1].data()};
reference::split(H, H_shape, sizeof(T), 1, 2, h_pointers);
reference::split(C, C_shape, sizeof(T), 1, 2, c_pointers);
reference::split(W, W_shape, sizeof(T), 0, 2, w_pointers);
reference::split(R, R_shape, sizeof(T), 0, 2, r_pointers);
reference::split(B, B_shape, sizeof(T), 0, 2, b_pointers);
reference::split(P, P_shape, sizeof(T), 0, 2, p_pointers);
std::vector<char> forward_res_y(sizeof(T) * H_shape[0] * H_shape[2] * X_shape[1]);
std::vector<char> reverse_res_y(sizeof(T) * H_shape[0] * H_shape[2] * X_shape[1]);
std::vector<std::vector<char>> forward_res(2, std::vector<char>(sizeof(T) * H_shape[0] * H_shape[2]));
std::vector<std::vector<char>> reverse_res(2, std::vector<char>(sizeof(T) * H_shape[0] * H_shape[2]));
CellArgs args;
args.activation_f = activation_f;
args.activation_g = activation_g;
args.activation_h = activation_h;
args.clip = clip;
args.weight_format = weight_format;
args.input_forget = input_forget;
std::vector<Shape> shapes = {X_shape, seq_lengths_shape, H_shape, C_shape, W_shape, R_shape, B_shape, P_shape};
// update H,C,W,R,B,P shapes after split
shapes[2][1] = 1;
shapes[3][1] = 1;
for (size_t i = 4; i < shapes.size(); ++i) {
shapes[i][0] = 1;
}
// forward pass
cell_pass<T, U>(
CellType::LSTM_v1,
{X, seq_lengths, h_pointers[0], c_pointers[0], w_pointers[0], r_pointers[0], b_pointers[0], p_pointers[0]},
shapes,
{forward_res_y.data(), forward_res[0].data(), forward_res[1].data()},
args,
false);
// reverse pass
cell_pass<T, U>(
CellType::LSTM_v1,
{X, seq_lengths, h_pointers[1], c_pointers[1], w_pointers[1], r_pointers[1], b_pointers[1], p_pointers[1]},
shapes,
{reverse_res_y.data(), reverse_res[0].data(), reverse_res[1].data()},
args,
true);
// Stack together respective outputs from both forward and reverse passes.
std::vector<Shape> in_shapes_y = {{H_shape[0], 1, X_shape[1], H_shape[2]},
{H_shape[0], 1, X_shape[1], H_shape[2]}};
std::vector<Shape> in_shapes_h_c = {{H_shape[0], 1, H_shape[2]}, {H_shape[0], 1, H_shape[2]}};
Shape output_shape_y{H_shape[0], 2, X_shape[1], H_shape[2]};
Shape output_shape_h_c{H_shape[0], 2, H_shape[2]};
runtime::reference::concat({forward_res_y.data(), reverse_res_y.data()},
Y,
in_shapes_y,
output_shape_y,
1,
sizeof(T));
runtime::reference::concat({forward_res[0].data(), reverse_res[0].data()},
Ho,
in_shapes_h_c,
output_shape_h_c,
1,
sizeof(T));
runtime::reference::concat({forward_res[1].data(), reverse_res[1].data()},
Co,
in_shapes_h_c,
output_shape_h_c,
1,
sizeof(T));
}
}
template <typename T, typename U>
void gru_sequence(const char* X,
const Shape& X_shape,

View File

@ -291,6 +291,14 @@ shared_ptr<Node> op::v0::LSTMSequence::prepare_input(Output<Node> node,
void op::v0::LSTMSequence::validate_and_infer_types() {
NGRAPH_OP_SCOPE(v0_LSTMSequence_validate_and_infer_types);
for (const auto& input : inputs()) {
if (input.get_partial_shape().rank().is_dynamic()) {
set_output_type(0, get_input_element_type(0), ov::PartialShape::dynamic());
set_output_type(1, get_input_element_type(0), ov::PartialShape::dynamic());
set_output_type(2, get_input_element_type(0), ov::PartialShape::dynamic());
return;
}
}
std::vector<ov::PartialShape> input_param{};
auto lstm_seq_gates_count = 4;

View File

@ -45,6 +45,28 @@ std::shared_ptr<ov::Node> ov::op::util::convert_lstm_node_format(const Output<No
return std::make_shared<ngraph::opset4::Concat>(nodes_in_new_format, axis);
}
std::shared_ptr<ov::Node> ov::op::util::convert_lstm_peepholes_format(const Output<Node>& node,
LSTMPeepholesFormat from_format,
LSTMPeepholesFormat to_format,
int64_t axis) {
static const std::map<op::util::LSTMPeepholesFormat, std::vector<size_t>> gate_order_map{
{op::util::LSTMPeepholesFormat::FIO, {0, 1, 2}},
{op::util::LSTMPeepholesFormat::IFO, {1, 0, 2}},
{op::util::LSTMPeepholesFormat::IOF, {1, 2, 0}},
};
const auto& from = gate_order_map.at(from_format);
const auto& to = gate_order_map.at(to_format);
size_t num_gates = 3;
auto axis_const = std::make_shared<ngraph::opset4::Constant>(element::i64, ngraph::Shape{}, axis);
OutputVector splitted_node = std::make_shared<ngraph::opset4::Split>(node, axis_const, num_gates)->outputs();
OutputVector nodes_in_new_format(num_gates);
for (size_t i = 0; i < num_gates; ++i) {
nodes_in_new_format[to[from[i]]] = splitted_node[i];
}
return std::make_shared<ngraph::opset4::Concat>(nodes_in_new_format, axis);
}
// Modify input vector in-place and return reference to modified vector.
static vector<string> to_lower_case(const vector<string>& vs) {
vector<string> res(vs);

View File

@ -73,10 +73,6 @@ INTERPRETER.onnx_model_dequantize_linear_1d_zero_scale_uint8_negative_axis
INTERPRETER.onnx_if_inside_if
INTERPRETER.onnx_if_inside_loop
# Legacy tests with unsupported features from opset4 LSTM/GRU/RNN
# Peepholes input unsupported
onnx_model_lstm_fwd_with_clip_peepholes
onnx_model_lstm_bdir_short_input_seq_peepholes
# Activation function hardsigmoid unsupported
onnx_model_gru_fwd_activations_relu_hardsigmoid
onnx_model_lstm_fwd_hardsigmoid_activation

View File

@ -4,6 +4,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/opsets/opset1.hpp"
#include "ngraph/opsets/opset5.hpp"
#include "util/type_prop.hpp"
@ -56,6 +57,39 @@ shared_ptr<opset5::LSTMSequence> lstm_seq_tensor_initialization(const recurrent_
return lstm_sequence;
}
shared_ptr<opset1::LSTMSequence> lstm_seq_v1_tensor_initialization(const recurrent_sequence_parameters& param) {
auto batch_size = param.batch_size;
auto seq_length = param.seq_length;
auto input_size = param.input_size;
auto num_directions = param.num_directions;
auto hidden_size = param.hidden_size;
auto et = param.et;
const auto X = make_shared<opset5::Parameter>(et, PartialShape{batch_size, seq_length, input_size});
const auto initial_hidden_state =
make_shared<opset5::Parameter>(et, PartialShape{batch_size, num_directions, hidden_size});
const auto initial_cell_state =
make_shared<opset5::Parameter>(et, PartialShape{batch_size, num_directions, hidden_size});
const auto sequence_lengths = make_shared<opset5::Parameter>(et, PartialShape{batch_size});
const auto W = make_shared<opset5::Parameter>(et, PartialShape{num_directions, hidden_size * 4, input_size});
const auto R = make_shared<opset5::Parameter>(et, PartialShape{num_directions, hidden_size * 4, hidden_size});
const auto B = make_shared<opset5::Parameter>(et, PartialShape{num_directions, hidden_size * 4});
const auto P = make_shared<opset5::Parameter>(et, PartialShape{num_directions, hidden_size * 3});
const auto lstm_sequence = make_shared<opset1::LSTMSequence>();
lstm_sequence->set_argument(0, X);
lstm_sequence->set_argument(1, initial_hidden_state);
lstm_sequence->set_argument(2, initial_cell_state);
lstm_sequence->set_argument(3, sequence_lengths);
lstm_sequence->set_argument(4, W);
lstm_sequence->set_argument(5, R);
lstm_sequence->set_argument(6, B);
lstm_sequence->set_argument(7, P);
return lstm_sequence;
}
TEST(type_prop, lstm_sequence_forward) {
const size_t batch_size = 8;
const size_t num_directions = 1;
@ -102,6 +136,54 @@ TEST(type_prop, lstm_sequence_forward) {
EXPECT_EQ(lstm_sequence->get_output_shape(2), (Shape{batch_size, num_directions, hidden_size}));
}
TEST(type_prop, lstm_sequence_v1_forward) {
const size_t batch_size = 8;
const size_t num_directions = 1;
const size_t seq_length = 6;
const size_t input_size = 4;
const size_t hidden_size = 128;
const auto X = make_shared<opset5::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
const auto initial_hidden_state =
make_shared<opset5::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
const auto initial_cell_state =
make_shared<opset5::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
const auto sequence_lengths = make_shared<opset5::Parameter>(element::i32, Shape{batch_size});
const auto W = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size, input_size});
const auto R = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size, hidden_size});
const auto B = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
const auto P = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, 3 * hidden_size});
const auto lstm_direction = op::RecurrentSequenceDirection::FORWARD;
const auto lstm_sequence = make_shared<opset1::LSTMSequence>(X,
initial_hidden_state,
initial_cell_state,
sequence_lengths,
W,
R,
B,
P,
hidden_size,
lstm_direction);
EXPECT_EQ(lstm_sequence->get_hidden_size(), hidden_size);
EXPECT_EQ(lstm_sequence->get_direction(), op::RecurrentSequenceDirection::FORWARD);
EXPECT_TRUE(lstm_sequence->get_activations_alpha().empty());
EXPECT_TRUE(lstm_sequence->get_activations_beta().empty());
EXPECT_EQ(lstm_sequence->get_activations()[0], "sigmoid");
EXPECT_EQ(lstm_sequence->get_activations()[1], "tanh");
EXPECT_EQ(lstm_sequence->get_activations()[2], "tanh");
EXPECT_EQ(lstm_sequence->get_clip_threshold(), 0.f);
EXPECT_EQ(lstm_sequence->get_output_element_type(0), element::f32);
EXPECT_EQ(lstm_sequence->outputs().size(), 3);
EXPECT_EQ(lstm_sequence->get_output_shape(0), (Shape{batch_size, num_directions, seq_length, hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_element_type(1), element::f32);
EXPECT_EQ(lstm_sequence->get_output_shape(1), (Shape{batch_size, num_directions, hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_element_type(2), element::f32);
EXPECT_EQ(lstm_sequence->get_output_shape(2), (Shape{batch_size, num_directions, hidden_size}));
}
TEST(type_prop, lstm_sequence_bidirectional) {
const size_t batch_size = 24;
const size_t num_directions = 2;
@ -152,6 +234,65 @@ TEST(type_prop, lstm_sequence_bidirectional) {
EXPECT_EQ(lstm_sequence->get_output_shape(2), (Shape{batch_size, num_directions, hidden_size}));
}
TEST(type_prop, lstm_sequence_v1_bidirectional) {
const size_t batch_size = 24;
const size_t num_directions = 2;
const size_t seq_length = 12;
const size_t input_size = 8;
const size_t hidden_size = 256;
const bool input_forget = true;
const ov::op::LSTMWeightsFormat weights_format = ov::op::LSTMWeightsFormat::FICO;
const float clip_threshold = 3.5f;
const auto X = make_shared<opset5::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
const auto initial_hidden_state =
make_shared<opset5::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
const auto initial_cell_state =
make_shared<opset5::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
const auto sequence_lengths = make_shared<opset5::Parameter>(element::i32, Shape{batch_size});
const auto W = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size, input_size});
const auto R = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size, hidden_size});
const auto B = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
const auto P = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, 3 * hidden_size});
const auto lstm_direction = opset5::LSTMSequence::direction::BIDIRECTIONAL;
const std::vector<float> activations_alpha = {2.7, 7.0, 32.367};
const std::vector<float> activations_beta = {0.0, 5.49, 6.0};
const std::vector<std::string> activations = {"tanh", "sigmoid", "sigmoid"};
const auto lstm_sequence = make_shared<opset1::LSTMSequence>(X,
initial_hidden_state,
initial_cell_state,
sequence_lengths,
W,
R,
B,
hidden_size,
lstm_direction,
weights_format,
activations_alpha,
activations_beta,
activations,
clip_threshold,
input_forget);
EXPECT_EQ(lstm_sequence->get_hidden_size(), hidden_size);
EXPECT_EQ(lstm_sequence->get_direction(), opset5::LSTMSequence::direction::BIDIRECTIONAL);
EXPECT_EQ(lstm_sequence->get_activations_alpha(), activations_alpha);
EXPECT_EQ(lstm_sequence->get_activations_beta(), activations_beta);
EXPECT_EQ(lstm_sequence->get_activations()[0], "tanh");
EXPECT_EQ(lstm_sequence->get_activations()[1], "sigmoid");
EXPECT_EQ(lstm_sequence->get_activations()[2], "sigmoid");
EXPECT_EQ(lstm_sequence->get_clip_threshold(), 3.5f);
EXPECT_EQ(lstm_sequence->get_input_forget(), true);
EXPECT_EQ(lstm_sequence->get_weights_format(), ov::op::LSTMWeightsFormat::FICO);
EXPECT_EQ(lstm_sequence->get_output_element_type(0), element::f32);
EXPECT_EQ(lstm_sequence->get_output_shape(0), (Shape{batch_size, num_directions, seq_length, hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_element_type(1), element::f32);
EXPECT_EQ(lstm_sequence->get_output_shape(1), (Shape{batch_size, num_directions, hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_element_type(2), element::f32);
EXPECT_EQ(lstm_sequence->get_output_shape(2), (Shape{batch_size, num_directions, hidden_size}));
}
TEST(type_prop, lstm_sequence_dynamic_batch_size) {
recurrent_sequence_parameters param;
@ -343,3 +484,171 @@ TEST(type_prop, lstm_sequence_invalid_input_direction) {
std::string("Parameter direction must be Forward or Reverse or Bidirectional"));
}
}
TEST(type_prop, lstm_sequence_v1_dynamic_num_directions) {
recurrent_sequence_parameters param;
param.batch_size = 24;
param.num_directions = Dimension::dynamic();
param.seq_length = 12;
param.input_size = 8;
param.hidden_size = 256;
param.et = element::f32;
auto lstm_sequence = lstm_seq_v1_tensor_initialization(param);
lstm_sequence->validate_and_infer_types();
EXPECT_EQ(lstm_sequence->get_output_partial_shape(0),
(PartialShape{param.batch_size, param.num_directions, param.seq_length, param.hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_partial_shape(1),
(PartialShape{param.batch_size, param.num_directions, param.hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_partial_shape(2),
(PartialShape{param.batch_size, param.num_directions, param.hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_element_type(0), param.et);
EXPECT_EQ(lstm_sequence->get_output_element_type(1), param.et);
EXPECT_EQ(lstm_sequence->get_output_element_type(2), param.et);
}
TEST(type_prop, lstm_sequence_v1_dynamic_seq_length) {
recurrent_sequence_parameters param;
param.batch_size = 24;
param.num_directions = 2;
param.seq_length = Dimension::dynamic();
param.input_size = 8;
param.hidden_size = 256;
param.et = element::f32;
auto lstm_sequence = lstm_seq_v1_tensor_initialization(param);
lstm_sequence->validate_and_infer_types();
EXPECT_EQ(lstm_sequence->get_output_partial_shape(0),
(PartialShape{param.batch_size, param.num_directions, param.seq_length, param.hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_partial_shape(1),
(PartialShape{param.batch_size, param.num_directions, param.hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_partial_shape(2),
(PartialShape{param.batch_size, param.num_directions, param.hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_element_type(0), param.et);
EXPECT_EQ(lstm_sequence->get_output_element_type(1), param.et);
EXPECT_EQ(lstm_sequence->get_output_element_type(2), param.et);
}
TEST(type_prop, lstm_sequence_v1_dynamic_hidden_size) {
recurrent_sequence_parameters param;
param.batch_size = 24;
param.num_directions = 2;
param.seq_length = 12;
param.input_size = 8;
param.hidden_size = Dimension::dynamic();
param.et = element::f32;
auto lstm_sequence = lstm_seq_v1_tensor_initialization(param);
lstm_sequence->validate_and_infer_types();
EXPECT_EQ(lstm_sequence->get_output_partial_shape(0),
(PartialShape{param.batch_size, param.num_directions, param.seq_length, param.hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_partial_shape(1),
(PartialShape{param.batch_size, param.num_directions, param.hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_partial_shape(2),
(PartialShape{param.batch_size, param.num_directions, param.hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_element_type(0), param.et);
EXPECT_EQ(lstm_sequence->get_output_element_type(1), param.et);
EXPECT_EQ(lstm_sequence->get_output_element_type(2), param.et);
}
TEST(type_prop, lstm_sequence_v1_dynamic_inputs) {
recurrent_sequence_parameters param;
param.batch_size = Dimension::dynamic();
param.input_size = Dimension::dynamic();
param.hidden_size = Dimension::dynamic();
param.num_directions = Dimension::dynamic();
param.seq_length = Dimension::dynamic();
param.et = element::f32;
auto lstm_sequence = lstm_seq_v1_tensor_initialization(param);
lstm_sequence->validate_and_infer_types();
EXPECT_EQ(lstm_sequence->get_output_partial_shape(0),
(PartialShape{param.batch_size, param.num_directions, param.seq_length, param.hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_partial_shape(1),
(PartialShape{param.batch_size, param.num_directions, param.hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_partial_shape(2),
(PartialShape{param.batch_size, param.num_directions, param.hidden_size}));
EXPECT_EQ(lstm_sequence->get_output_element_type(0), param.et);
EXPECT_EQ(lstm_sequence->get_output_element_type(1), param.et);
EXPECT_EQ(lstm_sequence->get_output_element_type(2), param.et);
}
TEST(type_prop, lstm_sequence_v1_invalid_input_dimension) {
recurrent_sequence_parameters param;
param.batch_size = 24;
param.num_directions = 2;
param.seq_length = 12;
param.input_size = 8;
param.hidden_size = 256;
param.et = element::f32;
auto lstm_sequence = lstm_seq_v1_tensor_initialization(param);
auto invalid_rank0_tensor = make_shared<opset5::Parameter>(param.et, PartialShape{});
// Validate invalid rank0 tensor for all inputs: X, initial_hidden_state, initial_cell_state W,
// R, B
for (size_t i = 0; i < lstm_sequence->get_input_size(); i++) {
lstm_sequence = lstm_seq_v1_tensor_initialization(param);
lstm_sequence->set_argument(i, invalid_rank0_tensor);
ASSERT_THROW(lstm_sequence->validate_and_infer_types(), ngraph::CheckFailure)
<< "LSTMSequence node was created with invalid data.";
}
}
TEST(type_prop, lstm_sequence_v1_invalid_input_dynamic_rank) {
recurrent_sequence_parameters param;
param.batch_size = 24;
param.num_directions = 2;
param.seq_length = 12;
param.input_size = 8;
param.hidden_size = 256;
param.et = element::f32;
auto check_dynamic_lstm = [](const shared_ptr<opset1::LSTMSequence>& lstm) -> bool {
return lstm->output(0).get_partial_shape() == PartialShape::dynamic() &&
lstm->output(1).get_partial_shape() == PartialShape::dynamic() &&
lstm->output(2).get_partial_shape() == PartialShape::dynamic() &&
lstm->output(0).get_element_type() == lstm->input(0).get_element_type();
};
auto lstm_sequence = lstm_seq_v1_tensor_initialization(param);
auto invalid_dynamic_tensor = make_shared<opset1::Parameter>(param.et, PartialShape::dynamic(Rank::dynamic()));
// Validate invalid dynamic tensor for all inputs: X, initial_hidden_state, initial_cell_state
// W, R, B
for (size_t i = 0; i < lstm_sequence->get_input_size(); i++) {
lstm_sequence = lstm_seq_v1_tensor_initialization(param);
lstm_sequence->set_argument(i, invalid_dynamic_tensor);
lstm_sequence->validate_and_infer_types();
EXPECT_EQ(check_dynamic_lstm(lstm_sequence), true);
}
}
TEST(type_prop, lstm_sequence_v1_invalid_input_direction) {
recurrent_sequence_parameters param;
param.batch_size = 24;
param.num_directions = 3;
param.seq_length = 12;
param.input_size = 8;
param.hidden_size = 256;
param.et = element::f32;
auto lstm_sequence = lstm_seq_v1_tensor_initialization(param);
try {
lstm_sequence->validate_and_infer_types();
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Parameter direction must be Forward or Reverse or Bidirectional"));
}
}

View File

@ -64,3 +64,60 @@ TEST(attributes, lstm_sequence_op) {
EXPECT_EQ(g_lstm_sequence->get_clip(), lstm_sequence->get_clip());
EXPECT_EQ(g_lstm_sequence->get_direction(), lstm_sequence->get_direction());
}
TEST(attributes, lstm_sequence_v1_op) {
NodeBuilder::get_ops().register_factory<opset5::LSTMSequence>();
const size_t batch_size = 4;
const size_t num_directions = 2;
const size_t seq_length = 8;
const size_t input_size = 16;
const size_t hidden_size = 64;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
const auto initial_hidden_state =
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
const auto initial_cell_state =
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{batch_size});
const auto W = make_shared<op::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size, hidden_size});
const auto B = make_shared<op::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
const auto P = make_shared<op::Parameter>(element::f32, Shape{num_directions, 3 * hidden_size});
const auto lstm_direction = op::RecurrentSequenceDirection::BIDIRECTIONAL;
const ov::op::LSTMWeightsFormat weights_format = ov::op::LSTMWeightsFormat::FICO;
const std::vector<float> activations_alpha = {1, 2, 3};
const std::vector<float> activations_beta = {4, 5, 6};
const std::vector<std::string> activations = {"tanh", "sigmoid", "tanh"};
const float clip_threshold = 0.5f;
const bool input_forget = true;
const auto lstm_sequence = make_shared<opset1::LSTMSequence>(X,
initial_hidden_state,
initial_cell_state,
sequence_lengths,
W,
R,
B,
P,
hidden_size,
lstm_direction,
weights_format,
activations_alpha,
activations_beta,
activations,
clip_threshold,
input_forget);
NodeBuilder builder(lstm_sequence);
auto g_lstm_sequence = ov::as_type_ptr<opset1::LSTMSequence>(builder.create());
EXPECT_EQ(g_lstm_sequence->get_hidden_size(), lstm_sequence->get_hidden_size());
EXPECT_EQ(g_lstm_sequence->get_activations(), lstm_sequence->get_activations());
EXPECT_EQ(g_lstm_sequence->get_activations_alpha(), lstm_sequence->get_activations_alpha());
EXPECT_EQ(g_lstm_sequence->get_activations_beta(), lstm_sequence->get_activations_beta());
EXPECT_EQ(g_lstm_sequence->get_clip_threshold(), lstm_sequence->get_clip_threshold());
EXPECT_EQ(g_lstm_sequence->get_direction(), lstm_sequence->get_direction());
EXPECT_EQ(g_lstm_sequence->get_input_forget(), lstm_sequence->get_input_forget());
EXPECT_EQ(g_lstm_sequence->get_weights_format(), lstm_sequence->get_weights_format());
}

View File

@ -48,6 +48,7 @@ struct LSTMNgInputMap {
const auto& ng_inputs = node.get_ng_inputs();
// We have input, output, forget and cell gates
constexpr std::size_t gates_count{4};
constexpr std::size_t P_gates_count{3};
// ----- Mandatory inputs ------
// Packed input sequences.
@ -153,9 +154,25 @@ struct LSTMNgInputMap {
init_c_shape);
}
// `P` - The weight tensor for peepholes.
// Peepholes input is not supported by OpenVino
// ONNX Shape: [num_directions, 3*hidden_size]
// OpenVino Shape: [num_directions, 4*hidden_size]
if (ng_inputs.size() > 7 && !ngraph::op::is_null(ng_inputs.at(7))) {
NGRAPH_WARN << (node) << " Input `P` (peepholes) is not supported and will be ignored ";
m_input_map[LSTMInput::LSTM_INPUT_P] =
ov::op::util::convert_lstm_peepholes_format(ng_inputs.at(7),
ov::op::util::LSTMPeepholesFormat::IOF,
ov::op::util::LSTMPeepholesFormat::FIO,
1);
} else {
auto p_shape = std::make_shared<default_opset::Concat>(
OutputVector{num_directions_node,
std::make_shared<default_opset::Multiply>(
default_opset::Constant::create(element::Type_t::i64, Shape{1}, {P_gates_count}),
hidden_size_node)},
0);
m_input_map[LSTMInput::LSTM_INPUT_P] = std::make_shared<default_opset::Broadcast>(
default_opset::Constant::create(m_input_map[LSTMInput::LSTM_INPUT_X].get_element_type(), Shape{}, {0}),
p_shape);
m_input_map[LSTMInput::LSTM_INPUT_P].set_names({"P_blank"});
}
}
@ -181,12 +198,6 @@ struct LSTMAttributes {
std::string direction = ngraph::to_lower(node.get_attribute_value<std::string>("direction", "forward"));
m_direction = ngraph::as_enum<ngraph::op::RecurrentSequenceDirection>(direction);
if (m_input_forget != 0) {
NGRAPH_WARN << (node)
<< " Attribute `input_forget` is not supported "
"and will be ignored ";
}
}
ngraph::op::RecurrentSequenceDirection m_direction;
@ -204,8 +215,28 @@ namespace set_1 {
OutputVector lstm(const Node& node) {
LSTMNgInputMap input_map{node};
LSTMAttributes attributes{node};
std::shared_ptr<ngraph::Node> lstm_sequence;
auto lstm_sequence = std::make_shared<default_opset::LSTMSequence>(input_map.at(LSTMInput::LSTM_INPUT_X),
if ((input_map.at(LSTMInput::LSTM_INPUT_P).get_names() != std::unordered_set<std::string>({"P_blank"})) ||
(attributes.m_input_forget == true)) {
lstm_sequence = std::make_shared<ov::op::v0::LSTMSequence>(input_map.at(LSTMInput::LSTM_INPUT_X),
input_map.at(LSTMInput::LSTM_INPUT_INIT_H),
input_map.at(LSTMInput::LSTM_INPUT_INIT_C),
input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
input_map.at(LSTMInput::LSTM_INPUT_W),
input_map.at(LSTMInput::LSTM_INPUT_R),
input_map.at(LSTMInput::LSTM_INPUT_B),
input_map.at(LSTMInput::LSTM_INPUT_P),
attributes.m_hidden_size,
attributes.m_direction,
ov::op::LSTMWeightsFormat::FICO,
attributes.m_activation_alpha,
attributes.m_activation_beta,
attributes.m_activations,
attributes.m_clip_threshold,
attributes.m_input_forget);
} else {
lstm_sequence = std::make_shared<default_opset::LSTMSequence>(input_map.at(LSTMInput::LSTM_INPUT_X),
input_map.at(LSTMInput::LSTM_INPUT_INIT_H),
input_map.at(LSTMInput::LSTM_INPUT_INIT_C),
input_map.at(LSTMInput::LSTM_INPUT_SEQ_LENGTHS),
@ -218,6 +249,7 @@ OutputVector lstm(const Node& node) {
attributes.m_activation_beta,
attributes.m_activations,
attributes.m_clip_threshold);
}
const auto Y = lstm_sequence->output(0);
const auto Y_h = lstm_sequence->output(1);