LSTMCell/Sequence v1, reference implementations and decompose transformations for LSTM/GRU/RNN Cells (#2000)

* validate_and_infer_types() implementation

* input parameter validation for LSTM, GRU and RNN

* style-check applied

* Add LSTMSequence dynamic shape validation and test props for RNNCell, GRUCell, LSTMCell and LSTMSequence.

* recurrent_sequence.hpp moved to ngraph/core/include/ngraph/op/util/

* style check applied

* removed unused variable from LSTMSequence::validate_and_infer_types

* Add missing newline mark at the end of file.

* Add supression macro for FusedOp deprecation.

* Add element type initialization

* transpose,rnn cell reference implementations

* Apply PR review remarks

* reference implementations for cells op, single layer tests, align lstm cell/sequence according to the spec

* lstm/gru/rnn cell decompostion transformations

* ngraph codestyle

* clean up

* ngraph code style

* change inheritance of Cells, fix build

* fix build

* fix build again

* remove Peepholes from LSTMSeq, fix copy_runtime_info in transformations

* Rewrite tests to use gtest exception assertions.

* resolve tests issues

* ngraph codestyle

* add missed files

* fix typeprop tests

* fix lstm sequence checks

* fix arm build

* fix arm again

* delete unnecessary file

* add convert weghts format function, enable lstm test, resolve review comments

* add ngraph builders

* ngraph codestyle

* fix unit tests

* revert transpose reference implementation

* revert LSTM Cell v0, add LSTMCell v1, update transformation lstm_cell_to_cell_ie

* v1 version of LSTMCell op

* LSTMSequence v1 operation, exclude LSTMSeq from opset4

* fix python api tests

* resolve review comments, tests for decomposition transformations, switch lstm cell to opset4 in mo

Co-authored-by: Szymon Durawa <szymon.durawa@intel.com>
This commit is contained in:
Ivan Tikhonov 2020-09-04 09:04:36 +03:00 committed by GitHub
parent 28eed7708e
commit 2f5a28d44f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
65 changed files with 4695 additions and 1248 deletions

View File

@ -410,6 +410,29 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr
return res;
});
addSpecificCreator({"LSTMCellIE"}, [](const std::shared_ptr<::ngraph::Node>& node,
const std::map<std::string, std::string> params) -> CNNLayerPtr {
LayerParams attrs = {node->get_friendly_name(), "LSTMCell",
details::convertPrecision(node->get_output_element_type(0))};
auto res = std::make_shared<LSTMCell>(attrs);
res->params = params;
Builder::NodeConverter<ngraph::op::Constant> converter;
const auto weightsNode = node->input_value(3).get_node_shared_ptr();
if (converter.canCreate(weightsNode)) {
const auto& weights = converter.createLayer(weightsNode);
res->blobs["weights"] = weights->blobs["custom"];
res->_weights = weights->blobs["custom"];
}
const auto biasNode = node->input_value(4).get_node_shared_ptr();
if (converter.canCreate(biasNode)) {
const auto& bias = converter.createLayer(biasNode);
res->blobs["biases"] = bias->blobs["custom"];
res->_biases = bias->blobs["custom"];
}
return res;
});
addSpecificCreator({"RNNCellIE"}, [](const std::shared_ptr<::ngraph::Node>& node,
const std::map<std::string, std::string>& params) -> CNNLayerPtr {
LayerParams attrs = {node->get_friendly_name(), "RNNCell",
@ -672,7 +695,6 @@ void convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function
std::make_shared<Builder::NodeConverter<::ngraph::op::TopKIE>>(),
std::make_shared<Builder::NodeConverter<::ngraph::op::Unsqueeze>>(),
std::make_shared<Builder::NodeConverter<::ngraph::op::TensorIterator>>(),
std::make_shared<Builder::NodeConverter<::ngraph::op::LSTMCellIE>>(),
std::make_shared<Builder::NodeConverter<::ngraph::op::HardSigmoid_IE>>(),
std::make_shared<Builder::NodeConverter<::ngraph::op::v1::LogicalNot>>(),
std::make_shared<Builder::NodeConverter<::ngraph::op::ShuffleChannels>>(),

View File

@ -1866,54 +1866,6 @@ CNNLayer::Ptr NodeConverter<ngraph::op::FullyConnected>::createLayer(const std::
return res;
}
template <>
CNNLayer::Ptr NodeConverter<ngraph::op::LSTMCellIE>::createLayer(const std::shared_ptr<ngraph::Node>& layer) const {
LayerParams params = {layer->get_friendly_name(), "LSTMCell",
details::convertPrecision(layer->get_output_element_type(0))};
auto castedLayer = ngraph::as_type_ptr<ngraph::op::LSTMCellIE>(layer);
if (castedLayer == nullptr) THROW_IE_EXCEPTION << "Cannot get " << params.type << " layer " << params.name;
auto res = std::make_shared<InferenceEngine::LSTMCell>(params);
res->params["hidden_size"] = asString(castedLayer->get_hidden_size());
std::string value;
for (const auto& val : castedLayer->get_activations()) {
if (!value.empty()) value += ",";
value += val;
}
res->params["activations"] = value;
value.clear();
for (const auto& val : castedLayer->get_activations_alpha()) {
if (!value.empty()) value += ",";
value += val;
}
res->params["activations_alpha"] = value;
value.clear();
for (const auto& val : castedLayer->get_activations_beta()) {
if (!value.empty()) value += ",";
value += val;
}
res->params["activations_beta"] = value;
res->params["clip"] = asString(castedLayer->get_clip());
NodeConverter<ngraph::op::Constant> converter;
const auto weightsNode = layer->input_value(3).get_node_shared_ptr();
if (converter.canCreate(weightsNode)) {
const auto& weights = converter.createLayer(weightsNode);
res->blobs["weights"] = weights->blobs["custom"];
res->_weights = weights->blobs["custom"];
}
const auto biasNode = layer->input_value(4).get_node_shared_ptr();
if (converter.canCreate(biasNode)) {
const auto& bias = converter.createLayer(biasNode);
res->blobs["biases"] = bias->blobs["custom"];
res->_biases = bias->blobs["custom"];
}
return res;
}
template <>
CNNLayer::Ptr NodeConverter<ngraph::op::MatMul>::createLayer(const std::shared_ptr<ngraph::Node>& layer) const {
LayerParams params = {layer->get_friendly_name(), "Gemm",

View File

@ -439,7 +439,7 @@ std::shared_ptr<ngraph::Node> V10Parser::createNode(const std::vector<ngraph::Ou
std::make_shared<LayerCreator<ngraph::op::v1::Select>>("Select"),
std::make_shared<LayerCreator<ngraph::op::LRN>>("LRN"),
std::make_shared<LayerCreator<ngraph::op::MVN>>("MVN"),
std::make_shared<LayerCreator<ngraph::op::LSTMCell>>("LSTMCell"),
std::make_shared<LayerCreator<ngraph::op::v0::LSTMCell>>("LSTMCell"),
std::make_shared<LayerCreator<ngraph::op::v1::MaxPool>>("MaxPool"),
std::make_shared<LayerCreator<ngraph::op::v1::Maximum>>("Maximum"),
std::make_shared<LayerCreator<ngraph::op::v1::Minimum>>("Minimum"),
@ -910,7 +910,7 @@ std::shared_ptr<ngraph::Node> V10Parser::LayerCreator<ngraph::op::Convert>::crea
// LSTMCell layer
template <>
std::shared_ptr<ngraph::Node> V10Parser::LayerCreator<ngraph::op::LSTMCell>::createLayer(
std::shared_ptr<ngraph::Node> V10Parser::LayerCreator<ngraph::op::v0::LSTMCell>::createLayer(
const ngraph::OutputVector& inputs, const pugi::xml_node& node, std::istream& binStream,
const GenericLayerParams& layerParsePrms) {
checkParameters(inputs, layerParsePrms, 6);
@ -922,7 +922,7 @@ std::shared_ptr<ngraph::Node> V10Parser::LayerCreator<ngraph::op::LSTMCell>::cre
std::vector<float> activations_alpha = getParameters<float>(dn, "activations_alpha", {});
std::vector<float> activations_beta = getParameters<float>(dn, "activations_beta", {});
float clip = GetFloatAttr(dn, "clip", 0.f);
return std::make_shared<ngraph::op::LSTMCell>(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5],
return std::make_shared<ngraph::op::v0::LSTMCell>(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5],
GetUInt64Attr(dn, "hidden_size"), ngraph::op::LSTMWeightsFormat::IFCO,
activations, activations_alpha, activations_beta, clip);
}

View File

@ -41,13 +41,14 @@ public:
const std::vector<float>& get_activations_alpha() { return m_activations_alpha; }
const std::vector<float>& get_activations_beta() { return m_activations_beta; }
float get_clip() {return m_clip;}
bool visit_attributes(AttributeVisitor& visitor) override;
protected:
int64_t m_hidden_size{};
const std::vector<std::string> m_activations;
const std::vector<float> m_activations_alpha;
const std::vector<float> m_activations_beta;
std::vector<std::string> m_activations;
std::vector<float> m_activations_alpha;
std::vector<float> m_activations_beta;
float m_clip;
};

View File

@ -0,0 +1,41 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API GRUCellDecomposition;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief GRUCellDecomposition transformation decomposes GRUCell layer with inputs X, H, W, R, B
* to Add, Split, MatMul, Multiply and Subtract ops according to the formula:
(.) - Denotes element-wise multiplication.
* - Denotes dot product.
f, g - are activation functions
zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # when linear_before_reset := false # (default)
ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset:= true
Ht = (1 - zt) (.) ht + zt (.) Ht-1
* *
*/
class ngraph::pass::GRUCellDecomposition: public ngraph::pass::MatcherPass {
public:
GRUCellDecomposition();
};

View File

@ -0,0 +1,42 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API LSTMCellDecomposition;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief LSTMCellDecomposition transformation decomposes LSTMCell layer with inputs X, H, C, W, R, B
* to Add, Split, MatMul, Multiply ops according to the formula:
* (.) - Denotes element-wise multiplication.
* - Denotes dot product.
f, g, h - are activation functions.
* it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf)
ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo)
Ct = ft (.) Ct-1 + it (.) ct
Ht = ot (.) h(Ct)
* *
*/
class ngraph::pass::LSTMCellDecomposition: public ngraph::pass::MatcherPass {
public:
LSTMCellDecomposition();
};

View File

@ -0,0 +1,36 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include <memory>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API RNNCellDecomposition;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief RNNCellDecomposition transformation decomposes RNNCell layer with inputs X, H, W, R, B
* to Add, MatMul ops according to the formula:
* - Denotes dot product.
f - is an activation functions.
* Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
* *
*/
class ngraph::pass::RNNCellDecomposition: public ngraph::pass::MatcherPass {
public:
RNNCellDecomposition();
};

View File

@ -101,6 +101,9 @@ TRANSFORMATIONS_API bool has_f16_constants(const std::shared_ptr<const ngraph::F
TRANSFORMATIONS_API bool check_for_broadcast(const ngraph::Shape &ref_shape, const ngraph::Shape &other_shape);
TRANSFORMATIONS_API std::shared_ptr<ngraph::Node> activation(const std::string& activation_name,
const ngraph::Output<ngraph::Node>& apply_to);
} // namespace util
} // namespace op
} // namespace ngraph

View File

@ -37,6 +37,15 @@ void op::LSTMCellIE::validate_and_infer_types() {
set_output_type(1, arg_type, output_shape);
}
bool ngraph::op::LSTMCellIE::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("hidden_size", m_hidden_size);
visitor.on_attribute("activations", m_activations);
visitor.on_attribute("activations_alpha", m_activations_alpha);
visitor.on_attribute("activations_beta", m_activations_beta);
visitor.on_attribute("clip", m_clip);
return true;
}
shared_ptr<Node> op::LSTMCellIE::clone_with_new_inputs(const OutputVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<op::LSTMCellIE>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4),

View File

@ -9,22 +9,25 @@
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/op/util/rnn_cell_base.hpp>
#include <ngraph_ops/lstm_cell_ie.hpp>
#include <ngraph_ops/gru_cell_ie.hpp>
#include <ngraph_ops/rnn_cell_ie.hpp>
ngraph::pass::ConvertLSTMCellMatcher::ConvertLSTMCellMatcher() {
auto lstm_cell_ngraph = ngraph::pattern::wrap_type<ngraph::opset1::LSTMCell>();
auto is_supported_lstm_cell = [](const std::shared_ptr<Node>& n) {
return pattern::has_class<ngraph::opset1::LSTMCell>()(n) || pattern::has_class<ngraph::opset4::LSTMCell>()(n);
};
auto any_lstm = std::make_shared<pattern::op::Label>(element::f32, Shape{}, is_supported_lstm_cell);
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
auto lstm_cell = std::dynamic_pointer_cast<ngraph::opset1::LSTMCell> (m.get_match_root());
auto lstm_cell = std::dynamic_pointer_cast<ngraph::op::util::RNNCellBase>(m.get_match_root());
if (!lstm_cell) {
return false;
}
auto W = std::dynamic_pointer_cast<ngraph::opset1::Constant> (lstm_cell->input_value(3).get_node_shared_ptr());
if (!W) {
return false;
@ -53,7 +56,7 @@ ngraph::pass::ConvertLSTMCellMatcher::ConvertLSTMCellMatcher() {
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(lstm_cell_ngraph, "ConvertLSTMCellToLSTMCellIE");
auto m = std::make_shared<ngraph::pattern::Matcher>(any_lstm, "ConvertLSTMCellToLSTMCellIE");
this->register_matcher(m, callback);
}

View File

@ -0,0 +1,104 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/gru_cell_decomposition.hpp"
#include <memory>
#include <vector>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <transformations/utils/utils.hpp>
ngraph::pass::GRUCellDecomposition::GRUCellDecomposition() {
auto gru_cell = ngraph::pattern::wrap_type<opset4::GRUCell>();
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
auto gru_cell = std::dynamic_pointer_cast<ngraph::opset4::GRUCell> (m.get_match_root());
if (!gru_cell) {
return false;
}
const Output<Node>& X = gru_cell->input_value(0);
const Output<Node>& H_t = gru_cell->input_value(1);
const Output<Node>& W = gru_cell->input_value(2);
const Output<Node>& R = gru_cell->input_value(3);
const Output<Node>& B = gru_cell->input_value(4);
// Xt*(W^T)
auto Xt_W = std::make_shared<opset4::MatMul>(X, W, false, true);
// Ht-1*(R^T)
auto Ht_R = std::make_shared<opset4::MatMul>(H_t, R, false, true);
// split to gates:
auto axis_0 = ngraph::opset4::Constant::create(element::i64, Shape{}, {0});
auto axis_1 = ngraph::opset4::Constant::create(element::i64, Shape{}, {1});
auto Xt_W_zrh = std::make_shared<opset4::Split>(Xt_W, axis_1, 3);
auto R_zrh = std::make_shared<opset4::Split>(R, axis_0, 3);
auto Ht_R_zrh = std::make_shared<opset4::Split>(Ht_R, axis_1, 3);
auto biases_zrh = std::make_shared<opset4::Split>(B, axis_0, gru_cell->get_linear_before_reset() ? 4 : 3);
// Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz
auto add_z_1 = std::make_shared<opset4::Add>(Ht_R_zrh->output(0), biases_zrh->output(0));
auto add_z_2 = std::make_shared<opset4::Add>(Xt_W_zrh->output(0), add_z_1);
// Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr
auto add_r_1 = std::make_shared<opset4::Add>(Ht_R_zrh->output(1), biases_zrh->output(1));
auto add_r_2 = std::make_shared<opset4::Add>(Xt_W_zrh->output(1), add_r_1);
auto clip = gru_cell->get_clip();
std::shared_ptr<Node> clamp_z = add_z_2;
std::shared_ptr<Node> clamp_r = add_r_2;
if (clip > 0.f) {
clamp_z = std::make_shared<opset4::Clamp>(add_z_2, -clip, clip);
clamp_r = std::make_shared<opset4::Clamp>(add_r_2, -clip, clip);
ngraph::copy_runtime_info(gru_cell, {clamp_z, clamp_r});
}
// zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
auto z_t = ngraph::op::util::activation(gru_cell->get_activations()[0], clamp_z);
// rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto r_t = ngraph::op::util::activation(gru_cell->get_activations()[0], clamp_r);
std::shared_ptr<Node> _h;
if (gru_cell->get_linear_before_reset()) {
// _h = Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh
auto Ht_Rh_Rbh = std::make_shared<opset4::Add>(Ht_R_zrh->output(2), biases_zrh->output(3));
auto mul_h_1 = std::make_shared<opset4::Multiply>(r_t, Ht_Rh_Rbh);
auto add_h_1 = std::make_shared<opset4::Add>(mul_h_1, biases_zrh->output(2));
_h = std::make_shared<opset4::Add>(Xt_W_zrh->output(2), add_h_1);
ngraph::copy_runtime_info(gru_cell, {Ht_Rh_Rbh, mul_h_1, add_h_1, _h});
} else {
// _h = Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh
auto rt_Ht = std::make_shared<opset4::Multiply>(r_t, H_t);
auto mul_h_1 = std::make_shared<opset4::MatMul>(rt_Ht, R_zrh->output(2), false, true);
auto add_h_1 = std::make_shared<opset4::Add>(mul_h_1, biases_zrh->output(2));
_h = std::make_shared<opset4::Add>(Xt_W_zrh->output(2), add_h_1);
ngraph::copy_runtime_info(gru_cell, {rt_Ht, mul_h_1, add_h_1, _h});
}
// ht = g(_h)
std::shared_ptr<Node> clamp_h = _h;
if (clip > 0.f) {
clamp_h = std::make_shared<opset4::Clamp>(_h, -clip, clip);
ngraph::copy_runtime_info(gru_cell, clamp_h);
}
auto h_t = ngraph::op::util::activation(gru_cell->get_activations()[1], clamp_h);
// Ht = (1 - zt) (.) ht + zt (.) Ht-1
auto one = opset4::Constant::create(z_t->get_element_type(), Shape{1}, {1.f});
auto sub = std::make_shared<opset4::Subtract>(one, z_t);
auto mul_1 = std::make_shared<opset4::Multiply>(sub, h_t);
auto mul_2 = std::make_shared<opset4::Multiply>(z_t, H_t);
auto out_H = std::make_shared<opset4::Add>(mul_1, mul_2);
out_H->set_friendly_name(gru_cell->get_friendly_name());
ngraph::copy_runtime_info(gru_cell, {Xt_W, Ht_R, axis_0, Xt_W_zrh, R_zrh, Ht_R_zrh, biases_zrh,
add_z_1, add_z_2, add_r_1, add_r_2, h_t, one, sub, mul_1, mul_2, out_H});
ngraph::replace_node(gru_cell, out_H);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(gru_cell, "GRUCellDecomposition");
register_matcher(m, callback);
}

View File

@ -0,0 +1,85 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/lstm_cell_decomposition.hpp"
#include <memory>
#include <vector>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <transformations/utils/utils.hpp>
ngraph::pass::LSTMCellDecomposition::LSTMCellDecomposition() {
auto lstm_cell = ngraph::pattern::wrap_type<opset4::LSTMCell>();
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
auto lstm_cell = std::dynamic_pointer_cast<ngraph::opset4::LSTMCell> (m.get_match_root());
if (!lstm_cell) {
return false;
}
const Output<Node>& X = lstm_cell->input_value(0);
const Output<Node>& H_t = lstm_cell->input_value(1);
const Output<Node>& C_t = lstm_cell->input_value(2);
const Output<Node>& W = lstm_cell->input_value(3);
const Output<Node>& R = lstm_cell->input_value(4);
const Output<Node>& bias = lstm_cell->input_value(5);
// Xt*(W^T)
auto Xt_W = std::make_shared<opset4::MatMul>(X, W, false, true);
// Ht-1*(R^T)
auto Ht_R = std::make_shared<opset4::MatMul>(H_t, R, false, true);
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb
auto add = std::make_shared<opset4::Add>(Ht_R, bias);
auto XHB = std::make_shared<opset4::Add>(Xt_W, add);
auto axis_node = ngraph::opset4::Constant::create(element::u64, Shape{}, {1});
auto split = std::make_shared<opset4::Split>(XHB, axis_node, 4);
Output<Node> f = split->output(0);
Output<Node> i = split->output(1);
Output<Node> c = split->output(2);
Output<Node> o = split->output(3);
auto clip = lstm_cell->get_clip();
if (clip > 0.f) {
auto clamp_f = std::make_shared<opset4::Clamp>(f, -clip, clip);
auto clamp_i = std::make_shared<opset4::Clamp>(i, -clip, clip);
auto clamp_c = std::make_shared<opset4::Clamp>(c, -clip, clip);
auto clamp_o = std::make_shared<opset4::Clamp>(o, -clip, clip);
f = clamp_f;
i = clamp_i;
c = clamp_c;
o = clamp_o;
ngraph::copy_runtime_info(lstm_cell, {clamp_f, clamp_i, clamp_c, clamp_o});
}
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf)
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo)
auto f_t = ngraph::op::util::activation(lstm_cell->get_activations()[0], f);
auto i_t = ngraph::op::util::activation(lstm_cell->get_activations()[0], i);
auto c_t = ngraph::op::util::activation(lstm_cell->get_activations()[1], c);
auto o_t = ngraph::op::util::activation(lstm_cell->get_activations()[0], o);
// Ct = ft (.) Ct-1 + it (.) ct
auto mul1 = std::make_shared<opset4::Multiply>(f_t, C_t);
auto mul2 = std::make_shared<opset4::Multiply>(i_t, c_t);
auto out_C = std::make_shared<opset4::Add>(mul1, mul2);
// H = ot (.) h(Ct)
auto hC = ngraph::op::util::activation(lstm_cell->get_activations()[2], out_C);
auto out_H = std::make_shared<opset4::Multiply>(o_t, hC);
out_H->set_friendly_name(lstm_cell->get_friendly_name()+".0");
out_C->set_friendly_name(lstm_cell->get_friendly_name()+".1");
ngraph::copy_runtime_info(lstm_cell, {Xt_W, Ht_R, add, split, mul1, mul2, out_H, hC, out_C, axis_node, XHB,
f_t, i_t, c_t, o_t});
ngraph::replace_node(lstm_cell, {out_H->output(0), out_C->output(0)});
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(lstm_cell, "LSTMCellDecomposition");
register_matcher(m, callback);
}

View File

@ -0,0 +1,52 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/rnn_cell_decomposition.hpp"
#include <memory>
#include <transformations/utils/utils.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/op/util/activation_functions.hpp>
ngraph::pass::RNNCellDecomposition::RNNCellDecomposition() {
auto rnn_cell = ngraph::pattern::wrap_type<opset4::RNNCell>();
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
auto rnn_cell = std::dynamic_pointer_cast<ngraph::opset4::RNNCell> (m.get_match_root());
if (!rnn_cell) {
return false;
}
const Output<Node>& X = rnn_cell->input_value(0);
const Output<Node>& H_t = rnn_cell->input_value(1);
const Output<Node>& W = rnn_cell->input_value(2);
const Output<Node>& R = rnn_cell->input_value(3);
const Output<Node>& bias = rnn_cell->input_value(4);
// Xt*(W^T)
auto Xt_W = std::make_shared<opset4::MatMul>(X, W, false, true);
// Ht-1*(R^T)
auto Ht_R = std::make_shared<opset4::MatMul>(H_t, R, false, true);
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb
auto add = std::make_shared<opset4::Add>(Ht_R, bias);
auto i_t = std::make_shared<opset4::Add>(Xt_W, add);
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
auto clip = rnn_cell->get_clip();
std::shared_ptr<Node> clamp = i_t;
if (clip > 0.f) {
clamp = std::make_shared<opset4::Clamp>(i_t, -clip, clip);
ngraph::copy_runtime_info(rnn_cell, clamp);
}
auto out = ngraph::op::util::activation(rnn_cell->get_activations()[0], clamp);
out->set_friendly_name(rnn_cell->get_friendly_name());
ngraph::copy_runtime_info(rnn_cell, {Xt_W, Ht_R, add, i_t, out});
ngraph::replace_node(rnn_cell, out);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(rnn_cell, "RNNCellDecomposition");
register_matcher(m, callback);
}

View File

