[ONNX] Set defaults in LSTMSequence even if dimensions are dynamic (#3491)
Ticket: 43221
This commit is contained in:
parent
6888ffa328
commit
8b90c9e7e2
@ -60,62 +60,8 @@ namespace ngraph
|
||||
LSTM_INPUT_P
|
||||
};
|
||||
|
||||
enum class LSTMInputDimension
|
||||
{
|
||||
BATCH_SIZE,
|
||||
SEQ_LENGTH,
|
||||
NUM_DIRECTIONS,
|
||||
HIDDEN_SIZE,
|
||||
};
|
||||
|
||||
struct LSTMNgInputMap
|
||||
{
|
||||
// Check if input shape dimension at dimension_index is static
|
||||
bool check_static_input_dim(LSTMInput input, const size_t dimension_index)
|
||||
{
|
||||
return m_input_map[input].get_partial_shape().rank().is_static() &&
|
||||
m_input_map[input].get_partial_shape().rank().get_length() >
|
||||
dimension_index &&
|
||||
m_input_map[input].get_partial_shape()[dimension_index].is_static();
|
||||
}
|
||||
|
||||
// Validate and handle dimensions required to create default inputs
|
||||
void init_dim_map()
|
||||
{
|
||||
// batch_size
|
||||
if (check_static_input_dim(LSTMInput::LSTM_INPUT_X, 0))
|
||||
{
|
||||
m_dim_map[LSTMInputDimension::BATCH_SIZE] =
|
||||
m_input_map[LSTMInput::LSTM_INPUT_X]
|
||||
.get_partial_shape()[0]
|
||||
.get_length();
|
||||
}
|
||||
// seq_length
|
||||
if (check_static_input_dim(LSTMInput::LSTM_INPUT_X, 1))
|
||||
{
|
||||
m_dim_map[LSTMInputDimension::SEQ_LENGTH] =
|
||||
m_input_map[LSTMInput::LSTM_INPUT_X]
|
||||
.get_partial_shape()[1]
|
||||
.get_length();
|
||||
}
|
||||
// num_directions
|
||||
if (check_static_input_dim(LSTMInput::LSTM_INPUT_R, 0))
|
||||
{
|
||||
m_dim_map[LSTMInputDimension::NUM_DIRECTIONS] =
|
||||
m_input_map[LSTMInput::LSTM_INPUT_R]
|
||||
.get_partial_shape()[0]
|
||||
.get_length();
|
||||
}
|
||||
// hidden_size
|
||||
if (check_static_input_dim(LSTMInput::LSTM_INPUT_R, 2))
|
||||
{
|
||||
m_dim_map[LSTMInputDimension::HIDDEN_SIZE] =
|
||||
m_input_map[LSTMInput::LSTM_INPUT_R]
|
||||
.get_partial_shape()[2]
|
||||
.get_length();
|
||||
}
|
||||
}
|
||||
|
||||
explicit LSTMNgInputMap(const Node& node)
|
||||
{
|
||||
const auto& ng_inputs = node.get_ng_inputs();
|
||||
@ -150,7 +96,29 @@ namespace ngraph
|
||||
1);
|
||||
|
||||
// Get dimensions needed for default inputs creation
|
||||
init_dim_map();
|
||||
auto shape_of_x = std::make_shared<default_opset::ShapeOf>(
|
||||
m_input_map[LSTMInput::LSTM_INPUT_X]);
|
||||
auto axes =
|
||||
default_opset::Constant::create(element::Type_t::i32, Shape{1}, {0});
|
||||
auto batch_size_node = std::make_shared<default_opset::Gather>(
|
||||
shape_of_x,
|
||||
default_opset::Constant::create(element::Type_t::i32, Shape{1}, {0}),
|
||||
axes);
|
||||
auto seq_length_node = std::make_shared<default_opset::Gather>(
|
||||
shape_of_x,
|
||||
default_opset::Constant::create(element::Type_t::i32, Shape{1}, {1}),
|
||||
axes);
|
||||
|
||||
auto shape_of_r = std::make_shared<default_opset::ShapeOf>(
|
||||
m_input_map[LSTMInput::LSTM_INPUT_R]);
|
||||
auto num_directions_node = std::make_shared<default_opset::Gather>(
|
||||
shape_of_r,
|
||||
default_opset::Constant::create(element::Type_t::i32, Shape{1}, {0}),
|
||||
axes);
|
||||
auto hidden_size_node = std::make_shared<default_opset::Gather>(
|
||||
shape_of_r,
|
||||
default_opset::Constant::create(element::Type_t::i32, Shape{1}, {2}),
|
||||
axes);
|
||||
|
||||
// ------ Optional inputs ------
|
||||
// `B` - The bias tensor for input gate.
|
||||
@ -174,23 +142,20 @@ namespace ngraph
|
||||
}
|
||||
else
|
||||
{
|
||||
NGRAPH_CHECK(m_dim_map.count(LSTMInputDimension::NUM_DIRECTIONS) &&
|
||||
m_dim_map.count(LSTMInputDimension::HIDDEN_SIZE),
|
||||
"ONNX LSTM: Can't create default `B` input, "
|
||||
"because at least one of required dimensions "
|
||||
"(num_directions, hidden_size) is dynamic. "
|
||||
"\n`R` input onnx shape {num_directions, "
|
||||
"gates_count*hidden_size, hidden_size}: ",
|
||||
ng_inputs.at(2).get_partial_shape());
|
||||
|
||||
m_input_map[LSTMInput::LSTM_INPUT_B] = default_opset::Constant::create(
|
||||
m_input_map[LSTMInput::LSTM_INPUT_X].get_element_type(),
|
||||
Shape{m_dim_map[LSTMInputDimension::NUM_DIRECTIONS],
|
||||
gates_count * m_dim_map[LSTMInputDimension::HIDDEN_SIZE]},
|
||||
std::vector<float>(m_dim_map[LSTMInputDimension::NUM_DIRECTIONS] *
|
||||
gates_count *
|
||||
m_dim_map[LSTMInputDimension::HIDDEN_SIZE],
|
||||
0.f));
|
||||
auto b_shape = std::make_shared<default_opset::Concat>(
|
||||
OutputVector{num_directions_node,
|
||||
std::make_shared<default_opset::Multiply>(
|
||||
default_opset::Constant::create(
|
||||
element::Type_t::i64, Shape{1}, {gates_count}),
|
||||
hidden_size_node)},
|
||||
0);
|
||||
m_input_map[LSTMInput::LSTM_INPUT_B] =
|
||||
std::make_shared<default_opset::Broadcast>(
|
||||
default_opset::Constant::create(
|
||||
m_input_map[LSTMInput::LSTM_INPUT_X].get_element_type(),
|
||||
Shape{},
|
||||
{0}),
|
||||
b_shape);
|
||||
}
|
||||
// `sequence_lens`- The lengths of the sequences in a batch.
|
||||
// Shape: [batch_size]
|
||||
@ -200,22 +165,9 @@ namespace ngraph
|
||||
}
|
||||
else
|
||||
{
|
||||
NGRAPH_CHECK(
|
||||
m_dim_map.count(LSTMInputDimension::BATCH_SIZE) &&
|
||||
m_dim_map.count(LSTMInputDimension::SEQ_LENGTH),
|
||||
"ONNX LSTM: Can't create default `sequence_lens` input, ",
|
||||
"because at least one of required dimensions "
|
||||
"(batch_size, seq_length) is dynamic. "
|
||||
"\n`X` input onnx shape {seq_length, batch_size, input_size} is ",
|
||||
ng_inputs.at(0).get_partial_shape());
|
||||
|
||||
m_input_map[LSTMInput::LSTM_INPUT_SEQ_LENGTHS] =
|
||||
default_opset::Constant::create(
|
||||
element::i32,
|
||||
Shape{m_dim_map[LSTMInputDimension::BATCH_SIZE]},
|
||||
std::vector<std::int32_t>(
|
||||
m_dim_map[LSTMInputDimension::BATCH_SIZE],
|
||||
m_dim_map[LSTMInputDimension::SEQ_LENGTH]));
|
||||
std::make_shared<default_opset::Broadcast>(seq_length_node,
|
||||
batch_size_node);
|
||||
}
|
||||
// `initial_h` - The initial value of the hidden.
|
||||
// ONNX Shape: [num_directions, batch_size, hidden_size]
|
||||
@ -227,30 +179,17 @@ namespace ngraph
|
||||
}
|
||||
else
|
||||
{
|
||||
NGRAPH_CHECK(
|
||||
m_dim_map.count(LSTMInputDimension::BATCH_SIZE) &&
|
||||
m_dim_map.count(LSTMInputDimension::NUM_DIRECTIONS) &&
|
||||
m_dim_map.count(LSTMInputDimension::HIDDEN_SIZE),
|
||||
"ONNX LSTM: Can't create default `initial_h` input, "
|
||||
"because at least one of required dimensions "
|
||||
"(batch_size, num_directions, hidden_size) is dynamic. "
|
||||
"\n`X` input onnx shape {seq_length, batch_size, input_size} is ",
|
||||
ng_inputs.at(0).get_partial_shape(),
|
||||
"\n`R` input onnx shape {num_directions, 4*hidden_size, "
|
||||
"hidden_size} is ",
|
||||
ng_inputs.at(2).get_partial_shape());
|
||||
|
||||
auto init_h_shape = std::make_shared<default_opset::Concat>(
|
||||
OutputVector{
|
||||
batch_size_node, num_directions_node, hidden_size_node},
|
||||
0);
|
||||
m_input_map[LSTMInput::LSTM_INPUT_INIT_H] =
|
||||
default_opset::Constant::create(
|
||||
m_input_map[LSTMInput::LSTM_INPUT_X].get_element_type(),
|
||||
Shape{m_dim_map[LSTMInputDimension::BATCH_SIZE],
|
||||
m_dim_map[LSTMInputDimension::NUM_DIRECTIONS],
|
||||
m_dim_map[LSTMInputDimension::HIDDEN_SIZE]},
|
||||
std::vector<float>(
|
||||
m_dim_map[LSTMInputDimension::BATCH_SIZE] *
|
||||
m_dim_map[LSTMInputDimension::NUM_DIRECTIONS] *
|
||||
m_dim_map[LSTMInputDimension::HIDDEN_SIZE],
|
||||
0.f));
|
||||
std::make_shared<default_opset::Broadcast>(
|
||||
default_opset::Constant::create(
|
||||
m_input_map[LSTMInput::LSTM_INPUT_X].get_element_type(),
|
||||
Shape{},
|
||||
{0}),
|
||||
init_h_shape);
|
||||
}
|
||||
// `initial_c` - The initial value of the cell.
|
||||
// ONNX Shape: [num_directions, batch_size, hidden_size]
|
||||
@ -262,30 +201,17 @@ namespace ngraph
|
||||
}
|
||||
else
|
||||
{
|
||||
NGRAPH_CHECK(
|
||||
m_dim_map.count(LSTMInputDimension::BATCH_SIZE) &&
|
||||
m_dim_map.count(LSTMInputDimension::NUM_DIRECTIONS) &&
|
||||
m_dim_map.count(LSTMInputDimension::HIDDEN_SIZE),
|
||||
"ONNX LSTM: Can't create default `initial_c` input, "
|
||||
"because at least one of required dimensions "
|
||||
"(batch_size, num_directions, hidden_size) is dynamic. "
|
||||
"\n`X` input onnx shape {seq_length, batch_size, input_size} is ",
|
||||
ng_inputs.at(0).get_partial_shape(),
|
||||
"\n`R` input onnx shape {num_directions, 4*hidden_size, "
|
||||
"hidden_size} is ",
|
||||
ng_inputs.at(2).get_partial_shape());
|
||||
|
||||
auto init_c_shape = std::make_shared<default_opset::Concat>(
|
||||
OutputVector{
|
||||
batch_size_node, num_directions_node, hidden_size_node},
|
||||
0);
|
||||
m_input_map[LSTMInput::LSTM_INPUT_INIT_C] =
|
||||
default_opset::Constant::create(
|
||||
m_input_map[LSTMInput::LSTM_INPUT_X].get_element_type(),
|
||||
Shape{m_dim_map[LSTMInputDimension::BATCH_SIZE],
|
||||
m_dim_map[LSTMInputDimension::NUM_DIRECTIONS],
|
||||
m_dim_map[LSTMInputDimension::HIDDEN_SIZE]},
|
||||
std::vector<float>(
|
||||
m_dim_map[LSTMInputDimension::BATCH_SIZE] *
|
||||
m_dim_map[LSTMInputDimension::NUM_DIRECTIONS] *
|
||||
m_dim_map[LSTMInputDimension::HIDDEN_SIZE],
|
||||
0.f));
|
||||
std::make_shared<default_opset::Broadcast>(
|
||||
default_opset::Constant::create(
|
||||
m_input_map[LSTMInput::LSTM_INPUT_X].get_element_type(),
|
||||
Shape{},
|
||||
{0}),
|
||||
init_c_shape);
|
||||
}
|
||||
// `P` - The weight tensor for peepholes.
|
||||
// Peepholes input is not supported by OpenVino
|
||||
@ -299,7 +225,6 @@ namespace ngraph
|
||||
|
||||
Output<ngraph::Node>& at(const LSTMInput& key) { return m_input_map.at(key); }
|
||||
std::map<LSTMInput, Output<ngraph::Node>> m_input_map;
|
||||
std::map<LSTMInputDimension, size_t> m_dim_map;
|
||||
};
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ATTRIBUTES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
@ -0,0 +1,185 @@
|
||||
ir_version: 7
|
||||
producer_name: "onnx-importer-test"
|
||||
graph {
|
||||
node {
|
||||
output: "W"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
dims: 12
|
||||
dims: 1
|
||||
data_type: 1
|
||||
float_data: 0.31403765082359314
|
||||
float_data: -0.16793324053287506
|
||||
float_data: 1.3882579803466797
|
||||
float_data: -0.690295398235321
|
||||
float_data: -0.39940449595451355
|
||||
float_data: -0.7833511233329773
|
||||
float_data: -0.30992957949638367
|
||||
float_data: 0.35575729608535767
|
||||
float_data: -0.46826308965682983
|
||||
float_data: 1.1741459369659424
|
||||
float_data: -2.4147889614105225
|
||||
float_data: -0.42783254384994507
|
||||
name: "const_tensor_W"
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
output: "R"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
dims: 12
|
||||
dims: 3
|
||||
data_type: 1
|
||||
float_data: 0.8490582704544067
|
||||
float_data: 0.45121243596076965
|
||||
float_data: -1.179901361465454
|
||||
float_data: 0.13536448776721954
|
||||
float_data: 0.813286542892456
|
||||
float_data: 0.6017516255378723
|
||||
float_data: 0.4847572445869446
|
||||
float_data: -1.2136037349700928
|
||||
float_data: 0.16383321583271027
|
||||
float_data: 1.5106260776519775
|
||||
float_data: 1.1177502870559692
|
||||
float_data: 0.2358246147632599
|
||||
float_data: 0.8490582704544067
|
||||
float_data: 0.45121243596076965
|
||||
float_data: -1.179901361465454
|
||||
float_data: 0.13536448776721954
|
||||
float_data: 0.813286542892456
|
||||
float_data: 0.6017516255378723
|
||||
float_data: 0.4847572445869446
|
||||
float_data: -1.2136037349700928
|
||||
float_data: 0.16383321583271027
|
||||
float_data: 1.5106260776519775
|
||||
float_data: 1.1177502870559692
|
||||
float_data: 0.2358246147632599
|
||||
float_data: 0.8490582704544067
|
||||
float_data: 0.45121243596076965
|
||||
float_data: -1.179901361465454
|
||||
float_data: 0.13536448776721954
|
||||
float_data: 0.813286542892456
|
||||
float_data: 0.6017516255378723
|
||||
float_data: 0.4847572445869446
|
||||
float_data: -1.2136037349700928
|
||||
float_data: 0.16383321583271027
|
||||
float_data: 1.5106260776519775
|
||||
float_data: 1.1177502870559692
|
||||
float_data: 0.2358246147632599
|
||||
name: "const_tensor"
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "X"
|
||||
input: "W"
|
||||
input: "R"
|
||||
output: "Y"
|
||||
output: "Y_h"
|
||||
output: "Y_c"
|
||||
op_type: "LSTM"
|
||||
attribute {
|
||||
name: "direction"
|
||||
s: "forward"
|
||||
type: STRING
|
||||
}
|
||||
attribute {
|
||||
name: "hidden_size"
|
||||
i: 3
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "test-model-lstm"
|
||||
input {
|
||||
name: "X"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: -1
|
||||
}
|
||||
dim {
|
||||
dim_value: -1
|
||||
}
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: -1
|
||||
}
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: -1
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y_h"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: -1
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y_c"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: -1
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
domain: ""
|
||||
version: 12
|
||||
}
|
@ -0,0 +1,269 @@
|
||||
ir_version: 7
|
||||
producer_name: "onnx-importer-test"
|
||||
graph {
|
||||
node {
|
||||
input: "A"
|
||||
output: "shape"
|
||||
op_type: "Shape"
|
||||
}
|
||||
node {
|
||||
output: "zero"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: 0
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
output: "one"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: 1
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "shape"
|
||||
input: "one"
|
||||
output: "mul"
|
||||
op_type: "Mul"
|
||||
}
|
||||
node {
|
||||
input: "mul"
|
||||
output: "constantofshape"
|
||||
op_type: "ConstantOfShape"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
data_type: 1
|
||||
float_data: 1
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "constantofshape"
|
||||
input: "A"
|
||||
output: "conv"
|
||||
op_type: "Conv"
|
||||
}
|
||||
node {
|
||||
input: "conv"
|
||||
output: "transposed"
|
||||
op_type: "Transpose"
|
||||
attribute {
|
||||
name: "perm"
|
||||
ints: 2
|
||||
ints: 0
|
||||
ints: 1
|
||||
type: INTS
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "shape"
|
||||
input: "zero"
|
||||
output: "batch_size"
|
||||
op_type: "Gather"
|
||||
attribute {
|
||||
name: "axis"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
node {
|
||||
output: "hidden_size"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
data_type: 7
|
||||
int64_data: 2
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "one"
|
||||
input: "batch_size"
|
||||
input: "hidden_size"
|
||||
output: "concat"
|
||||
op_type: "Concat"
|
||||
attribute {
|
||||
name: "axis"
|
||||
i: 0
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "concat"
|
||||
output: "initial_hc"
|
||||
op_type: "ConstantOfShape"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
data_type: 1
|
||||
float_data: 0
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
output: "W"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
dims: 8
|
||||
dims: 3
|
||||
data_type: 1
|
||||
float_data: 4.0
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
output: "R"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
dims: 8
|
||||
dims: 2
|
||||
data_type: 1
|
||||
float_data: 2.0
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
output: "B"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
dims: 16
|
||||
data_type: 1
|
||||
float_data: 3.0
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "transposed"
|
||||
input: "W"
|
||||
input: "R"
|
||||
input: "B"
|
||||
input: ""
|
||||
input: "initial_hc"
|
||||
input: "initial_hc"
|
||||
output: "Y"
|
||||
output: "Y_h"
|
||||
output: "Y_c"
|
||||
op_type: "LSTM"
|
||||
attribute {
|
||||
name: "hidden_size"
|
||||
i: 2
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "test-model-lstm"
|
||||
input {
|
||||
name: "A"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y_h"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y_c"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1
|
||||
}
|
||||
dim {
|
||||
dim_value: 3
|
||||
}
|
||||
dim {
|
||||
dim_value: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
domain: ""
|
||||
version: 12
|
||||
}
|
@ -566,6 +566,44 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_import_only_lstm_dynamic_batch_seq_all_i
|
||||
EXPECT_EQ(count_ops_of_type<op::v5::LSTMSequence>(function), 1);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_import_only_lstm_dynamic_batch_seq_3_inputs)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/dynamic_shapes/lstm_dyn_batch_seq_3_inputs.prototxt"));
|
||||
|
||||
auto batch_size = Dimension::dynamic();
|
||||
auto seq_length = Dimension::dynamic();
|
||||
int64_t hidden_size = 3;
|
||||
int64_t num_directions = 1;
|
||||
auto Y_expected_output = PartialShape{batch_size, num_directions, seq_length, hidden_size};
|
||||
auto Y_h_expected_output = PartialShape{num_directions, batch_size, hidden_size};
|
||||
auto Y_c_expected_output = PartialShape{num_directions, batch_size, hidden_size};
|
||||
|
||||
EXPECT_EQ(function->get_output_size(), 3);
|
||||
EXPECT_EQ(function->get_output_partial_shape(0), Y_expected_output);
|
||||
EXPECT_EQ(function->get_output_partial_shape(1), Y_h_expected_output);
|
||||
EXPECT_EQ(function->get_output_partial_shape(2), Y_c_expected_output);
|
||||
|
||||
EXPECT_EQ(count_ops_of_type<op::v5::LSTMSequence>(function), 1);
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_model_lstm_dynamic_batch_size_and_seq_len)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(
|
||||
file_util::path_join(SERIALIZED_ZOO, "onnx/lstm_dynamic_batch_size_and_seq_len.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine, test::TestCaseType::DYNAMIC>(function);
|
||||
test_case.add_input<float>({1, 2, 3, 4, 5, 6});
|
||||
|
||||
test_case.add_expected_output<float>(
|
||||
Shape{1, 1, 3, 2}, {0.761594, 0.761594, 0.761594, 0.761594, 0.761594, 0.761594}); // Y
|
||||
test_case.add_expected_output<float>(
|
||||
Shape{1, 3, 2}, {0.761594, 0.761594, 0.761594, 0.761594, 0.761594, 0.761594}); // Y_c
|
||||
test_case.add_expected_output<float>(Shape{1, 3, 2}, {1, 1, 1, 1, 1, 1}); // Y_h
|
||||
|
||||
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
|
||||
}
|
||||
|
||||
// RNNLikeSequenceOp test fixture for test setup reuse
|
||||
class GRUSequenceOp : public testing::Test
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user