GRU/AUGRUCell shape inference function (#13708)

* Add shape_infer function for GRUCell op

* Add shape_infer function for AUGRUCell

* Consts refactor

* Add batch_size check

* Enable GRUCell shape_infer for CPU

* Style apply

* Use OV_EXPECT_THROW in tests

* Use helper for input shapes

* Use .back() instead of index

* Change rnn_seq namespace to rnn

* Fix win warnings

* Enable AUGRUCell/Sequence shape_infer on CPU

* Fix warn

* Fix warn
This commit is contained in:
Katarzyna Mitrus 2022-11-04 21:31:58 +01:00 committed by GitHub
parent 19c2ec068a
commit c953186ff0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 556 additions and 251 deletions

View File

@ -78,9 +78,9 @@ AUGRU formula:
* **4**: `R` - 2D tensor of type *T* and shape `[3 * hidden_size, hidden_size]`. The recurrence weights for matrix multiplication, gate order: zrh. **Required.**
* **6**: `B` - 2D tensor of type *T*. The biases. If *linear_before_reset* is set to `False`, then the shape is `[3 * hidden_size]`, gate order: zrh. Otherwise the shape is `[4 * hidden_size]` - the sum of biases for z and r gates (weights and recurrence weights), the biases for h gate are placed separately. **Required.**
* **5**: `B` - 2D tensor of type *T*. The biases. If *linear_before_reset* is set to `False`, then the shape is `[3 * hidden_size]`, gate order: zrh. Otherwise the shape is `[4 * hidden_size]` - the sum of biases for z and r gates (weights and recurrence weights), the biases for h gate are placed separately. **Required.**
* **7**: `A` - 2D tensor of type *T* and shape `[batch_size, 1]`, the attention score. **Required.**
* **6**: `A` - 2D tensor of type *T* and shape `[batch_size, 1]`, the attention score. **Required.**
**Outputs**

View File

@ -6,6 +6,7 @@
#include <cmath>
#include "augru_cell_shape_inference.hpp"
#include "itt.hpp"
using namespace std;
@ -40,23 +41,6 @@ bool ov::op::internal::AUGRUCell::visit_attributes(AttributeVisitor& visitor) {
void ov::op::internal::AUGRUCell::validate_and_infer_types() {
INTERNAL_OP_SCOPE(internal_AUGRUCell_validate_and_infer_types);
for (const auto& input : inputs()) {
if (input.get_partial_shape().rank().is_dynamic()) {
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(2));
return;
}
}
auto merged_batch_size = Dimension::dynamic();
auto merged_hidden_size = Dimension::dynamic();
auto result_et = element::dynamic;
// Get input partial shape for all inputs
const auto& x_pshape = get_input_partial_shape(0);
const auto& ht_pshape = get_input_partial_shape(1);
const auto& w_pshape = get_input_partial_shape(2);
const auto& r_pshape = get_input_partial_shape(3);
const auto& b_pshape = get_input_partial_shape(4);
const auto& a_pshape = get_input_partial_shape(5);
NODE_VALIDATION_CHECK(this, m_clip == 0.f, "AUGRUCell doesn't support clip other than 0.");
NODE_VALIDATION_CHECK(this,
@ -69,15 +53,8 @@ void ov::op::internal::AUGRUCell::validate_and_infer_types() {
m_linear_before_reset == false,
"AUGRUCell supports only linear_before_reset equals false.");
validate_input_rank_dimension({x_pshape, ht_pshape, w_pshape, r_pshape, b_pshape});
// `A` input shape validation // [batch_size, 1]
NODE_VALIDATION_CHECK(this, a_pshape.rank().compatible(2), "'A' input must be a 2D tensor.");
if (a_pshape.rank().is_static()) {
NODE_VALIDATION_CHECK(this, a_pshape[1].compatible(1), "The last dimension of `A` shape must be equal to `1`.");
}
// Validate input types and save result for output type
auto result_et = element::dynamic;
NODE_VALIDATION_CHECK(this,
element::Type::merge(result_et, result_et, get_input_element_type(0)) &&
element::Type::merge(result_et, result_et, get_input_element_type(1)) &&
@ -87,55 +64,13 @@ void ov::op::internal::AUGRUCell::validate_and_infer_types() {
element::Type::merge(result_et, result_et, get_input_element_type(5)),
"Element types for inputs do not match.");
// Merge batch_size dimension across all inputs to evaluate output[0] dimension
NODE_VALIDATION_CHECK(this,
Dimension::merge(merged_batch_size, merged_batch_size, ht_pshape[0]) &&
Dimension::merge(merged_batch_size, merged_batch_size, a_pshape[0]) &&
Dimension::merge(merged_batch_size, merged_batch_size, x_pshape[0]),
"Dimension batch_size is not matched between inputs.");
// Get input partial shape for all inputs
const auto input_shapes = get_node_input_partial_shapes(*this);
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape::dynamic(2)};
shape_infer(this, input_shapes, output_shapes);
// Merge hidden_size dimension across all inputs to evaluate output[1] dimension
NODE_VALIDATION_CHECK(this,
Dimension::merge(merged_hidden_size, merged_hidden_size, ht_pshape[1]) &&
Dimension::merge(merged_hidden_size, merged_hidden_size, r_pshape[1]),
"Dimension hidden_size not matched for R and initial_hidden_state inputs.");
// Validate hidden_size value for W, B and R inputs
if (merged_hidden_size.is_static()) {
if (w_pshape[0].is_static()) {
NODE_VALIDATION_CHECK(this,
w_pshape[0].compatible(merged_hidden_size * s_gates_count),
"Parameter hidden_size mistmatched in W input. Current value is: ",
w_pshape[0].get_length(),
", expected: ",
merged_hidden_size.get_length() * s_gates_count,
".");
}
if (r_pshape[0].is_static()) {
NODE_VALIDATION_CHECK(this,
r_pshape[0].compatible(merged_hidden_size * s_gates_count),
"Parameter hidden_size mistmatched in R input. Current value is: ",
r_pshape[0].get_length(),
", expected: ",
merged_hidden_size.get_length() * s_gates_count,
".");
}
if (b_pshape[0].is_static()) {
NODE_VALIDATION_CHECK(this,
b_pshape[0].compatible(merged_hidden_size * (s_gates_count + m_linear_before_reset)),
"Parameter hidden_size mistmatched in B input. Current value is: ",
b_pshape[0].get_length(),
", expected: ",
merged_hidden_size.get_length() * (s_gates_count + m_linear_before_reset),
".");
}
}
// Set output size, type and shape
set_output_size(1);
set_output_type(0, result_et, {merged_batch_size, merged_hidden_size});
// Set output type and shape
set_output_type(0, result_et, output_shapes[0]);
}
shared_ptr<ov::Node> ov::op::internal::AUGRUCell::clone_with_new_inputs(const OutputVector& new_args) const {

View File

@ -63,17 +63,8 @@ void ov::op::internal::AUGRUSequence::validate_and_infer_types() {
element::Type::merge(result_et, result_et, get_input_element_type(6)),
"Element types for inputs do not match.");
const auto& x_pshape = get_input_partial_shape(0);
const auto& ht_pshape = get_input_partial_shape(1);
const auto& sl_pshape = get_input_partial_shape(2);
const auto& w_pshape = get_input_partial_shape(3);
const auto& r_pshape = get_input_partial_shape(4);
const auto& b_pshape = get_input_partial_shape(5);
const auto& a_pshape = get_input_partial_shape(6);
std::vector<ov::PartialShape> input_shapes =
{x_pshape, ht_pshape, sl_pshape, w_pshape, r_pshape, b_pshape, a_pshape};
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{4}, ov::PartialShape{3}};
const auto input_shapes = get_node_input_partial_shapes(*this);
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape::dynamic(4), ov::PartialShape::dynamic(3)};
shape_infer(this, input_shapes, output_shapes);
// Set output size, type and shape

View File

@ -0,0 +1,44 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "gru_cell_shape_inference.hpp"
#include "ngraph_ops/augru_sequence.hpp"
#include "utils.hpp"
namespace ov {
namespace op {
namespace internal {
template <class ShapeType>
void shape_infer(const ov::op::internal::AUGRUCell* op,
const std::vector<ShapeType>& input_shapes,
std::vector<ShapeType>& output_shapes) {
constexpr size_t expected_in_shapes_count = 6;
NODE_VALIDATION_CHECK(op,
input_shapes.size() == expected_in_shapes_count,
"Incorrect number of input shapes has been provided. Expected: ",
expected_in_shapes_count,
", got: ",
input_shapes.size(),
".");
rnn::gru_cell_shape_infer(op, input_shapes, output_shapes);
// `A` input shape validation // [batch_size, 1]
const auto& a_shape = input_shapes.back();
const auto& x_shape = input_shapes[0];
NODE_VALIDATION_CHECK(op, a_shape.rank().compatible(2), "'A' input must be a 2D tensor.");
if (a_shape.rank().is_static()) {
if (x_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(op,
x_shape.rank().get_length() > 1 && a_shape[0].compatible(x_shape[0]),
"Dimension `batch_size` must be the same for `X` and `A` inputs.");
}
NODE_VALIDATION_CHECK(op, a_shape[1].compatible(1), "The last dimension of `A` shape must be equal to `1`.");
}
}
} // namespace internal
} // namespace op
} // namespace ov

View File

@ -24,16 +24,19 @@ void shape_infer(const ov::op::internal::AUGRUSequence* op,
input_shapes.size(),
".");
rnn_seq::gru_shape_infer(op, input_shapes, output_shapes);
rnn::gru_sequence_shape_infer(op, input_shapes, output_shapes);
// A input shape validation // [batch_size, seq_length, 1]
const auto& a_shape = input_shapes[6];
const auto& a_shape = input_shapes.back();
const auto& x_shape = input_shapes[0];
NODE_VALIDATION_CHECK(op, a_shape.rank().compatible(3), "'A' input must be a 3D tensor.");
if (a_shape.rank().is_static()) {
if (x_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(op,
x_shape.rank().get_length() > 1 && a_shape[1].compatible(x_shape[1]),
x_shape.rank().get_length() > 1 && a_shape[0].compatible(x_shape[0]),
"Dimension `batch_size` must be the same for `X` and `A` inputs.");
NODE_VALIDATION_CHECK(op,
x_shape.rank().get_length() > 2 && a_shape[1].compatible(x_shape[1]),
"Dimension `seq_length` must be the same for `X` and `A` inputs.");
}
NODE_VALIDATION_CHECK(op, a_shape[2].compatible(1), "The last dimension of `A` shape must be equal to `1`.");

View File

@ -0,0 +1,105 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/core/validation_util.hpp>
#include <openvino/op/gru_cell.hpp>
#include "gru_cell_shape_inference.hpp"
#include "gru_sequence_shape_inference.hpp"
#include "utils.hpp"
namespace ov {
namespace op {
namespace rnn {
// Output shape layout:
// output_shapes[0]: [batch_size, hidden_size] // Rank always 2
template <class OpType, class ShapeType>
void gru_cell_shape_infer(const OpType* op,
const std::vector<ShapeType>& input_shapes,
std::vector<ShapeType>& output_shapes) {
NODE_VALIDATION_CHECK(op,
input_shapes.size() >= 5 && output_shapes.size() == 1,
"Incorrect number of shapes has been provided.");
auto& y_out_shape = output_shapes[0];
y_out_shape.resize(2); // Rank always 2
rnn::validate_inputs_rank(op, input_shapes, {2, 2, 2, 2, 1});
const auto& x_pshape = input_shapes[0]; // [batch_size, input_size]
const auto& ht_pshape = input_shapes[1]; // [batch_size, hidden_size]
const auto& w_pshape = input_shapes[2]; // [3 * hidden_size, input_size]
const auto& r_pshape = input_shapes[3]; // [3 * hidden_size, hidden_size]
const auto& b_pshape = input_shapes[4]; // if linear_before_reset [4 * hidden_size], otherwise [3 * hidden_size]
using DimType = typename std::iterator_traits<typename ShapeType::iterator>::value_type;
// Merge batch_size dimension across all inputs to evaluate output[0] dimension
DimType merged_batch_size = x_pshape.rank().is_static() ? x_pshape[0] : DimType();
NODE_VALIDATION_CHECK(
op,
DimType::merge(merged_batch_size, merged_batch_size, ht_pshape.rank().is_static() ? ht_pshape[0] : DimType()),
"Dimension `batch_size` is not matched between inputs.");
// Set batch_size dimension
y_out_shape[0] = merged_batch_size;
// Merge hidden_size dimension across all inputs to evaluate output dimension
// `hidden_size` attribute is not used for backward compatibility
DimType merged_hidden_size = ht_pshape.rank().is_static() ? ht_pshape[1] : DimType();
NODE_VALIDATION_CHECK(
op,
DimType::merge(merged_hidden_size, merged_hidden_size, r_pshape.rank().is_static() ? r_pshape[1] : DimType()),
"Dimension `hidden_size` is not matched between inputs.");
// Validate dimensions related to hidden_size for W, R, B inputs
if (merged_hidden_size.is_static()) {
constexpr auto gru_gates_count = 3;
if (w_pshape.rank().is_static()) {
NODE_VALIDATION_CHECK(op,
w_pshape[0].compatible(merged_hidden_size * gru_gates_count),
"First dimension of W input shape is required to be compatible with ",
merged_hidden_size * gru_gates_count,
". Got shape: ",
w_pshape[0],
".");
}
if (r_pshape.rank().is_static()) {
NODE_VALIDATION_CHECK(op,
r_pshape[0].compatible(merged_hidden_size * gru_gates_count),
"Fisrt dimension of R input shape is required to be compatible with ",
merged_hidden_size * gru_gates_count,
". Got shape: ",
r_pshape[0],
".");
}
if (b_pshape.rank().is_static()) {
auto bias_dim_multiplier = op->get_linear_before_reset() ? (gru_gates_count + 1) : gru_gates_count;
NODE_VALIDATION_CHECK(op,
b_pshape[0].compatible(merged_hidden_size * bias_dim_multiplier),
"First dimension of B input shape is required to be compatible with ",
merged_hidden_size * bias_dim_multiplier,
". Got shape: ",
b_pshape[0],
".");
}
}
// Set hidden_size dimension
y_out_shape[1] = merged_hidden_size;
}
} // namespace rnn
namespace v3 {
template <class ShapeType>
void shape_infer(const ov::op::v3::GRUCell* op,
const std::vector<ShapeType>& input_shapes,
std::vector<ShapeType>& output_shapes) {
rnn::gru_cell_shape_infer(op, input_shapes, output_shapes);
}
} // namespace v3
} // namespace op
} // namespace ov

View File

@ -9,7 +9,7 @@
namespace ov {
namespace op {
namespace rnn_seq {
namespace rnn {
template <class OpType, class ShapeType>
void validate_inputs_rank(const OpType* op,
const std::vector<ShapeType>& input_shapes,
@ -32,9 +32,9 @@ void validate_inputs_rank(const OpType* op,
// output_shapes[0]: [batch_size, num_directions, seq_length, hidden_size] // Rank always 4
// output_shapes[1]: [batch_size, num_directions, hidden_size] // Rank always 3
template <class OpType, class ShapeType>
void gru_shape_infer(const OpType* op,
const std::vector<ShapeType>& input_shapes,
std::vector<ShapeType>& output_shapes) {
void gru_sequence_shape_infer(const OpType* op,
const std::vector<ShapeType>& input_shapes,
std::vector<ShapeType>& output_shapes) {
NODE_VALIDATION_CHECK(op,
input_shapes.size() >= 6 && output_shapes.size() == 2,
"Incorrect number of shapes has been provided.");
@ -44,14 +44,14 @@ void gru_shape_infer(const OpType* op,
y_out_shape.resize(4); // Rank always 4
ho_out_shape.resize(3); // Rank always 3
rnn_seq::validate_inputs_rank(op, input_shapes, {3, 3, 1, 3, 3, 2});
rnn::validate_inputs_rank(op, input_shapes, {3, 3, 1, 3, 3, 2});
auto x_pshape = input_shapes[0];
auto ht_pshape = input_shapes[1];
auto sl_pshape = input_shapes[2];
auto w_pshape = input_shapes[3];
auto r_pshape = input_shapes[4];
auto b_pshape = input_shapes[5];
const auto& x_pshape = input_shapes[0];
const auto& ht_pshape = input_shapes[1];
const auto& sl_pshape = input_shapes[2];
const auto& w_pshape = input_shapes[3];
const auto& r_pshape = input_shapes[4];
const auto& b_pshape = input_shapes[5];
using DimType = typename std::iterator_traits<typename ShapeType::iterator>::value_type;
@ -155,7 +155,7 @@ void gru_shape_infer(const OpType* op,
y_out_shape[3] = merged_hidden_size;
ho_out_shape[2] = merged_hidden_size;
}
} // namespace rnn_seq
} // namespace rnn
namespace v5 {
template <class ShapeType>
void shape_infer(const ov::op::v5::GRUSequence* op,
@ -170,7 +170,7 @@ void shape_infer(const ov::op::v5::GRUSequence* op,
input_shapes.size(),
".");
rnn_seq::gru_shape_infer(op, input_shapes, output_shapes);
rnn::gru_sequence_shape_infer(op, input_shapes, output_shapes);
}
} // namespace v5
} // namespace op

View File

@ -6,6 +6,7 @@
#include <cmath>
#include "gru_cell_shape_inference.hpp"
#include "itt.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/shape.hpp"
@ -87,87 +88,22 @@ bool op::v3::GRUCell::visit_attributes(AttributeVisitor& visitor) {
void op::v3::GRUCell::validate_and_infer_types() {
OV_OP_SCOPE(v3_GRUCell_validate_and_infer_types);
for (const auto& input : inputs()) {
if (input.get_partial_shape().rank().is_dynamic()) {
set_output_type(0, get_input_element_type(0), ov::PartialShape::dynamic());
return;
}
}
auto merged_batch_size = Dimension::dynamic();
auto merged_hidden_size = Dimension::dynamic();
auto result_et = element::dynamic;
// Get input partial shape for all inputs
const auto& x_pshape = get_input_partial_shape(0);
const auto& ht_pshape = get_input_partial_shape(1);
const auto& w_pshape = get_input_partial_shape(2);
const auto& r_pshape = get_input_partial_shape(3);
const auto& b_pshape = get_input_partial_shape(4);
validate_input_rank_dimension({x_pshape, ht_pshape, w_pshape, r_pshape, b_pshape});
// Validate input types and save result for output type
auto result_et = element::dynamic;
NODE_VALIDATION_CHECK(this,
element::Type::merge(result_et, result_et, get_input_element_type(0)) &&
element::Type::merge(result_et, result_et, get_input_element_type(1)) &&
element::Type::merge(result_et, result_et, get_input_element_type(2)) &&
element::Type::merge(result_et, result_et, get_input_element_type(3)) &&
element::Type::merge(result_et, result_et, get_input_element_type(4)),
"Element types for X, initial_hidden_state, W, R and B inputs do not match.");
"Element types for X, initial_hidden_state, W, R and B inputs do not "
"match.");
// Merge batch_size dimension across all inputs to evaluate output[0] dimension
NODE_VALIDATION_CHECK(this,
Dimension::merge(merged_batch_size, merged_batch_size, ht_pshape[0]) &&
Dimension::merge(merged_batch_size, merged_batch_size, x_pshape[0]),
"Parameter batch_size not matched for X and initial_hidden_state inputs.");
const auto input_shapes = get_node_input_partial_shapes(*this);
std::vector<ov::PartialShape> output_shapes{ov::PartialShape::dynamic(2)};
shape_infer(this, input_shapes, output_shapes);
// Merge hidden_size dimension across all inputs to evaluate output[1] dimension
NODE_VALIDATION_CHECK(this,
Dimension::merge(merged_hidden_size, merged_hidden_size, ht_pshape[1]) &&
Dimension::merge(merged_hidden_size, merged_hidden_size, r_pshape[1]),
"Parameter hidden_size not matched for R and initial_hidden_state inputs.");
// Validate hidden_size value for W, B and R inputs
if (merged_hidden_size.is_static()) {
if (w_pshape[0].is_static()) {
NODE_VALIDATION_CHECK(this,
w_pshape[0].compatible(merged_hidden_size * s_gates_count),
"Parameter hidden_size mistmatched in W input. Current value is: ",
w_pshape[0].get_length(),
", expected: ",
merged_hidden_size.get_length() * s_gates_count,
".");
}
if (r_pshape[0].is_static()) {
NODE_VALIDATION_CHECK(this,
r_pshape[0].compatible(merged_hidden_size * s_gates_count),
"Parameter hidden_size mistmatched in R input. Current value is: ",
r_pshape[0].get_length(),
", expected: ",
merged_hidden_size.get_length() * s_gates_count,
".");
}
if (b_pshape[0].is_static()) {
NODE_VALIDATION_CHECK(this,
b_pshape[0].compatible(merged_hidden_size * (s_gates_count + m_linear_before_reset)),
"Parameter hidden_size mistmatched in B input. Current value is: ",
b_pshape[0].get_length(),
", expected: ",
merged_hidden_size.get_length() * (s_gates_count + m_linear_before_reset),
".");
}
}
// Mark inputs which are relevant to output parameters
set_input_is_relevant_to_shape(0);
set_input_is_relevant_to_shape(1);
set_input_is_relevant_to_shape(3);
// Set output size, type and shape
set_output_size(1);
set_output_type(0, result_et, {merged_batch_size, merged_hidden_size});
set_output_type(0, result_et, output_shapes[0]);
}
void op::v3::GRUCell::add_default_bias_input() {

View File

@ -60,15 +60,8 @@ void op::v5::GRUSequence::validate_and_infer_types() {
"Element types for X, initial_hidden_state, W, R and B inputs do not "
"match.");
const auto& x_pshape = get_input_partial_shape(0);
const auto& ht_pshape = get_input_partial_shape(1);
const auto& sl_pshape = get_input_partial_shape(2);
const auto& w_pshape = get_input_partial_shape(3);
const auto& r_pshape = get_input_partial_shape(4);
const auto& b_pshape = get_input_partial_shape(5);
std::vector<ov::PartialShape> input_shapes = {x_pshape, ht_pshape, sl_pshape, w_pshape, r_pshape, b_pshape};
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{4}, ov::PartialShape{3}};
const auto input_shapes = get_node_input_partial_shapes(*this);
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape::dynamic(4), ov::PartialShape::dynamic(3)};
shape_infer(this, input_shapes, output_shapes);
// Set output size, type and shape

View File

@ -48,7 +48,8 @@ TEST(type_prop, augru_cell_invalid_input) {
const auto gru_cell = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
FAIL() << "AUGRUCell node was created with invalid data.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter hidden_size mistmatched in W input."));
EXPECT_HAS_SUBSTRING(error.what(),
std::string("First dimension of W input shape is required to be compatible"));
}
// Invalid R tensor shape.
@ -58,8 +59,7 @@ TEST(type_prop, augru_cell_invalid_input) {
const auto gru_cell = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
FAIL() << "AUGRUCell node was created with invalid data.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Dimension hidden_size not matched for R and initial_hidden_state inputs."));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Dimension `hidden_size` is not matched between inputs"));
}
// Invalid H_t tensor shape.
@ -69,7 +69,7 @@ TEST(type_prop, augru_cell_invalid_input) {
const auto gru_cell = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
FAIL() << "AUGRUCell node was created with invalid data.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Dimension batch_size is not matched between inputs."));
EXPECT_HAS_SUBSTRING(error.what(), std::string("Dimension `batch_size` is not matched between inputs"));
}
// Invalid B tensor shape.
@ -79,9 +79,8 @@ TEST(type_prop, augru_cell_invalid_input) {
const auto gru_cell = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
FAIL() << "GRUCell node was created with invalid data.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(
error.what(),
std::string("Parameter hidden_size mistmatched in B input. Current value is: 3, expected: 9."));
EXPECT_HAS_SUBSTRING(error.what(),
std::string("First dimension of B input shape is required to be compatible"));
}
// Invalid A tensor shape.
@ -185,11 +184,11 @@ TEST(type_prop, augru_cell_invalid_input_rank) {
<< "AUGRUCell node was created with invalid data.";
}
TEST(type_prop, augru_cell_invalid_input_dynamic_rank) {
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
const size_t gates_count = 3;
TEST(type_prop, augru_cell_input_dynamic_rank) {
int64_t batch_size = 2;
int64_t input_size = 3;
int64_t hidden_size = 3;
int64_t gates_count = 3;
auto X = make_shared<opset9::Parameter>(element::f32, PartialShape{batch_size, input_size});
auto R = make_shared<opset9::Parameter>(element::f32, PartialShape{gates_count * hidden_size, hidden_size});
@ -197,41 +196,41 @@ TEST(type_prop, augru_cell_invalid_input_dynamic_rank) {
auto B = make_shared<opset9::Parameter>(element::f32, PartialShape{gates_count * hidden_size});
auto A = make_shared<opset9::Parameter>(element::f32, PartialShape{batch_size, 1});
auto check_dynamic_gru = [](const shared_ptr<op::internal::AUGRUCell>& augru) -> bool {
return augru->output(0).get_partial_shape() == PartialShape::dynamic(2) &&
auto check_dynamic_gru = [&](const shared_ptr<op::internal::AUGRUCell>& augru) -> bool {
return augru->output(0).get_partial_shape() == PartialShape{batch_size, hidden_size} &&
augru->output(0).get_element_type() == augru->input(0).get_element_type();
};
// Invalid dynamic rank for W tensor.
// Dynamic rank for W tensor.
auto W = make_shared<opset9::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
auto augru_w = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
EXPECT_TRUE(check_dynamic_gru(augru_w));
// Invalid dynamic rank for X tensor.
W = make_shared<opset9::Parameter>(element::f32, PartialShape{hidden_size, input_size});
// Dynamic rank for X tensor.
W = make_shared<opset9::Parameter>(element::f32, PartialShape{gates_count * hidden_size, input_size});
X = make_shared<opset9::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
auto augru_x = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
EXPECT_TRUE(check_dynamic_gru(augru_x));
// Invalid dynamic rank for H_t tensor.
// Dynamic rank for H_t tensor.
X = make_shared<opset9::Parameter>(element::f32, PartialShape{batch_size, input_size});
H_t = make_shared<opset9::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
auto augru_h = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
EXPECT_TRUE(check_dynamic_gru(augru_h));
// Invalid dynamic rank for R tensor.
// Dynamic rank for R tensor.
H_t = make_shared<opset9::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
R = make_shared<opset9::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
auto augru_r = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
EXPECT_TRUE(check_dynamic_gru(augru_r));
// Invalid dynamic rank for B tensor.
// Dynamic rank for B tensor.
R = make_shared<opset9::Parameter>(element::f32, PartialShape{gates_count * hidden_size, hidden_size});
B = make_shared<opset9::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
auto augru_b = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
EXPECT_TRUE(check_dynamic_gru(augru_b));
// Invalid dynamic rank for A tensor.
// Dynamic rank for A tensor.
B = make_shared<opset9::Parameter>(element::f32, PartialShape{gates_count * hidden_size});
A = make_shared<opset9::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
auto augru_a = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);

View File

@ -4,6 +4,7 @@
#include "ngraph_ops/augru_sequence.hpp"
#include "common_test_utils/test_assertions.hpp"
#include "gtest/gtest.h"
#include "openvino/core/attribute_visitor.hpp"
#include "openvino/opsets/opset9.hpp"
@ -11,6 +12,7 @@
using namespace std;
using namespace ov;
using namespace testing;
struct augru_sequence_parameters {
Dimension batch_size = 8;
@ -258,7 +260,7 @@ TEST(type_prop, augru_sequence_all_inputs_dynamic_rank) {
EXPECT_EQ(augru_sequence->get_output_element_type(1), param.et);
}
TEST(type_prop, augru_sequence_invalid_attention_gate) {
TEST(type_prop, augru_sequence_invalid_attention_gate_seq_length) {
augru_sequence_parameters params;
params.batch_size = 8;
@ -272,13 +274,28 @@ TEST(type_prop, augru_sequence_invalid_attention_gate) {
auto invalid_attention_gate = make_shared<opset9::Parameter>(params.et, PartialShape{params.batch_size, 999, 1});
augru_sequence->set_argument(6, invalid_attention_gate);
try {
augru_sequence->validate_and_infer_types();
FAIL() << "AUGRUSequence node was created with invalid data.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Dimension `seq_length` must be the same for `X` and `A` inputs."));
}
OV_EXPECT_THROW(augru_sequence->validate_and_infer_types(),
ov::NodeValidationFailure,
HasSubstr("Dimension `seq_length` must be the same for `X` and `A` inputs"));
}
TEST(type_prop, augru_sequence_invalid_attention_gate_batch) {
augru_sequence_parameters params;
params.batch_size = 8;
params.num_directions = 1;
params.seq_length = 6;
params.input_size = 4;
params.hidden_size = 128;
params.et = element::f32;
auto augru_sequence = augru_seq_init(params);
auto invalid_attention_gate = make_shared<opset9::Parameter>(params.et, PartialShape{999, params.seq_length, 1});
augru_sequence->set_argument(6, invalid_attention_gate);
OV_EXPECT_THROW(augru_sequence->validate_and_infer_types(),
ov::NodeValidationFailure,
HasSubstr("Dimension `batch_size` must be the same for `X` and `A` inputs"));
}
namespace {

View File

@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "common_test_utils/test_assertions.hpp"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/opsets/opset4.hpp"
@ -9,6 +10,7 @@
using namespace std;
using namespace ngraph;
using namespace testing;
TEST(type_prop, gru_cell) {
const size_t batch_size = 2;
@ -38,44 +40,30 @@ TEST(type_prop, gru_cell_invalid_input) {
// Invalid W tensor shape.
auto W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
try {
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
FAIL() << "GRUCell node was created with invalid data.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter hidden_size mistmatched in W input."));
}
OV_EXPECT_THROW(auto op = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
ov::NodeValidationFailure,
HasSubstr("First dimension of W input shape is required to be compatible"));
// Invalid R tensor shape.
W = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, 1});
try {
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
FAIL() << "GRUCell node was created with invalid data.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Parameter hidden_size not matched for R and initial_hidden_state inputs."));
}
OV_EXPECT_THROW(auto op = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
ov::NodeValidationFailure,
HasSubstr("Dimension `hidden_size` is not matched between inputs"));
// Invalid H_t tensor shape.
R = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
H_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
try {
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
FAIL() << "GRUCell node was created with invalid data.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Parameter batch_size not matched for X and initial_hidden_state inputs."));
}
OV_EXPECT_THROW(auto op = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
ov::NodeValidationFailure,
HasSubstr("Dimension `batch_size` is not matched between inputs"));
// Invalid B tensor shape.
H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto B = make_shared<op::Parameter>(element::f32, Shape{hidden_size});
try {
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, B, hidden_size);
FAIL() << "GRUCell node was created with invalid data.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter hidden_size mistmatched in B input."));
}
OV_EXPECT_THROW(auto op = make_shared<opset4::GRUCell>(X, H_t, W, R, B, hidden_size),
ov::NodeValidationFailure,
HasSubstr("First dimension of B input shape is required to be compatible"));
}
TEST(type_prop, gru_cell_dynamic_batch_size) {
@ -171,45 +159,45 @@ TEST(type_prop, gru_cell_invalid_input_rank0) {
<< "GRUCell node was created with invalid data.";
}
TEST(type_prop, gru_cell_invalid_input_dynamic_rank) {
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
const size_t gates_count = 3;
TEST(type_prop, gru_cell_input_dynamic_rank) {
int64_t batch_size = 2;
int64_t input_size = 3;
int64_t hidden_size = 3;
int64_t gates_count = 3;
auto X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
auto R = make_shared<op::Parameter>(element::f32, PartialShape{gates_count * hidden_size, hidden_size});
auto H_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
auto check_dynamic_gru = [](const shared_ptr<opset4::GRUCell>& gru) -> bool {
return gru->output(0).get_partial_shape() == PartialShape::dynamic() &&
auto check_dynamic_gru = [&](const shared_ptr<opset4::GRUCell>& gru) -> bool {
return gru->output(0).get_partial_shape() == PartialShape{batch_size, hidden_size} &&
gru->output(0).get_element_type() == gru->input(0).get_element_type();
};
// Invalid dynamic rank for W tensor.
// Dynamic rank for W tensor.
auto W = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
auto gru_w = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
EXPECT_EQ(check_dynamic_gru(gru_w), true);
// Invalid dynamic rank for X tensor.
W = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, input_size});
// Dynamic rank for X tensor.
W = make_shared<op::Parameter>(element::f32, PartialShape{gates_count * hidden_size, input_size});
X = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
auto gru_x = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
EXPECT_EQ(check_dynamic_gru(gru_x), true);
// Invalid dynamic rank for H_t tensor.
// Dynamic rank for H_t tensor.
X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
H_t = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
auto gru_h = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
EXPECT_EQ(check_dynamic_gru(gru_h), true);
// Invalid dynamic rank for R tensor.
// Dynamic rank for R tensor.
H_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
R = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
auto gru_r = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
EXPECT_EQ(check_dynamic_gru(gru_r), true);
// Invalid dynamic rank for B tensor.
// Dynamic rank for B tensor.
R = make_shared<op::Parameter>(element::f32, PartialShape{gates_count * hidden_size, hidden_size});
auto B = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
auto gru_b = make_shared<opset4::GRUCell>(X, H_t, W, R, B, hidden_size);

View File

@ -11,7 +11,12 @@
#include <openvino/opsets/opset7.hpp>
#include <openvino/opsets/opset8.hpp>
#include "ngraph_ops/augru_cell.hpp"
#include "ngraph_ops/augru_sequence.hpp"
#include "assign_shape_inference.hpp"
#include "augru_cell_shape_inference.hpp"
#include "augru_sequence_shape_inference.hpp"
#include "batch_to_space_shape_inference.hpp"
#include "broadcast_shape_inference.hpp"
#include "bucketize_shape_inference.hpp"
@ -38,6 +43,7 @@
#include "gather_shape_inference.hpp"
#include "gather_tree_shape_inference.hpp"
#include "gru_sequence_shape_inference.hpp"
#include "gru_cell_shape_inference.hpp"
#include "interpolate_shape_inference.hpp"
#include "lstm_cell_shape_inference.hpp"
#include "matmul_shape_inference.hpp"
@ -494,6 +500,12 @@ std::shared_ptr<IShapeInfer> make_shape_inference(const std::shared_ptr<ngraph::
return make_shared_entryIO(node);
} else if (auto node = ov::as_type_ptr<ov::opset5::GRUSequence>(op)) {
return make_shared_entryIO(node);
} else if (auto node = ov::as_type_ptr<ov::op::internal::AUGRUSequence>(op)) {
return make_shared_entryIO(node);
} else if (auto node = ov::as_type_ptr<ov::opset3::GRUCell>(op)) {
return make_shared_entryIO(node);
} else if (auto node = ov::as_type_ptr<ov::op::internal::AUGRUCell>(op)) {
return make_shared_entryIO(node);
} else if (auto node = ov::as_type_ptr<ov::opset1::OneHot>(op)) {
return make_shared_entryIOC(node);
} else if (auto node = ov::as_type_ptr<ov::opset4::CTCLoss>(op)) {

View File

@ -0,0 +1,71 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph_ops/augru_cell.hpp"
#include <gtest/gtest.h>
#include <openvino/op/ops.hpp>
#include <openvino/op/parameter.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>
using namespace ov;
using namespace ov::intel_cpu;
TEST(StaticShapeInferenceTest, AUGRUCellTest_all_inputs_static_rank) {
constexpr size_t batch_size = 2;
constexpr size_t input_size = 3;
constexpr size_t hidden_size = 5;
constexpr size_t gates_count = 3;
const auto X = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
const auto H_t = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
const auto W = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
const auto R = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
const auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(1));
const auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
const auto augru = std::make_shared<ov::op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
std::vector<StaticShape> static_input_shapes{StaticShape{batch_size, input_size}, // X
StaticShape{batch_size, hidden_size}, // H_t
StaticShape{gates_count * hidden_size, input_size}, // W
StaticShape{gates_count * hidden_size, hidden_size}, // R
StaticShape{gates_count * hidden_size}, // B
StaticShape{batch_size, 1}}; // A
std::vector<StaticShape> static_output_shapes{StaticShape{}, StaticShape{}};
shape_inference(augru.get(), static_input_shapes, static_output_shapes);
EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size}));
}
TEST(StaticShapeInferenceTest, AUGRUCellTest_all_inputs_dynamic_rank) {
constexpr size_t batch_size = 2;
constexpr size_t input_size = 3;
constexpr size_t hidden_size = 5;
constexpr size_t gates_count = 3;
const auto X = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto H_t = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto W = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto R = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto augru = std::make_shared<ov::op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
std::vector<StaticShape> static_input_shapes{StaticShape{batch_size, input_size}, // X
StaticShape{batch_size, hidden_size}, // H_t
StaticShape{gates_count * hidden_size, input_size}, // W
StaticShape{gates_count * hidden_size, hidden_size}, // R
StaticShape{gates_count * hidden_size}, // B
StaticShape{batch_size, 1}}; // A
std::vector<StaticShape> static_output_shapes{StaticShape{}, StaticShape{}};
shape_inference(augru.get(), static_input_shapes, static_output_shapes);
EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size}));
}