@ -108,6 +108,18 @@ bool check_for_broadcast(const ngraph::Shape &ref_shape, const ngraph::Shape &ot
return false;
}
std::shared_ptr<ngraph::Node> activation(const std::string& activation_name, const ngraph::Output<ngraph::Node>& apply_to) {
if (activation_name == "relu") {
return std::make_shared<ngraph::opset4::Relu>(apply_to);
} else if (activation_name == "sigmoid") {
return std::make_shared<ngraph::opset4::Sigmoid>(apply_to);
} else if (activation_name == "tanh") {
return std::make_shared<ngraph::opset4::Tanh>(apply_to);
} else {
throw ngraph_error("Unsupported activation function");
}
}
} // namespace util
} // namespace op
} // namespace ngraph

View File

@ -14,6 +14,7 @@
#include <ngraph/ops.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/function.hpp>
#include <transformations/convert_opset1_to_legacy/convert_cells_to_cells_ie.hpp>
#include <transformations/init_node_info.hpp>
@ -129,7 +130,7 @@ TEST(TransformationTests, RNNCellConversionTest) {
ASSERT_TRUE(cell_node->get_friendly_name() == "test_cell") << "Transformation ConvertRNNCellToRNNCellIE should keep output names.\n";
}
TEST(TransformationTests, LSTMCellConversionTest) {
TEST(TransformationTests, LSTMCellConversionTest_opset3) {
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
@ -186,4 +187,76 @@ TEST(TransformationTests, LSTMCellConversionTest) {
auto result_node_of_converted_f = f->get_output_op(0);
auto cell_node = result_node_of_converted_f->input(0).get_source_output().get_node_shared_ptr();
ASSERT_TRUE(cell_node->get_friendly_name() == "test_cell") << "Transformation ConvertLSTMCellToLSTMCellIE should keep output names.\n";
}
}
TEST(TransformationTests, LSTMCellConversionTest_opset4) {
const size_t batch_size = 2;
const size_t input_size = 3;
const size_t hidden_size = 3;
const size_t gates_count = 4;
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
std::shared_ptr<ngraph::opset4::LSTMCell> cell;
{
const auto X = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32,
ngraph::Shape{batch_size, input_size});
const auto W =
std::make_shared<ngraph::opset1::Constant>(ngraph::element::f32,
ngraph::Shape{gates_count * hidden_size, input_size});
const auto R =
std::make_shared<ngraph::opset1::Constant>(ngraph::element::f32,
ngraph::Shape{gates_count * hidden_size, hidden_size});
const auto H_t = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32,
ngraph::Shape{batch_size, hidden_size});
const auto C_t = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32,
ngraph::Shape{batch_size, hidden_size});
const auto B = std::make_shared<ngraph::opset1::Constant>(ngraph::element::f32,
ngraph::Shape{gates_count * hidden_size});
cell = std::make_shared<ngraph::opset4::LSTMCell>(X, H_t, C_t, W, R, B, hidden_size);
cell->set_friendly_name("test_cell");
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{cell}, ngraph::ParameterVector{X, H_t, C_t});
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::ConvertLSTMCellMatcher>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
}
{
const auto X = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32,
ngraph::Shape{batch_size, input_size});
const auto W =
std::make_shared<ngraph::opset1::Constant>(ngraph::element::f32,
ngraph::Shape{gates_count * hidden_size, input_size});
const auto R =
std::make_shared<ngraph::opset1::Constant>(ngraph::element::f32,
ngraph::Shape{gates_count * hidden_size, hidden_size});
const auto H_t = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32,
ngraph::Shape{batch_size, hidden_size});
const auto C_t = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32,
ngraph::Shape{batch_size, hidden_size});
const auto B = std::make_shared<ngraph::opset1::Constant>(ngraph::element::f32,
ngraph::Shape{gates_count * hidden_size});
auto concat = std::make_shared<ngraph::opset1::Concat>(ngraph::NodeVector({W, R}), 1);
auto cell_ie = std::make_shared<ngraph::op::LSTMCellIE>(X, H_t, C_t, concat, B,
cell->get_hidden_size(),
cell->get_activations(),
cell->get_activations_alpha(),
cell->get_activations_beta(),
cell->get_clip());
cell_ie->set_friendly_name("test_cell");
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{cell_ie}, ngraph::ParameterVector{X, H_t, C_t});
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
auto result_node_of_converted_f = f->get_output_op(0);
auto cell_node = result_node_of_converted_f->input(0).get_source_output().get_node_shared_ptr();
ASSERT_TRUE(cell_node->get_friendly_name() == "test_cell")
<< "Transformation ConvertLSTMCellToLSTMCellIE should keep output names.\n";
}

View File

@ -0,0 +1,37 @@
// Copyright (C) 2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include "single_layer_tests/gru_cell.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
namespace {
std::vector<bool> should_decompose{false, true};
std::vector<size_t> batch{5};
std::vector<size_t> hidden_size{1, 10};
std::vector<size_t> input_size{1, 30};
std::vector<std::vector<std::string>> activations = {{"relu", "tanh"}, {"tanh", "sigmoid"}, {"sigmoid", "tanh"},
{"tanh", "relu"}};
std::vector<float> clip = {0.0f, 0.7f};
std::vector<bool> linear_before_reset = {true, false};
std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16};
INSTANTIATE_TEST_CASE_P(GRUCellCommon, GRUCellTest,
::testing::Combine(
::testing::ValuesIn(should_decompose),
::testing::ValuesIn(batch),
::testing::ValuesIn(hidden_size),
::testing::ValuesIn(input_size),
::testing::ValuesIn(activations),
::testing::ValuesIn(clip),
::testing::ValuesIn(linear_before_reset),
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
GRUCellTest::getTestCaseName);
} // namespace

View File

@ -0,0 +1,36 @@
// Copyright (C) 2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include "single_layer_tests/lstm_cell.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
namespace {
std::vector<bool> should_decompose{false, true};
std::vector<size_t> batch{5};
std::vector<size_t> hidden_size{1, 10};
std::vector<size_t> input_size{1, 30};
std::vector<std::vector<std::string>> activations = {{"relu", "sigmoid", "tanh"}, {"sigmoid", "tanh", "tanh"},
{"tanh", "relu", "sigmoid"}, {"sigmoid", "sigmoid", "sigmoid"},
{"tanh", "tanh", "tanh"}, {"relu", "relu", "relu"}};
std::vector<float> clip{0.f, 0.7f};
std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16};
INSTANTIATE_TEST_CASE_P(LSTMCellCommon, LSTMCellTest,
::testing::Combine(
::testing::ValuesIn(should_decompose),
::testing::ValuesIn(batch),
::testing::ValuesIn(hidden_size),
::testing::ValuesIn(input_size),
::testing::ValuesIn(activations),
::testing::ValuesIn(clip),
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
LSTMCellTest::getTestCaseName);
} // namespace

View File

@ -0,0 +1,34 @@
// Copyright (C) 2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include "single_layer_tests/rnn_cell.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
namespace {
std::vector<bool> should_decompose{false, true};
std::vector<size_t> batch{1, 5};
std::vector<size_t> hidden_size{1, 10};
std::vector<size_t> input_size{1, 30};
std::vector<std::vector<std::string>> activations = {{"relu"}, {"sigmoid"}, {"tanh"}};
std::vector<float> clip = {0.f, 0.7f};
std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16};
INSTANTIATE_TEST_CASE_P(RNNCellCommon, RNNCellTest,
::testing::Combine(
::testing::ValuesIn(should_decompose),
::testing::ValuesIn(batch),
::testing::ValuesIn(hidden_size),
::testing::ValuesIn(input_size),
::testing::ValuesIn(activations),
::testing::ValuesIn(clip),
::testing::ValuesIn(netPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
RNNCellTest::getTestCaseName);
} // namespace

View File

@ -0,0 +1,38 @@
// Copyright (C) 2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <tuple>
#include <string>
#include <vector>
#include <memory>
#include "functional_test_utils/layer_test_utils.hpp"
#include "ngraph_functions/builders.hpp"
#include "ngraph_functions/utils/ngraph_helpers.hpp"
namespace LayerTestsDefinitions {
using GRUCellParams = typename std::tuple<
bool, // using decompose to sub-ops transformation
size_t, // batch
size_t, // hidden size
size_t, // input size
std::vector<std::string>, // activations
float, // clip
bool, // linear_before_reset
InferenceEngine::Precision, // Network precision
std::string>; // Device name
class GRUCellTest : public testing::WithParamInterface<GRUCellParams >,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<GRUCellParams> &obj);
protected:
void SetUp() override;
};
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,37 @@
// Copyright (C) 2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <tuple>
#include <string>
#include <vector>
#include <memory>
#include "functional_test_utils/layer_test_utils.hpp"
#include "ngraph_functions/builders.hpp"
#include "ngraph_functions/utils/ngraph_helpers.hpp"
namespace LayerTestsDefinitions {
using LSTMCellParams = typename std::tuple<
bool, // using decompose to sub-ops transformation
size_t, // batch
size_t, // hidden size
size_t, // input size
std::vector<std::string>, // activations
float, // clip
InferenceEngine::Precision, // Network precision
std::string>; // Device name
class LSTMCellTest : public testing::WithParamInterface<LSTMCellParams >,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<LSTMCellParams> &obj);
protected:
void SetUp() override;
};
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,37 @@
// Copyright (C) 2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <tuple>
#include <string>
#include <vector>
#include <memory>
#include "functional_test_utils/layer_test_utils.hpp"
#include "ngraph_functions/builders.hpp"
#include "ngraph_functions/utils/ngraph_helpers.hpp"
namespace LayerTestsDefinitions {
using RNNCellParams = typename std::tuple<
bool, // using decompose to sub-ops transformation
size_t, // batch
size_t, // hidden size
size_t, // input size
std::vector<std::string>, // activations
float, // clip
InferenceEngine::Precision, // Network precision
std::string>; // Device name
class RNNCellTest : public testing::WithParamInterface<RNNCellParams >,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<RNNCellParams> &obj);
protected:
void SetUp() override;
};
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,90 @@
// Copyright (C) 2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <tuple>
#include <string>
#include <vector>
#include <memory>
#include <functional>
#include "ie_core.hpp"
#include "common_test_utils/common_utils.hpp"
#include "functional_test_utils/blob_utils.hpp"
#include "functional_test_utils/precision_utils.hpp"
#include "functional_test_utils/plugin_cache.hpp"
#include "functional_test_utils/skip_tests_config.hpp"
#include <transformations/gru_cell_decomposition.hpp>
#include "single_layer_tests/gru_cell.hpp"
namespace LayerTestsDefinitions {
std::string GRUCellTest::getTestCaseName(const testing::TestParamInfo<GRUCellParams> &obj) {
bool should_decompose;
size_t batch;
size_t hidden_size;
size_t input_size;
std::vector<std::string> activations;
std::vector<float> activations_alpha;
std::vector<float> activations_beta;
float clip;
bool linear_before_reset;
std::vector<std::vector<size_t>> inputShapes;
InferenceEngine::Precision netPrecision;
std::string targetDevice;
std::tie(should_decompose, batch, hidden_size, input_size, activations, clip,
linear_before_reset, netPrecision, targetDevice) = obj.param;
std::ostringstream result;
result << "decomposition" << should_decompose << "_";
result << "batch=" << batch << "_";
result << "hidden_size=" << hidden_size << "_";
result << "input_size=" << input_size << "_";
result << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_";
result << "activations=" << CommonTestUtils::vec2str(activations) << "_";
result << "clip=" << clip << "_";
result << "linear_before_reset=" << linear_before_reset << "_";
result << "netPRC=" << netPrecision.name() << "_";
result << "targetDevice=" << targetDevice << "_";
return result.str();
}
void GRUCellTest::SetUp() {
bool should_decompose;
size_t batch;
size_t hidden_size;
size_t input_size;
std::vector<std::string> activations;
std::vector<float> activations_alpha;
std::vector<float> activations_beta;
float clip;
bool linear_before_reset;
InferenceEngine::Precision netPrecision;
std::tie(should_decompose, batch, hidden_size, input_size, activations, clip, linear_before_reset,
netPrecision, targetDevice) = this->GetParam();
std::vector<std::vector<size_t>> inputShapes = {
{{batch, input_size}, {batch, hidden_size}, {3 * hidden_size, input_size},
{3 * hidden_size, hidden_size}, {(linear_before_reset? 4 : 3) * hidden_size}},
};
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]});
std::vector<ngraph::Shape> WRB = {inputShapes[2], inputShapes[3], inputShapes[4]};
auto gru_cell = ngraph::builder::makeGRUCell(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)),
WRB, hidden_size, activations, {}, {}, clip, linear_before_reset);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gru_cell->output(0))};
function = std::make_shared<ngraph::Function>(results, params, "gru_cell");
if (should_decompose) {
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::GRUCellDecomposition>();
m.run_passes(function);
}
}
TEST_P(GRUCellTest, CompareWithRefs) {
Run();
};
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,89 @@
// Copyright (C) 2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <tuple>
#include <string>
#include <vector>
#include <memory>
#include <functional>
#include "ie_core.hpp"
#include "common_test_utils/common_utils.hpp"
#include "functional_test_utils/blob_utils.hpp"
#include "functional_test_utils/precision_utils.hpp"
#include "functional_test_utils/plugin_cache.hpp"
#include "functional_test_utils/skip_tests_config.hpp"
#include <transformations/lstm_cell_decomposition.hpp>
#include "single_layer_tests/lstm_cell.hpp"
namespace LayerTestsDefinitions {
std::string LSTMCellTest::getTestCaseName(const testing::TestParamInfo<LSTMCellParams> &obj) {
bool should_decompose;
size_t batch;
size_t hidden_size;
size_t input_size;
std::vector<std::string> activations;
std::vector<float> activations_alpha;
std::vector<float> activations_beta;
float clip;
InferenceEngine::Precision netPrecision;
std::string targetDevice;
std::tie(should_decompose, batch, hidden_size, input_size, activations, clip, netPrecision,
targetDevice) = obj.param;
std::vector<std::vector<size_t>> inputShapes = {
{{batch, input_size}, {batch, hidden_size}, {batch, hidden_size}, {4 * hidden_size, input_size},
{4 * hidden_size, hidden_size}, {4 * hidden_size}},
};
std::ostringstream result;
result << "decomposition" << should_decompose << "_";
result << "batch=" << batch << "_";
result << "hidden_size=" << hidden_size << "_";
result << "input_size=" << input_size << "_";
result << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_";
result << "activations=" << CommonTestUtils::vec2str(activations) << "_";
result << "clip=" << clip << "_";
result << "netPRC=" << netPrecision.name() << "_";
result << "targetDevice=" << targetDevice << "_";
return result.str();
}
void LSTMCellTest::SetUp() {
bool should_decompose;
size_t batch;
size_t hidden_size;
size_t input_size;
std::vector<std::string> activations;
std::vector<float> activations_alpha;
std::vector<float> activations_beta;
float clip;
InferenceEngine::Precision netPrecision;
std::tie(should_decompose, batch, hidden_size, input_size, activations, clip, netPrecision,
targetDevice) = this->GetParam();
std::vector<std::vector<size_t>> inputShapes = {
{{batch, input_size}, {batch, hidden_size}, {batch, hidden_size}, {4 * hidden_size, input_size},
{4 * hidden_size, hidden_size}, {4 * hidden_size}},
};
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1], inputShapes[2]});
std::vector<ngraph::Shape> WRB = {inputShapes[3], inputShapes[4], inputShapes[5]};
auto lstm_cell = ngraph::builder::makeLSTMCell(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)),
WRB, hidden_size, activations, {}, {}, clip);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(lstm_cell->output(0)),
std::make_shared<ngraph::opset1::Result>(lstm_cell->output(1))};
function = std::make_shared<ngraph::Function>(results, params, "lstm_cell");
if (should_decompose) {
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::LSTMCellDecomposition>();
m.run_passes(function);
}
}
TEST_P(LSTMCellTest, CompareWithRefs) {
Run();
};
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,82 @@
// Copyright (C) 2019 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <tuple>
#include <string>
#include <vector>
#include <memory>
#include <functional>
#include "ie_core.hpp"
#include "common_test_utils/common_utils.hpp"
#include "functional_test_utils/blob_utils.hpp"
#include "functional_test_utils/precision_utils.hpp"
#include "functional_test_utils/plugin_cache.hpp"
#include "functional_test_utils/skip_tests_config.hpp"
#include <transformations/rnn_cell_decomposition.hpp>
#include "single_layer_tests/rnn_cell.hpp"
namespace LayerTestsDefinitions {
std::string RNNCellTest::getTestCaseName(const testing::TestParamInfo<RNNCellParams> &obj) {
bool should_decompose;
size_t batch;
size_t hidden_size;
size_t input_size;
std::vector<std::string> activations;
float clip;
InferenceEngine::Precision netPrecision;
std::string targetDevice;
std::tie(should_decompose, batch, hidden_size, input_size, activations, clip,
netPrecision, targetDevice) = obj.param;
std::vector<std::vector<size_t>> inputShapes = {{batch, input_size}, {batch, hidden_size},
{hidden_size, input_size}, {hidden_size, hidden_size}, {hidden_size}};
std::ostringstream result;
result << "decomposition" << should_decompose << "_";
result << "batch=" << batch << "_";
result << "hidden_size=" << hidden_size << "_";
result << "input_size=" << input_size << "_";
result << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_";
result << "activations=" << CommonTestUtils::vec2str(activations) << "_";
result << "clip=" << clip << "_";
result << "netPRC=" << netPrecision.name() << "_";
result << "targetDevice=" << targetDevice << "_";
return result.str();
}
void RNNCellTest::SetUp() {
bool should_decompose;
size_t batch;
size_t hidden_size;
size_t input_size;
std::vector<std::string> activations;
std::vector<float> activations_alpha;
std::vector<float> activations_beta;
float clip;
InferenceEngine::Precision netPrecision;
std::tie(should_decompose, batch, hidden_size, input_size, activations, clip,
netPrecision, targetDevice) = this->GetParam();
std::vector<std::vector<size_t>> inputShapes = {{batch, input_size}, {batch, hidden_size},
{hidden_size, input_size}, {hidden_size, hidden_size}, {hidden_size}};
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]});
std::vector<ngraph::Shape> WRB = {inputShapes[2], inputShapes[3], inputShapes[4]};
auto rnn_cell = ngraph::builder::makeRNNCell(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)),
WRB, hidden_size, activations, {}, {}, clip);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(rnn_cell)};
function = std::make_shared<ngraph::Function>(results, params, "rnn_cell");
if (should_decompose) {
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::RNNCellDecomposition>();
m.run_passes(function);
}
}
TEST_P(RNNCellTest, CompareWithRefs) {
Run();
};
} // namespace LayerTestsDefinitions

View File

@ -72,7 +72,7 @@ void Basic_LSTM_S::SetUp() {
//lstm [1, 10], [1, 118], [1, 118] -> [1, 118], [1, 118]
outFormShapes1 = { batch_size, reshape1_shape[2] };
auto constantX = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{2}, outFormShapes1);
auto lstm1 = std::make_shared<ngraph::opset1::LSTMCell>(std::make_shared<ngraph::opset1::Reshape>(X, constantX, false),
auto lstm1 = std::make_shared<ngraph::opset4::LSTMCell>(std::make_shared<ngraph::opset1::Reshape>(X, constantX, false),
H_t, C_t,
weightsNode, reccurrenceWeightsNode, hidden_size);
@ -137,7 +137,7 @@ std::shared_ptr<ngraph::Function> Basic_LSTM_S::CreateGraphWithUnrolledTI() {
ngraph::Output<ngraph::Node> H[iterations + 1];
ngraph::Output<ngraph::Node> C[iterations + 1];
std::shared_ptr<ngraph::opset1::LSTMCell> lstm[iterations];
std::shared_ptr<ngraph::opset4::LSTMCell> lstm[iterations];
H[0] = ngraph::builder::makeConstant<float>(ngPrc, { batch_size, hidden_size }, {}, true);
C[0] = ngraph::builder::makeConstant<float>(ngPrc, { batch_size, hidden_size }, {}, true);
auto reshape1_shape = reshape1->output(0).get_shape();
@ -149,7 +149,7 @@ std::shared_ptr<ngraph::Function> Basic_LSTM_S::CreateGraphWithUnrolledTI() {
for (size_t i = 0; i < iterations; ++i) {
auto X = split1->output(i);
lstm[i] = std::make_shared<ngraph::opset1::LSTMCell>(std::make_shared<ngraph::opset1::Reshape>(X, constantX, false),
lstm[i] = std::make_shared<ngraph::opset4::LSTMCell>(std::make_shared<ngraph::opset1::Reshape>(X, constantX, false),
H[i], C[i],
weightsNode, reccurrenceWeightsNode, hidden_size);

View File

@ -389,5 +389,31 @@ std::shared_ptr<ngraph::Node> makePad(const ngraph::Output<Node>& data,
std::shared_ptr<ngraph::Node> makeBatchNormInference(const ngraph::Output<Node>& data,
double epsilon);
std::shared_ptr<ngraph::Node> makeLSTMCell(const OutputVector& in,
const std::vector<ngraph::Shape>& WRB,
std::size_t hidden_size,
const std::vector<std::string>& activations =
std::vector<std::string>{"sigmoid", "tanh", "tanh"},
const std::vector<float>& activations_alpha = {},
const std::vector<float>& activations_beta = {},
float clip = 0.f);
std::shared_ptr<ngraph::Node> makeGRUCell(const OutputVector& in,
const std::vector<ngraph::Shape>& WRB,
std::size_t hidden_size,
const std::vector<std::string>& activations =
std::vector<std::string>{"sigmoid", "tanh"},
const std::vector<float>& activations_alpha = {},
const std::vector<float>& activations_beta = {},
float clip = 0.f,
bool linear_before_reset = false);
std::shared_ptr<ngraph::Node> makeRNNCell(const OutputVector& in,
const std::vector<ngraph::Shape>& WRB,
std::size_t hidden_size,
const std::vector<std::string>& activations = std::vector<std::string>{"tanh"},
const std::vector<float>& activations_alpha = {},
const std::vector<float>& activations_beta = {},
float clip = 0.f);
} // namespace builder
} // namespace ngraph

View File

@ -130,7 +130,7 @@ static std::shared_ptr<ngraph::Function> makeTIwithLSTMcell(InferenceEngine::Pre
inShape = {N, I};
auto constantX = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{2}, inShape);
auto LSTM_cell =
std::make_shared<ngraph::opset1::LSTMCell>(std::make_shared<ngraph::opset1::Reshape>(X, constantX, false),
std::make_shared<ngraph::opset4::LSTMCell>(std::make_shared<ngraph::opset1::Reshape>(X, constantX, false),
std::make_shared<ngraph::opset1::Reshape>(H_t, constantH, false),
std::make_shared<ngraph::opset1::Reshape>(C_t, constantH, false),
W_body,

View File

@ -0,0 +1,30 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include <memory>
#include "ngraph_functions/builders.hpp"
namespace ngraph {
namespace builder {
std::shared_ptr<ngraph::Node> makeGRUCell(const OutputVector& in,
const std::vector<ngraph::Shape>& WRB,
std::size_t hidden_size,
const std::vector<std::string>& activations,
const std::vector<float>& activations_alpha,
const std::vector<float>& activations_beta,
float clip,
bool linear_before_reset) {
std::vector<float> empty;
auto W = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[0], empty, true);
auto R = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[1], empty, true);
auto B = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[2], empty, true);
return std::make_shared<ngraph::opset4::GRUCell>(in[0], in[1], W, R, B, hidden_size, activations,
activations_alpha, activations_beta, clip, linear_before_reset);
}
} // namespace builder
} // namespace ngraph

View File

@ -0,0 +1,29 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include <memory>
#include "ngraph_functions/builders.hpp"
namespace ngraph {
namespace builder {
std::shared_ptr<ngraph::Node> makeLSTMCell(const std::vector<ngraph::Output<Node>>& in,
const std::vector<ngraph::Shape>& WRB,
std::size_t hidden_size,
const std::vector<std::string>& activations,
const std::vector<float>& activations_alpha,
const std::vector<float>& activations_beta,
float clip) {
std::vector<float> empty;
auto W = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[0], empty, true);
auto R = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[1], empty, true);
auto B = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[2], empty, true);
return std::make_shared<ngraph::opset4::LSTMCell>(in[0], in[1], in[2], W, R, B, hidden_size, activations,
activations_alpha, activations_beta, clip);
}
} // namespace builder
} // namespace ngraph

View File

