revise RNNCell RNNsequence operation class (#7335)
* revise RNNCell RNNsequence operation class Signed-off-by: Hu, Yuan2 <yuan2.hu@intel.com> * fix clang sytax check error Signed-off-by: Hu, Yuan2 <yuan2.hu@intel.com>
This commit is contained in:
parent
6df94afdcb
commit
97d937c8ea
@ -97,6 +97,20 @@ void op::v5::RNNSequence::validate_and_infer_types() {
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, b_pshape[0]),
|
||||
"Parameter num_directions not matched in RNNSequence.");
|
||||
|
||||
auto valid_num_directions = 0;
|
||||
if (m_direction == op::RecurrentSequenceDirection::FORWARD ||
|
||||
m_direction == op::RecurrentSequenceDirection::REVERSE) {
|
||||
valid_num_directions = 1;
|
||||
} else if (m_direction == op::RecurrentSequenceDirection::BIDIRECTIONAL) {
|
||||
valid_num_directions = 2;
|
||||
} else {
|
||||
NODE_VALIDATION_CHECK(this, false, "Parameter direction must be FORWARD or REVERSE or BIDIRECTIONAL.");
|
||||
}
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, valid_num_directions),
|
||||
"Parameter num_directions not match direction in RNNSequence.");
|
||||
|
||||
// Validate hidden_size value for W, R, B inputs
|
||||
if (merged_hidden_size.is_static()) {
|
||||
if (w_pshape[1].is_static()) {
|
||||
|
@ -329,6 +329,7 @@ set(SRC
|
||||
visitors/op/reverse.cpp
|
||||
visitors/op/reverse_sequence.cpp
|
||||
visitors/op/rnn_cell.cpp
|
||||
visitors/op/rnn_sequence.cpp
|
||||
visitors/op/roi_pooling.cpp
|
||||
visitors/op/round.cpp
|
||||
visitors/op/scatter_elements_update.cpp
|
||||
|
@ -43,3 +43,298 @@ TEST(type_prop, rnn_sequence_forward) {
|
||||
EXPECT_EQ(sequence->get_output_element_type(1), element::f32);
|
||||
EXPECT_EQ(sequence->get_output_shape(1), (Shape{batch_size, num_directions, hidden_size}));
|
||||
}
|
||||
|
||||
TEST(type_prop, rnn_sequence_invalid_input) {
|
||||
const size_t batch_size = 8;
|
||||
const size_t num_directions = 1;
|
||||
const size_t seq_length = 6;
|
||||
const size_t input_size = 4;
|
||||
const size_t hidden_size = 128;
|
||||
|
||||
auto X = make_shared<opset5::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
|
||||
auto H_t = make_shared<opset5::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{batch_size});
|
||||
|
||||
auto W = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, hidden_size, input_size});
|
||||
auto R = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, hidden_size, hidden_size});
|
||||
auto B = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, hidden_size});
|
||||
|
||||
auto direction = op::RecurrentSequenceDirection::FORWARD;
|
||||
|
||||
// Invalid W tensor shape.
|
||||
W = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, 2 * hidden_size, input_size});
|
||||
try {
|
||||
const auto rnn_sequence =
|
||||
make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction);
|
||||
FAIL() << "RNNSequence node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter hidden_size mistmatched in W input"));
|
||||
}
|
||||
|
||||
// Invalid R tensor shape.
|
||||
W = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, hidden_size, input_size});
|
||||
R = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, hidden_size, 1});
|
||||
try {
|
||||
const auto rnn_sequence =
|
||||
make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction);
|
||||
FAIL() << "RNNSequence node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter hidden_size not matched RNNSequence."));
|
||||
}
|
||||
|
||||
// Invalid H_t tensor shape.
|
||||
R = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, hidden_size, hidden_size});
|
||||
H_t = make_shared<opset5::Parameter>(element::f32, Shape{4, num_directions, hidden_size});
|
||||
try {
|
||||
const auto rnn_sequence =
|
||||
make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction);
|
||||
FAIL() << "RNNSequence node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter batch_size not matched in RNNSequence."));
|
||||
}
|
||||
|
||||
// Invalid B tensor shape.
|
||||
H_t = make_shared<opset5::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
B = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, 2 * hidden_size});
|
||||
try {
|
||||
const auto rnn_sequence =
|
||||
make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction);
|
||||
FAIL() << "RNNSequence node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter hidden_size mistmatched in B inpu"));
|
||||
}
|
||||
|
||||
// Invalid direction.
|
||||
B = make_shared<opset5::Parameter>(element::f32, Shape{2, hidden_size});
|
||||
try {
|
||||
const auto rnn_sequence =
|
||||
make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction);
|
||||
FAIL() << "RNNSequence node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter num_directions not matched in RNNSequence."));
|
||||
}
|
||||
|
||||
// Invalid direction.
|
||||
B = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, hidden_size});
|
||||
direction = op::RecurrentSequenceDirection::BIDIRECTIONAL;
|
||||
try {
|
||||
const auto rnn_sequence =
|
||||
make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction);
|
||||
FAIL() << "RNNSequence node was created with invalid data.";
|
||||
} catch (const NodeValidationFailure& error) {
|
||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter num_directions not match direction in RNNSequence."));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, rnn_sequence_dynamic_inputs) {
|
||||
const auto batch_size = Dimension::dynamic();
|
||||
const size_t num_directions = 1;
|
||||
const size_t seq_length = 6;
|
||||
const auto input_size = Dimension::dynamic();
|
||||
const auto hidden_size = Dimension::dynamic();
|
||||
|
||||
const auto X = make_shared<opset5::Parameter>(element::f32, PartialShape{batch_size, seq_length, input_size});
|
||||
const auto H_t =
|
||||
make_shared<opset5::Parameter>(element::f32, PartialShape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, PartialShape{batch_size});
|
||||
|
||||
const auto W = make_shared<opset5::Parameter>(element::f32, PartialShape{num_directions, hidden_size, input_size});
|
||||
const auto R = make_shared<opset5::Parameter>(element::f32, PartialShape{num_directions, hidden_size, hidden_size});
|
||||
const auto B = make_shared<opset5::Parameter>(element::f32, PartialShape{num_directions, hidden_size});
|
||||
|
||||
const auto direction = op::RecurrentSequenceDirection::REVERSE;
|
||||
|
||||
const auto sequence = make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, 128, direction);
|
||||
|
||||
EXPECT_EQ(sequence->outputs().size(), 2);
|
||||
EXPECT_EQ(sequence->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(sequence->get_output_partial_shape(0),
|
||||
(PartialShape{batch_size, num_directions, seq_length, hidden_size}));
|
||||
EXPECT_EQ(sequence->get_output_element_type(1), element::f32);
|
||||
EXPECT_EQ(sequence->get_output_partial_shape(1), (PartialShape{batch_size, num_directions, hidden_size}));
|
||||
}
|
||||
|
||||
TEST(type_prop, rnn_sequence_dynamic_batch_size) {
|
||||
const auto batch_size = Dimension::dynamic();
|
||||
const size_t num_directions = 1;
|
||||
const size_t seq_length = 6;
|
||||
const size_t input_size = 4;
|
||||
const size_t hidden_size = 128;
|
||||
|
||||
const auto X = make_shared<opset5::Parameter>(element::f32, PartialShape{batch_size, seq_length, input_size});
|
||||
const auto H_t =
|
||||
make_shared<opset5::Parameter>(element::f32, PartialShape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, PartialShape{batch_size});
|
||||
|
||||
const auto W = make_shared<opset5::Parameter>(element::f32, PartialShape{num_directions, hidden_size, input_size});
|
||||
const auto R = make_shared<opset5::Parameter>(element::f32, PartialShape{num_directions, hidden_size, hidden_size});
|
||||
const auto B = make_shared<opset5::Parameter>(element::f32, PartialShape{num_directions, hidden_size});
|
||||
|
||||
const auto direction = op::RecurrentSequenceDirection::FORWARD;
|
||||
|
||||
const auto sequence = make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction);
|
||||
|
||||
EXPECT_EQ(sequence->outputs().size(), 2);
|
||||
EXPECT_EQ(sequence->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(sequence->get_output_partial_shape(0),
|
||||
(PartialShape{batch_size, num_directions, seq_length, hidden_size}));
|
||||
EXPECT_EQ(sequence->get_output_element_type(1), element::f32);
|
||||
EXPECT_EQ(sequence->get_output_partial_shape(1), (PartialShape{batch_size, num_directions, hidden_size}));
|
||||
}
|
||||
|
||||
TEST(type_prop, rnn_sequence_dynamic_input_size) {
|
||||
const size_t batch_size = 8;
|
||||
const size_t num_directions = 1;
|
||||
const size_t seq_length = 6;
|
||||
const auto input_size = Dimension::dynamic();
|
||||
const size_t hidden_size = 128;
|
||||
|
||||
const auto X = make_shared<opset5::Parameter>(element::f32, PartialShape{batch_size, seq_length, input_size});
|
||||
const auto H_t =
|
||||
make_shared<opset5::Parameter>(element::f32, PartialShape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, PartialShape{batch_size});
|
||||
|
||||
const auto W = make_shared<opset5::Parameter>(element::f32, PartialShape{num_directions, hidden_size, input_size});
|
||||
const auto R = make_shared<opset5::Parameter>(element::f32, PartialShape{num_directions, hidden_size, hidden_size});
|
||||
const auto B = make_shared<opset5::Parameter>(element::f32, PartialShape{num_directions, hidden_size});
|
||||
|
||||
const auto direction = op::RecurrentSequenceDirection::FORWARD;
|
||||
|
||||
const auto sequence = make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction);
|
||||
|
||||
EXPECT_EQ(sequence->outputs().size(), 2);
|
||||
EXPECT_EQ(sequence->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(sequence->get_output_partial_shape(0),
|
||||
(PartialShape{batch_size, num_directions, seq_length, hidden_size}));
|
||||
EXPECT_EQ(sequence->get_output_element_type(1), element::f32);
|
||||
EXPECT_EQ(sequence->get_output_partial_shape(1), (PartialShape{batch_size, num_directions, hidden_size}));
|
||||
}
|
||||
|
||||
TEST(type_prop, rnn_sequence_dynamic_hidden_size) {
|
||||
const size_t batch_size = 8;
|
||||
const size_t num_directions = 1;
|
||||
const size_t seq_length = 6;
|
||||
const size_t input_size = 4;
|
||||
const auto hidden_size = Dimension::dynamic();
|
||||
|
||||
const auto X = make_shared<opset5::Parameter>(element::f32, PartialShape{batch_size, seq_length, input_size});
|
||||
const auto H_t =
|
||||
make_shared<opset5::Parameter>(element::f32, PartialShape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, PartialShape{batch_size});
|
||||
|
||||
const auto W = make_shared<opset5::Parameter>(element::f32, PartialShape{num_directions, hidden_size, input_size});
|
||||
const auto R = make_shared<opset5::Parameter>(element::f32, PartialShape{num_directions, hidden_size, hidden_size});
|
||||
const auto B = make_shared<opset5::Parameter>(element::f32, PartialShape{num_directions, hidden_size});
|
||||
|
||||
const auto direction = op::RecurrentSequenceDirection::FORWARD;
|
||||
|
||||
const auto sequence = make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, 128, direction);
|
||||
|
||||
EXPECT_EQ(sequence->outputs().size(), 2);
|
||||
EXPECT_EQ(sequence->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(sequence->get_output_partial_shape(0),
|
||||
(PartialShape{batch_size, num_directions, seq_length, hidden_size}));
|
||||
EXPECT_EQ(sequence->get_output_element_type(1), element::f32);
|
||||
EXPECT_EQ(sequence->get_output_partial_shape(1), (PartialShape{batch_size, num_directions, hidden_size}));
|
||||
}
|
||||
|
||||
TEST(type_prop, rnn_sequence_dynamic_invalid_input_rank0) {
|
||||
const size_t batch_size = 8;
|
||||
const size_t num_directions = 1;
|
||||
const size_t seq_length = 6;
|
||||
const size_t input_size = 4;
|
||||
const size_t hidden_size = 128;
|
||||
|
||||
auto X = make_shared<opset5::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
|
||||
auto H_t = make_shared<opset5::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{batch_size});
|
||||
|
||||
auto W = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, hidden_size, input_size});
|
||||
auto R = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, hidden_size, hidden_size});
|
||||
auto B = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, hidden_size});
|
||||
|
||||
const auto direction = op::RecurrentSequenceDirection::FORWARD;
|
||||
|
||||
// Invalid rank0 for X tensor.
|
||||
X = make_shared<opset5::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction),
|
||||
ngraph::CheckFailure)
|
||||
<< "RNNSequence node was created with invalid data.";
|
||||
|
||||
// Invalid rank0 for H_t tensor.
|
||||
X = make_shared<opset5::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
|
||||
H_t = make_shared<opset5::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction),
|
||||
ngraph::CheckFailure)
|
||||
<< "RNNSequence node was created with invalid data.";
|
||||
|
||||
// Invalid rank0 for W tensor.
|
||||
H_t = make_shared<opset5::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
W = make_shared<opset5::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction),
|
||||
ngraph::CheckFailure)
|
||||
<< "RNNSequence node was created with invalid data.";
|
||||
|
||||
// Invalid rank0 for R tensor.
|
||||
W = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, hidden_size, input_size});
|
||||
R = make_shared<opset5::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction),
|
||||
ngraph::CheckFailure)
|
||||
<< "RNNSequence node was created with invalid data.";
|
||||
|
||||
// Invalid rank0 for B tensor.
|
||||
R = make_shared<opset5::Parameter>(element::f32, Shape{num_directions, hidden_size, hidden_size});
|
||||
B = make_shared<opset5::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction),
|
||||
ngraph::CheckFailure)
|
||||
<< "RNNSequence node was created with invalid data.";
|
||||
}
|
||||
|
||||
TEST(type_prop, rnn_sequence_dynamic_invalid_input_dynamic_rank) {
|
||||
const size_t batch_size = 8;
|
||||
const size_t num_directions = 1;
|
||||
const size_t seq_length = 6;
|
||||
const size_t input_size = 4;
|
||||
const size_t hidden_size = 128;
|
||||
|
||||
auto X = make_shared<opset5::Parameter>(element::f32, PartialShape{batch_size, seq_length, input_size});
|
||||
auto H_t = make_shared<opset5::Parameter>(element::f32, PartialShape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, PartialShape{batch_size});
|
||||
|
||||
auto W = make_shared<opset5::Parameter>(element::f32, PartialShape{num_directions, hidden_size, input_size});
|
||||
auto R = make_shared<opset5::Parameter>(element::f32, PartialShape{num_directions, hidden_size, hidden_size});
|
||||
auto B = make_shared<opset5::Parameter>(element::f32, PartialShape{num_directions, hidden_size});
|
||||
|
||||
const auto direction = op::RecurrentSequenceDirection::FORWARD;
|
||||
|
||||
auto check_dynamic_rnn = [](const shared_ptr<opset5::RNNSequence>& rnn) -> bool {
|
||||
return rnn->output(0).get_partial_shape() == PartialShape::dynamic() &&
|
||||
rnn->output(0).get_element_type() == rnn->input(0).get_element_type() &&
|
||||
rnn->output(1).get_partial_shape() == PartialShape::dynamic() &&
|
||||
rnn->output(1).get_element_type() == rnn->input(0).get_element_type();
|
||||
};
|
||||
|
||||
X = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
auto rnn_x = make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction);
|
||||
EXPECT_EQ(check_dynamic_rnn(rnn_x), true);
|
||||
|
||||
X = make_shared<opset5::Parameter>(element::f32, PartialShape{batch_size, seq_length, input_size});
|
||||
H_t = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
auto rnn_h = make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction);
|
||||
EXPECT_EQ(check_dynamic_rnn(rnn_h), true);
|
||||
|
||||
H_t = make_shared<opset5::Parameter>(element::f32, PartialShape{batch_size, num_directions, hidden_size});
|
||||
W = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
auto rnn_w = make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction);
|
||||
EXPECT_EQ(check_dynamic_rnn(rnn_w), true);
|
||||
|
||||
W = make_shared<opset5::Parameter>(element::f32, PartialShape{num_directions, hidden_size, input_size});
|
||||
R = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
auto rnn_r = make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction);
|
||||
EXPECT_EQ(check_dynamic_rnn(rnn_r), true);
|
||||
|
||||
R = make_shared<opset5::Parameter>(element::f32, PartialShape{num_directions, hidden_size, hidden_size});
|
||||
B = make_shared<opset5::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
auto rnn_b = make_shared<opset5::RNNSequence>(X, H_t, sequence_lengths, W, R, B, hidden_size, direction);
|
||||
EXPECT_EQ(check_dynamic_rnn(rnn_b), true);
|
||||
}
|
||||
|
63
ngraph/test/visitors/op/rnn_sequence.cpp
Normal file
63
ngraph/test/visitors/op/rnn_sequence.cpp
Normal file
@ -0,0 +1,63 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
#include "ngraph/opsets/opset3.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
#include "ngraph/opsets/opset5.hpp"
|
||||
#include "util/visitor.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using ngraph::test::NodeBuilder;
|
||||
using ngraph::test::ValueMap;
|
||||
|
||||
TEST(attributes, rnn_sequence_op) {
|
||||
NodeBuilder::get_ops().register_factory<opset5::RNNSequence>();
|
||||
|
||||
const size_t batch_size = 4;
|
||||
const size_t num_directions = 2;
|
||||
const size_t seq_length = 8;
|
||||
const size_t input_size = 16;
|
||||
const size_t hidden_size = 64;
|
||||
|
||||
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
|
||||
const auto initial_hidden_state =
|
||||
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{batch_size});
|
||||
const auto W = make_shared<op::Parameter>(element::f32, Shape{num_directions, hidden_size, input_size});
|
||||
const auto R = make_shared<op::Parameter>(element::f32, Shape{num_directions, hidden_size, hidden_size});
|
||||
const auto B = make_shared<op::Parameter>(element::f32, Shape{num_directions, hidden_size});
|
||||
|
||||
const auto rnn_direction = op::RecurrentSequenceDirection::BIDIRECTIONAL;
|
||||
const std::vector<float> activations_alpha = {1, 2, 3};
|
||||
const std::vector<float> activations_beta = {4, 5, 6};
|
||||
const std::vector<std::string> activations = {"tanh", "sigmoid", "tanh"};
|
||||
const float clip_threshold = 0.5f;
|
||||
|
||||
const auto rnn_sequence = make_shared<opset5::RNNSequence>(X,
|
||||
initial_hidden_state,
|
||||
sequence_lengths,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
rnn_direction,
|
||||
activations,
|
||||
activations_alpha,
|
||||
activations_beta,
|
||||
clip_threshold);
|
||||
NodeBuilder builder(rnn_sequence);
|
||||
auto g_rnn_sequence = ov::as_type_ptr<opset5::RNNSequence>(builder.create());
|
||||
|
||||
EXPECT_EQ(g_rnn_sequence->get_hidden_size(), rnn_sequence->get_hidden_size());
|
||||
EXPECT_EQ(g_rnn_sequence->get_activations(), rnn_sequence->get_activations());
|
||||
EXPECT_EQ(g_rnn_sequence->get_activations_alpha(), rnn_sequence->get_activations_alpha());
|
||||
EXPECT_EQ(g_rnn_sequence->get_activations_beta(), rnn_sequence->get_activations_beta());
|
||||
EXPECT_EQ(g_rnn_sequence->get_clip(), rnn_sequence->get_clip());
|
||||
EXPECT_EQ(g_rnn_sequence->get_direction(), rnn_sequence->get_direction());
|
||||
}
|
Loading…
Reference in New Issue
Block a user