ONNX RNN/GRU enable dynamic input shape (#4241)
This commit is contained in:
parent
b800f08c0c
commit
ead6427097
@ -41,26 +41,25 @@ namespace ngraph
|
||||
m_map[OpInput::W] = ng_inputs.at(1);
|
||||
m_map[OpInput::R] = ng_inputs.at(2);
|
||||
|
||||
const auto el_type = ng_inputs.at(0).get_element_type();
|
||||
|
||||
const auto x_pshape = m_map[OpInput::X].get_partial_shape();
|
||||
const auto w_pshape = m_map[OpInput::W].get_partial_shape();
|
||||
const auto r_pshape = m_map[OpInput::R].get_partial_shape();
|
||||
NGRAPH_CHECK(x_pshape.rank().is_static() && x_pshape[0].is_static() &&
|
||||
x_pshape[1].is_static(),
|
||||
"RecurrentSequence input X must have static \"seq_length\" and "
|
||||
"\"batch_size\" dimensions.");
|
||||
NGRAPH_CHECK(w_pshape.rank().is_static() && w_pshape[0].is_static(),
|
||||
"RecurrentSequence input W must have static \"num_directions\" "
|
||||
"(outermost) dimension.");
|
||||
NGRAPH_CHECK(r_pshape.rank().is_static() && r_pshape[2].is_static(),
|
||||
"RecurrentSequence input R must have static \"hidden_size\" "
|
||||
"(innermost) dimension.");
|
||||
|
||||
const std::size_t hidden_size = m_map[OpInput::R].get_shape().back();
|
||||
const std::size_t batch_size = m_map[OpInput::X].get_shape().at(0);
|
||||
const std::size_t num_directions = m_map[OpInput::W].get_shape().front();
|
||||
// Get dimensions needed for default inputs creation
|
||||
auto shape_of_x = std::make_shared<default_opset::ShapeOf>(m_map[OpInput::X]);
|
||||
auto axes = default_opset::Constant::create(element::i32, Shape{1}, {0});
|
||||
auto batch_size_node = std::make_shared<default_opset::Gather>(
|
||||
shape_of_x, default_opset::Constant::create(element::i32, Shape{1}, {0}), axes);
|
||||
auto seq_length_node = std::make_shared<default_opset::Gather>(
|
||||
shape_of_x, default_opset::Constant::create(element::i32, Shape{1}, {1}), axes);
|
||||
|
||||
auto shape_of_r = std::make_shared<default_opset::ShapeOf>(m_map[OpInput::R]);
|
||||
auto num_directions_node = std::make_shared<default_opset::Gather>(
|
||||
shape_of_r, default_opset::Constant::create(element::i32, Shape{1}, {0}), axes);
|
||||
auto hidden_size_node = std::make_shared<default_opset::Gather>(
|
||||
shape_of_r, default_opset::Constant::create(element::i32, Shape{1}, {2}), axes);
|
||||
|
||||
// ------ Optional inputs ------
|
||||
if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
|
||||
{
|
||||
auto bias = ng_inputs.at(3);
|
||||
@ -72,8 +71,17 @@ namespace ngraph
|
||||
}
|
||||
else
|
||||
{
|
||||
m_map[OpInput::B] = std::make_shared<default_opset::Constant>(
|
||||
el_type, Shape{num_directions, gates_count * hidden_size}, 0.f);
|
||||
auto b_shape = std::make_shared<default_opset::Concat>(
|
||||
OutputVector{num_directions_node,
|
||||
std::make_shared<default_opset::Multiply>(
|
||||
default_opset::Constant::create(
|
||||
element::Type_t::i64, Shape{1}, {gates_count}),
|
||||
hidden_size_node)},
|
||||
0);
|
||||
m_map[OpInput::B] = std::make_shared<default_opset::Broadcast>(
|
||||
default_opset::Constant::create(
|
||||
m_map[OpInput::X].get_element_type(), Shape{}, {0}),
|
||||
b_shape);
|
||||
}
|
||||
if (ng_inputs.size() > 4 && !ngraph::op::is_null(ng_inputs.at(4)))
|
||||
{
|
||||
@ -81,8 +89,8 @@ namespace ngraph
|
||||
}
|
||||
else
|
||||
{
|
||||
m_map[OpInput::SEQ_LENGTHS] = std::make_shared<default_opset::Constant>(
|
||||
element::i32, Shape{batch_size}, m_map[OpInput::X].get_shape().at(1));
|
||||
m_map[OpInput::SEQ_LENGTHS] = std::make_shared<default_opset::Broadcast>(
|
||||
seq_length_node, batch_size_node);
|
||||
}
|
||||
// The initial value of the hidden.
|
||||
if (ng_inputs.size() > 5 && !ngraph::op::is_null(ng_inputs.at(5)))
|
||||
@ -92,8 +100,12 @@ namespace ngraph
|
||||
}
|
||||
else
|
||||
{
|
||||
m_map[OpInput::INIT_H] = std::make_shared<default_opset::Constant>(
|
||||
el_type, Shape{batch_size, num_directions, hidden_size}, 0.f);
|
||||
auto init_h_shape = std::make_shared<default_opset::Concat>(
|
||||
OutputVector{batch_size_node, num_directions_node, hidden_size_node}, 0);
|
||||
m_map[OpInput::INIT_H] = std::make_shared<default_opset::Broadcast>(
|
||||
default_opset::Constant::create(
|
||||
m_map[OpInput::X].get_element_type(), Shape{}, {0}),
|
||||
init_h_shape);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,220 @@
|
||||
ir_version: 3
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
output: "W"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1 # num_directions
|
||||
dims: 15 # gates_count*hidden_size
|
||||
dims: 2 # input_size
|
||||
data_type: 1
|
||||
float_data: 0.31403765
|
||||
float_data: -0.16793324
|
||||
float_data: 1.388258
|
||||
float_data: -0.6902954
|
||||
float_data: -0.3994045
|
||||
float_data: -0.7833511
|
||||
|
||||
float_data: -0.30992958
|
||||
float_data: 0.3557573
|
||||
float_data: -0.4682631
|
||||
float_data: 1.1741459
|
||||
float_data: -2.414789
|
||||
float_data: -0.42783254
|
||||
|
||||
float_data: -0.82199496
|
||||
float_data: -0.0390086
|
||||
float_data: -0.43670088
|
||||
float_data: -0.53810567
|
||||
float_data: -0.10769883
|
||||
float_data: 0.75242394
|
||||
|
||||
float_data: -0.2507971
|
||||
float_data: 1.0447186
|
||||
float_data: -1.4777364
|
||||
float_data: 0.19993274
|
||||
float_data: 0.925649
|
||||
float_data: -2.282516
|
||||
|
||||
float_data: 0.95039636
|
||||
float_data: 1.5379831
|
||||
float_data: -0.88576007
|
||||
float_data: 0.28566247
|
||||
float_data: 0.79292643
|
||||
float_data: -0.04261953
|
||||
name: "W_tensor"
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
output: "R"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1 # num_directions
|
||||
dims: 15 # gates_count*hidden_size
|
||||
dims: 5 # input_size
|
||||
data_type: 1
|
||||
float_data: 0.8490583
|
||||
float_data: 0.45121244
|
||||
float_data: -1.1799014
|
||||
float_data: 0.13536449
|
||||
float_data: 0.81328654
|
||||
float_data: 0.6017516
|
||||
float_data: 0.48475724
|
||||
float_data: -1.2136037
|
||||
float_data: 0.16383322
|
||||
float_data: 1.5106261
|
||||
float_data: 1.1177503
|
||||
float_data: 0.23582461
|
||||
float_data: 0.5754652
|
||||
float_data: 0.43879887
|
||||
float_data: 0.7399294
|
||||
float_data: 0.4517558
|
||||
float_data: 1.3536783
|
||||
float_data: -0.4843166
|
||||
float_data: -1.1503736
|
||||
float_data: -0.2458678
|
||||
float_data: 0.54523313
|
||||
float_data: -0.08649993
|
||||
float_data: -0.6936281
|
||||
float_data: 1.002422
|
||||
float_data: -1.770847
|
||||
float_data: -0.94642
|
||||
float_data: -1.8135757
|
||||
float_data: 1.8819852
|
||||
float_data: -0.10852333
|
||||
float_data: -0.26120332
|
||||
float_data: 1.0223165
|
||||
float_data: -0.7468837
|
||||
float_data: 0.28566906
|
||||
float_data: 0.92321056
|
||||
float_data: 0.22521864
|
||||
float_data: 1.1123824
|
||||
float_data: -0.9298287
|
||||
float_data: 1.2141289
|
||||
float_data: 1.3470556
|
||||
float_data: -0.32972014
|
||||
float_data: -1.6552197
|
||||
float_data: -1.0998285
|
||||
float_data: 0.71901864
|
||||
float_data: 0.962846
|
||||
float_data: -0.1366851
|
||||
float_data: -2.6534476
|
||||
float_data: -1.4992771
|
||||
float_data: -0.45793465
|
||||
float_data: 0.4290477
|
||||
float_data: 0.9893151
|
||||
float_data: 0.2511034
|
||||
float_data: 0.12906462
|
||||
float_data: 0.7491512
|
||||
float_data: 0.3316756
|
||||
float_data: 1.0576645
|
||||
float_data: -0.04618666
|
||||
float_data: 1.3556088
|
||||
float_data: 1.2842374
|
||||
float_data: 0.7103014
|
||||
float_data: 0.52889013
|
||||
float_data: 0.30327162
|
||||
float_data: 1.5069056
|
||||
float_data: 0.16591893
|
||||
float_data: 1.5719851
|
||||
float_data: -2.099427
|
||||
float_data: -1.010277
|
||||
float_data: -0.52800924
|
||||
float_data: -0.22292352
|
||||
float_data: -0.55177474
|
||||
float_data: 1.3432894
|
||||
float_data: 0.8731192
|
||||
float_data: -0.01055307
|
||||
float_data: -0.01138215
|
||||
float_data: 0.85698843
|
||||
float_data: -1.2615703
|
||||
name: "R_tensor"
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "X"
|
||||
input: "W"
|
||||
input: "R"
|
||||
output: "Y"
|
||||
output: "Y_h"
|
||||
op_type: "GRU"
|
||||
attribute {
|
||||
name: "hidden_size"
|
||||
i: 5
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "test_gru_defaults_const"
|
||||
input {
|
||||
name: "X"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: -1 # seq_length
|
||||
}
|
||||
dim {
|
||||
dim_value: -1 # batch size
|
||||
}
|
||||
dim {
|
||||
dim_value: 2 # input size
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: -1 # seq_length
|
||||
}
|
||||
dim {
|
||||
dim_value: 1 # num_directions
|
||||
}
|
||||
dim {
|
||||
dim_value: -1 # batch_size
|
||||
}
|
||||
dim {
|
||||
dim_value: 5 # hidden_size
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y_h"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1 # num_directions
|
||||
}
|
||||
dim {
|
||||
dim_value: -1 # batch_size
|
||||
}
|
||||
dim {
|
||||
dim_value: 5 # hidden_size
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 9
|
||||
}
|
@ -0,0 +1,146 @@
|
||||
ir_version: 3
|
||||
producer_name: "nGraph ONNX Importer"
|
||||
graph {
|
||||
node {
|
||||
output: "W"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1
|
||||
dims: 5
|
||||
dims: 2
|
||||
data_type: 1
|
||||
float_data: 0.31403765
|
||||
float_data: -0.16793324
|
||||
float_data: 1.388258
|
||||
float_data: -0.6902954
|
||||
float_data: -0.3994045
|
||||
float_data: -0.7833511
|
||||
float_data: -0.30992958
|
||||
float_data: 0.3557573
|
||||
float_data: -0.4682631
|
||||
float_data: 1.1741459
|
||||
name: "W_tensor"
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
output: "R"
|
||||
op_type: "Constant"
|
||||
attribute {
|
||||
name: "value"
|
||||
t {
|
||||
dims: 1 # num_directions
|
||||
dims: 5 # gates_count*hidden_size
|
||||
dims: 5 # input_size
|
||||
data_type: 1
|
||||
float_data: -2.414789
|
||||
float_data: -0.42783254
|
||||
float_data: -0.82199496
|
||||
float_data: -0.03900861
|
||||
float_data: -0.43670088
|
||||
float_data: -0.53810567
|
||||
float_data: -0.10769883
|
||||
float_data: 0.75242394
|
||||
float_data: -0.2507971
|
||||
float_data: 1.0447186
|
||||
float_data: -1.4777364
|
||||
float_data: 0.19993274
|
||||
float_data: 0.925649
|
||||
float_data: -2.282516
|
||||
float_data: 0.95039636
|
||||
float_data: 1.5379831
|
||||
float_data: -0.88576007
|
||||
float_data: 0.28566247
|
||||
float_data: 0.79292643
|
||||
float_data: -0.04261953
|
||||
float_data: 0.8490583
|
||||
float_data: 0.45121244
|
||||
float_data: -1.1799014
|
||||
float_data: 0.13536449
|
||||
float_data: 0.81328654
|
||||
name: "R_tensor"
|
||||
}
|
||||
type: TENSOR
|
||||
}
|
||||
}
|
||||
node {
|
||||
input: "X"
|
||||
input: "W"
|
||||
input: "R"
|
||||
output: "Y"
|
||||
output: "Y_h"
|
||||
op_type: "RNN"
|
||||
attribute {
|
||||
name: "hidden_size"
|
||||
i: 5
|
||||
type: INT
|
||||
}
|
||||
}
|
||||
name: "test_rnn_defaults"
|
||||
input {
|
||||
name: "X"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: -1 # seq_length
|
||||
}
|
||||
dim {
|
||||
dim_value: -1 # batch size
|
||||
}
|
||||
dim {
|
||||
dim_value: 2 # input size
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: -1 # seq_length
|
||||
}
|
||||
dim {
|
||||
dim_value: 1 # num_directions
|
||||
}
|
||||
dim {
|
||||
dim_value: -1 # batch_size
|
||||
}
|
||||
dim {
|
||||
dim_value: 5 # hidden_size
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
output {
|
||||
name: "Y_h"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
shape {
|
||||
dim {
|
||||
dim_value: 1 # num_directions
|
||||
}
|
||||
dim {
|
||||
dim_value: -1 # batch_size
|
||||
}
|
||||
dim {
|
||||
dim_value: 5 # hidden_size
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
opset_import {
|
||||
version: 9
|
||||
}
|
@ -1551,6 +1551,70 @@ NGRAPH_TEST_F(${BACKEND_NAME}, GRUSequenceOp, onnx_model_gru_fwd_linear_before_r
|
||||
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 4);
|
||||
}
|
||||
|
||||
NGRAPH_TEST_F(${BACKEND_NAME}, GRUSequenceOp, onnx_model_gru_defaults_fwd_const_dynamic)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/dynamic_shapes/gru_defaults_fwd_const_dynamic.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine, test::TestCaseType::DYNAMIC>(function);
|
||||
test_case.add_input<float>(Shape{4, 3, 2}, in_X);
|
||||
|
||||
// Y
|
||||
test_case.add_expected_output<float>(
|
||||
Shape{4, 1, 3, 5},
|
||||
std::vector<float>{
|
||||
-0.3224981f, -0.44282594f, 0.7499796f, -0.12240417f, 0.12079421f, 0.02534253f,
|
||||
0.02504562f, -0.0463777f, 0.01204534f, -0.01497037f, -0.04651929f, -0.6264307f,
|
||||
0.7236632f, 0.06250653f, 0.02594197f, -0.06868916f, -0.5412897f, 0.49794048f,
|
||||
0.22239858f, -0.11257736f, -0.23071964f, 0.26079988f, -0.07375772f, -0.21816255f,
|
||||
0.18764113f, -0.5228772f, 0.00575754f, 0.2514028f, -0.58864325f, 0.49843538f,
|
||||
-0.6129046f, -0.10794663f, 0.6544055f, -0.70105773f, 0.5397687f, -0.35791716f,
|
||||
0.3885092f, -0.15291792f, -0.22324723f, 0.11557932f, -0.42112932f, 0.26772985f,
|
||||
-0.38304564f, -0.05039781f, -0.5057976f, 0.5775348f, -0.6736855f, -0.20032284f,
|
||||
0.03698462f, -0.7693824f, -0.5831348f, 0.25767964f, 0.7121098f, -0.35951245f,
|
||||
0.39223647f, -0.6645166f, 0.37950075f, 0.59931314f, -0.4741001f, 0.21156166f,
|
||||
});
|
||||
// Y_h
|
||||
test_case.add_expected_output<float>(Shape{1, 3, 5},
|
||||
std::vector<float>{
|
||||
0.5775348f,
|
||||
-0.6736855f,
|
||||
-0.20032284f,
|
||||
0.03698462f,
|
||||
-0.7693824f,
|
||||
-0.5831348f,
|
||||
0.25767964f,
|
||||
0.7121098f,
|
||||
-0.35951245f,
|
||||
0.39223647f,
|
||||
-0.6645166f,
|
||||
0.37950075f,
|
||||
0.59931314f,
|
||||
-0.4741001f,
|
||||
0.21156166f,
|
||||
});
|
||||
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 7);
|
||||
}
|
||||
|
||||
NGRAPH_TEST_F(${BACKEND_NAME}, GRUSequenceOp, onnx_model_import_only_gru_defaults_fwd_const_dynamic)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/dynamic_shapes/gru_defaults_fwd_const_dynamic.prototxt"));
|
||||
|
||||
auto batch_size = Dimension::dynamic();
|
||||
auto seq_length = Dimension::dynamic();
|
||||
int64_t hidden_size = 5;
|
||||
int64_t num_directions = 1;
|
||||
auto Y_expected_output = PartialShape{batch_size, num_directions, seq_length, hidden_size};
|
||||
auto Y_h_expected_output = PartialShape{num_directions, batch_size, hidden_size};
|
||||
|
||||
EXPECT_EQ(function->get_output_size(), 2);
|
||||
EXPECT_EQ(function->get_output_partial_shape(0), Y_expected_output);
|
||||
EXPECT_EQ(function->get_output_partial_shape(1), Y_h_expected_output);
|
||||
|
||||
EXPECT_EQ(count_ops_of_type<op::v5::GRUSequence>(function), 1);
|
||||
}
|
||||
|
||||
// RNNLikeSequenceOp test fixture for test setup reuse
|
||||
class RNNSequenceOp : public testing::Test
|
||||
{
|
||||
@ -2386,3 +2450,67 @@ NGRAPH_TEST_F(${BACKEND_NAME}, RNNSequenceOp, onnx_model_rnn_bidirectional_const
|
||||
});
|
||||
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 6);
|
||||
}
|
||||
|
||||
NGRAPH_TEST_F(${BACKEND_NAME}, RNNSequenceOp, onnx_model_rnn_defaults_fwd_const_dynamic)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/dynamic_shapes/rnn_defaults_fwd_const_dynamic.prototxt"));
|
||||
|
||||
auto test_case = test::TestCase<TestEngine, test::TestCaseType::DYNAMIC>(function);
|
||||
test_case.add_input<float>(Shape{4, 3, 2}, in_X);
|
||||
|
||||
// Y
|
||||
test_case.add_expected_output<float>(
|
||||
Shape{4, 1, 3, 5},
|
||||
std::vector<float>{
|
||||
0.02254748f, 0.15776646f, -0.8229023f, 0.19205809f, 0.76984656f, -0.00603169f,
|
||||
-0.02861464f, 0.04512155f, -0.0011912f, -0.02572936f, -0.13703543f, -0.49651444f,
|
||||
-0.78868157f, 0.3566854f, 0.8758509f, 0.20788848f, 0.13481987f, -0.756822f,
|
||||
-0.121436f, 0.97542346f, 0.16959739f, 0.63496053f, 0.1245538f, -0.1970138f,
|
||||
-0.56581646f, 0.8225869f, 0.9611373f, -0.42990375f, -0.22925597f, 0.2226491f,
|
||||
0.08246052f, 0.9798831f, -0.13415998f, -0.5567714f, 0.78594816f, -0.34759718f,
|
||||
0.11376679f, -0.07107389f, -0.5420871f, -0.58504283f, -0.96065646f, 0.18588805f,
|
||||
-0.4870671f, -0.1475982f, 0.82456505f, -0.80264574f, -0.46370947f, 0.9719335f,
|
||||
-0.7374159f, 0.94937694f, 0.8814341f, 0.67015004f, 0.21958017f, -0.8332769f,
|
||||
-0.487742f, 0.9918536f, 0.99563396f, 0.94866276f, -0.98504806f, -0.42824882f,
|
||||
});
|
||||
// Y_h
|
||||
test_case.add_expected_output<float>(Shape{1, 3, 5},
|
||||
std::vector<float>{
|
||||
-0.80264574f,
|
||||
-0.46370947f,
|
||||
0.9719335f,
|
||||
-0.7374159f,
|
||||
0.94937694f,
|
||||
0.8814341f,
|
||||
0.67015004f,
|
||||
0.21958017f,
|
||||
-0.8332769f,
|
||||
-0.487742f,
|
||||
0.9918536f,
|
||||
0.99563396f,
|
||||
0.94866276f,
|
||||
-0.98504806f,
|
||||
-0.42824882f,
|
||||
});
|
||||
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 4);
|
||||
}
|
||||
|
||||
NGRAPH_TEST_F(${BACKEND_NAME}, RNNSequenceOp, onnx_model_import_only_rnn_defaults_fwd_const_dynamic)
|
||||
{
|
||||
auto function = onnx_import::import_onnx_model(file_util::path_join(
|
||||
SERIALIZED_ZOO, "onnx/dynamic_shapes/rnn_defaults_fwd_const_dynamic.prototxt"));
|
||||
|
||||
auto batch_size = Dimension::dynamic();
|
||||
auto seq_length = Dimension::dynamic();
|
||||
int64_t hidden_size = 5;
|
||||
int64_t num_directions = 1;
|
||||
auto Y_expected_output = PartialShape{batch_size, num_directions, seq_length, hidden_size};
|
||||
auto Y_h_expected_output = PartialShape{num_directions, batch_size, hidden_size};
|
||||
|
||||
EXPECT_EQ(function->get_output_size(), 2);
|
||||
EXPECT_EQ(function->get_output_partial_shape(0), Y_expected_output);
|
||||
EXPECT_EQ(function->get_output_partial_shape(1), Y_h_expected_output);
|
||||
|
||||
EXPECT_EQ(count_ops_of_type<op::v5::RNNSequence>(function), 1);
|
||||
}
|
||||
|
@ -249,6 +249,8 @@ IE_CPU.nothing_to_reverse
|
||||
|
||||
# Unsupported dynamic ops
|
||||
onnx_size_dyn_op
|
||||
onnx_model_gru_defaults_fwd_const_dynamic
|
||||
onnx_model_rnn_defaults_fwd_const_dynamic
|
||||
|
||||
# Constant network
|
||||
# MKLDNNGraph::CreateGraph: No inputs for the topology
|
||||
|
Loading…
Reference in New Issue
Block a user