@ -0,0 +1,29 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include <memory>
#include "ngraph_functions/builders.hpp"
namespace ngraph {
namespace builder {
std::shared_ptr<ngraph::Node> makeRNNCell(const OutputVector& in,
const std::vector<ngraph::Shape>& WRB,
std::size_t hidden_size,
const std::vector<std::string>& activations,
const std::vector<float>& activations_alpha,
const std::vector<float>& activations_beta,
float clip) {
std::vector<float> empty;
auto W = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[0], empty, true);
auto R = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[1], empty, true);
auto B = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[2], empty, true);
return std::make_shared<ngraph::opset4::RNNCell>(in[0], in[1], W, R, B, hidden_size, activations,
activations_alpha, activations_beta, clip);
}
} // namespace builder
} // namespace ngraph

View File

@ -42,7 +42,7 @@ class LSTMCell(Op):
mandatory_props = {
'type': __class__.op,
'op': __class__.op,
'version': 'opset1',
'version': 'opset4',
'infer': __class__.infer,
'in_ports_count': 5,
'out_ports_count': 2,

View File

@ -26,8 +26,6 @@
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/op/util/rnn_cell_base.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
namespace ngraph
{
namespace op
@ -42,7 +40,7 @@ namespace ngraph
///
/// Note this class represents only single *cell* and not whole GRU *layer*.
///
class NGRAPH_API GRUCell : public util::FusedOp, public util::RNNCellBase
class NGRAPH_API GRUCell : public util::RNNCellBase
{
public:
static constexpr NodeTypeInfo type_info{"GRUCell", 3};
@ -151,8 +149,6 @@ namespace ngraph
virtual void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual void pre_validate_and_infer_types() override;
virtual OutputVector decompose_op() const override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
@ -180,8 +176,5 @@ namespace ngraph
bool m_linear_before_reset;
};
}
using v3::GRUCell;
}
}
NGRAPH_SUPPRESS_DEPRECATED_END

View File

@ -69,7 +69,7 @@ namespace ngraph
///
/// \sa LSTMSequence, RNNCell, GRUCell
///
class NGRAPH_API LSTMCell : public util::FusedOp, public util::RNNCellBase
class NGRAPH_API LSTMCell : public util::RNNCellBase
{
public:
static constexpr NodeTypeInfo type_info{"LSTMCell", 0};
@ -216,24 +216,11 @@ namespace ngraph
virtual void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual void pre_validate_and_infer_types() override;
virtual OutputVector decompose_op() const override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
bool get_input_forget() const { return m_input_forget; }
LSTMWeightsFormat get_weights_format() const { return m_weights_format; }
///
/// \brief Change data format of provided node into IFCO.
///
/// \node The IFCO format was chosen because it's default DNNL format.
///
/// \param[in] node The input node to be permuted.
///
/// \return Node representing reshaped tensor according to IFCO weights format.
///
std::shared_ptr<Node> convert_node_format(const Output<Node>& node) const;
private:
///
/// \brief Creates the default bias input initialized with zeros.
@ -273,9 +260,149 @@ namespace ngraph
static constexpr std::size_t s_gates_count{4};
static constexpr std::size_t s_peepholes_count{3};
};
}
using v0::LSTMCell;
} // namespace op
} // v0
namespace v4
{
///
/// \brief Class for single lstm cell node.
///
/// \note Following implementation supports:
/// \li \c peepholes Gers & Schmidhuber (2000)
/// https://ieeexplore.ieee.org/document/861302
/// \li Coupling input and forget gates.
///
/// \note It calculates following equations:
///
/// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
/// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf)
/// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
/// Ct = ft (.) Ct-1 + it (.) ct
/// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo)
/// Ht = ot (.) h(Ct)
///
/// * - Is a dot product,
/// (.) - is a Hadamard product (element-wise),
/// f, g, h - are activation functions.
///
/// \note This class represents only single *cell* (for current time step) and not
/// the whole LSTM Sequence layer
///
/// \sa LSTMSequence, RNNCell, GRUCell
///
class NGRAPH_API LSTMCell : public util::RNNCellBase
{
public:
static constexpr NodeTypeInfo type_info{"LSTMCell", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
LSTMCell();
///
/// \brief Constructs LSTMCell node.
///
/// \param[in] X The input tensor with shape: [batch_size,
/// input_size].
/// \param[in] initial_hidden_state The hidden state tensor at current time step
/// with shape: [batch_size, hidden_size].
/// \param[in] initial_cell_state The cell state tensor at current time step
/// with shape: [batch_size, hidden_size].
/// \param[in] W The gate weights tensor with shape:
/// [4*hidden_size, input_size].
/// \param[in] R The recurrence weights tensor with shape:
/// [4*hidden_size, hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activations_alpha The vector of alpha parameters for activation
/// functions in order respective to activation
/// list.
/// \param[in] activations_beta The vector of beta parameters for activation
/// functions in order respective to activation
/// list.
/// \param[in] clip The value defining clipping range [-clip,
/// clip] on input of activation functions.
LSTMCell(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& initial_cell_state,
const Output<Node>& W,
const Output<Node>& R,
std::size_t hidden_size,
const std::vector<std::string>& activations =
std::vector<std::string>{"sigmoid", "tanh", "tanh"},
const std::vector<float>& activations_alpha = {},
const std::vector<float>& activations_beta = {},
float clip = 0.f);
///
/// \brief Constructs LSTMCell node.
///
/// \param[in] X The input tensor with shape: [batch_size,
/// input_size].
/// \param[in] initial_hidden_state The hidden state tensor at current time step
/// with shape: [batch_size, hidden_size].
/// \param[in] initial_cell_state The cell state tensor at current time step
/// with shape: [batch_size, hidden_size].
/// \param[in] W The weight tensor with shape: [4*hidden_size,
/// input_size].
/// \param[in] R The recurrence weight tensor with shape:
/// [4*hidden_size, hidden_size].
/// \param[in] B The bias tensor for gates with shape:
/// [4*hidden_size].
/// \param[in] hidden_size The number of hidden units for recurrent cell.
/// \param[in] activations The vector of activation functions used inside
/// recurrent cell.
/// \param[in] activations_alpha The vector of alpha parameters for activation
/// functions in order respective to activation
/// list.
/// \param[in] activations_beta The vector of beta parameters for activation
/// functions in order respective to activation
/// list.
/// \param[in] clip The value defining clipping range [-clip,
/// clip] on input of activation functions.
///
LSTMCell(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& initial_cell_state,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& B,
std::size_t hidden_size,
const std::vector<std::string>& activations =
std::vector<std::string>{"sigmoid", "tanh", "tanh"},
const std::vector<float>& activations_alpha = {},
const std::vector<float>& activations_beta = {},
float clip = 0.f);
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
private:
///
/// \brief Creates the default bias input initialized with zeros.
///
/// \return The object of Output class.
///
Output<Node> get_default_bias_input() const;
///
/// \brief The Activation function f.
///
util::ActivationFunction m_activation_f;
///
/// \brief The Activation function g.
///
util::ActivationFunction m_activation_g;
///
/// \brief The Activation function h.
///
util::ActivationFunction m_activation_h;
static constexpr std::size_t s_gates_count{4};
};
} // v1
} // namespace op
NGRAPH_API
std::ostream& operator<<(std::ostream& s, const op::LSTMWeightsFormat& type);
@ -294,5 +421,3 @@ namespace ngraph
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
};
} // namespace ngraph
NGRAPH_SUPPRESS_DEPRECATED_END

View File

@ -27,8 +27,7 @@
#include "ngraph/op/lstm_cell.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/fused_op.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
#include "ngraph/op/util/rnn_cell_base.hpp"
namespace ngraph
{
@ -186,9 +185,66 @@ namespace ngraph
LSTMWeightsFormat m_weights_format;
};
}
using v0::LSTMSequence;
namespace v1
{
///
/// \brief Class for lstm sequence node.
///
/// \note It follows notation and equations defined as in ONNX standard:
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
///
/// \sa LSTMCell, RNNCell, GRUCell
///
///
class NGRAPH_API LSTMSequence : public util::RNNCellBase
{
public:
static constexpr NodeTypeInfo type_info{"LSTMSequence", 1};
const NodeTypeInfo& get_type_info() const override { return type_info; }
LSTMSequence() = default;
using direction = RecurrentSequenceDirection;
size_t get_default_output_index() const override { return no_default_index(); }
explicit LSTMSequence(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& initial_cell_state,
const Output<Node>& sequence_lengths,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& B,
const std::int64_t hidden_size,
const direction lstm_direction,
const std::vector<float> activations_alpha = {},
const std::vector<float> activations_beta = {},
const std::vector<std::string> activations = {"sigmoid",
"tanh",
"tanh"},
const float clip = 0.f)
: RNNCellBase(
{X, initial_hidden_state, initial_cell_state, sequence_lengths, W, R, B},
hidden_size,
clip,
activations,
activations_alpha,
activations_beta)
, m_direction(lstm_direction)
{
constructor_validate_and_infer_types();
}
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
direction get_direction() const { return m_direction; }
private:
direction m_direction;
};
}
} // namespace op
} // namespace ngraph
NGRAPH_SUPPRESS_DEPRECATED_END

View File

@ -26,8 +26,6 @@
#include "ngraph/op/util/fused_op.hpp"
#include "ngraph/op/util/rnn_cell_base.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
namespace ngraph
{
namespace op
@ -52,7 +50,7 @@ namespace ngraph
///
/// \sa LSTMSequence, LSTMCell, GRUCell
///
class NGRAPH_API RNNCell : public util::FusedOp, public util::RNNCellBase
class NGRAPH_API RNNCell : public util::RNNCellBase
{
public:
static constexpr NodeTypeInfo type_info{"RNNCell", 0};
@ -129,11 +127,9 @@ namespace ngraph
const std::vector<float>& activations_beta = {},
float clip = 0.f);
virtual void validate_and_infer_types() override;
void validate_and_infer_types() override;
bool visit_attributes(AttributeVisitor& visitor) override;
virtual void pre_validate_and_infer_types() override;
virtual OutputVector decompose_op() const override;
virtual std::shared_ptr<Node>
std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;
private:
@ -152,8 +148,5 @@ namespace ngraph
static constexpr std::size_t s_gates_count{1};
};
}
using v0::RNNCell;
} // namespace op
} // namespace ngraph
NGRAPH_SUPPRESS_DEPRECATED_END

View File

@ -30,11 +30,39 @@ namespace ngraph
{
namespace util
{
enum class LSTMWeightsFormat
{
FICO, // IE
ICOF, // PyTorch
IFCO, // DNNL, TF, MxNet
IFOC, // Caffe
IOFC, // ONNX
};
///
/// \brief Change data format of provided node.
///
/// \param[in] node The input node to be permuted.
///
///
/// \param[in] from_format Original node weights format.
///
///
/// \param[in] to_format Weights format to convert to.
///
/// \return Node representing reshaped tensor according to `to_format` weights
/// format.
///
std::shared_ptr<Node> NGRAPH_API
convert_lstm_node_format(const Output<Node>& node,
LSTMWeightsFormat from_format,
LSTMWeightsFormat to_format = LSTMWeightsFormat::FICO);
/// \brief Base class for all recurrent network cells.
///
/// \note It holds all common attributes.
///
class NGRAPH_API RNNCellBase
class NGRAPH_API RNNCellBase : public Op
{
public:
///
@ -50,7 +78,8 @@ namespace ngraph
/// \param[in] activations_beta The vector of beta parameters for activation
/// functions in order respective to activation list.
///
RNNCellBase(std::size_t hidden_size,
RNNCellBase(const OutputVector& args,
std::size_t hidden_size,
float clip,
const std::vector<std::string>& activations,
const std::vector<float>& activations_alpha,

View File

@ -70,8 +70,7 @@ NGRAPH_OP(LogicalNot, ngraph::op::v1)
NGRAPH_OP(LogicalOr, ngraph::op::v1)
NGRAPH_OP(LogicalXor, ngraph::op::v1)
NGRAPH_OP(LRN, ngraph::op::v0)
NGRAPH_OP(LSTMCell, ngraph::op::v0)
NGRAPH_OP(LSTMSequence, ngraph::op::v0)
NGRAPH_OP(LSTMCell, ngraph::op::v4)
NGRAPH_OP(MatMul, ngraph::op::v0)
NGRAPH_OP(MaxPool, ngraph::op::v1)
NGRAPH_OP(Maximum, ngraph::op::v1)

View File

@ -0,0 +1,316 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include <ngraph/runtime/reference/add.hpp>
#include <ngraph/runtime/reference/clamp.hpp>
#include <ngraph/runtime/reference/matmul.hpp>
#include <ngraph/runtime/reference/multiply.hpp>
#include <ngraph/runtime/reference/relu.hpp>
#include <ngraph/runtime/reference/sigmoid.hpp>
#include <ngraph/runtime/reference/split.hpp>
#include <ngraph/runtime/reference/subtract.hpp>
#include <ngraph/runtime/reference/tanh.hpp>
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void gru_cell(const T* X,
const Shape& X_shape,
const T* H,
const Shape& H_shape,
const T* W,
const Shape& W_shape,
const T* R,
const Shape& R_shape,
const T* B,
const Shape& B_shape,
T* dst_data,
const std::string& activation_f,
const std::string& activation_g,
float clip,
bool linear_before_reset)
{
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ ACRONYMS ------
// z_t - update gate at current time step
// r_t - reset gate at current time step
// h_t - hidden gate at current time step
// t - time step (t-1 means previous time step)
// X The input data tensor. Shape: [batch_size, input_size].
// W[zrh] - The weight tensor for update, reset and hidden gates.
// Shape: [gates_count * hidden_size, input_size].
// R[zrh] - The recurrence weight tensor for update, reset and hidden gates.
// Shape: [gates_count * hidden_size, hidden_size].
// H_t - The hidden state tensor at current time step. Shape: [batch_size,
// hidden_size].
// B - The sum of biases (weight and recurrence) for update, reset and hidden
// gates.
// If linear_before_reset := true then biases for hidden gates are placed
// separately
// (weight and recurrence).
// Shape: [gates_count * hidden_size] when linear_before_reset := false
// Shape: [(gates_count + 1) * hidden_size] when linear_before_reset :=
// true
// Wb[zrh] - W bias vectors for update, reset and hidden gates.
// Rb[zrh] - R bias vectors for update, reset and hidden gates.
// (.) - Denotes element-wise multiplication.
// * - Denotes dot product.
// ---- Equations ----
// f, g - are activation functions
// zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
// rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
// ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # when linear_before_reset
// := false
// # (default)
// ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset
// := true
// Ht = (1 - zt) (.) ht + zt (.) Ht-1
// -------------------
Shape gate_shape{X_shape[0], H_shape[1]};
Shape all_gates_shape{X_shape[0], 3 * H_shape[1]};
Shape bias_shape{H_shape[1], H_shape[1]};
auto gate_shape_size = X_shape[0] * H_shape[1];
auto all_gates_shape_size = gate_shape_size * 3;
auto bias_shape_size = H_shape[1] * H_shape[1];
// Xt*(W^T)
std::vector<T> Xt_W(all_gates_shape_size);
reference::matmul(
X, W, Xt_W.data(), X_shape, W_shape, all_gates_shape, false, true);
// Ht-1*(R^T)
std::vector<T> Ht_R(all_gates_shape_size);
reference::matmul(
H, R, Ht_R.data(), H_shape, R_shape, all_gates_shape, false, true);
std::vector<std::vector<T>> X_W_zrh(3, std::vector<T>(gate_shape_size));
std::vector<char*> pointers_XW = {reinterpret_cast<char*>(X_W_zrh[0].data()),
reinterpret_cast<char*>(X_W_zrh[1].data()),
reinterpret_cast<char*>(X_W_zrh[2].data())};
std::vector<std::vector<T>> R_zrh(3, std::vector<T>(bias_shape_size));
std::vector<char*> pointers_R = {reinterpret_cast<char*>(R_zrh[0].data()),
reinterpret_cast<char*>(R_zrh[1].data()),
reinterpret_cast<char*>(R_zrh[2].data())};
std::vector<std::vector<T>> Ht_R_zrh(3, std::vector<T>(gate_shape_size));
std::vector<char*> pointers_H_R = {reinterpret_cast<char*>(Ht_R_zrh[0].data()),
reinterpret_cast<char*>(Ht_R_zrh[1].data()),
reinterpret_cast<char*>(Ht_R_zrh[2].data())};
size_t num_b_splits = linear_before_reset ? 4 : 3;
std::vector<std::vector<T>> biases_zrh(num_b_splits,
std::vector<T>(B_shape[0] / num_b_splits));
std::vector<char*> pointers_biases = {
reinterpret_cast<char*>(biases_zrh[0].data()),
reinterpret_cast<char*>(biases_zrh[1].data()),
reinterpret_cast<char*>(biases_zrh[2].data())};
if (linear_before_reset)
{
pointers_biases.push_back(reinterpret_cast<char*>(biases_zrh[3].data()));
}
// split on gates
reference::split(reinterpret_cast<char*>(Xt_W.data()),
all_gates_shape,
sizeof(T),
1,
3,
pointers_XW.data());
reference::split(
reinterpret_cast<const char*>(R), R_shape, sizeof(T), 0, 3, pointers_R.data());
reference::split(reinterpret_cast<char*>(Ht_R.data()),
all_gates_shape,
sizeof(T),
1,
3,
pointers_H_R.data());
reference::split(reinterpret_cast<const char*>(B),
B_shape,
sizeof(T),
0,
num_b_splits,
pointers_biases.data());
auto clip_activation = [&clip](std::vector<T>& gate,
const std::string& activation) {
if (clip > 0.f)
{
reference::clamp(gate.data(),
gate.data(),
static_cast<T>(-clip),
static_cast<T>(clip),
gate.size());
}
if (activation == "relu")
{
reference::relu(gate.data(), gate.data(), gate.size());
}
else if (activation == "sigmoid")
{
reference::sigmoid(gate.data(), gate.data(), gate.size());
}
else if (activation == "tanh")
{
reference::tanh(gate.data(), gate.data(), gate.size());
}
else
{
throw ngraph_error("Activation function " + activation +
" is not supported.");
}
};
// calculate z_t
// steps:
// Ht-1*(Rz^T) + Wbz + Rbz
// Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz
// zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
std::vector<T> z_t(gate_shape_size);
reference::add(Ht_R_zrh[0].data(),
biases_zrh[0].data(),
z_t.data(),
gate_shape,
{B_shape[0] / num_b_splits},
op::AutoBroadcastSpec::NUMPY); //
reference::add(X_W_zrh[0].data(),
z_t.data(),
z_t.data(),
gate_shape,
gate_shape,
op::AutoBroadcastSpec::NUMPY); //
clip_activation(z_t, activation_f);
// calculate r_t
// steps:
// Ht-1*(Rr^T) + Wbr + Rbr
// Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr
// rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
std::vector<T> r_t(gate_shape_size);
reference::add(Ht_R_zrh[1].data(),
biases_zrh[1].data(),
r_t.data(),
gate_shape,
{B_shape[0] / num_b_splits},
op::AutoBroadcastSpec::NUMPY);
reference::add(X_W_zrh[1].data(),
r_t.data(),
r_t.data(),
gate_shape,
gate_shape,
op::AutoBroadcastSpec::NUMPY);
clip_activation(r_t, activation_f);
// calculate h_t
vector<T> h_t(gate_shape_size);
if (linear_before_reset)
{
// ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
reference::add(Ht_R_zrh[2].data(),
biases_zrh[3].data(),
h_t.data(),
gate_shape,
{B_shape[0] / num_b_splits},
op::AutoBroadcastSpec::NUMPY);
reference::multiply(r_t.data(),
h_t.data(),
h_t.data(),
gate_shape,
gate_shape,
op::AutoBroadcastSpec::NUMPY);
reference::add(h_t.data(),
biases_zrh[2].data(),
h_t.data(),
gate_shape,
{B_shape[0] / num_b_splits},
op::AutoBroadcastSpec::NUMPY);
reference::add(X_W_zrh[2].data(),
h_t.data(),
h_t.data(),
gate_shape,
gate_shape,
op::AutoBroadcastSpec::NUMPY);
}
else
{
// ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
reference::multiply(r_t.data(),
H,
h_t.data(),
gate_shape,
H_shape,
op::AutoBroadcastSpec::NUMPY);
std::vector<T> matmul(gate_shape_size);
reference::matmul(h_t.data(),
R_zrh[2].data(),
matmul.data(),
gate_shape,
bias_shape,
gate_shape,
false,
true);
reference::add(matmul.data(),
biases_zrh[2].data(),
h_t.data(),
gate_shape,
{B_shape[0] / num_b_splits},
op::AutoBroadcastSpec::NUMPY);
reference::add(X_W_zrh[2].data(),
h_t.data(),
h_t.data(),
gate_shape,
gate_shape,
op::AutoBroadcastSpec::NUMPY);
}
clip_activation(h_t, activation_g);
// Ht = (1 - zt) (.) ht + zt (.) Ht-1
vector<T> mul1(gate_shape_size);
vector<T> mul2(gate_shape_size);
T one[] = {1};
reference::subtract(
one, z_t.data(), mul1.data(), {1}, gate_shape, op::AutoBroadcastSpec::NUMPY);
reference::multiply(mul1.data(),
h_t.data(),
mul1.data(),
gate_shape,
gate_shape,
op::AutoBroadcastSpec::NUMPY);
reference::multiply(z_t.data(),
H,
mul2.data(),
gate_shape,
gate_shape,
op::AutoBroadcastSpec::NUMPY);
reference::add(mul1.data(),
mul2.data(),
dst_data,
gate_shape,
gate_shape,
op::AutoBroadcastSpec::NUMPY);
}
}
}
}

View File

@ -0,0 +1,217 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include <ngraph/runtime/reference/add.hpp>
#include <ngraph/runtime/reference/clamp.hpp>
#include <ngraph/runtime/reference/matmul.hpp>
#include <ngraph/runtime/reference/multiply.hpp>
#include <ngraph/runtime/reference/relu.hpp>
#include <ngraph/runtime/reference/sigmoid.hpp>
#include <ngraph/runtime/reference/split.hpp>
#include <ngraph/runtime/reference/subtract.hpp>
#include <ngraph/runtime/reference/tanh.hpp>
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void lstm_cell(const T* X,
const Shape& X_shape,
const T* H,
const Shape& H_shape,
const T* C,
const Shape& C_shape,
const T* W,
const Shape& W_shape,
const T* R,
const Shape& R_shape,
const T* B,
const Shape& B_shape,
T* out_Ht,
T* out_Ct,
const std::string& activation_f,
const std::string& activation_g,
const std::string& activation_h,
float clip)
{
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ ACRONYMS ------
// i - input gate
// o - output gate
// f - forget gate
// c - cell gate
// t - time step (t-1 means previous time step)
// Wb - W bias vectors for input, output, forget, and cell gates.
// Rb - R bias vectors for input, output, forget, and cell gates.
// P - The peephole weights for input, output and forget gates.
// ------ VARIABLE NAMES ------
// X - The input data tensor. Shape: [batch_size, input_size].
// W - The weight matrix for input, forget, cell and output gates
// Shape: [4*hidden_size, input_size]
// R - The recurrence weight matrix for input, forget, cell and output gates.
// Shape: [4*hidden_size, hidden_size].
// H_t - The hidden state tensor at current time step. Shape: [batch_size,
// hidden_size].
// C_t - The cell state tensor at current time step. Shape: [batch_size,
// hidden_size].
// bias - The sum of biases (weight and recurrence) for input, forget, cell and
// output gates.
// Shape: [4 * hidden_size]
// p_[iof] - The peephole weight vector for respectively: input, output, and forget
// gates.
// Each peephole has shape [hidden_size].
//
// (.) - Denotes element-wise multiplication.
// * - Denotes dot product.
//
// ---- Equations ----
// f, g, h - are activation functions.
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf)
// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo)
// Ct = ft (.) Ct-1 + it (.) ct
// Ht = ot (.) h(Ct)
// --------------------
Shape gate_shape{X_shape[0], H_shape[1]};
Shape all_gates_shape{X_shape[0], 4 * H_shape[1]};
auto gate_shape_size = X_shape[0] * H_shape[1];
auto all_gates_shape_size = gate_shape_size * 4;
// Xt*(W^T)
std::vector<T> Xt_W(all_gates_shape_size);
reference::matmul(
X, W, Xt_W.data(), X_shape, W_shape, all_gates_shape, false, true);
// Ht-1*(R^T)
std::vector<T> Ht_R(all_gates_shape_size);
reference::matmul(
H, R, Ht_R.data(), H_shape, R_shape, all_gates_shape, false, true);
// Ht-1*(R^T) + Wb + Rb
std::vector<T> Ht_R_B(all_gates_shape_size);
reference::add(Ht_R.data(),
B,
Ht_R_B.data(),
all_gates_shape,
B_shape,
op::AutoBroadcastSpec::NUMPY);
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb
std::vector<T> XHB(all_gates_shape_size);
reference::add(Xt_W.data(),
Ht_R_B.data(),
XHB.data(),
all_gates_shape,
all_gates_shape,
op::AutoBroadcastSpec::NUMPY);
std::vector<std::vector<T>> X_W_fico(4, std::vector<T>(all_gates_shape_size / 4));
std::vector<char*> pointers = {reinterpret_cast<char*>(X_W_fico[0].data()),
reinterpret_cast<char*>(X_W_fico[1].data()),
reinterpret_cast<char*>(X_W_fico[2].data()),
reinterpret_cast<char*>(X_W_fico[3].data())};
// split on gates
reference::split(reinterpret_cast<char*>(XHB.data()),
all_gates_shape,
sizeof(T),
1,
4,
pointers.data());
auto clip_activation = [&clip](
std::vector<T>& gate, const std::string& activation, bool enable_clip = true) {
if (clip > 0.f && enable_clip)
{
reference::clamp(gate.data(),
gate.data(),
static_cast<T>(-clip),
static_cast<T>(clip),
gate.size());
}
if (activation == "relu")
{
reference::relu(gate.data(), gate.data(), gate.size());
}
else if (activation == "sigmoid")
{
reference::sigmoid(gate.data(), gate.data(), gate.size());
}
else if (activation == "tanh")
{
reference::tanh(gate.data(), gate.data(), gate.size());
}
else
{
throw ngraph_error("Activation function " + activation +
" is not supported.");
}
};
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf)
clip_activation(X_W_fico[0], activation_f);
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
clip_activation(X_W_fico[1], activation_f);
// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
clip_activation(X_W_fico[2], activation_g);
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo)
clip_activation(X_W_fico[3], activation_f);
vector<T> mul1(gate_shape_size);
vector<T> mul2(gate_shape_size);
vector<T> Ct(gate_shape_size);
// ft (.) Ct-1
reference::multiply(X_W_fico[0].data(),
C,
mul1.data(),
gate_shape,
C_shape,
op::AutoBroadcastSpec::NUMPY);
// it (.) ct
reference::multiply(X_W_fico[1].data(),
X_W_fico[2].data(),
mul2.data(),
gate_shape,
gate_shape,
op::AutoBroadcastSpec::NUMPY);
// Ct = ft (.) Ct-1 + it (.) ct
reference::add(mul1.data(),
mul2.data(),
Ct.data(),
gate_shape,
gate_shape,
op::AutoBroadcastSpec::NUMPY);
std::memcpy(out_Ct, Ct.data(), Ct.size() * sizeof(T));
clip_activation(Ct, activation_h, false);
// Ht = ot (.) h(Ct)
reference::multiply(X_W_fico[3].data(),
Ct.data(),
out_Ht,
gate_shape,
gate_shape,
op::AutoBroadcastSpec::NUMPY);
}
}
}
}

