diff --git a/docs/ops/internal/AUGRUCell.md b/docs/ops/internal/AUGRUCell.md index b2c5cbc19cf..fa58148b714 100644 --- a/docs/ops/internal/AUGRUCell.md +++ b/docs/ops/internal/AUGRUCell.md @@ -78,9 +78,9 @@ AUGRU formula: * **4**: `R` - 2D tensor of type *T* and shape `[3 * hidden_size, hidden_size]`. The recurrence weights for matrix multiplication, gate order: zrh. **Required.** -* **6**: `B` - 2D tensor of type *T*. The biases. If *linear_before_reset* is set to `False`, then the shape is `[3 * hidden_size]`, gate order: zrh. Otherwise the shape is `[4 * hidden_size]` - the sum of biases for z and r gates (weights and recurrence weights), the biases for h gate are placed separately. **Required.** +* **5**: `B` - 2D tensor of type *T*. The biases. If *linear_before_reset* is set to `False`, then the shape is `[3 * hidden_size]`, gate order: zrh. Otherwise the shape is `[4 * hidden_size]` - the sum of biases for z and r gates (weights and recurrence weights), the biases for h gate are placed separately. **Required.** -* **7**: `A` - 2D tensor of type *T* and shape `[batch_size, 1]`, the attention score. **Required.** +* **6**: `A` - 2D tensor of type *T* and shape `[batch_size, 1]`, the attention score. **Required.** **Outputs** diff --git a/src/common/transformations/src/ngraph_ops/augru_cell.cpp b/src/common/transformations/src/ngraph_ops/augru_cell.cpp index d54b7a622b0..1f39508e3ce 100644 --- a/src/common/transformations/src/ngraph_ops/augru_cell.cpp +++ b/src/common/transformations/src/ngraph_ops/augru_cell.cpp @@ -6,6 +6,7 @@ #include +#include "augru_cell_shape_inference.hpp" #include "itt.hpp" using namespace std; @@ -40,23 +41,6 @@ bool ov::op::internal::AUGRUCell::visit_attributes(AttributeVisitor& visitor) { void ov::op::internal::AUGRUCell::validate_and_infer_types() { INTERNAL_OP_SCOPE(internal_AUGRUCell_validate_and_infer_types); - for (const auto& input : inputs()) { - if (input.get_partial_shape().rank().is_dynamic()) { - set_output_type(0, get_input_element_type(0), PartialShape::dynamic(2)); - return; - } - } - auto merged_batch_size = Dimension::dynamic(); - auto merged_hidden_size = Dimension::dynamic(); - auto result_et = element::dynamic; - - // Get input partial shape for all inputs - const auto& x_pshape = get_input_partial_shape(0); - const auto& ht_pshape = get_input_partial_shape(1); - const auto& w_pshape = get_input_partial_shape(2); - const auto& r_pshape = get_input_partial_shape(3); - const auto& b_pshape = get_input_partial_shape(4); - const auto& a_pshape = get_input_partial_shape(5); NODE_VALIDATION_CHECK(this, m_clip == 0.f, "AUGRUCell doesn't support clip other than 0."); NODE_VALIDATION_CHECK(this, @@ -69,15 +53,8 @@ void ov::op::internal::AUGRUCell::validate_and_infer_types() { m_linear_before_reset == false, "AUGRUCell supports only linear_before_reset equals false."); - validate_input_rank_dimension({x_pshape, ht_pshape, w_pshape, r_pshape, b_pshape}); - - // `A` input shape validation // [batch_size, 1] - NODE_VALIDATION_CHECK(this, a_pshape.rank().compatible(2), "'A' input must be a 2D tensor."); - if (a_pshape.rank().is_static()) { - NODE_VALIDATION_CHECK(this, a_pshape[1].compatible(1), "The last dimension of `A` shape must be equal to `1`."); - } - // Validate input types and save result for output type + auto result_et = element::dynamic; NODE_VALIDATION_CHECK(this, element::Type::merge(result_et, result_et, get_input_element_type(0)) && element::Type::merge(result_et, result_et, get_input_element_type(1)) && @@ -87,55 +64,13 @@ void ov::op::internal::AUGRUCell::validate_and_infer_types() { element::Type::merge(result_et, result_et, get_input_element_type(5)), "Element types for inputs do not match."); - // Merge batch_size dimension across all inputs to evaluate output[0] dimension - NODE_VALIDATION_CHECK(this, - Dimension::merge(merged_batch_size, merged_batch_size, ht_pshape[0]) && - Dimension::merge(merged_batch_size, merged_batch_size, a_pshape[0]) && - Dimension::merge(merged_batch_size, merged_batch_size, x_pshape[0]), - "Dimension batch_size is not matched between inputs."); + // Get input partial shape for all inputs + const auto input_shapes = get_node_input_partial_shapes(*this); + std::vector output_shapes = {ov::PartialShape::dynamic(2)}; + shape_infer(this, input_shapes, output_shapes); - // Merge hidden_size dimension across all inputs to evaluate output[1] dimension - NODE_VALIDATION_CHECK(this, - Dimension::merge(merged_hidden_size, merged_hidden_size, ht_pshape[1]) && - Dimension::merge(merged_hidden_size, merged_hidden_size, r_pshape[1]), - "Dimension hidden_size not matched for R and initial_hidden_state inputs."); - - // Validate hidden_size value for W, B and R inputs - if (merged_hidden_size.is_static()) { - if (w_pshape[0].is_static()) { - NODE_VALIDATION_CHECK(this, - w_pshape[0].compatible(merged_hidden_size * s_gates_count), - "Parameter hidden_size mistmatched in W input. Current value is: ", - w_pshape[0].get_length(), - ", expected: ", - merged_hidden_size.get_length() * s_gates_count, - "."); - } - - if (r_pshape[0].is_static()) { - NODE_VALIDATION_CHECK(this, - r_pshape[0].compatible(merged_hidden_size * s_gates_count), - "Parameter hidden_size mistmatched in R input. Current value is: ", - r_pshape[0].get_length(), - ", expected: ", - merged_hidden_size.get_length() * s_gates_count, - "."); - } - - if (b_pshape[0].is_static()) { - NODE_VALIDATION_CHECK(this, - b_pshape[0].compatible(merged_hidden_size * (s_gates_count + m_linear_before_reset)), - "Parameter hidden_size mistmatched in B input. Current value is: ", - b_pshape[0].get_length(), - ", expected: ", - merged_hidden_size.get_length() * (s_gates_count + m_linear_before_reset), - "."); - } - } - - // Set output size, type and shape - set_output_size(1); - set_output_type(0, result_et, {merged_batch_size, merged_hidden_size}); + // Set output type and shape + set_output_type(0, result_et, output_shapes[0]); } shared_ptr ov::op::internal::AUGRUCell::clone_with_new_inputs(const OutputVector& new_args) const { diff --git a/src/common/transformations/src/ngraph_ops/augru_sequence.cpp b/src/common/transformations/src/ngraph_ops/augru_sequence.cpp index 9947bfd56a1..03833775dc8 100644 --- a/src/common/transformations/src/ngraph_ops/augru_sequence.cpp +++ b/src/common/transformations/src/ngraph_ops/augru_sequence.cpp @@ -63,17 +63,8 @@ void ov::op::internal::AUGRUSequence::validate_and_infer_types() { element::Type::merge(result_et, result_et, get_input_element_type(6)), "Element types for inputs do not match."); - const auto& x_pshape = get_input_partial_shape(0); - const auto& ht_pshape = get_input_partial_shape(1); - const auto& sl_pshape = get_input_partial_shape(2); - const auto& w_pshape = get_input_partial_shape(3); - const auto& r_pshape = get_input_partial_shape(4); - const auto& b_pshape = get_input_partial_shape(5); - const auto& a_pshape = get_input_partial_shape(6); - - std::vector input_shapes = - {x_pshape, ht_pshape, sl_pshape, w_pshape, r_pshape, b_pshape, a_pshape}; - std::vector output_shapes = {ov::PartialShape{4}, ov::PartialShape{3}}; + const auto input_shapes = get_node_input_partial_shapes(*this); + std::vector output_shapes = {ov::PartialShape::dynamic(4), ov::PartialShape::dynamic(3)}; shape_infer(this, input_shapes, output_shapes); // Set output size, type and shape diff --git a/src/core/shape_inference/include/augru_cell_shape_inference.hpp b/src/core/shape_inference/include/augru_cell_shape_inference.hpp new file mode 100644 index 00000000000..95fcda944e6 --- /dev/null +++ b/src/core/shape_inference/include/augru_cell_shape_inference.hpp @@ -0,0 +1,44 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include "gru_cell_shape_inference.hpp" +#include "ngraph_ops/augru_sequence.hpp" +#include "utils.hpp" + +namespace ov { +namespace op { + +namespace internal { +template +void shape_infer(const ov::op::internal::AUGRUCell* op, + const std::vector& input_shapes, + std::vector& output_shapes) { + constexpr size_t expected_in_shapes_count = 6; + NODE_VALIDATION_CHECK(op, + input_shapes.size() == expected_in_shapes_count, + "Incorrect number of input shapes has been provided. Expected: ", + expected_in_shapes_count, + ", got: ", + input_shapes.size(), + "."); + + rnn::gru_cell_shape_infer(op, input_shapes, output_shapes); + + // `A` input shape validation // [batch_size, 1] + const auto& a_shape = input_shapes.back(); + const auto& x_shape = input_shapes[0]; + NODE_VALIDATION_CHECK(op, a_shape.rank().compatible(2), "'A' input must be a 2D tensor."); + if (a_shape.rank().is_static()) { + if (x_shape.rank().is_static()) { + NODE_VALIDATION_CHECK(op, + x_shape.rank().get_length() > 1 && a_shape[0].compatible(x_shape[0]), + "Dimension `batch_size` must be the same for `X` and `A` inputs."); + } + NODE_VALIDATION_CHECK(op, a_shape[1].compatible(1), "The last dimension of `A` shape must be equal to `1`."); + } +} +} // namespace internal +} // namespace op +} // namespace ov diff --git a/src/core/shape_inference/include/augru_sequence_shape_inference.hpp b/src/core/shape_inference/include/augru_sequence_shape_inference.hpp index a8de196d65e..3f59babf3d3 100644 --- a/src/core/shape_inference/include/augru_sequence_shape_inference.hpp +++ b/src/core/shape_inference/include/augru_sequence_shape_inference.hpp @@ -24,16 +24,19 @@ void shape_infer(const ov::op::internal::AUGRUSequence* op, input_shapes.size(), "."); - rnn_seq::gru_shape_infer(op, input_shapes, output_shapes); + rnn::gru_sequence_shape_infer(op, input_shapes, output_shapes); // A input shape validation // [batch_size, seq_length, 1] - const auto& a_shape = input_shapes[6]; + const auto& a_shape = input_shapes.back(); const auto& x_shape = input_shapes[0]; NODE_VALIDATION_CHECK(op, a_shape.rank().compatible(3), "'A' input must be a 3D tensor."); if (a_shape.rank().is_static()) { if (x_shape.rank().is_static()) { NODE_VALIDATION_CHECK(op, - x_shape.rank().get_length() > 1 && a_shape[1].compatible(x_shape[1]), + x_shape.rank().get_length() > 1 && a_shape[0].compatible(x_shape[0]), + "Dimension `batch_size` must be the same for `X` and `A` inputs."); + NODE_VALIDATION_CHECK(op, + x_shape.rank().get_length() > 2 && a_shape[1].compatible(x_shape[1]), "Dimension `seq_length` must be the same for `X` and `A` inputs."); } NODE_VALIDATION_CHECK(op, a_shape[2].compatible(1), "The last dimension of `A` shape must be equal to `1`."); diff --git a/src/core/shape_inference/include/gru_cell_shape_inference.hpp b/src/core/shape_inference/include/gru_cell_shape_inference.hpp new file mode 100644 index 00000000000..cee993661aa --- /dev/null +++ b/src/core/shape_inference/include/gru_cell_shape_inference.hpp @@ -0,0 +1,105 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once +#include +#include + +#include "gru_cell_shape_inference.hpp" +#include "gru_sequence_shape_inference.hpp" +#include "utils.hpp" + +namespace ov { +namespace op { +namespace rnn { + +// Output shape layout: +// output_shapes[0]: [batch_size, hidden_size] // Rank always 2 +template +void gru_cell_shape_infer(const OpType* op, + const std::vector& input_shapes, + std::vector& output_shapes) { + NODE_VALIDATION_CHECK(op, + input_shapes.size() >= 5 && output_shapes.size() == 1, + "Incorrect number of shapes has been provided."); + + auto& y_out_shape = output_shapes[0]; + y_out_shape.resize(2); // Rank always 2 + + rnn::validate_inputs_rank(op, input_shapes, {2, 2, 2, 2, 1}); + + const auto& x_pshape = input_shapes[0]; // [batch_size, input_size] + const auto& ht_pshape = input_shapes[1]; // [batch_size, hidden_size] + const auto& w_pshape = input_shapes[2]; // [3 * hidden_size, input_size] + const auto& r_pshape = input_shapes[3]; // [3 * hidden_size, hidden_size] + const auto& b_pshape = input_shapes[4]; // if linear_before_reset [4 * hidden_size], otherwise [3 * hidden_size] + + using DimType = typename std::iterator_traits::value_type; + + // Merge batch_size dimension across all inputs to evaluate output[0] dimension + DimType merged_batch_size = x_pshape.rank().is_static() ? x_pshape[0] : DimType(); + NODE_VALIDATION_CHECK( + op, + DimType::merge(merged_batch_size, merged_batch_size, ht_pshape.rank().is_static() ? ht_pshape[0] : DimType()), + "Dimension `batch_size` is not matched between inputs."); + + // Set batch_size dimension + y_out_shape[0] = merged_batch_size; + + // Merge hidden_size dimension across all inputs to evaluate output dimension + // `hidden_size` attribute is not used for backward compatibility + DimType merged_hidden_size = ht_pshape.rank().is_static() ? ht_pshape[1] : DimType(); + NODE_VALIDATION_CHECK( + op, + DimType::merge(merged_hidden_size, merged_hidden_size, r_pshape.rank().is_static() ? r_pshape[1] : DimType()), + "Dimension `hidden_size` is not matched between inputs."); + + // Validate dimensions related to hidden_size for W, R, B inputs + if (merged_hidden_size.is_static()) { + constexpr auto gru_gates_count = 3; + if (w_pshape.rank().is_static()) { + NODE_VALIDATION_CHECK(op, + w_pshape[0].compatible(merged_hidden_size * gru_gates_count), + "First dimension of W input shape is required to be compatible with ", + merged_hidden_size * gru_gates_count, + ". Got shape: ", + w_pshape[0], + "."); + } + + if (r_pshape.rank().is_static()) { + NODE_VALIDATION_CHECK(op, + r_pshape[0].compatible(merged_hidden_size * gru_gates_count), + "Fisrt dimension of R input shape is required to be compatible with ", + merged_hidden_size * gru_gates_count, + ". Got shape: ", + r_pshape[0], + "."); + } + + if (b_pshape.rank().is_static()) { + auto bias_dim_multiplier = op->get_linear_before_reset() ? (gru_gates_count + 1) : gru_gates_count; + NODE_VALIDATION_CHECK(op, + b_pshape[0].compatible(merged_hidden_size * bias_dim_multiplier), + "First dimension of B input shape is required to be compatible with ", + merged_hidden_size * bias_dim_multiplier, + ". Got shape: ", + b_pshape[0], + "."); + } + } + + // Set hidden_size dimension + y_out_shape[1] = merged_hidden_size; +} +} // namespace rnn +namespace v3 { +template +void shape_infer(const ov::op::v3::GRUCell* op, + const std::vector& input_shapes, + std::vector& output_shapes) { + rnn::gru_cell_shape_infer(op, input_shapes, output_shapes); +} +} // namespace v3 +} // namespace op +} // namespace ov diff --git a/src/core/shape_inference/include/gru_sequence_shape_inference.hpp b/src/core/shape_inference/include/gru_sequence_shape_inference.hpp index a6af2d40a50..d9a87ff72d8 100644 --- a/src/core/shape_inference/include/gru_sequence_shape_inference.hpp +++ b/src/core/shape_inference/include/gru_sequence_shape_inference.hpp @@ -9,7 +9,7 @@ namespace ov { namespace op { -namespace rnn_seq { +namespace rnn { template void validate_inputs_rank(const OpType* op, const std::vector& input_shapes, @@ -32,9 +32,9 @@ void validate_inputs_rank(const OpType* op, // output_shapes[0]: [batch_size, num_directions, seq_length, hidden_size] // Rank always 4 // output_shapes[1]: [batch_size, num_directions, hidden_size] // Rank always 3 template -void gru_shape_infer(const OpType* op, - const std::vector& input_shapes, - std::vector& output_shapes) { +void gru_sequence_shape_infer(const OpType* op, + const std::vector& input_shapes, + std::vector& output_shapes) { NODE_VALIDATION_CHECK(op, input_shapes.size() >= 6 && output_shapes.size() == 2, "Incorrect number of shapes has been provided."); @@ -44,14 +44,14 @@ void gru_shape_infer(const OpType* op, y_out_shape.resize(4); // Rank always 4 ho_out_shape.resize(3); // Rank always 3 - rnn_seq::validate_inputs_rank(op, input_shapes, {3, 3, 1, 3, 3, 2}); + rnn::validate_inputs_rank(op, input_shapes, {3, 3, 1, 3, 3, 2}); - auto x_pshape = input_shapes[0]; - auto ht_pshape = input_shapes[1]; - auto sl_pshape = input_shapes[2]; - auto w_pshape = input_shapes[3]; - auto r_pshape = input_shapes[4]; - auto b_pshape = input_shapes[5]; + const auto& x_pshape = input_shapes[0]; + const auto& ht_pshape = input_shapes[1]; + const auto& sl_pshape = input_shapes[2]; + const auto& w_pshape = input_shapes[3]; + const auto& r_pshape = input_shapes[4]; + const auto& b_pshape = input_shapes[5]; using DimType = typename std::iterator_traits::value_type; @@ -155,7 +155,7 @@ void gru_shape_infer(const OpType* op, y_out_shape[3] = merged_hidden_size; ho_out_shape[2] = merged_hidden_size; } -} // namespace rnn_seq +} // namespace rnn namespace v5 { template void shape_infer(const ov::op::v5::GRUSequence* op, @@ -170,7 +170,7 @@ void shape_infer(const ov::op::v5::GRUSequence* op, input_shapes.size(), "."); - rnn_seq::gru_shape_infer(op, input_shapes, output_shapes); + rnn::gru_sequence_shape_infer(op, input_shapes, output_shapes); } } // namespace v5 } // namespace op diff --git a/src/core/src/op/gru_cell.cpp b/src/core/src/op/gru_cell.cpp index 4293ad4db21..cc320a6bfea 100644 --- a/src/core/src/op/gru_cell.cpp +++ b/src/core/src/op/gru_cell.cpp @@ -6,6 +6,7 @@ #include +#include "gru_cell_shape_inference.hpp" #include "itt.hpp" #include "ngraph/op/constant.hpp" #include "ngraph/shape.hpp" @@ -87,87 +88,22 @@ bool op::v3::GRUCell::visit_attributes(AttributeVisitor& visitor) { void op::v3::GRUCell::validate_and_infer_types() { OV_OP_SCOPE(v3_GRUCell_validate_and_infer_types); - for (const auto& input : inputs()) { - if (input.get_partial_shape().rank().is_dynamic()) { - set_output_type(0, get_input_element_type(0), ov::PartialShape::dynamic()); - return; - } - } - auto merged_batch_size = Dimension::dynamic(); - auto merged_hidden_size = Dimension::dynamic(); - auto result_et = element::dynamic; - - // Get input partial shape for all inputs - const auto& x_pshape = get_input_partial_shape(0); - const auto& ht_pshape = get_input_partial_shape(1); - const auto& w_pshape = get_input_partial_shape(2); - const auto& r_pshape = get_input_partial_shape(3); - const auto& b_pshape = get_input_partial_shape(4); - - validate_input_rank_dimension({x_pshape, ht_pshape, w_pshape, r_pshape, b_pshape}); // Validate input types and save result for output type + auto result_et = element::dynamic; NODE_VALIDATION_CHECK(this, element::Type::merge(result_et, result_et, get_input_element_type(0)) && element::Type::merge(result_et, result_et, get_input_element_type(1)) && - element::Type::merge(result_et, result_et, get_input_element_type(2)) && element::Type::merge(result_et, result_et, get_input_element_type(3)) && element::Type::merge(result_et, result_et, get_input_element_type(4)), - "Element types for X, initial_hidden_state, W, R and B inputs do not match."); + "Element types for X, initial_hidden_state, W, R and B inputs do not " + "match."); - // Merge batch_size dimension across all inputs to evaluate output[0] dimension - NODE_VALIDATION_CHECK(this, - Dimension::merge(merged_batch_size, merged_batch_size, ht_pshape[0]) && - Dimension::merge(merged_batch_size, merged_batch_size, x_pshape[0]), - "Parameter batch_size not matched for X and initial_hidden_state inputs."); + const auto input_shapes = get_node_input_partial_shapes(*this); + std::vector output_shapes{ov::PartialShape::dynamic(2)}; + shape_infer(this, input_shapes, output_shapes); - // Merge hidden_size dimension across all inputs to evaluate output[1] dimension - NODE_VALIDATION_CHECK(this, - Dimension::merge(merged_hidden_size, merged_hidden_size, ht_pshape[1]) && - Dimension::merge(merged_hidden_size, merged_hidden_size, r_pshape[1]), - "Parameter hidden_size not matched for R and initial_hidden_state inputs."); - - // Validate hidden_size value for W, B and R inputs - if (merged_hidden_size.is_static()) { - if (w_pshape[0].is_static()) { - NODE_VALIDATION_CHECK(this, - w_pshape[0].compatible(merged_hidden_size * s_gates_count), - "Parameter hidden_size mistmatched in W input. Current value is: ", - w_pshape[0].get_length(), - ", expected: ", - merged_hidden_size.get_length() * s_gates_count, - "."); - } - - if (r_pshape[0].is_static()) { - NODE_VALIDATION_CHECK(this, - r_pshape[0].compatible(merged_hidden_size * s_gates_count), - "Parameter hidden_size mistmatched in R input. Current value is: ", - r_pshape[0].get_length(), - ", expected: ", - merged_hidden_size.get_length() * s_gates_count, - "."); - } - - if (b_pshape[0].is_static()) { - NODE_VALIDATION_CHECK(this, - b_pshape[0].compatible(merged_hidden_size * (s_gates_count + m_linear_before_reset)), - "Parameter hidden_size mistmatched in B input. Current value is: ", - b_pshape[0].get_length(), - ", expected: ", - merged_hidden_size.get_length() * (s_gates_count + m_linear_before_reset), - "."); - } - } - - // Mark inputs which are relevant to output parameters - set_input_is_relevant_to_shape(0); - set_input_is_relevant_to_shape(1); - set_input_is_relevant_to_shape(3); - - // Set output size, type and shape - set_output_size(1); - set_output_type(0, result_et, {merged_batch_size, merged_hidden_size}); + set_output_type(0, result_et, output_shapes[0]); } void op::v3::GRUCell::add_default_bias_input() { diff --git a/src/core/src/op/gru_sequence.cpp b/src/core/src/op/gru_sequence.cpp index 78db8758348..f8ed8c5957d 100644 --- a/src/core/src/op/gru_sequence.cpp +++ b/src/core/src/op/gru_sequence.cpp @@ -60,15 +60,8 @@ void op::v5::GRUSequence::validate_and_infer_types() { "Element types for X, initial_hidden_state, W, R and B inputs do not " "match."); - const auto& x_pshape = get_input_partial_shape(0); - const auto& ht_pshape = get_input_partial_shape(1); - const auto& sl_pshape = get_input_partial_shape(2); - const auto& w_pshape = get_input_partial_shape(3); - const auto& r_pshape = get_input_partial_shape(4); - const auto& b_pshape = get_input_partial_shape(5); - - std::vector input_shapes = {x_pshape, ht_pshape, sl_pshape, w_pshape, r_pshape, b_pshape}; - std::vector output_shapes = {ov::PartialShape{4}, ov::PartialShape{3}}; + const auto input_shapes = get_node_input_partial_shapes(*this); + std::vector output_shapes = {ov::PartialShape::dynamic(4), ov::PartialShape::dynamic(3)}; shape_infer(this, input_shapes, output_shapes); // Set output size, type and shape diff --git a/src/core/tests/type_prop/augru_cell.cpp b/src/core/tests/type_prop/augru_cell.cpp index d7ce2cb5248..04533df9d2a 100644 --- a/src/core/tests/type_prop/augru_cell.cpp +++ b/src/core/tests/type_prop/augru_cell.cpp @@ -48,7 +48,8 @@ TEST(type_prop, augru_cell_invalid_input) { const auto gru_cell = make_shared(X, H_t, W, R, B, A, hidden_size); FAIL() << "AUGRUCell node was created with invalid data."; } catch (const NodeValidationFailure& error) { - EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter hidden_size mistmatched in W input.")); + EXPECT_HAS_SUBSTRING(error.what(), + std::string("First dimension of W input shape is required to be compatible")); } // Invalid R tensor shape. @@ -58,8 +59,7 @@ TEST(type_prop, augru_cell_invalid_input) { const auto gru_cell = make_shared(X, H_t, W, R, B, A, hidden_size); FAIL() << "AUGRUCell node was created with invalid data."; } catch (const NodeValidationFailure& error) { - EXPECT_HAS_SUBSTRING(error.what(), - std::string("Dimension hidden_size not matched for R and initial_hidden_state inputs.")); + EXPECT_HAS_SUBSTRING(error.what(), std::string("Dimension `hidden_size` is not matched between inputs")); } // Invalid H_t tensor shape. @@ -69,7 +69,7 @@ TEST(type_prop, augru_cell_invalid_input) { const auto gru_cell = make_shared(X, H_t, W, R, B, A, hidden_size); FAIL() << "AUGRUCell node was created with invalid data."; } catch (const NodeValidationFailure& error) { - EXPECT_HAS_SUBSTRING(error.what(), std::string("Dimension batch_size is not matched between inputs.")); + EXPECT_HAS_SUBSTRING(error.what(), std::string("Dimension `batch_size` is not matched between inputs")); } // Invalid B tensor shape. @@ -79,9 +79,8 @@ TEST(type_prop, augru_cell_invalid_input) { const auto gru_cell = make_shared(X, H_t, W, R, B, A, hidden_size); FAIL() << "GRUCell node was created with invalid data."; } catch (const NodeValidationFailure& error) { - EXPECT_HAS_SUBSTRING( - error.what(), - std::string("Parameter hidden_size mistmatched in B input. Current value is: 3, expected: 9.")); + EXPECT_HAS_SUBSTRING(error.what(), + std::string("First dimension of B input shape is required to be compatible")); } // Invalid A tensor shape. @@ -185,11 +184,11 @@ TEST(type_prop, augru_cell_invalid_input_rank) { << "AUGRUCell node was created with invalid data."; } -TEST(type_prop, augru_cell_invalid_input_dynamic_rank) { - const size_t batch_size = 2; - const size_t input_size = 3; - const size_t hidden_size = 3; - const size_t gates_count = 3; +TEST(type_prop, augru_cell_input_dynamic_rank) { + int64_t batch_size = 2; + int64_t input_size = 3; + int64_t hidden_size = 3; + int64_t gates_count = 3; auto X = make_shared(element::f32, PartialShape{batch_size, input_size}); auto R = make_shared(element::f32, PartialShape{gates_count * hidden_size, hidden_size}); @@ -197,41 +196,41 @@ TEST(type_prop, augru_cell_invalid_input_dynamic_rank) { auto B = make_shared(element::f32, PartialShape{gates_count * hidden_size}); auto A = make_shared(element::f32, PartialShape{batch_size, 1}); - auto check_dynamic_gru = [](const shared_ptr& augru) -> bool { - return augru->output(0).get_partial_shape() == PartialShape::dynamic(2) && + auto check_dynamic_gru = [&](const shared_ptr& augru) -> bool { + return augru->output(0).get_partial_shape() == PartialShape{batch_size, hidden_size} && augru->output(0).get_element_type() == augru->input(0).get_element_type(); }; - // Invalid dynamic rank for W tensor. + // Dynamic rank for W tensor. auto W = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); auto augru_w = make_shared(X, H_t, W, R, B, A, hidden_size); EXPECT_TRUE(check_dynamic_gru(augru_w)); - // Invalid dynamic rank for X tensor. - W = make_shared(element::f32, PartialShape{hidden_size, input_size}); + // Dynamic rank for X tensor. + W = make_shared(element::f32, PartialShape{gates_count * hidden_size, input_size}); X = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); auto augru_x = make_shared(X, H_t, W, R, B, A, hidden_size); EXPECT_TRUE(check_dynamic_gru(augru_x)); - // Invalid dynamic rank for H_t tensor. + // Dynamic rank for H_t tensor. X = make_shared(element::f32, PartialShape{batch_size, input_size}); H_t = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); auto augru_h = make_shared(X, H_t, W, R, B, A, hidden_size); EXPECT_TRUE(check_dynamic_gru(augru_h)); - // Invalid dynamic rank for R tensor. + // Dynamic rank for R tensor. H_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); R = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); auto augru_r = make_shared(X, H_t, W, R, B, A, hidden_size); EXPECT_TRUE(check_dynamic_gru(augru_r)); - // Invalid dynamic rank for B tensor. + // Dynamic rank for B tensor. R = make_shared(element::f32, PartialShape{gates_count * hidden_size, hidden_size}); B = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); auto augru_b = make_shared(X, H_t, W, R, B, A, hidden_size); EXPECT_TRUE(check_dynamic_gru(augru_b)); - // Invalid dynamic rank for A tensor. + // Dynamic rank for A tensor. B = make_shared(element::f32, PartialShape{gates_count * hidden_size}); A = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); auto augru_a = make_shared(X, H_t, W, R, B, A, hidden_size); diff --git a/src/core/tests/type_prop/augru_sequence.cpp b/src/core/tests/type_prop/augru_sequence.cpp index dbc807b7dc2..0b9762eb247 100644 --- a/src/core/tests/type_prop/augru_sequence.cpp +++ b/src/core/tests/type_prop/augru_sequence.cpp @@ -4,6 +4,7 @@ #include "ngraph_ops/augru_sequence.hpp" +#include "common_test_utils/test_assertions.hpp" #include "gtest/gtest.h" #include "openvino/core/attribute_visitor.hpp" #include "openvino/opsets/opset9.hpp" @@ -11,6 +12,7 @@ using namespace std; using namespace ov; +using namespace testing; struct augru_sequence_parameters { Dimension batch_size = 8; @@ -258,7 +260,7 @@ TEST(type_prop, augru_sequence_all_inputs_dynamic_rank) { EXPECT_EQ(augru_sequence->get_output_element_type(1), param.et); } -TEST(type_prop, augru_sequence_invalid_attention_gate) { +TEST(type_prop, augru_sequence_invalid_attention_gate_seq_length) { augru_sequence_parameters params; params.batch_size = 8; @@ -272,13 +274,28 @@ TEST(type_prop, augru_sequence_invalid_attention_gate) { auto invalid_attention_gate = make_shared(params.et, PartialShape{params.batch_size, 999, 1}); augru_sequence->set_argument(6, invalid_attention_gate); - try { - augru_sequence->validate_and_infer_types(); - FAIL() << "AUGRUSequence node was created with invalid data."; - } catch (const NodeValidationFailure& error) { - EXPECT_HAS_SUBSTRING(error.what(), - std::string("Dimension `seq_length` must be the same for `X` and `A` inputs.")); - } + OV_EXPECT_THROW(augru_sequence->validate_and_infer_types(), + ov::NodeValidationFailure, + HasSubstr("Dimension `seq_length` must be the same for `X` and `A` inputs")); +} + +TEST(type_prop, augru_sequence_invalid_attention_gate_batch) { + augru_sequence_parameters params; + + params.batch_size = 8; + params.num_directions = 1; + params.seq_length = 6; + params.input_size = 4; + params.hidden_size = 128; + params.et = element::f32; + + auto augru_sequence = augru_seq_init(params); + auto invalid_attention_gate = make_shared(params.et, PartialShape{999, params.seq_length, 1}); + augru_sequence->set_argument(6, invalid_attention_gate); + + OV_EXPECT_THROW(augru_sequence->validate_and_infer_types(), + ov::NodeValidationFailure, + HasSubstr("Dimension `batch_size` must be the same for `X` and `A` inputs")); } namespace { diff --git a/src/core/tests/type_prop/gru_cell.cpp b/src/core/tests/type_prop/gru_cell.cpp index 34387f3e7fe..6e2a4ea0678 100644 --- a/src/core/tests/type_prop/gru_cell.cpp +++ b/src/core/tests/type_prop/gru_cell.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "common_test_utils/test_assertions.hpp" #include "gtest/gtest.h" #include "ngraph/ngraph.hpp" #include "ngraph/opsets/opset4.hpp" @@ -9,6 +10,7 @@ using namespace std; using namespace ngraph; +using namespace testing; TEST(type_prop, gru_cell) { const size_t batch_size = 2; @@ -38,44 +40,30 @@ TEST(type_prop, gru_cell_invalid_input) { // Invalid W tensor shape. auto W = make_shared(element::f32, Shape{hidden_size, input_size}); - try { - const auto gru_cell = make_shared(X, H_t, W, R, hidden_size); - FAIL() << "GRUCell node was created with invalid data."; - } catch (const NodeValidationFailure& error) { - EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter hidden_size mistmatched in W input.")); - } + OV_EXPECT_THROW(auto op = make_shared(X, H_t, W, R, hidden_size), + ov::NodeValidationFailure, + HasSubstr("First dimension of W input shape is required to be compatible")); // Invalid R tensor shape. W = make_shared(element::f32, Shape{gates_count * hidden_size, input_size}); R = make_shared(element::f32, Shape{hidden_size, 1}); - try { - const auto gru_cell = make_shared(X, H_t, W, R, hidden_size); - FAIL() << "GRUCell node was created with invalid data."; - } catch (const NodeValidationFailure& error) { - EXPECT_HAS_SUBSTRING(error.what(), - std::string("Parameter hidden_size not matched for R and initial_hidden_state inputs.")); - } + OV_EXPECT_THROW(auto op = make_shared(X, H_t, W, R, hidden_size), + ov::NodeValidationFailure, + HasSubstr("Dimension `hidden_size` is not matched between inputs")); // Invalid H_t tensor shape. R = make_shared(element::f32, Shape{gates_count * hidden_size, hidden_size}); H_t = make_shared(element::f32, Shape{4, hidden_size}); - try { - const auto gru_cell = make_shared(X, H_t, W, R, hidden_size); - FAIL() << "GRUCell node was created with invalid data."; - } catch (const NodeValidationFailure& error) { - EXPECT_HAS_SUBSTRING(error.what(), - std::string("Parameter batch_size not matched for X and initial_hidden_state inputs.")); - } + OV_EXPECT_THROW(auto op = make_shared(X, H_t, W, R, hidden_size), + ov::NodeValidationFailure, + HasSubstr("Dimension `batch_size` is not matched between inputs")); // Invalid B tensor shape. H_t = make_shared(element::f32, Shape{batch_size, hidden_size}); auto B = make_shared(element::f32, Shape{hidden_size}); - try { - const auto gru_cell = make_shared(X, H_t, W, R, B, hidden_size); - FAIL() << "GRUCell node was created with invalid data."; - } catch (const NodeValidationFailure& error) { - EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter hidden_size mistmatched in B input.")); - } + OV_EXPECT_THROW(auto op = make_shared(X, H_t, W, R, B, hidden_size), + ov::NodeValidationFailure, + HasSubstr("First dimension of B input shape is required to be compatible")); } TEST(type_prop, gru_cell_dynamic_batch_size) { @@ -171,45 +159,45 @@ TEST(type_prop, gru_cell_invalid_input_rank0) { << "GRUCell node was created with invalid data."; } -TEST(type_prop, gru_cell_invalid_input_dynamic_rank) { - const size_t batch_size = 2; - const size_t input_size = 3; - const size_t hidden_size = 3; - const size_t gates_count = 3; +TEST(type_prop, gru_cell_input_dynamic_rank) { + int64_t batch_size = 2; + int64_t input_size = 3; + int64_t hidden_size = 3; + int64_t gates_count = 3; auto X = make_shared(element::f32, PartialShape{batch_size, input_size}); auto R = make_shared(element::f32, PartialShape{gates_count * hidden_size, hidden_size}); auto H_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); - auto check_dynamic_gru = [](const shared_ptr& gru) -> bool { - return gru->output(0).get_partial_shape() == PartialShape::dynamic() && + auto check_dynamic_gru = [&](const shared_ptr& gru) -> bool { + return gru->output(0).get_partial_shape() == PartialShape{batch_size, hidden_size} && gru->output(0).get_element_type() == gru->input(0).get_element_type(); }; - // Invalid dynamic rank for W tensor. + // Dynamic rank for W tensor. auto W = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); auto gru_w = make_shared(X, H_t, W, R, hidden_size); EXPECT_EQ(check_dynamic_gru(gru_w), true); - // Invalid dynamic rank for X tensor. - W = make_shared(element::f32, PartialShape{hidden_size, input_size}); + // Dynamic rank for X tensor. + W = make_shared(element::f32, PartialShape{gates_count * hidden_size, input_size}); X = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); auto gru_x = make_shared(X, H_t, W, R, hidden_size); EXPECT_EQ(check_dynamic_gru(gru_x), true); - // Invalid dynamic rank for H_t tensor. + // Dynamic rank for H_t tensor. X = make_shared(element::f32, PartialShape{batch_size, input_size}); H_t = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); auto gru_h = make_shared(X, H_t, W, R, hidden_size); EXPECT_EQ(check_dynamic_gru(gru_h), true); - // Invalid dynamic rank for R tensor. + // Dynamic rank for R tensor. H_t = make_shared(element::f32, PartialShape{batch_size, hidden_size}); R = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); auto gru_r = make_shared(X, H_t, W, R, hidden_size); EXPECT_EQ(check_dynamic_gru(gru_r), true); - // Invalid dynamic rank for B tensor. + // Dynamic rank for B tensor. R = make_shared(element::f32, PartialShape{gates_count * hidden_size, hidden_size}); auto B = make_shared(element::f32, PartialShape::dynamic(Rank::dynamic())); auto gru_b = make_shared(X, H_t, W, R, B, hidden_size); diff --git a/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp b/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp index 4c686b93912..345dd19b045 100644 --- a/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp +++ b/src/plugins/intel_cpu/src/utils/shape_inference/shape_inference.cpp @@ -11,7 +11,12 @@ #include #include +#include "ngraph_ops/augru_cell.hpp" +#include "ngraph_ops/augru_sequence.hpp" + #include "assign_shape_inference.hpp" +#include "augru_cell_shape_inference.hpp" +#include "augru_sequence_shape_inference.hpp" #include "batch_to_space_shape_inference.hpp" #include "broadcast_shape_inference.hpp" #include "bucketize_shape_inference.hpp" @@ -38,6 +43,7 @@ #include "gather_shape_inference.hpp" #include "gather_tree_shape_inference.hpp" #include "gru_sequence_shape_inference.hpp" +#include "gru_cell_shape_inference.hpp" #include "interpolate_shape_inference.hpp" #include "lstm_cell_shape_inference.hpp" #include "matmul_shape_inference.hpp" @@ -494,6 +500,12 @@ std::shared_ptr make_shape_inference(const std::shared_ptr(op)) { return make_shared_entryIO(node); + } else if (auto node = ov::as_type_ptr(op)) { + return make_shared_entryIO(node); + } else if (auto node = ov::as_type_ptr(op)) { + return make_shared_entryIO(node); + } else if (auto node = ov::as_type_ptr(op)) { + return make_shared_entryIO(node); } else if (auto node = ov::as_type_ptr(op)) { return make_shared_entryIOC(node); } else if (auto node = ov::as_type_ptr(op)) { diff --git a/src/plugins/intel_cpu/tests/unit/shape_inference_test/augru_cell_test.cpp b/src/plugins/intel_cpu/tests/unit/shape_inference_test/augru_cell_test.cpp new file mode 100644 index 00000000000..da4eb5be5c8 --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/shape_inference_test/augru_cell_test.cpp @@ -0,0 +1,71 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "ngraph_ops/augru_cell.hpp" + +#include + +#include +#include +#include +#include + +using namespace ov; +using namespace ov::intel_cpu; + +TEST(StaticShapeInferenceTest, AUGRUCellTest_all_inputs_static_rank) { + constexpr size_t batch_size = 2; + constexpr size_t input_size = 3; + constexpr size_t hidden_size = 5; + constexpr size_t gates_count = 3; + + const auto X = std::make_shared(element::f32, PartialShape::dynamic(2)); + const auto H_t = std::make_shared(element::f32, PartialShape::dynamic(2)); + const auto W = std::make_shared(element::f32, PartialShape::dynamic(2)); + const auto R = std::make_shared(element::f32, PartialShape::dynamic(2)); + const auto B = std::make_shared(element::f32, PartialShape::dynamic(1)); + const auto A = std::make_shared(element::f32, PartialShape::dynamic(2)); + + const auto augru = std::make_shared(X, H_t, W, R, B, A, hidden_size); + + std::vector static_input_shapes{StaticShape{batch_size, input_size}, // X + StaticShape{batch_size, hidden_size}, // H_t + StaticShape{gates_count * hidden_size, input_size}, // W + StaticShape{gates_count * hidden_size, hidden_size}, // R + StaticShape{gates_count * hidden_size}, // B + StaticShape{batch_size, 1}}; // A + + std::vector static_output_shapes{StaticShape{}, StaticShape{}}; + + shape_inference(augru.get(), static_input_shapes, static_output_shapes); + EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size})); +} + +TEST(StaticShapeInferenceTest, AUGRUCellTest_all_inputs_dynamic_rank) { + constexpr size_t batch_size = 2; + constexpr size_t input_size = 3; + constexpr size_t hidden_size = 5; + constexpr size_t gates_count = 3; + + const auto X = std::make_shared(element::f32, PartialShape::dynamic()); + const auto H_t = std::make_shared(element::f32, PartialShape::dynamic()); + const auto W = std::make_shared(element::f32, PartialShape::dynamic()); + const auto R = std::make_shared(element::f32, PartialShape::dynamic()); + const auto B = std::make_shared(element::f32, PartialShape::dynamic()); + const auto A = std::make_shared(element::f32, PartialShape::dynamic()); + + const auto augru = std::make_shared(X, H_t, W, R, B, A, hidden_size); + + std::vector static_input_shapes{StaticShape{batch_size, input_size}, // X + StaticShape{batch_size, hidden_size}, // H_t + StaticShape{gates_count * hidden_size, input_size}, // W + StaticShape{gates_count * hidden_size, hidden_size}, // R + StaticShape{gates_count * hidden_size}, // B + StaticShape{batch_size, 1}}; // A + + std::vector static_output_shapes{StaticShape{}, StaticShape{}}; + + shape_inference(augru.get(), static_input_shapes, static_output_shapes); + EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size})); +} diff --git a/src/plugins/intel_cpu/tests/unit/shape_inference_test/augru_sequence_test.cpp b/src/plugins/intel_cpu/tests/unit/shape_inference_test/augru_sequence_test.cpp new file mode 100644 index 00000000000..a48b6367ff7 --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/shape_inference_test/augru_sequence_test.cpp @@ -0,0 +1,85 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "ngraph_ops/augru_sequence.hpp" + +#include + +#include +#include +#include +#include + +using namespace ov; +using namespace ov::intel_cpu; + +TEST(StaticShapeInferenceTest, AGRUSequenceTest_FORWARD_all_static_rank) { + constexpr size_t batch_size = 2; + constexpr size_t input_size = 3; + constexpr size_t hidden_size = 5; + constexpr size_t seq_len = 4; + constexpr size_t num_directions = 1; + constexpr size_t gates_count = 3; + + const auto X = std::make_shared(element::f32, PartialShape::dynamic(3)); + const auto H_t = std::make_shared(element::f32, PartialShape::dynamic(3)); + const auto seq_lengths = std::make_shared(element::i32, PartialShape::dynamic(1)); + const auto W = std::make_shared(element::f32, PartialShape::dynamic(3)); + const auto R = std::make_shared(element::f32, PartialShape::dynamic(3)); + const auto B = std::make_shared(element::f32, PartialShape::dynamic(2)); + const auto A = std::make_shared(element::f32, PartialShape::dynamic(3)); + + const auto augru_sequence = + std::make_shared(X, H_t, seq_lengths, W, R, B, A, hidden_size); + + std::vector static_input_shapes{ + StaticShape{batch_size, seq_len, input_size}, // X + StaticShape{batch_size, num_directions, hidden_size}, // H_t + StaticShape{batch_size}, // seq_lengths + StaticShape{num_directions, gates_count * hidden_size, input_size}, // W + StaticShape{num_directions, gates_count * hidden_size, hidden_size}, // R + StaticShape{num_directions, gates_count * hidden_size}, // B + StaticShape{batch_size, seq_len, 1}}; // A + + std::vector static_output_shapes{StaticShape{}, StaticShape{}}; + + shape_inference(augru_sequence.get(), static_input_shapes, static_output_shapes); + EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, num_directions, seq_len, hidden_size})); + EXPECT_EQ(static_output_shapes[1], StaticShape({batch_size, num_directions, hidden_size})); +} + +TEST(StaticShapeInferenceTest, AGRUSequenceTest_FORWARD_all_inputs_dynamic_rank) { + constexpr size_t batch_size = 2; + constexpr size_t input_size = 3; + constexpr size_t hidden_size = 5; + constexpr size_t seq_len = 4; + constexpr size_t num_directions = 1; + constexpr size_t gates_count = 3; + + const auto X = std::make_shared(element::f32, PartialShape::dynamic()); + const auto H_t = std::make_shared(element::f32, PartialShape::dynamic()); + const auto seq_lengths = std::make_shared(element::i32, PartialShape::dynamic()); + const auto W = std::make_shared(element::f32, PartialShape::dynamic()); + const auto R = std::make_shared(element::f32, PartialShape::dynamic()); + const auto B = std::make_shared(element::f32, PartialShape::dynamic()); + const auto A = std::make_shared(element::f32, PartialShape::dynamic()); + + const auto augru_sequence = + std::make_shared(X, H_t, seq_lengths, W, R, B, A, hidden_size); + + std::vector static_input_shapes{ + StaticShape{batch_size, seq_len, input_size}, // X + StaticShape{batch_size, num_directions, hidden_size}, // H_t + StaticShape{batch_size}, // seq_lengths + StaticShape{num_directions, gates_count * hidden_size, input_size}, // W + StaticShape{num_directions, gates_count * hidden_size, hidden_size}, // R + StaticShape{num_directions, gates_count * hidden_size}, // B + StaticShape{batch_size, seq_len, 1}}; // A + + std::vector static_output_shapes{StaticShape{}, StaticShape{}}; + + shape_inference(augru_sequence.get(), static_input_shapes, static_output_shapes); + EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, num_directions, seq_len, hidden_size})); + EXPECT_EQ(static_output_shapes[1], StaticShape({batch_size, num_directions, hidden_size})); +} diff --git a/src/plugins/intel_cpu/tests/unit/shape_inference_test/gru_cell_test.cpp b/src/plugins/intel_cpu/tests/unit/shape_inference_test/gru_cell_test.cpp new file mode 100644 index 00000000000..403f8e86e87 --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/shape_inference_test/gru_cell_test.cpp @@ -0,0 +1,126 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include + +using namespace ov; +using namespace ov::intel_cpu; + +TEST(StaticShapeInferenceTest, GRUCellTest_default_bias) { + constexpr size_t batch_size = 2; + constexpr size_t input_size = 3; + constexpr size_t hidden_size = 5; + constexpr size_t gates_count = 3; + + const auto X = std::make_shared(element::f32, PartialShape::dynamic(2)); + const auto H_t = std::make_shared(element::f32, PartialShape::dynamic(2)); + const auto W = std::make_shared(element::f32, PartialShape::dynamic(2)); + const auto R = std::make_shared(element::f32, PartialShape::dynamic(2)); + + // Default `B` input is created as Constant by GRUCell contructor + const auto gru = std::make_shared(X, H_t, W, R, hidden_size); + + std::vector static_input_shapes{StaticShape{batch_size, input_size}, // X + StaticShape{batch_size, hidden_size}, // H_t + StaticShape{gates_count * hidden_size, input_size}, // W + StaticShape{gates_count * hidden_size, hidden_size}, // R + StaticShape{gates_count * hidden_size}}; // B + + std::vector static_output_shapes{StaticShape{}}; + + shape_inference(gru.get(), static_input_shapes, static_output_shapes); + EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size})); +} + +TEST(StaticShapeInferenceTest, GRUCellTest_with_bias) { + constexpr size_t batch_size = 2; + constexpr size_t input_size = 3; + constexpr size_t hidden_size = 5; + constexpr size_t gates_count = 3; + + const auto X = std::make_shared(element::f32, PartialShape::dynamic(2)); + const auto H_t = std::make_shared(element::f32, PartialShape::dynamic(2)); + const auto W = std::make_shared(element::f32, PartialShape::dynamic(2)); + const auto R = std::make_shared(element::f32, PartialShape::dynamic(2)); + const auto B = std::make_shared(element::f32, PartialShape::dynamic(1)); + const auto gru = std::make_shared(X, H_t, W, R, B, hidden_size); + + std::vector static_input_shapes{StaticShape{batch_size, input_size}, // X + StaticShape{batch_size, hidden_size}, // H_t + StaticShape{gates_count * hidden_size, input_size}, // W + StaticShape{gates_count * hidden_size, hidden_size}, // R + StaticShape{gates_count * hidden_size}}; // B + + std::vector static_output_shapes{StaticShape{}, StaticShape{}}; + + shape_inference(gru.get(), static_input_shapes, static_output_shapes); + EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size})); +} + +TEST(StaticShapeInferenceTest, GRUCellTest_linear_before) { + constexpr size_t batch_size = 2; + constexpr size_t input_size = 3; + constexpr size_t hidden_size = 5; + constexpr size_t gates_count = 3; + + const auto X = std::make_shared(element::f32, PartialShape::dynamic(2)); + const auto H_t = std::make_shared(element::f32, PartialShape::dynamic(2)); + const auto W = std::make_shared(element::f32, PartialShape::dynamic(2)); + const auto R = std::make_shared(element::f32, PartialShape::dynamic(2)); + const auto B = std::make_shared(element::f32, PartialShape::dynamic(1)); + + const auto gru = std::make_shared(X, + H_t, + W, + R, + B, + hidden_size, + std::vector{"sigmoid", "tanh"}, + std::vector{}, + std::vector{}, + 0.f, + true); + + std::vector static_input_shapes{StaticShape{batch_size, input_size}, // X + StaticShape{batch_size, hidden_size}, // H_t + StaticShape{gates_count * hidden_size, input_size}, // W + StaticShape{gates_count * hidden_size, hidden_size}, // R + StaticShape{(gates_count + 1) * hidden_size}}; // B + + std::vector static_output_shapes{StaticShape{}, StaticShape{}}; + + shape_inference(gru.get(), static_input_shapes, static_output_shapes); + EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size})); +} + +TEST(StaticShapeInferenceTest, GRUCellTest_dynamic_rank_inputs) { + constexpr size_t batch_size = 2; + constexpr size_t input_size = 3; + constexpr size_t hidden_size = 5; + constexpr size_t gates_count = 3; + + const auto X = std::make_shared(element::f32, PartialShape::dynamic()); + const auto H_t = std::make_shared(element::f32, PartialShape::dynamic()); + const auto W = std::make_shared(element::f32, PartialShape::dynamic()); + const auto R = std::make_shared(element::f32, PartialShape::dynamic()); + const auto B = std::make_shared(element::f32, PartialShape::dynamic()); + + const auto gru = std::make_shared(X, H_t, W, R, B, hidden_size); + + std::vector static_input_shapes{StaticShape{batch_size, input_size}, // X + StaticShape{batch_size, hidden_size}, // H_t + StaticShape{gates_count * hidden_size, input_size}, // W + StaticShape{gates_count * hidden_size, hidden_size}, // R + StaticShape{gates_count * hidden_size}}; // B + + std::vector static_output_shapes{StaticShape{}}; + + shape_inference(gru.get(), static_input_shapes, static_output_shapes); + EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size})); +}