GRU/AUGRUCell shape inference function (#13708)
* Add shape_infer function for GRUCell op * Add shape_infer function for AUGRUCell * Consts refactor * Add batch_size check * Enable GRUCell shape_infer for CPU * Style apply * Use OV_EXPECT_THROW in tests * Use helper for input shapes * Use .back() instead of index * Change rnn_seq namespace to rnn * Fix win warnings * Enable AUGRUCell/Sequence shape_infer on CPU * Fix warn * Fix warn
This commit is contained in:
parent
19c2ec068a
commit
c953186ff0
@ -78,9 +78,9 @@ AUGRU formula:
|
||||
|
||||
* **4**: `R` - 2D tensor of type *T* and shape `[3 * hidden_size, hidden_size]`. The recurrence weights for matrix multiplication, gate order: zrh. **Required.**
|
||||
|
||||
* **6**: `B` - 2D tensor of type *T*. The biases. If *linear_before_reset* is set to `False`, then the shape is `[3 * hidden_size]`, gate order: zrh. Otherwise the shape is `[4 * hidden_size]` - the sum of biases for z and r gates (weights and recurrence weights), the biases for h gate are placed separately. **Required.**
|
||||
* **5**: `B` - 2D tensor of type *T*. The biases. If *linear_before_reset* is set to `False`, then the shape is `[3 * hidden_size]`, gate order: zrh. Otherwise the shape is `[4 * hidden_size]` - the sum of biases for z and r gates (weights and recurrence weights), the biases for h gate are placed separately. **Required.**
|
||||
|
||||
* **7**: `A` - 2D tensor of type *T* and shape `[batch_size, 1]`, the attention score. **Required.**
|
||||
* **6**: `A` - 2D tensor of type *T* and shape `[batch_size, 1]`, the attention score. **Required.**
|
||||
|
||||
|
||||
**Outputs**
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "augru_cell_shape_inference.hpp"
|
||||
#include "itt.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -40,23 +41,6 @@ bool ov::op::internal::AUGRUCell::visit_attributes(AttributeVisitor& visitor) {
|
||||
|
||||
void ov::op::internal::AUGRUCell::validate_and_infer_types() {
|
||||
INTERNAL_OP_SCOPE(internal_AUGRUCell_validate_and_infer_types);
|
||||
for (const auto& input : inputs()) {
|
||||
if (input.get_partial_shape().rank().is_dynamic()) {
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic(2));
|
||||
return;
|
||||
}
|
||||
}
|
||||
auto merged_batch_size = Dimension::dynamic();
|
||||
auto merged_hidden_size = Dimension::dynamic();
|
||||
auto result_et = element::dynamic;
|
||||
|
||||
// Get input partial shape for all inputs
|
||||
const auto& x_pshape = get_input_partial_shape(0);
|
||||
const auto& ht_pshape = get_input_partial_shape(1);
|
||||
const auto& w_pshape = get_input_partial_shape(2);
|
||||
const auto& r_pshape = get_input_partial_shape(3);
|
||||
const auto& b_pshape = get_input_partial_shape(4);
|
||||
const auto& a_pshape = get_input_partial_shape(5);
|
||||
|
||||
NODE_VALIDATION_CHECK(this, m_clip == 0.f, "AUGRUCell doesn't support clip other than 0.");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
@ -69,15 +53,8 @@ void ov::op::internal::AUGRUCell::validate_and_infer_types() {
|
||||
m_linear_before_reset == false,
|
||||
"AUGRUCell supports only linear_before_reset equals false.");
|
||||
|
||||
validate_input_rank_dimension({x_pshape, ht_pshape, w_pshape, r_pshape, b_pshape});
|
||||
|
||||
// `A` input shape validation // [batch_size, 1]
|
||||
NODE_VALIDATION_CHECK(this, a_pshape.rank().compatible(2), "'A' input must be a 2D tensor.");
|
||||
if (a_pshape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(this, a_pshape[1].compatible(1), "The last dimension of `A` shape must be equal to `1`.");
|
||||
}
|
||||
|
||||
// Validate input types and save result for output type
|
||||
auto result_et = element::dynamic;
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(0)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(1)) &&
|
||||
@ -87,55 +64,13 @@ void ov::op::internal::AUGRUCell::validate_and_infer_types() {
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(5)),
|
||||
"Element types for inputs do not match.");
|
||||
|
||||
// Merge batch_size dimension across all inputs to evaluate output[0] dimension
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, ht_pshape[0]) &&
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, a_pshape[0]) &&
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, x_pshape[0]),
|
||||
"Dimension batch_size is not matched between inputs.");
|
||||
// Get input partial shape for all inputs
|
||||
const auto input_shapes = get_node_input_partial_shapes(*this);
|
||||
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape::dynamic(2)};
|
||||
shape_infer(this, input_shapes, output_shapes);
|
||||
|
||||
// Merge hidden_size dimension across all inputs to evaluate output[1] dimension
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
Dimension::merge(merged_hidden_size, merged_hidden_size, ht_pshape[1]) &&
|
||||
Dimension::merge(merged_hidden_size, merged_hidden_size, r_pshape[1]),
|
||||
"Dimension hidden_size not matched for R and initial_hidden_state inputs.");
|
||||
|
||||
// Validate hidden_size value for W, B and R inputs
|
||||
if (merged_hidden_size.is_static()) {
|
||||
if (w_pshape[0].is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
w_pshape[0].compatible(merged_hidden_size * s_gates_count),
|
||||
"Parameter hidden_size mistmatched in W input. Current value is: ",
|
||||
w_pshape[0].get_length(),
|
||||
", expected: ",
|
||||
merged_hidden_size.get_length() * s_gates_count,
|
||||
".");
|
||||
}
|
||||
|
||||
if (r_pshape[0].is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
r_pshape[0].compatible(merged_hidden_size * s_gates_count),
|
||||
"Parameter hidden_size mistmatched in R input. Current value is: ",
|
||||
r_pshape[0].get_length(),
|
||||
", expected: ",
|
||||
merged_hidden_size.get_length() * s_gates_count,
|
||||
".");
|
||||
}
|
||||
|
||||
if (b_pshape[0].is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
b_pshape[0].compatible(merged_hidden_size * (s_gates_count + m_linear_before_reset)),
|
||||
"Parameter hidden_size mistmatched in B input. Current value is: ",
|
||||
b_pshape[0].get_length(),
|
||||
", expected: ",
|
||||
merged_hidden_size.get_length() * (s_gates_count + m_linear_before_reset),
|
||||
".");
|
||||
}
|
||||
}
|
||||
|
||||
// Set output size, type and shape
|
||||
set_output_size(1);
|
||||
set_output_type(0, result_et, {merged_batch_size, merged_hidden_size});
|
||||
// Set output type and shape
|
||||
set_output_type(0, result_et, output_shapes[0]);
|
||||
}
|
||||
|
||||
shared_ptr<ov::Node> ov::op::internal::AUGRUCell::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
|
@ -63,17 +63,8 @@ void ov::op::internal::AUGRUSequence::validate_and_infer_types() {
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(6)),
|
||||
"Element types for inputs do not match.");
|
||||
|
||||
const auto& x_pshape = get_input_partial_shape(0);
|
||||
const auto& ht_pshape = get_input_partial_shape(1);
|
||||
const auto& sl_pshape = get_input_partial_shape(2);
|
||||
const auto& w_pshape = get_input_partial_shape(3);
|
||||
const auto& r_pshape = get_input_partial_shape(4);
|
||||
const auto& b_pshape = get_input_partial_shape(5);
|
||||
const auto& a_pshape = get_input_partial_shape(6);
|
||||
|
||||
std::vector<ov::PartialShape> input_shapes =
|
||||
{x_pshape, ht_pshape, sl_pshape, w_pshape, r_pshape, b_pshape, a_pshape};
|
||||
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{4}, ov::PartialShape{3}};
|
||||
const auto input_shapes = get_node_input_partial_shapes(*this);
|
||||
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape::dynamic(4), ov::PartialShape::dynamic(3)};
|
||||
shape_infer(this, input_shapes, output_shapes);
|
||||
|
||||
// Set output size, type and shape
|
||||
|
@ -0,0 +1,44 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#pragma once
|
||||
|
||||
#include "gru_cell_shape_inference.hpp"
|
||||
#include "ngraph_ops/augru_sequence.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
|
||||
namespace internal {
|
||||
template <class ShapeType>
|
||||
void shape_infer(const ov::op::internal::AUGRUCell* op,
|
||||
const std::vector<ShapeType>& input_shapes,
|
||||
std::vector<ShapeType>& output_shapes) {
|
||||
constexpr size_t expected_in_shapes_count = 6;
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
input_shapes.size() == expected_in_shapes_count,
|
||||
"Incorrect number of input shapes has been provided. Expected: ",
|
||||
expected_in_shapes_count,
|
||||
", got: ",
|
||||
input_shapes.size(),
|
||||
".");
|
||||
|
||||
rnn::gru_cell_shape_infer(op, input_shapes, output_shapes);
|
||||
|
||||
// `A` input shape validation // [batch_size, 1]
|
||||
const auto& a_shape = input_shapes.back();
|
||||
const auto& x_shape = input_shapes[0];
|
||||
NODE_VALIDATION_CHECK(op, a_shape.rank().compatible(2), "'A' input must be a 2D tensor.");
|
||||
if (a_shape.rank().is_static()) {
|
||||
if (x_shape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
x_shape.rank().get_length() > 1 && a_shape[0].compatible(x_shape[0]),
|
||||
"Dimension `batch_size` must be the same for `X` and `A` inputs.");
|
||||
}
|
||||
NODE_VALIDATION_CHECK(op, a_shape[1].compatible(1), "The last dimension of `A` shape must be equal to `1`.");
|
||||
}
|
||||
}
|
||||
} // namespace internal
|
||||
} // namespace op
|
||||
} // namespace ov
|
@ -24,16 +24,19 @@ void shape_infer(const ov::op::internal::AUGRUSequence* op,
|
||||
input_shapes.size(),
|
||||
".");
|
||||
|
||||
rnn_seq::gru_shape_infer(op, input_shapes, output_shapes);
|
||||
rnn::gru_sequence_shape_infer(op, input_shapes, output_shapes);
|
||||
|
||||
// A input shape validation // [batch_size, seq_length, 1]
|
||||
const auto& a_shape = input_shapes[6];
|
||||
const auto& a_shape = input_shapes.back();
|
||||
const auto& x_shape = input_shapes[0];
|
||||
NODE_VALIDATION_CHECK(op, a_shape.rank().compatible(3), "'A' input must be a 3D tensor.");
|
||||
if (a_shape.rank().is_static()) {
|
||||
if (x_shape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
x_shape.rank().get_length() > 1 && a_shape[1].compatible(x_shape[1]),
|
||||
x_shape.rank().get_length() > 1 && a_shape[0].compatible(x_shape[0]),
|
||||
"Dimension `batch_size` must be the same for `X` and `A` inputs.");
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
x_shape.rank().get_length() > 2 && a_shape[1].compatible(x_shape[1]),
|
||||
"Dimension `seq_length` must be the same for `X` and `A` inputs.");
|
||||
}
|
||||
NODE_VALIDATION_CHECK(op, a_shape[2].compatible(1), "The last dimension of `A` shape must be equal to `1`.");
|
||||
|
105
src/core/shape_inference/include/gru_cell_shape_inference.hpp
Normal file
105
src/core/shape_inference/include/gru_cell_shape_inference.hpp
Normal file
@ -0,0 +1,105 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#pragma once
|
||||
#include <openvino/core/validation_util.hpp>
|
||||
#include <openvino/op/gru_cell.hpp>
|
||||
|
||||
#include "gru_cell_shape_inference.hpp"
|
||||
#include "gru_sequence_shape_inference.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace rnn {
|
||||
|
||||
// Output shape layout:
|
||||
// output_shapes[0]: [batch_size, hidden_size] // Rank always 2
|
||||
template <class OpType, class ShapeType>
|
||||
void gru_cell_shape_infer(const OpType* op,
|
||||
const std::vector<ShapeType>& input_shapes,
|
||||
std::vector<ShapeType>& output_shapes) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
input_shapes.size() >= 5 && output_shapes.size() == 1,
|
||||
"Incorrect number of shapes has been provided.");
|
||||
|
||||
auto& y_out_shape = output_shapes[0];
|
||||
y_out_shape.resize(2); // Rank always 2
|
||||
|
||||
rnn::validate_inputs_rank(op, input_shapes, {2, 2, 2, 2, 1});
|
||||
|
||||
const auto& x_pshape = input_shapes[0]; // [batch_size, input_size]
|
||||
const auto& ht_pshape = input_shapes[1]; // [batch_size, hidden_size]
|
||||
const auto& w_pshape = input_shapes[2]; // [3 * hidden_size, input_size]
|
||||
const auto& r_pshape = input_shapes[3]; // [3 * hidden_size, hidden_size]
|
||||
const auto& b_pshape = input_shapes[4]; // if linear_before_reset [4 * hidden_size], otherwise [3 * hidden_size]
|
||||
|
||||
using DimType = typename std::iterator_traits<typename ShapeType::iterator>::value_type;
|
||||
|
||||
// Merge batch_size dimension across all inputs to evaluate output[0] dimension
|
||||
DimType merged_batch_size = x_pshape.rank().is_static() ? x_pshape[0] : DimType();
|
||||
NODE_VALIDATION_CHECK(
|
||||
op,
|
||||
DimType::merge(merged_batch_size, merged_batch_size, ht_pshape.rank().is_static() ? ht_pshape[0] : DimType()),
|
||||
"Dimension `batch_size` is not matched between inputs.");
|
||||
|
||||
// Set batch_size dimension
|
||||
y_out_shape[0] = merged_batch_size;
|
||||
|
||||
// Merge hidden_size dimension across all inputs to evaluate output dimension
|
||||
// `hidden_size` attribute is not used for backward compatibility
|
||||
DimType merged_hidden_size = ht_pshape.rank().is_static() ? ht_pshape[1] : DimType();
|
||||
NODE_VALIDATION_CHECK(
|
||||
op,
|
||||
DimType::merge(merged_hidden_size, merged_hidden_size, r_pshape.rank().is_static() ? r_pshape[1] : DimType()),
|
||||
"Dimension `hidden_size` is not matched between inputs.");
|
||||
|
||||
// Validate dimensions related to hidden_size for W, R, B inputs
|
||||
if (merged_hidden_size.is_static()) {
|
||||
constexpr auto gru_gates_count = 3;
|
||||
if (w_pshape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
w_pshape[0].compatible(merged_hidden_size * gru_gates_count),
|
||||
"First dimension of W input shape is required to be compatible with ",
|
||||
merged_hidden_size * gru_gates_count,
|
||||
". Got shape: ",
|
||||
w_pshape[0],
|
||||
".");
|
||||
}
|
||||
|
||||
if (r_pshape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
r_pshape[0].compatible(merged_hidden_size * gru_gates_count),
|
||||
"Fisrt dimension of R input shape is required to be compatible with ",
|
||||
merged_hidden_size * gru_gates_count,
|
||||
". Got shape: ",
|
||||
r_pshape[0],
|
||||
".");
|
||||
}
|
||||
|
||||
if (b_pshape.rank().is_static()) {
|
||||
auto bias_dim_multiplier = op->get_linear_before_reset() ? (gru_gates_count + 1) : gru_gates_count;
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
b_pshape[0].compatible(merged_hidden_size * bias_dim_multiplier),
|
||||
"First dimension of B input shape is required to be compatible with ",
|
||||
merged_hidden_size * bias_dim_multiplier,
|
||||
". Got shape: ",
|
||||
b_pshape[0],
|
||||
".");
|
||||
}
|
||||
}
|
||||
|
||||
// Set hidden_size dimension
|
||||
y_out_shape[1] = merged_hidden_size;
|
||||
}
|
||||
} // namespace rnn
|
||||
namespace v3 {
|
||||
template <class ShapeType>
|
||||
void shape_infer(const ov::op::v3::GRUCell* op,
|
||||
const std::vector<ShapeType>& input_shapes,
|
||||
std::vector<ShapeType>& output_shapes) {
|
||||
rnn::gru_cell_shape_infer(op, input_shapes, output_shapes);
|
||||
}
|
||||
} // namespace v3
|
||||
} // namespace op
|
||||
} // namespace ov
|
@ -9,7 +9,7 @@
|
||||
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace rnn_seq {
|
||||
namespace rnn {
|
||||
template <class OpType, class ShapeType>
|
||||
void validate_inputs_rank(const OpType* op,
|
||||
const std::vector<ShapeType>& input_shapes,
|
||||
@ -32,9 +32,9 @@ void validate_inputs_rank(const OpType* op,
|
||||
// output_shapes[0]: [batch_size, num_directions, seq_length, hidden_size] // Rank always 4
|
||||
// output_shapes[1]: [batch_size, num_directions, hidden_size] // Rank always 3
|
||||
template <class OpType, class ShapeType>
|
||||
void gru_shape_infer(const OpType* op,
|
||||
const std::vector<ShapeType>& input_shapes,
|
||||
std::vector<ShapeType>& output_shapes) {
|
||||
void gru_sequence_shape_infer(const OpType* op,
|
||||
const std::vector<ShapeType>& input_shapes,
|
||||
std::vector<ShapeType>& output_shapes) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
input_shapes.size() >= 6 && output_shapes.size() == 2,
|
||||
"Incorrect number of shapes has been provided.");
|
||||
@ -44,14 +44,14 @@ void gru_shape_infer(const OpType* op,
|
||||
y_out_shape.resize(4); // Rank always 4
|
||||
ho_out_shape.resize(3); // Rank always 3
|
||||
|
||||
rnn_seq::validate_inputs_rank(op, input_shapes, {3, 3, 1, 3, 3, 2});
|
||||
rnn::validate_inputs_rank(op, input_shapes, {3, 3, 1, 3, 3, 2});
|
||||
|
||||
auto x_pshape = input_shapes[0];
|
||||
auto ht_pshape = input_shapes[1];
|
||||
auto sl_pshape = input_shapes[2];
|
||||
auto w_pshape = input_shapes[3];
|
||||
auto r_pshape = input_shapes[4];
|
||||
auto b_pshape = input_shapes[5];
|
||||
const auto& x_pshape = input_shapes[0];
|
||||
const auto& ht_pshape = input_shapes[1];
|
||||
const auto& sl_pshape = input_shapes[2];
|
||||
const auto& w_pshape = input_shapes[3];
|
||||
const auto& r_pshape = input_shapes[4];
|
||||
const auto& b_pshape = input_shapes[5];
|
||||
|
||||
using DimType = typename std::iterator_traits<typename ShapeType::iterator>::value_type;
|
||||
|
||||
@ -155,7 +155,7 @@ void gru_shape_infer(const OpType* op,
|
||||
y_out_shape[3] = merged_hidden_size;
|
||||
ho_out_shape[2] = merged_hidden_size;
|
||||
}
|
||||
} // namespace rnn_seq
|
||||
} // namespace rnn
|
||||
namespace v5 {
|
||||
template <class ShapeType>
|
||||
void shape_infer(const ov::op::v5::GRUSequence* op,
|
||||
@ -170,7 +170,7 @@ void shape_infer(const ov::op::v5::GRUSequence* op,
|
||||
input_shapes.size(),
|
||||
".");
|
||||
|
||||
rnn_seq::gru_shape_infer(op, input_shapes, output_shapes);
|
||||
rnn::gru_sequence_shape_infer(op, input_shapes, output_shapes);
|
||||
}
|
||||
} // namespace v5
|
||||
} // namespace op
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "gru_cell_shape_inference.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
@ -87,87 +88,22 @@ bool op::v3::GRUCell::visit_attributes(AttributeVisitor& visitor) {
|
||||
|
||||
void op::v3::GRUCell::validate_and_infer_types() {
|
||||
OV_OP_SCOPE(v3_GRUCell_validate_and_infer_types);
|
||||
for (const auto& input : inputs()) {
|
||||
if (input.get_partial_shape().rank().is_dynamic()) {
|
||||
set_output_type(0, get_input_element_type(0), ov::PartialShape::dynamic());
|
||||
return;
|
||||
}
|
||||
}
|
||||
auto merged_batch_size = Dimension::dynamic();
|
||||
auto merged_hidden_size = Dimension::dynamic();
|
||||
auto result_et = element::dynamic;
|
||||
|
||||
// Get input partial shape for all inputs
|
||||
const auto& x_pshape = get_input_partial_shape(0);
|
||||
const auto& ht_pshape = get_input_partial_shape(1);
|
||||
const auto& w_pshape = get_input_partial_shape(2);
|
||||
const auto& r_pshape = get_input_partial_shape(3);
|
||||
const auto& b_pshape = get_input_partial_shape(4);
|
||||
|
||||
validate_input_rank_dimension({x_pshape, ht_pshape, w_pshape, r_pshape, b_pshape});
|
||||
|
||||
// Validate input types and save result for output type
|
||||
auto result_et = element::dynamic;
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(0)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(1)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(2)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(3)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(4)),
|
||||
"Element types for X, initial_hidden_state, W, R and B inputs do not match.");
|
||||
"Element types for X, initial_hidden_state, W, R and B inputs do not "
|
||||
"match.");
|
||||
|
||||
// Merge batch_size dimension across all inputs to evaluate output[0] dimension
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, ht_pshape[0]) &&
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, x_pshape[0]),
|
||||
"Parameter batch_size not matched for X and initial_hidden_state inputs.");
|
||||
const auto input_shapes = get_node_input_partial_shapes(*this);
|
||||
std::vector<ov::PartialShape> output_shapes{ov::PartialShape::dynamic(2)};
|
||||
shape_infer(this, input_shapes, output_shapes);
|
||||
|
||||
// Merge hidden_size dimension across all inputs to evaluate output[1] dimension
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
Dimension::merge(merged_hidden_size, merged_hidden_size, ht_pshape[1]) &&
|
||||
Dimension::merge(merged_hidden_size, merged_hidden_size, r_pshape[1]),
|
||||
"Parameter hidden_size not matched for R and initial_hidden_state inputs.");
|
||||
|
||||
// Validate hidden_size value for W, B and R inputs
|
||||
if (merged_hidden_size.is_static()) {
|
||||
if (w_pshape[0].is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
w_pshape[0].compatible(merged_hidden_size * s_gates_count),
|
||||
"Parameter hidden_size mistmatched in W input. Current value is: ",
|
||||
w_pshape[0].get_length(),
|
||||
", expected: ",
|
||||
merged_hidden_size.get_length() * s_gates_count,
|
||||
".");
|
||||
}
|
||||
|
||||
if (r_pshape[0].is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
r_pshape[0].compatible(merged_hidden_size * s_gates_count),
|
||||
"Parameter hidden_size mistmatched in R input. Current value is: ",
|
||||
r_pshape[0].get_length(),
|
||||
", expected: ",
|
||||
merged_hidden_size.get_length() * s_gates_count,
|
||||
".");
|
||||
}
|
||||
|
||||
if (b_pshape[0].is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
b_pshape[0].compatible(merged_hidden_size * (s_gates_count + m_linear_before_reset)),
|
||||
"Parameter hidden_size mistmatched in B input. Current value is: ",
|
||||
b_pshape[0].get_length(),
|
||||
", expected: ",
|
||||
merged_hidden_size.get_length() * (s_gates_count + m_linear_before_reset),
|
||||
".");
|
||||
}
|
||||
}
|
||||
|
||||
// Mark inputs which are relevant to output parameters
|
||||
set_input_is_relevant_to_shape(0);
|
||||
set_input_is_relevant_to_shape(1);
|
||||
set_input_is_relevant_to_shape(3);
|
||||
|
||||
// Set output size, type and shape
|
||||
set_output_size(1);
|
||||
set_output_type(0, result_et, {merged_batch_size, merged_hidden_size});
|
||||
set_output_type(0, result_et, output_shapes[0]);
|
||||
}
|
||||
|
||||
void op::v3::GRUCell::add_default_bias_input() {
|
||||
|
@ -60,15 +60,8 @@ void op::v5::GRUSequence::validate_and_infer_types() {
|
||||
"Element types for X, initial_hidden_state, W, R and B inputs do not "
|
||||
"match.");
|
||||
|
||||
const auto& x_pshape = get_input_partial_shape(0);
|
||||
const auto& ht_pshape = get_input_partial_shape(1);
|
||||
const auto& sl_pshape = get_input_partial_shape(2);
|
||||
const auto& w_pshape = get_input_partial_shape(3);
|
||||
const auto& r_pshape = get_input_partial_shape(4);
|
||||
const auto& b_pshape = get_input_partial_shape(5);
|
||||
|
||||
std::vector<ov::PartialShape> input_shapes = {x_pshape, ht_pshape, sl_pshape, w_pshape, r_pshape, b_pshape};
|
||||
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape{4}, ov::PartialShape{3}};
|
||||
const auto input_shapes = get_node_input_partial_shapes(*this);
|
||||
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape::dynamic(4), ov::PartialShape::dynamic(3)};
|
||||
shape_infer(this, input_shapes, output_shapes);
|
||||
|
||||
// Set output size, type and shape
|
||||
|
@ -48,7 +48,8 @@ TEST(type_prop, augru_cell_invalid_input) {
|
||||
const auto gru_cell = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
|
||||
FAIL() << "AUGRUCell node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter hidden_size mistmatched in W input."));
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("First dimension of W input shape is required to be compatible"));
|
||||
}
|
||||
|
||||
// Invalid R tensor shape.
|
||||
@ -58,8 +59,7 @@ TEST(type_prop, augru_cell_invalid_input) {
|
||||
const auto gru_cell = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
|
||||
FAIL() << "AUGRUCell node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Dimension hidden_size not matched for R and initial_hidden_state inputs."));
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Dimension `hidden_size` is not matched between inputs"));
|
||||
}
|
||||
|
||||
// Invalid H_t tensor shape.
|
||||
@ -69,7 +69,7 @@ TEST(type_prop, augru_cell_invalid_input) {
|
||||
const auto gru_cell = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
|
||||
FAIL() << "AUGRUCell node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Dimension batch_size is not matched between inputs."));
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Dimension `batch_size` is not matched between inputs"));
|
||||
}
|
||||
|
||||
// Invalid B tensor shape.
|
||||
@ -79,9 +79,8 @@ TEST(type_prop, augru_cell_invalid_input) {
|
||||
const auto gru_cell = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
|
||||
FAIL() << "GRUCell node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(
|
||||
error.what(),
|
||||
std::string("Parameter hidden_size mistmatched in B input. Current value is: 3, expected: 9."));
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("First dimension of B input shape is required to be compatible"));
|
||||
}
|
||||
|
||||
// Invalid A tensor shape.
|
||||
@ -185,11 +184,11 @@ TEST(type_prop, augru_cell_invalid_input_rank) {
|
||||
<< "AUGRUCell node was created with invalid data.";
|
||||
}
|
||||
|
||||
TEST(type_prop, augru_cell_invalid_input_dynamic_rank) {
|
||||
const size_t batch_size = 2;
|
||||
const size_t input_size = 3;
|
||||
const size_t hidden_size = 3;
|
||||
const size_t gates_count = 3;
|
||||
TEST(type_prop, augru_cell_input_dynamic_rank) {
|
||||
int64_t batch_size = 2;
|
||||
int64_t input_size = 3;
|
||||
int64_t hidden_size = 3;
|
||||
int64_t gates_count = 3;
|
||||
|
||||
auto X = make_shared<opset9::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
auto R = make_shared<opset9::Parameter>(element::f32, PartialShape{gates_count * hidden_size, hidden_size});
|
||||
@ -197,41 +196,41 @@ TEST(type_prop, augru_cell_invalid_input_dynamic_rank) {
|
||||
auto B = make_shared<opset9::Parameter>(element::f32, PartialShape{gates_count * hidden_size});
|
||||
auto A = make_shared<opset9::Parameter>(element::f32, PartialShape{batch_size, 1});
|
||||
|
||||
auto check_dynamic_gru = [](const shared_ptr<op::internal::AUGRUCell>& augru) -> bool {
|
||||
return augru->output(0).get_partial_shape() == PartialShape::dynamic(2) &&
|
||||
auto check_dynamic_gru = [&](const shared_ptr<op::internal::AUGRUCell>& augru) -> bool {
|
||||
return augru->output(0).get_partial_shape() == PartialShape{batch_size, hidden_size} &&
|
||||
augru->output(0).get_element_type() == augru->input(0).get_element_type();
|
||||
};
|
||||
|
||||
// Invalid dynamic rank for W tensor.
|
||||
// Dynamic rank for W tensor.
|
||||
auto W = make_shared<opset9::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
auto augru_w = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
|
||||
EXPECT_TRUE(check_dynamic_gru(augru_w));
|
||||
|
||||
// Invalid dynamic rank for X tensor.
|
||||
W = make_shared<opset9::Parameter>(element::f32, PartialShape{hidden_size, input_size});
|
||||
// Dynamic rank for X tensor.
|
||||
W = make_shared<opset9::Parameter>(element::f32, PartialShape{gates_count * hidden_size, input_size});
|
||||
X = make_shared<opset9::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
auto augru_x = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
|
||||
EXPECT_TRUE(check_dynamic_gru(augru_x));
|
||||
|
||||
// Invalid dynamic rank for H_t tensor.
|
||||
// Dynamic rank for H_t tensor.
|
||||
X = make_shared<opset9::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
H_t = make_shared<opset9::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
auto augru_h = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
|
||||
EXPECT_TRUE(check_dynamic_gru(augru_h));
|
||||
|
||||
// Invalid dynamic rank for R tensor.
|
||||
// Dynamic rank for R tensor.
|
||||
H_t = make_shared<opset9::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
R = make_shared<opset9::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
auto augru_r = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
|
||||
EXPECT_TRUE(check_dynamic_gru(augru_r));
|
||||
|
||||
// Invalid dynamic rank for B tensor.
|
||||
// Dynamic rank for B tensor.
|
||||
R = make_shared<opset9::Parameter>(element::f32, PartialShape{gates_count * hidden_size, hidden_size});
|
||||
B = make_shared<opset9::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
auto augru_b = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
|
||||
EXPECT_TRUE(check_dynamic_gru(augru_b));
|
||||
|
||||
// Invalid dynamic rank for A tensor.
|
||||
// Dynamic rank for A tensor.
|
||||
B = make_shared<opset9::Parameter>(element::f32, PartialShape{gates_count * hidden_size});
|
||||
A = make_shared<opset9::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
auto augru_a = make_shared<op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "ngraph_ops/augru_sequence.hpp"
|
||||
|
||||
#include "common_test_utils/test_assertions.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "openvino/core/attribute_visitor.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
@ -11,6 +12,7 @@
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace testing;
|
||||
|
||||
struct augru_sequence_parameters {
|
||||
Dimension batch_size = 8;
|
||||
@ -258,7 +260,7 @@ TEST(type_prop, augru_sequence_all_inputs_dynamic_rank) {
|
||||
EXPECT_EQ(augru_sequence->get_output_element_type(1), param.et);
|
||||
}
|
||||
|
||||
TEST(type_prop, augru_sequence_invalid_attention_gate) {
|
||||
TEST(type_prop, augru_sequence_invalid_attention_gate_seq_length) {
|
||||
augru_sequence_parameters params;
|
||||
|
||||
params.batch_size = 8;
|
||||
@ -272,13 +274,28 @@ TEST(type_prop, augru_sequence_invalid_attention_gate) {
|
||||
auto invalid_attention_gate = make_shared<opset9::Parameter>(params.et, PartialShape{params.batch_size, 999, 1});
|
||||
augru_sequence->set_argument(6, invalid_attention_gate);
|
||||
|
||||
try {
|
||||
augru_sequence->validate_and_infer_types();
|
||||
FAIL() << "AUGRUSequence node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Dimension `seq_length` must be the same for `X` and `A` inputs."));
|
||||
}
|
||||
OV_EXPECT_THROW(augru_sequence->validate_and_infer_types(),
|
||||
ov::NodeValidationFailure,
|
||||
HasSubstr("Dimension `seq_length` must be the same for `X` and `A` inputs"));
|
||||
}
|
||||
|
||||
TEST(type_prop, augru_sequence_invalid_attention_gate_batch) {
|
||||
augru_sequence_parameters params;
|
||||
|
||||
params.batch_size = 8;
|
||||
params.num_directions = 1;
|
||||
params.seq_length = 6;
|
||||
params.input_size = 4;
|
||||
params.hidden_size = 128;
|
||||
params.et = element::f32;
|
||||
|
||||
auto augru_sequence = augru_seq_init(params);
|
||||
auto invalid_attention_gate = make_shared<opset9::Parameter>(params.et, PartialShape{999, params.seq_length, 1});
|
||||
augru_sequence->set_argument(6, invalid_attention_gate);
|
||||
|
||||
OV_EXPECT_THROW(augru_sequence->validate_and_infer_types(),
|
||||
ov::NodeValidationFailure,
|
||||
HasSubstr("Dimension `batch_size` must be the same for `X` and `A` inputs"));
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -2,6 +2,7 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "common_test_utils/test_assertions.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
@ -9,6 +10,7 @@
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using namespace testing;
|
||||
|
||||
TEST(type_prop, gru_cell) {
|
||||
const size_t batch_size = 2;
|
||||
@ -38,44 +40,30 @@ TEST(type_prop, gru_cell_invalid_input) {
|
||||
|
||||
// Invalid W tensor shape.
|
||||
auto W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
|
||||
try {
|
||||
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
FAIL() << "GRUCell node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter hidden_size mistmatched in W input."));
|
||||
}
|
||||
OV_EXPECT_THROW(auto op = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
|
||||
ov::NodeValidationFailure,
|
||||
HasSubstr("First dimension of W input shape is required to be compatible"));
|
||||
|
||||
// Invalid R tensor shape.
|
||||
W = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
|
||||
R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, 1});
|
||||
try {
|
||||
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
FAIL() << "GRUCell node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Parameter hidden_size not matched for R and initial_hidden_state inputs."));
|
||||
}
|
||||
OV_EXPECT_THROW(auto op = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
|
||||
ov::NodeValidationFailure,
|
||||
HasSubstr("Dimension `hidden_size` is not matched between inputs"));
|
||||
|
||||
// Invalid H_t tensor shape.
|
||||
R = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
|
||||
H_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
|
||||
try {
|
||||
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
FAIL() << "GRUCell node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Parameter batch_size not matched for X and initial_hidden_state inputs."));
|
||||
}
|
||||
OV_EXPECT_THROW(auto op = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
|
||||
ov::NodeValidationFailure,
|
||||
HasSubstr("Dimension `batch_size` is not matched between inputs"));
|
||||
|
||||
// Invalid B tensor shape.
|
||||
H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
auto B = make_shared<op::Parameter>(element::f32, Shape{hidden_size});
|
||||
try {
|
||||
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, B, hidden_size);
|
||||
FAIL() << "GRUCell node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter hidden_size mistmatched in B input."));
|
||||
}
|
||||
OV_EXPECT_THROW(auto op = make_shared<opset4::GRUCell>(X, H_t, W, R, B, hidden_size),
|
||||
ov::NodeValidationFailure,
|
||||
HasSubstr("First dimension of B input shape is required to be compatible"));
|
||||
}
|
||||
|
||||
TEST(type_prop, gru_cell_dynamic_batch_size) {
|
||||
@ -171,45 +159,45 @@ TEST(type_prop, gru_cell_invalid_input_rank0) {
|
||||
<< "GRUCell node was created with invalid data.";
|
||||
}
|
||||
|
||||
TEST(type_prop, gru_cell_invalid_input_dynamic_rank) {
|
||||
const size_t batch_size = 2;
|
||||
const size_t input_size = 3;
|
||||
const size_t hidden_size = 3;
|
||||
const size_t gates_count = 3;
|
||||
TEST(type_prop, gru_cell_input_dynamic_rank) {
|
||||
int64_t batch_size = 2;
|
||||
int64_t input_size = 3;
|
||||
int64_t hidden_size = 3;
|
||||
int64_t gates_count = 3;
|
||||
|
||||
auto X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
auto R = make_shared<op::Parameter>(element::f32, PartialShape{gates_count * hidden_size, hidden_size});
|
||||
auto H_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
|
||||
auto check_dynamic_gru = [](const shared_ptr<opset4::GRUCell>& gru) -> bool {
|
||||
return gru->output(0).get_partial_shape() == PartialShape::dynamic() &&
|
||||
auto check_dynamic_gru = [&](const shared_ptr<opset4::GRUCell>& gru) -> bool {
|
||||
return gru->output(0).get_partial_shape() == PartialShape{batch_size, hidden_size} &&
|
||||
gru->output(0).get_element_type() == gru->input(0).get_element_type();
|
||||
};
|
||||
|
||||
// Invalid dynamic rank for W tensor.
|
||||
// Dynamic rank for W tensor.
|
||||
auto W = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
auto gru_w = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_gru(gru_w), true);
|
||||
|
||||
// Invalid dynamic rank for X tensor.
|
||||
W = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, input_size});
|
||||
// Dynamic rank for X tensor.
|
||||
W = make_shared<op::Parameter>(element::f32, PartialShape{gates_count * hidden_size, input_size});
|
||||
X = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
auto gru_x = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_gru(gru_x), true);
|
||||
|
||||
// Invalid dynamic rank for H_t tensor.
|
||||
// Dynamic rank for H_t tensor.
|
||||
X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
H_t = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
auto gru_h = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_gru(gru_h), true);
|
||||
|
||||
// Invalid dynamic rank for R tensor.
|
||||
// Dynamic rank for R tensor.
|
||||
H_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
R = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
auto gru_r = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
EXPECT_EQ(check_dynamic_gru(gru_r), true);
|
||||
|
||||
// Invalid dynamic rank for B tensor.
|
||||
// Dynamic rank for B tensor.
|
||||
R = make_shared<op::Parameter>(element::f32, PartialShape{gates_count * hidden_size, hidden_size});
|
||||
auto B = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
auto gru_b = make_shared<opset4::GRUCell>(X, H_t, W, R, B, hidden_size);
|
||||
|
@ -11,7 +11,12 @@
|
||||
#include <openvino/opsets/opset7.hpp>
|
||||
#include <openvino/opsets/opset8.hpp>
|
||||
|
||||
#include "ngraph_ops/augru_cell.hpp"
|
||||
#include "ngraph_ops/augru_sequence.hpp"
|
||||
|
||||
#include "assign_shape_inference.hpp"
|
||||
#include "augru_cell_shape_inference.hpp"
|
||||
#include "augru_sequence_shape_inference.hpp"
|
||||
#include "batch_to_space_shape_inference.hpp"
|
||||
#include "broadcast_shape_inference.hpp"
|
||||
#include "bucketize_shape_inference.hpp"
|
||||
@ -38,6 +43,7 @@
|
||||
#include "gather_shape_inference.hpp"
|
||||
#include "gather_tree_shape_inference.hpp"
|
||||
#include "gru_sequence_shape_inference.hpp"
|
||||
#include "gru_cell_shape_inference.hpp"
|
||||
#include "interpolate_shape_inference.hpp"
|
||||
#include "lstm_cell_shape_inference.hpp"
|
||||
#include "matmul_shape_inference.hpp"
|
||||
@ -494,6 +500,12 @@ std::shared_ptr<IShapeInfer> make_shape_inference(const std::shared_ptr<ngraph::
|
||||
return make_shared_entryIO(node);
|
||||
} else if (auto node = ov::as_type_ptr<ov::opset5::GRUSequence>(op)) {
|
||||
return make_shared_entryIO(node);
|
||||
} else if (auto node = ov::as_type_ptr<ov::op::internal::AUGRUSequence>(op)) {
|
||||
return make_shared_entryIO(node);
|
||||
} else if (auto node = ov::as_type_ptr<ov::opset3::GRUCell>(op)) {
|
||||
return make_shared_entryIO(node);
|
||||
} else if (auto node = ov::as_type_ptr<ov::op::internal::AUGRUCell>(op)) {
|
||||
return make_shared_entryIO(node);
|
||||
} else if (auto node = ov::as_type_ptr<ov::opset1::OneHot>(op)) {
|
||||
return make_shared_entryIOC(node);
|
||||
} else if (auto node = ov::as_type_ptr<ov::opset4::CTCLoss>(op)) {
|
||||
|
@ -0,0 +1,71 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph_ops/augru_cell.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <openvino/op/ops.hpp>
|
||||
#include <openvino/op/parameter.hpp>
|
||||
#include <utils/shape_inference/shape_inference.hpp>
|
||||
#include <utils/shape_inference/static_shape.hpp>
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::intel_cpu;
|
||||
|
||||
TEST(StaticShapeInferenceTest, AUGRUCellTest_all_inputs_static_rank) {
|
||||
constexpr size_t batch_size = 2;
|
||||
constexpr size_t input_size = 3;
|
||||
constexpr size_t hidden_size = 5;
|
||||
constexpr size_t gates_count = 3;
|
||||
|
||||
const auto X = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
const auto H_t = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
const auto W = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
const auto R = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
const auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(1));
|
||||
const auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
|
||||
const auto augru = std::make_shared<ov::op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes{StaticShape{batch_size, input_size}, // X
|
||||
StaticShape{batch_size, hidden_size}, // H_t
|
||||
StaticShape{gates_count * hidden_size, input_size}, // W
|
||||
StaticShape{gates_count * hidden_size, hidden_size}, // R
|
||||
StaticShape{gates_count * hidden_size}, // B
|
||||
StaticShape{batch_size, 1}}; // A
|
||||
|
||||
std::vector<StaticShape> static_output_shapes{StaticShape{}, StaticShape{}};
|
||||
|
||||
shape_inference(augru.get(), static_input_shapes, static_output_shapes);
|
||||
EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size}));
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, AUGRUCellTest_all_inputs_dynamic_rank) {
|
||||
constexpr size_t batch_size = 2;
|
||||
constexpr size_t input_size = 3;
|
||||
constexpr size_t hidden_size = 5;
|
||||
constexpr size_t gates_count = 3;
|
||||
|
||||
const auto X = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto H_t = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto W = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto R = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
|
||||
const auto augru = std::make_shared<ov::op::internal::AUGRUCell>(X, H_t, W, R, B, A, hidden_size);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes{StaticShape{batch_size, input_size}, // X
|
||||
StaticShape{batch_size, hidden_size}, // H_t
|
||||
StaticShape{gates_count * hidden_size, input_size}, // W
|
||||
StaticShape{gates_count * hidden_size, hidden_size}, // R
|
||||
StaticShape{gates_count * hidden_size}, // B
|
||||
StaticShape{batch_size, 1}}; // A
|
||||
|
||||
std::vector<StaticShape> static_output_shapes{StaticShape{}, StaticShape{}};
|
||||
|
||||
shape_inference(augru.get(), static_input_shapes, static_output_shapes);
|
||||
EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size}));
|
||||
}
|
@ -0,0 +1,85 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph_ops/augru_sequence.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <openvino/op/ops.hpp>
|
||||
#include <openvino/op/parameter.hpp>
|
||||
#include <utils/shape_inference/shape_inference.hpp>
|
||||
#include <utils/shape_inference/static_shape.hpp>
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::intel_cpu;
|
||||
|
||||
TEST(StaticShapeInferenceTest, AGRUSequenceTest_FORWARD_all_static_rank) {
|
||||
constexpr size_t batch_size = 2;
|
||||
constexpr size_t input_size = 3;
|
||||
constexpr size_t hidden_size = 5;
|
||||
constexpr size_t seq_len = 4;
|
||||
constexpr size_t num_directions = 1;
|
||||
constexpr size_t gates_count = 3;
|
||||
|
||||
const auto X = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(3));
|
||||
const auto H_t = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(3));
|
||||
const auto seq_lengths = std::make_shared<op::v0::Parameter>(element::i32, PartialShape::dynamic(1));
|
||||
const auto W = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(3));
|
||||
const auto R = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(3));
|
||||
const auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
const auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(3));
|
||||
|
||||
const auto augru_sequence =
|
||||
std::make_shared<ov::op::internal::AUGRUSequence>(X, H_t, seq_lengths, W, R, B, A, hidden_size);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes{
|
||||
StaticShape{batch_size, seq_len, input_size}, // X
|
||||
StaticShape{batch_size, num_directions, hidden_size}, // H_t
|
||||
StaticShape{batch_size}, // seq_lengths
|
||||
StaticShape{num_directions, gates_count * hidden_size, input_size}, // W
|
||||
StaticShape{num_directions, gates_count * hidden_size, hidden_size}, // R
|
||||
StaticShape{num_directions, gates_count * hidden_size}, // B
|
||||
StaticShape{batch_size, seq_len, 1}}; // A
|
||||
|
||||
std::vector<StaticShape> static_output_shapes{StaticShape{}, StaticShape{}};
|
||||
|
||||
shape_inference(augru_sequence.get(), static_input_shapes, static_output_shapes);
|
||||
EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, num_directions, seq_len, hidden_size}));
|
||||
EXPECT_EQ(static_output_shapes[1], StaticShape({batch_size, num_directions, hidden_size}));
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, AGRUSequenceTest_FORWARD_all_inputs_dynamic_rank) {
|
||||
constexpr size_t batch_size = 2;
|
||||
constexpr size_t input_size = 3;
|
||||
constexpr size_t hidden_size = 5;
|
||||
constexpr size_t seq_len = 4;
|
||||
constexpr size_t num_directions = 1;
|
||||
constexpr size_t gates_count = 3;
|
||||
|
||||
const auto X = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto H_t = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto seq_lengths = std::make_shared<op::v0::Parameter>(element::i32, PartialShape::dynamic());
|
||||
const auto W = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto R = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto A = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
|
||||
const auto augru_sequence =
|
||||
std::make_shared<ov::op::internal::AUGRUSequence>(X, H_t, seq_lengths, W, R, B, A, hidden_size);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes{
|
||||
StaticShape{batch_size, seq_len, input_size}, // X
|
||||
StaticShape{batch_size, num_directions, hidden_size}, // H_t
|
||||
StaticShape{batch_size}, // seq_lengths
|
||||
StaticShape{num_directions, gates_count * hidden_size, input_size}, // W
|
||||
StaticShape{num_directions, gates_count * hidden_size, hidden_size}, // R
|
||||
StaticShape{num_directions, gates_count * hidden_size}, // B
|
||||
StaticShape{batch_size, seq_len, 1}}; // A
|
||||
|
||||
std::vector<StaticShape> static_output_shapes{StaticShape{}, StaticShape{}};
|
||||
|
||||
shape_inference(augru_sequence.get(), static_input_shapes, static_output_shapes);
|
||||
EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, num_directions, seq_len, hidden_size}));
|
||||
EXPECT_EQ(static_output_shapes[1], StaticShape({batch_size, num_directions, hidden_size}));
|
||||
}
|
@ -0,0 +1,126 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <openvino/op/ops.hpp>
|
||||
#include <openvino/op/parameter.hpp>
|
||||
#include <utils/shape_inference/shape_inference.hpp>
|
||||
#include <utils/shape_inference/static_shape.hpp>
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::intel_cpu;
|
||||
|
||||
TEST(StaticShapeInferenceTest, GRUCellTest_default_bias) {
|
||||
constexpr size_t batch_size = 2;
|
||||
constexpr size_t input_size = 3;
|
||||
constexpr size_t hidden_size = 5;
|
||||
constexpr size_t gates_count = 3;
|
||||
|
||||
const auto X = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
const auto H_t = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
const auto W = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
const auto R = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
|
||||
// Default `B` input is created as Constant by GRUCell contructor
|
||||
const auto gru = std::make_shared<op::v3::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes{StaticShape{batch_size, input_size}, // X
|
||||
StaticShape{batch_size, hidden_size}, // H_t
|
||||
StaticShape{gates_count * hidden_size, input_size}, // W
|
||||
StaticShape{gates_count * hidden_size, hidden_size}, // R
|
||||
StaticShape{gates_count * hidden_size}}; // B
|
||||
|
||||
std::vector<StaticShape> static_output_shapes{StaticShape{}};
|
||||
|
||||
shape_inference(gru.get(), static_input_shapes, static_output_shapes);
|
||||
EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size}));
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, GRUCellTest_with_bias) {
|
||||
constexpr size_t batch_size = 2;
|
||||
constexpr size_t input_size = 3;
|
||||
constexpr size_t hidden_size = 5;
|
||||
constexpr size_t gates_count = 3;
|
||||
|
||||
const auto X = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
const auto H_t = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
const auto W = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
const auto R = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
const auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(1));
|
||||
const auto gru = std::make_shared<op::v3::GRUCell>(X, H_t, W, R, B, hidden_size);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes{StaticShape{batch_size, input_size}, // X
|
||||
StaticShape{batch_size, hidden_size}, // H_t
|
||||
StaticShape{gates_count * hidden_size, input_size}, // W
|
||||
StaticShape{gates_count * hidden_size, hidden_size}, // R
|
||||
StaticShape{gates_count * hidden_size}}; // B
|
||||
|
||||
std::vector<StaticShape> static_output_shapes{StaticShape{}, StaticShape{}};
|
||||
|
||||
shape_inference(gru.get(), static_input_shapes, static_output_shapes);
|
||||
EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size}));
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, GRUCellTest_linear_before) {
|
||||
constexpr size_t batch_size = 2;
|
||||
constexpr size_t input_size = 3;
|
||||
constexpr size_t hidden_size = 5;
|
||||
constexpr size_t gates_count = 3;
|
||||
|
||||
const auto X = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
const auto H_t = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
const auto W = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
const auto R = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(2));
|
||||
const auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic(1));
|
||||
|
||||
const auto gru = std::make_shared<op::v3::GRUCell>(X,
|
||||
H_t,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
std::vector<std::string>{"sigmoid", "tanh"},
|
||||
std::vector<float>{},
|
||||
std::vector<float>{},
|
||||
0.f,
|
||||
true);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes{StaticShape{batch_size, input_size}, // X
|
||||
StaticShape{batch_size, hidden_size}, // H_t
|
||||
StaticShape{gates_count * hidden_size, input_size}, // W
|
||||
StaticShape{gates_count * hidden_size, hidden_size}, // R
|
||||
StaticShape{(gates_count + 1) * hidden_size}}; // B
|
||||
|
||||
std::vector<StaticShape> static_output_shapes{StaticShape{}, StaticShape{}};
|
||||
|
||||
shape_inference(gru.get(), static_input_shapes, static_output_shapes);
|
||||
EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size}));
|
||||
}
|
||||
|
||||
TEST(StaticShapeInferenceTest, GRUCellTest_dynamic_rank_inputs) {
|
||||
constexpr size_t batch_size = 2;
|
||||
constexpr size_t input_size = 3;
|
||||
constexpr size_t hidden_size = 5;
|
||||
constexpr size_t gates_count = 3;
|
||||
|
||||
const auto X = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto H_t = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto W = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto R = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
const auto B = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||
|
||||
const auto gru = std::make_shared<op::v3::GRUCell>(X, H_t, W, R, B, hidden_size);
|
||||
|
||||
std::vector<StaticShape> static_input_shapes{StaticShape{batch_size, input_size}, // X
|
||||
StaticShape{batch_size, hidden_size}, // H_t
|
||||
StaticShape{gates_count * hidden_size, input_size}, // W
|
||||
StaticShape{gates_count * hidden_size, hidden_size}, // R
|
||||
StaticShape{gates_count * hidden_size}}; // B
|
||||
|
||||
std::vector<StaticShape> static_output_shapes{StaticShape{}};
|
||||
|
||||
shape_inference(gru.get(), static_input_shapes, static_output_shapes);
|
||||
EXPECT_EQ(static_output_shapes[0], StaticShape({batch_size, hidden_size}));
|
||||
}
|
Loading…
Reference in New Issue
Block a user