View File

@ -0,0 +1,132 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include <ngraph/runtime/reference/add.hpp>
#include <ngraph/runtime/reference/clamp.hpp>
#include <ngraph/runtime/reference/matmul.hpp>
#include <ngraph/runtime/reference/relu.hpp>
#include <ngraph/runtime/reference/sigmoid.hpp>
#include <ngraph/runtime/reference/tanh.hpp>
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void rnn_cell(const T* X,
const Shape& X_shape,
const T* H,
const Shape& H_shape,
const T* W,
const Shape& W_shape,
const T* R,
const Shape& R_shape,
const T* B,
const Shape& B_shape,
T* dst_data,
const std::string& activation_f,
float clip)
{
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ ACRONYMS ------
// i_t - input gate at current time step
// t - time step (t-1 means previous time step)
// X - The input data tensor. Shape: [batch_size, input_size].
// W - The weight tensor for input gate. Shape: [hidden_size, input_size].
// R - The recurrence weight tensor for input gate. Shape: [hidden_size,
// hidden_size].
// H_t - The hidden state tensor at current time step. Shape: [batch_size,
// hidden_size].
// B - The bias tensor for the input gate. Shape: [hidden_size].
// Wb - W bias vectors for input gate.
// Rb - R bias vectors for input gate.
// ------ VARIABLE NAMES ------
// Xt_W - Input sequence multiplied by weights tensor at current time step.
// Ht_R - Hidden state multiplied by weights tensor at current time step.
// (.) - Denotes element-wise multiplication.
// * - Denotes dot product.
// ---- Equations ----
// f - is activation functions.
// Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
// --------------------
// Xt*(W^T)
std::vector<T> Xt_W(X_shape[0] * W_shape[0]);
reference::matmul(
X, W, Xt_W.data(), X_shape, W_shape, {X_shape[0], W_shape[0]}, false, true);
// Ht-1*(R^T)
std::vector<T> Ht_R(H_shape[0] * R_shape[0]);
reference::matmul(
H, R, Ht_R.data(), H_shape, R_shape, {H_shape[0], R_shape[0]}, false, true);
// Ht-1*(R^T) + Wb + Rb
std::vector<T> Ht_R_B(H_shape[0] * R_shape[0]);
reference::add(Ht_R.data(),
B,
Ht_R_B.data(),
{H_shape[0], R_shape[0]},
B_shape,
op::AutoBroadcastSpec::NUMPY);
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb
std::vector<T> i_t(H_shape[0] * R_shape[0]);
reference::add(Xt_W.data(),
Ht_R_B.data(),
i_t.data(),
{X_shape[0], W_shape[0]},
{H_shape[0], R_shape[0]},
op::AutoBroadcastSpec::NUMPY);
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
if (clip != 0.f)
{
reference::clamp(i_t.data(),
i_t.data(),
static_cast<T>(-clip),
static_cast<T>(clip),
i_t.size());
}
if (activation_f == "relu")
{
reference::relu(i_t.data(), dst_data, i_t.size());
}
else if (activation_f == "sigmoid")
{
reference::sigmoid(i_t.data(), dst_data, i_t.size());
}
else if (activation_f == "tanh")
{
reference::tanh(i_t.data(), dst_data, i_t.size());
}
else
{
throw ngraph_error("Activation function " + activation_f +
" is not supported.");
}
}
}
}
}

View File

@ -0,0 +1,37 @@
//*****************************************************************************
// Copyright 2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#pragma once
#include <cmath>
#include "ngraph/runtime/reference/slice.hpp"
namespace ngraph
{
namespace runtime
{
namespace reference
{
void split(const char* data,
const Shape& data_shape,
size_t elem_size,
int64_t axis,
size_t num_splits,
char** out_data);
}
}
}

View File

@ -0,0 +1,54 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include <cmath>
#include <stdio.h>
#include "ngraph/check.hpp"
#include "ngraph/runtime/reference/split.hpp"
using namespace ngraph;
void runtime::reference::split(const char* data,
const Shape& data_shape,
size_t elem_size,
int64_t axis,
size_t num_splits,
char** out_data)
{
const size_t part_length = data_shape.at(axis) / num_splits;
Shape output_shape = data_shape;
output_shape.at(axis) = part_length;
std::vector<size_t> lower_bounds(data_shape.size(), 0);
std::vector<size_t> upper_bounds = data_shape;
upper_bounds.at(axis) = part_length;
for (size_t i = 0; i < num_splits; ++i)
{
runtime::reference::slice(data,
out_data[i],
data_shape,
lower_bounds,
upper_bounds,
Strides(lower_bounds.size(), 1),
output_shape,
elem_size);
lower_bounds.at(axis) += part_length;
upper_bounds.at(axis) += part_length;
}
}

View File

@ -15,12 +15,9 @@
//*****************************************************************************
#include <cmath>
#include <functional>
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "itt.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/gru_cell.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
@ -28,8 +25,6 @@
using namespace std;
using namespace ngraph;
NGRAPH_SUPPRESS_DEPRECATED_START
constexpr NodeTypeInfo op::v3::GRUCell::type_info;
op::v3::GRUCell::GRUCell()
@ -68,8 +63,12 @@ op::v3::GRUCell::GRUCell(const Output<Node>& X,
const vector<float>& activations_beta,
float clip,
bool linear_before_reset)
: FusedOp({X, initial_hidden_state, W, R})
, RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta)
: RNNCellBase({X, initial_hidden_state, W, R},
hidden_size,
clip,
activations,
activations_alpha,
activations_beta)
, m_activation_f{get_activation_function(0)}
, m_activation_g{get_activation_function(1)}
, m_linear_before_reset{linear_before_reset}
@ -89,8 +88,12 @@ op::v3::GRUCell::GRUCell(const Output<Node>& X,
const vector<float>& activations_beta,
float clip,
bool linear_before_reset)
: FusedOp({X, initial_hidden_state, W, R, B})
, RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta)
: RNNCellBase({X, initial_hidden_state, W, R, B},
hidden_size,
clip,
activations,
activations_alpha,
activations_beta)
, m_activation_f{get_activation_function(0)}
, m_activation_g{get_activation_function(1)}
, m_linear_before_reset{linear_before_reset}
@ -104,83 +107,12 @@ bool op::v3::GRUCell::visit_attributes(AttributeVisitor& visitor)
return op::util::RNNCellBase::visit_attributes(visitor);
}
void op::v3::GRUCell::pre_validate_and_infer_types()
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
if (is_dynamic())
{
return;
}
const auto& x_pshape = get_input_partial_shape(0);
const auto& ht_pshape = get_input_partial_shape(1);
const auto& w_pshape = get_input_partial_shape(2);
const auto& r_pshape = get_input_partial_shape(3);
const auto& b_pshape = get_input_partial_shape(4);
const Shape& x_shape{x_pshape.to_shape()};
const size_t batch_size = x_shape.at(0);
const size_t input_size = x_shape.at(1);
const Shape& w_shape{w_pshape.to_shape()};
const Shape& r_shape{r_pshape.to_shape()};
const Shape& ht_shape{ht_pshape.to_shape()};
NODE_VALIDATION_CHECK(this,
(w_shape == Shape{s_gates_count * get_hidden_size(), input_size}),
"Input tensor W must have shape (",
s_gates_count * get_hidden_size(),
", ",
input_size,
"). Actual shape is:",
w_shape,
".");
NODE_VALIDATION_CHECK(this,
(r_shape == Shape{s_gates_count * get_hidden_size(), get_hidden_size()}),
"Input tensor R must have shape (",
s_gates_count * get_hidden_size(),
", ",
get_hidden_size(),
"). Actual shape is:",
w_shape,
".");
NODE_VALIDATION_CHECK(this,
(ht_shape == Shape{batch_size, get_hidden_size()}),
"Input tensor initial_hidden_state must have shape (",
batch_size,
", ",
get_hidden_size(),
"). Actual shape is:",
w_shape,
".");
const Shape& b_shape{b_pshape.to_shape()};
NODE_VALIDATION_CHECK(
this,
(b_shape == Shape{(s_gates_count + m_linear_before_reset) * get_hidden_size()}),
"Input tensor B must have shape (",
(s_gates_count + m_linear_before_reset) * get_hidden_size(),
"). Actual shape is:",
b_shape,
".");
}
void op::v3::GRUCell::validate_and_infer_types()
{
std::vector<ngraph::PartialShape> input_param{};
auto merged_batch_size = Dimension::dynamic();
auto merged_hidden_size = Dimension::dynamic();
auto result_et = element::dynamic;
// Copy all inputs for further validation
for (size_t i = 0; i < get_input_size(); i++)
{
input_param.push_back(get_input_partial_shape(i));
}
// Get input partial shape for all inputs
const auto& x_pshape = get_input_partial_shape(0);
const auto& ht_pshape = get_input_partial_shape(1);
@ -188,7 +120,7 @@ void op::v3::GRUCell::validate_and_infer_types()
const auto& r_pshape = get_input_partial_shape(3);
const auto& b_pshape = get_input_partial_shape(4);
validate_input_rank_dimension(input_param);
validate_input_rank_dimension({x_pshape, ht_pshape, w_pshape, r_pshape, b_pshape});
// Validate input types and save result for output type
NODE_VALIDATION_CHECK(
@ -265,90 +197,6 @@ void op::v3::GRUCell::validate_and_infer_types()
set_output_type(0, result_et, {merged_batch_size, merged_hidden_size});
}
OutputVector op::v3::GRUCell::decompose_op() const
{
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ ACRONYMS ------
// z_t - update gate at current time step
// r_t - reset gate at current time step
// h_t - hidden gate at current time step
// t - time step (t-1 means previous time step)
// X The input data tensor. Shape: [batch_size, input_size].
// W[zrh] - The weight tensor for update, reset and hidden gates.
// Shape: [gates_count * hidden_size, input_size].
// R[zrh] - The recurrence weight tensor for update, reset and hidden gates.
// Shape: [gates_count * hidden_size, hidden_size].
// H_t - The hidden state tensor at current time step. Shape: [batch_size, hidden_size].
// B - The sum of biases (weight and recurrence) for update, reset and hidden gates.
// If linear_before_reset := true then biases for hidden gates are placed separately
// (weight and recurrence).
// Shape: [gates_count * hidden_size] when linear_before_reset := false
// Shape: [(gates_count + 1) * hidden_size] when linear_before_reset := true
// Wb[zrh] - W bias vectors for update, reset and hidden gates.
// Rb[zrh] - R bias vectors for update, reset and hidden gates.
// (.) - Denotes element-wise multiplication.
// * - Denotes dot product.
// ---- Equations ----
// f, g - are activation functions
// zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
// rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
// ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # when linear_before_reset := false
// # (default)
// ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset := true
// Ht = (1 - zt) (.) ht + zt (.) Ht-1
// -------------------
Output<Node> X = input_value(0);
Output<Node> H_t = input_value(1);
Output<Node> W = input_value(2);
Output<Node> R = input_value(3);
Output<Node> B = input_value(4);
// Xt*(W^T)
auto Xt_W = make_shared<op::Dot>(X, builder::opset1::transpose(W));
auto R_transpose = builder::opset1::transpose(R);
// Ht-1*(R^T)
auto Ht_R = make_shared<op::Dot>(H_t, R_transpose);
// split to gates:
OutputVector Xt_W_zrh = builder::split(Xt_W, 3, 1);
OutputVector R_zrh = builder::split(R_transpose, 3, 1);
OutputVector Ht_R_zrh = builder::split(Ht_R, 3, 1);
OutputVector biases_zrh = m_linear_before_reset ? builder::split(B, 4) : builder::split(B, 3);
// zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
auto z_t = m_activation_f(clip(add(Xt_W_zrh[0], add(Ht_R_zrh[0], biases_zrh[0]))));
// rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
auto r_t = m_activation_f(clip(add(Xt_W_zrh[1], add(Ht_R_zrh[1], biases_zrh[1]))));
Output<Node> h_t;
if (m_linear_before_reset)
{
// ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
auto Ht_Rh_Rbh = add(Ht_R_zrh[2], biases_zrh[3]);
h_t = m_activation_g(clip(add(Xt_W_zrh[2], add(mul(r_t, Ht_Rh_Rbh), biases_zrh[2]))));
}
else
{
// ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
auto rt_Ht = mul(r_t, H_t);
auto rt_Ht_Rh = make_shared<op::Dot>(rt_Ht, R_zrh[2]);
// Tensor shape: [batch_size, hidden_size]
h_t = m_activation_g(clip(add(Xt_W_zrh[2], add(rt_Ht_Rh, biases_zrh[2]))));
}
auto one = op::Constant::create(z_t->get_element_type(),
z_t->get_shape(),
vector<float>(shape_size(z_t->get_shape()), 1.f));
// Ht = (1 - zt) (.) ht + zt (.) Ht-1
H_t = add(mul(sub(one, z_t), h_t), mul(z_t, H_t));
return {H_t.get_node_shared_ptr()};
}
void op::v3::GRUCell::add_default_bias_input()
{
Output<Node> B = op::Constant::create(

View File

@ -18,12 +18,8 @@
#include <functional>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/concat.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/lstm_cell.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
@ -31,11 +27,10 @@
using namespace std;
using namespace ngraph;
NGRAPH_SUPPRESS_DEPRECATED_START
constexpr NodeTypeInfo op::v4::LSTMCell::type_info;
constexpr NodeTypeInfo op::v0::LSTMCell::type_info;
constexpr NodeTypeInfo op::LSTMCell::type_info;
op::LSTMCell::LSTMCell()
op::v0::LSTMCell::LSTMCell()
: m_input_forget(false)
, m_weights_format(LSTMWeightsFormat::IFCO)
{
@ -45,20 +40,24 @@ op::LSTMCell::LSTMCell()
m_activation_h = get_activation_function(2);
}
op::LSTMCell::LSTMCell(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& initial_cell_state,
const Output<Node>& W,
const Output<Node>& R,
size_t hidden_size,
op::LSTMWeightsFormat weights_format,
const vector<string>& activations,
const vector<float>& activations_alpha,
const vector<float>& activations_beta,
float clip,
bool input_forget)
: FusedOp({X, initial_hidden_state, initial_cell_state, W, R})
, RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta)
op::v0::LSTMCell::LSTMCell(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& initial_cell_state,
const Output<Node>& W,
const Output<Node>& R,
size_t hidden_size,
op::LSTMWeightsFormat weights_format,
const vector<string>& activations,
const vector<float>& activations_alpha,
const vector<float>& activations_beta,
float clip,
bool input_forget)
: RNNCellBase({X, initial_hidden_state, initial_cell_state, W, R},
hidden_size,
clip,
activations,
activations_alpha,
activations_beta)
, m_activation_f{get_activation_function(0)}
, m_activation_g{get_activation_function(1)}
, m_activation_h{get_activation_function(2)}
@ -70,21 +69,25 @@ op::LSTMCell::LSTMCell(const Output<Node>& X,
constructor_validate_and_infer_types();
}
op::LSTMCell::LSTMCell(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& initial_cell_state,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& B,
size_t hidden_size,
op::LSTMWeightsFormat weights_format,
const vector<string>& activations,
const vector<float>& activations_alpha,
const vector<float>& activations_beta,
float clip,
bool input_forget)
: FusedOp({X, initial_hidden_state, initial_cell_state, W, R, B})
, RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta)
op::v0::LSTMCell::LSTMCell(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& initial_cell_state,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& B,
size_t hidden_size,
op::LSTMWeightsFormat weights_format,
const vector<string>& activations,
const vector<float>& activations_alpha,
const vector<float>& activations_beta,
float clip,
bool input_forget)
: RNNCellBase({X, initial_hidden_state, initial_cell_state, W, R, B},
hidden_size,
clip,
activations,
activations_alpha,
activations_beta)
, m_activation_f{get_activation_function(0)}
, m_activation_g{get_activation_function(1)}
, m_activation_h{get_activation_function(2)}
@ -95,22 +98,26 @@ op::LSTMCell::LSTMCell(const Output<Node>& X,
constructor_validate_and_infer_types();
}
op::LSTMCell::LSTMCell(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& initial_cell_state,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& B,
const Output<Node>& P,
size_t hidden_size,
op::LSTMWeightsFormat weights_format,
const vector<string>& activations,
const vector<float>& activations_alpha,
const vector<float>& activations_beta,
float clip,
bool input_forget)
: FusedOp({X, initial_hidden_state, initial_cell_state, W, R, B, P})
, RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta)
op::v0::LSTMCell::LSTMCell(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& initial_cell_state,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& B,
const Output<Node>& P,
size_t hidden_size,
op::LSTMWeightsFormat weights_format,
const vector<string>& activations,
const vector<float>& activations_alpha,
const vector<float>& activations_beta,
float clip,
bool input_forget)
: RNNCellBase({X, initial_hidden_state, initial_cell_state, W, R, B, P},
hidden_size,
clip,
activations,
activations_alpha,
activations_beta)
, m_activation_f{get_activation_function(0)}
, m_activation_g{get_activation_function(1)}
, m_activation_h{get_activation_function(2)}
@ -133,101 +140,7 @@ bool ngraph::op::v0::LSTMCell::visit_attributes(AttributeVisitor& visitor)
return true;
}
void op::LSTMCell::pre_validate_and_infer_types()
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
set_output_type(1, get_input_element_type(0), PartialShape::dynamic());
if (is_dynamic())
{
return;
}
const auto& x_pshape = get_input_partial_shape(0);
const auto& ht_pshape = get_input_partial_shape(1);
const auto& ct_pshape = get_input_partial_shape(2);
const auto& w_pshape = get_input_partial_shape(3);
const auto& r_pshape = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(this,
(x_pshape.is_static() || w_pshape.is_static() || r_pshape.is_static() ||
ht_pshape.is_static() || ct_pshape.is_static()),
"LSTMCell supports only static input tensors.");
const Shape& x_shape{x_pshape.to_shape()};
const size_t batch_size = x_shape.at(0);
const size_t input_size = x_shape.at(1);
const Shape& w_shape{w_pshape.to_shape()};
const Shape& r_shape{r_pshape.to_shape()};
const Shape& ht_shape{ht_pshape.to_shape()};
const Shape& ct_shape{ct_pshape.to_shape()};
NODE_VALIDATION_CHECK(this,
(w_shape == Shape{s_gates_count * get_hidden_size(), input_size}),
"Input tensor W must have shape (",
s_gates_count * get_hidden_size(),
", ",
input_size,
"). Actual shape is:",
w_shape,
".");
NODE_VALIDATION_CHECK(this,
(r_shape == Shape{s_gates_count * get_hidden_size(), get_hidden_size()}),
"Input tensor R must have shape (",
s_gates_count * get_hidden_size(),
", ",
get_hidden_size(),
"). Actual shape is:",
r_shape,
".");
NODE_VALIDATION_CHECK(this,
(ht_shape == Shape{batch_size, get_hidden_size()}),
"Input tensor initial_hidden_state must have shape (",
batch_size,
", ",
get_hidden_size(),
"). Actual shape is:",
ht_shape,
".");
NODE_VALIDATION_CHECK(this,
(ct_shape == Shape{batch_size, get_hidden_size()}),
"Input tensor initial_cell_state must have shape (",
batch_size,
", ",
get_hidden_size(),
"). Actual shape is:",
ct_shape,
".");
const auto& b_pshape = get_input_partial_shape(5);
const auto& p_pshape = get_input_partial_shape(6);
NODE_VALIDATION_CHECK(this,
(b_pshape.is_static() || p_pshape.is_static()),
"LSTMCell supports only static input tensors.");
const Shape& b_shape{b_pshape.to_shape()};
const Shape& p_shape{p_pshape.to_shape()};
NODE_VALIDATION_CHECK(this,
(b_shape == Shape{s_gates_count * get_hidden_size()}),
"Input tensor B must have shape (",
s_gates_count * get_hidden_size(),
"). Actual shape is:",
b_shape,
".");
NODE_VALIDATION_CHECK(this,
(p_shape == Shape{s_peepholes_count * get_hidden_size()}),
"Input tensor P must have shape (",
s_peepholes_count * get_hidden_size(),
"). Actual shape is:",
p_shape,
".");
}
void op::LSTMCell::validate_and_infer_types()
void op::v0::LSTMCell::validate_and_infer_types()
{
std::vector<ngraph::PartialShape> input_param{};
@ -367,186 +280,69 @@ void op::LSTMCell::validate_and_infer_types()
set_output_type(1, result_et, {merged_batch_size, merged_hidden_size});
}
OutputVector op::LSTMCell::decompose_op() const
{
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ ACRONYMS ------
// i - input gate
// o - output gate
// f - forget gate
// c - cell gate
// t - time step (t-1 means previous time step)
// Wb - W bias vectors for input, output, forget, and cell gates.
// Rb - R bias vectors for input, output, forget, and cell gates.
// P - The peephole weights for input, output and forget gates.
// ------ VARIABLE NAMES ------
// X - The input data tensor. Shape: [batch_size, input_size].
// W - The weight matrix for input, forget, cell and output gates
// Shape: [4*hidden_size, input_size]
// R - The recurrence weight matrix for input, forget, cell and output gates.
// Shape: [4*hidden_size, hidden_size].
// H_t - The hidden state tensor at current time step. Shape: [batch_size, hidden_size].
// C_t - The cell state tensor at current time step. Shape: [batch_size, hidden_size].
// bias - The sum of biases (weight and recurrence) for input, forget, cell and output gates.
// Shape: [4 * hidden_size]
// p_[iof] - The peephole weight vector for respectively: input, output, and forget gates.
// Each peephole has shape [hidden_size].
//
// (.) - Denotes element-wise multiplication.
// * - Denotes dot product.
//
// ---- Equations ----
// f, g, h - are activation functions.
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
// Ct = ft (.) Ct-1 + it (.) ct
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
// Ht = ot (.) h(Ct)
// --------------------
Output<Node> X = input_value(0);
Output<Node> H_t = input_value(1);
Output<Node> C_t = input_value(2);
Output<Node> W = input_value(3);
Output<Node> R = input_value(4);
Output<Node> bias = input_value(5);
OutputVector p_iof = builder::split(input_value(6), s_peepholes_count);
// Converting to IFCO format since it's DNNL default.
if (m_weights_format != op::LSTMWeightsFormat::IFCO)
{
W = convert_node_format(W);
R = convert_node_format(R);
bias = convert_node_format(bias);
}
const auto& p_i = p_iof.at(0);
const auto& p_o = p_iof.at(1);
const auto& p_f = p_iof.at(2);
// Xt*(W^T) -- for [iofc] gates.
auto Xt_W = make_shared<op::Dot>(X, builder::opset1::transpose(W));
// Ht-1*(R^T) -- for [iofc] gates.
auto Ht_R = make_shared<op::Dot>(H_t, builder::opset1::transpose(R));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates.
auto gates = add(Xt_W, add(Ht_R, bias));
OutputVector split_gates = builder::split(gates, 4, -1);
auto i_t = split_gates.at(0);
auto f_t = split_gates.at(1);
auto c_t = split_gates.at(2);
auto o_t = split_gates.at(3);
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
i_t = m_activation_f(clip(add(i_t, mul(p_i, C_t))));
if (m_input_forget)
{
// Couple input with forget gate: 1 - i_t
f_t = sub(op::Constant::create(i_t.get_element_type(),
i_t.get_shape(),
vector<float>(shape_size(i_t.get_shape()), 1.f)),
i_t);
}
else
{
// f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
f_t = m_activation_f(clip(add(f_t, mul(p_f, C_t))));
}
// ft (.) Ct-1 + it (.) ct
auto C = add(mul(f_t, C_t), mul(i_t, m_activation_g(clip(c_t))));
// f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
o_t = m_activation_f(clip(add(o_t, mul(p_o, C))));
// ot (.) h(Ct)
auto H = mul(o_t, m_activation_h(clip(C)));
return {H, C};
}
Output<Node> op::LSTMCell::get_default_bias_input() const
Output<Node> op::v0::LSTMCell::get_default_bias_input() const
{
return Output<Node>{op::Constant::create(
get_input_element_type(0), Shape{s_gates_count * get_hidden_size()}, vector<float>{0.f})};
}
Output<Node> op::LSTMCell::get_default_peepholes_input() const
Output<Node> op::v0::LSTMCell::get_default_peepholes_input() const
{
return Output<Node>{op::Constant::create(get_input_element_type(0),
Shape{s_peepholes_count * get_hidden_size()},
vector<float>{0.f})};
}
shared_ptr<Node> op::LSTMCell::convert_node_format(const Output<Node>& node) const
{
static const std::map<op::LSTMWeightsFormat, std::vector<size_t>> gate_order_conversion_map{
{op::LSTMWeightsFormat::FICO, {1, 0, 2, 3}},
{op::LSTMWeightsFormat::ICOF, {0, 3, 1, 2}},
{op::LSTMWeightsFormat::IFOC, {0, 1, 3, 2}},
{op::LSTMWeightsFormat::IOFC, {0, 2, 3, 1}},
};
OutputVector splitted_node = builder::split(node, s_gates_count);
OutputVector nodes_in_new_format;
nodes_in_new_format.reserve(s_gates_count);
for (const auto& axis : gate_order_conversion_map.at(m_weights_format))
{
nodes_in_new_format.push_back(splitted_node.at(axis));
}
return make_shared<op::Concat>(nodes_in_new_format, 0);
}
shared_ptr<Node> op::LSTMCell::clone_with_new_inputs(const OutputVector& new_args) const
shared_ptr<Node> op::v0::LSTMCell::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
if (new_args.size() == 5)
{
return make_shared<LSTMCell>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
get_hidden_size(),
get_weights_format(),
get_activations(),
get_activations_alpha(),
get_activations_beta(),
get_clip(),
m_input_forget);
return make_shared<op::v0::LSTMCell>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
get_hidden_size(),
get_weights_format(),
get_activations(),
get_activations_alpha(),
get_activations_beta(),
get_clip(),
m_input_forget);
}
else if (new_args.size() == 6)
{
return make_shared<LSTMCell>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
new_args.at(5),
get_hidden_size(),
get_weights_format(),
get_activations(),
get_activations_alpha(),
get_activations_beta(),
get_clip(),
m_input_forget);
return make_shared<op::v0::LSTMCell>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
new_args.at(5),
get_hidden_size(),
get_weights_format(),
get_activations(),
get_activations_alpha(),
get_activations_beta(),
get_clip(),
m_input_forget);
}
else if (new_args.size() == 7)
{
return make_shared<LSTMCell>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
new_args.at(5),
new_args.at(6),
get_hidden_size(),
get_weights_format(),
get_activations(),
get_activations_alpha(),
get_activations_beta(),
get_clip(),
m_input_forget);
return make_shared<op::v0::LSTMCell>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
new_args.at(5),
new_args.at(6),
get_hidden_size(),
get_weights_format(),
get_activations(),
get_activations_alpha(),
get_activations_beta(),
get_clip(),
m_input_forget);
}
else
{
@ -576,3 +372,212 @@ namespace ngraph
return s << as_string(type);
}
} // namespace ngraph
op::v4::LSTMCell::LSTMCell()
{
m_activations = {"sigmoid", "tanh", "tanh"};
m_activation_f = get_activation_function(0);
m_activation_g = get_activation_function(1);
m_activation_h = get_activation_function(2);
}
op::v4::LSTMCell::LSTMCell(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& initial_cell_state,
const Output<Node>& W,
const Output<Node>& R,
size_t hidden_size,
const vector<string>& activations,
const vector<float>& activations_alpha,
const vector<float>& activations_beta,
float clip)
: RNNCellBase({X, initial_hidden_state, initial_cell_state, W, R},
hidden_size,
clip,
activations,
activations_alpha,
activations_beta)
, m_activation_f{get_activation_function(0)}
, m_activation_g{get_activation_function(1)}
, m_activation_h{get_activation_function(2)}
{
set_argument(5, get_default_bias_input());
constructor_validate_and_infer_types();
}
op::v4::LSTMCell::LSTMCell(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& initial_cell_state,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& B,
size_t hidden_size,
const vector<string>& activations,
const vector<float>& activations_alpha,
const vector<float>& activations_beta,
float clip)
: RNNCellBase({X, initial_hidden_state, initial_cell_state, W, R, B},
hidden_size,
clip,
activations,
activations_alpha,
activations_beta)
, m_activation_f{get_activation_function(0)}
, m_activation_g{get_activation_function(1)}
, m_activation_h{get_activation_function(2)}
{
constructor_validate_and_infer_types();
}
bool ngraph::op::v4::LSTMCell::visit_attributes(AttributeVisitor& visitor)
{
return op::util::RNNCellBase::visit_attributes(visitor);
}
void op::v4::LSTMCell::validate_and_infer_types()
{
auto merged_batch_size = Dimension::dynamic();
auto merged_hidden_size = Dimension::dynamic();
auto result_et = element::dynamic;
// Get input partial shape for all inputs
const auto& x_pshape = get_input_partial_shape(0);
const auto& ht_pshape = get_input_partial_shape(1);
const auto& ct_pshape = get_input_partial_shape(2);
const auto& w_pshape = get_input_partial_shape(3);
const auto& r_pshape = get_input_partial_shape(4);
const auto& b_pshape = get_input_partial_shape(5);
// Validate rank and dimension for initial_cell_state input
NODE_VALIDATION_CHECK(this,
(ct_pshape.rank().is_static()),
"LSTMCell input tensor initial_cell_state shall have static rank.");
NODE_VALIDATION_CHECK(this,
(ct_pshape.rank().get_length() == 2),
"LSTMCell input tensor initial_cell_state shall have dimension 2D.");
validate_input_rank_dimension({x_pshape, ht_pshape, w_pshape, r_pshape, b_pshape});
// Validate input element types and save result for output type
NODE_VALIDATION_CHECK(
this,
element::Type::merge(result_et, result_et, get_input_element_type(0)) &&
element::Type::merge(result_et, result_et, get_input_element_type(1)) &&
element::Type::merge(result_et, result_et, get_input_element_type(2)) &&
element::Type::merge(result_et, result_et, get_input_element_type(3)) &&
element::Type::merge(result_et, result_et, get_input_element_type(4)) &&
element::Type::merge(result_et, result_et, get_input_element_type(5)),
"Element types for X, initial_hidden_state, initial_cell_state, W, R and B do not match.");
// Merge batch_size dimension across all inputs to evaluate output[0] dimension
NODE_VALIDATION_CHECK(
this,
Dimension::merge(merged_batch_size, merged_batch_size, ht_pshape[0]) &&
Dimension::merge(merged_batch_size, merged_batch_size, ct_pshape[0]) &&
Dimension::merge(merged_batch_size, merged_batch_size, x_pshape[0]),
"Parameter batch_size not matched for X, initial_hidden_state or initial_cell_state "
"inputs.");
// Merge hidden_size dimension across all inputs to evaluate output[1] dimension
NODE_VALIDATION_CHECK(
this,
Dimension::merge(merged_hidden_size, merged_hidden_size, ht_pshape[1]) &&
Dimension::merge(merged_hidden_size, merged_hidden_size, ct_pshape[1]) &&
Dimension::merge(merged_hidden_size, merged_hidden_size, r_pshape[1]),
"Parameter hidden_size not matched for R, initial_hidden_state and initial_cell_state "
"inputs.");
// Validate hidden_size value for W, R and P inputs
if (merged_hidden_size.is_static())
{
if (w_pshape[0].is_static())
{
NODE_VALIDATION_CHECK(
this,
w_pshape[0].compatible(merged_hidden_size * s_gates_count),
"Parameter hidden_size mistmatched in W input. Current value is: ",
w_pshape[0].get_length(),
", expected: ",
merged_hidden_size.get_length() * s_gates_count,
".");
}
if (r_pshape[0].is_static())
{
NODE_VALIDATION_CHECK(
this,
r_pshape[0].compatible(merged_hidden_size * s_gates_count),
"Parameter hidden_size mistmatched in R input. Current value is: ",
r_pshape[0].get_length(),
", expected: ",
merged_hidden_size.get_length() * s_gates_count,
".");
}
if (b_pshape[0].is_static())
{
NODE_VALIDATION_CHECK(
this,
b_pshape[0].compatible(merged_hidden_size * s_gates_count),
"Parameter hidden_size mistmatched in B input. Current value is: ",
b_pshape[0].get_length(),
", expected: ",
merged_hidden_size.get_length() * s_gates_count,
".");
}
}
// Mark inputs which are relevant to output parameters
set_input_is_relevant_to_shape(0);
set_input_is_relevant_to_shape(1);
set_input_is_relevant_to_shape(2);
set_input_is_relevant_to_shape(4);
// Set output size, type and shape
set_output_size(2);
set_output_type(0, result_et, {merged_batch_size, merged_hidden_size});
set_output_type(1, result_et, {merged_batch_size, merged_hidden_size});
}
Output<Node> op::v4::LSTMCell::get_default_bias_input() const
{
return Output<Node>{op::Constant::create(
get_input_element_type(0), Shape{s_gates_count * get_hidden_size()}, vector<float>{0.f})};
}
shared_ptr<Node> op::v4::LSTMCell::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
if (new_args.size() == 5)
{
return make_shared<LSTMCell>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
get_hidden_size(),
get_activations(),
get_activations_alpha(),
get_activations_beta(),
get_clip());
}
else if (new_args.size() == 6)
{
return make_shared<LSTMCell>(new_args.at(0),
new_args.at(1),
new_args.at(2),
new_args.at(3),
new_args.at(4),
new_args.at(5),
get_hidden_size(),
get_activations(),
get_activations_alpha(),
get_activations_beta(),
get_clip());
}
else
{
throw ngraph_error("Incorrect number of new arguments");
}
}