View File

@ -0,0 +1,85 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph_ops/augru_sequence.hpp"
#include <gtest/gtest.h>
#include <openvino/op/ops.hpp>
#include <openvino/op/parameter.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>
using namespace ov;
using namespace ov::intel_cpu;
TEST(StaticShapeInferenceTest, AGRUSequenceTest_FORWARD_all_static_rank) {
constexpr size_t batch_size = 2;
constexpr size_t input_size = 3;
constexpr size_t hidden_size = 5;
constexpr size_t seq_len = 4;
constexpr size_t num_directions = 1;
constexpr size_t gates_count = 3;
const auto X = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(3));
const auto H_t = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(3));
const auto seq_lengths = std::make_shared<op::v0::Parameter>(element::i32, PartialShape::dynamic(1));
const auto W = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(3));
const auto R = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(3));
const auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
const auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(3));
const auto augru_sequence =
std::make_shared<ov::op::internal::AUGRUSequence>(X, H_t, seq_lengths, W, R, B, A, hidden_size);
std::vector<StaticShape> static_input_shapes{
StaticShape{batch_size, seq_len, input_size}, // X
StaticShape{batch_size, num_directions, hidden_size}, // H_t
StaticShape{batch_size}, // seq_lengths
StaticShape{num_directions, gates_count * hidden_size, input_size}, // W
StaticShape{num_directions, gates_count * hidden_size, hidden_size}, // R
StaticShape{num_directions, gates_count * hidden_size}, // B
StaticShape{batch_size, seq_len, 1}}; // A
std::vector<StaticShape> static_output_shapes{StaticShape{}, StaticShape{}};
shape_inference(augru_sequence.get(), static_input_shapes, static_output_shapes);
EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, num_directions, seq_len, hidden_size}));
EXPECT_EQ(static_output_shapes[1], StaticShape({batch_size, num_directions, hidden_size}));
}
TEST(StaticShapeInferenceTest, AGRUSequenceTest_FORWARD_all_inputs_dynamic_rank) {
constexpr size_t batch_size = 2;
constexpr size_t input_size = 3;
constexpr size_t hidden_size = 5;
constexpr size_t seq_len = 4;
constexpr size_t num_directions = 1;
constexpr size_t gates_count = 3;
const auto X = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto H_t = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto seq_lengths = std::make_shared<op::v0::Parameter>(element::i32, PartialShape::dynamic());
const auto W = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto R = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto augru_sequence =
std::make_shared<ov::op::internal::AUGRUSequence>(X, H_t, seq_lengths, W, R, B, A, hidden_size);
std::vector<StaticShape> static_input_shapes{
StaticShape{batch_size, seq_len, input_size}, // X
StaticShape{batch_size, num_directions, hidden_size}, // H_t
StaticShape{batch_size}, // seq_lengths
StaticShape{num_directions, gates_count * hidden_size, input_size}, // W
StaticShape{num_directions, gates_count * hidden_size, hidden_size}, // R
StaticShape{num_directions, gates_count * hidden_size}, // B
StaticShape{batch_size, seq_len, 1}}; // A
std::vector<StaticShape> static_output_shapes{StaticShape{}, StaticShape{}};
shape_inference(augru_sequence.get(), static_input_shapes, static_output_shapes);
EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, num_directions, seq_len, hidden_size}));
EXPECT_EQ(static_output_shapes[1], StaticShape({batch_size, num_directions, hidden_size}));
}

