ONNX RNN/GRU enable dynamic input shape (#4241)

This commit is contained in:
Katarzyna Mitrus 2021-02-15 10:30:46 +01:00 committed by GitHub
parent b800f08c0c
commit ead6427097
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 529 additions and 21 deletions

View File

@ -41,26 +41,25 @@ namespace ngraph
m_map[OpInput::W] = ng_inputs.at(1);
m_map[OpInput::R] = ng_inputs.at(2);
const auto el_type = ng_inputs.at(0).get_element_type();
const auto x_pshape = m_map[OpInput::X].get_partial_shape();
const auto w_pshape = m_map[OpInput::W].get_partial_shape();
const auto r_pshape = m_map[OpInput::R].get_partial_shape();
NGRAPH_CHECK(x_pshape.rank().is_static() && x_pshape[0].is_static() &&
x_pshape[1].is_static(),
"RecurrentSequence input X must have static \"seq_length\" and "
"\"batch_size\" dimensions.");
NGRAPH_CHECK(w_pshape.rank().is_static() && w_pshape[0].is_static(),
"RecurrentSequence input W must have static \"num_directions\" "
"(outermost) dimension.");
NGRAPH_CHECK(r_pshape.rank().is_static() && r_pshape[2].is_static(),
"RecurrentSequence input R must have static \"hidden_size\" "
"(innermost) dimension.");
const std::size_t hidden_size = m_map[OpInput::R].get_shape().back();
const std::size_t batch_size = m_map[OpInput::X].get_shape().at(0);
const std::size_t num_directions = m_map[OpInput::W].get_shape().front();
// Get dimensions needed for default inputs creation
auto shape_of_x = std::make_shared<default_opset::ShapeOf>(m_map[OpInput::X]);
auto axes = default_opset::Constant::create(element::i32, Shape{1}, {0});
auto batch_size_node = std::make_shared<default_opset::Gather>(
shape_of_x, default_opset::Constant::create(element::i32, Shape{1}, {0}), axes);
auto seq_length_node = std::make_shared<default_opset::Gather>(
shape_of_x, default_opset::Constant::create(element::i32, Shape{1}, {1}), axes);
auto shape_of_r = std::make_shared<default_opset::ShapeOf>(m_map[OpInput::R]);
auto num_directions_node = std::make_shared<default_opset::Gather>(
shape_of_r, default_opset::Constant::create(element::i32, Shape{1}, {0}), axes);
auto hidden_size_node = std::make_shared<default_opset::Gather>(
shape_of_r, default_opset::Constant::create(element::i32, Shape{1}, {2}), axes);
// ------ Optional inputs ------
if (ng_inputs.size() > 3 && !ngraph::op::is_null(ng_inputs.at(3)))
{
auto bias = ng_inputs.at(3);
@ -72,8 +71,17 @@ namespace ngraph
}
else
{
m_map[OpInput::B] = std::make_shared<default_opset::Constant>(
el_type, Shape{num_directions, gates_count * hidden_size}, 0.f);
auto b_shape = std::make_shared<default_opset::Concat>(
OutputVector{num_directions_node,
std::make_shared<default_opset::Multiply>(
default_opset::Constant::create(
element::Type_t::i64, Shape{1}, {gates_count}),
hidden_size_node)},
0);
m_map[OpInput::B] = std::make_shared<default_opset::Broadcast>(
default_opset::Constant::create(
m_map[OpInput::X].get_element_type(), Shape{}, {0}),
b_shape);
}
if (ng_inputs.size() > 4 && !ngraph::op::is_null(ng_inputs.at(4)))
{
@ -81,8 +89,8 @@ namespace ngraph
}
else
{
m_map[OpInput::SEQ_LENGTHS] = std::make_shared<default_opset::Constant>(
element::i32, Shape{batch_size}, m_map[OpInput::X].get_shape().at(1));
m_map[OpInput::SEQ_LENGTHS] = std::make_shared<default_opset::Broadcast>(
seq_length_node, batch_size_node);
}
// The initial value of the hidden.
if (ng_inputs.size() > 5 && !ngraph::op::is_null(ng_inputs.at(5)))
@ -92,8 +100,12 @@ namespace ngraph
}
else
{
m_map[OpInput::INIT_H] = std::make_shared<default_opset::Constant>(
el_type, Shape{batch_size, num_directions, hidden_size}, 0.f);
auto init_h_shape = std::make_shared<default_opset::Concat>(
OutputVector{batch_size_node, num_directions_node, hidden_size_node}, 0);
m_map[OpInput::INIT_H] = std::make_shared<default_opset::Broadcast>(
default_opset::Constant::create(
m_map[OpInput::X].get_element_type(), Shape{}, {0}),
init_h_shape);
}
}