View File

@ -22,13 +22,16 @@
#include "ngraph/builder/split.hpp"
#include "ngraph/opsets/opset1.hpp"
#include "ngraph/opsets/opset4.hpp"
#include "ngraph/op/util/recurrent_sequence.hpp"
using namespace ngraph;
using namespace std;
constexpr NodeTypeInfo op::v1::LSTMSequence::type_info;
constexpr NodeTypeInfo op::v0::LSTMSequence::type_info;
bool ngraph::op::v0::LSTMSequence::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("hidden_size", m_hidden_size);
@ -415,3 +418,165 @@ void op::v0::LSTMSequence::validate_and_infer_types()
set_output_type(1, result_et, {merged_batch_size, merged_num_directions, merged_hidden_size});
set_output_type(2, result_et, {merged_batch_size, merged_num_directions, merged_hidden_size});
}
bool ngraph::op::v1::LSTMSequence::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("direction", m_direction);
return op::util::RNNCellBase::visit_attributes(visitor);
}
shared_ptr<Node> op::v1::LSTMSequence::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
if (new_args.size() == 7)
{
return make_shared<op::v1::LSTMSequence>(new_args.at(0), // X
new_args.at(1), // initial_hidden_state
new_args.at(2), // initial_cell_state
new_args.at(3), // sequence_lengths
new_args.at(4), // W
new_args.at(5), // R
new_args.at(6), // B
m_hidden_size,
m_direction,
m_activations_alpha,
m_activations_beta,
m_activations,
m_clip);
}
else
{
throw ngraph_error("Incorrect number of new arguments");
}
}
void op::v1::LSTMSequence::validate_and_infer_types()
{
std::vector<ngraph::PartialShape> input_param{};
auto lstm_seq_gates_count = 4;
auto merged_batch_size = Dimension::dynamic();
auto merged_hidden_size = Dimension::dynamic();
auto merged_num_directions = Dimension::dynamic();
auto result_et = element::dynamic;
// Copy all inputs without initial_cell_state information for further validation
for (size_t i = 0; i < get_input_size(); i++)
{
// exclude initial_cell_state from the loop
if (i != 2)
{
input_param.push_back(get_input_partial_shape(i));
}
}
// Get input partial shape for all inputs
const auto& x_pshape = get_input_partial_shape(0);
const auto& ht_pshape = get_input_partial_shape(1);
const auto& ct_pshape = get_input_partial_shape(2);
const auto& sl_pshape = get_input_partial_shape(3);
const auto& w_pshape = get_input_partial_shape(4);
const auto& r_pshape = get_input_partial_shape(5);
const auto& b_pshape = get_input_partial_shape(6);
ngraph::op::util::validate_seq_input_rank_dimension(input_param);
// Validate rank and dimension for initial_cell_state input
NODE_VALIDATION_CHECK(this,
(ct_pshape.rank().is_static()),
"LSTMSequence input tensor initial_cell_state shall have static rank.");
NODE_VALIDATION_CHECK(this,
(ct_pshape.rank().get_length() == 3),
"LSTMSequence input tensor initial_cell_state shall have dimension 3D.");
// Validate input types and save result for output type
NODE_VALIDATION_CHECK(
this,
element::Type::merge(result_et, result_et, get_input_element_type(0)) &&
element::Type::merge(result_et, result_et, get_input_element_type(1)) &&
element::Type::merge(result_et, result_et, get_input_element_type(2)) &&
element::Type::merge(result_et, result_et, get_input_element_type(4)) &&
element::Type::merge(result_et, result_et, get_input_element_type(5)) &&
element::Type::merge(result_et, result_et, get_input_element_type(6)),
"Element types for X, initial_hidden_state, initial_cell_state, W, R and B inputs do not "
"match.");
// Merge batch_size dimension across all inputs to evaluate output[0] dimension
NODE_VALIDATION_CHECK(
this,
Dimension::merge(merged_batch_size, merged_batch_size, ht_pshape[0]) &&
Dimension::merge(merged_batch_size, merged_batch_size, ct_pshape[0]) &&
Dimension::merge(merged_batch_size, merged_batch_size, x_pshape[0]) &&
Dimension::merge(merged_batch_size, merged_batch_size, sl_pshape[0]),
"Parameter batch_size not matched in LSTMSequence.");
// Merge hidden_size dimension across all inputs to evaluate output dimension
NODE_VALIDATION_CHECK(
this,
Dimension::merge(merged_hidden_size, merged_hidden_size, ht_pshape[2]) &&
Dimension::merge(merged_hidden_size, merged_hidden_size, ct_pshape[2]) &&
Dimension::merge(merged_hidden_size, merged_hidden_size, r_pshape[2]),
"Parameter hidden_size not matched LSTMSequence.");
// Merge num_directions dimension across all inputs to evaluate output dimension
NODE_VALIDATION_CHECK(
this,
Dimension::merge(merged_num_directions, merged_num_directions, ht_pshape[1]) &&
Dimension::merge(merged_num_directions, merged_num_directions, ct_pshape[1]) &&
Dimension::merge(merged_num_directions, merged_num_directions, w_pshape[0]) &&
Dimension::merge(merged_num_directions, merged_num_directions, r_pshape[0]) &&
Dimension::merge(merged_num_directions, merged_num_directions, b_pshape[0]),
"Parameter num_directions not matched in LSTMSequence.");
// Validate hidden_size value for W, R, B inputs
if (merged_hidden_size.is_static())
{
if (w_pshape[0].is_static())
{
NODE_VALIDATION_CHECK(
this,
w_pshape[1].compatible(merged_hidden_size * lstm_seq_gates_count),
"Parameter hidden_size mistmatched in W input. Current value is: ",
w_pshape[1].get_length(),
", expected: ",
merged_hidden_size.get_length() * lstm_seq_gates_count,
".");
}
if (r_pshape[0].is_static())
{
NODE_VALIDATION_CHECK(
this,
r_pshape[1].compatible(merged_hidden_size * lstm_seq_gates_count),
"Parameter hidden_size mistmatched in R input. Current value is: ",
r_pshape[1].get_length(),
", expected: ",
merged_hidden_size.get_length() * lstm_seq_gates_count,
".");
}
if (b_pshape[0].is_static())
{
NODE_VALIDATION_CHECK(
this,
b_pshape[1].compatible(merged_hidden_size * lstm_seq_gates_count),
"Parameter hidden_size mistmatched in B input. Current value is: ",
b_pshape[1].get_length(),
", expected: ",
merged_hidden_size.get_length() * lstm_seq_gates_count,
".");
}
}
// Mark inputs which are relevant to output parameters
for (size_t i = 0; i <= 6; ++i)
set_input_is_relevant_to_shape(i);
// Set output size, type and shape
set_output_size(3);
set_output_type(
0, result_et, {merged_batch_size, merged_num_directions, x_pshape[1], merged_hidden_size});
set_output_type(1, result_et, {merged_batch_size, merged_num_directions, merged_hidden_size});
set_output_type(2, result_et, {merged_batch_size, merged_num_directions, merged_hidden_size});
}

View File

@ -14,156 +14,79 @@
// limitations under the License.
//*****************************************************************************
#include "ngraph/op/rnn_cell.hpp"
#include <cmath>
#include <functional>
#include "itt.hpp"
#include "ngraph/builder/reshape.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/add.hpp"
#include "ngraph/op/constant.hpp"
#include "ngraph/op/dot.hpp"
#include "ngraph/op/rnn_cell.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
using namespace std;
using namespace ngraph;
NGRAPH_SUPPRESS_DEPRECATED_START
constexpr NodeTypeInfo op::v0::RNNCell::type_info;
constexpr NodeTypeInfo op::RNNCell::type_info;
op::RNNCell::RNNCell()
op::v0::RNNCell::RNNCell()
{
m_activations = {"tanh"};
m_activation_f = get_activation_function(0);
}
op::RNNCell::RNNCell(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& W,
const Output<Node>& R,
size_t hidden_size,
const vector<string>& activations,
const vector<float>& activations_alpha,
const vector<float>& activations_beta,
float clip)
: FusedOp({X, initial_hidden_state, W, R})
, RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta)
op::v0::RNNCell::RNNCell(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& W,
const Output<Node>& R,
size_t hidden_size,
const vector<string>& activations,
const vector<float>& activations_alpha,
const vector<float>& activations_beta,
float clip)
: RNNCellBase({X, initial_hidden_state, W, R},
hidden_size,
clip,
activations,
activations_alpha,
activations_beta)
, m_activation_f{get_activation_function(0)}
{
set_argument(4, get_default_bias_input());
constructor_validate_and_infer_types();
}
op::RNNCell::RNNCell(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& B,
size_t hidden_size,
const vector<string>& activations,
const vector<float>& activations_alpha,
const vector<float>& activations_beta,
float clip)
: FusedOp({X, initial_hidden_state, W, R, B})
, RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta)
op::v0::RNNCell::RNNCell(const Output<Node>& X,
const Output<Node>& initial_hidden_state,
const Output<Node>& W,
const Output<Node>& R,
const Output<Node>& B,
size_t hidden_size,
const vector<string>& activations,
const vector<float>& activations_alpha,
const vector<float>& activations_beta,
float clip)
: RNNCellBase({X, initial_hidden_state, W, R, B},
hidden_size,
clip,
activations,
activations_alpha,
activations_beta)
, m_activation_f{get_activation_function(0)}
{
constructor_validate_and_infer_types();
}
bool op::RNNCell::visit_attributes(AttributeVisitor& visitor)
bool op::v0::RNNCell::visit_attributes(AttributeVisitor& visitor)
{
return op::util::RNNCellBase::visit_attributes(visitor);
}
void op::RNNCell::pre_validate_and_infer_types()
void op::v0::RNNCell::validate_and_infer_types()
{
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
if (is_dynamic())
{
return;
}
const auto& x_pshape = get_input_partial_shape(0);
const auto& ht_pshape = get_input_partial_shape(1);
const auto& w_pshape = get_input_partial_shape(2);
const auto& r_pshape = get_input_partial_shape(3);
NODE_VALIDATION_CHECK(this,
(x_pshape.is_static() || w_pshape.is_static() || r_pshape.is_static() ||
ht_pshape.is_static()),
"RNNCell supports only static input tensors.");
const Shape& x_shape{x_pshape.to_shape()};
const size_t batch_size = x_shape.at(0);
const size_t input_size = x_shape.at(1);
const Shape& w_shape{w_pshape.to_shape()};
const Shape& r_shape{r_pshape.to_shape()};
const Shape& ht_shape{ht_pshape.to_shape()};
NODE_VALIDATION_CHECK(this,
(w_shape == Shape{get_hidden_size(), input_size}),
"Input tensor W must have shape (",
get_hidden_size(),
", ",
input_size,
"). Actual shape is:",
w_shape,
".");
NODE_VALIDATION_CHECK(this,
(r_shape == Shape{get_hidden_size(), get_hidden_size()}),
"Input tensor R must have shape (",
get_hidden_size(),
", ",
get_hidden_size(),
"). Actual shape is:",
w_shape,
".");
NODE_VALIDATION_CHECK(this,
(ht_shape == Shape{batch_size, get_hidden_size()}),
"Input tensor initial_hidden_state must have shape (",
batch_size,
", ",
get_hidden_size(),
"). Actual shape is:",
w_shape,
".");
const auto& b_pshape = get_input_partial_shape(4);
NODE_VALIDATION_CHECK(
this, b_pshape.is_static(), "RNNCell supports only static input tensors.");
const Shape& b_shape{b_pshape.to_shape()};
NODE_VALIDATION_CHECK(this,
(b_shape == Shape{get_hidden_size()}),
"Input tensor B must have shape (",
get_hidden_size(),
"). Actual shape is:",
b_shape,
".");
}
void op::RNNCell::validate_and_infer_types()
{
std::vector<ngraph::PartialShape> input_param{};
auto merged_batch_size = Dimension::dynamic();
auto merged_hidden_size = Dimension::dynamic();
auto result_et = element::dynamic;
// Copy all inputs for further validation
for (size_t i = 0; i < get_input_size(); i++)
{
input_param.push_back(get_input_partial_shape(i));
}
// Get input partial shape for all inputs
const auto& x_pshape = get_input_partial_shape(0);
const auto& ht_pshape = get_input_partial_shape(1);
@ -171,7 +94,7 @@ void op::RNNCell::validate_and_infer_types()
const auto& r_pshape = get_input_partial_shape(3);
const auto& b_pshape = get_input_partial_shape(4);
validate_input_rank_dimension(input_param);
validate_input_rank_dimension({x_pshape, ht_pshape, w_pshape, r_pshape, b_pshape});
// Validate input types and save result for output type
NODE_VALIDATION_CHECK(
@ -238,72 +161,23 @@ void op::RNNCell::validate_and_infer_types()
}
// Mark inputs which are relevant to output parameters
set_input_is_relevant_to_shape(0);
set_input_is_relevant_to_shape(1);
set_input_is_relevant_to_shape(2);
set_input_is_relevant_to_shape(3);
set_input_is_relevant_to_shape(4);
for (size_t i = 0; i <= 4; ++i)
set_input_is_relevant_to_shape(i);
// Set output size, type and shape
set_output_size(1);
set_output_type(0, result_et, {merged_batch_size, merged_hidden_size});
}
OutputVector op::RNNCell::decompose_op() const
{
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
// The names used below are analogous to the one used in ONNX documentation.
//
// ------ ACRONYMS ------
// i_t - input gate at current time step
// t - time step (t-1 means previous time step)
// X - The input data tensor. Shape: [batch_size, input_size].
// W - The weight tensor for input gate. Shape: [hidden_size, input_size].
// R - The recurrence weight tensor for input gate. Shape: [hidden_size, hidden_size].
// H_t - The hidden state tensor at current time step. Shape: [batch_size, hidden_size].
// B - The bias tensor for the input gate. Shape: [hidden_size].
// Wb - W bias vectors for input gate.
// Rb - R bias vectors for input gate.
// ------ VARIABLE NAMES ------
// Xt_W - Input sequence multiplied by weights tensor at current time step.
// Ht_R - Hidden state multiplied by weights tensor at current time step.
// (.) - Denotes element-wise multiplication.
// * - Denotes dot product.
// ---- Equations ----
// f - is activation functions.
// Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
// --------------------
Output<Node> X = input_value(0);
Output<Node> H_t = input_value(1);
Output<Node> W = input_value(2);
Output<Node> R = input_value(3);
Output<Node> bias = input_value(4);
// Xt*(W^T)
auto Xt_W = std::make_shared<op::Dot>(X, builder::opset1::transpose(W));
// Ht-1*(R^T)
auto Ht_R = std::make_shared<op::Dot>(H_t, builder::opset1::transpose(R));
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb
auto i_t = add(Xt_W, add(Ht_R, bias));
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
i_t = m_activation_f(clip(i_t));
return {i_t};
}
Output<Node> op::RNNCell::get_default_bias_input() const
Output<Node> op::v0::RNNCell::get_default_bias_input() const
{
return Output<Node>{
op::Constant::create(get_input_element_type(0),
Shape{s_gates_count * get_hidden_size()},
vector<float>(s_gates_count * get_hidden_size(), 0.f))};
op::v0::Constant::create(get_input_element_type(0),
Shape{s_gates_count * get_hidden_size()},
vector<float>(s_gates_count * get_hidden_size(), 0.f))};
}
shared_ptr<Node> op::RNNCell::clone_with_new_inputs(const OutputVector& new_args) const
shared_ptr<Node> op::v0::RNNCell::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
if (new_args.size() == 4)