View File

@ -0,0 +1,126 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <openvino/op/ops.hpp>
#include <openvino/op/parameter.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>
using namespace ov;
using namespace ov::intel_cpu;
TEST(StaticShapeInferenceTest, GRUCellTest_default_bias) {
constexpr size_t batch_size = 2;
constexpr size_t input_size = 3;
constexpr size_t hidden_size = 5;
constexpr size_t gates_count = 3;
const auto X = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
const auto H_t = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
const auto W = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
const auto R = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
// Default `B` input is created as Constant by GRUCell contructor
const auto gru = std::make_shared<op::v3::GRUCell>(X, H_t, W, R, hidden_size);
std::vector<StaticShape> static_input_shapes{StaticShape{batch_size, input_size}, // X
StaticShape{batch_size, hidden_size}, // H_t
StaticShape{gates_count * hidden_size, input_size}, // W
StaticShape{gates_count * hidden_size, hidden_size}, // R
StaticShape{gates_count * hidden_size}}; // B
std::vector<StaticShape> static_output_shapes{StaticShape{}};
shape_inference(gru.get(), static_input_shapes, static_output_shapes);
EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size}));
}
TEST(StaticShapeInferenceTest, GRUCellTest_with_bias) {
constexpr size_t batch_size = 2;
constexpr size_t input_size = 3;
constexpr size_t hidden_size = 5;
constexpr size_t gates_count = 3;
const auto X = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
const auto H_t = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
const auto W = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
const auto R = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
const auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(1));
const auto gru = std::make_shared<op::v3::GRUCell>(X, H_t, W, R, B, hidden_size);
std::vector<StaticShape> static_input_shapes{StaticShape{batch_size, input_size}, // X
StaticShape{batch_size, hidden_size}, // H_t
StaticShape{gates_count * hidden_size, input_size}, // W
StaticShape{gates_count * hidden_size, hidden_size}, // R
StaticShape{gates_count * hidden_size}}; // B
std::vector<StaticShape> static_output_shapes{StaticShape{}, StaticShape{}};
shape_inference(gru.get(), static_input_shapes, static_output_shapes);
EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size}));
}
TEST(StaticShapeInferenceTest, GRUCellTest_linear_before) {
constexpr size_t batch_size = 2;
constexpr size_t input_size = 3;
constexpr size_t hidden_size = 5;
constexpr size_t gates_count = 3;
const auto X = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
const auto H_t = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
const auto W = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
const auto R = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
const auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(1));
const auto gru = std::make_shared<op::v3::GRUCell>(X,
H_t,
W,
R,
B,
hidden_size,
std::vector<std::string>{"sigmoid", "tanh"},
std::vector<float>{},
std::vector<float>{},
0.f,
true);
std::vector<StaticShape> static_input_shapes{StaticShape{batch_size, input_size}, // X
StaticShape{batch_size, hidden_size}, // H_t
StaticShape{gates_count * hidden_size, input_size}, // W
StaticShape{gates_count * hidden_size, hidden_size}, // R
StaticShape{(gates_count + 1) * hidden_size}}; // B
std::vector<StaticShape> static_output_shapes{StaticShape{}, StaticShape{}};
shape_inference(gru.get(), static_input_shapes, static_output_shapes);
EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size}));
}
TEST(StaticShapeInferenceTest, GRUCellTest_dynamic_rank_inputs) {
constexpr size_t batch_size = 2;
constexpr size_t input_size = 3;
constexpr size_t hidden_size = 5;
constexpr size_t gates_count = 3;
const auto X = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto H_t = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto W = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto R = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
const auto gru = std::make_shared<op::v3::GRUCell>(X, H_t, W, R, B, hidden_size);
std::vector<StaticShape> static_input_shapes{StaticShape{batch_size, input_size}, // X
StaticShape{batch_size, hidden_size}, // H_t
StaticShape{gates_count * hidden_size, input_size}, // W
StaticShape{gates_count * hidden_size, hidden_size}, // R
StaticShape{gates_count * hidden_size}}; // B
std::vector<StaticShape> static_output_shapes{StaticShape{}};
shape_inference(gru.get(), static_input_shapes, static_output_shapes);
EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size}));
}