[ONNX] GRU and RNN operators. (#607)

* Create generic RecurrentSequenceDirection enum.

* Helper class RecurrentSequenceOp.

* Add ONNX GRU & RNN operators.

* Use OutputVector.

* Update doc.

* Add UTs for GRU and skip them on IE_CPU

* Add UT for bidirectional mode and fix it.

* Normalize activation function name case.

* Add unit-tests for RNN operator.

* UT for GRU with linear_before_reset set to true.

* Fix ONNX GRU for linear_before_reset case.

* Remove unnecessary symbol export macro.

* Fix CentOS error.

* Update UTs.

- Update few tests accuracy tolerance
- Update rnn_fwd_activations with new reference values and model.

* Review comment: add check for static shape

* Add UT for RNN with constant inputs W, R.

* Skip UT with const W,R on IE_CPU
This commit is contained in:
Adam Osewski
2020-06-03 11:01:56 +02:00
committed by GitHub
parent 4e0c7a217f
commit 3a80f0476b
36 changed files with 4120 additions and 70 deletions

View File

@@ -153,6 +153,7 @@ namespace ngraph
/// new axis is placed.
///
/// \return Reshape:v1 op.
NGRAPH_API
std::shared_ptr<Node> expand_dims(const Output<Node>& value, std::size_t axis = 0);
/// \brief Remove empty axes from input tensor.
@@ -161,6 +162,7 @@ namespace ngraph
/// \param[in] axes The vector defining indexes of axes to be removed.
///
/// \return Reshape:v1 op.
NGRAPH_API
std::shared_ptr<Node> squeeze(const Output<Node>& value,
std::vector<std::size_t> axes = {0});
}

View File