View File

@ -13,8 +13,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************
#include "ngraph/runtime/reference/split.hpp"
#include <numeric>
#include "ngraph/attribute_visitor.hpp"
#include "ngraph/builder/split.hpp"
#include "ngraph/op/constant.hpp"
@ -23,8 +23,6 @@
#include "ngraph/validation_util.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/reference/slice.hpp"
NGRAPH_SUPPRESS_DEPRECATED_START
using namespace std;
@ -196,20 +194,25 @@ shared_ptr<Node> op::v1::Split::clone_with_new_inputs(const OutputVector& new_ar
namespace
{
inline bool evaluate(const HostTensorPtr& in,
const HostTensorPtr& out,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds)
inline bool evaluate(const HostTensorPtr& data_tensor,
const HostTensorVector& outputs,
const int64_t axis,
const int64_t num_splits)
{
runtime::reference::slice(in->get_data_ptr<const char>(),
out->get_data_ptr<char>(),
in->get_shape(),
lower_bounds,
upper_bounds,
Strides(lower_bounds.size(), 1),
out->get_shape(),
in->get_element_type().size());
Shape output_shape = data_tensor->get_shape();
std::vector<char*> outputs_data(num_splits);
output_shape.at(axis) /= num_splits;
for (size_t i = 0; i < outputs.size(); ++i)
{
outputs[i]->set_shape(output_shape);
outputs_data[i] = outputs[i]->get_data_ptr<char>();
}
ngraph::runtime::reference::split(data_tensor->get_data_ptr<char>(),
data_tensor->get_shape(),
data_tensor->get_element_type().size(),
axis,
num_splits,
outputs_data.data());
return true;
}
@ -236,26 +239,7 @@ namespace
break;
}
axis = ngraph::normalize_axis(split_node, axis, data_tensor->get_partial_shape().rank());
const auto data_shape = data_tensor->get_shape();
const size_t axis_dim_length = data_shape.at(axis);
const size_t part_length = axis_dim_length / num_splits;
Shape output_shape = data_shape;
output_shape.at(axis) = part_length;
std::vector<size_t> lower_bounds(data_shape.size(), 0);
std::vector<size_t> upper_bounds = data_shape;
upper_bounds.at(axis) = part_length;
for (const auto& output : outputs)
{
output->set_shape(output_shape);
evaluate(data_tensor, output, lower_bounds, upper_bounds);
lower_bounds.at(axis) += part_length;
upper_bounds.at(axis) += part_length;
}
evaluate(data_tensor, outputs, axis, num_splits);
return true;
}
}

View File

@ -24,11 +24,38 @@
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/subtract.hpp"
#include "ngraph/op/util/rnn_cell_base.hpp"
#include "ngraph/opsets/opset4.hpp"
#include "ngraph/util.hpp"
using namespace std;
using namespace ngraph;
std::shared_ptr<Node> ngraph::op::util::convert_lstm_node_format(const Output<Node>& node,
LSTMWeightsFormat from_format,
LSTMWeightsFormat to_format)
{
static const std::map<op::util::LSTMWeightsFormat, std::vector<size_t>> gate_order_map{
{op::util::LSTMWeightsFormat::FICO, {0, 1, 2, 3}},
{op::util::LSTMWeightsFormat::ICOF, {1, 2, 3, 0}},
{op::util::LSTMWeightsFormat::IFOC, {1, 0, 3, 2}},
{op::util::LSTMWeightsFormat::IOFC, {1, 3, 0, 2}},
{op::util::LSTMWeightsFormat::IFCO, {1, 0, 2, 3}},
};
const auto& from = gate_order_map.at(from_format);
const auto& to = gate_order_map.at(to_format);
size_t num_gates = 4;
auto axis_const = std::make_shared<opset4::Constant>(element::i64, Shape{}, 0);
OutputVector splitted_node =
std::make_shared<opset4::Split>(node, axis_const, num_gates)->outputs();
OutputVector nodes_in_new_format(num_gates);
for (size_t i = 0; i < num_gates; ++i)
{
nodes_in_new_format[to[from[i]]] = splitted_node[i];
}
return std::make_shared<opset4::Concat>(nodes_in_new_format, 0);
}
// Modify input vector in-place and return reference to modified vector.
static vector<string> to_lower_case(const vector<string>& vs)
{
@ -43,12 +70,14 @@ op::util::RNNCellBase::RNNCellBase()
{
}
op::util::RNNCellBase::RNNCellBase(size_t hidden_size,
op::util::RNNCellBase::RNNCellBase(const OutputVector& args,
size_t hidden_size,
float clip,
const vector<string>& activations,
const vector<float>& activations_alpha,
const vector<float>& activations_beta)
: m_hidden_size(hidden_size)
: Op(args)
, m_hidden_size(hidden_size)
, m_clip(clip)
, m_activations(to_lower_case(activations))
, m_activations_alpha(activations_alpha)

View File

@ -29,6 +29,7 @@
#include "ngraph/op/constant.hpp"
#include "ngraph/op/lstm_sequence.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/opsets/opset3.hpp"
#include "ngraph/shape.hpp"
#include "ngraph/type/element_type.hpp"
#include "onnx_import/core/null_node.hpp"
@ -212,7 +213,10 @@ namespace ngraph
LSTMNgInputMap input_map{node};
LSTMAttributes attributes{node};
auto lstmSequence = std::make_shared<default_opset::LSTMSequence>(
// LSTMSequence is not fully supported in OpenVINO and is excluded from
// opset4 (current the latest opset version), use one of the previous
// opsets instead of default
auto lstmSequence = std::make_shared<opset3::LSTMSequence>(
input_map.at(LSTMInput::LSTM_INPUT_X),
input_map.at(LSTMInput::LSTM_INPUT_INIT_H),
input_map.at(LSTMInput::LSTM_INPUT_INIT_C),

View File

@ -82,7 +82,7 @@ from ngraph.opset1.ops import logical_not
from ngraph.opset1.ops import logical_or
from ngraph.opset1.ops import logical_xor
from ngraph.opset1.ops import lrn
from ngraph.opset1.ops import lstm_cell
from ngraph.opset4.ops import lstm_cell
from ngraph.opset1.ops import lstm_sequence
from ngraph.opset1.ops import matmul
from ngraph.opset1.ops import max_pool

View File

@ -367,3 +367,54 @@ def reduce_l2(
return _get_node_factory_opset4().create(
"ReduceL2", as_nodes(node, reduction_axes), {"keep_dims": keep_dims}
)
@nameable_op
def lstm_cell(
X: NodeInput,
initial_hidden_state: NodeInput,
initial_cell_state: NodeInput,
W: NodeInput,
R: NodeInput,
B: NodeInput,
hidden_size: int,
activations: List[str] = None,
activations_alpha: List[float] = None,
activations_beta: List[float] = None,
clip: float = 0.0,
name: Optional[str] = None,
) -> Node:
"""Return a node which performs LSTMCell operation.
:param X: The input tensor with shape: [batch_size, input_size].
:param initial_hidden_state: The hidden state tensor with shape: [batch_size, hidden_size].
:param initial_cell_state: The cell state tensor with shape: [batch_size, hidden_size].
:param W: The weight tensor with shape: [4*hidden_size, input_size].
:param R: The recurrence weight tensor with shape: [4*hidden_size, hidden_size].
:param B: The bias tensor for gates with shape: [4*hidden_size].
:param hidden_size: Specifies hidden state size.
:param activations: The list of three activation functions for gates.
:param activations_alpha: The list of alpha parameters for activation functions.
:param activations_beta: The list of beta parameters for activation functions.
:param clip: Specifies bound values [-C, C] for tensor clipping performed before activations.
:param name: An optional name of the output node.
:return: The new node represents LSTMCell. Node outputs count: 2.
"""
if activations is None:
activations = ["sigmoid", "tanh", "tanh"]
if activations_alpha is None:
activations_alpha = []
if activations_beta is None:
activations_beta = []
node_inputs = as_nodes(X, initial_hidden_state, initial_cell_state, W, R, B)
attributes = {
"hidden_size": hidden_size,
"activations": activations,
"activations_alpha": activations_alpha,
"activations_beta": activations_beta,
"clip": clip,
}
return _get_node_factory_opset4().create("LSTMCell", node_inputs, attributes)

View File

@ -18,6 +18,7 @@ import pytest
from _pyngraph import PartialShape
import ngraph as ng
import ngraph.opset1 as ng_opset1
from ngraph.impl import Type
np_types = [np.float32, np.int32]
@ -230,6 +231,62 @@ def test_lstm_cell_operator(dtype):
assert list(node_param.get_output_shape(1)) == expected_shape
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_lstm_cell_operator_opset1(dtype):
batch_size = 1
input_size = 16
hidden_size = 128
X_shape = [batch_size, input_size]
H_t_shape = [batch_size, hidden_size]
C_t_shape = [batch_size, hidden_size]
W_shape = [4 * hidden_size, input_size]
R_shape = [4 * hidden_size, hidden_size]
B_shape = [4 * hidden_size]
parameter_X = ng.parameter(X_shape, name="X", dtype=dtype)
parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=dtype)
parameter_C_t = ng.parameter(C_t_shape, name="C_t", dtype=dtype)
parameter_W = ng.parameter(W_shape, name="W", dtype=dtype)
parameter_R = ng.parameter(R_shape, name="R", dtype=dtype)
parameter_B = ng.parameter(B_shape, name="B", dtype=dtype)
expected_shape = [1, 128]
node_default = ng_opset1.lstm_cell(
parameter_X, parameter_H_t, parameter_C_t, parameter_W, parameter_R, parameter_B, hidden_size,
)
assert node_default.get_type_name() == "LSTMCell"
assert node_default.get_output_size() == 2
assert list(node_default.get_output_shape(0)) == expected_shape
assert list(node_default.get_output_shape(1)) == expected_shape
activations = ["tanh", "Sigmoid", "RELU"]
activation_alpha = [1.0, 2.0, 3.0]
activation_beta = [3.0, 2.0, 1.0]
clip = 0.5
node_param = ng_opset1.lstm_cell(
parameter_X,
parameter_H_t,
parameter_C_t,
parameter_W,
parameter_R,
parameter_B,
hidden_size,
activations,
activation_alpha,
activation_beta,
clip,
)
assert node_param.get_type_name() == "LSTMCell"
assert node_param.get_output_size() == 2
assert list(node_param.get_output_shape(0)) == expected_shape
assert list(node_param.get_output_shape(1)) == expected_shape
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_lstm_sequence_operator_bidirectional(dtype):
batch_size = 1
@ -255,7 +312,7 @@ def test_lstm_sequence_operator_bidirectional(dtype):
parameter_B = ng.parameter(B_shape, name="B", dtype=dtype)
direction = "BIDIRECTIONAL"
node = ng.lstm_sequence(
node = ng_opset1.lstm_sequence(
parameter_X,
parameter_H_t,
parameter_C_t,
@ -275,7 +332,7 @@ def test_lstm_sequence_operator_bidirectional(dtype):
activation_beta = [3.0, 2.0, 1.0]
clip = 1.22
node_param = ng.lstm_sequence(
node_param = ng_opset1.lstm_sequence(
parameter_X,
parameter_H_t,
parameter_C_t,
@ -321,7 +378,7 @@ def test_lstm_sequence_operator_reverse(dtype):
direction = "REVERSE"
node_default = ng.lstm_sequence(
node_default = ng_opset1.lstm_sequence(
parameter_X,
parameter_H_t,
parameter_C_t,
@ -341,7 +398,7 @@ def test_lstm_sequence_operator_reverse(dtype):
activation_beta = [3.0, 2.0, 1.0]
clip = 1.22
node_param = ng.lstm_sequence(
node_param = ng_opset1.lstm_sequence(
parameter_X,
parameter_H_t,
parameter_C_t,
@ -387,7 +444,7 @@ def test_lstm_sequence_operator_forward(dtype):
direction = "forward"
node_default = ng.lstm_sequence(
node_default = ng_opset1.lstm_sequence(
parameter_X,
parameter_H_t,
parameter_C_t,
@ -407,7 +464,7 @@ def test_lstm_sequence_operator_forward(dtype):
activation_beta = [1.0]
clip = 0.5
node = ng.lstm_sequence(
node = ng_opset1.lstm_sequence(
parameter_X,
parameter_H_t,
parameter_C_t,

View File

@ -20,6 +20,7 @@
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/opsets/opset1.hpp"
#include "ngraph/opsets/opset3.hpp"
#include "ngraph/opsets/opset4.hpp"
#include "util/visitor.hpp"
@ -1063,7 +1064,7 @@ TEST(attributes, lrn_op)
TEST(attributes, lstm_cell_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::LSTMCell>();
FactoryRegistry<Node>::get().register_factory<opset4::LSTMCell>();
auto X = make_shared<op::Parameter>(element::f32, Shape{2, 3});
auto H = make_shared<op::Parameter>(element::f32, Shape{2, 3});
auto W = make_shared<op::Parameter>(element::f32, Shape{12, 3});
@ -1072,40 +1073,33 @@ TEST(attributes, lstm_cell_op)
const auto initial_cell_state = make_shared<op::Parameter>(element::f32, Shape{2, 3});
const auto hidden_size = 3;
const auto weights_format = op::LSTMWeightsFormat::ICOF;
const std::vector<std::string> activations = {"tanh", "sigmoid", "tanh"};
auto activations_alpha = std::vector<float>{1.0, 1.5};
auto activations_beta = std::vector<float>{2.0, 1.0};
const float clip = 0.5f;
bool input_forget = true;
const auto lstm_cell = make_shared<opset1::LSTMCell>(X,
const auto lstm_cell = make_shared<opset4::LSTMCell>(X,
initial_hidden_state,
initial_cell_state,
W,
R,
hidden_size,
weights_format,
activations,
activations_alpha,
activations_beta,
clip,
input_forget);
clip);
NodeBuilder builder(lstm_cell);
auto g_lstm_cell = as_type_ptr<opset1::LSTMCell>(builder.create());
auto g_lstm_cell = as_type_ptr<opset4::LSTMCell>(builder.create());
EXPECT_EQ(g_lstm_cell->get_hidden_size(), lstm_cell->get_hidden_size());
EXPECT_EQ(g_lstm_cell->get_activations(), lstm_cell->get_activations());
EXPECT_EQ(g_lstm_cell->get_activations_alpha(), lstm_cell->get_activations_alpha());
EXPECT_EQ(g_lstm_cell->get_activations_beta(), lstm_cell->get_activations_beta());
EXPECT_EQ(g_lstm_cell->get_clip(), lstm_cell->get_clip());
EXPECT_EQ(g_lstm_cell->get_input_forget(), lstm_cell->get_input_forget());
EXPECT_EQ(g_lstm_cell->get_weights_format(), lstm_cell->get_weights_format());
}
TEST(attributes, lstm_sequence_op)
{
FactoryRegistry<Node>::get().register_factory<opset1::LSTMSequence>();
FactoryRegistry<Node>::get().register_factory<op::v1::LSTMSequence>();
const size_t batch_size = 4;
const size_t num_directions = 2;
@ -1127,14 +1121,12 @@ TEST(attributes, lstm_sequence_op)
const auto B = make_shared<op::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
const auto lstm_direction = op::RecurrentSequenceDirection::BIDIRECTIONAL;
const auto weights_format = op::LSTMWeightsFormat::ICOF;
const std::vector<float> activations_alpha = {1, 2, 3};
const std::vector<float> activations_beta = {4, 5, 6};
const std::vector<std::string> activations = {"tanh", "sigmoid", "tanh"};
const float clip_threshold = 0.5f;
const bool input_forget = true;
const auto lstm_sequence = make_shared<opset1::LSTMSequence>(X,
const auto lstm_sequence = make_shared<op::v1::LSTMSequence>(X,
initial_hidden_state,
initial_cell_state,
sequence_lengths,
@ -1143,23 +1135,19 @@ TEST(attributes, lstm_sequence_op)
B,
hidden_size,
lstm_direction,
weights_format,
activations_alpha,
activations_beta,
activations,
clip_threshold,
input_forget);
clip_threshold);
NodeBuilder builder(lstm_sequence);
auto g_lstm_sequence = as_type_ptr<opset1::LSTMSequence>(builder.create());
auto g_lstm_sequence = as_type_ptr<op::v1::LSTMSequence>(builder.create());
EXPECT_EQ(g_lstm_sequence->get_hidden_size(), lstm_sequence->get_hidden_size());
EXPECT_EQ(g_lstm_sequence->get_activations(), lstm_sequence->get_activations());
EXPECT_EQ(g_lstm_sequence->get_activations_alpha(), lstm_sequence->get_activations_alpha());
EXPECT_EQ(g_lstm_sequence->get_activations_beta(), lstm_sequence->get_activations_beta());
EXPECT_EQ(g_lstm_sequence->get_clip_threshold(), lstm_sequence->get_clip_threshold());
EXPECT_EQ(g_lstm_sequence->get_clip(), lstm_sequence->get_clip());
EXPECT_EQ(g_lstm_sequence->get_direction(), lstm_sequence->get_direction());
EXPECT_EQ(g_lstm_sequence->get_input_forget(), lstm_sequence->get_input_forget());
EXPECT_EQ(g_lstm_sequence->get_weights_format(), lstm_sequence->get_weights_format());
}
TEST(attributes, shuffle_channels_op)

View File

@ -33,7 +33,9 @@
#include "gtest/gtest.h"
#include "ngraph/check.hpp"
#include "ngraph/ngraph.hpp"
#include "ngraph/opsets/opset4.hpp"
#include "ngraph/op/util/attr_types.hpp"
#include "ngraph/op/util/rnn_cell_base.hpp"
#include "op/group_conv.hpp"
#include "util/all_close.hpp"
#include "util/all_close_f.hpp"
@ -1629,11 +1631,17 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_zero_bias_peepholes)
const auto B = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size});
const auto P = make_shared<op::Parameter>(element::f32, Shape{3 * hidden_size});
const auto lstm_cell = make_shared<op::LSTMCell>(
X, H_t, C_t, W, R, B, P, hidden_size, op::LSTMWeightsFormat::IOFC);
const auto lstm_cell = make_shared<opset4::LSTMCell>(
X,
H_t,
C_t,
op::util::convert_lstm_node_format(W, op::util::LSTMWeightsFormat::IOFC),
op::util::convert_lstm_node_format(R, op::util::LSTMWeightsFormat::IOFC),
op::util::convert_lstm_node_format(B, op::util::LSTMWeightsFormat::IOFC),
hidden_size);
auto ht_function = make_shared<Function>(OutputVector{lstm_cell->output(0)},
ParameterVector{X, H_t, C_t, W, R, B, P});
ParameterVector{X, H_t, C_t, W, R, B});
auto ht_test_case = test::TestCase<TestEngine>(ht_function);
// X
@ -1665,18 +1673,16 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_zero_bias_peepholes)
// P
vector<float> in_P(3 * hidden_size, 0.f);
ht_test_case.add_multiple_inputs(
vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P});
ht_test_case.add_multiple_inputs(vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B});
ht_test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{0.81457126f, 0.61109227f, 0.769522f, 0.52239674f, 0.4324641f, 0.63183f});
ht_test_case.run();
auto ct_function = make_shared<Function>(OutputVector{lstm_cell->output(1)},
ParameterVector{X, H_t, C_t, W, R, B, P});
ParameterVector{X, H_t, C_t, W, R, B});
auto ct_test_case = test::TestCase<TestEngine>(ct_function);
ct_test_case.add_multiple_inputs(
vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P});
ct_test_case.add_multiple_inputs(vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B});
ct_test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{1.4444952f, 0.9635685f, 1.2875274f, 0.8053419f, 0.7184521f, 0.95803297f});
@ -1700,11 +1706,10 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_bias_peepholes)
const auto B = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size});
const auto P = make_shared<op::Parameter>(element::f32, Shape{3 * hidden_size});
const auto lstm_cell = make_shared<op::LSTMCell>(
X, H_t, C_t, W, R, B, P, hidden_size, op::LSTMWeightsFormat::IOFC);
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, B, hidden_size);
auto ht_function = make_shared<Function>(OutputVector{lstm_cell->output(0)},
ParameterVector{X, H_t, C_t, W, R, B, P});
ParameterVector{X, H_t, C_t, W, R, B});
auto ht_test_case = test::TestCase<TestEngine>(ht_function);
// X
@ -1755,18 +1760,16 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_bias_peepholes)
0.13840231f,
0.24175227f};
ht_test_case.add_multiple_inputs(
vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P});
ht_test_case.add_multiple_inputs(vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B});
ht_test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{0.9218244f, 0.78787273f, 0.8754273f, 0.7361462f, 0.70927656f, 0.83522964f});
ht_test_case.run();
auto ct_function = make_shared<Function>(OutputVector{lstm_cell->output(1)},
ParameterVector{X, H_t, C_t, W, R, B, P});
ParameterVector{X, H_t, C_t, W, R, B});
auto ct_test_case = test::TestCase<TestEngine>(ct_function);
ct_test_case.add_multiple_inputs(
vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P});
ct_test_case.add_multiple_inputs(vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B});
ct_test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{1.7094649f, 1.1259761f, 1.444019f, 1.086587f, 0.9762144f, 1.3066899f});
@ -1792,22 +1795,19 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_bias_peepholes_clip_input_forget)
const auto B = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size});
const auto P = make_shared<op::Parameter>(element::f32, Shape{3 * hidden_size});
const auto lstm_cell = make_shared<op::LSTMCell>(X,
H_t,
C_t,
W,
R,
B,
P,
hidden_size,
op::LSTMWeightsFormat::IOFC,
vector<string>{"sigmoid", "tanh", "tanh"},
vector<float>{},
vector<float>{},
clip_threshold,
input_forget);
const auto lstm_cell = make_shared<opset4::LSTMCell>(X,
H_t,
C_t,
W,
R,
B,
hidden_size,
vector<string>{"sigmoid", "tanh", "tanh"},
vector<float>{},
vector<float>{},
clip_threshold);
auto ht_function = make_shared<Function>(OutputVector{lstm_cell->output(0)},
ParameterVector{X, H_t, C_t, W, R, B, P});
ParameterVector{X, H_t, C_t, W, R, B});
auto ht_test_case = test::TestCase<TestEngine>(ht_function);
// X
@ -1858,18 +1858,16 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_bias_peepholes_clip_input_forget)
0.13840231f,
0.24175227f};
ht_test_case.add_multiple_inputs(
vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P});
ht_test_case.add_multiple_inputs(vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B});
ht_test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{0.71485436f, 0.71844107f, 0.72704613f, 0.6235602f, 0.68306124f, 0.6978715f});
ht_test_case.run();
auto ct_function = make_shared<Function>(OutputVector{lstm_cell->output(1)},
ParameterVector{X, H_t, C_t, W, R, B, P});
ParameterVector{X, H_t, C_t, W, R, B});
auto ct_test_case = test::TestCase<TestEngine>(ct_function);
ct_test_case.add_multiple_inputs(
vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P});
ct_test_case.add_multiple_inputs(vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B});
ct_test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{0.94656503f, 0.9527454f, 0.9706756f, 0.84206575f, 0.91898793f, 0.9127192f});
@ -1898,22 +1896,19 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_activaction_functions)
const auto B = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size});
const auto P = make_shared<op::Parameter>(element::f32, Shape{3 * hidden_size});
const auto lstm_cell = make_shared<op::LSTMCell>(X,
H_t,
C_t,
W,
R,
B,
P,
hidden_size,
op::LSTMWeightsFormat::IOFC,
activations,
activation_alpha,
activation_beta,
clip_threshold,
input_forget);
const auto lstm_cell = make_shared<opset4::LSTMCell>(X,
H_t,
C_t,
W,
R,
B,
hidden_size,
activations,
activation_alpha,
activation_beta,
clip_threshold);
auto ht_function = make_shared<Function>(OutputVector{lstm_cell->output(0)},
ParameterVector{X, H_t, C_t, W, R, B, P});
ParameterVector{X, H_t, C_t, W, R, B});
auto ht_test_case = test::TestCase<TestEngine>(ht_function);
// X
@ -1964,18 +1959,16 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_activaction_functions)
0.13840231f,
0.24175227f};
ht_test_case.add_multiple_inputs(
vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P});
ht_test_case.add_multiple_inputs(vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B});
ht_test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{0.96834344f, 0.9695254f, 0.97068775f, 0.9077866f, 0.94161016f, 0.96599925f});
ht_test_case.run();
auto ct_function = make_shared<Function>(OutputVector{lstm_cell->output(1)},
ParameterVector{X, H_t, C_t, W, R, B, P});
ParameterVector{X, H_t, C_t, W, R, B});
auto ct_test_case = test::TestCase<TestEngine>(ct_function);
ct_test_case.add_multiple_inputs(
vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P});
ct_test_case.add_multiple_inputs(vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B});
ct_test_case.add_expected_output<float>(
Shape{batch_size, hidden_size},
{0.94656503f, 0.9527454f, 0.9706756f, 0.84206575f, 0.91898793f, 0.9127192f});
@ -2168,7 +2161,7 @@ NGRAPH_TEST(${BACKEND_NAME}, rnn_cell_no_bias)
const auto W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, hidden_size);
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size);
auto function = make_shared<Function>(rnn_cell, ParameterVector{X, H_t, W, R});
auto test_case = test::TestCase<TestEngine>(function);
@ -2219,16 +2212,16 @@ NGRAPH_TEST(${BACKEND_NAME}, rnn_cell_bias_clip)
const auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
const auto B = make_shared<op::Parameter>(element::f32, Shape{hidden_size});
const auto rnn_cell = make_shared<op::RNNCell>(X,
H_t,
W,
R,
B,
hidden_size,
vector<string>{"tanh"},
vector<float>{},
vector<float>{},
clip);
const auto rnn_cell = make_shared<opset4::RNNCell>(X,
H_t,
W,
R,
B,
hidden_size,
vector<string>{"tanh"},
vector<float>{},
vector<float>{},
clip);
auto function = make_shared<Function>(rnn_cell, ParameterVector{X, H_t, W, R, B});
auto test_case = test::TestCase<TestEngine>(function);
@ -2281,16 +2274,16 @@ NGRAPH_TEST(${BACKEND_NAME}, rnn_cell_activation_function)
const auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
const auto B = make_shared<op::Parameter>(element::f32, Shape{hidden_size});
const auto rnn_cell = make_shared<op::RNNCell>(X,
H_t,
W,
R,
B,
hidden_size,
vector<string>{"sigmoid"},
vector<float>{},
vector<float>{},
clip);
const auto rnn_cell = make_shared<opset4::RNNCell>(X,
H_t,
W,
R,
B,
hidden_size,
vector<string>{"sigmoid"},
vector<float>{},
vector<float>{},
clip);
auto function = make_shared<Function>(rnn_cell, ParameterVector{X, H_t, W, R, B});
auto test_case = test::TestCase<TestEngine>(function);
@ -2347,17 +2340,17 @@ NGRAPH_TEST(${BACKEND_NAME}, gru_cell_bias_clip)
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto B = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size});
const auto gru_cell = make_shared<op::GRUCell>(X,
H_t,
W,
R,
B,
hidden_size,
vector<string>{"sigmoid", "tanh"},
vector<float>{},
vector<float>{},
clip,
linear_before_reset);
const auto gru_cell = make_shared<opset4::GRUCell>(X,
H_t,
W,
R,
B,
hidden_size,
vector<string>{"sigmoid", "tanh"},
vector<float>{},
vector<float>{},
clip,
linear_before_reset);
auto function = make_shared<Function>(gru_cell, ParameterVector{X, H_t, W, R, B});
auto test_case = test::TestCase<TestEngine>(function);
@ -2420,17 +2413,17 @@ NGRAPH_TEST(${BACKEND_NAME}, gru_cell_linear_before_reset)
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto B = make_shared<op::Parameter>(element::f32, Shape{(gates_count + 1) * hidden_size});
const auto gru_cell = make_shared<op::GRUCell>(X,
H_t,
W,
R,
B,
hidden_size,
vector<string>{"sigmoid", "tanh"},
vector<float>{},
vector<float>{},
clip,
linear_before_reset);
const auto gru_cell = make_shared<opset4::GRUCell>(X,
H_t,
W,
R,
B,
hidden_size,
vector<string>{"sigmoid", "tanh"},
vector<float>{},
vector<float>{},
clip,
linear_before_reset);
auto function = make_shared<Function>(gru_cell, ParameterVector{X, H_t, W, R, B});
auto test_case = test::TestCase<TestEngine>(function);
@ -2492,17 +2485,17 @@ NGRAPH_TEST(${BACKEND_NAME}, gru_cell_activation_function)
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto B = make_shared<op::Parameter>(element::f32, Shape{(gates_count + 1) * hidden_size});
const auto gru_cell = make_shared<op::GRUCell>(X,
H_t,
W,
R,
B,
hidden_size,
vector<string>{"hardsigmoid", "hardsigmoid"},
vector<float>{1.8345f, 1.8345f},
vector<float>{3.05f, 3.05f},
clip,
linear_before_reset);
const auto gru_cell = make_shared<opset4::GRUCell>(X,
H_t,
W,
R,
B,
hidden_size,
vector<string>{"hardsigmoid", "hardsigmoid"},
vector<float>{1.8345f, 1.8345f},
vector<float>{3.05f, 3.05f},
clip,
linear_before_reset);
auto function = make_shared<Function>(gru_cell, ParameterVector{X, H_t, W, R, B});
auto test_case = test::TestCase<TestEngine>(function);