View File

@ -0,0 +1,220 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
output: "W"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1 # num_directions
dims: 15 # gates_count*hidden_size
dims: 2 # input_size
data_type: 1
float_data: 0.31403765
float_data: -0.16793324
float_data: 1.388258
float_data: -0.6902954
float_data: -0.3994045
float_data: -0.7833511
float_data: -0.30992958
float_data: 0.3557573
float_data: -0.4682631
float_data: 1.1741459
float_data: -2.414789
float_data: -0.42783254
float_data: -0.82199496
float_data: -0.0390086
float_data: -0.43670088
float_data: -0.53810567
float_data: -0.10769883
float_data: 0.75242394
float_data: -0.2507971
float_data: 1.0447186
float_data: -1.4777364
float_data: 0.19993274
float_data: 0.925649
float_data: -2.282516
float_data: 0.95039636
float_data: 1.5379831
float_data: -0.88576007
float_data: 0.28566247
float_data: 0.79292643
float_data: -0.04261953
name: "W_tensor"
}
type: TENSOR
}
}
node {
output: "R"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1 # num_directions
dims: 15 # gates_count*hidden_size
dims: 5 # input_size
data_type: 1
float_data: 0.8490583
float_data: 0.45121244
float_data: -1.1799014
float_data: 0.13536449
float_data: 0.81328654
float_data: 0.6017516
float_data: 0.48475724
float_data: -1.2136037
float_data: 0.16383322
float_data: 1.5106261
float_data: 1.1177503
float_data: 0.23582461
float_data: 0.5754652
float_data: 0.43879887
float_data: 0.7399294
float_data: 0.4517558
float_data: 1.3536783
float_data: -0.4843166
float_data: -1.1503736
float_data: -0.2458678
float_data: 0.54523313
float_data: -0.08649993
float_data: -0.6936281
float_data: 1.002422
float_data: -1.770847
float_data: -0.94642
float_data: -1.8135757
float_data: 1.8819852
float_data: -0.10852333
float_data: -0.26120332
float_data: 1.0223165
float_data: -0.7468837
float_data: 0.28566906
float_data: 0.92321056
float_data: 0.22521864
float_data: 1.1123824
float_data: -0.9298287
float_data: 1.2141289
float_data: 1.3470556
float_data: -0.32972014
float_data: -1.6552197
float_data: -1.0998285
float_data: 0.71901864
float_data: 0.962846
float_data: -0.1366851
float_data: -2.6534476
float_data: -1.4992771
float_data: -0.45793465
float_data: 0.4290477
float_data: 0.9893151
float_data: 0.2511034
float_data: 0.12906462
float_data: 0.7491512
float_data: 0.3316756
float_data: 1.0576645
float_data: -0.04618666
float_data: 1.3556088
float_data: 1.2842374
float_data: 0.7103014
float_data: 0.52889013
float_data: 0.30327162
float_data: 1.5069056
float_data: 0.16591893
float_data: 1.5719851
float_data: -2.099427
float_data: -1.010277
float_data: -0.52800924
float_data: -0.22292352
float_data: -0.55177474
float_data: 1.3432894
float_data: 0.8731192
float_data: -0.01055307
float_data: -0.01138215
float_data: 0.85698843
float_data: -1.2615703
name: "R_tensor"
}
type: TENSOR
}
}
node {
input: "X"
input: "W"
input: "R"
output: "Y"
output: "Y_h"
op_type: "GRU"
attribute {
name: "hidden_size"
i: 5
type: INT
}
}
name: "test_gru_defaults_const"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: -1 # seq_length
}
dim {
dim_value: -1 # batch size
}
dim {
dim_value: 2 # input size
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: -1 # seq_length
}
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: -1 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: -1 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@ -0,0 +1,146 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
output: "W"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1
dims: 5
dims: 2
data_type: 1
float_data: 0.31403765
float_data: -0.16793324
float_data: 1.388258
float_data: -0.6902954
float_data: -0.3994045
float_data: -0.7833511
float_data: -0.30992958
float_data: 0.3557573
float_data: -0.4682631
float_data: 1.1741459
name: "W_tensor"
}
type: TENSOR
}
}
node {
output: "R"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 1 # num_directions
dims: 5 # gates_count*hidden_size
dims: 5 # input_size
data_type: 1
float_data: -2.414789
float_data: -0.42783254
float_data: -0.82199496
float_data: -0.03900861
float_data: -0.43670088
float_data: -0.53810567
float_data: -0.10769883
float_data: 0.75242394
float_data: -0.2507971
float_data: 1.0447186
float_data: -1.4777364
float_data: 0.19993274
float_data: 0.925649
float_data: -2.282516
float_data: 0.95039636
float_data: 1.5379831
float_data: -0.88576007
float_data: 0.28566247
float_data: 0.79292643
float_data: -0.04261953
float_data: 0.8490583
float_data: 0.45121244
float_data: -1.1799014
float_data: 0.13536449
float_data: 0.81328654
name: "R_tensor"
}
type: TENSOR
}
}
node {
input: "X"
input: "W"
input: "R"
output: "Y"
output: "Y_h"
op_type: "RNN"
attribute {
name: "hidden_size"
i: 5
type: INT
}
}
name: "test_rnn_defaults"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: -1 # seq_length
}
dim {
dim_value: -1 # batch size
}
dim {
dim_value: 2 # input size
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: -1 # seq_length
}
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: -1 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: -1 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@ -1551,6 +1551,70 @@ NGRAPH_TEST_F(${BACKEND_NAME}, GRUSequenceOp, onnx_model_gru_fwd_linear_before_r
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 4);
}
NGRAPH_TEST_F(${BACKEND_NAME}, GRUSequenceOp, onnx_model_gru_defaults_fwd_const_dynamic)
{
auto function = onnx_import::import_onnx_model(file_util::path_join(
SERIALIZED_ZOO, "onnx/dynamic_shapes/gru_defaults_fwd_const_dynamic.prototxt"));
auto test_case = test::TestCase<TestEngine, test::TestCaseType::DYNAMIC>(function);
test_case.add_input<float>(Shape{4, 3, 2}, in_X);
// Y
test_case.add_expected_output<float>(
Shape{4, 1, 3, 5},
std::vector<float>{
-0.3224981f, -0.44282594f, 0.7499796f, -0.12240417f, 0.12079421f, 0.02534253f,
0.02504562f, -0.0463777f, 0.01204534f, -0.01497037f, -0.04651929f, -0.6264307f,
0.7236632f, 0.06250653f, 0.02594197f, -0.06868916f, -0.5412897f, 0.49794048f,
0.22239858f, -0.11257736f, -0.23071964f, 0.26079988f, -0.07375772f, -0.21816255f,
0.18764113f, -0.5228772f, 0.00575754f, 0.2514028f, -0.58864325f, 0.49843538f,
-0.6129046f, -0.10794663f, 0.6544055f, -0.70105773f, 0.5397687f, -0.35791716f,
0.3885092f, -0.15291792f, -0.22324723f, 0.11557932f, -0.42112932f, 0.26772985f,
-0.38304564f, -0.05039781f, -0.5057976f, 0.5775348f, -0.6736855f, -0.20032284f,
0.03698462f, -0.7693824f, -0.5831348f, 0.25767964f, 0.7121098f, -0.35951245f,
0.39223647f, -0.6645166f, 0.37950075f, 0.59931314f, -0.4741001f, 0.21156166f,
});
// Y_h
test_case.add_expected_output<float>(Shape{1, 3, 5},
std::vector<float>{
0.5775348f,
-0.6736855f,
-0.20032284f,
0.03698462f,
-0.7693824f,
-0.5831348f,
0.25767964f,
0.7121098f,
-0.35951245f,
0.39223647f,
-0.6645166f,
0.37950075f,
0.59931314f,
-0.4741001f,
0.21156166f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 7);
}
NGRAPH_TEST_F(${BACKEND_NAME}, GRUSequenceOp, onnx_model_import_only_gru_defaults_fwd_const_dynamic)
{
auto function = onnx_import::import_onnx_model(file_util::path_join(
SERIALIZED_ZOO, "onnx/dynamic_shapes/gru_defaults_fwd_const_dynamic.prototxt"));
auto batch_size = Dimension::dynamic();
auto seq_length = Dimension::dynamic();
int64_t hidden_size = 5;
int64_t num_directions = 1;
auto Y_expected_output = PartialShape{batch_size, num_directions, seq_length, hidden_size};
auto Y_h_expected_output = PartialShape{num_directions, batch_size, hidden_size};
EXPECT_EQ(function->get_output_size(), 2);
EXPECT_EQ(function->get_output_partial_shape(0), Y_expected_output);
EXPECT_EQ(function->get_output_partial_shape(1), Y_h_expected_output);
EXPECT_EQ(count_ops_of_type<op::v5::GRUSequence>(function), 1);
}
// RNNLikeSequenceOp test fixture for test setup reuse
class RNNSequenceOp : public testing::Test
{
@ -2386,3 +2450,67 @@ NGRAPH_TEST_F(${BACKEND_NAME}, RNNSequenceOp, onnx_model_rnn_bidirectional_const
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 6);
}
NGRAPH_TEST_F(${BACKEND_NAME}, RNNSequenceOp, onnx_model_rnn_defaults_fwd_const_dynamic)
{
auto function = onnx_import::import_onnx_model(file_util::path_join(
SERIALIZED_ZOO, "onnx/dynamic_shapes/rnn_defaults_fwd_const_dynamic.prototxt"));
auto test_case = test::TestCase<TestEngine, test::TestCaseType::DYNAMIC>(function);
test_case.add_input<float>(Shape{4, 3, 2}, in_X);
// Y
test_case.add_expected_output<float>(
Shape{4, 1, 3, 5},
std::vector<float>{
0.02254748f, 0.15776646f, -0.8229023f, 0.19205809f, 0.76984656f, -0.00603169f,
-0.02861464f, 0.04512155f, -0.0011912f, -0.02572936f, -0.13703543f, -0.49651444f,
-0.78868157f, 0.3566854f, 0.8758509f, 0.20788848f, 0.13481987f, -0.756822f,
-0.121436f, 0.97542346f, 0.16959739f, 0.63496053f, 0.1245538f, -0.1970138f,
-0.56581646f, 0.8225869f, 0.9611373f, -0.42990375f, -0.22925597f, 0.2226491f,
0.08246052f, 0.9798831f, -0.13415998f, -0.5567714f, 0.78594816f, -0.34759718f,
0.11376679f, -0.07107389f, -0.5420871f, -0.58504283f, -0.96065646f, 0.18588805f,
-0.4870671f, -0.1475982f, 0.82456505f, -0.80264574f, -0.46370947f, 0.9719335f,
-0.7374159f, 0.94937694f, 0.8814341f, 0.67015004f, 0.21958017f, -0.8332769f,
-0.487742f, 0.9918536f, 0.99563396f, 0.94866276f, -0.98504806f, -0.42824882f,
});
// Y_h
test_case.add_expected_output<float>(Shape{1, 3, 5},
std::vector<float>{
-0.80264574f,
-0.46370947f,
0.9719335f,
-0.7374159f,
0.94937694f,
0.8814341f,
0.67015004f,
0.21958017f,
-0.8332769f,
-0.487742f,
0.9918536f,
0.99563396f,
0.94866276f,
-0.98504806f,
-0.42824882f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 4);
}
NGRAPH_TEST_F(${BACKEND_NAME}, RNNSequenceOp, onnx_model_import_only_rnn_defaults_fwd_const_dynamic)
{
auto function = onnx_import::import_onnx_model(file_util::path_join(
SERIALIZED_ZOO, "onnx/dynamic_shapes/rnn_defaults_fwd_const_dynamic.prototxt"));
auto batch_size = Dimension::dynamic();
auto seq_length = Dimension::dynamic();
int64_t hidden_size = 5;
int64_t num_directions = 1;
auto Y_expected_output = PartialShape{batch_size, num_directions, seq_length, hidden_size};
auto Y_h_expected_output = PartialShape{num_directions, batch_size, hidden_size};
EXPECT_EQ(function->get_output_size(), 2);
EXPECT_EQ(function->get_output_partial_shape(0), Y_expected_output);
EXPECT_EQ(function->get_output_partial_shape(1), Y_h_expected_output);
EXPECT_EQ(count_ops_of_type<op::v5::RNNSequence>(function), 1);
}

View File

@ -249,6 +249,8 @@ IE_CPU.nothing_to_reverse
# Unsupported dynamic ops
onnx_size_dyn_op
onnx_model_gru_defaults_fwd_const_dynamic
onnx_model_rnn_defaults_fwd_const_dynamic
# Constant network
# MKLDNNGraph::CreateGraph: No inputs for the topology