@@ -110,6 +110,8 @@ add_library(onnx_importer SHARED
op/global_max_pool.cpp
op/global_max_pool.hpp
op/greater.hpp
op/gru.cpp
op/gru.hpp
op/hard_sigmoid.cpp
op/hard_sigmoid.hpp
op/hardmax.cpp
@@ -179,6 +181,8 @@ add_library(onnx_importer SHARED
op/resize.hpp
op/reverse_sequence.cpp
op/reverse_sequence.hpp
op/rnn.cpp
op/rnn.hpp
op/roi_align.cpp
op/roi_align.hpp
op/round.cpp
@@ -241,6 +245,8 @@ add_library(onnx_importer SHARED
utils/onnx_importer_visibility.hpp
utils/pooling_factory.cpp
utils/pooling_factory.hpp
utils/recurrent.cpp
utils/recurrent.hpp
utils/reduction.cpp
utils/reduction.hpp
utils/reshape.cpp

View File

@@ -0,0 +1,148 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <string>
#include <vector>
#include "default_opset.hpp"
#include "gru.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/shape.hpp"
#include "utils/recurrent.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
namespace
{
struct GRUInputMap : public recurrent::OpInputMap
{
GRUInputMap(const Node& node, std::size_t gates_count)
: OpInputMap(node, gates_count)
{
bool linear_before_reset = static_cast<bool>(
node.get_attribute_value<std::int64_t>("linear_before_reset", 0));
// Override bias, since we need separated W and R biases for `h` gate.
if (linear_before_reset)
{
const auto& ng_inputs = node.get_ng_inputs();
const auto el_type = ng_inputs.at(0)->get_output_element_type(0);
if (ng_inputs.size() > 3 && !ng_inputs.at(3)->is_null())
{
auto bias = ng_inputs.at(3);
// gates_count * 2 since B is: [Wb, Rb]
const int split_parts = 2 * 3;
const auto split_bias =
builder::opset1::split(bias, split_parts, 1);
const auto wr_z_bias = split_bias.at(0) + split_bias.at(3);
const auto wr_r_bias = split_bias.at(1) + split_bias.at(4);
// The result has shape: [num_directions, 4 * hidden_size]
// and data layout:
// [
// [Wb_z + Rb_z],
// [Wb_r + Rb_r],
// [Wb_h],
// [Rb_h],
// // num_directions times
// ]
m_map[recurrent::OpInput::B] =
std::make_shared<default_opset::Concat>(
NodeVector{wr_z_bias,
wr_r_bias,
split_bias.at(2),
split_bias.at(5)},
1);
}
else
{
const std::size_t hidden_size =
m_map[recurrent::OpInput::R]->get_shape().back();
const std::size_t num_directions =
m_map[recurrent::OpInput::W]->get_shape().front();
m_map[recurrent::OpInput::B] =
std::make_shared<default_opset::Constant>(
el_type,
Shape{num_directions, (gates_count + 1) * hidden_size},
0.f);
}
}
}
virtual ~GRUInputMap() = default;
};
struct GRUAttributes : public recurrent::OpAttributes
{
GRUAttributes(const Node& node)
: OpAttributes(node)
, m_linear_before_reset{static_cast<bool>(
node.get_attribute_value<std::int64_t>("linear_before_reset", 0))}
{
m_activations = node.get_attribute_value<std::vector<std::string>>(
"activations", {"sigmoid", "tanh"});
}
virtual ~GRUAttributes() = default;
bool m_linear_before_reset;
};
}
NodeVector gru(const Node& node)
{
constexpr std::size_t gates_count = 3;
GRUInputMap input_map{node, gates_count};
GRUAttributes attributes{node};
recurrent::RecurrentSequence sequence_op(input_map, attributes.m_direction);
auto results =
sequence_op.run_sequence([&attributes](const recurrent::OpInputMap& args,
const Output<ngraph::Node>& in_Xt,
const Output<ngraph::Node> H_t) {
const GRUInputMap& gru_args = dynamic_cast<const GRUInputMap&>(args);
return std::make_shared<default_opset::GRUCell>(
in_Xt,
H_t,
gru_args.at(recurrent::OpInput::W),
gru_args.at(recurrent::OpInput::R),
gru_args.at(recurrent::OpInput::B),
attributes.m_hidden_size,
attributes.m_activations,
attributes.m_activations_alpha,
attributes.m_activations_beta,
attributes.m_clip_threshold,
attributes.m_linear_before_reset);
});
return results;
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph

View File

@@ -0,0 +1,38 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector gru(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph

View File

@@ -26,11 +26,13 @@
#include "lstm.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/enum_names.hpp"
#include "ngraph/frontend/onnx_import/op/lstm.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/lstm_sequence.hpp"
#include "ngraph/op/get_output_element.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
@@ -188,26 +190,12 @@ namespace ngraph
m_clip_threshold = std::abs(m_clip_threshold);
std::string direction = ngraph::to_lower(
node.get_attribute_value<std::string>("direction", "forward"));
NGRAPH_CHECK(direction == "bidirectional" || direction == "forward" ||
direction == "reverse",
"Provided direction: ",
direction,
" is invalid");
if (direction == "forward")
{
m_direction = default_opset::LSTMSequence::direction::FORWARD;
}
else if (direction == "reverse")
{
m_direction = default_opset::LSTMSequence::direction::REVERSE;
}
else // (direction == "bidirectional")
{
m_direction = default_opset::LSTMSequence::direction::BIDIRECTIONAL;
}
m_direction =
ngraph::as_enum<ngraph::op::RecurrentSequenceDirection>(direction);
}
ngraph::op::LSTMSequence::direction m_direction;
ngraph::op::RecurrentSequenceDirection m_direction;
std::int64_t m_hidden_size;
float m_clip_threshold;
std::vector<std::string> m_activations;

View File

@@ -0,0 +1,83 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "rnn.hpp"
#include "default_opset.hpp"
#include "utils/recurrent.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
namespace
{
struct RNNInputMap : public recurrent::OpInputMap
{
RNNInputMap(const onnx_import::Node& node, std::size_t gates_count)
: OpInputMap(node, gates_count)
{
}
virtual ~RNNInputMap() = default;
};
struct RNNAttributes : public recurrent::OpAttributes
{
RNNAttributes(const Node& node)
: OpAttributes(node)
{
}
virtual ~RNNAttributes() = default;
};
}
NodeVector rnn(const Node& node)
{
constexpr std::size_t gates_count = 1;
RNNInputMap input_map{node, gates_count};
RNNAttributes attributes{node};
recurrent::RecurrentSequence sequence_op(input_map, attributes.m_direction);
auto results =
sequence_op.run_sequence([&attributes](const recurrent::OpInputMap& args,
const Output<ngraph::Node>& in_Xt,
const Output<ngraph::Node> H_t) {
const RNNInputMap& rnn_args = dynamic_cast<const RNNInputMap&>(args);
return std::make_shared<default_opset::RNNCell>(
in_Xt,
H_t,
rnn_args.at(recurrent::OpInput::W),
rnn_args.at(recurrent::OpInput::R),
rnn_args.at(recurrent::OpInput::B),
attributes.m_hidden_size,
attributes.m_activations,
attributes.m_activations_alpha,
attributes.m_activations_beta,
attributes.m_clip_threshold);
});
return results;
}
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph

View File

@@ -0,0 +1,38 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include "core/node.hpp"
#include "ngraph/node.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace op
{
namespace set_1
{
NodeVector rnn(const Node& node);
} // namespace set_1
} // namespace op
} // namespace onnx_import
} // namespace ngraph

View File

@@ -65,6 +65,7 @@
#include "op/global_average_pool.hpp"
#include "op/global_max_pool.hpp"
#include "op/greater.hpp"
#include "op/gru.hpp"
#include "op/hard_sigmoid.hpp"
#include "op/hardmax.hpp"
#include "op/identity.hpp"
@@ -105,6 +106,7 @@
#include "op/reshape.hpp"
#include "op/resize.hpp"
#include "op/reverse_sequence.hpp"
#include "op/rnn.hpp"
#include "op/roi_align.hpp"
#include "op/round.hpp"
#include "op/scatter_elements.hpp"
@@ -293,6 +295,7 @@ namespace ngraph
REGISTER_OPERATOR("GlobalLpPool", 1, global_lp_pool);
REGISTER_OPERATOR("GlobalMaxPool", 1, global_max_pool);
REGISTER_OPERATOR("Greater", 1, greater);
REGISTER_OPERATOR("GRU", 1, gru);
REGISTER_OPERATOR("Hardmax", 1, hardmax);
REGISTER_OPERATOR("HardSigmoid", 1, hard_sigmoid);
REGISTER_OPERATOR("Identity", 1, identity);
@@ -346,6 +349,7 @@ namespace ngraph
REGISTER_OPERATOR("Reshape", 1, reshape);
REGISTER_OPERATOR("Resize", 1, resize);
REGISTER_OPERATOR("ReverseSequence", 1, reverse_sequence);
REGISTER_OPERATOR("RNN", 1, rnn);
REGISTER_OPERATOR("RoiAlign", 1, roi_align);
REGISTER_OPERATOR("Round", 1, round);
REGISTER_OPERATOR("Scatter", 1, scatter_elements);

View File

@@ -0,0 +1,309 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cstdint>
#include <cstdlib>
#include <vector>
#include "default_opset.hpp"
#include "ngraph/builder/autobroadcast.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/check.hpp"
#include "ngraph/enum_names.hpp"
#include "recurrent.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace recurrent
{
OpInputMap::OpInputMap(const onnx_import::Node& node, std::size_t gates_count)
{
const auto& ng_inputs = node.get_ng_inputs();
m_map[OpInput::X] = ng_inputs.at(0);
m_map[OpInput::W] = ng_inputs.at(1);
m_map[OpInput::R] = ng_inputs.at(2);
const auto el_type = ng_inputs.at(0)->get_output_element_type(0);
const auto x_pshape = m_map[OpInput::X]->get_output_partial_shape(0);
const auto w_pshape = m_map[OpInput::W]->get_output_partial_shape(0);
const auto r_pshape = m_map[OpInput::R]->get_output_partial_shape(0);
NGRAPH_CHECK(x_pshape.rank().is_static() &&
x_pshape[0].is_static() &&
x_pshape[1].is_static(),
"RecurrentSequence input X must have static \"seq_length\" and "
"\"batch_size\" dimensions.");
NGRAPH_CHECK(w_pshape.rank().is_static() &&
w_pshape[0].is_static(),
"RecurrentSequence input W must have static \"num_directions\" "
"(outermost) dimension.");
NGRAPH_CHECK(r_pshape.rank().is_static() &&
r_pshape[2].is_static(),
"RecurrentSequence input R must have static \"hidden_size\" "
"(innermost) dimension.");
const std::size_t hidden_size = m_map[OpInput::R]->get_shape().back();
const std::size_t batch_size = m_map[OpInput::X]->get_shape().at(1);
const std::size_t num_directions = m_map[OpInput::W]->get_shape().front();
if (ng_inputs.size() > 3 && !ng_inputs.at(3)->is_null())
{
auto bias = ng_inputs.at(3);
auto split_bias = builder::opset1::split(bias, 2, 1);
m_map[OpInput::B] = split_bias.at(0) + split_bias.at(1);
}
else
{
m_map[OpInput::B] = std::make_shared<default_opset::Constant>(
el_type, Shape{num_directions, gates_count * hidden_size}, 0.f);
}
if (ng_inputs.size() > 4 && !ng_inputs.at(4)->is_null())
{
m_map[OpInput::SEQ_LENGTHS] = ng_inputs.at(4);
}
else
{
m_map[OpInput::SEQ_LENGTHS] = std::make_shared<default_opset::Constant>(
element::i32, Shape{batch_size}, m_map[OpInput::X]->get_shape().at(0));
}
// The initial value of the hidden.
if (ng_inputs.size() > 5 && !ng_inputs.at(5)->is_null())
{
m_map[OpInput::INIT_H] = ng_inputs.at(5);
}
else
{
m_map[OpInput::INIT_H] = std::make_shared<default_opset::Constant>(
el_type, Shape{num_directions, batch_size, hidden_size}, 0.f);
}
}
OpInputMap::OpInputMap(container_type&& map)
: m_map(std::move(map))
{
}
std::shared_ptr<ngraph::Node>& OpInputMap::at(const OpInput& key)
{
return m_map.at(key);
}
const std::shared_ptr<ngraph::Node>& OpInputMap::at(const OpInput& key) const
{
return m_map.at(key);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ATTRIBUTES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
OpAttributes::OpAttributes(const Node& node)
: m_hidden_size{node.get_attribute_value<std::int64_t>("hidden_size")}
, m_clip_threshold{node.get_attribute_value<float>("clip", 0.f)}
// Recurrent Operators which have more activation functions should override
// this value in constructor of respective Attributes' struct.
, m_activations{node.get_attribute_value<std::vector<std::string>>("activations",
{"tanh"})}
// Default values for activation functions are same as for corresponding
// ONNX operator.
, m_activations_alpha{node.get_attribute_value<std::vector<float>>(
"activation_alpha", std::vector<float>{})}
, m_activations_beta{node.get_attribute_value<std::vector<float>>(
"activation_beta", std::vector<float>{})}
{
m_clip_threshold = std::abs(m_clip_threshold);
std::string direction =
ngraph::to_lower(node.get_attribute_value<std::string>("direction", "forward"));
m_direction = ngraph::as_enum<ngraph::op::RecurrentSequenceDirection>(direction);
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Sequence Computations ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
RecurrentSequence::RecurrentSequence(OpInputMap& args,
ngraph::op::RecurrentSequenceDirection direction)
: m_args(args)
, m_direction(direction)
{
}
NodeVector RecurrentSequence::run_sequence(const RecurrentCellFunction& kernel)
{
NodeVector results;
if (m_direction == ngraph::op::RecurrentSequenceDirection::FORWARD ||
m_direction == ngraph::op::RecurrentSequenceDirection::REVERSE)
{
results = recurrent_sequence_pass(
kernel, m_direction == ngraph::op::RecurrentSequenceDirection::REVERSE);
}
else if (m_direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
{
NodeVector fwd_results{recurrent_sequence_pass(kernel)};
NodeVector rev_results{recurrent_sequence_pass(kernel, true)};
// Stack together respective outputs from both forward and reverse passess.
std::shared_ptr<ngraph::Node> Y{std::make_shared<default_opset::Concat>(
NodeVector{fwd_results.at(0), rev_results.at(0)}, 1)};
results.push_back(Y);
std::shared_ptr<ngraph::Node> Y_h{std::make_shared<default_opset::Concat>(
NodeVector{fwd_results.at(1), rev_results.at(1)}, 0)};
results.push_back(Y_h);
}
else
{
throw ngraph_error(
"RecurrentSequence: unhandled direction mode during decomposition.");
}
return results;
}
NodeVector
RecurrentSequence::recurrent_sequence_pass(const RecurrentCellFunction& kernel,
bool is_reverse)
{
OutputVector h_list;
// back-up nodes which we may later modify.
std::shared_ptr<ngraph::Node> orig_W = m_args.at(OpInput::W);
std::shared_ptr<ngraph::Node> orig_R = m_args.at(OpInput::R);
std::shared_ptr<ngraph::Node> orig_B = m_args.at(OpInput::B);
std::shared_ptr<ngraph::Node> X = m_args.at(OpInput::X);
std::shared_ptr<ngraph::Node> H_t =
prepare_input(m_args.at(OpInput::INIT_H), is_reverse);
std::shared_ptr<ngraph::Node> W = prepare_input(m_args.at(OpInput::W), is_reverse);
std::shared_ptr<ngraph::Node> R = prepare_input(m_args.at(OpInput::R), is_reverse);
std::shared_ptr<ngraph::Node> B = prepare_input(m_args.at(OpInput::B), is_reverse);
std::shared_ptr<ngraph::Node> seq_lengths = m_args.at(OpInput::SEQ_LENGTHS);
m_args.at(OpInput::W) = W;
m_args.at(OpInput::R) = R;
m_args.at(OpInput::B) = B;
if (is_reverse)
{
X = std::make_shared<default_opset::ReverseSequence>(
X, seq_lengths, 1 /*batch_axis*/, 0 /*seq_axis*/);
}
OutputVector in_seq_steps =
as_output_vector(builder::opset1::split(X, X->get_shape().at(0)));
for (auto& in_x : in_seq_steps)
{
// remove first empty dim, after above split.
in_x = builder::opset1::squeeze(in_x);
}
int32_t time_step{1};
for (const auto& in_x : in_seq_steps)
{
Output<ngraph::Node> H = kernel(m_args, in_x, H_t);
// Expand tensors with empty outermost dim, so we can later concatenate
// them.
// Mask hidden state tensor in order to handle mixed sequence lengths.
// This results in zeroing out values in batches with sequence shorter
// than current time_step.
h_list.push_back(
get_masked_node(builder::opset1::expand_dims(H), time_step, 1));
// Here we make sure that only appropriate batches (with respect to its sequence
// length) are updated. Those batches which has shorter sequences preserve
// the last value.
H_t = get_masked_node(H, time_step, 0, H_t);
time_step++;
}
// Get back original nodes.
m_args.at(OpInput::W) = orig_W;
m_args.at(OpInput::R) = orig_R;
m_args.at(OpInput::B) = orig_B;
// The tensor that concats all the intermediate output values of the hidden.
// It has shape [seq_length, batch_size, hidden_size]
std::shared_ptr<ngraph::Node> Y{std::make_shared<default_opset::Concat>(h_list, 0)};
// Get back the original order of the output data.
if (is_reverse)
{
Y = std::make_shared<default_opset::ReverseSequence>(
Y, seq_lengths, 1 /*batch_axis*/, 0 /*seq_axis*/);
}
// Expand Y so that it has expected shape:
// [seq_length, num_directions, batch_size, hidden_size]
Y = builder::opset1::expand_dims(Y, 1);
// Expand H_t so that it has expected shape:
// [num_directions, batch_size, hidden_size]
auto Y_h = builder::opset1::expand_dims(H_t);
return {Y, Y_h};
}
std::shared_ptr<ngraph::Node>
RecurrentSequence::get_masked_node(const Output<ngraph::Node>& data,
int32_t time_step,
size_t batch_axis,
const Output<ngraph::Node>& default_value) const
{
Output<ngraph::Node> mask_value = default_value;
// Create zero mask value node.
if (!mask_value.get_node_shared_ptr())
{
mask_value = std::make_shared<default_opset::Constant>(
data.get_element_type(), data.get_shape(), 0.f);
}
// Create predicate nodes. The condition is whether current time step value
// is greater than sequence length for respective batch inputs.
std::shared_ptr<ngraph::Node> curr_time_step_node =
std::make_shared<default_opset::Constant>(
element::i32, data.get_shape(), time_step);
Output<ngraph::Node> batch_seq_length =
builder::opset1::legacy_broadcast_for_binary_operation(
curr_time_step_node, m_args.at(OpInput::SEQ_LENGTHS), batch_axis);
// Create mask node deciding whether or not to mask batch data.
std::shared_ptr<ngraph::Node> mask_condition =
std::make_shared<default_opset::Greater>(curr_time_step_node, batch_seq_length);
// Select values depnding on mask_condition.
// Select(<condition>, <true_value>, <false_value>)
return std::make_shared<default_opset::Select>(mask_condition, mask_value, data);
}
std::shared_ptr<ngraph::Node>
RecurrentSequence::prepare_input(Output<ngraph::Node> node, bool is_reverse) const
{
// In bidirectional mode inputs are stacked together, so we must split them.
std::shared_ptr<ngraph::Node> tmp = node.get_node_shared_ptr();
if (m_direction == ngraph::op::RecurrentSequenceDirection::BIDIRECTIONAL)
{
tmp = builder::opset1::split(node, 2).at(is_reverse ? 1 : 0);
}
// Since we work in forward pass mode, we can squeeze `num_directions` axis from
// input.
return builder::opset1::squeeze(tmp);
}
} // recurrent
} // onnx_import
} // ngraph

View File

@@ -0,0 +1,186 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <functional>
#include <map>
#include <memory>
#include "core/node.hpp"
#include "ngraph/node.hpp"
#include "ngraph/op/util/attr_types.hpp"
namespace ngraph
{
namespace onnx_import
{
namespace recurrent
{
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INPUT NODES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
///
/// \brief This class describes a recurrent operation input name
///
enum class OpInput
{
X, // Packed input sequences.
// Shape: [seq_length, batch_size, input_size]
W, // Weight tensor for the gates.
// Shape: [num_directions, gates_count*hidden_size, input_size]
R, // The recurrence weight tensor.
// Shape: [num_directions, gates_count*hidden_size, hidden_size]
B, // The bias tensor for gates.
// Shape [num_directions, gates_count*hidden_size]
SEQ_LENGTHS, // The lengths of the sequences in a batch. Shape [batch_size]
INIT_H, // The initial value of the hidden.
// Shape [num_directions, batch_size, hidden_size]
};
///
/// \brief This structure aggregates operator's inptus in a key-value map.
///
struct OpInputMap
{
using container_type = std::map<OpInput, std::shared_ptr<ngraph::Node>>;
explicit OpInputMap(const onnx_import::Node& node, std::size_t gates_count);
OpInputMap(container_type&& map);
virtual ~OpInputMap() = default;
std::shared_ptr<ngraph::Node>& at(const OpInput& key);
const std::shared_ptr<ngraph::Node>& at(const OpInput& key) const;
container_type m_map;
};
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ATTRIBUTES PARSING ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
///
/// \brief This structure aggregates operator's attributes.
///
struct OpAttributes
{
explicit OpAttributes(const Node& node);
virtual ~OpAttributes() = default;
ngraph::op::RecurrentSequenceDirection m_direction;
std::int64_t m_hidden_size;
float m_clip_threshold;
std::vector<std::string> m_activations;
std::vector<float> m_activations_alpha;
std::vector<float> m_activations_beta;
};
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Helper classes~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
///
/// \brief Callable object defining recurrent cell computations.
///
/// Function returns node output representing cell hidden state after cell
/// computations. The arguments are:
/// * input node map.
/// * the cell input data
/// * the cell hidden state from previous step.
///
using RecurrentCellFunction = std::function<Output<ngraph::Node>(
const OpInputMap&, const Output<ngraph::Node>&, const Output<ngraph::Node>)>;
///
/// \brief This class describes a recurrent (RNN-like) sequence operation.
///
/// \paragraph Outline. This class takes care of orchestration of computations carried
/// out on data sequence. Use have to only provide kernel function
/// which would be executed on current time-step input data and the
/// sequence direction mode.
///
/// \paragraph Assumptions. This class assumes the RNN-like sequence operation. This
/// means that the operator should have inputs and outputs
/// the same as RNN operator. Especially the cell/kernel should
/// have input related to hidden cell state.
///
class RecurrentSequence
{
public:
///
/// \brief Constructs a RecurrentSequence class object.
///
/// \param[in] args The map with recurrent sequence operator inputs.
/// \param[in] attrs The structure containing operator attributes.
/// \param[in] direction The sequence direction mode {FORWARD, REVERSE,
/// BIDIRECTIONAL}.
///
RecurrentSequence(OpInputMap& args,
ngraph::op::RecurrentSequenceDirection direction);
///
/// \brief Carry out all steps of recurrent sequence with provided cell kernel.
///
/// \param[in] kernel The cell kernel function.
///
/// \return The node vector containing results from all sequence steps.
///
NodeVector run_sequence(const RecurrentCellFunction& kernel);
private:
///
/// \brief Gets the masked value according to sequence lenght in a batch.
///
/// \note Zeros out values or sets them to default value for inputs with
/// sequence lenght shorter than currently procssed time step.
///
/// \param[in] data The input value.
/// \param[in] time_step The current time step denoting sequence lenght.
/// \param[in] batch_axis The batch axis index of data tensor.
/// \param[in] default_value The default value for masked elements.
///
/// \return The masked value.
///
std::shared_ptr<ngraph::Node> get_masked_node(
const Output<ngraph::Node>& data,
std::int32_t time_step,
std::size_t batch_axis = 0,
const Output<ngraph::Node>& default_value = Output<ngraph::Node>()) const;
///
/// \brief Split and squeeze input data to remove 'num_direction' dimension.
///
/// \param[in] node The node to update.
/// \param[in] is_reverse Indicates if configure to reverse pass.
///
/// \return Updated node for forward/reverse pass.
///
std::shared_ptr<ngraph::Node> prepare_input(Output<ngraph::Node> node,
bool is_reverse) const;
///
/// \brief Perform computation through all input sequence steps in single mode.
///
/// \param[in] kernel The cell kernel function.
/// \param[in] is_reverse Indicates if carry out reverse or forward pass.
///
/// \return The node vector with pass results.
///
NodeVector recurrent_sequence_pass(const RecurrentCellFunction& kernel,
bool is_reverse = false);
OpInputMap& m_args;
ngraph::op::RecurrentSequenceDirection m_direction;
};
} // recurrent
} // onnx_import
} // ngraph

View File

@@ -255,24 +255,3 @@ shared_ptr<Node> op::v0::LSTMSequence::prepare_input(Output<Node> node,
// Since we have forward LSTM we can squeeze `num_directions` axis from inputs.
return builder::opset1::squeeze(tmp, {num_direction_axis});
}
namespace ngraph
{
template <>
EnumNames<op::v0::LSTMSequence::direction>& EnumNames<op::v0::LSTMSequence::direction>::get()
{
static auto enum_names = EnumNames<op::v0::LSTMSequence::direction>(
"op::v0::LSTMSequence::direction",
{{"forward", op::v0::LSTMSequence::direction::FORWARD},
{"reverse", op::v0::LSTMSequence::direction::REVERSE},
{"bidirectional", op::v0::LSTMSequence::direction::BIDIRECTIONAL}});
return enum_names;
}
constexpr DiscreteTypeInfo AttributeAdapter<op::v0::LSTMSequence::direction>::type_info;
std::ostream& operator<<(std::ostream& s, const op::v0::LSTMSequence::direction& type)
{
return s << as_string(type);
}
} // namespace ngraph

View File

@@ -25,6 +25,7 @@
#include "ngraph/node.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/fused/lstm_cell.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/fused_op.hpp"
namespace ngraph
@@ -49,14 +50,9 @@ namespace ngraph
const NodeTypeInfo& get_type_info() const override { return type_info; }
LSTMSequence() = default;
size_t get_default_output_index() const override { return no_default_index(); }
enum class direction
{
FORWARD,
REVERSE,
BIDIRECTIONAL
};
using direction = RecurrentSequenceDirection;
size_t get_default_output_index() const override { return no_default_index(); }
explicit LSTMSequence(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& initial_cell_state,
@@ -190,21 +186,4 @@ namespace ngraph
using v0::LSTMSequence;
} // namespace op
NGRAPH_API
std::ostream& operator<<(std::ostream& s, const op::v0::LSTMSequence::direction& type);
template <>
class NGRAPH_API AttributeAdapter<op::v0::LSTMSequence::direction>
: public EnumAttributeAdapterBase<op::v0::LSTMSequence::direction>
{
public:
AttributeAdapter(op::v0::LSTMSequence::direction& value)
: EnumAttributeAdapterBase<op::v0::LSTMSequence::direction>(value)
{
}
static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<op::v0::LSTMSequence::direction>", 1};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // namespace ngraph

View File

@@ -45,9 +45,8 @@ namespace ngraph
/// * - Is a dot product,
/// f - is activation functions.
///
/// \note This class represents only single *cell* (for current time step) and not
/// the
/// whole LSTM Sequence layer
/// \note This class represents only single *cell* (for current time step)
/// and not the whole RNN Sequence layer
///
/// \sa LSTMSequence, LSTMCell, GRUCell
///

View File

@@ -203,4 +203,23 @@ namespace ngraph
}
constexpr DiscreteTypeInfo AttributeAdapter<op::BroadcastModeSpec>::type_info;
NGRAPH_API
constexpr DiscreteTypeInfo AttributeAdapter<op::RecurrentSequenceDirection>::type_info;
std::ostream& op::operator<<(std::ostream& s, const op::RecurrentSequenceDirection& direction)
{
return s << as_string(direction);
}
template <>
NGRAPH_API EnumNames<op::RecurrentSequenceDirection>&
EnumNames<op::RecurrentSequenceDirection>::get()
{
static auto enum_names = EnumNames<op::RecurrentSequenceDirection>(
"op::RecurrentSequenceDirection",
{{"forward", op::RecurrentSequenceDirection::FORWARD},
{"reverse", op::RecurrentSequenceDirection::REVERSE},
{"bidirectional", op::RecurrentSequenceDirection::BIDIRECTIONAL}});
return enum_names;
}
}

View File

@@ -420,4 +420,35 @@ namespace ngraph
protected:
op::BroadcastModeSpec& m_ref;
};
namespace op
{
///
/// \brief This class defines possible recurrent sequence directions.
///
enum class RecurrentSequenceDirection
{
FORWARD,
REVERSE,
BIDIRECTIONAL
};
NGRAPH_API
std::ostream& operator<<(std::ostream& s, const RecurrentSequenceDirection& direction);
}
template <>
class NGRAPH_API AttributeAdapter<op::RecurrentSequenceDirection>
: public EnumAttributeAdapterBase<op::RecurrentSequenceDirection>
{
public:
AttributeAdapter(op::RecurrentSequenceDirection& value)
: EnumAttributeAdapterBase<op::RecurrentSequenceDirection>(value)
{
}
static constexpr DiscreteTypeInfo type_info{
"AttributeAdapter<op::RecurrentSequenceDirection>", 1};
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
}

View File

@@ -16,6 +16,7 @@
#include <algorithm>
#include <iterator>
#include <locale>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/op/add.hpp"
@@ -61,7 +62,14 @@ bool ngraph::op::util::RNNCellBase::visit_attributes(AttributeVisitor& visitor)
op::util::ActivationFunction op::util::RNNCellBase::get_activation_function(size_t idx) const
{
op::util::ActivationFunction afunc = get_activation_func_by_name(m_activations.at(idx));
// Normalize activation function case.
std::string func_name = m_activations.at(idx);
std::locale loc;
std::transform(func_name.begin(), func_name.end(), func_name.begin(), [&loc](char c) {
return std::tolower(c, loc);
});
op::util::ActivationFunction afunc = get_activation_func_by_name(func_name);
// Set activation functions parameters (if any)
if (m_activations_alpha.size() > idx)

View File

@@ -1124,7 +1124,7 @@ TEST(attributes, lstm_sequence_op)
Shape{num_directions, 4 * hidden_size, hidden_size});
const auto B = make_shared<op::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
const auto lstm_direction = op::LSTMSequence::direction::BIDIRECTIONAL;
const auto lstm_direction = op::RecurrentSequenceDirection::BIDIRECTIONAL;
const auto weights_format = op::LSTMWeightsFormat::ICOF;
const std::vector<float> activations_alpha = {1, 2, 3};
const std::vector<float> activations_beta = {4, 5, 6};

View File

@@ -0,0 +1,124 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
output: "Y"
output: "Y_h"
op_type: "GRU"
attribute {
name: "hidden_size"
i: 5
type: INT
}
attribute {
name: "direction"
s: "bidirectional"
type: STRING
}
}
name: "test_gru_bidirectional"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 3 # batch size
}
dim {
dim_value: 2 # input size
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2 # num_directions
}
dim {
dim_value: 15 # gates_count*hidden_size
}
dim {
dim_value: 2 # input_size
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2 # num_directions
}
dim {
dim_value: 15 # gates_count*hidden_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 2 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@@ -0,0 +1,119 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
output: "Y"
output: "Y_h"
op_type: "GRU"
attribute {
name: "hidden_size"
i: 5
type: INT
}
}
name: "test_gru_defaults"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 3 # batch size
}
dim {
dim_value: 2 # input size
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 15 # gates_count*hidden_size
}
dim {
dim_value: 2 # input_size
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 15 # gates_count*hidden_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@@ -0,0 +1,130 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
output: "Y"
output: "Y_h"
op_type: "GRU"
attribute {
name: "hidden_size"
i: 5
type: INT
}
attribute {
name: "direction"
s: "forward"
type: STRING
}
attribute {
name: "activations"
strings: "relu"
strings: "hardsigmoid"
type: STRINGS
}
}
name: "test_gru_fwd_activations"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 3 # batch size
}
dim {
dim_value: 2 # input size
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 15 # gates_count*hidden_size
}
dim {
dim_value: 2 # input_size
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 15 # gates_count*hidden_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@@ -0,0 +1,157 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
input: "B"
input: ""
input: "initial_h"
output: "Y"
output: "Y_h"
op_type: "GRU"
attribute {
name: "hidden_size"
i: 5
type: INT
}
}
name: "test_gru_fwd_bias_initial_h"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 3 # batch size
}
dim {
dim_value: 2 # input size
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 15 # gates_count*hidden_size
}
dim {
dim_value: 2 # input_size
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 15 # gates_count*hidden_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 30 # gates_count*hidden_size
}
}
}
}
}
input {
name: "initial_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch size
}
dim {
dim_value: 5 # hidden size
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@@ -0,0 +1,141 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
input: "B"
output: "Y"
output: "Y_h"
op_type: "GRU"
attribute {
name: "hidden_size"
i: 5
type: INT
}
attribute {
name: "linear_before_reset"
i: 1
type: INT
}
}
name: "test_gru_fwd_linear_before_reset_bias"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 3 # batch size
}
dim {
dim_value: 2 # input size
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 15 # gates_count*hidden_size
}
dim {
dim_value: 2 # input_size
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 15 # gates_count*hidden_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 30 # 2 * gates_count * hidden_size
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@@ -0,0 +1,170 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
input: "B"
input: "sequence_lens"
input: "initial_h"
output: "Y"
output: "Y_h"
op_type: "GRU"
attribute {
name: "hidden_size"
i: 5
type: INT
}
}
name: "test_gru_fwd_mixed_seq_len"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 15
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 15
}
dim {
dim_value: 5
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 30
}
}
}
}
}
input {
name: "sequence_lens"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 3
}
}
}
}
}
input {
name: "initial_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 5
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@@ -0,0 +1,129 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
output: "Y"
output: "Y_h"
op_type: "GRU"
attribute {
name: "hidden_size"
i: 5
type: INT
}
attribute {
name: "clip"
f: 1.752
type: FLOAT
}
attribute {
name: "direction"
s: "reverse"
type: STRING
}
}
name: "test_gru_rev_clip"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 3 # batch size
}
dim {
dim_value: 2 # input size
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 15 # gates_count*hidden_size
}
dim {
dim_value: 2 # input_size
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 15 # gates_count*hidden_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@@ -0,0 +1,124 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
output: "Y"
output: "Y_h"
op_type: "GRU"
attribute {
name: "hidden_size"
i: 5
type: INT
}
attribute {
name: "direction"
s: 'reverse'
type: STRING
}
}
name: "test_gru_reverse"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 3 # batch size
}
dim {
dim_value: 2 # input size
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 15 # gates_count*hidden_size
}
dim {
dim_value: 2 # input_size
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 15 # gates_count*hidden_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@@ -0,0 +1,124 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
output: "Y"
output: "Y_h"
op_type: "RNN"
attribute {
name: "hidden_size"
i: 5
type: INT
}
attribute {
name: "direction"
s: "bidirectional"
type: STRING
}
}
name: "test_rnn_bidirectional"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 3 # batch size
}
dim {
dim_value: 2 # input size
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2 # num_directions
}
dim {
dim_value: 5 # gates_count*hidden_size
}
dim {
dim_value: 2 # input_size
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2 # num_directions
}
dim {
dim_value: 5 # gates_count*hidden_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 2 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@@ -0,0 +1,187 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
output: "W"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 2 # num_directions
dims: 5 # gates_count*hidden_size
dims: 2 # input_size
data_type: 1
float_data: 0.31403765
float_data: -0.16793324
float_data: 1.388258
float_data: -0.6902954
float_data: -0.3994045
float_data: -0.7833511
float_data: -0.30992958
float_data: 0.3557573
float_data: -0.4682631
float_data: 1.1741459
float_data: -2.414789
float_data: -0.42783254
float_data: -0.82199496
float_data: -0.03900861
float_data: -0.43670088
float_data: -0.53810567
float_data: -0.10769883
float_data: 0.75242394
float_data: -0.2507971
float_data: 1.0447186
name: "W_tensor"
}
type: TENSOR
}
}
node {
output: "R"
op_type: "Constant"
attribute {
name: "value"
t {
dims: 2 # num_directions
dims: 5 # gates_count*hidden_size
dims: 5 # input_size
data_type: 1
float_data: -1.4777364
float_data: 0.19993274
float_data: 0.925649
float_data: -2.282516
float_data: 0.95039636
float_data: 1.5379831
float_data: -0.88576007
float_data: 0.28566247
float_data: 0.79292643
float_data: -0.04261953
float_data: 0.8490583
float_data: 0.45121244
float_data: -1.1799014
float_data: 0.13536449
float_data: 0.81328654
float_data: 0.6017516
float_data: 0.48475724
float_data: -1.2136037
float_data: 0.16383322
float_data: 1.5106261
float_data: 1.1177503
float_data: 0.23582461
float_data: 0.5754652
float_data: 0.43879887
float_data: 0.7399294
float_data: 0.4517558
float_data: 1.3536783
float_data: -0.4843166
float_data: -1.1503736
float_data: -0.2458678
float_data: 0.54523313
float_data: -0.08649993
float_data: -0.6936281
float_data: 1.002422
float_data: -1.770847
float_data: -0.94642
float_data: -1.8135757
float_data: 1.8819852
float_data: -0.10852333
float_data: -0.26120332
float_data: 1.0223165
float_data: -0.7468837
float_data: 0.28566906
float_data: 0.92321056
float_data: 0.22521864
float_data: 1.1123824
float_data: -0.9298287
float_data: 1.2141289
float_data: 1.3470556
float_data: -0.32972014
name: "R_tensor"
}
type: TENSOR
}
}
node {
input: "X"
input: "W"
input: "R"
output: "Y"
output: "Y_h"
op_type: "RNN"
attribute {
name: "hidden_size"
i: 5
type: INT
}
attribute {
name: "direction"
s: "bidirectional"
type: STRING
}
}
name: "test_rnn_bidirectional"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 3 # batch size
}
dim {
dim_value: 2 # input size
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 2 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@@ -0,0 +1,119 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
output: "Y"
output: "Y_h"
op_type: "RNN"
attribute {
name: "hidden_size"
i: 5
type: INT
}
}
name: "test_rnn_defaults"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 3 # batch size
}
dim {
dim_value: 2 # input size
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 5 # gates_count*hidden_size
}
dim {
dim_value: 2 # input_size
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 5 # gates_count*hidden_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@@ -0,0 +1,129 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
output: "Y"
output: "Y_h"
op_type: "RNN"
attribute {
name: "hidden_size"
i: 5
type: INT
}
attribute {
name: "direction"
s: "forward"
type: STRING
}
attribute {
name: "activations"
strings: "Relu"
type: STRINGS
}
}
name: "test_rnn_fwd_activations"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 3 # batch size
}
dim {
dim_value: 2 # input size
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 5 # gates_count*hidden_size
}
dim {
dim_value: 2 # input_size
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 5 # gates_count*hidden_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@@ -0,0 +1,157 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
input: "B"
input: ""
input: "initial_h"
output: "Y"
output: "Y_h"
op_type: "RNN"
attribute {
name: "hidden_size"
i: 5
type: INT
}
}
name: "test_rnn_fwd_bias_initial_h"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 3 # batch size
}
dim {
dim_value: 2 # input size
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 5 # gates_count*hidden_size
}
dim {
dim_value: 2 # input_size
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 5 # gates_count*hidden_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 10 # 2 * gates_count*hidden_size
}
}
}
}
}
input {
name: "initial_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch size
}
dim {
dim_value: 5 # hidden size
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@@ -0,0 +1,170 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
input: "B"
input: "sequence_lens"
input: "initial_h"
output: "Y"
output: "Y_h"
op_type: "RNN"
attribute {
name: "hidden_size"
i: 5
type: INT
}
}
name: "test_rnn_fwd_mixed_seq_len"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 5
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 5
}
dim {
dim_value: 5
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 10
}
}
}
}
}
input {
name: "sequence_lens"
type {
tensor_type {
elem_type: 6
shape {
dim {
dim_value: 3
}
}
}
}
}
input {
name: "initial_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4
}
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 5
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 3
}
dim {
dim_value: 5
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@@ -0,0 +1,129 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
output: "Y"
output: "Y_h"
op_type: "RNN"
attribute {
name: "hidden_size"
i: 5
type: INT
}
attribute {
name: "clip"
f: 1.752
type: FLOAT
}
attribute {
name: "direction"
s: "reverse"
type: STRING
}
}
name: "test_rnn_rev_clip"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 3 # batch size
}
dim {
dim_value: 2 # input size
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 5 # gates_count*hidden_size
}
dim {
dim_value: 2 # input_size
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 5 # gates_count*hidden_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@@ -0,0 +1,124 @@
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "X"
input: "W"
input: "R"
output: "Y"
output: "Y_h"
op_type: "RNN"
attribute {
name: "hidden_size"
i: 5
type: INT
}
attribute {
name: "direction"
s: 'reverse'
type: STRING
}
}
name: "test_rnn_reverse"
input {
name: "X"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 3 # batch size
}
dim {
dim_value: 2 # input size
}
}
}
}
}
input {
name: "W"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 5 # gates_count*hidden_size
}
dim {
dim_value: 2 # input_size
}
}
}
}
}
input {
name: "R"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 5 # gates_count*hidden_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 4 # seq_length
}
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
output {
name: "Y_h"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1 # num_directions
}
dim {
dim_value: 3 # batch_size
}
dim {
dim_value: 5 # hidden_size
}
}
}
}
}
}
opset_import {
version: 9
}