View File

@ -346,7 +346,7 @@ namespace
void op_is_GRUCell()
{
op::GRUCell node;
op::v3::GRUCell node;
EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
@ -472,7 +472,7 @@ namespace
void op_is_LSTMCell()
{
op::LSTMCell node;
op::v4::LSTMCell node;
EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
@ -481,7 +481,7 @@ namespace
void op_is_LSTMSequence()
{
op::LSTMSequence node;
op::v0::LSTMSequence node;
EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
@ -733,7 +733,7 @@ namespace
void op_is_RNNCell()
{
op::RNNCell node;
op::v0::RNNCell node;
EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));

View File

@ -1085,14 +1085,14 @@ IE_CPU.builder_opset1_collapse_dyn_shape
# IE_CPU.interpolate_down_scales_const_linear
# GRUCell operation has a form that is not supported
IE_CPU.onnx_model_gru_defaults_fwd
IE_CPU.onnx_model_gru_fwd_activations
IE_CPU.onnx_model_gru_fwd_mixed_seq_len
IE_CPU.onnx_model_gru_rev_clip
IE_CPU.onnx_model_gru_reverse
IE_CPU.onnx_model_gru_fwd_bias_initial_h
IE_CPU.onnx_model_gru_bidirectional
IE_CPU.onnx_model_gru_fwd_linear_before_reset
onnx_model_gru_defaults_fwd
onnx_model_gru_fwd_activations
onnx_model_gru_fwd_mixed_seq_len
onnx_model_gru_rev_clip
onnx_model_gru_reverse
onnx_model_gru_fwd_bias_initial_h
onnx_model_gru_bidirectional
onnx_model_gru_fwd_linear_before_reset
# Not implemented Interpolate-4:
IE_CPU.onnx_model_resize10_import_only

View File

@ -59,8 +59,10 @@
#include "ngraph/runtime/reference/floor.hpp"
#include "ngraph/runtime/reference/gather.hpp"
#include "ngraph/runtime/reference/gather_nd.hpp"
#include "ngraph/runtime/reference/gru_cell.hpp"
#include "ngraph/runtime/reference/log.hpp"
#include "ngraph/runtime/reference/lrn.hpp"
#include "ngraph/runtime/reference/lstm_cell.hpp"
#include "ngraph/runtime/reference/matmul.hpp"
#include "ngraph/runtime/reference/max.hpp"
#include "ngraph/runtime/reference/max_pool.hpp"
@ -77,6 +79,7 @@
#include "ngraph/runtime/reference/result.hpp"
#include "ngraph/runtime/reference/reverse.hpp"
#include "ngraph/runtime/reference/reverse_sequence.hpp"
#include "ngraph/runtime/reference/rnn_cell.hpp"
#include "ngraph/runtime/reference/round.hpp"
#include "ngraph/runtime/reference/scatter_nd_update.hpp"
#include "ngraph/runtime/reference/select.hpp"
@ -692,6 +695,67 @@ protected:
}
break;
}
case OP_TYPEID::GRUCell_v3:
{
const op::v3::GRUCell* gru_cell = static_cast<const op::v3::GRUCell*>(&node);
runtime::reference::gru_cell(args[0]->get_data_ptr<T>(),
args[0]->get_shape(),
args[1]->get_data_ptr<T>(),
args[1]->get_shape(),
args[2]->get_data_ptr<T>(),
args[2]->get_shape(),
args[3]->get_data_ptr<T>(),
args[3]->get_shape(),
args[4]->get_data_ptr<T>(),
args[4]->get_shape(),
out[0]->get_data_ptr<T>(),
gru_cell->get_activations()[0],
gru_cell->get_activations()[1],
gru_cell->get_clip(),
gru_cell->get_linear_before_reset());
break;
}
case OP_TYPEID::LSTMCell_v4:
{
const op::v4::LSTMCell* lstm_cell = static_cast<const op::v4::LSTMCell*>(&node);
runtime::reference::lstm_cell(args[0]->get_data_ptr<T>(),
args[0]->get_shape(),
args[1]->get_data_ptr<T>(),
args[1]->get_shape(),
args[2]->get_data_ptr<T>(),
args[2]->get_shape(),
args[3]->get_data_ptr<T>(),
args[3]->get_shape(),
args[4]->get_data_ptr<T>(),
args[4]->get_shape(),
args[5]->get_data_ptr<T>(),
args[5]->get_shape(),
out[0]->get_data_ptr<T>(),
out[1]->get_data_ptr<T>(),
lstm_cell->get_activations()[0],
lstm_cell->get_activations()[1],
lstm_cell->get_activations()[2],
lstm_cell->get_clip());
break;
}
case OP_TYPEID::RNNCell_v0:
{
const op::v0::RNNCell* rnn_cell = static_cast<const op::v0::RNNCell*>(&node);
runtime::reference::rnn_cell(args[0]->get_data_ptr<T>(),
args[0]->get_shape(),
args[1]->get_data_ptr<T>(),
args[1]->get_shape(),
args[2]->get_data_ptr<T>(),
args[2]->get_shape(),
args[3]->get_data_ptr<T>(),
args[3]->get_shape(),
args[4]->get_data_ptr<T>(),
args[4]->get_shape(),
out[0]->get_data_ptr<T>(),
rnn_cell->get_activations()[0],
rnn_cell->get_clip());
break;
}
case OP_TYPEID::Log:
{
size_t element_count = shape_size(node.get_output_shape(0));
@ -1203,15 +1267,12 @@ protected:
case OP_TYPEID::GRN:
case OP_TYPEID::GroupConvolution:
case OP_TYPEID::GroupConvolutionBackpropData:
case OP_TYPEID::GRUCell:
case OP_TYPEID::HardSigmoid:
case OP_TYPEID::Interpolate:
case OP_TYPEID::LSTMCell:
case OP_TYPEID::LSTMSequence:
case OP_TYPEID::MVN:
case OP_TYPEID::NormalizeL2:
case OP_TYPEID::PRelu:
case OP_TYPEID::RNNCell:
case OP_TYPEID::ScatterUpdate_v3:
case OP_TYPEID::Selu:
case OP_TYPEID::ShuffleChannels:

View File

@ -20,6 +20,7 @@
#define ID_SUFFIX(NAME) NAME##_v0
NGRAPH_OP(DetectionOutput, op::v0)
NGRAPH_OP(RNNCell, op::v0)
#undef ID_SUFFIX
#define ID_SUFFIX(NAME) NAME##_v1
@ -31,6 +32,7 @@ NGRAPH_OP(LogicalNot, op::v1)
#undef ID_SUFFIX
#define ID_SUFFIX(NAME) NAME##_v3
NGRAPH_OP(GRUCell, op::v3)
NGRAPH_OP(EmbeddingBagOffsetsSum, op::v3)
NGRAPH_OP(EmbeddingBagPackedSum, op::v3)
NGRAPH_OP(EmbeddingSegmentsSum, op::v3)
@ -43,4 +45,5 @@ NGRAPH_OP(ScatterUpdate, op::v3)
#define ID_SUFFIX(NAME) NAME##_v4
NGRAPH_OP(CTCLoss, op::v4)
NGRAPH_OP(LSTMCell, op::v4)
#undef ID_SUFFIX

View File

@ -105,3 +105,20 @@ INTERPRETER.onnx_model_gatherND_float
# Round op doesn't support some specific cases of rounding
onnx_model_round_half_nearest_even
# Unsupported op 'LSTMSequence': not FusedOp anymore, no reference implementation yet
onnx_model_lstm_fwd_with_clip
onnx_model_lstm_fwd_mixed_seq
onnx_model_lstm_fwd_hardsigmoid_activation
onnx_model_lstm_fwd_large_batch_no_clip
onnx_model_lstm_bdir_short_input_seq
onnx_model_lstm_mixed_seq_reverse
# Activation function hardsigmoid is not supported.
gru_cell_activation_function
lstm_cell_activaction_functions
onnx_model_gru_fwd_activations
# Peepholes, input_forget are not supported
lstm_cell_bias_peepholes
lstm_cell_bias_peepholes_clip_input_forget

View File

@ -81,7 +81,6 @@ NGRAPH_OP(Exp, ngraph::op)
NGRAPH_OP(FakeQuantize, ngraph::op)
NGRAPH_OP(Floor, ngraph::op)
NGRAPH_OP(GRN, ngraph::op)
NGRAPH_OP(GRUCell, ngraph::op)
NGRAPH_OP(Gather, ngraph::op)
NGRAPH_OP(GatherND, ngraph::op)
NGRAPH_OP(Gelu, ngraph::op)
@ -95,8 +94,7 @@ NGRAPH_OP(Less, ngraph::op)
NGRAPH_OP(LessEq, ngraph::op)
NGRAPH_OP(Log, ngraph::op)
NGRAPH_OP(LRN, ngraph::op)
NGRAPH_OP(LSTMCell, ngraph::op)
NGRAPH_OP(LSTMSequence, ngraph::op)
NGRAPH_OP(LSTMSequence, ngraph::op::v0)
NGRAPH_OP(MatMul, ngraph::op)
NGRAPH_OP(NormalizeL2, ngraph::op)
NGRAPH_OP(Max, ngraph::op)
@ -124,7 +122,6 @@ NGRAPH_OP(Reshape, ngraph::op)
NGRAPH_OP(Result, ngraph::op)
NGRAPH_OP(Reverse, ngraph::op)
NGRAPH_OP(ReverseSequence, ngraph::op)
NGRAPH_OP(RNNCell, ngraph::op)
NGRAPH_OP(Round, ngraph::op)
NGRAPH_OP(Select, ngraph::op)
NGRAPH_OP(Selu, ngraph::op)

View File

@ -16,6 +16,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/opsets/opset4.hpp"
#include "util/type_prop.hpp"
using namespace std;
@ -35,7 +36,7 @@ TEST(type_prop, gru_cell)
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto gru_cell = make_shared<op::GRUCell>(X, H_t, W, R, hidden_size);
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
EXPECT_EQ(gru_cell->get_output_element_type(0), element::f32);
EXPECT_EQ(gru_cell->get_output_shape(0), (Shape{batch_size, hidden_size}));
}
@ -56,7 +57,7 @@ TEST(type_prop, gru_cell_invalid_input)
auto W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
try
{
const auto gru_cell = make_shared<op::GRUCell>(X, H_t, W, R, hidden_size);
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
FAIL() << "GRUCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
@ -70,7 +71,7 @@ TEST(type_prop, gru_cell_invalid_input)
R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, 1});
try
{
const auto gru_cell = make_shared<op::GRUCell>(X, H_t, W, R, hidden_size);
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
FAIL() << "GRUCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
@ -86,7 +87,7 @@ TEST(type_prop, gru_cell_invalid_input)
H_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
try
{
const auto gru_cell = make_shared<op::GRUCell>(X, H_t, W, R, hidden_size);
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
FAIL() << "GRUCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
@ -101,7 +102,7 @@ TEST(type_prop, gru_cell_invalid_input)
auto B = make_shared<op::Parameter>(element::f32, Shape{hidden_size});
try
{
const auto gru_cell = make_shared<op::GRUCell>(X, H_t, W, R, B, hidden_size);
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, B, hidden_size);
FAIL() << "GRUCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
@ -126,7 +127,7 @@ TEST(type_prop, gru_cell_dynamic_batch_size)
const auto H_t =
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
const auto gru_cell = make_shared<op::GRUCell>(X, H_t, W, R, hidden_size);
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
EXPECT_EQ(gru_cell->get_output_element_type(0), element::f32);
EXPECT_EQ(gru_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
}
@ -146,7 +147,7 @@ TEST(type_prop, gru_cell_dynamic_hidden_size)
const auto H_t =
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
const auto gru_cell = make_shared<op::GRUCell>(X, H_t, W, R, 3);
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, 3);
EXPECT_EQ(gru_cell->get_output_element_type(0), element::f32);
EXPECT_EQ(gru_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
}
@ -163,7 +164,7 @@ TEST(type_prop, gru_cell_dynamic_inputs)
const auto H_t =
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
const auto gru_cell = make_shared<op::GRUCell>(X, H_t, W, R, 2);
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, 2);
EXPECT_EQ(gru_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
EXPECT_EQ(gru_cell->get_output_element_type(0), element::f32);
@ -183,33 +184,37 @@ TEST(type_prop, gru_cell_invalid_input_rank0)
// Invalid rank0 for W tensor.
auto W = make_shared<op::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "GRUCell node was created with invalid data.";
// Invalid rank0 for X tensor.
W = make_shared<op::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, input_size});
X = make_shared<op::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "GRUCell node was created with invalid data.";
// Invalid rank0 for H_t tensor.
X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
H_t = make_shared<op::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "GRUCell node was created with invalid data.";
// Invalid rank0 for R tensor.
H_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
R = make_shared<op::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "GRUCell node was created with invalid data.";
// Invalid rank0 for B tensor.
R = make_shared<op::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, input_size});
auto B = make_shared<op::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, B, hidden_size),
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, B, hidden_size),
ngraph::NodeValidationFailure)
<< "GRUCell node was created with invalid data.";
}
@ -228,32 +233,36 @@ TEST(type_prop, gru_cell_invalid_input_dynamic_rank)
// Invalid dynamic rank for W tensor.
auto W = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "GRUCell node was created with invalid data.";
// Invalid dynamic rank for X tensor.
W = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, input_size});
X = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "GRUCell node was created with invalid data.";
// Invalid dynamic rank for H_t tensor.
X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
H_t = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "GRUCell node was created with invalid data.";
// Invalid dynamic rank for R tensor.
H_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
R = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "GRUCell node was created with invalid data.";
// Invalid dynamic rank for B tensor.
R = make_shared<op::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, hidden_size});
auto B = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, B, hidden_size),
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, B, hidden_size),
ngraph::NodeValidationFailure)
<< "GRUCell node was created with invalid data.";
}

View File

