LSTMSequence/GPUSequence - fix validate_and_infer_types (#9996)
* Fix LSTMSequence/GPUSequence validation behavior consistent with RNNSequence Fixed issue with no exception if num_directions=2, but 'm_direction' is not set to BIDIRECTIONAL. Previously there was no error with this (and luckily it failed later in some CPU transformations during compile_network) Corrected several tests which use copy-pasted num_directions=2 without m_direction set Also for dynamic 'num_directions' - output shape still has 1 or 2 directions, because m_direction is known. Tests for GRU/LSTM are updated for this Also several tests worked incorrectly for LSTMv0 - expectation was specific error to be thrown, but no expection was also allowed * Fixed clang-format
This commit is contained in:
@@ -85,13 +85,13 @@ void op::v5::GRUSequence::validate_and_infer_types() {
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, ht_pshape[0]) &&
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, x_pshape[0]) &&
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, sl_pshape[0]),
|
||||
"Parameter batch_size not matched in RNNSequence.");
|
||||
"Parameter batch_size not matched in GRUSequence.");
|
||||
|
||||
// Merge hidden_size dimension across all inputs to evaluate output dimension
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
Dimension::merge(merged_hidden_size, merged_hidden_size, ht_pshape[2]) &&
|
||||
Dimension::merge(merged_hidden_size, merged_hidden_size, r_pshape[2]),
|
||||
"Parameter hidden_size not matched RNNSequence.");
|
||||
"Parameter hidden_size not matched GRUSequence.");
|
||||
|
||||
// Merge num_directions dimension across all inputs to evaluate output dimension
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
@@ -99,7 +99,27 @@ void op::v5::GRUSequence::validate_and_infer_types() {
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, w_pshape[0]) &&
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, r_pshape[0]) &&
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, b_pshape[0]),
|
||||
"Parameter num_directions not matched in RNNSequence.");
|
||||
"Parameter num_directions not matched in GRUSequence.");
|
||||
|
||||
auto valid_num_directions = 0;
|
||||
if (m_direction == op::RecurrentSequenceDirection::FORWARD ||
|
||||
m_direction == op::RecurrentSequenceDirection::REVERSE) {
|
||||
valid_num_directions = 1;
|
||||
} else if (m_direction == op::RecurrentSequenceDirection::BIDIRECTIONAL) {
|
||||
valid_num_directions = 2;
|
||||
} else {
|
||||
// Guard for potential future extension of RecurrentSequenceDirection enum
|
||||
NODE_VALIDATION_CHECK(this, false, "Parameter direction must be FORWARD or REVERSE or BIDIRECTIONAL.");
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, valid_num_directions),
|
||||
"Parameter 'num_directions' doesn't match with direction '",
|
||||
m_direction,
|
||||
"' in GRUSequence. Expected ",
|
||||
valid_num_directions,
|
||||
", actual ",
|
||||
merged_num_directions);
|
||||
|
||||
// Validate hidden_size value for W, R, B inputs
|
||||
if (merged_hidden_size.is_static()) {
|
||||
|
||||
@@ -381,6 +381,26 @@ void op::v0::LSTMSequence::validate_and_infer_types() {
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, b_pshape[0]),
|
||||
"Parameter num_directions not matched in LSTMSequence.");
|
||||
|
||||
auto valid_num_directions = 0;
|
||||
if (m_direction == op::RecurrentSequenceDirection::FORWARD ||
|
||||
m_direction == op::RecurrentSequenceDirection::REVERSE) {
|
||||
valid_num_directions = 1;
|
||||
} else if (m_direction == op::RecurrentSequenceDirection::BIDIRECTIONAL) {
|
||||
valid_num_directions = 2;
|
||||
} else {
|
||||
// Guard for potential future extension of RecurrentSequenceDirection enum
|
||||
NODE_VALIDATION_CHECK(this, false, "Parameter direction must be FORWARD or REVERSE or BIDIRECTIONAL.");
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, valid_num_directions),
|
||||
"Parameter 'num_directions' doesn't match with direction '",
|
||||
m_direction,
|
||||
"' in LSTMSequence. Expected ",
|
||||
valid_num_directions,
|
||||
", actual ",
|
||||
merged_num_directions);
|
||||
|
||||
// Validate hidden_size value for W, R, B and P inputs
|
||||
if (merged_hidden_size.is_static()) {
|
||||
if (w_pshape[1].is_static()) {
|
||||
@@ -546,22 +566,25 @@ void op::v5::LSTMSequence::validate_and_infer_types() {
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, b_pshape[0]),
|
||||
"Parameter num_directions not matched in LSTMSequence.");
|
||||
|
||||
auto check_direction_valid = [](const ov::PartialShape& pshape, size_t index) -> bool {
|
||||
if (pshape[index].is_static())
|
||||
return static_cast<direction>(pshape[index].get_length()) == direction::FORWARD ||
|
||||
static_cast<direction>(pshape[index].get_length()) == direction::REVERSE ||
|
||||
static_cast<direction>(pshape[index].get_length()) == direction::BIDIRECTIONAL;
|
||||
return true;
|
||||
};
|
||||
auto valid_num_directions = 0;
|
||||
if (m_direction == op::RecurrentSequenceDirection::FORWARD ||
|
||||
m_direction == op::RecurrentSequenceDirection::REVERSE) {
|
||||
valid_num_directions = 1;
|
||||
} else if (m_direction == op::RecurrentSequenceDirection::BIDIRECTIONAL) {
|
||||
valid_num_directions = 2;
|
||||
} else {
|
||||
// Guard for potential future extension of RecurrentSequenceDirection enum
|
||||
NODE_VALIDATION_CHECK(this, false, "Parameter direction must be FORWARD or REVERSE or BIDIRECTIONAL.");
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
check_direction_valid(ht_pshape, 1),
|
||||
"Parameter direction must be Forward or Reverse or Bidirectional.");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
m_direction == direction::FORWARD || m_direction == direction::REVERSE ||
|
||||
m_direction == direction::BIDIRECTIONAL,
|
||||
"Parameter direction must be Forward or Reverse or Bidirectional.");
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, valid_num_directions),
|
||||
"Parameter 'num_directions' doesn't match with direction '",
|
||||
m_direction,
|
||||
"' in LSTMSequence. Expected ",
|
||||
valid_num_directions,
|
||||
", actual ",
|
||||
merged_num_directions);
|
||||
|
||||
// Validate hidden_size value for W, R, B inputs
|
||||
if (merged_hidden_size.is_static()) {
|
||||
|
||||
@@ -104,12 +104,18 @@ void op::v5::RNNSequence::validate_and_infer_types() {
|
||||
} else if (m_direction == op::RecurrentSequenceDirection::BIDIRECTIONAL) {
|
||||
valid_num_directions = 2;
|
||||
} else {
|
||||
// Guard for potential future extension of RecurrentSequenceDirection enum
|
||||
NODE_VALIDATION_CHECK(this, false, "Parameter direction must be FORWARD or REVERSE or BIDIRECTIONAL.");
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, valid_num_directions),
|
||||
"Parameter num_directions not match direction in RNNSequence.");
|
||||
"Parameter 'num_directions' doesn't match with direction '",
|
||||
m_direction,
|
||||
"' in RNNSequence. Expected ",
|
||||
valid_num_directions,
|
||||
", actual ",
|
||||
merged_num_directions);
|
||||
|
||||
// Validate hidden_size value for W, R, B inputs
|
||||
if (merged_hidden_size.is_static()) {
|
||||
|
||||
@@ -47,6 +47,36 @@ shared_ptr<opset5::GRUSequence> gru_seq_tensor_initialization(const gru_sequence
|
||||
return gru_sequence;
|
||||
}
|
||||
|
||||
shared_ptr<opset5::GRUSequence> gru_seq_direction_initialization(const gru_sequence_parameters& param,
|
||||
op::RecurrentSequenceDirection direction) {
|
||||
auto batch_size = param.batch_size;
|
||||
auto seq_length = param.seq_length;
|
||||
auto input_size = param.input_size;
|
||||
auto num_directions = param.num_directions;
|
||||
auto hidden_size = param.hidden_size;
|
||||
auto hidden_size_value = hidden_size.is_dynamic() ? 0 : hidden_size.get_length();
|
||||
auto et = param.et;
|
||||
|
||||
const auto X = make_shared<opset5::Parameter>(et, PartialShape{batch_size, seq_length, input_size});
|
||||
const auto initial_hidden_state =
|
||||
make_shared<opset5::Parameter>(et, PartialShape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<opset5::Parameter>(et, PartialShape{batch_size});
|
||||
const auto W = make_shared<opset5::Parameter>(et, PartialShape{num_directions, hidden_size * 3, input_size});
|
||||
const auto R = make_shared<opset5::Parameter>(et, PartialShape{num_directions, hidden_size * 3, hidden_size});
|
||||
const auto B = make_shared<opset5::Parameter>(et, PartialShape{num_directions, hidden_size * 3});
|
||||
|
||||
auto gru_sequence = make_shared<opset5::GRUSequence>(X,
|
||||
initial_hidden_state,
|
||||
sequence_lengths,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size_value,
|
||||
direction);
|
||||
|
||||
return gru_sequence;
|
||||
}
|
||||
|
||||
TEST(type_prop, gru_sequence_forward) {
|
||||
const size_t batch_size = 8;
|
||||
const size_t num_directions = 1;
|
||||
@@ -84,7 +114,7 @@ TEST(type_prop, gru_sequence_forward) {
|
||||
|
||||
TEST(type_prop, gru_sequence_bidirectional) {
|
||||
const size_t batch_size = 8;
|
||||
const size_t num_directions = 1;
|
||||
const size_t num_directions = 2;
|
||||
const size_t seq_length = 6;
|
||||
const size_t input_size = 4;
|
||||
const size_t hidden_size = 128;
|
||||
@@ -132,7 +162,7 @@ TEST(type_prop, gru_sequence_bidirectional) {
|
||||
TEST(type_prop, gru_sequence_dynamic_batch_size) {
|
||||
gru_sequence_parameters param;
|
||||
param.batch_size = Dimension::dynamic();
|
||||
param.num_directions = 2;
|
||||
param.num_directions = 1;
|
||||
param.seq_length = 6;
|
||||
param.input_size = 4;
|
||||
param.hidden_size = 128;
|
||||
@@ -162,9 +192,8 @@ TEST(type_prop, gru_sequence_dynamic_num_directions) {
|
||||
gru_sequence->validate_and_infer_types();
|
||||
|
||||
EXPECT_EQ(gru_sequence->get_output_partial_shape(0),
|
||||
(PartialShape{param.batch_size, param.num_directions, param.seq_length, param.hidden_size}));
|
||||
EXPECT_EQ(gru_sequence->get_output_partial_shape(1),
|
||||
(PartialShape{param.batch_size, param.num_directions, param.hidden_size}));
|
||||
(PartialShape{param.batch_size, 1, param.seq_length, param.hidden_size}));
|
||||
EXPECT_EQ(gru_sequence->get_output_partial_shape(1), (PartialShape{param.batch_size, 1, param.hidden_size}));
|
||||
EXPECT_EQ(gru_sequence->get_output_element_type(0), param.et);
|
||||
EXPECT_EQ(gru_sequence->get_output_element_type(1), param.et);
|
||||
}
|
||||
@@ -235,7 +264,7 @@ TEST(type_prop, gru_sequence_invalid_input_dynamic_rank) {
|
||||
gru_sequence_parameters param;
|
||||
|
||||
param.batch_size = 8;
|
||||
param.num_directions = 2;
|
||||
param.num_directions = 1;
|
||||
param.seq_length = 6;
|
||||
param.input_size = 4;
|
||||
param.hidden_size = 128;
|
||||
@@ -258,3 +287,27 @@ TEST(type_prop, gru_sequence_invalid_input_dynamic_rank) {
|
||||
EXPECT_EQ(check_dynamic_gru(gru_sequence), true);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, gru_sequence_invalid_input_direction_num_mismatch) {
|
||||
auto check_error = [](op::RecurrentSequenceDirection direction, int num_directions) {
|
||||
gru_sequence_parameters param;
|
||||
|
||||
param.batch_size = 24;
|
||||
param.num_directions = num_directions;
|
||||
param.seq_length = 12;
|
||||
param.input_size = 8;
|
||||
param.hidden_size = 256;
|
||||
param.et = element::f32;
|
||||
try {
|
||||
auto gru_sequence = gru_seq_direction_initialization(param, direction);
|
||||
gru_sequence->validate_and_infer_types();
|
||||
FAIL() << "GRUSequence node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter 'num_directions' doesn't match with direction"));
|
||||
}
|
||||
};
|
||||
|
||||
check_error(op::RecurrentSequenceDirection::BIDIRECTIONAL, 1);
|
||||
check_error(op::RecurrentSequenceDirection::FORWARD, 2);
|
||||
check_error(op::RecurrentSequenceDirection::REVERSE, 2);
|
||||
}
|
||||
|
||||
@@ -57,6 +57,39 @@ shared_ptr<opset5::LSTMSequence> lstm_seq_tensor_initialization(const recurrent_
|
||||
return lstm_sequence;
|
||||
}
|
||||
|
||||
shared_ptr<opset5::LSTMSequence> lstm_seq_direction_initialization(const recurrent_sequence_parameters& param,
|
||||
opset5::LSTMSequence::direction direction) {
|
||||
auto batch_size = param.batch_size;
|
||||
auto seq_length = param.seq_length;
|
||||
auto input_size = param.input_size;
|
||||
auto num_directions = param.num_directions;
|
||||
auto hidden_size = param.hidden_size;
|
||||
int64_t hidden_size_num = hidden_size.is_dynamic() ? 0 : hidden_size.get_length();
|
||||
auto et = param.et;
|
||||
|
||||
const auto X = make_shared<opset5::Parameter>(et, PartialShape{batch_size, seq_length, input_size});
|
||||
const auto initial_hidden_state =
|
||||
make_shared<opset5::Parameter>(et, PartialShape{batch_size, num_directions, hidden_size});
|
||||
const auto initial_cell_state =
|
||||
make_shared<opset5::Parameter>(et, PartialShape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<opset5::Parameter>(et, PartialShape{batch_size});
|
||||
const auto W = make_shared<opset5::Parameter>(et, PartialShape{num_directions, hidden_size * 4, input_size});
|
||||
const auto R = make_shared<opset5::Parameter>(et, PartialShape{num_directions, hidden_size * 4, hidden_size});
|
||||
const auto B = make_shared<opset5::Parameter>(et, PartialShape{num_directions, hidden_size * 4});
|
||||
|
||||
auto lstm_sequence = make_shared<opset5::LSTMSequence>(X,
|
||||
initial_hidden_state,
|
||||
initial_cell_state,
|
||||
sequence_lengths,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size_num,
|
||||
direction);
|
||||
|
||||
return lstm_sequence;
|
||||
}
|
||||
|
||||
shared_ptr<opset1::LSTMSequence> lstm_seq_v1_tensor_initialization(const recurrent_sequence_parameters& param) {
|
||||
auto batch_size = param.batch_size;
|
||||
auto seq_length = param.seq_length;
|
||||
@@ -297,7 +330,7 @@ TEST(type_prop, lstm_sequence_dynamic_batch_size) {
|
||||
recurrent_sequence_parameters param;
|
||||
|
||||
param.batch_size = Dimension::dynamic();
|
||||
param.num_directions = 2;
|
||||
param.num_directions = 1;
|
||||
param.seq_length = 12;
|
||||
param.input_size = 8;
|
||||
param.hidden_size = 256;
|
||||
@@ -330,12 +363,11 @@ TEST(type_prop, lstm_sequence_dynamic_num_directions) {
|
||||
auto lstm_sequence = lstm_seq_tensor_initialization(param);
|
||||
lstm_sequence->validate_and_infer_types();
|
||||
|
||||
// Output 'num_directions' is '1' due to default FORWARD direction
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(0),
|
||||
(PartialShape{param.batch_size, param.num_directions, param.seq_length, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(1),
|
||||
(PartialShape{param.batch_size, param.num_directions, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(2),
|
||||
(PartialShape{param.batch_size, param.num_directions, param.hidden_size}));
|
||||
(PartialShape{param.batch_size, 1, param.seq_length, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(1), (PartialShape{param.batch_size, 1, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(2), (PartialShape{param.batch_size, 1, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(0), param.et);
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(1), param.et);
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(2), param.et);
|
||||
@@ -345,7 +377,7 @@ TEST(type_prop, lstm_sequence_dynamic_seq_length) {
|
||||
recurrent_sequence_parameters param;
|
||||
|
||||
param.batch_size = 24;
|
||||
param.num_directions = 2;
|
||||
param.num_directions = 1;
|
||||
param.seq_length = Dimension::dynamic();
|
||||
param.input_size = 8;
|
||||
param.hidden_size = 256;
|
||||
@@ -369,7 +401,7 @@ TEST(type_prop, lstm_sequence_dynamic_hidden_size) {
|
||||
recurrent_sequence_parameters param;
|
||||
|
||||
param.batch_size = 24;
|
||||
param.num_directions = 2;
|
||||
param.num_directions = 1;
|
||||
param.seq_length = 12;
|
||||
param.input_size = 8;
|
||||
param.hidden_size = Dimension::dynamic();
|
||||
@@ -403,11 +435,9 @@ TEST(type_prop, lstm_sequence_dynamic_inputs) {
|
||||
lstm_sequence->validate_and_infer_types();
|
||||
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(0),
|
||||
(PartialShape{param.batch_size, param.num_directions, param.seq_length, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(1),
|
||||
(PartialShape{param.batch_size, param.num_directions, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(2),
|
||||
(PartialShape{param.batch_size, param.num_directions, param.hidden_size}));
|
||||
(PartialShape{param.batch_size, 1, param.seq_length, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(1), (PartialShape{param.batch_size, 1, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(2), (PartialShape{param.batch_size, 1, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(0), param.et);
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(1), param.et);
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(2), param.et);
|
||||
@@ -479,12 +509,36 @@ TEST(type_prop, lstm_sequence_invalid_input_direction) {
|
||||
auto lstm_sequence = lstm_seq_tensor_initialization(param);
|
||||
try {
|
||||
lstm_sequence->validate_and_infer_types();
|
||||
FAIL() << "LSTMSequence node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Parameter direction must be Forward or Reverse or Bidirectional"));
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter 'num_directions' doesn't match with direction"));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, lstm_sequence_invalid_input_direction_num_mismatch) {
|
||||
auto check_error = [](op::RecurrentSequenceDirection direction, int num_directions) {
|
||||
recurrent_sequence_parameters param;
|
||||
|
||||
param.batch_size = 24;
|
||||
param.num_directions = num_directions;
|
||||
param.seq_length = 12;
|
||||
param.input_size = 8;
|
||||
param.hidden_size = 256;
|
||||
param.et = element::f32;
|
||||
try {
|
||||
auto gru_sequence = lstm_seq_direction_initialization(param, direction);
|
||||
gru_sequence->validate_and_infer_types();
|
||||
FAIL() << "LSTMSequence node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter 'num_directions' doesn't match with direction"));
|
||||
}
|
||||
};
|
||||
|
||||
check_error(op::RecurrentSequenceDirection::BIDIRECTIONAL, 1);
|
||||
check_error(op::RecurrentSequenceDirection::FORWARD, 2);
|
||||
check_error(op::RecurrentSequenceDirection::REVERSE, 2);
|
||||
}
|
||||
|
||||
TEST(type_prop, lstm_sequence_v1_dynamic_num_directions) {
|
||||
recurrent_sequence_parameters param;
|
||||
|
||||
@@ -499,11 +553,9 @@ TEST(type_prop, lstm_sequence_v1_dynamic_num_directions) {
|
||||
lstm_sequence->validate_and_infer_types();
|
||||
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(0),
|
||||
(PartialShape{param.batch_size, param.num_directions, param.seq_length, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(1),
|
||||
(PartialShape{param.batch_size, param.num_directions, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(2),
|
||||
(PartialShape{param.batch_size, param.num_directions, param.hidden_size}));
|
||||
(PartialShape{param.batch_size, 1, param.seq_length, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(1), (PartialShape{param.batch_size, 1, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(2), (PartialShape{param.batch_size, 1, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(0), param.et);
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(1), param.et);
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(2), param.et);
|
||||
@@ -513,7 +565,7 @@ TEST(type_prop, lstm_sequence_v1_dynamic_seq_length) {
|
||||
recurrent_sequence_parameters param;
|
||||
|
||||
param.batch_size = 24;
|
||||
param.num_directions = 2;
|
||||
param.num_directions = 1;
|
||||
param.seq_length = Dimension::dynamic();
|
||||
param.input_size = 8;
|
||||
param.hidden_size = 256;
|
||||
@@ -537,7 +589,7 @@ TEST(type_prop, lstm_sequence_v1_dynamic_hidden_size) {
|
||||
recurrent_sequence_parameters param;
|
||||
|
||||
param.batch_size = 24;
|
||||
param.num_directions = 2;
|
||||
param.num_directions = 1;
|
||||
param.seq_length = 12;
|
||||
param.input_size = 8;
|
||||
param.hidden_size = Dimension::dynamic();
|
||||
@@ -571,11 +623,9 @@ TEST(type_prop, lstm_sequence_v1_dynamic_inputs) {
|
||||
lstm_sequence->validate_and_infer_types();
|
||||
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(0),
|
||||
(PartialShape{param.batch_size, param.num_directions, param.seq_length, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(1),
|
||||
(PartialShape{param.batch_size, param.num_directions, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(2),
|
||||
(PartialShape{param.batch_size, param.num_directions, param.hidden_size}));
|
||||
(PartialShape{param.batch_size, 1, param.seq_length, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(1), (PartialShape{param.batch_size, 1, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_partial_shape(2), (PartialShape{param.batch_size, 1, param.hidden_size}));
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(0), param.et);
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(1), param.et);
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(2), param.et);
|
||||
@@ -647,8 +697,17 @@ TEST(type_prop, lstm_sequence_v1_invalid_input_direction) {
|
||||
auto lstm_sequence = lstm_seq_v1_tensor_initialization(param);
|
||||
try {
|
||||
lstm_sequence->validate_and_infer_types();
|
||||
FAIL() << "LSTMSequence node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Parameter direction must be Forward or Reverse or Bidirectional"));
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter 'num_directions' doesn't match with direction"));
|
||||
}
|
||||
|
||||
param.num_directions = 2; // 2 is also not allowed for default 'm_direction' = FORWARD
|
||||
lstm_sequence = lstm_seq_v1_tensor_initialization(param);
|
||||
try {
|
||||
lstm_sequence->validate_and_infer_types();
|
||||
FAIL() << "LSTMSequence node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter 'num_directions' doesn't match with direction"));
|
||||
}
|
||||
}
|
||||
@@ -122,7 +122,7 @@ TEST(type_prop, rnn_sequence_invalid_input) {
|
||||
make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction);
|
||||
FAIL() << "RNNSequence node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter num_directions not match direction in RNNSequence."));
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter 'num_directions' doesn't match with direction"));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user