View File

@@ -352,3 +352,983 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_model_lstm_mixed_seq_reverse)
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 1);
}
// RNNLikeSequenceOp test fixture for test setup reuse
class GRUSequenceOp : public testing::Test
{
public:
std::vector<float> in_X{0.68172926f, 1.1405563f, -0.03931177f, -0.03759607f, 0.22778925f,
1.2471468f, 0.2785642f, 0.5198979f, 0.3712886f, -0.3194908f,
0.8448233f, -0.62065625f, 1.2968333f, -0.20370148f, 0.40204826f,
-0.23721986f, 0.3629822f, -0.3819832f, -0.7766345f, 0.19374144f,
1.1397027f, 0.60444903f, 1.3246384f, -0.28191715f};
std::vector<float> in_W{
0.31403765f, -0.16793324f, 1.388258f, -0.6902954f, -0.3994045f, -0.7833511f,
-0.30992958f, 0.3557573f, -0.4682631f, 1.1741459f, -2.414789f, -0.42783254f,
-0.82199496f, -0.03900861f, -0.43670088f, -0.53810567f, -0.10769883f, 0.75242394f,
-0.2507971f, 1.0447186f, -1.4777364f, 0.19993274f, 0.925649f, -2.282516f,
0.95039636f, 1.5379831f, -0.88576007f, 0.28566247f, 0.79292643f, -0.04261953f,
};
std::vector<float> in_R{
0.8490583f, 0.45121244f, -1.1799014f, 0.13536449f, 0.81328654f, 0.6017516f,
0.48475724f, -1.2136037f, 0.16383322f, 1.5106261f, 1.1177503f, 0.23582461f,
0.5754652f, 0.43879887f, 0.7399294f, 0.4517558f, 1.3536783f, -0.4843166f,
-1.1503736f, -0.2458678f, 0.54523313f, -0.08649993f, -0.6936281f, 1.002422f,
-1.770847f, -0.94642f, -1.8135757f, 1.8819852f, -0.10852333f, -0.26120332f,
1.0223165f, -0.7468837f, 0.28566906f, 0.92321056f, 0.22521864f, 1.1123824f,
-0.9298287f, 1.2141289f, 1.3470556f, -0.32972014f, -1.6552197f, -1.0998285f,
0.71901864f, 0.962846f, -0.1366851f, -2.6534476f, -1.4992771f, -0.45793465f,
0.4290477f, 0.9893151f, 0.2511034f, 0.12906462f, 0.7491512f, 0.3316756f,
1.0576645f, -0.04618666f, 1.3556088f, 1.2842374f, 0.7103014f, 0.52889013f,
0.30327162f, 1.5069056f, 0.16591893f, 1.5719851f, -2.099427f, -1.010277f,
-0.52800924f, -0.22292352f, -0.55177474f, 1.3432894f, 0.8731192f, -0.01055307f,
-0.01138215f, 0.85698843f, -1.2615703f,
};
std::vector<float> in_B{
0.5336702f, 1.6593654f, -1.150011f, 0.00342217f, 0.799371f, 0.43780383f,
-0.55082625f, 1.0774187f, -0.6065135f, 0.6434064f, -1.5693754f, 1.4923384f,
1.1554348f, -1.328159f, 0.24995533f, 0.15112682f, -0.34698758f, -0.10088819f,
-0.2931625f, -0.47319615f, 0.66167855f, -1.1646721f, -0.09588219f, 0.5212928f,
0.37182367f, 0.27342287f, 1.1613405f, -0.75196224f, -1.5143642f, 0.20604452f,
};
std::vector<int32_t> in_sequence_lens{2, 3, 4};
std::vector<float> in_initial_h{
-0.4840693f,
-1.4054376f,
0.84533644f,
-0.1160888f,
-1.3724717f,
1.978259f,
-0.8500094f,
-2.0120409f,
0.89959633f,
-0.5367942f,
0.21188478f,
1.7603784f,
0.38752958f,
-0.06706902f,
-1.4290836f,
};
std::vector<float> in_bdir_W{
0.31403765f, -0.16793324f, 1.388258f, -0.6902954f, -0.3994045f, -0.7833511f,
-0.30992958f, 0.3557573f, -0.4682631f, 1.1741459f, -2.414789f, -0.42783254f,
-0.82199496f, -0.03900861f, -0.43670088f, -0.53810567f, -0.10769883f, 0.75242394f,
-0.2507971f, 1.0447186f, -1.4777364f, 0.19993274f, 0.925649f, -2.282516f,
0.95039636f, 1.5379831f, -0.88576007f, 0.28566247f, 0.79292643f, -0.04261953f,
0.8490583f, 0.45121244f, -1.1799014f, 0.13536449f, 0.81328654f, 0.6017516f,
0.48475724f, -1.2136037f, 0.16383322f, 1.5106261f, 1.1177503f, 0.23582461f,
0.5754652f, 0.43879887f, 0.7399294f, 0.4517558f, 1.3536783f, -0.4843166f,
-1.1503736f, -0.2458678f, 0.54523313f, -0.08649993f, -0.6936281f, 1.002422f,
-1.770847f, -0.94642f, -1.8135757f, 1.8819852f, -0.10852333f, -0.26120332f,
};
std::vector<float> in_bdir_R{
1.02231646e+00f, -7.46883690e-01f, 2.85669059e-01f, 9.23210561e-01f, 2.25218639e-01f,
1.11238241e+00f, -9.29828703e-01f, 1.21412885e+00f, 1.34705555e+00f, -3.29720140e-01f,
-1.65521967e+00f, -1.09982848e+00f, 7.19018638e-01f, 9.62845981e-01f, -1.36685103e-01f,
-2.65344763e+00f, -1.49927711e+00f, -4.57934648e-01f, 4.29047704e-01f, 9.89315093e-01f,
2.51103401e-01f, 1.29064620e-01f, 7.49151170e-01f, 3.31675589e-01f, 1.05766451e+00f,
-4.61866595e-02f, 1.35560882e+00f, 1.28423738e+00f, 7.10301399e-01f, 5.28890133e-01f,
3.03271621e-01f, 1.50690556e+00f, 1.65918931e-01f, 1.57198513e+00f, -2.09942698e+00f,
-1.01027703e+00f, -5.28009236e-01f, -2.22923517e-01f, -5.51774740e-01f, 1.34328938e+00f,
8.73119175e-01f, -1.05530666e-02f, -1.13821477e-02f, 8.56988430e-01f, -1.26157033e+00f,
5.33670187e-01f, 1.65936542e+00f, -1.15001094e+00f, 3.42216762e-03f, 7.99371004e-01f,
4.37803835e-01f, -5.50826252e-01f, 1.07741868e+00f, -6.06513500e-01f, 6.43406391e-01f,
-1.56937540e+00f, 1.49233842e+00f, 1.15543485e+00f, -1.32815897e+00f, 2.49955326e-01f,
1.51126817e-01f, -3.46987575e-01f, -1.00888193e-01f, -2.93162495e-01f, -4.73196149e-01f,
6.61678553e-01f, -1.16467214e+00f, -9.58821923e-02f, 5.21292806e-01f, 3.71823668e-01f,
2.73422867e-01f, 1.16134048e+00f, -7.51962245e-01f, -1.51436424e+00f, 2.06044525e-01f,
-4.84069288e-01f, -1.40543759e+00f, 8.45336437e-01f, -1.16088800e-01f, -1.37247169e+00f,
1.97825897e+00f, -8.50009382e-01f, -2.01204085e+00f, 8.99596334e-01f, -5.36794186e-01f,
2.11884782e-01f, 1.76037836e+00f, 3.87529582e-01f, -6.70690164e-02f, -1.42908359e+00f,
8.20716441e-01f, 7.34144002e-02f, 2.08005775e-02f, -3.74185145e-01f, 2.27367446e-01f,
-4.54556733e-01f, -2.24295408e-01f, 3.42533922e+00f, -3.13701063e-01f, 1.25070000e+00f,
-1.29529154e+00f, -4.87530619e-01f, 6.51176691e-01f, -8.81322920e-02f, -1.84014812e-01f,
-6.68743193e-01f, -2.83598930e-01f, 1.24322104e+00f, -1.03440486e-01f, -4.63501781e-01f,
1.72944975e+00f, -2.54249543e-01f, -1.60864544e+00f, 4.86483961e-01f, 7.00442135e-01f,
-1.71952701e+00f, -2.44922549e-01f, -5.80028534e-01f, 6.99496418e-02f, 3.74598980e-01f,
-1.19728017e+00f, 9.30128455e-01f, -2.42379427e-01f, 6.40181661e-01f, 2.04856300e+00f,
-1.27523863e+00f, -4.75532770e-01f, 3.02047610e-01f, -2.54939228e-01f, -1.33242559e+00f,
-8.23140562e-01f, -1.09450793e+00f, -1.70845091e-01f, 1.31205237e+00f, 2.28988096e-01f,
-5.51795721e-01f, -9.49851334e-01f, 1.28619313e+00f, 1.28273416e+00f, 2.92767227e-01f,
-3.92974496e-01f, 2.09084296e+00f, -1.28314102e+00f, -1.19076264e+00f, -3.52258608e-02f,
-4.47186083e-02f, 6.82157278e-01f, -2.59570718e-01f, 1.50172567e+00f, -2.76523419e-02f,
};
protected:
virtual void SetUp() override {}
};
NGRAPH_TEST_F(${BACKEND_NAME}, GRUSequenceOp, onnx_model_gru_defaults_fwd)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/gru_defaults_fwd.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>(in_X);
test_case.add_input<float>(in_W);
test_case.add_input<float>(in_R);
// Y
test_case.add_expected_output<float>(
Shape{4, 1, 3, 5},
std::vector<float>{
-0.3224981f, -0.44282594f, 0.7499796f, -0.12240417f, 0.12079421f, 0.02534253f,
0.02504562f, -0.0463777f, 0.01204534f, -0.01497037f, -0.04651929f, -0.6264307f,
0.7236632f, 0.06250653f, 0.02594197f, -0.06868916f, -0.5412897f, 0.49794048f,
0.22239858f, -0.11257736f, -0.23071964f, 0.26079988f, -0.07375772f, -0.21816255f,
0.18764113f, -0.5228772f, 0.00575754f, 0.2514028f, -0.58864325f, 0.49843538f,
-0.6129046f, -0.10794663f, 0.6544055f, -0.70105773f, 0.5397687f, -0.35791716f,
0.3885092f, -0.15291792f, -0.22324723f, 0.11557932f, -0.42112932f, 0.26772985f,
-0.38304564f, -0.05039781f, -0.5057976f, 0.5775348f, -0.6736855f, -0.20032284f,
0.03698462f, -0.7693824f, -0.5831348f, 0.25767964f, 0.7121098f, -0.35951245f,
0.39223647f, -0.6645166f, 0.37950075f, 0.59931314f, -0.4741001f, 0.21156166f,
});
// Y_h
test_case.add_expected_output<float>(Shape{1, 3, 5},
std::vector<float>{
0.5775348f,
-0.6736855f,
-0.20032284f,
0.03698462f,
-0.7693824f,
-0.5831348f,
0.25767964f,
0.7121098f,
-0.35951245f,
0.39223647f,
-0.6645166f,
0.37950075f,
0.59931314f,
-0.4741001f,
0.21156166f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 7);
}
NGRAPH_TEST_F(${BACKEND_NAME}, GRUSequenceOp, onnx_model_gru_fwd_activations)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/gru_fwd_activations.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>(in_X);
test_case.add_input<float>(in_W);
test_case.add_input<float>(in_R);
// Y
test_case.add_expected_output<float>(
Shape{4, 1, 3, 5},
std::vector<float>{
0.33636323f, 0.08874974f, 0.9804139f, 0.35797057f, -0.01193848f, 0.51011515f,
0.50988495f, 0.4592467f, 0.5048162f, 0.4940862f, 0.4825466f, 0.f,
0.9269162f, 0.3328298f, -0.18795171f, 0.69541144f, 0.7612694f, 0.937299f,
0.3463983f, 0.38764104f, 0.49957055f, 0.27359068f, 0.38423678f, 0.3618936f,
0.55977404f, 0.5223568f, 0.46266305f, 1.016379f, 0.22654215f, 0.6347567f,
0.53541327f, 0.46684968f, 1.0639775f, 0.21325049f, 0.70507824f, 0.48425108f,
-0.05370265f, 0.3055008f, 0.38166368f, 0.5645658f, 0.5998517f, 0.42573926f,
1.4539189f, 0.31789488f, 0.5567502f, 1.f, 0.92153484f, 1.4015231f,
0.24147032f, 0.5783859f, 0.42785907f, -0.5690068f, 0.69624555f, 0.32291538f,
0.68179333f, 0.50179297f, 0.0067991f, 2.043301f, 0.12669492f, 0.7062868f,
});
// Y_h
test_case.add_expected_output<float>(Shape{1, 3, 5},
std::vector<float>{
1.f,
0.92153484f,
1.4015231f,
0.24147032f,
0.5783859f,
0.42785907f,
-0.5690068f,
0.69624555f,
0.32291538f,
0.68179333f,
0.50179297f,
0.0067991f,
2.043301f,
0.12669492f,
0.7062868f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 5);
}
NGRAPH_TEST_F(${BACKEND_NAME}, GRUSequenceOp, onnx_model_gru_fwd_mixed_seq_len)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/gru_fwd_mixed_seq_len.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>(in_X);
test_case.add_input<float>(in_W);
test_case.add_input<float>(in_R);
test_case.add_input<float>(in_B);
test_case.add_input<int>(in_sequence_lens);
test_case.add_input<float>(in_initial_h);
// Y
test_case.add_expected_output<float>(
Shape{4, 1, 3, 5},
std::vector<float>{
-0.9559332f, 0.4372494f, 0.9967716f, -0.9079381f, -1.2538278f, 1.9265908f,
-0.8437393f, -1.2057271f, -0.25887525f, -0.52679026f, -0.3619178f, 0.67928517f,
0.9486744f, -0.12006134f, -1.3862017f, -0.98941356f, 0.80389524f, 0.97586197f,
-0.9343586f, -0.74858856f, 1.797039f, -0.7873732f, -0.72469383f, -0.5866635f,
-0.42103744f, -0.8406298f, 0.85877097f, 0.6349921f, -0.55897295f, -0.6168443f,
0.f, 0.f, 0.f, 0.f, 0.f, 1.577129f,
-0.6935871f, -0.304804f, -0.75392795f, -0.20703818f, -0.93796504f, 0.9220495f,
0.36017662f, -0.7007159f, 0.06962098f, 0.f, 0.f, 0.f,
0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
0.f, -0.96323603f, 0.9265786f, 0.54976916f, -0.8037839f, 0.73501444f,
});
// Y_h
test_case.add_expected_output<float>(Shape{1, 3, 5},
std::vector<float>{
-0.98941356f,
0.80389524f,
0.97586197f,
-0.9343586f,
-0.74858856f,
1.577129f,
-0.6935871f,
-0.304804f,
-0.75392795f,
-0.20703818f,
-0.96323603f,
0.9265786f,
0.54976916f,
-0.8037839f,
0.73501444f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 3);
}
NGRAPH_TEST_F(${BACKEND_NAME}, GRUSequenceOp, onnx_model_gru_rev_clip)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/gru_rev_clip.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>(in_X);
test_case.add_input<float>(in_W);
test_case.add_input<float>(in_R);
// Y
test_case.add_expected_output<float>(
Shape{4, 1, 3, 5},
std::vector<float>{
-0.50679326f, -0.8251296f, 0.7804218f, -0.1813852f, 0.00147036f, -0.18647355f,
0.38888037f, -0.07898733f, -0.05150563f, -0.23335457f, -0.21705005f, -0.2966391f,
0.67461425f, -0.1695634f, -0.09241624f, -0.10538863f, -0.6444952f, -0.01815936f,
-0.09695458f, -0.15107796f, -0.5036379f, 0.56125206f, 0.12785181f, -0.22290717f,
0.08662428f, -0.5849108f, 0.4789885f, -0.03569929f, -0.42043984f, 0.33464667f,
-0.01091215f, -0.42090097f, 0.24428985f, -0.6002675f, 0.27305228f, -0.35063627f,
0.3717615f, -0.00495788f, -0.00491725f, -0.27061304f, -0.3190831f, 0.3542383f,
-0.17784928f, -0.12995736f, -0.30778408f, 0.47168806f, -0.6330014f, -0.1905269f,
0.26708886f, -0.19741398f, -0.3995853f, -0.07459997f, 0.6749513f, -0.36566192f,
0.32173023f, -0.36364347f, 0.13916425f, 0.3908174f, -0.53085154f, 0.56740737f,
});
// Y_h
test_case.add_expected_output<float>(Shape{1, 3, 5},
std::vector<float>{
-0.50679326f,
-0.8251296f,
0.7804218f,
-0.1813852f,
0.00147036f,
-0.18647355f,
0.38888037f,
-0.07898733f,
-0.05150563f,
-0.23335457f,
-0.21705005f,
-0.2966391f,
0.67461425f,
-0.1695634f,
-0.09241624f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 8);
}
NGRAPH_TEST_F(${BACKEND_NAME}, GRUSequenceOp, onnx_model_gru_reverse)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/gru_reverse.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>(in_X);
test_case.add_input<float>(in_W);
test_case.add_input<float>(in_R);
// Y
test_case.add_expected_output<float>(
Shape{4, 1, 3, 5},
std::vector<float>{
-0.51097775f, -0.85767376f, 0.8065842f, -0.1832461f, -0.00109532f, -0.18766233f,
0.3910985f, -0.0617601f, -0.05733761f, -0.23259571f, -0.22787738f, -0.3715533f,
0.70320934f, -0.17635077f, -0.0972611f, -0.11218601f, -0.660165f, -0.03494868f,
-0.07503931f, -0.15422714f, -0.5053969f, 0.5710621f, 0.1448728f, -0.225453f,
0.07250313f, -0.5988957f, 0.48768237f, 0.00665835f, -0.42196327f, 0.2749501f,
-0.02106231f, -0.44533628f, 0.24044508f, -0.5907899f, 0.26883256f, -0.3462156f,
0.3782666f, 0.00699124f, -0.00378288f, -0.2990779f, -0.32031405f, 0.3363319f,
-0.1877775f, -0.10781199f, -0.40970552f, 0.47168806f, -0.6330014f, -0.1905269f,
0.26708886f, -0.19741398f, -0.3995853f, -0.07459997f, 0.691666f, -0.36566192f,
0.32173023f, -0.37267625f, 0.1103513f, 0.3908174f, -0.53085154f, 0.56740737f,
});
// Y_h
test_case.add_expected_output<float>(Shape{1, 3, 5},
std::vector<float>{
-0.51097775f,
-0.85767376f,
0.8065842f,
-0.1832461f,
-0.00109532f,
-0.18766233f,
0.3910985f,
-0.0617601f,
-0.05733761f,
-0.23259571f,
-0.22787738f,
-0.3715533f,
0.70320934f,
-0.17635077f,
-0.0972611f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 8);
}
NGRAPH_TEST_F(${BACKEND_NAME}, GRUSequenceOp, onnx_model_gru_fwd_bias_initial_h)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/gru_fwd_bias_initial_h.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>(in_X);
test_case.add_input<float>(in_W);
test_case.add_input<float>(in_R);
test_case.add_input<float>(in_B);
test_case.add_input<float>(in_initial_h);
// Y
test_case.add_expected_output<float>(
Shape{4, 1, 3, 5},
std::vector<float>{
-0.9559332f, 0.4372494f, 0.9967716f, -0.9079381f, -1.2538278f, 1.9265908f,
-0.8437393f, -1.2057271f, -0.25887525f, -0.52679026f, -0.3619178f, 0.67928517f,
0.9486744f, -0.12006134f, -1.3862017f, -0.98941356f, 0.80389524f, 0.97586197f,
-0.9343586f, -0.74858856f, 1.797039f, -0.7873732f, -0.72469383f, -0.5866635f,
-0.42103744f, -0.8406298f, 0.85877097f, 0.6349921f, -0.55897295f, -0.6168443f,
-0.99686503f, 0.87408733f, 0.87070423f, -0.9564345f, 0.52932394f, 1.577129f,
-0.6935871f, -0.304804f, -0.75392795f, -0.20703818f, -0.93796504f, 0.9220495f,
0.36017662f, -0.7007159f, 0.06962098f, -0.22581682f, 0.9119905f, -0.64628327f,
-0.79374063f, -0.82321495f, 1.2853851f, -0.6176347f, 0.6865668f, -0.85147655f,
0.0379298f, -0.96323603f, 0.9265786f, 0.54976916f, -0.8037839f, 0.73501444f,
});
// Y_h
test_case.add_expected_output<float>(Shape{1, 3, 5},
std::vector<float>{
-0.22581682f,
0.9119905f,
-0.64628327f,
-0.79374063f,
-0.82321495f,
1.2853851f,
-0.6176347f,
0.6865668f,
-0.85147655f,
0.0379298f,
-0.96323603f,
0.9265786f,
0.54976916f,
-0.8037839f,
0.73501444f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 4);
}
NGRAPH_TEST_F(${BACKEND_NAME}, GRUSequenceOp, onnx_model_gru_bidirectional)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/gru_bidirectional.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>(in_X);
test_case.add_input<float>(in_bdir_W);
test_case.add_input<float>(in_bdir_R);
// Y
test_case.add_expected_output<float>(
Shape{4, 2, 3, 5},
std::vector<float>{
-0.3224981f, -0.44282594f, 0.7499796f, -0.12240417f, 0.12079421f, 0.02534253f,
0.02504562f, -0.0463777f, 0.01204534f, -0.01497037f, -0.04651929f, -0.6264307f,
0.7236632f, 0.06250653f, 0.02594197f, 0.06575559f, 0.34565696f, -0.3178988f,
0.6183835f, -0.02136152f, 0.11640755f, -0.45138f, -0.64678776f, -0.09675756f,
-0.37742358f, 0.20918667f, -0.59024405f, -0.845524f, 0.60705113f, -0.6336088f,
-0.0833023f, -0.40062034f, 0.7579466f, -0.12340625f, 0.04415433f, -0.24662055f,
0.27420586f, -0.09122991f, -0.22768986f, 0.19980885f, -0.218649f, -0.5560231f,
0.56177044f, -0.25098884f, 0.15462328f, 0.02859182f, 0.22456945f, -0.16747908f,
-0.10665483f, 0.06054133f, 0.18795699f, -0.49318847f, -0.6660372f, -0.5589901f,
-0.42696574f, 0.25369287f, -0.7369056f, -0.73285f, -0.5750758f, -0.533177f,
-0.34549737f, -0.33324608f, 0.74590445f, -0.48038307f, 0.40253335f, -0.45753813f,
0.5987347f, -0.07046633f, -0.35819566f, 0.3916747f, -0.18096107f, -0.24415034f,
0.38435352f, -0.29881003f, 0.07738188f, -0.04626282f, -0.34389234f, 0.2419839f,
-0.01195046f, 0.12158976f, 0.1648429f, -0.4124067f, -0.4792929f, -0.498473f,
-0.28167045f, 0.19370168f, -0.6386781f, -0.42919028f, -0.47081998f, -0.2954276f,
0.47018337f, 0.01509789f, 0.43945605f, -0.31491262f, 0.14951898f, -0.7645583f,
0.2566264f, 0.7295435f, -0.5008343f, 0.57549477f, -0.50112087f, -0.11085765f,
0.5155622f, -0.5635352f, 0.54762024f, -0.26451954f, 0.17519262f, 0.5203082f,
0.6119683f, 0.01544304f, 0.11548323f, -0.14230084f, -0.2133323f, -0.3981219f,
-0.06852704f, 0.17058733f, -0.6941011f, -0.27862304f, -0.27050856f, -0.03864266f,
});
// Y_h
test_case.add_expected_output<float>(
Shape{2, 3, 5},
std::vector<float>{
0.47018337f, 0.01509789f, 0.43945605f, -0.31491262f, 0.14951898f, -0.7645583f,
0.2566264f, 0.7295435f, -0.5008343f, 0.57549477f, -0.50112087f, -0.11085765f,
0.5155622f, -0.5635352f, 0.54762024f, 0.06575559f, 0.34565696f, -0.3178988f,
0.6183835f, -0.02136152f, 0.11640755f, -0.45138f, -0.64678776f, -0.09675756f,
-0.37742358f, 0.20918667f, -0.59024405f, -0.845524f, 0.60705113f, -0.6336088f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 6);
}
NGRAPH_TEST_F(${BACKEND_NAME}, GRUSequenceOp, onnx_model_gru_fwd_linear_before_reset)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/gru_fwd_linear_before_reset.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>(in_X);
test_case.add_input<float>(in_W);
test_case.add_input<float>(in_R);
test_case.add_input<float>(in_B);
// Y
test_case.add_expected_output<float>(
Shape{4, 1, 3, 5},
std::vector<float>{
-0.32330805f, -0.06708707f, 0.9148428f, -0.5182527f, 0.15030569f, -0.29070354f,
0.20353599f, 0.36028495f, -0.5524303f, 0.15076958f, -0.3330416f, -0.2412689f,
0.90464234f, -0.46817362f, 0.08000847f, -0.63514394f, 0.25109228f, 0.7674645f,
-0.7781104f, -0.07633221f, -0.5679979f, 0.32793444f, 0.18232828f, -0.756521f,
0.07898282f, -0.7205035f, -0.02278003f, -0.14991446f, -0.86801296f, 0.4434091f,
-0.8497459f, 0.35516143f, 0.8932138f, -0.8957482f, 0.4693949f, -0.74337614f,
0.43600178f, 0.51654255f, -0.8376663f, -0.18606272f, -0.8050637f, 0.06592449f,
0.13366115f, -0.8945458f, -0.66395104f, 0.140306f, 0.42112982f, -0.15852913f,
-0.74940586f, -0.7907575f, -0.89268315f, 0.5274858f, 0.97432905f, -0.89276016f,
0.15256537f, -0.91766477f, 0.22483218f, 0.9143838f, -0.9442929f, 0.33684915f,
});
// Y_h
test_case.add_expected_output<float>(Shape{1, 3, 5},
std::vector<float>{
0.140306f,
0.42112982f,
-0.15852913f,
-0.74940586f,
-0.7907575f,
-0.89268315f,
0.5274858f,
0.97432905f,
-0.89276016f,
0.15256537f,
-0.91766477f,
0.22483218f,
0.9143838f,
-0.9442929f,
0.33684915f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 4);
}
// RNNLikeSequenceOp test fixture for test setup reuse
class RNNSequenceOp : public testing::Test
{
public:
std::vector<float> in_X{
0.68172926f, 1.1405563f, -0.03931177f, -0.03759607f, 0.22778925f, 1.2471468f,
0.2785642f, 0.5198979f, 0.3712886f, -0.3194908f, 0.8448233f, -0.62065625f,
1.2968333f, -0.20370148f, 0.40204826f, -0.23721986f, 0.3629822f, -0.3819832f,
-0.7766345f, 0.19374144f, 1.1397027f, 0.60444903f, 1.3246384f, -0.28191715f,
};
std::vector<float> in_W{
0.31403765f,
-0.16793324f,
1.388258f,
-0.6902954f,
-0.3994045f,
-0.7833511f,
-0.30992958f,
0.3557573f,
-0.4682631f,
1.1741459f,
};
std::vector<float> in_R{
-2.414789f, -0.42783254f, -0.82199496f, -0.03900861f, -0.43670088f,
-0.53810567f, -0.10769883f, 0.75242394f, -0.2507971f, 1.0447186f,
-1.4777364f, 0.19993274f, 0.925649f, -2.282516f, 0.95039636f,
1.5379831f, -0.88576007f, 0.28566247f, 0.79292643f, -0.04261953f,
0.8490583f, 0.45121244f, -1.1799014f, 0.13536449f, 0.81328654f,
};
std::vector<float> in_B{
0.6017516f,
0.48475724f,
-1.2136037f,
0.16383322f,
1.5106261f,
1.1177503f,
0.23582461f,
0.5754652f,
0.43879887f,
0.7399294f,
};
std::vector<int32_t> in_sequence_lens{2, 3, 4};
std::vector<float> in_initial_h{
0.4517558f,
1.3536783f,
-0.4843166f,
-1.1503736f,
-0.2458678f,
0.54523313f,
-0.08649993f,
-0.6936281f,
1.002422f,
-1.770847f,
-0.94642f,
-1.8135757f,
1.8819852f,
-0.10852333f,
-0.26120332f,
};
std::vector<float> in_bdir_W{
0.31403765f, -0.16793324f, 1.388258f, -0.6902954f, -0.3994045f,
-0.7833511f, -0.30992958f, 0.3557573f, -0.4682631f, 1.1741459f,
-2.414789f, -0.42783254f, -0.82199496f, -0.03900861f, -0.43670088f,
-0.53810567f, -0.10769883f, 0.75242394f, -0.2507971f, 1.0447186f,
};
std::vector<float> in_bdir_R{
-1.4777364f, 0.19993274f, 0.925649f, -2.282516f, 0.95039636f, 1.5379831f,
-0.88576007f, 0.28566247f, 0.79292643f, -0.04261953f, 0.8490583f, 0.45121244f,
-1.1799014f, 0.13536449f, 0.81328654f, 0.6017516f, 0.48475724f, -1.2136037f,
0.16383322f, 1.5106261f, 1.1177503f, 0.23582461f, 0.5754652f, 0.43879887f,
0.7399294f, 0.4517558f, 1.3536783f, -0.4843166f, -1.1503736f, -0.2458678f,
0.54523313f, -0.08649993f, -0.6936281f, 1.002422f, -1.770847f, -0.94642f,
-1.8135757f, 1.8819852f, -0.10852333f, -0.26120332f, 1.0223165f, -0.7468837f,
0.28566906f, 0.92321056f, 0.22521864f, 1.1123824f, -0.9298287f, 1.2141289f,
1.3470556f, -0.32972014f,
};
protected:
virtual void SetUp() override {}
};
NGRAPH_TEST_F(${BACKEND_NAME}, RNNSequenceOp, onnx_model_rnn_defaults_fwd)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/rnn_defaults_fwd.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>(in_X);
test_case.add_input<float>(in_W);
test_case.add_input<float>(in_R);
// Y
test_case.add_expected_output<float>(
Shape{4, 1, 3, 5},
std::vector<float>{
0.02254748f, 0.15776646f, -0.8229023f, 0.19205809f, 0.76984656f, -0.00603169f,
-0.02861464f, 0.04512155f, -0.0011912f, -0.02572936f, -0.13703543f, -0.49651444f,
-0.78868157f, 0.3566854f, 0.8758509f, 0.20788848f, 0.13481987f, -0.756822f,
-0.121436f, 0.97542346f, 0.16959739f, 0.63496053f, 0.1245538f, -0.1970138f,
-0.56581646f, 0.8225869f, 0.9611373f, -0.42990375f, -0.22925597f, 0.2226491f,
0.08246052f, 0.9798831f, -0.13415998f, -0.5567714f, 0.78594816f, -0.34759718f,
0.11376679f, -0.07107389f, -0.5420871f, -0.58504283f, -0.96065646f, 0.18588805f,
-0.4870671f, -0.1475982f, 0.82456505f, -0.80264574f, -0.46370947f, 0.9719335f,
-0.7374159f, 0.94937694f, 0.8814341f, 0.67015004f, 0.21958017f, -0.8332769f,
-0.487742f, 0.9918536f, 0.99563396f, 0.94866276f, -0.98504806f, -0.42824882f,
});
// Y_h
test_case.add_expected_output<float>(Shape{1, 3, 5},
std::vector<float>{
-0.80264574f,
-0.46370947f,
0.9719335f,
-0.7374159f,
0.94937694f,
0.8814341f,
0.67015004f,
0.21958017f,
-0.8332769f,
-0.487742f,
0.9918536f,
0.99563396f,
0.94866276f,
-0.98504806f,
-0.42824882f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 4);
}
NGRAPH_TEST_F(${BACKEND_NAME}, RNNSequenceOp, onnx_model_rnn_fwd_activations)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/rnn_fwd_activations.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>(in_X);
test_case.add_input<float>(in_W);
test_case.add_input<float>(in_R);
// Y
test_case.add_expected_output<float>(
Shape{4, 1, 3, 5},
std::vector<float>{
0.02255133f, 0.15909529f, 0.f, 0.19447318f, 1.019951f, 0.f,
0.f, 0.04515222f, 0.f, 0.f, 0.f, 0.f,
0.f, 0.37308297f, 1.3576671f, 0.f, 1.015355f, 0.00543064f,
0.10311858f, 1.426765f, 0.13313684f, 0.769961f, 0.14377424f, 0.f,
0.f, 0.f, 2.9260807f, 0.5875195f, 0.f, 0.030334f,
0.f, 3.300393f, 0.97026074f, 0.f, 0.7796261f, 0.f,
0.6755121f, 0.1155303f, 0.f, 0.f, 0.f, 0.92621297f,
1.3119358f, 0.f, 0.03326398f, 0.f, 0.f, 2.4573548f,
0.f, 1.5695758f, 0.f, 1.1791289f, 0.f, 0.f,
0.34451577f, 0.f, 2.9556773f, 1.12296f, 0.f, 0.f,
});
// Y_h
test_case.add_expected_output<float>(Shape{1, 3, 5},
std::vector<float>{
0.f,
0.f,
2.4573548f,
0.f,
1.5695758f,
0.f,
1.1791289f,
0.f,
0.f,
0.34451577f,
0.f,
2.9556773f,
1.12296f,
0.f,
0.f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 3);
}
NGRAPH_TEST_F(${BACKEND_NAME}, RNNSequenceOp, onnx_model_rnn_fwd_mixed_seq_len)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/rnn_fwd_mixed_seq_len.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>(in_X);
test_case.add_input<float>(in_W);
test_case.add_input<float>(in_R);
test_case.add_input<float>(in_B);
test_case.add_input<int>(in_sequence_lens);
test_case.add_input<float>(in_initial_h);
// Y
test_case.add_expected_output<float>(
Shape{4, 1, 3, 5},
std::vector<float>{
0.55277014f, 0.15672898f, -0.25152922f, -0.63345766f, 0.99974346f, 0.94002223f,
-0.97647303f, -0.9999884f, 0.9752002f, 0.97388494f, 0.9967754f, 0.96745205f,
0.7899921f, 0.92003024f, -0.43116868f, 0.11219919f, 0.895327f, 0.21749747f,
0.6617017f, 0.99962795f, 0.37670398f, 0.7918401f, -0.99966455f, 0.9961897f,
0.9995159f, -0.84224236f, 0.92083716f, -0.99834263f, 0.9435711f, 0.8485148f,
0.f, 0.f, 0.f, 0.f, 0.f, 0.75459063f,
0.8326433f, -0.99705976f, 0.62511444f, 0.99979305f, 0.99925995f, 0.94032586f,
-0.86841005f, -0.8692311f, 0.9974319f, 0.f, 0.f, 0.f,
0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
0.f, -0.30979204f, 0.99138904f, -0.10645419f, -0.18203181f, 0.9996245f,
});
// Y_h
test_case.add_expected_output<float>(Shape{1, 3, 5},
std::vector<float>{
0.11219919f,
0.895327f,
0.21749747f,
0.6617017f,
0.99962795f,
0.75459063f,
0.8326433f,
-0.99705976f,
0.62511444f,
0.99979305f,
-0.30979204f,
0.99138904f,
-0.10645419f,
-0.18203181f,
0.9996245f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 3);
}
NGRAPH_TEST_F(${BACKEND_NAME}, RNNSequenceOp, onnx_model_rnn_rev_clip)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/rnn_rev_clip.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>(in_X);
test_case.add_input<float>(in_W);
test_case.add_input<float>(in_R);
// Y
test_case.add_expected_output<float>(
Shape{4, 1, 3, 5},
std::vector<float>{
0.9416027f, 0.6461365f, -0.8407804f, -0.33646506f, 0.92833483f, -0.9416027f,
0.65075886f, 0.9416027f, -0.33576548f, -0.10364902f, -0.9416027f, -0.832458f,
-0.18187332f, 0.5103179f, 0.5227027f, -0.9416027f, -0.90681225f, -0.9416027f,
0.5091027f, 0.8053496f, 0.6005076f, 0.92147183f, 0.9416027f, -0.8985506f,
0.28120112f, 0.9416027f, 0.9416027f, 0.9416027f, -0.92463756f, -0.9416027f,
0.79248047f, 0.9416027f, -0.1611281f, 0.11231542f, -0.8230629f, -0.2566173f,
0.16398644f, -0.36077273f, -0.70470357f, 0.8011706f, -0.59314847f, -0.41942674f,
-0.20039755f, -0.6877927f, -0.13850075f, -0.26959598f, -0.8372509f, 0.15711153f,
0.3000977f, 0.53072214f, 0.25092757f, 0.82264745f, -0.72998637f, -0.13731742f,
0.17423475f, 0.43279397f, 0.9416027f, -0.2988227f, -0.4705984f, -0.74036705f,
});
// Y_h
test_case.add_expected_output<float>(Shape{1, 3, 5},
std::vector<float>{
0.9416027f,
0.6461365f,
-0.8407804f,
-0.33646506f,
0.92833483f,
-0.9416027f,
0.65075886f,
0.9416027f,
-0.33576548f,
-0.10364902f,
-0.9416027f,
-0.832458f,
-0.18187332f,
0.5103179f,
0.5227027f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 3);
}
NGRAPH_TEST_F(${BACKEND_NAME}, RNNSequenceOp, onnx_model_rnn_reverse)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/rnn_reverse.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>(in_X);
test_case.add_input<float>(in_W);
test_case.add_input<float>(in_R);
// Y
test_case.add_expected_output<float>(
Shape{4, 1, 3, 5},
std::vector<float>{
0.9963336f, 0.63758683f, -0.82404625f, -0.38524252f, 0.9350034f, -0.9918621f,
0.67038023f, 0.9884596f, -0.32398474f, -0.15730727f, -0.9970634f, -0.831641f,
-0.19750828f, 0.5491314f, 0.5148814f, -0.9517943f, -0.9077764f, -0.9906229f,
0.4751265f, 0.81323147f, 0.6005076f, 0.92147183f, 0.9878793f, -0.8985506f,
0.28120112f, 0.97769725f, 0.95308435f, 0.9777889f, -0.9270168f, -0.9459193f,
0.79248047f, 0.99223363f, -0.1611281f, 0.11231542f, -0.8230629f, -0.2566173f,
0.16398644f, -0.36077273f, -0.70470357f, 0.8011706f, -0.59996057f, -0.42161822f,
-0.19564903f, -0.6991576f, -0.12754434f, -0.26959598f, -0.8372509f, 0.15711153f,
0.3000977f, 0.53072214f, 0.25092757f, 0.82264745f, -0.72998637f, -0.13731742f,
0.17423475f, 0.43279397f, 0.96632254f, -0.2988227f, -0.4705984f, -0.74036705f,
});
// Y_h
test_case.add_expected_output<float>(Shape{1, 3, 5},
std::vector<float>{
0.9963336f,
0.63758683f,
-0.82404625f,
-0.38524252f,
0.9350034f,
-0.9918621f,
0.67038023f,
0.9884596f,
-0.32398474f,
-0.15730727f,
-0.9970634f,
-0.831641f,
-0.19750828f,
0.5491314f,
0.5148814f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 3);
}
NGRAPH_TEST_F(${BACKEND_NAME}, RNNSequenceOp, onnx_model_rnn_fwd_bias_initial_h)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/rnn_fwd_bias_initial_h.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>(in_X);
test_case.add_input<float>(in_W);
test_case.add_input<float>(in_R);
test_case.add_input<float>(in_B);
test_case.add_input<float>(in_initial_h);
// Y
test_case.add_expected_output<float>(
Shape{4, 1, 3, 5},
std::vector<float>{
0.55277014f, 0.15672898f, -0.25152922f, -0.63345766f, 0.99974346f, 0.94002223f,
-0.97647303f, -0.9999884f, 0.9752002f, 0.97388494f, 0.9967754f, 0.96745205f,
0.7899921f, 0.92003024f, -0.43116868f, 0.11219919f, 0.895327f, 0.21749747f,
0.6617017f, 0.99962795f, 0.37670398f, 0.7918401f, -0.99966455f, 0.9961897f,
0.9995159f, -0.84224236f, 0.92083716f, -0.99834263f, 0.9435711f, 0.8485148f,
0.699257f, 0.9983405f, -0.87222385f, 0.05191362f, 0.9878634f, 0.75459063f,
0.8326433f, -0.99705976f, 0.62511444f, 0.99979305f, 0.99925995f, 0.94032586f,
-0.86841005f, -0.8692311f, 0.9974319f, -0.37055743f, -0.54580235f, -0.8618355f,
0.6927968f, 0.99997866f, 0.15482295f, 0.90996563f, -0.9992051f, 0.784014f,
0.9999677f, -0.30979204f, 0.99138904f, -0.10645419f, -0.18203181f, 0.9996245f,
});
// Y_h
test_case.add_expected_output<float>(Shape{1, 3, 5},
std::vector<float>{
-0.37055743f,
-0.54580235f,
-0.8618355f,
0.6927968f,
0.99997866f,
0.15482295f,
0.90996563f,
-0.9992051f,
0.784014f,
0.9999677f,
-0.30979204f,
0.99138904f,
-0.10645419f,
-0.18203181f,
0.9996245f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 5);
}
NGRAPH_TEST_F(${BACKEND_NAME}, RNNSequenceOp, onnx_model_rnn_bidirectional)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/rnn_bidirectional.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>(in_X);
test_case.add_input<float>(in_bdir_W);
test_case.add_input<float>(in_bdir_R);
// Y
test_case.add_expected_output<float>(
Shape{4, 2, 3, 5},
std::vector<float>{
0.02254748f, 0.15776646f, -0.8229023f, 0.19205809f, 0.76984656f, -0.00603169f,
-0.02861464f, 0.04512155f, -0.0011912f, -0.02572936f, -0.13703543f, -0.49651444f,
-0.78868157f, 0.3566854f, 0.8758509f, -0.9964788f, -0.15236056f, 0.5478349f,
0.14500666f, 0.61871886f, 0.03722596f, -0.81331265f, 0.99774206f, -0.888188f,
-0.5575663f, -0.9284624f, -0.5595875f, 0.9986867f, -0.18373811f, 0.8451735f,
-0.43823165f, -0.1904698f, 0.8320786f, 0.9830735f, 0.61861455f, 0.19109797f,
0.6440699f, 0.00962079f, -0.32752872f, -0.5050589f, -0.23455954f, 0.9517933f,
0.9050665f, 0.91091585f, -0.77941567f, -0.9915407f, -0.23976672f, 0.04337811f,
0.2958206f, -0.3979709f, -0.9083327f, -0.21814531f, 0.9981259f, -0.8650538f,
-0.4886601f, -0.8349008f, -0.7880142f, 0.99017143f, -0.9816452f, -0.93827677f,
0.16374564f, 0.98451114f, -0.821692f, -0.6319715f, -0.01324981f, 0.28117967f,
0.20685172f, 0.01166677f, -0.5441829f, -0.5463746f, -0.85301256f, 0.52109087f,
-0.8317892f, -0.9676957f, -0.30258918f, -0.9810498f, -0.83153796f, -0.9676579f,
0.5483788f, 0.42533123f, -0.9851954f, -0.5354376f, 0.6905062f, -0.46665573f,
-0.851916f, -0.9073148f, 0.16276085f, 0.9518349f, -0.8635942f, -0.92539954f,
0.33436012f, -0.988292f, 0.9238765f, 0.94239855f, 0.24151397f, 0.5482547f,
0.76547384f, -0.81047577f, -0.6625802f, -0.09694612f, 0.9948462f, -0.6242633f,
-0.19065344f, -0.36072153f, -0.99407107f, 0.94602585f, 0.55862486f, 0.2306763f,
0.22547626f, 0.37753606f, -0.9951596f, -0.74445903f, -0.6766813f, 0.32036817f,
0.33250773f, -0.9957684f, -0.7924f, -0.40261805f, -0.34061068f, -0.55580306f,
});
// Y_h
test_case.add_expected_output<float>(
Shape{2, 3, 5},
std::vector<float>{
0.33436012f, -0.988292f, 0.9238765f, 0.94239855f, 0.24151397f, 0.5482547f,
0.76547384f, -0.81047577f, -0.6625802f, -0.09694612f, 0.9948462f, -0.6242633f,
-0.19065344f, -0.36072153f, -0.99407107f, -0.9964788f, -0.15236056f, 0.5478349f,
0.14500666f, 0.61871886f, 0.03722596f, -0.81331265f, 0.99774206f, -0.888188f,
-0.5575663f, -0.9284624f, -0.5595875f, 0.9986867f, -0.18373811f, 0.8451735f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 6);
}
NGRAPH_TEST_F(${BACKEND_NAME}, RNNSequenceOp, onnx_model_rnn_bidirectional_const)
{
auto function = onnx_import::import_onnx_model(
file_util::path_join(SERIALIZED_ZOO, "onnx/rnn_bidirectional_const.prototxt"));
auto test_case = ngraph::test::NgraphTestCase(function, "${BACKEND_NAME}");
test_case.add_input<float>(in_X);
// Y
test_case.add_expected_output<float>(
Shape{4, 2, 3, 5},
std::vector<float>{
0.02254748f, 0.15776646f, -0.8229023f, 0.19205809f, 0.76984656f, -0.00603169f,
-0.02861464f, 0.04512155f, -0.0011912f, -0.02572936f, -0.13703543f, -0.49651444f,
-0.78868157f, 0.3566854f, 0.8758509f, -0.9964788f, -0.15236056f, 0.5478349f,
0.14500666f, 0.61871886f, 0.03722596f, -0.81331265f, 0.99774206f, -0.888188f,
-0.5575663f, -0.9284624f, -0.5595875f, 0.9986867f, -0.18373811f, 0.8451735f,
-0.43823165f, -0.1904698f, 0.8320786f, 0.9830735f, 0.61861455f, 0.19109797f,
0.6440699f, 0.00962079f, -0.32752872f, -0.5050589f, -0.23455954f, 0.9517933f,
0.9050665f, 0.91091585f, -0.77941567f, -0.9915407f, -0.23976672f, 0.04337811f,
0.2958206f, -0.3979709f, -0.9083327f, -0.21814531f, 0.9981259f, -0.8650538f,
-0.4886601f, -0.8349008f, -0.7880142f, 0.99017143f, -0.9816452f, -0.93827677f,
0.16374564f, 0.98451114f, -0.821692f, -0.6319715f, -0.01324981f, 0.28117967f,
0.20685172f, 0.01166677f, -0.5441829f, -0.5463746f, -0.85301256f, 0.52109087f,
-0.8317892f, -0.9676957f, -0.30258918f, -0.9810498f, -0.83153796f, -0.9676579f,
0.5483788f, 0.42533123f, -0.9851954f, -0.5354376f, 0.6905062f, -0.46665573f,
-0.851916f, -0.9073148f, 0.16276085f, 0.9518349f, -0.8635942f, -0.92539954f,
0.33436012f, -0.988292f, 0.9238765f, 0.94239855f, 0.24151397f, 0.5482547f,
0.76547384f, -0.81047577f, -0.6625802f, -0.09694612f, 0.9948462f, -0.6242633f,
-0.19065344f, -0.36072153f, -0.99407107f, 0.94602585f, 0.55862486f, 0.2306763f,
0.22547626f, 0.37753606f, -0.9951596f, -0.74445903f, -0.6766813f, 0.32036817f,
0.33250773f, -0.9957684f, -0.7924f, -0.40261805f, -0.34061068f, -0.55580306f,
});
// Y_h
test_case.add_expected_output<float>(
Shape{2, 3, 5},
std::vector<float>{
0.33436012f, -0.988292f, 0.9238765f, 0.94239855f, 0.24151397f, 0.5482547f,
0.76547384f, -0.81047577f, -0.6625802f, -0.09694612f, 0.9948462f, -0.6242633f,
-0.19065344f, -0.36072153f, -0.99407107f, -0.9964788f, -0.15236056f, 0.5478349f,
0.14500666f, 0.61871886f, 0.03722596f, -0.81331265f, 0.99774206f, -0.888188f,
-0.5575663f, -0.9284624f, -0.5595875f, 0.9986867f, -0.18373811f, 0.8451735f,
});
test_case.run(DEFAULT_FLOAT_TOLERANCE_BITS + 6);
}