@ -16,6 +16,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/opsets/opset4.hpp"
#include "util/type_prop.hpp"
using namespace std;
@ -28,15 +29,15 @@ TEST(type_prop, lstm_cell)
const size_t hidden_size = 3;
const size_t gates_count = 4;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
const auto X = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, input_size});
const auto W =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
make_shared<opset4::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
const auto R =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
make_shared<opset4::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
const auto H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto C_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
EXPECT_EQ(lstm_cell->get_hidden_size(), hidden_size);
EXPECT_EQ(lstm_cell->get_clip(), 0.f);
EXPECT_TRUE(lstm_cell->get_activations_alpha().empty());
@ -44,8 +45,6 @@ TEST(type_prop, lstm_cell)
EXPECT_EQ(lstm_cell->get_activations()[0], "sigmoid");
EXPECT_EQ(lstm_cell->get_activations()[1], "tanh");
EXPECT_EQ(lstm_cell->get_activations()[2], "tanh");
EXPECT_EQ(lstm_cell->get_weights_format(), op::LSTMWeightsFormat::IFCO);
EXPECT_FALSE(lstm_cell->get_input_forget());
EXPECT_EQ(lstm_cell->get_output_element_type(0), element::f32);
EXPECT_EQ(lstm_cell->get_output_shape(0), (Shape{batch_size, hidden_size}));
EXPECT_EQ(lstm_cell->get_output_element_type(1), element::f32);
@ -59,17 +58,17 @@ TEST(type_prop, lstm_cell_invalid_input)
const size_t hidden_size = 3;
const size_t gates_count = 4;
auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
auto X = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, input_size});
auto R =
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
make_shared<opset4::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
auto H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto C_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
// Invalid W tensor shape.
auto W = make_shared<op::Parameter>(element::f32, Shape{1 * hidden_size, input_size});
auto W = make_shared<opset4::Parameter>(element::f32, Shape{1 * hidden_size, input_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
@ -79,11 +78,11 @@ TEST(type_prop, lstm_cell_invalid_input)
}
// Invalid R tensor shape.
W = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
R = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, 1});
W = make_shared<opset4::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
R = make_shared<opset4::Parameter>(element::f32, Shape{gates_count * hidden_size, 1});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
@ -94,11 +93,11 @@ TEST(type_prop, lstm_cell_invalid_input)
}
// Invalid H_t tensor shape.
R = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
H_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
R = make_shared<opset4::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
H_t = make_shared<opset4::Parameter>(element::f32, Shape{4, hidden_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
@ -109,11 +108,11 @@ TEST(type_prop, lstm_cell_invalid_input)
}
// Invalid C_t tensor shape.
H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
C_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
C_t = make_shared<opset4::Parameter>(element::f32, Shape{4, hidden_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
@ -124,12 +123,12 @@ TEST(type_prop, lstm_cell_invalid_input)
}
// Invalid B tensor shape.
C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto B = make_shared<op::Parameter>(element::f32, Shape{2 * gates_count * hidden_size});
auto P = make_shared<op::Parameter>(element::f32, Shape{3 * hidden_size});
C_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto B = make_shared<opset4::Parameter>(element::f32, Shape{2 * gates_count * hidden_size});
auto P = make_shared<opset4::Parameter>(element::f32, Shape{3 * hidden_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, B, P, hidden_size);
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, B, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
@ -137,20 +136,6 @@ TEST(type_prop, lstm_cell_invalid_input)
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Parameter hidden_size mistmatched in B input."));
}
// Invalid P tensor shape.
B = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size});
P = make_shared<op::Parameter>(element::f32, Shape{hidden_size});
try
{
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, B, P, hidden_size);
FAIL() << "LSTMCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
{
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Parameter hidden_size mistmatched in P input."));
}
}
TEST(type_prop, lstm_cell_dynamic_batch_size)
@ -160,17 +145,18 @@ TEST(type_prop, lstm_cell_dynamic_batch_size)
const size_t hidden_size = 3;
const size_t gates_count = 4;
const auto X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
const auto W = make_shared<op::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, hidden_size});
const auto X =
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
const auto W = make_shared<opset4::Parameter>(
element::f32, PartialShape{gates_count * hidden_size, input_size});
const auto R = make_shared<opset4::Parameter>(
element::f32, PartialShape{gates_count * hidden_size, hidden_size});
const auto H_t =
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
const auto C_t =
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
EXPECT_EQ(lstm_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
EXPECT_EQ(lstm_cell->get_output_partial_shape(1), (PartialShape{batch_size, hidden_size}));
@ -185,17 +171,18 @@ TEST(type_prop, lstm_cell_dynamic_hidden_size)
const auto hidden_size = Dimension::dynamic();
const size_t gates_count = 4;
const auto X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
const auto W = make_shared<op::Parameter>(element::f32,
PartialShape{hidden_size * gates_count, input_size});
const auto R = make_shared<op::Parameter>(element::f32,
PartialShape{hidden_size * gates_count, hidden_size});
const auto X =
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
const auto W = make_shared<opset4::Parameter>(
element::f32, PartialShape{hidden_size * gates_count, input_size});
const auto R = make_shared<opset4::Parameter>(
element::f32, PartialShape{hidden_size * gates_count, hidden_size});
const auto H_t =
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
const auto C_t =
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, 3);
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, 3);
EXPECT_EQ(lstm_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
EXPECT_EQ(lstm_cell->get_output_partial_shape(1), (PartialShape{batch_size, hidden_size}));
@ -210,17 +197,18 @@ TEST(type_prop, lstm_cell_dynamic_inputs)
const auto hidden_size = Dimension::dynamic();
const size_t gates_count = 4;
const auto X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
const auto W = make_shared<op::Parameter>(element::f32,
PartialShape{hidden_size * gates_count, input_size});
const auto R = make_shared<op::Parameter>(element::f32,
PartialShape{hidden_size * gates_count, hidden_size});
const auto X =
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
const auto W = make_shared<opset4::Parameter>(
element::f32, PartialShape{hidden_size * gates_count, input_size});
const auto R = make_shared<opset4::Parameter>(
element::f32, PartialShape{hidden_size * gates_count, hidden_size});
const auto H_t =
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
const auto C_t =
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, 3);
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, 3);
EXPECT_EQ(lstm_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
EXPECT_EQ(lstm_cell->get_output_partial_shape(1), (PartialShape{batch_size, hidden_size}));
@ -235,62 +223,54 @@ TEST(type_prop, lstm_cell_invalid_input_rank0)
const size_t hidden_size = 3;
const size_t gates_count = 4;
auto X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
auto W = make_shared<op::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, input_size});
auto R = make_shared<op::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, hidden_size});
auto H_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
auto C_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
auto X = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
auto W = make_shared<opset4::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, input_size});
auto R = make_shared<opset4::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, hidden_size});
auto H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
auto C_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
// Invalid rank0 for W tensor.
W = make_shared<op::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
W = make_shared<opset4::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "LSTMCell node was created with invalid data.";
// Invalid rank0 for X tensor.
W = make_shared<op::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, input_size});
X = make_shared<op::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
W = make_shared<opset4::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, input_size});
X = make_shared<opset4::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "LSTMCell node was created with invalid data.";
// Invalid rank0 for H_t tensor.
X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
H_t = make_shared<op::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
X = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "LSTMCell node was created with invalid data.";
// Invalid rank0 for C_t tensor.
H_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
C_t = make_shared<op::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
C_t = make_shared<opset4::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "LSTMCell node was created with invalid data.";
// Invalid rank0 for R tensor.
C_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
R = make_shared<op::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
C_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
R = make_shared<opset4::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "LSTMCell node was created with invalid data.";
// Invalid rank0 for B tensor.
R = make_shared<op::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, hidden_size});
auto B = make_shared<op::Parameter>(element::f32, PartialShape{});
auto P = make_shared<op::Parameter>(element::f32, PartialShape{3 * hidden_size});
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, B, P, hidden_size),
ngraph::NodeValidationFailure)
<< "LSTMCell node was created with invalid data.";
// Invalid rank0 for P tensor.
B = make_shared<op::Parameter>(element::f32, PartialShape{gates_count * hidden_size});
P = make_shared<op::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, B, P, hidden_size),
R = make_shared<opset4::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, hidden_size});
auto B = make_shared<opset4::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, B, hidden_size),
ngraph::NodeValidationFailure)
<< "LSTMCell node was created with invalid data.";
}
@ -302,62 +282,54 @@ TEST(type_prop, lstm_cell_invalid_input_dynamic_rank)
const size_t hidden_size = 3;
const size_t gates_count = 4;
auto X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
auto W = make_shared<op::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, input_size});
auto R = make_shared<op::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, hidden_size});
auto H_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
auto C_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
auto X = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
auto W = make_shared<opset4::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, input_size});
auto R = make_shared<opset4::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, hidden_size});
auto H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
auto C_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
// Invalid dynamic rank for W tensor.
W = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
W = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "LSTMCell node was created with invalid data.";
// Invalid dynamic rank for X tensor.
W = make_shared<op::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, input_size});
X = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
W = make_shared<opset4::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, input_size});
X = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "LSTMCell node was created with invalid data.";
// Invalid dynamic rank for H_t tensor.
X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
H_t = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
X = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
H_t = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "LSTMCell node was created with invalid data.";
// Invalid dynamic rank for C_t tensor.
H_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
C_t = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
C_t = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "LSTMCell node was created with invalid data.";
// Invalid dynamic rank for R tensor.
C_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
R = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
C_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
R = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "LSTMCell node was created with invalid data.";
// Invalid dynamic rank for B tensor.
R = make_shared<op::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, hidden_size});
auto B = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
auto P = make_shared<op::Parameter>(element::f32, PartialShape{3 * hidden_size});
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, B, P, hidden_size),
ngraph::NodeValidationFailure)
<< "LSTMCell node was created with invalid data.";
// Invalid dynamic rank for P tensor.
B = make_shared<op::Parameter>(element::f32, PartialShape{gates_count * hidden_size});
P = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, B, P, hidden_size),
R = make_shared<opset4::Parameter>(element::f32,
PartialShape{gates_count * hidden_size, hidden_size});
auto B = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, B, hidden_size),
ngraph::NodeValidationFailure)
<< "LSTMCell node was created with invalid data.";
}

View File

@ -16,6 +16,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/opsets/opset4.hpp"
#include "util/type_prop.hpp"
// suppress FusedOp deprecation warnings
@ -40,7 +41,7 @@ struct recurrent_sequence_parameters
//
// Create and initialize default input test tensors.
//
shared_ptr<op::LSTMSequence>
shared_ptr<op::v1::LSTMSequence>
lstm_seq_tensor_initialization(const recurrent_sequence_parameters& param)
{
auto batch_size = param.batch_size;
@ -50,20 +51,21 @@ shared_ptr<op::LSTMSequence>
auto hidden_size = param.hidden_size;
auto et = param.et;
const auto X = make_shared<op::Parameter>(et, PartialShape{batch_size, seq_length, input_size});
const auto X =
make_shared<opset4::Parameter>(et, PartialShape{batch_size, seq_length, input_size});
const auto initial_hidden_state =
make_shared<op::Parameter>(et, PartialShape{batch_size, num_directions, hidden_size});
make_shared<opset4::Parameter>(et, PartialShape{batch_size, num_directions, hidden_size});
const auto initial_cell_state =
make_shared<op::Parameter>(et, PartialShape{batch_size, num_directions, hidden_size});
const auto sequence_lengths = make_shared<op::Parameter>(et, PartialShape{batch_size});
const auto W =
make_shared<op::Parameter>(et, PartialShape{num_directions, hidden_size * 4, input_size});
const auto R =
make_shared<op::Parameter>(et, PartialShape{num_directions, hidden_size * 4, hidden_size});
const auto B = make_shared<op::Parameter>(et, PartialShape{num_directions, hidden_size * 4});
const auto P = make_shared<op::Parameter>(et, PartialShape{num_directions, hidden_size * 3});
make_shared<opset4::Parameter>(et, PartialShape{batch_size, num_directions, hidden_size});
const auto sequence_lengths = make_shared<opset4::Parameter>(et, PartialShape{batch_size});
const auto W = make_shared<opset4::Parameter>(
et, PartialShape{num_directions, hidden_size * 4, input_size});
const auto R = make_shared<opset4::Parameter>(
et, PartialShape{num_directions, hidden_size * 4, hidden_size});
const auto B =
make_shared<opset4::Parameter>(et, PartialShape{num_directions, hidden_size * 4});
const auto lstm_sequence = make_shared<op::LSTMSequence>();
const auto lstm_sequence = make_shared<op::v1::LSTMSequence>();
lstm_sequence->set_argument(0, X);
lstm_sequence->set_argument(1, initial_hidden_state);
@ -72,7 +74,6 @@ shared_ptr<op::LSTMSequence>
lstm_sequence->set_argument(4, W);
lstm_sequence->set_argument(5, R);
lstm_sequence->set_argument(6, B);
lstm_sequence->set_argument(7, P);
return lstm_sequence;
}
@ -86,40 +87,39 @@ TEST(type_prop, lstm_sequence_forward)
const size_t hidden_size = 128;
const auto X =
make_shared<op::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
const auto initial_hidden_state =
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
const auto initial_cell_state =
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{batch_size});
const auto W = make_shared<op::Parameter>(element::f32,
Shape{num_directions, 4 * hidden_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32,
Shape{num_directions, 4 * hidden_size, hidden_size});
const auto B = make_shared<op::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
make_shared<opset4::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
const auto initial_hidden_state = make_shared<opset4::Parameter>(
element::f32, Shape{batch_size, num_directions, hidden_size});
const auto initial_cell_state = make_shared<opset4::Parameter>(
element::f32, Shape{batch_size, num_directions, hidden_size});
const auto sequence_lengths = make_shared<opset4::Parameter>(element::i32, Shape{batch_size});
const auto W = make_shared<opset4::Parameter>(
element::f32, Shape{num_directions, 4 * hidden_size, input_size});
const auto R = make_shared<opset4::Parameter>(
element::f32, Shape{num_directions, 4 * hidden_size, hidden_size});
const auto B =
make_shared<opset4::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
const auto lstm_direction = op::RecurrentSequenceDirection::FORWARD;
const auto lstm_sequence = make_shared<op::LSTMSequence>(X,
initial_hidden_state,
initial_cell_state,
sequence_lengths,
W,
R,
B,
hidden_size,
lstm_direction);
const auto lstm_sequence = make_shared<op::v1::LSTMSequence>(X,
initial_hidden_state,
initial_cell_state,
sequence_lengths,
W,
R,
B,
hidden_size,
lstm_direction);
EXPECT_EQ(lstm_sequence->get_hidden_size(), hidden_size);
EXPECT_EQ(lstm_sequence->get_direction(), op::RecurrentSequenceDirection::FORWARD);
EXPECT_EQ(lstm_sequence->get_weights_format(), op::LSTMWeightsFormat::IFCO);
EXPECT_TRUE(lstm_sequence->get_activations_alpha().empty());
EXPECT_TRUE(lstm_sequence->get_activations_beta().empty());
EXPECT_EQ(lstm_sequence->get_activations()[0], "sigmoid");
EXPECT_EQ(lstm_sequence->get_activations()[1], "tanh");
EXPECT_EQ(lstm_sequence->get_activations()[2], "tanh");
EXPECT_EQ(lstm_sequence->get_clip_threshold(), 0.f);
EXPECT_FALSE(lstm_sequence->get_input_forget());
EXPECT_EQ(lstm_sequence->get_clip(), 0.f);
EXPECT_EQ(lstm_sequence->get_output_element_type(0), element::f32);
EXPECT_EQ(lstm_sequence->get_output_shape(0),
(Shape{batch_size, num_directions, seq_length, hidden_size}));
@ -138,47 +138,44 @@ TEST(type_prop, lstm_sequence_bidirectional)
const size_t hidden_size = 256;
const auto X =
make_shared<op::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
const auto initial_hidden_state =
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
const auto initial_cell_state =
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{batch_size});
const auto W = make_shared<op::Parameter>(element::f32,
Shape{num_directions, 4 * hidden_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32,
Shape{num_directions, 4 * hidden_size, hidden_size});
const auto B = make_shared<op::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
make_shared<opset4::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
const auto initial_hidden_state = make_shared<opset4::Parameter>(
element::f32, Shape{batch_size, num_directions, hidden_size});
const auto initial_cell_state = make_shared<opset4::Parameter>(
element::f32, Shape{batch_size, num_directions, hidden_size});
const auto sequence_lengths = make_shared<opset4::Parameter>(element::i32, Shape{batch_size});
const auto W = make_shared<opset4::Parameter>(
element::f32, Shape{num_directions, 4 * hidden_size, input_size});
const auto R = make_shared<opset4::Parameter>(
element::f32, Shape{num_directions, 4 * hidden_size, hidden_size});
const auto B =
make_shared<opset4::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
const auto weights_format = op::LSTMWeightsFormat::FICO;
const auto lstm_direction = op::LSTMSequence::direction::BIDIRECTIONAL;
const auto lstm_direction = op::v1::LSTMSequence::direction::BIDIRECTIONAL;
const std::vector<float> activations_alpha = {2.7, 7.0, 32.367};
const std::vector<float> activations_beta = {0.0, 5.49, 6.0};
const std::vector<std::string> activations = {"tanh", "sigmoid", "sigmoid"};
const auto lstm_sequence = make_shared<op::LSTMSequence>(X,
initial_hidden_state,
initial_cell_state,
sequence_lengths,
W,
R,
B,
hidden_size,
lstm_direction,
weights_format,
activations_alpha,
activations_beta,
activations);
const auto lstm_sequence = make_shared<op::v1::LSTMSequence>(X,
initial_hidden_state,
initial_cell_state,
sequence_lengths,
W,
R,
B,
hidden_size,
lstm_direction,
activations_alpha,
activations_beta,
activations);
EXPECT_EQ(lstm_sequence->get_hidden_size(), hidden_size);
EXPECT_EQ(lstm_sequence->get_direction(), op::LSTMSequence::direction::BIDIRECTIONAL);
EXPECT_EQ(lstm_sequence->get_weights_format(), op::LSTMWeightsFormat::FICO);
EXPECT_EQ(lstm_sequence->get_direction(), op::v1::LSTMSequence::direction::BIDIRECTIONAL);
EXPECT_EQ(lstm_sequence->get_activations_alpha(), activations_alpha);
EXPECT_EQ(lstm_sequence->get_activations_beta(), activations_beta);
EXPECT_EQ(lstm_sequence->get_activations()[0], "tanh");
EXPECT_EQ(lstm_sequence->get_activations()[1], "sigmoid");
EXPECT_EQ(lstm_sequence->get_activations()[2], "sigmoid");
EXPECT_EQ(lstm_sequence->get_clip_threshold(), 0.f);
EXPECT_FALSE(lstm_sequence->get_input_forget());
EXPECT_EQ(lstm_sequence->get_clip(), 0.f);
EXPECT_EQ(lstm_sequence->get_output_element_type(0), element::f32);
EXPECT_EQ(lstm_sequence->get_output_shape(0),
(Shape{batch_size, num_directions, seq_length, hidden_size}));
@ -330,15 +327,14 @@ TEST(type_prop, lstm_sequence_invalid_input_dimension)
param.et = element::f32;
auto lstm_sequence = lstm_seq_tensor_initialization(param);
auto invalid_rank0_tensor = make_shared<op::Parameter>(param.et, PartialShape{});
auto invalid_rank0_tensor = make_shared<opset4::Parameter>(param.et, PartialShape{});
// Validate invalid rank0 tensor for all inputs: X, initial_hidden_state, initial_cell_state W,
// R, B and P
// R, B
for (auto i = 0; i < lstm_sequence->get_input_size(); i++)
{
lstm_sequence = lstm_seq_tensor_initialization(param);
lstm_sequence->set_argument(i, invalid_rank0_tensor);
ASSERT_THROW(lstm_sequence->validate_and_infer_types(), ngraph::CheckFailure)
<< "LSTMSequence node was created with invalid data.";
}
@ -357,15 +353,14 @@ TEST(type_prop, lstm_sequence_invalid_input_dynamic_rank)
auto lstm_sequence = lstm_seq_tensor_initialization(param);
auto invalid_dynamic_tensor =
make_shared<op::Parameter>(param.et, PartialShape::dynamic(Rank::dynamic()));
make_shared<opset4::Parameter>(param.et, PartialShape::dynamic(Rank::dynamic()));
// Validate invalid dynamic tensor for all inputs: X, initial_hidden_state, initial_cell_state
// W, R, B and P
// W, R, B
for (auto i = 0; i < lstm_sequence->get_input_size(); i++)
{
lstm_sequence = lstm_seq_tensor_initialization(param);
lstm_sequence->set_argument(i, invalid_dynamic_tensor);
ASSERT_THROW(lstm_sequence->validate_and_infer_types(), ngraph::CheckFailure)
<< "LSTMSequence node was created with invalid data.";
}

View File

@ -16,6 +16,7 @@
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "ngraph/opsets/opset4.hpp"
#include "util/type_prop.hpp"
using namespace std;
@ -27,12 +28,12 @@ TEST(type_prop, rnn_cell)
const size_t input_size = 3;
const size_t hidden_size = 3;
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
const auto X = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, input_size});
const auto H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
const auto W = make_shared<opset4::Parameter>(element::f32, Shape{hidden_size, input_size});
const auto R = make_shared<opset4::Parameter>(element::f32, Shape{hidden_size, hidden_size});
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, hidden_size);
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size);
EXPECT_EQ(rnn_cell->get_output_element_type(0), element::f32);
EXPECT_EQ(rnn_cell->get_output_shape(0), (Shape{batch_size, hidden_size}));
}
@ -43,15 +44,15 @@ TEST(type_prop, rnn_cell_invalid_input)
const size_t input_size = 3;
const size_t hidden_size = 3;
auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto X = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, input_size});
auto R = make_shared<opset4::Parameter>(element::f32, Shape{hidden_size, hidden_size});
auto H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
// Invalid W tensor shape.
auto W = make_shared<op::Parameter>(element::f32, Shape{2 * hidden_size, input_size});
auto W = make_shared<opset4::Parameter>(element::f32, Shape{2 * hidden_size, input_size});
try
{
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, hidden_size);
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size);
FAIL() << "RNNCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
@ -61,11 +62,11 @@ TEST(type_prop, rnn_cell_invalid_input)
}
// Invalid R tensor shape.
W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, 1});
W = make_shared<opset4::Parameter>(element::f32, Shape{hidden_size, input_size});
R = make_shared<opset4::Parameter>(element::f32, Shape{hidden_size, 1});
try
{
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, hidden_size);
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size);
FAIL() << "RNNCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
@ -77,11 +78,11 @@ TEST(type_prop, rnn_cell_invalid_input)
}
// Invalid H_t tensor shape.
R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
H_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
R = make_shared<opset4::Parameter>(element::f32, Shape{hidden_size, hidden_size});
H_t = make_shared<opset4::Parameter>(element::f32, Shape{4, hidden_size});
try
{
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, hidden_size);
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size);
FAIL() << "RNNCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
@ -92,11 +93,11 @@ TEST(type_prop, rnn_cell_invalid_input)
}
// Invalid B tensor shape.
H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto B = make_shared<op::Parameter>(element::f32, Shape{2 * hidden_size});
H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto B = make_shared<opset4::Parameter>(element::f32, Shape{2 * hidden_size});
try
{
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, B, hidden_size);
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, B, hidden_size);
FAIL() << "RNNCell node was created with invalid data.";
}
catch (const NodeValidationFailure& error)
@ -112,13 +113,16 @@ TEST(type_prop, rnn_cell_dynamic_batch_size)
const size_t input_size = 3;
const size_t hidden_size = 3;
const auto X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
const auto X =
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
const auto H_t =
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
const auto W = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
const auto W =
make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, input_size});
const auto R =
make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, hidden_size);
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size);
EXPECT_EQ(rnn_cell->get_output_element_type(0), element::f32);
EXPECT_EQ(rnn_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
}
@ -129,13 +133,16 @@ TEST(type_prop, rnn_cell_dynamic_hidden_size)
const size_t input_size = 3;
const auto hidden_size = Dimension::dynamic();
const auto X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
const auto X =
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
const auto H_t =
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
const auto W = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
const auto W =
make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, input_size});
const auto R =
make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, 3);
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, 3);
EXPECT_EQ(rnn_cell->get_output_element_type(0), element::f32);
EXPECT_EQ(rnn_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
}
@ -146,13 +153,16 @@ TEST(type_prop, rnn_cell_dynamic_inputs)
const auto input_size = Dimension::dynamic();
const auto hidden_size = Dimension::dynamic();
const auto X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
const auto R = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
const auto W = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, input_size});
const auto X =
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
const auto R =
make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
const auto W =
make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, input_size});
const auto H_t =
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, 2);
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, 2);
EXPECT_EQ(rnn_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
EXPECT_EQ(rnn_cell->get_output_element_type(0), element::f32);
@ -164,37 +174,41 @@ TEST(type_prop, rnn_cell_invalid_input_rank0)
const size_t input_size = 3;
const size_t hidden_size = 3;
auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto X = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, input_size});
auto R = make_shared<opset4::Parameter>(element::f32, Shape{hidden_size, hidden_size});
auto H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
// Invalid rank0 for W tensor.
auto W = make_shared<op::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
auto W = make_shared<opset4::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "RNNCell node was created with invalid data.";
// Invalid rank0 for X tensor.
W = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, input_size});
X = make_shared<op::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
W = make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, input_size});
X = make_shared<opset4::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "RNNCell node was created with invalid data.";
// Invalid rank0 for H_t tensor.
X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
H_t = make_shared<op::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
X = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, input_size});
H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "RNNCell node was created with invalid data.";
// Invalid rank0 for R tensor.
H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
R = make_shared<op::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
R = make_shared<opset4::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "RNNCell node was created with invalid data.";
// Invalid rank0 for B tensor.
R = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
auto B = make_shared<op::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, B, hidden_size),
R = make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
auto B = make_shared<opset4::Parameter>(element::f32, PartialShape{});
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, B, hidden_size),
ngraph::NodeValidationFailure)
<< "RNNCell node was created with invalid data.";
}
@ -205,37 +219,41 @@ TEST(type_prop, rnn_cell_invalid_input_dynamic_rank)
const size_t input_size = 3;
const size_t hidden_size = 3;
auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
auto X = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, input_size});
auto R = make_shared<opset4::Parameter>(element::f32, Shape{hidden_size, hidden_size});
auto H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
// Invalid dynamic rank for W tensor.
auto W = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
auto W = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "RNNCell node was created with invalid data.";
// Invalid dynamic rank for X tensor.
W = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, input_size});
X = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
W = make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, input_size});
X = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "RNNCell node was created with invalid data.";
// Invalid dynamic rank for H_t tensor.
X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
H_t = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
X = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, input_size});
H_t = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "RNNCell node was created with invalid data.";
// Invalid dynamic rank for R tensor.
H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
R = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
R = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
ngraph::NodeValidationFailure)
<< "RNNCell node was created with invalid data.";
// Invalid dynamic rank for B tensor.
R = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
auto B = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, B, hidden_size),
R = make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
auto B = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, B, hidden_size),
ngraph::NodeValidationFailure)
<< "RNNCell node was created with invalid data.";
}