View File

@@ -1349,6 +1349,25 @@ IE_CPU.onnx_dyn_shapes_model_flatten
IE_CPU.onnx_dyn_shapes_slice_10_default_axes
IE_CPU.fused_clamp_float
# GRUCell operation has a form that is not supported
IE_CPU.onnx_model_gru_defaults_fwd
IE_CPU.onnx_model_gru_fwd_activations
IE_CPU.onnx_model_gru_fwd_mixed_seq_len
IE_CPU.onnx_model_gru_rev_clip
IE_CPU.onnx_model_gru_reverse
IE_CPU.onnx_model_gru_fwd_bias_initial_h
IE_CPU.onnx_model_gru_bidirectional
IE_CPU.onnx_model_gru_fwd_linear_before_reset
# RNNCell operation has a form that is not supported
IE_CPU.onnx_model_rnn_defaults_fwd
IE_CPU.onnx_model_rnn_fwd_activations
IE_CPU.onnx_model_rnn_fwd_mixed_seq_len
IE_CPU.onnx_model_rnn_rev_clip
IE_CPU.onnx_model_rnn_reverse
IE_CPU.onnx_model_rnn_fwd_bias_initial_h
IE_CPU.onnx_model_rnn_bidirectional
IE_CPU.onnx_model_rnn_bidirectional_const
#-------------------------------------------------------------------------------
#
# Inference Engine GPU plugin excludes

View File

@@ -42,7 +42,7 @@ TEST(type_prop, lstm_sequence_forward)
Shape{num_directions, 4 * hidden_size, hidden_size});
const auto B = make_shared<op::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
const auto lstm_direction = op::LSTMSequence::direction::FORWARD;
const auto lstm_direction = op::RecurrentSequenceDirection::FORWARD;
const auto lstm_sequence = make_shared<op::LSTMSequence>(X,
initial_hidden_state,
@@ -53,8 +53,9 @@ TEST(type_prop, lstm_sequence_forward)
B,
hidden_size,
lstm_direction);
EXPECT_EQ(lstm_sequence->get_hidden_size(), hidden_size);
EXPECT_EQ(lstm_sequence->get_direction(), op::LSTMSequence::direction::FORWARD);
EXPECT_EQ(lstm_sequence->get_direction(), op::RecurrentSequenceDirection::FORWARD);
EXPECT_EQ(lstm_sequence->get_weights_format(), op::LSTMWeightsFormat::IFCO);
EXPECT_TRUE(lstm_sequence->get_activations_alpha().empty());
EXPECT_TRUE(lstm_sequence->get_activations_beta().empty());