LSTMCell/Sequence v1, reference implementations and decompose transformations for LSTM/GRU/RNN Cells (#2000)
* validate_and_infer_types() implementation * input parameter validation for LSTM, GRU and RNN * style-check applied * Add LSTMSequence dynamic shape validation and test props for RNNCell, GRUCell, LSTMCell and LSTMSequence. * recurrent_sequence.hpp moved to ngraph/core/include/ngraph/op/util/ * style check applied * removed unused variable from LSTMSequence::validate_and_infer_types * Add missing newline mark at the end of file. * Add supression macro for FusedOp deprecation. * Add element type initialization * transpose,rnn cell reference implementations * Apply PR review remarks * reference implementations for cells op, single layer tests, align lstm cell/sequence according to the spec * lstm/gru/rnn cell decompostion transformations * ngraph codestyle * clean up * ngraph code style * change inheritance of Cells, fix build * fix build * fix build again * remove Peepholes from LSTMSeq, fix copy_runtime_info in transformations * Rewrite tests to use gtest exception assertions. * resolve tests issues * ngraph codestyle * add missed files * fix typeprop tests * fix lstm sequence checks * fix arm build * fix arm again * delete unnecessary file * add convert weghts format function, enable lstm test, resolve review comments * add ngraph builders * ngraph codestyle * fix unit tests * revert transpose reference implementation * revert LSTM Cell v0, add LSTMCell v1, update transformation lstm_cell_to_cell_ie * v1 version of LSTMCell op * LSTMSequence v1 operation, exclude LSTMSeq from opset4 * fix python api tests * resolve review comments, tests for decomposition transformations, switch lstm cell to opset4 in mo Co-authored-by: Szymon Durawa <szymon.durawa@intel.com>
This commit is contained in:
parent
28eed7708e
commit
2f5a28d44f
@ -410,6 +410,29 @@ InferenceEngine::details::CNNLayerCreator::CNNLayerCreator(const std::shared_ptr
|
||||
return res;
|
||||
});
|
||||
|
||||
addSpecificCreator({"LSTMCellIE"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
||||
const std::map<std::string, std::string> params) -> CNNLayerPtr {
|
||||
LayerParams attrs = {node->get_friendly_name(), "LSTMCell",
|
||||
details::convertPrecision(node->get_output_element_type(0))};
|
||||
auto res = std::make_shared<LSTMCell>(attrs);
|
||||
res->params = params;
|
||||
Builder::NodeConverter<ngraph::op::Constant> converter;
|
||||
const auto weightsNode = node->input_value(3).get_node_shared_ptr();
|
||||
if (converter.canCreate(weightsNode)) {
|
||||
const auto& weights = converter.createLayer(weightsNode);
|
||||
res->blobs["weights"] = weights->blobs["custom"];
|
||||
res->_weights = weights->blobs["custom"];
|
||||
}
|
||||
|
||||
const auto biasNode = node->input_value(4).get_node_shared_ptr();
|
||||
if (converter.canCreate(biasNode)) {
|
||||
const auto& bias = converter.createLayer(biasNode);
|
||||
res->blobs["biases"] = bias->blobs["custom"];
|
||||
res->_biases = bias->blobs["custom"];
|
||||
}
|
||||
return res;
|
||||
});
|
||||
|
||||
addSpecificCreator({"RNNCellIE"}, [](const std::shared_ptr<::ngraph::Node>& node,
|
||||
const std::map<std::string, std::string>& params) -> CNNLayerPtr {
|
||||
LayerParams attrs = {node->get_friendly_name(), "RNNCell",
|
||||
@ -672,7 +695,6 @@ void convertFunctionToICNNNetwork(const std::shared_ptr<const ::ngraph::Function
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::TopKIE>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::Unsqueeze>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::TensorIterator>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::LSTMCellIE>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::HardSigmoid_IE>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::v1::LogicalNot>>(),
|
||||
std::make_shared<Builder::NodeConverter<::ngraph::op::ShuffleChannels>>(),
|
||||
|
@ -1866,54 +1866,6 @@ CNNLayer::Ptr NodeConverter<ngraph::op::FullyConnected>::createLayer(const std::
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
CNNLayer::Ptr NodeConverter<ngraph::op::LSTMCellIE>::createLayer(const std::shared_ptr<ngraph::Node>& layer) const {
|
||||
LayerParams params = {layer->get_friendly_name(), "LSTMCell",
|
||||
details::convertPrecision(layer->get_output_element_type(0))};
|
||||
auto castedLayer = ngraph::as_type_ptr<ngraph::op::LSTMCellIE>(layer);
|
||||
if (castedLayer == nullptr) THROW_IE_EXCEPTION << "Cannot get " << params.type << " layer " << params.name;
|
||||
|
||||
auto res = std::make_shared<InferenceEngine::LSTMCell>(params);
|
||||
res->params["hidden_size"] = asString(castedLayer->get_hidden_size());
|
||||
std::string value;
|
||||
for (const auto& val : castedLayer->get_activations()) {
|
||||
if (!value.empty()) value += ",";
|
||||
value += val;
|
||||
}
|
||||
res->params["activations"] = value;
|
||||
|
||||
value.clear();
|
||||
for (const auto& val : castedLayer->get_activations_alpha()) {
|
||||
if (!value.empty()) value += ",";
|
||||
value += val;
|
||||
}
|
||||
res->params["activations_alpha"] = value;
|
||||
|
||||
value.clear();
|
||||
for (const auto& val : castedLayer->get_activations_beta()) {
|
||||
if (!value.empty()) value += ",";
|
||||
value += val;
|
||||
}
|
||||
res->params["activations_beta"] = value;
|
||||
res->params["clip"] = asString(castedLayer->get_clip());
|
||||
|
||||
NodeConverter<ngraph::op::Constant> converter;
|
||||
const auto weightsNode = layer->input_value(3).get_node_shared_ptr();
|
||||
if (converter.canCreate(weightsNode)) {
|
||||
const auto& weights = converter.createLayer(weightsNode);
|
||||
res->blobs["weights"] = weights->blobs["custom"];
|
||||
res->_weights = weights->blobs["custom"];
|
||||
}
|
||||
|
||||
const auto biasNode = layer->input_value(4).get_node_shared_ptr();
|
||||
if (converter.canCreate(biasNode)) {
|
||||
const auto& bias = converter.createLayer(biasNode);
|
||||
res->blobs["biases"] = bias->blobs["custom"];
|
||||
res->_biases = bias->blobs["custom"];
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template <>
|
||||
CNNLayer::Ptr NodeConverter<ngraph::op::MatMul>::createLayer(const std::shared_ptr<ngraph::Node>& layer) const {
|
||||
LayerParams params = {layer->get_friendly_name(), "Gemm",
|
||||
|
@ -439,7 +439,7 @@ std::shared_ptr<ngraph::Node> V10Parser::createNode(const std::vector<ngraph::Ou
|
||||
std::make_shared<LayerCreator<ngraph::op::v1::Select>>("Select"),
|
||||
std::make_shared<LayerCreator<ngraph::op::LRN>>("LRN"),
|
||||
std::make_shared<LayerCreator<ngraph::op::MVN>>("MVN"),
|
||||
std::make_shared<LayerCreator<ngraph::op::LSTMCell>>("LSTMCell"),
|
||||
std::make_shared<LayerCreator<ngraph::op::v0::LSTMCell>>("LSTMCell"),
|
||||
std::make_shared<LayerCreator<ngraph::op::v1::MaxPool>>("MaxPool"),
|
||||
std::make_shared<LayerCreator<ngraph::op::v1::Maximum>>("Maximum"),
|
||||
std::make_shared<LayerCreator<ngraph::op::v1::Minimum>>("Minimum"),
|
||||
@ -910,7 +910,7 @@ std::shared_ptr<ngraph::Node> V10Parser::LayerCreator<ngraph::op::Convert>::crea
|
||||
|
||||
// LSTMCell layer
|
||||
template <>
|
||||
std::shared_ptr<ngraph::Node> V10Parser::LayerCreator<ngraph::op::LSTMCell>::createLayer(
|
||||
std::shared_ptr<ngraph::Node> V10Parser::LayerCreator<ngraph::op::v0::LSTMCell>::createLayer(
|
||||
const ngraph::OutputVector& inputs, const pugi::xml_node& node, std::istream& binStream,
|
||||
const GenericLayerParams& layerParsePrms) {
|
||||
checkParameters(inputs, layerParsePrms, 6);
|
||||
@ -922,7 +922,7 @@ std::shared_ptr<ngraph::Node> V10Parser::LayerCreator<ngraph::op::LSTMCell>::cre
|
||||
std::vector<float> activations_alpha = getParameters<float>(dn, "activations_alpha", {});
|
||||
std::vector<float> activations_beta = getParameters<float>(dn, "activations_beta", {});
|
||||
float clip = GetFloatAttr(dn, "clip", 0.f);
|
||||
return std::make_shared<ngraph::op::LSTMCell>(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5],
|
||||
return std::make_shared<ngraph::op::v0::LSTMCell>(inputs[0], inputs[1], inputs[2], inputs[3], inputs[4], inputs[5],
|
||||
GetUInt64Attr(dn, "hidden_size"), ngraph::op::LSTMWeightsFormat::IFCO,
|
||||
activations, activations_alpha, activations_beta, clip);
|
||||
}
|
||||
|
@ -41,13 +41,14 @@ public:
|
||||
const std::vector<float>& get_activations_alpha() { return m_activations_alpha; }
|
||||
const std::vector<float>& get_activations_beta() { return m_activations_beta; }
|
||||
float get_clip() {return m_clip;}
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
protected:
|
||||
int64_t m_hidden_size{};
|
||||
|
||||
const std::vector<std::string> m_activations;
|
||||
const std::vector<float> m_activations_alpha;
|
||||
const std::vector<float> m_activations_beta;
|
||||
std::vector<std::string> m_activations;
|
||||
std::vector<float> m_activations_alpha;
|
||||
std::vector<float> m_activations_beta;
|
||||
float m_clip;
|
||||
};
|
||||
|
||||
|
@ -0,0 +1,41 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API GRUCellDecomposition;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief GRUCellDecomposition transformation decomposes GRUCell layer with inputs X, H, W, R, B
|
||||
* to Add, Split, MatMul, Multiply and Subtract ops according to the formula:
|
||||
(.) - Denotes element-wise multiplication.
|
||||
* - Denotes dot product.
|
||||
f, g - are activation functions
|
||||
|
||||
zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
|
||||
rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
|
||||
ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # when linear_before_reset := false # (default)
|
||||
ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset:= true
|
||||
Ht = (1 - zt) (.) ht + zt (.) Ht-1
|
||||
* *
|
||||
*/
|
||||
|
||||
class ngraph::pass::GRUCellDecomposition: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
GRUCellDecomposition();
|
||||
};
|
@ -0,0 +1,42 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API LSTMCellDecomposition;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief LSTMCellDecomposition transformation decomposes LSTMCell layer with inputs X, H, C, W, R, B
|
||||
* to Add, Split, MatMul, Multiply ops according to the formula:
|
||||
* (.) - Denotes element-wise multiplication.
|
||||
* - Denotes dot product.
|
||||
f, g, h - are activation functions.
|
||||
|
||||
* it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
|
||||
ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf)
|
||||
ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
|
||||
ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo)
|
||||
Ct = ft (.) Ct-1 + it (.) ct
|
||||
Ht = ot (.) h(Ct)
|
||||
* *
|
||||
*/
|
||||
|
||||
class ngraph::pass::LSTMCellDecomposition: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
LSTMCellDecomposition();
|
||||
};
|
@ -0,0 +1,36 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API RNNCellDecomposition;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief RNNCellDecomposition transformation decomposes RNNCell layer with inputs X, H, W, R, B
|
||||
* to Add, MatMul ops according to the formula:
|
||||
* - Denotes dot product.
|
||||
f - is an activation functions.
|
||||
|
||||
* Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
|
||||
* *
|
||||
*/
|
||||
|
||||
class ngraph::pass::RNNCellDecomposition: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
RNNCellDecomposition();
|
||||
};
|
@ -101,6 +101,9 @@ TRANSFORMATIONS_API bool has_f16_constants(const std::shared_ptr<const ngraph::F
|
||||
|
||||
TRANSFORMATIONS_API bool check_for_broadcast(const ngraph::Shape &ref_shape, const ngraph::Shape &other_shape);
|
||||
|
||||
TRANSFORMATIONS_API std::shared_ptr<ngraph::Node> activation(const std::string& activation_name,
|
||||
const ngraph::Output<ngraph::Node>& apply_to);
|
||||
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
@ -37,6 +37,15 @@ void op::LSTMCellIE::validate_and_infer_types() {
|
||||
set_output_type(1, arg_type, output_shape);
|
||||
}
|
||||
|
||||
bool ngraph::op::LSTMCellIE::visit_attributes(AttributeVisitor& visitor) {
|
||||
visitor.on_attribute("hidden_size", m_hidden_size);
|
||||
visitor.on_attribute("activations", m_activations);
|
||||
visitor.on_attribute("activations_alpha", m_activations_alpha);
|
||||
visitor.on_attribute("activations_beta", m_activations_beta);
|
||||
visitor.on_attribute("clip", m_clip);
|
||||
return true;
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::LSTMCellIE::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
check_new_args_count(this, new_args);
|
||||
return make_shared<op::LSTMCellIE>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4),
|
||||
|
@ -9,22 +9,25 @@
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/op/util/rnn_cell_base.hpp>
|
||||
|
||||
#include <ngraph_ops/lstm_cell_ie.hpp>
|
||||
#include <ngraph_ops/gru_cell_ie.hpp>
|
||||
#include <ngraph_ops/rnn_cell_ie.hpp>
|
||||
|
||||
ngraph::pass::ConvertLSTMCellMatcher::ConvertLSTMCellMatcher() {
|
||||
auto lstm_cell_ngraph = ngraph::pattern::wrap_type<ngraph::opset1::LSTMCell>();
|
||||
|
||||
auto is_supported_lstm_cell = [](const std::shared_ptr<Node>& n) {
|
||||
return pattern::has_class<ngraph::opset1::LSTMCell>()(n) || pattern::has_class<ngraph::opset4::LSTMCell>()(n);
|
||||
};
|
||||
auto any_lstm = std::make_shared<pattern::op::Label>(element::f32, Shape{}, is_supported_lstm_cell);
|
||||
ngraph::matcher_pass_callback callback = [](pattern::Matcher& m) {
|
||||
auto lstm_cell = std::dynamic_pointer_cast<ngraph::opset1::LSTMCell> (m.get_match_root());
|
||||
auto lstm_cell = std::dynamic_pointer_cast<ngraph::op::util::RNNCellBase>(m.get_match_root());
|
||||
if (!lstm_cell) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto W = std::dynamic_pointer_cast<ngraph::opset1::Constant> (lstm_cell->input_value(3).get_node_shared_ptr());
|
||||
if (!W) {
|
||||
return false;
|
||||
@ -53,7 +56,7 @@ ngraph::pass::ConvertLSTMCellMatcher::ConvertLSTMCellMatcher() {
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(lstm_cell_ngraph, "ConvertLSTMCellToLSTMCellIE");
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(any_lstm, "ConvertLSTMCellToLSTMCellIE");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,104 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/gru_cell_decomposition.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
ngraph::pass::GRUCellDecomposition::GRUCellDecomposition() {
|
||||
auto gru_cell = ngraph::pattern::wrap_type<opset4::GRUCell>();
|
||||
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
|
||||
auto gru_cell = std::dynamic_pointer_cast<ngraph::opset4::GRUCell> (m.get_match_root());
|
||||
if (!gru_cell) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const Output<Node>& X = gru_cell->input_value(0);
|
||||
const Output<Node>& H_t = gru_cell->input_value(1);
|
||||
const Output<Node>& W = gru_cell->input_value(2);
|
||||
const Output<Node>& R = gru_cell->input_value(3);
|
||||
const Output<Node>& B = gru_cell->input_value(4);
|
||||
|
||||
// Xt*(W^T)
|
||||
auto Xt_W = std::make_shared<opset4::MatMul>(X, W, false, true);
|
||||
// Ht-1*(R^T)
|
||||
auto Ht_R = std::make_shared<opset4::MatMul>(H_t, R, false, true);
|
||||
|
||||
// split to gates:
|
||||
auto axis_0 = ngraph::opset4::Constant::create(element::i64, Shape{}, {0});
|
||||
auto axis_1 = ngraph::opset4::Constant::create(element::i64, Shape{}, {1});
|
||||
auto Xt_W_zrh = std::make_shared<opset4::Split>(Xt_W, axis_1, 3);
|
||||
auto R_zrh = std::make_shared<opset4::Split>(R, axis_0, 3);
|
||||
auto Ht_R_zrh = std::make_shared<opset4::Split>(Ht_R, axis_1, 3);
|
||||
auto biases_zrh = std::make_shared<opset4::Split>(B, axis_0, gru_cell->get_linear_before_reset() ? 4 : 3);
|
||||
|
||||
// Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz
|
||||
auto add_z_1 = std::make_shared<opset4::Add>(Ht_R_zrh->output(0), biases_zrh->output(0));
|
||||
auto add_z_2 = std::make_shared<opset4::Add>(Xt_W_zrh->output(0), add_z_1);
|
||||
|
||||
// Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr
|
||||
auto add_r_1 = std::make_shared<opset4::Add>(Ht_R_zrh->output(1), biases_zrh->output(1));
|
||||
auto add_r_2 = std::make_shared<opset4::Add>(Xt_W_zrh->output(1), add_r_1);
|
||||
|
||||
auto clip = gru_cell->get_clip();
|
||||
std::shared_ptr<Node> clamp_z = add_z_2;
|
||||
std::shared_ptr<Node> clamp_r = add_r_2;
|
||||
if (clip > 0.f) {
|
||||
clamp_z = std::make_shared<opset4::Clamp>(add_z_2, -clip, clip);
|
||||
clamp_r = std::make_shared<opset4::Clamp>(add_r_2, -clip, clip);
|
||||
ngraph::copy_runtime_info(gru_cell, {clamp_z, clamp_r});
|
||||
}
|
||||
|
||||
// zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
|
||||
auto z_t = ngraph::op::util::activation(gru_cell->get_activations()[0], clamp_z);
|
||||
// rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
|
||||
auto r_t = ngraph::op::util::activation(gru_cell->get_activations()[0], clamp_r);
|
||||
|
||||
std::shared_ptr<Node> _h;
|
||||
if (gru_cell->get_linear_before_reset()) {
|
||||
// _h = Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh
|
||||
auto Ht_Rh_Rbh = std::make_shared<opset4::Add>(Ht_R_zrh->output(2), biases_zrh->output(3));
|
||||
auto mul_h_1 = std::make_shared<opset4::Multiply>(r_t, Ht_Rh_Rbh);
|
||||
auto add_h_1 = std::make_shared<opset4::Add>(mul_h_1, biases_zrh->output(2));
|
||||
_h = std::make_shared<opset4::Add>(Xt_W_zrh->output(2), add_h_1);
|
||||
ngraph::copy_runtime_info(gru_cell, {Ht_Rh_Rbh, mul_h_1, add_h_1, _h});
|
||||
} else {
|
||||
// _h = Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh
|
||||
auto rt_Ht = std::make_shared<opset4::Multiply>(r_t, H_t);
|
||||
auto mul_h_1 = std::make_shared<opset4::MatMul>(rt_Ht, R_zrh->output(2), false, true);
|
||||
auto add_h_1 = std::make_shared<opset4::Add>(mul_h_1, biases_zrh->output(2));
|
||||
_h = std::make_shared<opset4::Add>(Xt_W_zrh->output(2), add_h_1);
|
||||
ngraph::copy_runtime_info(gru_cell, {rt_Ht, mul_h_1, add_h_1, _h});
|
||||
}
|
||||
// ht = g(_h)
|
||||
std::shared_ptr<Node> clamp_h = _h;
|
||||
if (clip > 0.f) {
|
||||
clamp_h = std::make_shared<opset4::Clamp>(_h, -clip, clip);
|
||||
ngraph::copy_runtime_info(gru_cell, clamp_h);
|
||||
}
|
||||
auto h_t = ngraph::op::util::activation(gru_cell->get_activations()[1], clamp_h);
|
||||
|
||||
// Ht = (1 - zt) (.) ht + zt (.) Ht-1
|
||||
auto one = opset4::Constant::create(z_t->get_element_type(), Shape{1}, {1.f});
|
||||
auto sub = std::make_shared<opset4::Subtract>(one, z_t);
|
||||
auto mul_1 = std::make_shared<opset4::Multiply>(sub, h_t);
|
||||
auto mul_2 = std::make_shared<opset4::Multiply>(z_t, H_t);
|
||||
auto out_H = std::make_shared<opset4::Add>(mul_1, mul_2);
|
||||
|
||||
out_H->set_friendly_name(gru_cell->get_friendly_name());
|
||||
ngraph::copy_runtime_info(gru_cell, {Xt_W, Ht_R, axis_0, Xt_W_zrh, R_zrh, Ht_R_zrh, biases_zrh,
|
||||
add_z_1, add_z_2, add_r_1, add_r_2, h_t, one, sub, mul_1, mul_2, out_H});
|
||||
ngraph::replace_node(gru_cell, out_H);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(gru_cell, "GRUCellDecomposition");
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,85 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/lstm_cell_decomposition.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
ngraph::pass::LSTMCellDecomposition::LSTMCellDecomposition() {
|
||||
auto lstm_cell = ngraph::pattern::wrap_type<opset4::LSTMCell>();
|
||||
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
|
||||
auto lstm_cell = std::dynamic_pointer_cast<ngraph::opset4::LSTMCell> (m.get_match_root());
|
||||
if (!lstm_cell) {
|
||||
return false;
|
||||
}
|
||||
const Output<Node>& X = lstm_cell->input_value(0);
|
||||
const Output<Node>& H_t = lstm_cell->input_value(1);
|
||||
const Output<Node>& C_t = lstm_cell->input_value(2);
|
||||
const Output<Node>& W = lstm_cell->input_value(3);
|
||||
const Output<Node>& R = lstm_cell->input_value(4);
|
||||
const Output<Node>& bias = lstm_cell->input_value(5);
|
||||
|
||||
// Xt*(W^T)
|
||||
auto Xt_W = std::make_shared<opset4::MatMul>(X, W, false, true);
|
||||
// Ht-1*(R^T)
|
||||
auto Ht_R = std::make_shared<opset4::MatMul>(H_t, R, false, true);
|
||||
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb
|
||||
auto add = std::make_shared<opset4::Add>(Ht_R, bias);
|
||||
auto XHB = std::make_shared<opset4::Add>(Xt_W, add);
|
||||
|
||||
auto axis_node = ngraph::opset4::Constant::create(element::u64, Shape{}, {1});
|
||||
auto split = std::make_shared<opset4::Split>(XHB, axis_node, 4);
|
||||
Output<Node> f = split->output(0);
|
||||
Output<Node> i = split->output(1);
|
||||
Output<Node> c = split->output(2);
|
||||
Output<Node> o = split->output(3);
|
||||
|
||||
auto clip = lstm_cell->get_clip();
|
||||
if (clip > 0.f) {
|
||||
auto clamp_f = std::make_shared<opset4::Clamp>(f, -clip, clip);
|
||||
auto clamp_i = std::make_shared<opset4::Clamp>(i, -clip, clip);
|
||||
auto clamp_c = std::make_shared<opset4::Clamp>(c, -clip, clip);
|
||||
auto clamp_o = std::make_shared<opset4::Clamp>(o, -clip, clip);
|
||||
f = clamp_f;
|
||||
i = clamp_i;
|
||||
c = clamp_c;
|
||||
o = clamp_o;
|
||||
ngraph::copy_runtime_info(lstm_cell, {clamp_f, clamp_i, clamp_c, clamp_o});
|
||||
}
|
||||
|
||||
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf)
|
||||
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
|
||||
// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
|
||||
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo)
|
||||
auto f_t = ngraph::op::util::activation(lstm_cell->get_activations()[0], f);
|
||||
auto i_t = ngraph::op::util::activation(lstm_cell->get_activations()[0], i);
|
||||
auto c_t = ngraph::op::util::activation(lstm_cell->get_activations()[1], c);
|
||||
auto o_t = ngraph::op::util::activation(lstm_cell->get_activations()[0], o);
|
||||
|
||||
// Ct = ft (.) Ct-1 + it (.) ct
|
||||
auto mul1 = std::make_shared<opset4::Multiply>(f_t, C_t);
|
||||
auto mul2 = std::make_shared<opset4::Multiply>(i_t, c_t);
|
||||
auto out_C = std::make_shared<opset4::Add>(mul1, mul2);
|
||||
|
||||
// H = ot (.) h(Ct)
|
||||
auto hC = ngraph::op::util::activation(lstm_cell->get_activations()[2], out_C);
|
||||
auto out_H = std::make_shared<opset4::Multiply>(o_t, hC);
|
||||
|
||||
out_H->set_friendly_name(lstm_cell->get_friendly_name()+".0");
|
||||
out_C->set_friendly_name(lstm_cell->get_friendly_name()+".1");
|
||||
ngraph::copy_runtime_info(lstm_cell, {Xt_W, Ht_R, add, split, mul1, mul2, out_H, hC, out_C, axis_node, XHB,
|
||||
f_t, i_t, c_t, o_t});
|
||||
ngraph::replace_node(lstm_cell, {out_H->output(0), out_C->output(0)});
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(lstm_cell, "LSTMCellDecomposition");
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -0,0 +1,52 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/rnn_cell_decomposition.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/op/util/activation_functions.hpp>
|
||||
|
||||
ngraph::pass::RNNCellDecomposition::RNNCellDecomposition() {
|
||||
auto rnn_cell = ngraph::pattern::wrap_type<opset4::RNNCell>();
|
||||
ngraph::matcher_pass_callback callback = [](ngraph::pattern::Matcher& m) {
|
||||
auto rnn_cell = std::dynamic_pointer_cast<ngraph::opset4::RNNCell> (m.get_match_root());
|
||||
if (!rnn_cell) {
|
||||
return false;
|
||||
}
|
||||
const Output<Node>& X = rnn_cell->input_value(0);
|
||||
const Output<Node>& H_t = rnn_cell->input_value(1);
|
||||
const Output<Node>& W = rnn_cell->input_value(2);
|
||||
const Output<Node>& R = rnn_cell->input_value(3);
|
||||
const Output<Node>& bias = rnn_cell->input_value(4);
|
||||
|
||||
// Xt*(W^T)
|
||||
auto Xt_W = std::make_shared<opset4::MatMul>(X, W, false, true);
|
||||
// Ht-1*(R^T)
|
||||
auto Ht_R = std::make_shared<opset4::MatMul>(H_t, R, false, true);
|
||||
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb
|
||||
auto add = std::make_shared<opset4::Add>(Ht_R, bias);
|
||||
auto i_t = std::make_shared<opset4::Add>(Xt_W, add);
|
||||
|
||||
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
|
||||
auto clip = rnn_cell->get_clip();
|
||||
std::shared_ptr<Node> clamp = i_t;
|
||||
if (clip > 0.f) {
|
||||
clamp = std::make_shared<opset4::Clamp>(i_t, -clip, clip);
|
||||
ngraph::copy_runtime_info(rnn_cell, clamp);
|
||||
}
|
||||
auto out = ngraph::op::util::activation(rnn_cell->get_activations()[0], clamp);
|
||||
out->set_friendly_name(rnn_cell->get_friendly_name());
|
||||
ngraph::copy_runtime_info(rnn_cell, {Xt_W, Ht_R, add, i_t, out});
|
||||
ngraph::replace_node(rnn_cell, out);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(rnn_cell, "RNNCellDecomposition");
|
||||
register_matcher(m, callback);
|
||||
}
|
@ -108,6 +108,18 @@ bool check_for_broadcast(const ngraph::Shape &ref_shape, const ngraph::Shape &ot
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> activation(const std::string& activation_name, const ngraph::Output<ngraph::Node>& apply_to) {
|
||||
if (activation_name == "relu") {
|
||||
return std::make_shared<ngraph::opset4::Relu>(apply_to);
|
||||
} else if (activation_name == "sigmoid") {
|
||||
return std::make_shared<ngraph::opset4::Sigmoid>(apply_to);
|
||||
} else if (activation_name == "tanh") {
|
||||
return std::make_shared<ngraph::opset4::Tanh>(apply_to);
|
||||
} else {
|
||||
throw ngraph_error("Unsupported activation function");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace util
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -14,6 +14,7 @@
|
||||
#include <ngraph/ops.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/function.hpp>
|
||||
#include <transformations/convert_opset1_to_legacy/convert_cells_to_cells_ie.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
@ -129,7 +130,7 @@ TEST(TransformationTests, RNNCellConversionTest) {
|
||||
ASSERT_TRUE(cell_node->get_friendly_name() == "test_cell") << "Transformation ConvertRNNCellToRNNCellIE should keep output names.\n";
|
||||
}
|
||||
|
||||
TEST(TransformationTests, LSTMCellConversionTest) {
|
||||
TEST(TransformationTests, LSTMCellConversionTest_opset3) {
|
||||
const size_t batch_size = 2;
|
||||
const size_t input_size = 3;
|
||||
const size_t hidden_size = 3;
|
||||
@ -186,4 +187,76 @@ TEST(TransformationTests, LSTMCellConversionTest) {
|
||||
auto result_node_of_converted_f = f->get_output_op(0);
|
||||
auto cell_node = result_node_of_converted_f->input(0).get_source_output().get_node_shared_ptr();
|
||||
ASSERT_TRUE(cell_node->get_friendly_name() == "test_cell") << "Transformation ConvertLSTMCellToLSTMCellIE should keep output names.\n";
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TransformationTests, LSTMCellConversionTest_opset4) {
|
||||
const size_t batch_size = 2;
|
||||
const size_t input_size = 3;
|
||||
const size_t hidden_size = 3;
|
||||
const size_t gates_count = 4;
|
||||
|
||||
std::shared_ptr<ngraph::Function> f(nullptr), f_ref(nullptr);
|
||||
std::shared_ptr<ngraph::opset4::LSTMCell> cell;
|
||||
{
|
||||
const auto X = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32,
|
||||
ngraph::Shape{batch_size, input_size});
|
||||
const auto W =
|
||||
std::make_shared<ngraph::opset1::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{gates_count * hidden_size, input_size});
|
||||
const auto R =
|
||||
std::make_shared<ngraph::opset1::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{gates_count * hidden_size, hidden_size});
|
||||
const auto H_t = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32,
|
||||
ngraph::Shape{batch_size, hidden_size});
|
||||
const auto C_t = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32,
|
||||
ngraph::Shape{batch_size, hidden_size});
|
||||
const auto B = std::make_shared<ngraph::opset1::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{gates_count * hidden_size});
|
||||
|
||||
cell = std::make_shared<ngraph::opset4::LSTMCell>(X, H_t, C_t, W, R, B, hidden_size);
|
||||
cell->set_friendly_name("test_cell");
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{cell}, ngraph::ParameterVector{X, H_t, C_t});
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::ConvertLSTMCellMatcher>();
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
}
|
||||
|
||||
{
|
||||
const auto X = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32,
|
||||
ngraph::Shape{batch_size, input_size});
|
||||
const auto W =
|
||||
std::make_shared<ngraph::opset1::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{gates_count * hidden_size, input_size});
|
||||
const auto R =
|
||||
std::make_shared<ngraph::opset1::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{gates_count * hidden_size, hidden_size});
|
||||
const auto H_t = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32,
|
||||
ngraph::Shape{batch_size, hidden_size});
|
||||
const auto C_t = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32,
|
||||
ngraph::Shape{batch_size, hidden_size});
|
||||
const auto B = std::make_shared<ngraph::opset1::Constant>(ngraph::element::f32,
|
||||
ngraph::Shape{gates_count * hidden_size});
|
||||
|
||||
auto concat = std::make_shared<ngraph::opset1::Concat>(ngraph::NodeVector({W, R}), 1);
|
||||
auto cell_ie = std::make_shared<ngraph::op::LSTMCellIE>(X, H_t, C_t, concat, B,
|
||||
cell->get_hidden_size(),
|
||||
cell->get_activations(),
|
||||
cell->get_activations_alpha(),
|
||||
cell->get_activations_beta(),
|
||||
cell->get_clip());
|
||||
cell_ie->set_friendly_name("test_cell");
|
||||
|
||||
f_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{cell_ie}, ngraph::ParameterVector{X, H_t, C_t});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
|
||||
auto result_node_of_converted_f = f->get_output_op(0);
|
||||
auto cell_node = result_node_of_converted_f->input(0).get_source_output().get_node_shared_ptr();
|
||||
ASSERT_TRUE(cell_node->get_friendly_name() == "test_cell")
|
||||
<< "Transformation ConvertLSTMCellToLSTMCellIE should keep output names.\n";
|
||||
}
|
||||
|
@ -0,0 +1,37 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "single_layer_tests/gru_cell.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
std::vector<bool> should_decompose{false, true};
|
||||
std::vector<size_t> batch{5};
|
||||
std::vector<size_t> hidden_size{1, 10};
|
||||
std::vector<size_t> input_size{1, 30};
|
||||
std::vector<std::vector<std::string>> activations = {{"relu", "tanh"}, {"tanh", "sigmoid"}, {"sigmoid", "tanh"},
|
||||
{"tanh", "relu"}};
|
||||
std::vector<float> clip = {0.0f, 0.7f};
|
||||
std::vector<bool> linear_before_reset = {true, false};
|
||||
std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(GRUCellCommon, GRUCellTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(should_decompose),
|
||||
::testing::ValuesIn(batch),
|
||||
::testing::ValuesIn(hidden_size),
|
||||
::testing::ValuesIn(input_size),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip),
|
||||
::testing::ValuesIn(linear_before_reset),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
GRUCellTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
@ -0,0 +1,36 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "single_layer_tests/lstm_cell.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
std::vector<bool> should_decompose{false, true};
|
||||
std::vector<size_t> batch{5};
|
||||
std::vector<size_t> hidden_size{1, 10};
|
||||
std::vector<size_t> input_size{1, 30};
|
||||
std::vector<std::vector<std::string>> activations = {{"relu", "sigmoid", "tanh"}, {"sigmoid", "tanh", "tanh"},
|
||||
{"tanh", "relu", "sigmoid"}, {"sigmoid", "sigmoid", "sigmoid"},
|
||||
{"tanh", "tanh", "tanh"}, {"relu", "relu", "relu"}};
|
||||
std::vector<float> clip{0.f, 0.7f};
|
||||
std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(LSTMCellCommon, LSTMCellTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(should_decompose),
|
||||
::testing::ValuesIn(batch),
|
||||
::testing::ValuesIn(hidden_size),
|
||||
::testing::ValuesIn(input_size),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
LSTMCellTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
@ -0,0 +1,34 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "single_layer_tests/rnn_cell.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
std::vector<bool> should_decompose{false, true};
|
||||
std::vector<size_t> batch{1, 5};
|
||||
std::vector<size_t> hidden_size{1, 10};
|
||||
std::vector<size_t> input_size{1, 30};
|
||||
std::vector<std::vector<std::string>> activations = {{"relu"}, {"sigmoid"}, {"tanh"}};
|
||||
std::vector<float> clip = {0.f, 0.7f};
|
||||
std::vector<InferenceEngine::Precision> netPrecisions = {InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(RNNCellCommon, RNNCellTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(should_decompose),
|
||||
::testing::ValuesIn(batch),
|
||||
::testing::ValuesIn(hidden_size),
|
||||
::testing::ValuesIn(input_size),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
RNNCellTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
@ -0,0 +1,38 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "functional_test_utils/layer_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "ngraph_functions/utils/ngraph_helpers.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
using GRUCellParams = typename std::tuple<
|
||||
bool, // using decompose to sub-ops transformation
|
||||
size_t, // batch
|
||||
size_t, // hidden size
|
||||
size_t, // input size
|
||||
std::vector<std::string>, // activations
|
||||
float, // clip
|
||||
bool, // linear_before_reset
|
||||
InferenceEngine::Precision, // Network precision
|
||||
std::string>; // Device name
|
||||
|
||||
class GRUCellTest : public testing::WithParamInterface<GRUCellParams >,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<GRUCellParams> &obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
@ -0,0 +1,37 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "functional_test_utils/layer_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "ngraph_functions/utils/ngraph_helpers.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
using LSTMCellParams = typename std::tuple<
|
||||
bool, // using decompose to sub-ops transformation
|
||||
size_t, // batch
|
||||
size_t, // hidden size
|
||||
size_t, // input size
|
||||
std::vector<std::string>, // activations
|
||||
float, // clip
|
||||
InferenceEngine::Precision, // Network precision
|
||||
std::string>; // Device name
|
||||
|
||||
class LSTMCellTest : public testing::WithParamInterface<LSTMCellParams >,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<LSTMCellParams> &obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
@ -0,0 +1,37 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "functional_test_utils/layer_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "ngraph_functions/utils/ngraph_helpers.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
using RNNCellParams = typename std::tuple<
|
||||
bool, // using decompose to sub-ops transformation
|
||||
size_t, // batch
|
||||
size_t, // hidden size
|
||||
size_t, // input size
|
||||
std::vector<std::string>, // activations
|
||||
float, // clip
|
||||
InferenceEngine::Precision, // Network precision
|
||||
std::string>; // Device name
|
||||
|
||||
class RNNCellTest : public testing::WithParamInterface<RNNCellParams >,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<RNNCellParams> &obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
@ -0,0 +1,90 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
|
||||
#include "ie_core.hpp"
|
||||
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
#include "functional_test_utils/blob_utils.hpp"
|
||||
#include "functional_test_utils/precision_utils.hpp"
|
||||
#include "functional_test_utils/plugin_cache.hpp"
|
||||
#include "functional_test_utils/skip_tests_config.hpp"
|
||||
|
||||
#include <transformations/gru_cell_decomposition.hpp>
|
||||
#include "single_layer_tests/gru_cell.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
std::string GRUCellTest::getTestCaseName(const testing::TestParamInfo<GRUCellParams> &obj) {
|
||||
bool should_decompose;
|
||||
size_t batch;
|
||||
size_t hidden_size;
|
||||
size_t input_size;
|
||||
std::vector<std::string> activations;
|
||||
std::vector<float> activations_alpha;
|
||||
std::vector<float> activations_beta;
|
||||
float clip;
|
||||
bool linear_before_reset;
|
||||
std::vector<std::vector<size_t>> inputShapes;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::string targetDevice;
|
||||
std::tie(should_decompose, batch, hidden_size, input_size, activations, clip,
|
||||
linear_before_reset, netPrecision, targetDevice) = obj.param;
|
||||
std::ostringstream result;
|
||||
result << "decomposition" << should_decompose << "_";
|
||||
result << "batch=" << batch << "_";
|
||||
result << "hidden_size=" << hidden_size << "_";
|
||||
result << "input_size=" << input_size << "_";
|
||||
result << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_";
|
||||
result << "activations=" << CommonTestUtils::vec2str(activations) << "_";
|
||||
result << "clip=" << clip << "_";
|
||||
result << "linear_before_reset=" << linear_before_reset << "_";
|
||||
result << "netPRC=" << netPrecision.name() << "_";
|
||||
result << "targetDevice=" << targetDevice << "_";
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void GRUCellTest::SetUp() {
|
||||
bool should_decompose;
|
||||
size_t batch;
|
||||
size_t hidden_size;
|
||||
size_t input_size;
|
||||
std::vector<std::string> activations;
|
||||
std::vector<float> activations_alpha;
|
||||
std::vector<float> activations_beta;
|
||||
float clip;
|
||||
bool linear_before_reset;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::tie(should_decompose, batch, hidden_size, input_size, activations, clip, linear_before_reset,
|
||||
netPrecision, targetDevice) = this->GetParam();
|
||||
|
||||
std::vector<std::vector<size_t>> inputShapes = {
|
||||
{{batch, input_size}, {batch, hidden_size}, {3 * hidden_size, input_size},
|
||||
{3 * hidden_size, hidden_size}, {(linear_before_reset? 4 : 3) * hidden_size}},
|
||||
};
|
||||
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]});
|
||||
std::vector<ngraph::Shape> WRB = {inputShapes[2], inputShapes[3], inputShapes[4]};
|
||||
auto gru_cell = ngraph::builder::makeGRUCell(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)),
|
||||
WRB, hidden_size, activations, {}, {}, clip, linear_before_reset);
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(gru_cell->output(0))};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "gru_cell");
|
||||
if (should_decompose) {
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::GRUCellDecomposition>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_P(GRUCellTest, CompareWithRefs) {
|
||||
Run();
|
||||
};
|
||||
} // namespace LayerTestsDefinitions
|
@ -0,0 +1,89 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
|
||||
#include "ie_core.hpp"
|
||||
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
#include "functional_test_utils/blob_utils.hpp"
|
||||
#include "functional_test_utils/precision_utils.hpp"
|
||||
#include "functional_test_utils/plugin_cache.hpp"
|
||||
#include "functional_test_utils/skip_tests_config.hpp"
|
||||
|
||||
#include <transformations/lstm_cell_decomposition.hpp>
|
||||
#include "single_layer_tests/lstm_cell.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
std::string LSTMCellTest::getTestCaseName(const testing::TestParamInfo<LSTMCellParams> &obj) {
|
||||
bool should_decompose;
|
||||
size_t batch;
|
||||
size_t hidden_size;
|
||||
size_t input_size;
|
||||
std::vector<std::string> activations;
|
||||
std::vector<float> activations_alpha;
|
||||
std::vector<float> activations_beta;
|
||||
float clip;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::string targetDevice;
|
||||
std::tie(should_decompose, batch, hidden_size, input_size, activations, clip, netPrecision,
|
||||
targetDevice) = obj.param;
|
||||
std::vector<std::vector<size_t>> inputShapes = {
|
||||
{{batch, input_size}, {batch, hidden_size}, {batch, hidden_size}, {4 * hidden_size, input_size},
|
||||
{4 * hidden_size, hidden_size}, {4 * hidden_size}},
|
||||
};
|
||||
std::ostringstream result;
|
||||
result << "decomposition" << should_decompose << "_";
|
||||
result << "batch=" << batch << "_";
|
||||
result << "hidden_size=" << hidden_size << "_";
|
||||
result << "input_size=" << input_size << "_";
|
||||
result << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_";
|
||||
result << "activations=" << CommonTestUtils::vec2str(activations) << "_";
|
||||
result << "clip=" << clip << "_";
|
||||
result << "netPRC=" << netPrecision.name() << "_";
|
||||
result << "targetDevice=" << targetDevice << "_";
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void LSTMCellTest::SetUp() {
|
||||
bool should_decompose;
|
||||
size_t batch;
|
||||
size_t hidden_size;
|
||||
size_t input_size;
|
||||
std::vector<std::string> activations;
|
||||
std::vector<float> activations_alpha;
|
||||
std::vector<float> activations_beta;
|
||||
float clip;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::tie(should_decompose, batch, hidden_size, input_size, activations, clip, netPrecision,
|
||||
targetDevice) = this->GetParam();
|
||||
std::vector<std::vector<size_t>> inputShapes = {
|
||||
{{batch, input_size}, {batch, hidden_size}, {batch, hidden_size}, {4 * hidden_size, input_size},
|
||||
{4 * hidden_size, hidden_size}, {4 * hidden_size}},
|
||||
};
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1], inputShapes[2]});
|
||||
std::vector<ngraph::Shape> WRB = {inputShapes[3], inputShapes[4], inputShapes[5]};
|
||||
auto lstm_cell = ngraph::builder::makeLSTMCell(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)),
|
||||
WRB, hidden_size, activations, {}, {}, clip);
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(lstm_cell->output(0)),
|
||||
std::make_shared<ngraph::opset1::Result>(lstm_cell->output(1))};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "lstm_cell");
|
||||
if (should_decompose) {
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::LSTMCellDecomposition>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_P(LSTMCellTest, CompareWithRefs) {
|
||||
Run();
|
||||
};
|
||||
} // namespace LayerTestsDefinitions
|
@ -0,0 +1,82 @@
|
||||
// Copyright (C) 2019 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <functional>
|
||||
|
||||
#include "ie_core.hpp"
|
||||
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
#include "functional_test_utils/blob_utils.hpp"
|
||||
#include "functional_test_utils/precision_utils.hpp"
|
||||
#include "functional_test_utils/plugin_cache.hpp"
|
||||
#include "functional_test_utils/skip_tests_config.hpp"
|
||||
|
||||
#include <transformations/rnn_cell_decomposition.hpp>
|
||||
#include "single_layer_tests/rnn_cell.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
std::string RNNCellTest::getTestCaseName(const testing::TestParamInfo<RNNCellParams> &obj) {
|
||||
bool should_decompose;
|
||||
size_t batch;
|
||||
size_t hidden_size;
|
||||
size_t input_size;
|
||||
std::vector<std::string> activations;
|
||||
float clip;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::string targetDevice;
|
||||
std::tie(should_decompose, batch, hidden_size, input_size, activations, clip,
|
||||
netPrecision, targetDevice) = obj.param;
|
||||
std::vector<std::vector<size_t>> inputShapes = {{batch, input_size}, {batch, hidden_size},
|
||||
{hidden_size, input_size}, {hidden_size, hidden_size}, {hidden_size}};
|
||||
std::ostringstream result;
|
||||
result << "decomposition" << should_decompose << "_";
|
||||
result << "batch=" << batch << "_";
|
||||
result << "hidden_size=" << hidden_size << "_";
|
||||
result << "input_size=" << input_size << "_";
|
||||
result << "IS=" << CommonTestUtils::vec2str(inputShapes) << "_";
|
||||
result << "activations=" << CommonTestUtils::vec2str(activations) << "_";
|
||||
result << "clip=" << clip << "_";
|
||||
result << "netPRC=" << netPrecision.name() << "_";
|
||||
result << "targetDevice=" << targetDevice << "_";
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void RNNCellTest::SetUp() {
|
||||
bool should_decompose;
|
||||
size_t batch;
|
||||
size_t hidden_size;
|
||||
size_t input_size;
|
||||
std::vector<std::string> activations;
|
||||
std::vector<float> activations_alpha;
|
||||
std::vector<float> activations_beta;
|
||||
float clip;
|
||||
InferenceEngine::Precision netPrecision;
|
||||
std::tie(should_decompose, batch, hidden_size, input_size, activations, clip,
|
||||
netPrecision, targetDevice) = this->GetParam();
|
||||
std::vector<std::vector<size_t>> inputShapes = {{batch, input_size}, {batch, hidden_size},
|
||||
{hidden_size, input_size}, {hidden_size, hidden_size}, {hidden_size}};
|
||||
auto ngPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision);
|
||||
auto params = ngraph::builder::makeParams(ngPrc, {inputShapes[0], inputShapes[1]});
|
||||
std::vector<ngraph::Shape> WRB = {inputShapes[2], inputShapes[3], inputShapes[4]};
|
||||
auto rnn_cell = ngraph::builder::makeRNNCell(ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(params)),
|
||||
WRB, hidden_size, activations, {}, {}, clip);
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(rnn_cell)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "rnn_cell");
|
||||
if (should_decompose) {
|
||||
ngraph::pass::Manager m;
|
||||
m.register_pass<ngraph::pass::RNNCellDecomposition>();
|
||||
m.run_passes(function);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
TEST_P(RNNCellTest, CompareWithRefs) {
|
||||
Run();
|
||||
};
|
||||
} // namespace LayerTestsDefinitions
|
@ -72,7 +72,7 @@ void Basic_LSTM_S::SetUp() {
|
||||
//lstm [1, 10], [1, 118], [1, 118] -> [1, 118], [1, 118]
|
||||
outFormShapes1 = { batch_size, reshape1_shape[2] };
|
||||
auto constantX = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{2}, outFormShapes1);
|
||||
auto lstm1 = std::make_shared<ngraph::opset1::LSTMCell>(std::make_shared<ngraph::opset1::Reshape>(X, constantX, false),
|
||||
auto lstm1 = std::make_shared<ngraph::opset4::LSTMCell>(std::make_shared<ngraph::opset1::Reshape>(X, constantX, false),
|
||||
H_t, C_t,
|
||||
weightsNode, reccurrenceWeightsNode, hidden_size);
|
||||
|
||||
@ -137,7 +137,7 @@ std::shared_ptr<ngraph::Function> Basic_LSTM_S::CreateGraphWithUnrolledTI() {
|
||||
|
||||
ngraph::Output<ngraph::Node> H[iterations + 1];
|
||||
ngraph::Output<ngraph::Node> C[iterations + 1];
|
||||
std::shared_ptr<ngraph::opset1::LSTMCell> lstm[iterations];
|
||||
std::shared_ptr<ngraph::opset4::LSTMCell> lstm[iterations];
|
||||
H[0] = ngraph::builder::makeConstant<float>(ngPrc, { batch_size, hidden_size }, {}, true);
|
||||
C[0] = ngraph::builder::makeConstant<float>(ngPrc, { batch_size, hidden_size }, {}, true);
|
||||
auto reshape1_shape = reshape1->output(0).get_shape();
|
||||
@ -149,7 +149,7 @@ std::shared_ptr<ngraph::Function> Basic_LSTM_S::CreateGraphWithUnrolledTI() {
|
||||
|
||||
for (size_t i = 0; i < iterations; ++i) {
|
||||
auto X = split1->output(i);
|
||||
lstm[i] = std::make_shared<ngraph::opset1::LSTMCell>(std::make_shared<ngraph::opset1::Reshape>(X, constantX, false),
|
||||
lstm[i] = std::make_shared<ngraph::opset4::LSTMCell>(std::make_shared<ngraph::opset1::Reshape>(X, constantX, false),
|
||||
H[i], C[i],
|
||||
weightsNode, reccurrenceWeightsNode, hidden_size);
|
||||
|
||||
|
@ -389,5 +389,31 @@ std::shared_ptr<ngraph::Node> makePad(const ngraph::Output<Node>& data,
|
||||
std::shared_ptr<ngraph::Node> makeBatchNormInference(const ngraph::Output<Node>& data,
|
||||
double epsilon);
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeLSTMCell(const OutputVector& in,
|
||||
const std::vector<ngraph::Shape>& WRB,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations =
|
||||
std::vector<std::string>{"sigmoid", "tanh", "tanh"},
|
||||
const std::vector<float>& activations_alpha = {},
|
||||
const std::vector<float>& activations_beta = {},
|
||||
float clip = 0.f);
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeGRUCell(const OutputVector& in,
|
||||
const std::vector<ngraph::Shape>& WRB,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations =
|
||||
std::vector<std::string>{"sigmoid", "tanh"},
|
||||
const std::vector<float>& activations_alpha = {},
|
||||
const std::vector<float>& activations_beta = {},
|
||||
float clip = 0.f,
|
||||
bool linear_before_reset = false);
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeRNNCell(const OutputVector& in,
|
||||
const std::vector<ngraph::Shape>& WRB,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations = std::vector<std::string>{"tanh"},
|
||||
const std::vector<float>& activations_alpha = {},
|
||||
const std::vector<float>& activations_beta = {},
|
||||
float clip = 0.f);
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
||||
|
@ -130,7 +130,7 @@ static std::shared_ptr<ngraph::Function> makeTIwithLSTMcell(InferenceEngine::Pre
|
||||
inShape = {N, I};
|
||||
auto constantX = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{2}, inShape);
|
||||
auto LSTM_cell =
|
||||
std::make_shared<ngraph::opset1::LSTMCell>(std::make_shared<ngraph::opset1::Reshape>(X, constantX, false),
|
||||
std::make_shared<ngraph::opset4::LSTMCell>(std::make_shared<ngraph::opset1::Reshape>(X, constantX, false),
|
||||
std::make_shared<ngraph::opset1::Reshape>(H_t, constantH, false),
|
||||
std::make_shared<ngraph::opset1::Reshape>(C_t, constantH, false),
|
||||
W_body,
|
||||
|
30
inference-engine/tests/ngraph_functions/src/gru_cell.cpp
Normal file
30
inference-engine/tests/ngraph_functions/src/gru_cell.cpp
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeGRUCell(const OutputVector& in,
|
||||
const std::vector<ngraph::Shape>& WRB,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations,
|
||||
const std::vector<float>& activations_alpha,
|
||||
const std::vector<float>& activations_beta,
|
||||
float clip,
|
||||
bool linear_before_reset) {
|
||||
std::vector<float> empty;
|
||||
auto W = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[0], empty, true);
|
||||
auto R = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[1], empty, true);
|
||||
auto B = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[2], empty, true);
|
||||
return std::make_shared<ngraph::opset4::GRUCell>(in[0], in[1], W, R, B, hidden_size, activations,
|
||||
activations_alpha, activations_beta, clip, linear_before_reset);
|
||||
}
|
||||
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
29
inference-engine/tests/ngraph_functions/src/lstm_cell.cpp
Normal file
29
inference-engine/tests/ngraph_functions/src/lstm_cell.cpp
Normal file
@ -0,0 +1,29 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeLSTMCell(const std::vector<ngraph::Output<Node>>& in,
|
||||
const std::vector<ngraph::Shape>& WRB,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations,
|
||||
const std::vector<float>& activations_alpha,
|
||||
const std::vector<float>& activations_beta,
|
||||
float clip) {
|
||||
std::vector<float> empty;
|
||||
auto W = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[0], empty, true);
|
||||
auto R = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[1], empty, true);
|
||||
auto B = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[2], empty, true);
|
||||
return std::make_shared<ngraph::opset4::LSTMCell>(in[0], in[1], in[2], W, R, B, hidden_size, activations,
|
||||
activations_alpha, activations_beta, clip);
|
||||
}
|
||||
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
29
inference-engine/tests/ngraph_functions/src/rnn_cell.cpp
Normal file
29
inference-engine/tests/ngraph_functions/src/rnn_cell.cpp
Normal file
@ -0,0 +1,29 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
|
||||
std::shared_ptr<ngraph::Node> makeRNNCell(const OutputVector& in,
|
||||
const std::vector<ngraph::Shape>& WRB,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations,
|
||||
const std::vector<float>& activations_alpha,
|
||||
const std::vector<float>& activations_beta,
|
||||
float clip) {
|
||||
std::vector<float> empty;
|
||||
auto W = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[0], empty, true);
|
||||
auto R = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[1], empty, true);
|
||||
auto B = ngraph::builder::makeConstant(in[0].get_element_type(), WRB[2], empty, true);
|
||||
return std::make_shared<ngraph::opset4::RNNCell>(in[0], in[1], W, R, B, hidden_size, activations,
|
||||
activations_alpha, activations_beta, clip);
|
||||
}
|
||||
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
@ -42,7 +42,7 @@ class LSTMCell(Op):
|
||||
mandatory_props = {
|
||||
'type': __class__.op,
|
||||
'op': __class__.op,
|
||||
'version': 'opset1',
|
||||
'version': 'opset4',
|
||||
'infer': __class__.infer,
|
||||
'in_ports_count': 5,
|
||||
'out_ports_count': 2,
|
||||
|
@ -26,8 +26,6 @@
|
||||
#include "ngraph/op/util/fused_op.hpp"
|
||||
#include "ngraph/op/util/rnn_cell_base.hpp"
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
@ -42,7 +40,7 @@ namespace ngraph
|
||||
///
|
||||
/// Note this class represents only single *cell* and not whole GRU *layer*.
|
||||
///
|
||||
class NGRAPH_API GRUCell : public util::FusedOp, public util::RNNCellBase
|
||||
class NGRAPH_API GRUCell : public util::RNNCellBase
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"GRUCell", 3};
|
||||
@ -151,8 +149,6 @@ namespace ngraph
|
||||
|
||||
virtual void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
virtual void pre_validate_and_infer_types() override;
|
||||
virtual OutputVector decompose_op() const override;
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
@ -180,8 +176,5 @@ namespace ngraph
|
||||
bool m_linear_before_reset;
|
||||
};
|
||||
}
|
||||
using v3::GRUCell;
|
||||
}
|
||||
}
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
|
@ -69,7 +69,7 @@ namespace ngraph
|
||||
///
|
||||
/// \sa LSTMSequence, RNNCell, GRUCell
|
||||
///
|
||||
class NGRAPH_API LSTMCell : public util::FusedOp, public util::RNNCellBase
|
||||
class NGRAPH_API LSTMCell : public util::RNNCellBase
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"LSTMCell", 0};
|
||||
@ -216,24 +216,11 @@ namespace ngraph
|
||||
|
||||
virtual void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
virtual void pre_validate_and_infer_types() override;
|
||||
virtual OutputVector decompose_op() const override;
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
bool get_input_forget() const { return m_input_forget; }
|
||||
LSTMWeightsFormat get_weights_format() const { return m_weights_format; }
|
||||
///
|
||||
/// \brief Change data format of provided node into IFCO.
|
||||
///
|
||||
/// \node The IFCO format was chosen because it's default DNNL format.
|
||||
///
|
||||
/// \param[in] node The input node to be permuted.
|
||||
///
|
||||
/// \return Node representing reshaped tensor according to IFCO weights format.
|
||||
///
|
||||
std::shared_ptr<Node> convert_node_format(const Output<Node>& node) const;
|
||||
|
||||
private:
|
||||
///
|
||||
/// \brief Creates the default bias input initialized with zeros.
|
||||
@ -273,9 +260,149 @@ namespace ngraph
|
||||
static constexpr std::size_t s_gates_count{4};
|
||||
static constexpr std::size_t s_peepholes_count{3};
|
||||
};
|
||||
}
|
||||
using v0::LSTMCell;
|
||||
} // namespace op
|
||||
} // v0
|
||||
|
||||
namespace v4
|
||||
{
|
||||
///
|
||||
/// \brief Class for single lstm cell node.
|
||||
///
|
||||
/// \note Following implementation supports:
|
||||
/// \li \c peepholes Gers & Schmidhuber (2000)
|
||||
/// https://ieeexplore.ieee.org/document/861302
|
||||
/// \li Coupling input and forget gates.
|
||||
///
|
||||
/// \note It calculates following equations:
|
||||
///
|
||||
/// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
|
||||
/// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf)
|
||||
/// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
|
||||
/// Ct = ft (.) Ct-1 + it (.) ct
|
||||
/// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo)
|
||||
/// Ht = ot (.) h(Ct)
|
||||
///
|
||||
/// * - Is a dot product,
|
||||
/// (.) - is a Hadamard product (element-wise),
|
||||
/// f, g, h - are activation functions.
|
||||
///
|
||||
/// \note This class represents only single *cell* (for current time step) and not
|
||||
/// the whole LSTM Sequence layer
|
||||
///
|
||||
/// \sa LSTMSequence, RNNCell, GRUCell
|
||||
///
|
||||
class NGRAPH_API LSTMCell : public util::RNNCellBase
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"LSTMCell", 1};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
LSTMCell();
|
||||
///
|
||||
/// \brief Constructs LSTMCell node.
|
||||
///
|
||||
/// \param[in] X The input tensor with shape: [batch_size,
|
||||
/// input_size].
|
||||
/// \param[in] initial_hidden_state The hidden state tensor at current time step
|
||||
/// with shape: [batch_size, hidden_size].
|
||||
/// \param[in] initial_cell_state The cell state tensor at current time step
|
||||
/// with shape: [batch_size, hidden_size].
|
||||
/// \param[in] W The gate weights tensor with shape:
|
||||
/// [4*hidden_size, input_size].
|
||||
/// \param[in] R The recurrence weights tensor with shape:
|
||||
/// [4*hidden_size, hidden_size].
|
||||
/// \param[in] hidden_size The number of hidden units for recurrent cell.
|
||||
/// \param[in] activations The vector of activation functions used inside
|
||||
/// recurrent cell.
|
||||
/// \param[in] activations_alpha The vector of alpha parameters for activation
|
||||
/// functions in order respective to activation
|
||||
/// list.
|
||||
/// \param[in] activations_beta The vector of beta parameters for activation
|
||||
/// functions in order respective to activation
|
||||
/// list.
|
||||
/// \param[in] clip The value defining clipping range [-clip,
|
||||
/// clip] on input of activation functions.
|
||||
LSTMCell(const Output<Node>& X,
|
||||
const Output<Node>& initial_hidden_state,
|
||||
const Output<Node>& initial_cell_state,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations =
|
||||
std::vector<std::string>{"sigmoid", "tanh", "tanh"},
|
||||
const std::vector<float>& activations_alpha = {},
|
||||
const std::vector<float>& activations_beta = {},
|
||||
float clip = 0.f);
|
||||
|
||||
///
|
||||
/// \brief Constructs LSTMCell node.
|
||||
///
|
||||
/// \param[in] X The input tensor with shape: [batch_size,
|
||||
/// input_size].
|
||||
/// \param[in] initial_hidden_state The hidden state tensor at current time step
|
||||
/// with shape: [batch_size, hidden_size].
|
||||
/// \param[in] initial_cell_state The cell state tensor at current time step
|
||||
/// with shape: [batch_size, hidden_size].
|
||||
/// \param[in] W The weight tensor with shape: [4*hidden_size,
|
||||
/// input_size].
|
||||
/// \param[in] R The recurrence weight tensor with shape:
|
||||
/// [4*hidden_size, hidden_size].
|
||||
/// \param[in] B The bias tensor for gates with shape:
|
||||
/// [4*hidden_size].
|
||||
/// \param[in] hidden_size The number of hidden units for recurrent cell.
|
||||
/// \param[in] activations The vector of activation functions used inside
|
||||
/// recurrent cell.
|
||||
/// \param[in] activations_alpha The vector of alpha parameters for activation
|
||||
/// functions in order respective to activation
|
||||
/// list.
|
||||
/// \param[in] activations_beta The vector of beta parameters for activation
|
||||
/// functions in order respective to activation
|
||||
/// list.
|
||||
/// \param[in] clip The value defining clipping range [-clip,
|
||||
/// clip] on input of activation functions.
|
||||
///
|
||||
LSTMCell(const Output<Node>& X,
|
||||
const Output<Node>& initial_hidden_state,
|
||||
const Output<Node>& initial_cell_state,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
const Output<Node>& B,
|
||||
std::size_t hidden_size,
|
||||
const std::vector<std::string>& activations =
|
||||
std::vector<std::string>{"sigmoid", "tanh", "tanh"},
|
||||
const std::vector<float>& activations_alpha = {},
|
||||
const std::vector<float>& activations_beta = {},
|
||||
float clip = 0.f);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
private:
|
||||
///
|
||||
/// \brief Creates the default bias input initialized with zeros.
|
||||
///
|
||||
/// \return The object of Output class.
|
||||
///
|
||||
Output<Node> get_default_bias_input() const;
|
||||
|
||||
///
|
||||
/// \brief The Activation function f.
|
||||
///
|
||||
util::ActivationFunction m_activation_f;
|
||||
///
|
||||
/// \brief The Activation function g.
|
||||
///
|
||||
util::ActivationFunction m_activation_g;
|
||||
///
|
||||
/// \brief The Activation function h.
|
||||
///
|
||||
util::ActivationFunction m_activation_h;
|
||||
|
||||
static constexpr std::size_t s_gates_count{4};
|
||||
};
|
||||
} // v1
|
||||
} // namespace op
|
||||
|
||||
NGRAPH_API
|
||||
std::ostream& operator<<(std::ostream& s, const op::LSTMWeightsFormat& type);
|
||||
@ -294,5 +421,3 @@ namespace ngraph
|
||||
const DiscreteTypeInfo& get_type_info() const override { return type_info; }
|
||||
};
|
||||
} // namespace ngraph
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
|
@ -27,8 +27,7 @@
|
||||
#include "ngraph/op/lstm_cell.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "ngraph/op/util/fused_op.hpp"
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
#include "ngraph/op/util/rnn_cell_base.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -186,9 +185,66 @@ namespace ngraph
|
||||
LSTMWeightsFormat m_weights_format;
|
||||
};
|
||||
}
|
||||
using v0::LSTMSequence;
|
||||
|
||||
namespace v1
|
||||
{
|
||||
///
|
||||
/// \brief Class for lstm sequence node.
|
||||
///
|
||||
/// \note It follows notation and equations defined as in ONNX standard:
|
||||
/// https://github.com/onnx/onnx/blob/master/docs/Operators.md#LSTM
|
||||
///
|
||||
/// \sa LSTMCell, RNNCell, GRUCell
|
||||
///
|
||||
///
|
||||
class NGRAPH_API LSTMSequence : public util::RNNCellBase
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"LSTMSequence", 1};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
LSTMSequence() = default;
|
||||
|
||||
using direction = RecurrentSequenceDirection;
|
||||
|
||||
size_t get_default_output_index() const override { return no_default_index(); }
|
||||
explicit LSTMSequence(const Output<Node>& X,
|
||||
const Output<Node>& initial_hidden_state,
|
||||
const Output<Node>& initial_cell_state,
|
||||
const Output<Node>& sequence_lengths,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
const Output<Node>& B,
|
||||
const std::int64_t hidden_size,
|
||||
const direction lstm_direction,
|
||||
const std::vector<float> activations_alpha = {},
|
||||
const std::vector<float> activations_beta = {},
|
||||
const std::vector<std::string> activations = {"sigmoid",
|
||||
"tanh",
|
||||
"tanh"},
|
||||
const float clip = 0.f)
|
||||
: RNNCellBase(
|
||||
{X, initial_hidden_state, initial_cell_state, sequence_lengths, W, R, B},
|
||||
hidden_size,
|
||||
clip,
|
||||
activations,
|
||||
activations_alpha,
|
||||
activations_beta)
|
||||
, m_direction(lstm_direction)
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
|
||||
virtual std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
direction get_direction() const { return m_direction; }
|
||||
private:
|
||||
direction m_direction;
|
||||
};
|
||||
}
|
||||
} // namespace op
|
||||
|
||||
} // namespace ngraph
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
|
@ -26,8 +26,6 @@
|
||||
#include "ngraph/op/util/fused_op.hpp"
|
||||
#include "ngraph/op/util/rnn_cell_base.hpp"
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace op
|
||||
@ -52,7 +50,7 @@ namespace ngraph
|
||||
///
|
||||
/// \sa LSTMSequence, LSTMCell, GRUCell
|
||||
///
|
||||
class NGRAPH_API RNNCell : public util::FusedOp, public util::RNNCellBase
|
||||
class NGRAPH_API RNNCell : public util::RNNCellBase
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"RNNCell", 0};
|
||||
@ -129,11 +127,9 @@ namespace ngraph
|
||||
const std::vector<float>& activations_beta = {},
|
||||
float clip = 0.f);
|
||||
|
||||
virtual void validate_and_infer_types() override;
|
||||
void validate_and_infer_types() override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
virtual void pre_validate_and_infer_types() override;
|
||||
virtual OutputVector decompose_op() const override;
|
||||
virtual std::shared_ptr<Node>
|
||||
std::shared_ptr<Node>
|
||||
clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
private:
|
||||
@ -152,8 +148,5 @@ namespace ngraph
|
||||
static constexpr std::size_t s_gates_count{1};
|
||||
};
|
||||
}
|
||||
using v0::RNNCell;
|
||||
} // namespace op
|
||||
} // namespace ngraph
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_END
|
||||
|
@ -30,11 +30,39 @@ namespace ngraph
|
||||
{
|
||||
namespace util
|
||||
{
|
||||
enum class LSTMWeightsFormat
|
||||
{
|
||||
FICO, // IE
|
||||
ICOF, // PyTorch
|
||||
IFCO, // DNNL, TF, MxNet
|
||||
IFOC, // Caffe
|
||||
IOFC, // ONNX
|
||||
};
|
||||
|
||||
///
|
||||
/// \brief Change data format of provided node.
|
||||
///
|
||||
/// \param[in] node The input node to be permuted.
|
||||
///
|
||||
///
|
||||
/// \param[in] from_format Original node weights format.
|
||||
///
|
||||
///
|
||||
/// \param[in] to_format Weights format to convert to.
|
||||
///
|
||||
/// \return Node representing reshaped tensor according to `to_format` weights
|
||||
/// format.
|
||||
///
|
||||
std::shared_ptr<Node> NGRAPH_API
|
||||
convert_lstm_node_format(const Output<Node>& node,
|
||||
LSTMWeightsFormat from_format,
|
||||
LSTMWeightsFormat to_format = LSTMWeightsFormat::FICO);
|
||||
|
||||
/// \brief Base class for all recurrent network cells.
|
||||
///
|
||||
/// \note It holds all common attributes.
|
||||
///
|
||||
class NGRAPH_API RNNCellBase
|
||||
class NGRAPH_API RNNCellBase : public Op
|
||||
{
|
||||
public:
|
||||
///
|
||||
@ -50,7 +78,8 @@ namespace ngraph
|
||||
/// \param[in] activations_beta The vector of beta parameters for activation
|
||||
/// functions in order respective to activation list.
|
||||
///
|
||||
RNNCellBase(std::size_t hidden_size,
|
||||
RNNCellBase(const OutputVector& args,
|
||||
std::size_t hidden_size,
|
||||
float clip,
|
||||
const std::vector<std::string>& activations,
|
||||
const std::vector<float>& activations_alpha,
|
||||
|
@ -70,8 +70,7 @@ NGRAPH_OP(LogicalNot, ngraph::op::v1)
|
||||
NGRAPH_OP(LogicalOr, ngraph::op::v1)
|
||||
NGRAPH_OP(LogicalXor, ngraph::op::v1)
|
||||
NGRAPH_OP(LRN, ngraph::op::v0)
|
||||
NGRAPH_OP(LSTMCell, ngraph::op::v0)
|
||||
NGRAPH_OP(LSTMSequence, ngraph::op::v0)
|
||||
NGRAPH_OP(LSTMCell, ngraph::op::v4)
|
||||
NGRAPH_OP(MatMul, ngraph::op::v0)
|
||||
NGRAPH_OP(MaxPool, ngraph::op::v1)
|
||||
NGRAPH_OP(Maximum, ngraph::op::v1)
|
||||
|
@ -0,0 +1,316 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <ngraph/runtime/reference/add.hpp>
|
||||
#include <ngraph/runtime/reference/clamp.hpp>
|
||||
#include <ngraph/runtime/reference/matmul.hpp>
|
||||
#include <ngraph/runtime/reference/multiply.hpp>
|
||||
#include <ngraph/runtime/reference/relu.hpp>
|
||||
#include <ngraph/runtime/reference/sigmoid.hpp>
|
||||
#include <ngraph/runtime/reference/split.hpp>
|
||||
#include <ngraph/runtime/reference/subtract.hpp>
|
||||
#include <ngraph/runtime/reference/tanh.hpp>
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
template <typename T>
|
||||
void gru_cell(const T* X,
|
||||
const Shape& X_shape,
|
||||
const T* H,
|
||||
const Shape& H_shape,
|
||||
const T* W,
|
||||
const Shape& W_shape,
|
||||
const T* R,
|
||||
const Shape& R_shape,
|
||||
const T* B,
|
||||
const Shape& B_shape,
|
||||
T* dst_data,
|
||||
const std::string& activation_f,
|
||||
const std::string& activation_g,
|
||||
float clip,
|
||||
bool linear_before_reset)
|
||||
{
|
||||
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
|
||||
// The names used below are analogous to the one used in ONNX documentation.
|
||||
//
|
||||
// ------ ACRONYMS ------
|
||||
// z_t - update gate at current time step
|
||||
// r_t - reset gate at current time step
|
||||
// h_t - hidden gate at current time step
|
||||
// t - time step (t-1 means previous time step)
|
||||
// X The input data tensor. Shape: [batch_size, input_size].
|
||||
// W[zrh] - The weight tensor for update, reset and hidden gates.
|
||||
// Shape: [gates_count * hidden_size, input_size].
|
||||
// R[zrh] - The recurrence weight tensor for update, reset and hidden gates.
|
||||
// Shape: [gates_count * hidden_size, hidden_size].
|
||||
// H_t - The hidden state tensor at current time step. Shape: [batch_size,
|
||||
// hidden_size].
|
||||
// B - The sum of biases (weight and recurrence) for update, reset and hidden
|
||||
// gates.
|
||||
// If linear_before_reset := true then biases for hidden gates are placed
|
||||
// separately
|
||||
// (weight and recurrence).
|
||||
// Shape: [gates_count * hidden_size] when linear_before_reset := false
|
||||
// Shape: [(gates_count + 1) * hidden_size] when linear_before_reset :=
|
||||
// true
|
||||
// Wb[zrh] - W bias vectors for update, reset and hidden gates.
|
||||
// Rb[zrh] - R bias vectors for update, reset and hidden gates.
|
||||
|
||||
// (.) - Denotes element-wise multiplication.
|
||||
// * - Denotes dot product.
|
||||
|
||||
// ---- Equations ----
|
||||
// f, g - are activation functions
|
||||
// zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
|
||||
// rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
|
||||
// ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # when linear_before_reset
|
||||
// := false
|
||||
// # (default)
|
||||
// ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset
|
||||
// := true
|
||||
// Ht = (1 - zt) (.) ht + zt (.) Ht-1
|
||||
// -------------------
|
||||
|
||||
Shape gate_shape{X_shape[0], H_shape[1]};
|
||||
Shape all_gates_shape{X_shape[0], 3 * H_shape[1]};
|
||||
Shape bias_shape{H_shape[1], H_shape[1]};
|
||||
auto gate_shape_size = X_shape[0] * H_shape[1];
|
||||
auto all_gates_shape_size = gate_shape_size * 3;
|
||||
auto bias_shape_size = H_shape[1] * H_shape[1];
|
||||
|
||||
// Xt*(W^T)
|
||||
std::vector<T> Xt_W(all_gates_shape_size);
|
||||
reference::matmul(
|
||||
X, W, Xt_W.data(), X_shape, W_shape, all_gates_shape, false, true);
|
||||
|
||||
// Ht-1*(R^T)
|
||||
std::vector<T> Ht_R(all_gates_shape_size);
|
||||
reference::matmul(
|
||||
H, R, Ht_R.data(), H_shape, R_shape, all_gates_shape, false, true);
|
||||
|
||||
std::vector<std::vector<T>> X_W_zrh(3, std::vector<T>(gate_shape_size));
|
||||
std::vector<char*> pointers_XW = {reinterpret_cast<char*>(X_W_zrh[0].data()),
|
||||
reinterpret_cast<char*>(X_W_zrh[1].data()),
|
||||
reinterpret_cast<char*>(X_W_zrh[2].data())};
|
||||
std::vector<std::vector<T>> R_zrh(3, std::vector<T>(bias_shape_size));
|
||||
std::vector<char*> pointers_R = {reinterpret_cast<char*>(R_zrh[0].data()),
|
||||
reinterpret_cast<char*>(R_zrh[1].data()),
|
||||
reinterpret_cast<char*>(R_zrh[2].data())};
|
||||
std::vector<std::vector<T>> Ht_R_zrh(3, std::vector<T>(gate_shape_size));
|
||||
std::vector<char*> pointers_H_R = {reinterpret_cast<char*>(Ht_R_zrh[0].data()),
|
||||
reinterpret_cast<char*>(Ht_R_zrh[1].data()),
|
||||
reinterpret_cast<char*>(Ht_R_zrh[2].data())};
|
||||
|
||||
size_t num_b_splits = linear_before_reset ? 4 : 3;
|
||||
std::vector<std::vector<T>> biases_zrh(num_b_splits,
|
||||
std::vector<T>(B_shape[0] / num_b_splits));
|
||||
std::vector<char*> pointers_biases = {
|
||||
reinterpret_cast<char*>(biases_zrh[0].data()),
|
||||
reinterpret_cast<char*>(biases_zrh[1].data()),
|
||||
reinterpret_cast<char*>(biases_zrh[2].data())};
|
||||
if (linear_before_reset)
|
||||
{
|
||||
pointers_biases.push_back(reinterpret_cast<char*>(biases_zrh[3].data()));
|
||||
}
|
||||
|
||||
// split on gates
|
||||
reference::split(reinterpret_cast<char*>(Xt_W.data()),
|
||||
all_gates_shape,
|
||||
sizeof(T),
|
||||
1,
|
||||
3,
|
||||
pointers_XW.data());
|
||||
reference::split(
|
||||
reinterpret_cast<const char*>(R), R_shape, sizeof(T), 0, 3, pointers_R.data());
|
||||
reference::split(reinterpret_cast<char*>(Ht_R.data()),
|
||||
all_gates_shape,
|
||||
sizeof(T),
|
||||
1,
|
||||
3,
|
||||
pointers_H_R.data());
|
||||
reference::split(reinterpret_cast<const char*>(B),
|
||||
B_shape,
|
||||
sizeof(T),
|
||||
0,
|
||||
num_b_splits,
|
||||
pointers_biases.data());
|
||||
|
||||
auto clip_activation = [&clip](std::vector<T>& gate,
|
||||
const std::string& activation) {
|
||||
if (clip > 0.f)
|
||||
{
|
||||
reference::clamp(gate.data(),
|
||||
gate.data(),
|
||||
static_cast<T>(-clip),
|
||||
static_cast<T>(clip),
|
||||
gate.size());
|
||||
}
|
||||
if (activation == "relu")
|
||||
{
|
||||
reference::relu(gate.data(), gate.data(), gate.size());
|
||||
}
|
||||
else if (activation == "sigmoid")
|
||||
{
|
||||
reference::sigmoid(gate.data(), gate.data(), gate.size());
|
||||
}
|
||||
else if (activation == "tanh")
|
||||
{
|
||||
reference::tanh(gate.data(), gate.data(), gate.size());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("Activation function " + activation +
|
||||
" is not supported.");
|
||||
}
|
||||
};
|
||||
|
||||
// calculate z_t
|
||||
// steps:
|
||||
// Ht-1*(Rz^T) + Wbz + Rbz
|
||||
// Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz
|
||||
// zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
|
||||
std::vector<T> z_t(gate_shape_size);
|
||||
reference::add(Ht_R_zrh[0].data(),
|
||||
biases_zrh[0].data(),
|
||||
z_t.data(),
|
||||
gate_shape,
|
||||
{B_shape[0] / num_b_splits},
|
||||
op::AutoBroadcastSpec::NUMPY); //
|
||||
reference::add(X_W_zrh[0].data(),
|
||||
z_t.data(),
|
||||
z_t.data(),
|
||||
gate_shape,
|
||||
gate_shape,
|
||||
op::AutoBroadcastSpec::NUMPY); //
|
||||
clip_activation(z_t, activation_f);
|
||||
|
||||
// calculate r_t
|
||||
// steps:
|
||||
// Ht-1*(Rr^T) + Wbr + Rbr
|
||||
// Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr
|
||||
// rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
|
||||
std::vector<T> r_t(gate_shape_size);
|
||||
reference::add(Ht_R_zrh[1].data(),
|
||||
biases_zrh[1].data(),
|
||||
r_t.data(),
|
||||
gate_shape,
|
||||
{B_shape[0] / num_b_splits},
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
reference::add(X_W_zrh[1].data(),
|
||||
r_t.data(),
|
||||
r_t.data(),
|
||||
gate_shape,
|
||||
gate_shape,
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
clip_activation(r_t, activation_f);
|
||||
|
||||
// calculate h_t
|
||||
vector<T> h_t(gate_shape_size);
|
||||
if (linear_before_reset)
|
||||
{
|
||||
// ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
|
||||
reference::add(Ht_R_zrh[2].data(),
|
||||
biases_zrh[3].data(),
|
||||
h_t.data(),
|
||||
gate_shape,
|
||||
{B_shape[0] / num_b_splits},
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
reference::multiply(r_t.data(),
|
||||
h_t.data(),
|
||||
h_t.data(),
|
||||
gate_shape,
|
||||
gate_shape,
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
reference::add(h_t.data(),
|
||||
biases_zrh[2].data(),
|
||||
h_t.data(),
|
||||
gate_shape,
|
||||
{B_shape[0] / num_b_splits},
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
reference::add(X_W_zrh[2].data(),
|
||||
h_t.data(),
|
||||
h_t.data(),
|
||||
gate_shape,
|
||||
gate_shape,
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
}
|
||||
else
|
||||
{
|
||||
// ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
|
||||
reference::multiply(r_t.data(),
|
||||
H,
|
||||
h_t.data(),
|
||||
gate_shape,
|
||||
H_shape,
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
std::vector<T> matmul(gate_shape_size);
|
||||
reference::matmul(h_t.data(),
|
||||
R_zrh[2].data(),
|
||||
matmul.data(),
|
||||
gate_shape,
|
||||
bias_shape,
|
||||
gate_shape,
|
||||
false,
|
||||
true);
|
||||
reference::add(matmul.data(),
|
||||
biases_zrh[2].data(),
|
||||
h_t.data(),
|
||||
gate_shape,
|
||||
{B_shape[0] / num_b_splits},
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
reference::add(X_W_zrh[2].data(),
|
||||
h_t.data(),
|
||||
h_t.data(),
|
||||
gate_shape,
|
||||
gate_shape,
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
}
|
||||
clip_activation(h_t, activation_g);
|
||||
// Ht = (1 - zt) (.) ht + zt (.) Ht-1
|
||||
vector<T> mul1(gate_shape_size);
|
||||
vector<T> mul2(gate_shape_size);
|
||||
T one[] = {1};
|
||||
reference::subtract(
|
||||
one, z_t.data(), mul1.data(), {1}, gate_shape, op::AutoBroadcastSpec::NUMPY);
|
||||
reference::multiply(mul1.data(),
|
||||
h_t.data(),
|
||||
mul1.data(),
|
||||
gate_shape,
|
||||
gate_shape,
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
reference::multiply(z_t.data(),
|
||||
H,
|
||||
mul2.data(),
|
||||
gate_shape,
|
||||
gate_shape,
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
reference::add(mul1.data(),
|
||||
mul2.data(),
|
||||
dst_data,
|
||||
gate_shape,
|
||||
gate_shape,
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,217 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <ngraph/runtime/reference/add.hpp>
|
||||
#include <ngraph/runtime/reference/clamp.hpp>
|
||||
#include <ngraph/runtime/reference/matmul.hpp>
|
||||
#include <ngraph/runtime/reference/multiply.hpp>
|
||||
#include <ngraph/runtime/reference/relu.hpp>
|
||||
#include <ngraph/runtime/reference/sigmoid.hpp>
|
||||
#include <ngraph/runtime/reference/split.hpp>
|
||||
#include <ngraph/runtime/reference/subtract.hpp>
|
||||
#include <ngraph/runtime/reference/tanh.hpp>
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
template <typename T>
|
||||
void lstm_cell(const T* X,
|
||||
const Shape& X_shape,
|
||||
const T* H,
|
||||
const Shape& H_shape,
|
||||
const T* C,
|
||||
const Shape& C_shape,
|
||||
const T* W,
|
||||
const Shape& W_shape,
|
||||
const T* R,
|
||||
const Shape& R_shape,
|
||||
const T* B,
|
||||
const Shape& B_shape,
|
||||
T* out_Ht,
|
||||
T* out_Ct,
|
||||
const std::string& activation_f,
|
||||
const std::string& activation_g,
|
||||
const std::string& activation_h,
|
||||
float clip)
|
||||
{
|
||||
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
|
||||
// The names used below are analogous to the one used in ONNX documentation.
|
||||
//
|
||||
// ------ ACRONYMS ------
|
||||
// i - input gate
|
||||
// o - output gate
|
||||
// f - forget gate
|
||||
// c - cell gate
|
||||
// t - time step (t-1 means previous time step)
|
||||
// Wb - W bias vectors for input, output, forget, and cell gates.
|
||||
// Rb - R bias vectors for input, output, forget, and cell gates.
|
||||
// P - The peephole weights for input, output and forget gates.
|
||||
// ------ VARIABLE NAMES ------
|
||||
// X - The input data tensor. Shape: [batch_size, input_size].
|
||||
// W - The weight matrix for input, forget, cell and output gates
|
||||
// Shape: [4*hidden_size, input_size]
|
||||
// R - The recurrence weight matrix for input, forget, cell and output gates.
|
||||
// Shape: [4*hidden_size, hidden_size].
|
||||
// H_t - The hidden state tensor at current time step. Shape: [batch_size,
|
||||
// hidden_size].
|
||||
// C_t - The cell state tensor at current time step. Shape: [batch_size,
|
||||
// hidden_size].
|
||||
// bias - The sum of biases (weight and recurrence) for input, forget, cell and
|
||||
// output gates.
|
||||
// Shape: [4 * hidden_size]
|
||||
// p_[iof] - The peephole weight vector for respectively: input, output, and forget
|
||||
// gates.
|
||||
// Each peephole has shape [hidden_size].
|
||||
//
|
||||
// (.) - Denotes element-wise multiplication.
|
||||
// * - Denotes dot product.
|
||||
//
|
||||
// ---- Equations ----
|
||||
// f, g, h - are activation functions.
|
||||
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
|
||||
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf)
|
||||
// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
|
||||
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo)
|
||||
// Ct = ft (.) Ct-1 + it (.) ct
|
||||
// Ht = ot (.) h(Ct)
|
||||
// --------------------
|
||||
Shape gate_shape{X_shape[0], H_shape[1]};
|
||||
Shape all_gates_shape{X_shape[0], 4 * H_shape[1]};
|
||||
auto gate_shape_size = X_shape[0] * H_shape[1];
|
||||
auto all_gates_shape_size = gate_shape_size * 4;
|
||||
// Xt*(W^T)
|
||||
std::vector<T> Xt_W(all_gates_shape_size);
|
||||
reference::matmul(
|
||||
X, W, Xt_W.data(), X_shape, W_shape, all_gates_shape, false, true);
|
||||
|
||||
// Ht-1*(R^T)
|
||||
std::vector<T> Ht_R(all_gates_shape_size);
|
||||
reference::matmul(
|
||||
H, R, Ht_R.data(), H_shape, R_shape, all_gates_shape, false, true);
|
||||
|
||||
// Ht-1*(R^T) + Wb + Rb
|
||||
std::vector<T> Ht_R_B(all_gates_shape_size);
|
||||
reference::add(Ht_R.data(),
|
||||
B,
|
||||
Ht_R_B.data(),
|
||||
all_gates_shape,
|
||||
B_shape,
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
|
||||
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb
|
||||
std::vector<T> XHB(all_gates_shape_size);
|
||||
reference::add(Xt_W.data(),
|
||||
Ht_R_B.data(),
|
||||
XHB.data(),
|
||||
all_gates_shape,
|
||||
all_gates_shape,
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
|
||||
std::vector<std::vector<T>> X_W_fico(4, std::vector<T>(all_gates_shape_size / 4));
|
||||
std::vector<char*> pointers = {reinterpret_cast<char*>(X_W_fico[0].data()),
|
||||
reinterpret_cast<char*>(X_W_fico[1].data()),
|
||||
reinterpret_cast<char*>(X_W_fico[2].data()),
|
||||
reinterpret_cast<char*>(X_W_fico[3].data())};
|
||||
// split on gates
|
||||
reference::split(reinterpret_cast<char*>(XHB.data()),
|
||||
all_gates_shape,
|
||||
sizeof(T),
|
||||
1,
|
||||
4,
|
||||
pointers.data());
|
||||
|
||||
auto clip_activation = [&clip](
|
||||
std::vector<T>& gate, const std::string& activation, bool enable_clip = true) {
|
||||
if (clip > 0.f && enable_clip)
|
||||
{
|
||||
reference::clamp(gate.data(),
|
||||
gate.data(),
|
||||
static_cast<T>(-clip),
|
||||
static_cast<T>(clip),
|
||||
gate.size());
|
||||
}
|
||||
if (activation == "relu")
|
||||
{
|
||||
reference::relu(gate.data(), gate.data(), gate.size());
|
||||
}
|
||||
else if (activation == "sigmoid")
|
||||
{
|
||||
reference::sigmoid(gate.data(), gate.data(), gate.size());
|
||||
}
|
||||
else if (activation == "tanh")
|
||||
{
|
||||
reference::tanh(gate.data(), gate.data(), gate.size());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("Activation function " + activation +
|
||||
" is not supported.");
|
||||
}
|
||||
};
|
||||
|
||||
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Wbf + Rbf)
|
||||
clip_activation(X_W_fico[0], activation_f);
|
||||
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
|
||||
clip_activation(X_W_fico[1], activation_f);
|
||||
// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
|
||||
clip_activation(X_W_fico[2], activation_g);
|
||||
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Wbo + Rbo)
|
||||
clip_activation(X_W_fico[3], activation_f);
|
||||
|
||||
vector<T> mul1(gate_shape_size);
|
||||
vector<T> mul2(gate_shape_size);
|
||||
vector<T> Ct(gate_shape_size);
|
||||
// ft (.) Ct-1
|
||||
reference::multiply(X_W_fico[0].data(),
|
||||
C,
|
||||
mul1.data(),
|
||||
gate_shape,
|
||||
C_shape,
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
// it (.) ct
|
||||
reference::multiply(X_W_fico[1].data(),
|
||||
X_W_fico[2].data(),
|
||||
mul2.data(),
|
||||
gate_shape,
|
||||
gate_shape,
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
// Ct = ft (.) Ct-1 + it (.) ct
|
||||
reference::add(mul1.data(),
|
||||
mul2.data(),
|
||||
Ct.data(),
|
||||
gate_shape,
|
||||
gate_shape,
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
std::memcpy(out_Ct, Ct.data(), Ct.size() * sizeof(T));
|
||||
clip_activation(Ct, activation_h, false);
|
||||
|
||||
// Ht = ot (.) h(Ct)
|
||||
reference::multiply(X_W_fico[3].data(),
|
||||
Ct.data(),
|
||||
out_Ht,
|
||||
gate_shape,
|
||||
gate_shape,
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,132 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <ngraph/runtime/reference/add.hpp>
|
||||
#include <ngraph/runtime/reference/clamp.hpp>
|
||||
#include <ngraph/runtime/reference/matmul.hpp>
|
||||
#include <ngraph/runtime/reference/relu.hpp>
|
||||
#include <ngraph/runtime/reference/sigmoid.hpp>
|
||||
#include <ngraph/runtime/reference/tanh.hpp>
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
template <typename T>
|
||||
void rnn_cell(const T* X,
|
||||
const Shape& X_shape,
|
||||
const T* H,
|
||||
const Shape& H_shape,
|
||||
const T* W,
|
||||
const Shape& W_shape,
|
||||
const T* R,
|
||||
const Shape& R_shape,
|
||||
const T* B,
|
||||
const Shape& B_shape,
|
||||
T* dst_data,
|
||||
const std::string& activation_f,
|
||||
float clip)
|
||||
{
|
||||
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
|
||||
// The names used below are analogous to the one used in ONNX documentation.
|
||||
//
|
||||
// ------ ACRONYMS ------
|
||||
// i_t - input gate at current time step
|
||||
// t - time step (t-1 means previous time step)
|
||||
// X - The input data tensor. Shape: [batch_size, input_size].
|
||||
// W - The weight tensor for input gate. Shape: [hidden_size, input_size].
|
||||
// R - The recurrence weight tensor for input gate. Shape: [hidden_size,
|
||||
// hidden_size].
|
||||
// H_t - The hidden state tensor at current time step. Shape: [batch_size,
|
||||
// hidden_size].
|
||||
// B - The bias tensor for the input gate. Shape: [hidden_size].
|
||||
// Wb - W bias vectors for input gate.
|
||||
// Rb - R bias vectors for input gate.
|
||||
// ------ VARIABLE NAMES ------
|
||||
// Xt_W - Input sequence multiplied by weights tensor at current time step.
|
||||
// Ht_R - Hidden state multiplied by weights tensor at current time step.
|
||||
|
||||
// (.) - Denotes element-wise multiplication.
|
||||
// * - Denotes dot product.
|
||||
|
||||
// ---- Equations ----
|
||||
// f - is activation functions.
|
||||
// Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
|
||||
// --------------------
|
||||
|
||||
// Xt*(W^T)
|
||||
std::vector<T> Xt_W(X_shape[0] * W_shape[0]);
|
||||
reference::matmul(
|
||||
X, W, Xt_W.data(), X_shape, W_shape, {X_shape[0], W_shape[0]}, false, true);
|
||||
|
||||
// Ht-1*(R^T)
|
||||
std::vector<T> Ht_R(H_shape[0] * R_shape[0]);
|
||||
reference::matmul(
|
||||
H, R, Ht_R.data(), H_shape, R_shape, {H_shape[0], R_shape[0]}, false, true);
|
||||
|
||||
// Ht-1*(R^T) + Wb + Rb
|
||||
std::vector<T> Ht_R_B(H_shape[0] * R_shape[0]);
|
||||
reference::add(Ht_R.data(),
|
||||
B,
|
||||
Ht_R_B.data(),
|
||||
{H_shape[0], R_shape[0]},
|
||||
B_shape,
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
|
||||
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb
|
||||
std::vector<T> i_t(H_shape[0] * R_shape[0]);
|
||||
reference::add(Xt_W.data(),
|
||||
Ht_R_B.data(),
|
||||
i_t.data(),
|
||||
{X_shape[0], W_shape[0]},
|
||||
{H_shape[0], R_shape[0]},
|
||||
op::AutoBroadcastSpec::NUMPY);
|
||||
|
||||
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
|
||||
if (clip != 0.f)
|
||||
{
|
||||
reference::clamp(i_t.data(),
|
||||
i_t.data(),
|
||||
static_cast<T>(-clip),
|
||||
static_cast<T>(clip),
|
||||
i_t.size());
|
||||
}
|
||||
if (activation_f == "relu")
|
||||
{
|
||||
reference::relu(i_t.data(), dst_data, i_t.size());
|
||||
}
|
||||
else if (activation_f == "sigmoid")
|
||||
{
|
||||
reference::sigmoid(i_t.data(), dst_data, i_t.size());
|
||||
}
|
||||
else if (activation_f == "tanh")
|
||||
{
|
||||
reference::tanh(i_t.data(), dst_data, i_t.size());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("Activation function " + activation_f +
|
||||
" is not supported.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,37 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "ngraph/runtime/reference/slice.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
void split(const char* data,
|
||||
const Shape& data_shape,
|
||||
size_t elem_size,
|
||||
int64_t axis,
|
||||
size_t num_splits,
|
||||
char** out_data);
|
||||
}
|
||||
}
|
||||
}
|
54
ngraph/core/reference/src/runtime/reference/split.cpp
Normal file
54
ngraph/core/reference/src/runtime/reference/split.cpp
Normal file
@ -0,0 +1,54 @@
|
||||
//*****************************************************************************
|
||||
// Copyright 2017-2020 Intel Corporation
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include <cmath>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "ngraph/check.hpp"
|
||||
#include "ngraph/runtime/reference/split.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
void runtime::reference::split(const char* data,
|
||||
const Shape& data_shape,
|
||||
size_t elem_size,
|
||||
int64_t axis,
|
||||
size_t num_splits,
|
||||
char** out_data)
|
||||
{
|
||||
const size_t part_length = data_shape.at(axis) / num_splits;
|
||||
|
||||
Shape output_shape = data_shape;
|
||||
output_shape.at(axis) = part_length;
|
||||
|
||||
std::vector<size_t> lower_bounds(data_shape.size(), 0);
|
||||
std::vector<size_t> upper_bounds = data_shape;
|
||||
upper_bounds.at(axis) = part_length;
|
||||
|
||||
for (size_t i = 0; i < num_splits; ++i)
|
||||
{
|
||||
runtime::reference::slice(data,
|
||||
out_data[i],
|
||||
data_shape,
|
||||
lower_bounds,
|
||||
upper_bounds,
|
||||
Strides(lower_bounds.size(), 1),
|
||||
output_shape,
|
||||
elem_size);
|
||||
lower_bounds.at(axis) += part_length;
|
||||
upper_bounds.at(axis) += part_length;
|
||||
}
|
||||
}
|
@ -15,12 +15,9 @@
|
||||
//*****************************************************************************
|
||||
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
|
||||
#include "ngraph/builder/reshape.hpp"
|
||||
#include "ngraph/builder/split.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/dot.hpp"
|
||||
#include "ngraph/op/gru_cell.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
#include "ngraph/type/element_type.hpp"
|
||||
@ -28,8 +25,6 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
constexpr NodeTypeInfo op::v3::GRUCell::type_info;
|
||||
|
||||
op::v3::GRUCell::GRUCell()
|
||||
@ -68,8 +63,12 @@ op::v3::GRUCell::GRUCell(const Output<Node>& X,
|
||||
const vector<float>& activations_beta,
|
||||
float clip,
|
||||
bool linear_before_reset)
|
||||
: FusedOp({X, initial_hidden_state, W, R})
|
||||
, RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta)
|
||||
: RNNCellBase({X, initial_hidden_state, W, R},
|
||||
hidden_size,
|
||||
clip,
|
||||
activations,
|
||||
activations_alpha,
|
||||
activations_beta)
|
||||
, m_activation_f{get_activation_function(0)}
|
||||
, m_activation_g{get_activation_function(1)}
|
||||
, m_linear_before_reset{linear_before_reset}
|
||||
@ -89,8 +88,12 @@ op::v3::GRUCell::GRUCell(const Output<Node>& X,
|
||||
const vector<float>& activations_beta,
|
||||
float clip,
|
||||
bool linear_before_reset)
|
||||
: FusedOp({X, initial_hidden_state, W, R, B})
|
||||
, RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta)
|
||||
: RNNCellBase({X, initial_hidden_state, W, R, B},
|
||||
hidden_size,
|
||||
clip,
|
||||
activations,
|
||||
activations_alpha,
|
||||
activations_beta)
|
||||
, m_activation_f{get_activation_function(0)}
|
||||
, m_activation_g{get_activation_function(1)}
|
||||
, m_linear_before_reset{linear_before_reset}
|
||||
@ -104,83 +107,12 @@ bool op::v3::GRUCell::visit_attributes(AttributeVisitor& visitor)
|
||||
return op::util::RNNCellBase::visit_attributes(visitor);
|
||||
}
|
||||
|
||||
void op::v3::GRUCell::pre_validate_and_infer_types()
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
|
||||
|
||||
if (is_dynamic())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& x_pshape = get_input_partial_shape(0);
|
||||
const auto& ht_pshape = get_input_partial_shape(1);
|
||||
const auto& w_pshape = get_input_partial_shape(2);
|
||||
const auto& r_pshape = get_input_partial_shape(3);
|
||||
const auto& b_pshape = get_input_partial_shape(4);
|
||||
|
||||
const Shape& x_shape{x_pshape.to_shape()};
|
||||
|
||||
const size_t batch_size = x_shape.at(0);
|
||||
const size_t input_size = x_shape.at(1);
|
||||
|
||||
const Shape& w_shape{w_pshape.to_shape()};
|
||||
const Shape& r_shape{r_pshape.to_shape()};
|
||||
const Shape& ht_shape{ht_pshape.to_shape()};
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(w_shape == Shape{s_gates_count * get_hidden_size(), input_size}),
|
||||
"Input tensor W must have shape (",
|
||||
s_gates_count * get_hidden_size(),
|
||||
", ",
|
||||
input_size,
|
||||
"). Actual shape is:",
|
||||
w_shape,
|
||||
".");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(r_shape == Shape{s_gates_count * get_hidden_size(), get_hidden_size()}),
|
||||
"Input tensor R must have shape (",
|
||||
s_gates_count * get_hidden_size(),
|
||||
", ",
|
||||
get_hidden_size(),
|
||||
"). Actual shape is:",
|
||||
w_shape,
|
||||
".");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(ht_shape == Shape{batch_size, get_hidden_size()}),
|
||||
"Input tensor initial_hidden_state must have shape (",
|
||||
batch_size,
|
||||
", ",
|
||||
get_hidden_size(),
|
||||
"). Actual shape is:",
|
||||
w_shape,
|
||||
".");
|
||||
|
||||
const Shape& b_shape{b_pshape.to_shape()};
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
(b_shape == Shape{(s_gates_count + m_linear_before_reset) * get_hidden_size()}),
|
||||
"Input tensor B must have shape (",
|
||||
(s_gates_count + m_linear_before_reset) * get_hidden_size(),
|
||||
"). Actual shape is:",
|
||||
b_shape,
|
||||
".");
|
||||
}
|
||||
|
||||
void op::v3::GRUCell::validate_and_infer_types()
|
||||
{
|
||||
std::vector<ngraph::PartialShape> input_param{};
|
||||
|
||||
auto merged_batch_size = Dimension::dynamic();
|
||||
auto merged_hidden_size = Dimension::dynamic();
|
||||
auto result_et = element::dynamic;
|
||||
|
||||
// Copy all inputs for further validation
|
||||
for (size_t i = 0; i < get_input_size(); i++)
|
||||
{
|
||||
input_param.push_back(get_input_partial_shape(i));
|
||||
}
|
||||
|
||||
// Get input partial shape for all inputs
|
||||
const auto& x_pshape = get_input_partial_shape(0);
|
||||
const auto& ht_pshape = get_input_partial_shape(1);
|
||||
@ -188,7 +120,7 @@ void op::v3::GRUCell::validate_and_infer_types()
|
||||
const auto& r_pshape = get_input_partial_shape(3);
|
||||
const auto& b_pshape = get_input_partial_shape(4);
|
||||
|
||||
validate_input_rank_dimension(input_param);
|
||||
validate_input_rank_dimension({x_pshape, ht_pshape, w_pshape, r_pshape, b_pshape});
|
||||
|
||||
// Validate input types and save result for output type
|
||||
NODE_VALIDATION_CHECK(
|
||||
@ -265,90 +197,6 @@ void op::v3::GRUCell::validate_and_infer_types()
|
||||
set_output_type(0, result_et, {merged_batch_size, merged_hidden_size});
|
||||
}
|
||||
|
||||
OutputVector op::v3::GRUCell::decompose_op() const
|
||||
{
|
||||
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
|
||||
// The names used below are analogous to the one used in ONNX documentation.
|
||||
//
|
||||
// ------ ACRONYMS ------
|
||||
// z_t - update gate at current time step
|
||||
// r_t - reset gate at current time step
|
||||
// h_t - hidden gate at current time step
|
||||
// t - time step (t-1 means previous time step)
|
||||
// X The input data tensor. Shape: [batch_size, input_size].
|
||||
// W[zrh] - The weight tensor for update, reset and hidden gates.
|
||||
// Shape: [gates_count * hidden_size, input_size].
|
||||
// R[zrh] - The recurrence weight tensor for update, reset and hidden gates.
|
||||
// Shape: [gates_count * hidden_size, hidden_size].
|
||||
// H_t - The hidden state tensor at current time step. Shape: [batch_size, hidden_size].
|
||||
// B - The sum of biases (weight and recurrence) for update, reset and hidden gates.
|
||||
// If linear_before_reset := true then biases for hidden gates are placed separately
|
||||
// (weight and recurrence).
|
||||
// Shape: [gates_count * hidden_size] when linear_before_reset := false
|
||||
// Shape: [(gates_count + 1) * hidden_size] when linear_before_reset := true
|
||||
// Wb[zrh] - W bias vectors for update, reset and hidden gates.
|
||||
// Rb[zrh] - R bias vectors for update, reset and hidden gates.
|
||||
|
||||
// (.) - Denotes element-wise multiplication.
|
||||
// * - Denotes dot product.
|
||||
|
||||
// ---- Equations ----
|
||||
// f, g - are activation functions
|
||||
// zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
|
||||
// rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
|
||||
// ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # when linear_before_reset := false
|
||||
// # (default)
|
||||
// ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset := true
|
||||
// Ht = (1 - zt) (.) ht + zt (.) Ht-1
|
||||
// -------------------
|
||||
|
||||
Output<Node> X = input_value(0);
|
||||
Output<Node> H_t = input_value(1);
|
||||
Output<Node> W = input_value(2);
|
||||
Output<Node> R = input_value(3);
|
||||
Output<Node> B = input_value(4);
|
||||
|
||||
// Xt*(W^T)
|
||||
auto Xt_W = make_shared<op::Dot>(X, builder::opset1::transpose(W));
|
||||
auto R_transpose = builder::opset1::transpose(R);
|
||||
// Ht-1*(R^T)
|
||||
auto Ht_R = make_shared<op::Dot>(H_t, R_transpose);
|
||||
|
||||
// split to gates:
|
||||
OutputVector Xt_W_zrh = builder::split(Xt_W, 3, 1);
|
||||
OutputVector R_zrh = builder::split(R_transpose, 3, 1);
|
||||
OutputVector Ht_R_zrh = builder::split(Ht_R, 3, 1);
|
||||
OutputVector biases_zrh = m_linear_before_reset ? builder::split(B, 4) : builder::split(B, 3);
|
||||
|
||||
// zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz)
|
||||
auto z_t = m_activation_f(clip(add(Xt_W_zrh[0], add(Ht_R_zrh[0], biases_zrh[0]))));
|
||||
// rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr)
|
||||
auto r_t = m_activation_f(clip(add(Xt_W_zrh[1], add(Ht_R_zrh[1], biases_zrh[1]))));
|
||||
|
||||
Output<Node> h_t;
|
||||
if (m_linear_before_reset)
|
||||
{
|
||||
// ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh)
|
||||
auto Ht_Rh_Rbh = add(Ht_R_zrh[2], biases_zrh[3]);
|
||||
h_t = m_activation_g(clip(add(Xt_W_zrh[2], add(mul(r_t, Ht_Rh_Rbh), biases_zrh[2]))));
|
||||
}
|
||||
else
|
||||
{
|
||||
// ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh)
|
||||
auto rt_Ht = mul(r_t, H_t);
|
||||
auto rt_Ht_Rh = make_shared<op::Dot>(rt_Ht, R_zrh[2]);
|
||||
// Tensor shape: [batch_size, hidden_size]
|
||||
h_t = m_activation_g(clip(add(Xt_W_zrh[2], add(rt_Ht_Rh, biases_zrh[2]))));
|
||||
}
|
||||
|
||||
auto one = op::Constant::create(z_t->get_element_type(),
|
||||
z_t->get_shape(),
|
||||
vector<float>(shape_size(z_t->get_shape()), 1.f));
|
||||
// Ht = (1 - zt) (.) ht + zt (.) Ht-1
|
||||
H_t = add(mul(sub(one, z_t), h_t), mul(z_t, H_t));
|
||||
return {H_t.get_node_shared_ptr()};
|
||||
}
|
||||
|
||||
void op::v3::GRUCell::add_default_bias_input()
|
||||
{
|
||||
Output<Node> B = op::Constant::create(
|
||||
|
@ -18,12 +18,8 @@
|
||||
#include <functional>
|
||||
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/builder/reshape.hpp"
|
||||
#include "ngraph/builder/split.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/concat.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/dot.hpp"
|
||||
#include "ngraph/op/lstm_cell.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
#include "ngraph/type/element_type.hpp"
|
||||
@ -31,11 +27,10 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
constexpr NodeTypeInfo op::v4::LSTMCell::type_info;
|
||||
constexpr NodeTypeInfo op::v0::LSTMCell::type_info;
|
||||
|
||||
constexpr NodeTypeInfo op::LSTMCell::type_info;
|
||||
|
||||
op::LSTMCell::LSTMCell()
|
||||
op::v0::LSTMCell::LSTMCell()
|
||||
: m_input_forget(false)
|
||||
, m_weights_format(LSTMWeightsFormat::IFCO)
|
||||
{
|
||||
@ -45,20 +40,24 @@ op::LSTMCell::LSTMCell()
|
||||
m_activation_h = get_activation_function(2);
|
||||
}
|
||||
|
||||
op::LSTMCell::LSTMCell(const Output<Node>& X,
|
||||
const Output<Node>& initial_hidden_state,
|
||||
const Output<Node>& initial_cell_state,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
size_t hidden_size,
|
||||
op::LSTMWeightsFormat weights_format,
|
||||
const vector<string>& activations,
|
||||
const vector<float>& activations_alpha,
|
||||
const vector<float>& activations_beta,
|
||||
float clip,
|
||||
bool input_forget)
|
||||
: FusedOp({X, initial_hidden_state, initial_cell_state, W, R})
|
||||
, RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta)
|
||||
op::v0::LSTMCell::LSTMCell(const Output<Node>& X,
|
||||
const Output<Node>& initial_hidden_state,
|
||||
const Output<Node>& initial_cell_state,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
size_t hidden_size,
|
||||
op::LSTMWeightsFormat weights_format,
|
||||
const vector<string>& activations,
|
||||
const vector<float>& activations_alpha,
|
||||
const vector<float>& activations_beta,
|
||||
float clip,
|
||||
bool input_forget)
|
||||
: RNNCellBase({X, initial_hidden_state, initial_cell_state, W, R},
|
||||
hidden_size,
|
||||
clip,
|
||||
activations,
|
||||
activations_alpha,
|
||||
activations_beta)
|
||||
, m_activation_f{get_activation_function(0)}
|
||||
, m_activation_g{get_activation_function(1)}
|
||||
, m_activation_h{get_activation_function(2)}
|
||||
@ -70,21 +69,25 @@ op::LSTMCell::LSTMCell(const Output<Node>& X,
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
op::LSTMCell::LSTMCell(const Output<Node>& X,
|
||||
const Output<Node>& initial_hidden_state,
|
||||
const Output<Node>& initial_cell_state,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
const Output<Node>& B,
|
||||
size_t hidden_size,
|
||||
op::LSTMWeightsFormat weights_format,
|
||||
const vector<string>& activations,
|
||||
const vector<float>& activations_alpha,
|
||||
const vector<float>& activations_beta,
|
||||
float clip,
|
||||
bool input_forget)
|
||||
: FusedOp({X, initial_hidden_state, initial_cell_state, W, R, B})
|
||||
, RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta)
|
||||
op::v0::LSTMCell::LSTMCell(const Output<Node>& X,
|
||||
const Output<Node>& initial_hidden_state,
|
||||
const Output<Node>& initial_cell_state,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
const Output<Node>& B,
|
||||
size_t hidden_size,
|
||||
op::LSTMWeightsFormat weights_format,
|
||||
const vector<string>& activations,
|
||||
const vector<float>& activations_alpha,
|
||||
const vector<float>& activations_beta,
|
||||
float clip,
|
||||
bool input_forget)
|
||||
: RNNCellBase({X, initial_hidden_state, initial_cell_state, W, R, B},
|
||||
hidden_size,
|
||||
clip,
|
||||
activations,
|
||||
activations_alpha,
|
||||
activations_beta)
|
||||
, m_activation_f{get_activation_function(0)}
|
||||
, m_activation_g{get_activation_function(1)}
|
||||
, m_activation_h{get_activation_function(2)}
|
||||
@ -95,22 +98,26 @@ op::LSTMCell::LSTMCell(const Output<Node>& X,
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
op::LSTMCell::LSTMCell(const Output<Node>& X,
|
||||
const Output<Node>& initial_hidden_state,
|
||||
const Output<Node>& initial_cell_state,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
const Output<Node>& B,
|
||||
const Output<Node>& P,
|
||||
size_t hidden_size,
|
||||
op::LSTMWeightsFormat weights_format,
|
||||
const vector<string>& activations,
|
||||
const vector<float>& activations_alpha,
|
||||
const vector<float>& activations_beta,
|
||||
float clip,
|
||||
bool input_forget)
|
||||
: FusedOp({X, initial_hidden_state, initial_cell_state, W, R, B, P})
|
||||
, RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta)
|
||||
op::v0::LSTMCell::LSTMCell(const Output<Node>& X,
|
||||
const Output<Node>& initial_hidden_state,
|
||||
const Output<Node>& initial_cell_state,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
const Output<Node>& B,
|
||||
const Output<Node>& P,
|
||||
size_t hidden_size,
|
||||
op::LSTMWeightsFormat weights_format,
|
||||
const vector<string>& activations,
|
||||
const vector<float>& activations_alpha,
|
||||
const vector<float>& activations_beta,
|
||||
float clip,
|
||||
bool input_forget)
|
||||
: RNNCellBase({X, initial_hidden_state, initial_cell_state, W, R, B, P},
|
||||
hidden_size,
|
||||
clip,
|
||||
activations,
|
||||
activations_alpha,
|
||||
activations_beta)
|
||||
, m_activation_f{get_activation_function(0)}
|
||||
, m_activation_g{get_activation_function(1)}
|
||||
, m_activation_h{get_activation_function(2)}
|
||||
@ -133,101 +140,7 @@ bool ngraph::op::v0::LSTMCell::visit_attributes(AttributeVisitor& visitor)
|
||||
return true;
|
||||
}
|
||||
|
||||
void op::LSTMCell::pre_validate_and_infer_types()
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
|
||||
set_output_type(1, get_input_element_type(0), PartialShape::dynamic());
|
||||
if (is_dynamic())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& x_pshape = get_input_partial_shape(0);
|
||||
const auto& ht_pshape = get_input_partial_shape(1);
|
||||
const auto& ct_pshape = get_input_partial_shape(2);
|
||||
const auto& w_pshape = get_input_partial_shape(3);
|
||||
const auto& r_pshape = get_input_partial_shape(4);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(x_pshape.is_static() || w_pshape.is_static() || r_pshape.is_static() ||
|
||||
ht_pshape.is_static() || ct_pshape.is_static()),
|
||||
"LSTMCell supports only static input tensors.");
|
||||
|
||||
const Shape& x_shape{x_pshape.to_shape()};
|
||||
|
||||
const size_t batch_size = x_shape.at(0);
|
||||
const size_t input_size = x_shape.at(1);
|
||||
|
||||
const Shape& w_shape{w_pshape.to_shape()};
|
||||
const Shape& r_shape{r_pshape.to_shape()};
|
||||
const Shape& ht_shape{ht_pshape.to_shape()};
|
||||
const Shape& ct_shape{ct_pshape.to_shape()};
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(w_shape == Shape{s_gates_count * get_hidden_size(), input_size}),
|
||||
"Input tensor W must have shape (",
|
||||
s_gates_count * get_hidden_size(),
|
||||
", ",
|
||||
input_size,
|
||||
"). Actual shape is:",
|
||||
w_shape,
|
||||
".");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(r_shape == Shape{s_gates_count * get_hidden_size(), get_hidden_size()}),
|
||||
"Input tensor R must have shape (",
|
||||
s_gates_count * get_hidden_size(),
|
||||
", ",
|
||||
get_hidden_size(),
|
||||
"). Actual shape is:",
|
||||
r_shape,
|
||||
".");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(ht_shape == Shape{batch_size, get_hidden_size()}),
|
||||
"Input tensor initial_hidden_state must have shape (",
|
||||
batch_size,
|
||||
", ",
|
||||
get_hidden_size(),
|
||||
"). Actual shape is:",
|
||||
ht_shape,
|
||||
".");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(ct_shape == Shape{batch_size, get_hidden_size()}),
|
||||
"Input tensor initial_cell_state must have shape (",
|
||||
batch_size,
|
||||
", ",
|
||||
get_hidden_size(),
|
||||
"). Actual shape is:",
|
||||
ct_shape,
|
||||
".");
|
||||
|
||||
const auto& b_pshape = get_input_partial_shape(5);
|
||||
const auto& p_pshape = get_input_partial_shape(6);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(b_pshape.is_static() || p_pshape.is_static()),
|
||||
"LSTMCell supports only static input tensors.");
|
||||
|
||||
const Shape& b_shape{b_pshape.to_shape()};
|
||||
const Shape& p_shape{p_pshape.to_shape()};
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(b_shape == Shape{s_gates_count * get_hidden_size()}),
|
||||
"Input tensor B must have shape (",
|
||||
s_gates_count * get_hidden_size(),
|
||||
"). Actual shape is:",
|
||||
b_shape,
|
||||
".");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(p_shape == Shape{s_peepholes_count * get_hidden_size()}),
|
||||
"Input tensor P must have shape (",
|
||||
s_peepholes_count * get_hidden_size(),
|
||||
"). Actual shape is:",
|
||||
p_shape,
|
||||
".");
|
||||
}
|
||||
|
||||
void op::LSTMCell::validate_and_infer_types()
|
||||
void op::v0::LSTMCell::validate_and_infer_types()
|
||||
{
|
||||
std::vector<ngraph::PartialShape> input_param{};
|
||||
|
||||
@ -367,186 +280,69 @@ void op::LSTMCell::validate_and_infer_types()
|
||||
set_output_type(1, result_et, {merged_batch_size, merged_hidden_size});
|
||||
}
|
||||
|
||||
OutputVector op::LSTMCell::decompose_op() const
|
||||
{
|
||||
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
|
||||
// The names used below are analogous to the one used in ONNX documentation.
|
||||
//
|
||||
// ------ ACRONYMS ------
|
||||
// i - input gate
|
||||
// o - output gate
|
||||
// f - forget gate
|
||||
// c - cell gate
|
||||
// t - time step (t-1 means previous time step)
|
||||
// Wb - W bias vectors for input, output, forget, and cell gates.
|
||||
// Rb - R bias vectors for input, output, forget, and cell gates.
|
||||
// P - The peephole weights for input, output and forget gates.
|
||||
// ------ VARIABLE NAMES ------
|
||||
// X - The input data tensor. Shape: [batch_size, input_size].
|
||||
// W - The weight matrix for input, forget, cell and output gates
|
||||
// Shape: [4*hidden_size, input_size]
|
||||
// R - The recurrence weight matrix for input, forget, cell and output gates.
|
||||
// Shape: [4*hidden_size, hidden_size].
|
||||
// H_t - The hidden state tensor at current time step. Shape: [batch_size, hidden_size].
|
||||
// C_t - The cell state tensor at current time step. Shape: [batch_size, hidden_size].
|
||||
// bias - The sum of biases (weight and recurrence) for input, forget, cell and output gates.
|
||||
// Shape: [4 * hidden_size]
|
||||
// p_[iof] - The peephole weight vector for respectively: input, output, and forget gates.
|
||||
// Each peephole has shape [hidden_size].
|
||||
//
|
||||
// (.) - Denotes element-wise multiplication.
|
||||
// * - Denotes dot product.
|
||||
//
|
||||
// ---- Equations ----
|
||||
// f, g, h - are activation functions.
|
||||
// it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
|
||||
// ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
|
||||
// ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc)
|
||||
// Ct = ft (.) Ct-1 + it (.) ct
|
||||
// ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
|
||||
// Ht = ot (.) h(Ct)
|
||||
// --------------------
|
||||
|
||||
Output<Node> X = input_value(0);
|
||||
Output<Node> H_t = input_value(1);
|
||||
Output<Node> C_t = input_value(2);
|
||||
Output<Node> W = input_value(3);
|
||||
Output<Node> R = input_value(4);
|
||||
Output<Node> bias = input_value(5);
|
||||
OutputVector p_iof = builder::split(input_value(6), s_peepholes_count);
|
||||
|
||||
// Converting to IFCO format since it's DNNL default.
|
||||
if (m_weights_format != op::LSTMWeightsFormat::IFCO)
|
||||
{
|
||||
W = convert_node_format(W);
|
||||
R = convert_node_format(R);
|
||||
bias = convert_node_format(bias);
|
||||
}
|
||||
|
||||
const auto& p_i = p_iof.at(0);
|
||||
const auto& p_o = p_iof.at(1);
|
||||
const auto& p_f = p_iof.at(2);
|
||||
|
||||
// Xt*(W^T) -- for [iofc] gates.
|
||||
auto Xt_W = make_shared<op::Dot>(X, builder::opset1::transpose(W));
|
||||
// Ht-1*(R^T) -- for [iofc] gates.
|
||||
auto Ht_R = make_shared<op::Dot>(H_t, builder::opset1::transpose(R));
|
||||
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb -- for [iofc] gates.
|
||||
auto gates = add(Xt_W, add(Ht_R, bias));
|
||||
|
||||
OutputVector split_gates = builder::split(gates, 4, -1);
|
||||
auto i_t = split_gates.at(0);
|
||||
auto f_t = split_gates.at(1);
|
||||
auto c_t = split_gates.at(2);
|
||||
auto o_t = split_gates.at(3);
|
||||
|
||||
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi)
|
||||
i_t = m_activation_f(clip(add(i_t, mul(p_i, C_t))));
|
||||
if (m_input_forget)
|
||||
{
|
||||
// Couple input with forget gate: 1 - i_t
|
||||
f_t = sub(op::Constant::create(i_t.get_element_type(),
|
||||
i_t.get_shape(),
|
||||
vector<float>(shape_size(i_t.get_shape()), 1.f)),
|
||||
i_t);
|
||||
}
|
||||
else
|
||||
{
|
||||
// f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf)
|
||||
f_t = m_activation_f(clip(add(f_t, mul(p_f, C_t))));
|
||||
}
|
||||
// ft (.) Ct-1 + it (.) ct
|
||||
auto C = add(mul(f_t, C_t), mul(i_t, m_activation_g(clip(c_t))));
|
||||
// f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo)
|
||||
o_t = m_activation_f(clip(add(o_t, mul(p_o, C))));
|
||||
// ot (.) h(Ct)
|
||||
auto H = mul(o_t, m_activation_h(clip(C)));
|
||||
|
||||
return {H, C};
|
||||
}
|
||||
|
||||
Output<Node> op::LSTMCell::get_default_bias_input() const
|
||||
Output<Node> op::v0::LSTMCell::get_default_bias_input() const
|
||||
{
|
||||
return Output<Node>{op::Constant::create(
|
||||
get_input_element_type(0), Shape{s_gates_count * get_hidden_size()}, vector<float>{0.f})};
|
||||
}
|
||||
|
||||
Output<Node> op::LSTMCell::get_default_peepholes_input() const
|
||||
Output<Node> op::v0::LSTMCell::get_default_peepholes_input() const
|
||||
{
|
||||
return Output<Node>{op::Constant::create(get_input_element_type(0),
|
||||
Shape{s_peepholes_count * get_hidden_size()},
|
||||
vector<float>{0.f})};
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::LSTMCell::convert_node_format(const Output<Node>& node) const
|
||||
{
|
||||
static const std::map<op::LSTMWeightsFormat, std::vector<size_t>> gate_order_conversion_map{
|
||||
{op::LSTMWeightsFormat::FICO, {1, 0, 2, 3}},
|
||||
{op::LSTMWeightsFormat::ICOF, {0, 3, 1, 2}},
|
||||
{op::LSTMWeightsFormat::IFOC, {0, 1, 3, 2}},
|
||||
{op::LSTMWeightsFormat::IOFC, {0, 2, 3, 1}},
|
||||
};
|
||||
|
||||
OutputVector splitted_node = builder::split(node, s_gates_count);
|
||||
OutputVector nodes_in_new_format;
|
||||
nodes_in_new_format.reserve(s_gates_count);
|
||||
for (const auto& axis : gate_order_conversion_map.at(m_weights_format))
|
||||
{
|
||||
nodes_in_new_format.push_back(splitted_node.at(axis));
|
||||
}
|
||||
return make_shared<op::Concat>(nodes_in_new_format, 0);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::LSTMCell::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
shared_ptr<Node> op::v0::LSTMCell::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
if (new_args.size() == 5)
|
||||
{
|
||||
return make_shared<LSTMCell>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
new_args.at(3),
|
||||
new_args.at(4),
|
||||
get_hidden_size(),
|
||||
get_weights_format(),
|
||||
get_activations(),
|
||||
get_activations_alpha(),
|
||||
get_activations_beta(),
|
||||
get_clip(),
|
||||
m_input_forget);
|
||||
return make_shared<op::v0::LSTMCell>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
new_args.at(3),
|
||||
new_args.at(4),
|
||||
get_hidden_size(),
|
||||
get_weights_format(),
|
||||
get_activations(),
|
||||
get_activations_alpha(),
|
||||
get_activations_beta(),
|
||||
get_clip(),
|
||||
m_input_forget);
|
||||
}
|
||||
else if (new_args.size() == 6)
|
||||
{
|
||||
return make_shared<LSTMCell>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
new_args.at(3),
|
||||
new_args.at(4),
|
||||
new_args.at(5),
|
||||
get_hidden_size(),
|
||||
get_weights_format(),
|
||||
get_activations(),
|
||||
get_activations_alpha(),
|
||||
get_activations_beta(),
|
||||
get_clip(),
|
||||
m_input_forget);
|
||||
return make_shared<op::v0::LSTMCell>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
new_args.at(3),
|
||||
new_args.at(4),
|
||||
new_args.at(5),
|
||||
get_hidden_size(),
|
||||
get_weights_format(),
|
||||
get_activations(),
|
||||
get_activations_alpha(),
|
||||
get_activations_beta(),
|
||||
get_clip(),
|
||||
m_input_forget);
|
||||
}
|
||||
else if (new_args.size() == 7)
|
||||
{
|
||||
return make_shared<LSTMCell>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
new_args.at(3),
|
||||
new_args.at(4),
|
||||
new_args.at(5),
|
||||
new_args.at(6),
|
||||
get_hidden_size(),
|
||||
get_weights_format(),
|
||||
get_activations(),
|
||||
get_activations_alpha(),
|
||||
get_activations_beta(),
|
||||
get_clip(),
|
||||
m_input_forget);
|
||||
return make_shared<op::v0::LSTMCell>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
new_args.at(3),
|
||||
new_args.at(4),
|
||||
new_args.at(5),
|
||||
new_args.at(6),
|
||||
get_hidden_size(),
|
||||
get_weights_format(),
|
||||
get_activations(),
|
||||
get_activations_alpha(),
|
||||
get_activations_beta(),
|
||||
get_clip(),
|
||||
m_input_forget);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -576,3 +372,212 @@ namespace ngraph
|
||||
return s << as_string(type);
|
||||
}
|
||||
} // namespace ngraph
|
||||
|
||||
op::v4::LSTMCell::LSTMCell()
|
||||
{
|
||||
m_activations = {"sigmoid", "tanh", "tanh"};
|
||||
m_activation_f = get_activation_function(0);
|
||||
m_activation_g = get_activation_function(1);
|
||||
m_activation_h = get_activation_function(2);
|
||||
}
|
||||
|
||||
op::v4::LSTMCell::LSTMCell(const Output<Node>& X,
|
||||
const Output<Node>& initial_hidden_state,
|
||||
const Output<Node>& initial_cell_state,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
size_t hidden_size,
|
||||
const vector<string>& activations,
|
||||
const vector<float>& activations_alpha,
|
||||
const vector<float>& activations_beta,
|
||||
float clip)
|
||||
: RNNCellBase({X, initial_hidden_state, initial_cell_state, W, R},
|
||||
hidden_size,
|
||||
clip,
|
||||
activations,
|
||||
activations_alpha,
|
||||
activations_beta)
|
||||
, m_activation_f{get_activation_function(0)}
|
||||
, m_activation_g{get_activation_function(1)}
|
||||
, m_activation_h{get_activation_function(2)}
|
||||
{
|
||||
set_argument(5, get_default_bias_input());
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
op::v4::LSTMCell::LSTMCell(const Output<Node>& X,
|
||||
const Output<Node>& initial_hidden_state,
|
||||
const Output<Node>& initial_cell_state,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
const Output<Node>& B,
|
||||
size_t hidden_size,
|
||||
const vector<string>& activations,
|
||||
const vector<float>& activations_alpha,
|
||||
const vector<float>& activations_beta,
|
||||
float clip)
|
||||
: RNNCellBase({X, initial_hidden_state, initial_cell_state, W, R, B},
|
||||
hidden_size,
|
||||
clip,
|
||||
activations,
|
||||
activations_alpha,
|
||||
activations_beta)
|
||||
, m_activation_f{get_activation_function(0)}
|
||||
, m_activation_g{get_activation_function(1)}
|
||||
, m_activation_h{get_activation_function(2)}
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool ngraph::op::v4::LSTMCell::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
return op::util::RNNCellBase::visit_attributes(visitor);
|
||||
}
|
||||
|
||||
void op::v4::LSTMCell::validate_and_infer_types()
|
||||
{
|
||||
auto merged_batch_size = Dimension::dynamic();
|
||||
auto merged_hidden_size = Dimension::dynamic();
|
||||
auto result_et = element::dynamic;
|
||||
|
||||
// Get input partial shape for all inputs
|
||||
const auto& x_pshape = get_input_partial_shape(0);
|
||||
const auto& ht_pshape = get_input_partial_shape(1);
|
||||
const auto& ct_pshape = get_input_partial_shape(2);
|
||||
const auto& w_pshape = get_input_partial_shape(3);
|
||||
const auto& r_pshape = get_input_partial_shape(4);
|
||||
const auto& b_pshape = get_input_partial_shape(5);
|
||||
|
||||
// Validate rank and dimension for initial_cell_state input
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(ct_pshape.rank().is_static()),
|
||||
"LSTMCell input tensor initial_cell_state shall have static rank.");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(ct_pshape.rank().get_length() == 2),
|
||||
"LSTMCell input tensor initial_cell_state shall have dimension 2D.");
|
||||
|
||||
validate_input_rank_dimension({x_pshape, ht_pshape, w_pshape, r_pshape, b_pshape});
|
||||
|
||||
// Validate input element types and save result for output type
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(0)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(1)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(2)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(3)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(4)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(5)),
|
||||
"Element types for X, initial_hidden_state, initial_cell_state, W, R and B do not match.");
|
||||
|
||||
// Merge batch_size dimension across all inputs to evaluate output[0] dimension
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, ht_pshape[0]) &&
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, ct_pshape[0]) &&
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, x_pshape[0]),
|
||||
"Parameter batch_size not matched for X, initial_hidden_state or initial_cell_state "
|
||||
"inputs.");
|
||||
|
||||
// Merge hidden_size dimension across all inputs to evaluate output[1] dimension
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
Dimension::merge(merged_hidden_size, merged_hidden_size, ht_pshape[1]) &&
|
||||
Dimension::merge(merged_hidden_size, merged_hidden_size, ct_pshape[1]) &&
|
||||
Dimension::merge(merged_hidden_size, merged_hidden_size, r_pshape[1]),
|
||||
"Parameter hidden_size not matched for R, initial_hidden_state and initial_cell_state "
|
||||
"inputs.");
|
||||
|
||||
// Validate hidden_size value for W, R and P inputs
|
||||
if (merged_hidden_size.is_static())
|
||||
{
|
||||
if (w_pshape[0].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
w_pshape[0].compatible(merged_hidden_size * s_gates_count),
|
||||
"Parameter hidden_size mistmatched in W input. Current value is: ",
|
||||
w_pshape[0].get_length(),
|
||||
", expected: ",
|
||||
merged_hidden_size.get_length() * s_gates_count,
|
||||
".");
|
||||
}
|
||||
|
||||
if (r_pshape[0].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
r_pshape[0].compatible(merged_hidden_size * s_gates_count),
|
||||
"Parameter hidden_size mistmatched in R input. Current value is: ",
|
||||
r_pshape[0].get_length(),
|
||||
", expected: ",
|
||||
merged_hidden_size.get_length() * s_gates_count,
|
||||
".");
|
||||
}
|
||||
|
||||
if (b_pshape[0].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
b_pshape[0].compatible(merged_hidden_size * s_gates_count),
|
||||
"Parameter hidden_size mistmatched in B input. Current value is: ",
|
||||
b_pshape[0].get_length(),
|
||||
", expected: ",
|
||||
merged_hidden_size.get_length() * s_gates_count,
|
||||
".");
|
||||
}
|
||||
}
|
||||
|
||||
// Mark inputs which are relevant to output parameters
|
||||
set_input_is_relevant_to_shape(0);
|
||||
set_input_is_relevant_to_shape(1);
|
||||
set_input_is_relevant_to_shape(2);
|
||||
set_input_is_relevant_to_shape(4);
|
||||
|
||||
// Set output size, type and shape
|
||||
set_output_size(2);
|
||||
set_output_type(0, result_et, {merged_batch_size, merged_hidden_size});
|
||||
set_output_type(1, result_et, {merged_batch_size, merged_hidden_size});
|
||||
}
|
||||
|
||||
Output<Node> op::v4::LSTMCell::get_default_bias_input() const
|
||||
{
|
||||
return Output<Node>{op::Constant::create(
|
||||
get_input_element_type(0), Shape{s_gates_count * get_hidden_size()}, vector<float>{0.f})};
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v4::LSTMCell::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
if (new_args.size() == 5)
|
||||
{
|
||||
return make_shared<LSTMCell>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
new_args.at(3),
|
||||
new_args.at(4),
|
||||
get_hidden_size(),
|
||||
get_activations(),
|
||||
get_activations_alpha(),
|
||||
get_activations_beta(),
|
||||
get_clip());
|
||||
}
|
||||
else if (new_args.size() == 6)
|
||||
{
|
||||
return make_shared<LSTMCell>(new_args.at(0),
|
||||
new_args.at(1),
|
||||
new_args.at(2),
|
||||
new_args.at(3),
|
||||
new_args.at(4),
|
||||
new_args.at(5),
|
||||
get_hidden_size(),
|
||||
get_activations(),
|
||||
get_activations_alpha(),
|
||||
get_activations_beta(),
|
||||
get_clip());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("Incorrect number of new arguments");
|
||||
}
|
||||
}
|
||||
|
@ -22,13 +22,16 @@
|
||||
#include "ngraph/builder/split.hpp"
|
||||
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
|
||||
#include "ngraph/op/util/recurrent_sequence.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace std;
|
||||
|
||||
constexpr NodeTypeInfo op::v1::LSTMSequence::type_info;
|
||||
constexpr NodeTypeInfo op::v0::LSTMSequence::type_info;
|
||||
|
||||
bool ngraph::op::v0::LSTMSequence::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("hidden_size", m_hidden_size);
|
||||
@ -415,3 +418,165 @@ void op::v0::LSTMSequence::validate_and_infer_types()
|
||||
set_output_type(1, result_et, {merged_batch_size, merged_num_directions, merged_hidden_size});
|
||||
set_output_type(2, result_et, {merged_batch_size, merged_num_directions, merged_hidden_size});
|
||||
}
|
||||
|
||||
bool ngraph::op::v1::LSTMSequence::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
visitor.on_attribute("direction", m_direction);
|
||||
return op::util::RNNCellBase::visit_attributes(visitor);
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::v1::LSTMSequence::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
if (new_args.size() == 7)
|
||||
{
|
||||
return make_shared<op::v1::LSTMSequence>(new_args.at(0), // X
|
||||
new_args.at(1), // initial_hidden_state
|
||||
new_args.at(2), // initial_cell_state
|
||||
new_args.at(3), // sequence_lengths
|
||||
new_args.at(4), // W
|
||||
new_args.at(5), // R
|
||||
new_args.at(6), // B
|
||||
m_hidden_size,
|
||||
m_direction,
|
||||
m_activations_alpha,
|
||||
m_activations_beta,
|
||||
m_activations,
|
||||
m_clip);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error("Incorrect number of new arguments");
|
||||
}
|
||||
}
|
||||
|
||||
void op::v1::LSTMSequence::validate_and_infer_types()
|
||||
{
|
||||
std::vector<ngraph::PartialShape> input_param{};
|
||||
|
||||
auto lstm_seq_gates_count = 4;
|
||||
auto merged_batch_size = Dimension::dynamic();
|
||||
auto merged_hidden_size = Dimension::dynamic();
|
||||
auto merged_num_directions = Dimension::dynamic();
|
||||
auto result_et = element::dynamic;
|
||||
|
||||
// Copy all inputs without initial_cell_state information for further validation
|
||||
for (size_t i = 0; i < get_input_size(); i++)
|
||||
{
|
||||
// exclude initial_cell_state from the loop
|
||||
if (i != 2)
|
||||
{
|
||||
input_param.push_back(get_input_partial_shape(i));
|
||||
}
|
||||
}
|
||||
|
||||
// Get input partial shape for all inputs
|
||||
const auto& x_pshape = get_input_partial_shape(0);
|
||||
const auto& ht_pshape = get_input_partial_shape(1);
|
||||
const auto& ct_pshape = get_input_partial_shape(2);
|
||||
const auto& sl_pshape = get_input_partial_shape(3);
|
||||
const auto& w_pshape = get_input_partial_shape(4);
|
||||
const auto& r_pshape = get_input_partial_shape(5);
|
||||
const auto& b_pshape = get_input_partial_shape(6);
|
||||
|
||||
ngraph::op::util::validate_seq_input_rank_dimension(input_param);
|
||||
|
||||
// Validate rank and dimension for initial_cell_state input
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(ct_pshape.rank().is_static()),
|
||||
"LSTMSequence input tensor initial_cell_state shall have static rank.");
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(ct_pshape.rank().get_length() == 3),
|
||||
"LSTMSequence input tensor initial_cell_state shall have dimension 3D.");
|
||||
|
||||
// Validate input types and save result for output type
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(0)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(1)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(2)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(4)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(5)) &&
|
||||
element::Type::merge(result_et, result_et, get_input_element_type(6)),
|
||||
"Element types for X, initial_hidden_state, initial_cell_state, W, R and B inputs do not "
|
||||
"match.");
|
||||
|
||||
// Merge batch_size dimension across all inputs to evaluate output[0] dimension
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, ht_pshape[0]) &&
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, ct_pshape[0]) &&
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, x_pshape[0]) &&
|
||||
Dimension::merge(merged_batch_size, merged_batch_size, sl_pshape[0]),
|
||||
"Parameter batch_size not matched in LSTMSequence.");
|
||||
|
||||
// Merge hidden_size dimension across all inputs to evaluate output dimension
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
Dimension::merge(merged_hidden_size, merged_hidden_size, ht_pshape[2]) &&
|
||||
Dimension::merge(merged_hidden_size, merged_hidden_size, ct_pshape[2]) &&
|
||||
Dimension::merge(merged_hidden_size, merged_hidden_size, r_pshape[2]),
|
||||
"Parameter hidden_size not matched LSTMSequence.");
|
||||
|
||||
// Merge num_directions dimension across all inputs to evaluate output dimension
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, ht_pshape[1]) &&
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, ct_pshape[1]) &&
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, w_pshape[0]) &&
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, r_pshape[0]) &&
|
||||
Dimension::merge(merged_num_directions, merged_num_directions, b_pshape[0]),
|
||||
"Parameter num_directions not matched in LSTMSequence.");
|
||||
|
||||
// Validate hidden_size value for W, R, B inputs
|
||||
if (merged_hidden_size.is_static())
|
||||
{
|
||||
if (w_pshape[0].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
w_pshape[1].compatible(merged_hidden_size * lstm_seq_gates_count),
|
||||
"Parameter hidden_size mistmatched in W input. Current value is: ",
|
||||
w_pshape[1].get_length(),
|
||||
", expected: ",
|
||||
merged_hidden_size.get_length() * lstm_seq_gates_count,
|
||||
".");
|
||||
}
|
||||
|
||||
if (r_pshape[0].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
r_pshape[1].compatible(merged_hidden_size * lstm_seq_gates_count),
|
||||
"Parameter hidden_size mistmatched in R input. Current value is: ",
|
||||
r_pshape[1].get_length(),
|
||||
", expected: ",
|
||||
merged_hidden_size.get_length() * lstm_seq_gates_count,
|
||||
".");
|
||||
}
|
||||
|
||||
if (b_pshape[0].is_static())
|
||||
{
|
||||
NODE_VALIDATION_CHECK(
|
||||
this,
|
||||
b_pshape[1].compatible(merged_hidden_size * lstm_seq_gates_count),
|
||||
"Parameter hidden_size mistmatched in B input. Current value is: ",
|
||||
b_pshape[1].get_length(),
|
||||
", expected: ",
|
||||
merged_hidden_size.get_length() * lstm_seq_gates_count,
|
||||
".");
|
||||
}
|
||||
}
|
||||
|
||||
// Mark inputs which are relevant to output parameters
|
||||
for (size_t i = 0; i <= 6; ++i)
|
||||
set_input_is_relevant_to_shape(i);
|
||||
|
||||
// Set output size, type and shape
|
||||
set_output_size(3);
|
||||
set_output_type(
|
||||
0, result_et, {merged_batch_size, merged_num_directions, x_pshape[1], merged_hidden_size});
|
||||
set_output_type(1, result_et, {merged_batch_size, merged_num_directions, merged_hidden_size});
|
||||
set_output_type(2, result_et, {merged_batch_size, merged_num_directions, merged_hidden_size});
|
||||
}
|
||||
|
@ -14,156 +14,79 @@
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
|
||||
#include "ngraph/op/rnn_cell.hpp"
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/builder/reshape.hpp"
|
||||
#include "ngraph/builder/split.hpp"
|
||||
#include "ngraph/op/add.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/dot.hpp"
|
||||
#include "ngraph/op/rnn_cell.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
#include "ngraph/type/element_type.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
constexpr NodeTypeInfo op::v0::RNNCell::type_info;
|
||||
|
||||
constexpr NodeTypeInfo op::RNNCell::type_info;
|
||||
|
||||
op::RNNCell::RNNCell()
|
||||
op::v0::RNNCell::RNNCell()
|
||||
{
|
||||
m_activations = {"tanh"};
|
||||
m_activation_f = get_activation_function(0);
|
||||
}
|
||||
|
||||
op::RNNCell::RNNCell(const Output<Node>& X,
|
||||
const Output<Node>& initial_hidden_state,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
size_t hidden_size,
|
||||
const vector<string>& activations,
|
||||
const vector<float>& activations_alpha,
|
||||
const vector<float>& activations_beta,
|
||||
float clip)
|
||||
: FusedOp({X, initial_hidden_state, W, R})
|
||||
, RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta)
|
||||
op::v0::RNNCell::RNNCell(const Output<Node>& X,
|
||||
const Output<Node>& initial_hidden_state,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
size_t hidden_size,
|
||||
const vector<string>& activations,
|
||||
const vector<float>& activations_alpha,
|
||||
const vector<float>& activations_beta,
|
||||
float clip)
|
||||
: RNNCellBase({X, initial_hidden_state, W, R},
|
||||
hidden_size,
|
||||
clip,
|
||||
activations,
|
||||
activations_alpha,
|
||||
activations_beta)
|
||||
, m_activation_f{get_activation_function(0)}
|
||||
{
|
||||
set_argument(4, get_default_bias_input());
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
op::RNNCell::RNNCell(const Output<Node>& X,
|
||||
const Output<Node>& initial_hidden_state,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
const Output<Node>& B,
|
||||
size_t hidden_size,
|
||||
const vector<string>& activations,
|
||||
const vector<float>& activations_alpha,
|
||||
const vector<float>& activations_beta,
|
||||
float clip)
|
||||
: FusedOp({X, initial_hidden_state, W, R, B})
|
||||
, RNNCellBase(hidden_size, clip, activations, activations_alpha, activations_beta)
|
||||
op::v0::RNNCell::RNNCell(const Output<Node>& X,
|
||||
const Output<Node>& initial_hidden_state,
|
||||
const Output<Node>& W,
|
||||
const Output<Node>& R,
|
||||
const Output<Node>& B,
|
||||
size_t hidden_size,
|
||||
const vector<string>& activations,
|
||||
const vector<float>& activations_alpha,
|
||||
const vector<float>& activations_beta,
|
||||
float clip)
|
||||
: RNNCellBase({X, initial_hidden_state, W, R, B},
|
||||
hidden_size,
|
||||
clip,
|
||||
activations,
|
||||
activations_alpha,
|
||||
activations_beta)
|
||||
, m_activation_f{get_activation_function(0)}
|
||||
{
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
bool op::RNNCell::visit_attributes(AttributeVisitor& visitor)
|
||||
bool op::v0::RNNCell::visit_attributes(AttributeVisitor& visitor)
|
||||
{
|
||||
return op::util::RNNCellBase::visit_attributes(visitor);
|
||||
}
|
||||
|
||||
void op::RNNCell::pre_validate_and_infer_types()
|
||||
void op::v0::RNNCell::validate_and_infer_types()
|
||||
{
|
||||
set_output_type(0, get_input_element_type(0), PartialShape::dynamic());
|
||||
|
||||
if (is_dynamic())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& x_pshape = get_input_partial_shape(0);
|
||||
const auto& ht_pshape = get_input_partial_shape(1);
|
||||
const auto& w_pshape = get_input_partial_shape(2);
|
||||
const auto& r_pshape = get_input_partial_shape(3);
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(x_pshape.is_static() || w_pshape.is_static() || r_pshape.is_static() ||
|
||||
ht_pshape.is_static()),
|
||||
"RNNCell supports only static input tensors.");
|
||||
|
||||
const Shape& x_shape{x_pshape.to_shape()};
|
||||
|
||||
const size_t batch_size = x_shape.at(0);
|
||||
const size_t input_size = x_shape.at(1);
|
||||
|
||||
const Shape& w_shape{w_pshape.to_shape()};
|
||||
const Shape& r_shape{r_pshape.to_shape()};
|
||||
const Shape& ht_shape{ht_pshape.to_shape()};
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(w_shape == Shape{get_hidden_size(), input_size}),
|
||||
"Input tensor W must have shape (",
|
||||
get_hidden_size(),
|
||||
", ",
|
||||
input_size,
|
||||
"). Actual shape is:",
|
||||
w_shape,
|
||||
".");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(r_shape == Shape{get_hidden_size(), get_hidden_size()}),
|
||||
"Input tensor R must have shape (",
|
||||
get_hidden_size(),
|
||||
", ",
|
||||
get_hidden_size(),
|
||||
"). Actual shape is:",
|
||||
w_shape,
|
||||
".");
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(ht_shape == Shape{batch_size, get_hidden_size()}),
|
||||
"Input tensor initial_hidden_state must have shape (",
|
||||
batch_size,
|
||||
", ",
|
||||
get_hidden_size(),
|
||||
"). Actual shape is:",
|
||||
w_shape,
|
||||
".");
|
||||
|
||||
const auto& b_pshape = get_input_partial_shape(4);
|
||||
|
||||
NODE_VALIDATION_CHECK(
|
||||
this, b_pshape.is_static(), "RNNCell supports only static input tensors.");
|
||||
|
||||
const Shape& b_shape{b_pshape.to_shape()};
|
||||
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
(b_shape == Shape{get_hidden_size()}),
|
||||
"Input tensor B must have shape (",
|
||||
get_hidden_size(),
|
||||
"). Actual shape is:",
|
||||
b_shape,
|
||||
".");
|
||||
}
|
||||
|
||||
void op::RNNCell::validate_and_infer_types()
|
||||
{
|
||||
std::vector<ngraph::PartialShape> input_param{};
|
||||
|
||||
auto merged_batch_size = Dimension::dynamic();
|
||||
auto merged_hidden_size = Dimension::dynamic();
|
||||
auto result_et = element::dynamic;
|
||||
|
||||
// Copy all inputs for further validation
|
||||
for (size_t i = 0; i < get_input_size(); i++)
|
||||
{
|
||||
input_param.push_back(get_input_partial_shape(i));
|
||||
}
|
||||
|
||||
// Get input partial shape for all inputs
|
||||
const auto& x_pshape = get_input_partial_shape(0);
|
||||
const auto& ht_pshape = get_input_partial_shape(1);
|
||||
@ -171,7 +94,7 @@ void op::RNNCell::validate_and_infer_types()
|
||||
const auto& r_pshape = get_input_partial_shape(3);
|
||||
const auto& b_pshape = get_input_partial_shape(4);
|
||||
|
||||
validate_input_rank_dimension(input_param);
|
||||
validate_input_rank_dimension({x_pshape, ht_pshape, w_pshape, r_pshape, b_pshape});
|
||||
|
||||
// Validate input types and save result for output type
|
||||
NODE_VALIDATION_CHECK(
|
||||
@ -238,72 +161,23 @@ void op::RNNCell::validate_and_infer_types()
|
||||
}
|
||||
|
||||
// Mark inputs which are relevant to output parameters
|
||||
set_input_is_relevant_to_shape(0);
|
||||
set_input_is_relevant_to_shape(1);
|
||||
set_input_is_relevant_to_shape(2);
|
||||
set_input_is_relevant_to_shape(3);
|
||||
set_input_is_relevant_to_shape(4);
|
||||
for (size_t i = 0; i <= 4; ++i)
|
||||
set_input_is_relevant_to_shape(i);
|
||||
|
||||
// Set output size, type and shape
|
||||
set_output_size(1);
|
||||
set_output_type(0, result_et, {merged_batch_size, merged_hidden_size});
|
||||
}
|
||||
|
||||
OutputVector op::RNNCell::decompose_op() const
|
||||
{
|
||||
// ------ VARIABLE'S NAMES AND ACRONYM DEFINITIONS ------
|
||||
// The names used below are analogous to the one used in ONNX documentation.
|
||||
//
|
||||
// ------ ACRONYMS ------
|
||||
// i_t - input gate at current time step
|
||||
// t - time step (t-1 means previous time step)
|
||||
// X - The input data tensor. Shape: [batch_size, input_size].
|
||||
// W - The weight tensor for input gate. Shape: [hidden_size, input_size].
|
||||
// R - The recurrence weight tensor for input gate. Shape: [hidden_size, hidden_size].
|
||||
// H_t - The hidden state tensor at current time step. Shape: [batch_size, hidden_size].
|
||||
// B - The bias tensor for the input gate. Shape: [hidden_size].
|
||||
// Wb - W bias vectors for input gate.
|
||||
// Rb - R bias vectors for input gate.
|
||||
// ------ VARIABLE NAMES ------
|
||||
// Xt_W - Input sequence multiplied by weights tensor at current time step.
|
||||
// Ht_R - Hidden state multiplied by weights tensor at current time step.
|
||||
|
||||
// (.) - Denotes element-wise multiplication.
|
||||
// * - Denotes dot product.
|
||||
|
||||
// ---- Equations ----
|
||||
// f - is activation functions.
|
||||
// Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
|
||||
// --------------------
|
||||
|
||||
Output<Node> X = input_value(0);
|
||||
Output<Node> H_t = input_value(1);
|
||||
Output<Node> W = input_value(2);
|
||||
Output<Node> R = input_value(3);
|
||||
Output<Node> bias = input_value(4);
|
||||
|
||||
// Xt*(W^T)
|
||||
auto Xt_W = std::make_shared<op::Dot>(X, builder::opset1::transpose(W));
|
||||
// Ht-1*(R^T)
|
||||
auto Ht_R = std::make_shared<op::Dot>(H_t, builder::opset1::transpose(R));
|
||||
// Xt*(W^T) + Ht-1*(R^T) + Wb + Rb
|
||||
auto i_t = add(Xt_W, add(Ht_R, bias));
|
||||
|
||||
// f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi)
|
||||
i_t = m_activation_f(clip(i_t));
|
||||
|
||||
return {i_t};
|
||||
}
|
||||
|
||||
Output<Node> op::RNNCell::get_default_bias_input() const
|
||||
Output<Node> op::v0::RNNCell::get_default_bias_input() const
|
||||
{
|
||||
return Output<Node>{
|
||||
op::Constant::create(get_input_element_type(0),
|
||||
Shape{s_gates_count * get_hidden_size()},
|
||||
vector<float>(s_gates_count * get_hidden_size(), 0.f))};
|
||||
op::v0::Constant::create(get_input_element_type(0),
|
||||
Shape{s_gates_count * get_hidden_size()},
|
||||
vector<float>(s_gates_count * get_hidden_size(), 0.f))};
|
||||
}
|
||||
|
||||
shared_ptr<Node> op::RNNCell::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
shared_ptr<Node> op::v0::RNNCell::clone_with_new_inputs(const OutputVector& new_args) const
|
||||
{
|
||||
check_new_args_count(this, new_args);
|
||||
if (new_args.size() == 4)
|
||||
|
@ -13,8 +13,8 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//*****************************************************************************
|
||||
#include "ngraph/runtime/reference/split.hpp"
|
||||
#include <numeric>
|
||||
|
||||
#include "ngraph/attribute_visitor.hpp"
|
||||
#include "ngraph/builder/split.hpp"
|
||||
#include "ngraph/op/constant.hpp"
|
||||
@ -23,8 +23,6 @@
|
||||
#include "ngraph/validation_util.hpp"
|
||||
|
||||
#include "ngraph/runtime/host_tensor.hpp"
|
||||
#include "ngraph/runtime/reference/slice.hpp"
|
||||
|
||||
NGRAPH_SUPPRESS_DEPRECATED_START
|
||||
|
||||
using namespace std;
|
||||
@ -196,20 +194,25 @@ shared_ptr<Node> op::v1::Split::clone_with_new_inputs(const OutputVector& new_ar
|
||||
|
||||
namespace
|
||||
{
|
||||
inline bool evaluate(const HostTensorPtr& in,
|
||||
const HostTensorPtr& out,
|
||||
const Coordinate& lower_bounds,
|
||||
const Coordinate& upper_bounds)
|
||||
inline bool evaluate(const HostTensorPtr& data_tensor,
|
||||
const HostTensorVector& outputs,
|
||||
const int64_t axis,
|
||||
const int64_t num_splits)
|
||||
{
|
||||
runtime::reference::slice(in->get_data_ptr<const char>(),
|
||||
out->get_data_ptr<char>(),
|
||||
in->get_shape(),
|
||||
lower_bounds,
|
||||
upper_bounds,
|
||||
Strides(lower_bounds.size(), 1),
|
||||
out->get_shape(),
|
||||
in->get_element_type().size());
|
||||
|
||||
Shape output_shape = data_tensor->get_shape();
|
||||
std::vector<char*> outputs_data(num_splits);
|
||||
output_shape.at(axis) /= num_splits;
|
||||
for (size_t i = 0; i < outputs.size(); ++i)
|
||||
{
|
||||
outputs[i]->set_shape(output_shape);
|
||||
outputs_data[i] = outputs[i]->get_data_ptr<char>();
|
||||
}
|
||||
ngraph::runtime::reference::split(data_tensor->get_data_ptr<char>(),
|
||||
data_tensor->get_shape(),
|
||||
data_tensor->get_element_type().size(),
|
||||
axis,
|
||||
num_splits,
|
||||
outputs_data.data());
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -236,26 +239,7 @@ namespace
|
||||
break;
|
||||
}
|
||||
axis = ngraph::normalize_axis(split_node, axis, data_tensor->get_partial_shape().rank());
|
||||
|
||||
const auto data_shape = data_tensor->get_shape();
|
||||
const size_t axis_dim_length = data_shape.at(axis);
|
||||
const size_t part_length = axis_dim_length / num_splits;
|
||||
|
||||
Shape output_shape = data_shape;
|
||||
output_shape.at(axis) = part_length;
|
||||
|
||||
std::vector<size_t> lower_bounds(data_shape.size(), 0);
|
||||
std::vector<size_t> upper_bounds = data_shape;
|
||||
upper_bounds.at(axis) = part_length;
|
||||
|
||||
for (const auto& output : outputs)
|
||||
{
|
||||
output->set_shape(output_shape);
|
||||
evaluate(data_tensor, output, lower_bounds, upper_bounds);
|
||||
lower_bounds.at(axis) += part_length;
|
||||
upper_bounds.at(axis) += part_length;
|
||||
}
|
||||
|
||||
evaluate(data_tensor, outputs, axis, num_splits);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -24,11 +24,38 @@
|
||||
#include "ngraph/op/multiply.hpp"
|
||||
#include "ngraph/op/subtract.hpp"
|
||||
#include "ngraph/op/util/rnn_cell_base.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
#include "ngraph/util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
std::shared_ptr<Node> ngraph::op::util::convert_lstm_node_format(const Output<Node>& node,
|
||||
LSTMWeightsFormat from_format,
|
||||
LSTMWeightsFormat to_format)
|
||||
{
|
||||
static const std::map<op::util::LSTMWeightsFormat, std::vector<size_t>> gate_order_map{
|
||||
{op::util::LSTMWeightsFormat::FICO, {0, 1, 2, 3}},
|
||||
{op::util::LSTMWeightsFormat::ICOF, {1, 2, 3, 0}},
|
||||
{op::util::LSTMWeightsFormat::IFOC, {1, 0, 3, 2}},
|
||||
{op::util::LSTMWeightsFormat::IOFC, {1, 3, 0, 2}},
|
||||
{op::util::LSTMWeightsFormat::IFCO, {1, 0, 2, 3}},
|
||||
};
|
||||
const auto& from = gate_order_map.at(from_format);
|
||||
const auto& to = gate_order_map.at(to_format);
|
||||
size_t num_gates = 4;
|
||||
|
||||
auto axis_const = std::make_shared<opset4::Constant>(element::i64, Shape{}, 0);
|
||||
OutputVector splitted_node =
|
||||
std::make_shared<opset4::Split>(node, axis_const, num_gates)->outputs();
|
||||
OutputVector nodes_in_new_format(num_gates);
|
||||
for (size_t i = 0; i < num_gates; ++i)
|
||||
{
|
||||
nodes_in_new_format[to[from[i]]] = splitted_node[i];
|
||||
}
|
||||
return std::make_shared<opset4::Concat>(nodes_in_new_format, 0);
|
||||
}
|
||||
|
||||
// Modify input vector in-place and return reference to modified vector.
|
||||
static vector<string> to_lower_case(const vector<string>& vs)
|
||||
{
|
||||
@ -43,12 +70,14 @@ op::util::RNNCellBase::RNNCellBase()
|
||||
{
|
||||
}
|
||||
|
||||
op::util::RNNCellBase::RNNCellBase(size_t hidden_size,
|
||||
op::util::RNNCellBase::RNNCellBase(const OutputVector& args,
|
||||
size_t hidden_size,
|
||||
float clip,
|
||||
const vector<string>& activations,
|
||||
const vector<float>& activations_alpha,
|
||||
const vector<float>& activations_beta)
|
||||
: m_hidden_size(hidden_size)
|
||||
: Op(args)
|
||||
, m_hidden_size(hidden_size)
|
||||
, m_clip(clip)
|
||||
, m_activations(to_lower_case(activations))
|
||||
, m_activations_alpha(activations_alpha)
|
||||
|
@ -29,6 +29,7 @@
|
||||
#include "ngraph/op/constant.hpp"
|
||||
#include "ngraph/op/lstm_sequence.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "ngraph/opsets/opset3.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
#include "ngraph/type/element_type.hpp"
|
||||
#include "onnx_import/core/null_node.hpp"
|
||||
@ -212,7 +213,10 @@ namespace ngraph
|
||||
LSTMNgInputMap input_map{node};
|
||||
LSTMAttributes attributes{node};
|
||||
|
||||
auto lstmSequence = std::make_shared<default_opset::LSTMSequence>(
|
||||
// LSTMSequence is not fully supported in OpenVINO and is excluded from
|
||||
// opset4 (current the latest opset version), use one of the previous
|
||||
// opsets instead of default
|
||||
auto lstmSequence = std::make_shared<opset3::LSTMSequence>(
|
||||
input_map.at(LSTMInput::LSTM_INPUT_X),
|
||||
input_map.at(LSTMInput::LSTM_INPUT_INIT_H),
|
||||
input_map.at(LSTMInput::LSTM_INPUT_INIT_C),
|
||||
|
@ -82,7 +82,7 @@ from ngraph.opset1.ops import logical_not
|
||||
from ngraph.opset1.ops import logical_or
|
||||
from ngraph.opset1.ops import logical_xor
|
||||
from ngraph.opset1.ops import lrn
|
||||
from ngraph.opset1.ops import lstm_cell
|
||||
from ngraph.opset4.ops import lstm_cell
|
||||
from ngraph.opset1.ops import lstm_sequence
|
||||
from ngraph.opset1.ops import matmul
|
||||
from ngraph.opset1.ops import max_pool
|
||||
|
@ -367,3 +367,54 @@ def reduce_l2(
|
||||
return _get_node_factory_opset4().create(
|
||||
"ReduceL2", as_nodes(node, reduction_axes), {"keep_dims": keep_dims}
|
||||
)
|
||||
|
||||
|
||||
@nameable_op
|
||||
def lstm_cell(
|
||||
X: NodeInput,
|
||||
initial_hidden_state: NodeInput,
|
||||
initial_cell_state: NodeInput,
|
||||
W: NodeInput,
|
||||
R: NodeInput,
|
||||
B: NodeInput,
|
||||
hidden_size: int,
|
||||
activations: List[str] = None,
|
||||
activations_alpha: List[float] = None,
|
||||
activations_beta: List[float] = None,
|
||||
clip: float = 0.0,
|
||||
name: Optional[str] = None,
|
||||
) -> Node:
|
||||
"""Return a node which performs LSTMCell operation.
|
||||
|
||||
:param X: The input tensor with shape: [batch_size, input_size].
|
||||
:param initial_hidden_state: The hidden state tensor with shape: [batch_size, hidden_size].
|
||||
:param initial_cell_state: The cell state tensor with shape: [batch_size, hidden_size].
|
||||
:param W: The weight tensor with shape: [4*hidden_size, input_size].
|
||||
:param R: The recurrence weight tensor with shape: [4*hidden_size, hidden_size].
|
||||
:param B: The bias tensor for gates with shape: [4*hidden_size].
|
||||
:param hidden_size: Specifies hidden state size.
|
||||
:param activations: The list of three activation functions for gates.
|
||||
:param activations_alpha: The list of alpha parameters for activation functions.
|
||||
:param activations_beta: The list of beta parameters for activation functions.
|
||||
:param clip: Specifies bound values [-C, C] for tensor clipping performed before activations.
|
||||
:param name: An optional name of the output node.
|
||||
|
||||
:return: The new node represents LSTMCell. Node outputs count: 2.
|
||||
"""
|
||||
if activations is None:
|
||||
activations = ["sigmoid", "tanh", "tanh"]
|
||||
if activations_alpha is None:
|
||||
activations_alpha = []
|
||||
if activations_beta is None:
|
||||
activations_beta = []
|
||||
|
||||
node_inputs = as_nodes(X, initial_hidden_state, initial_cell_state, W, R, B)
|
||||
|
||||
attributes = {
|
||||
"hidden_size": hidden_size,
|
||||
"activations": activations,
|
||||
"activations_alpha": activations_alpha,
|
||||
"activations_beta": activations_beta,
|
||||
"clip": clip,
|
||||
}
|
||||
return _get_node_factory_opset4().create("LSTMCell", node_inputs, attributes)
|
||||
|
@ -18,6 +18,7 @@ import pytest
|
||||
from _pyngraph import PartialShape
|
||||
|
||||
import ngraph as ng
|
||||
import ngraph.opset1 as ng_opset1
|
||||
from ngraph.impl import Type
|
||||
|
||||
np_types = [np.float32, np.int32]
|
||||
@ -230,6 +231,62 @@ def test_lstm_cell_operator(dtype):
|
||||
assert list(node_param.get_output_shape(1)) == expected_shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
|
||||
def test_lstm_cell_operator_opset1(dtype):
|
||||
batch_size = 1
|
||||
input_size = 16
|
||||
hidden_size = 128
|
||||
|
||||
X_shape = [batch_size, input_size]
|
||||
H_t_shape = [batch_size, hidden_size]
|
||||
C_t_shape = [batch_size, hidden_size]
|
||||
W_shape = [4 * hidden_size, input_size]
|
||||
R_shape = [4 * hidden_size, hidden_size]
|
||||
B_shape = [4 * hidden_size]
|
||||
|
||||
parameter_X = ng.parameter(X_shape, name="X", dtype=dtype)
|
||||
parameter_H_t = ng.parameter(H_t_shape, name="H_t", dtype=dtype)
|
||||
parameter_C_t = ng.parameter(C_t_shape, name="C_t", dtype=dtype)
|
||||
parameter_W = ng.parameter(W_shape, name="W", dtype=dtype)
|
||||
parameter_R = ng.parameter(R_shape, name="R", dtype=dtype)
|
||||
parameter_B = ng.parameter(B_shape, name="B", dtype=dtype)
|
||||
|
||||
expected_shape = [1, 128]
|
||||
|
||||
node_default = ng_opset1.lstm_cell(
|
||||
parameter_X, parameter_H_t, parameter_C_t, parameter_W, parameter_R, parameter_B, hidden_size,
|
||||
)
|
||||
|
||||
assert node_default.get_type_name() == "LSTMCell"
|
||||
assert node_default.get_output_size() == 2
|
||||
assert list(node_default.get_output_shape(0)) == expected_shape
|
||||
assert list(node_default.get_output_shape(1)) == expected_shape
|
||||
|
||||
activations = ["tanh", "Sigmoid", "RELU"]
|
||||
activation_alpha = [1.0, 2.0, 3.0]
|
||||
activation_beta = [3.0, 2.0, 1.0]
|
||||
clip = 0.5
|
||||
|
||||
node_param = ng_opset1.lstm_cell(
|
||||
parameter_X,
|
||||
parameter_H_t,
|
||||
parameter_C_t,
|
||||
parameter_W,
|
||||
parameter_R,
|
||||
parameter_B,
|
||||
hidden_size,
|
||||
activations,
|
||||
activation_alpha,
|
||||
activation_beta,
|
||||
clip,
|
||||
)
|
||||
|
||||
assert node_param.get_type_name() == "LSTMCell"
|
||||
assert node_param.get_output_size() == 2
|
||||
assert list(node_param.get_output_shape(0)) == expected_shape
|
||||
assert list(node_param.get_output_shape(1)) == expected_shape
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
|
||||
def test_lstm_sequence_operator_bidirectional(dtype):
|
||||
batch_size = 1
|
||||
@ -255,7 +312,7 @@ def test_lstm_sequence_operator_bidirectional(dtype):
|
||||
parameter_B = ng.parameter(B_shape, name="B", dtype=dtype)
|
||||
|
||||
direction = "BIDIRECTIONAL"
|
||||
node = ng.lstm_sequence(
|
||||
node = ng_opset1.lstm_sequence(
|
||||
parameter_X,
|
||||
parameter_H_t,
|
||||
parameter_C_t,
|
||||
@ -275,7 +332,7 @@ def test_lstm_sequence_operator_bidirectional(dtype):
|
||||
activation_beta = [3.0, 2.0, 1.0]
|
||||
clip = 1.22
|
||||
|
||||
node_param = ng.lstm_sequence(
|
||||
node_param = ng_opset1.lstm_sequence(
|
||||
parameter_X,
|
||||
parameter_H_t,
|
||||
parameter_C_t,
|
||||
@ -321,7 +378,7 @@ def test_lstm_sequence_operator_reverse(dtype):
|
||||
|
||||
direction = "REVERSE"
|
||||
|
||||
node_default = ng.lstm_sequence(
|
||||
node_default = ng_opset1.lstm_sequence(
|
||||
parameter_X,
|
||||
parameter_H_t,
|
||||
parameter_C_t,
|
||||
@ -341,7 +398,7 @@ def test_lstm_sequence_operator_reverse(dtype):
|
||||
activation_beta = [3.0, 2.0, 1.0]
|
||||
clip = 1.22
|
||||
|
||||
node_param = ng.lstm_sequence(
|
||||
node_param = ng_opset1.lstm_sequence(
|
||||
parameter_X,
|
||||
parameter_H_t,
|
||||
parameter_C_t,
|
||||
@ -387,7 +444,7 @@ def test_lstm_sequence_operator_forward(dtype):
|
||||
|
||||
direction = "forward"
|
||||
|
||||
node_default = ng.lstm_sequence(
|
||||
node_default = ng_opset1.lstm_sequence(
|
||||
parameter_X,
|
||||
parameter_H_t,
|
||||
parameter_C_t,
|
||||
@ -407,7 +464,7 @@ def test_lstm_sequence_operator_forward(dtype):
|
||||
activation_beta = [1.0]
|
||||
clip = 0.5
|
||||
|
||||
node = ng.lstm_sequence(
|
||||
node = ng_opset1.lstm_sequence(
|
||||
parameter_X,
|
||||
parameter_H_t,
|
||||
parameter_C_t,
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
#include "ngraph/opsets/opset3.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
|
||||
#include "util/visitor.hpp"
|
||||
|
||||
@ -1063,7 +1064,7 @@ TEST(attributes, lrn_op)
|
||||
|
||||
TEST(attributes, lstm_cell_op)
|
||||
{
|
||||
FactoryRegistry<Node>::get().register_factory<opset1::LSTMCell>();
|
||||
FactoryRegistry<Node>::get().register_factory<opset4::LSTMCell>();
|
||||
auto X = make_shared<op::Parameter>(element::f32, Shape{2, 3});
|
||||
auto H = make_shared<op::Parameter>(element::f32, Shape{2, 3});
|
||||
auto W = make_shared<op::Parameter>(element::f32, Shape{12, 3});
|
||||
@ -1072,40 +1073,33 @@ TEST(attributes, lstm_cell_op)
|
||||
const auto initial_cell_state = make_shared<op::Parameter>(element::f32, Shape{2, 3});
|
||||
|
||||
const auto hidden_size = 3;
|
||||
const auto weights_format = op::LSTMWeightsFormat::ICOF;
|
||||
const std::vector<std::string> activations = {"tanh", "sigmoid", "tanh"};
|
||||
auto activations_alpha = std::vector<float>{1.0, 1.5};
|
||||
auto activations_beta = std::vector<float>{2.0, 1.0};
|
||||
const float clip = 0.5f;
|
||||
bool input_forget = true;
|
||||
|
||||
const auto lstm_cell = make_shared<opset1::LSTMCell>(X,
|
||||
const auto lstm_cell = make_shared<opset4::LSTMCell>(X,
|
||||
initial_hidden_state,
|
||||
initial_cell_state,
|
||||
W,
|
||||
R,
|
||||
hidden_size,
|
||||
weights_format,
|
||||
activations,
|
||||
activations_alpha,
|
||||
activations_beta,
|
||||
clip,
|
||||
input_forget);
|
||||
clip);
|
||||
NodeBuilder builder(lstm_cell);
|
||||
auto g_lstm_cell = as_type_ptr<opset1::LSTMCell>(builder.create());
|
||||
auto g_lstm_cell = as_type_ptr<opset4::LSTMCell>(builder.create());
|
||||
|
||||
EXPECT_EQ(g_lstm_cell->get_hidden_size(), lstm_cell->get_hidden_size());
|
||||
EXPECT_EQ(g_lstm_cell->get_activations(), lstm_cell->get_activations());
|
||||
EXPECT_EQ(g_lstm_cell->get_activations_alpha(), lstm_cell->get_activations_alpha());
|
||||
EXPECT_EQ(g_lstm_cell->get_activations_beta(), lstm_cell->get_activations_beta());
|
||||
EXPECT_EQ(g_lstm_cell->get_clip(), lstm_cell->get_clip());
|
||||
EXPECT_EQ(g_lstm_cell->get_input_forget(), lstm_cell->get_input_forget());
|
||||
EXPECT_EQ(g_lstm_cell->get_weights_format(), lstm_cell->get_weights_format());
|
||||
}
|
||||
|
||||
TEST(attributes, lstm_sequence_op)
|
||||
{
|
||||
FactoryRegistry<Node>::get().register_factory<opset1::LSTMSequence>();
|
||||
FactoryRegistry<Node>::get().register_factory<op::v1::LSTMSequence>();
|
||||
|
||||
const size_t batch_size = 4;
|
||||
const size_t num_directions = 2;
|
||||
@ -1127,14 +1121,12 @@ TEST(attributes, lstm_sequence_op)
|
||||
const auto B = make_shared<op::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
|
||||
|
||||
const auto lstm_direction = op::RecurrentSequenceDirection::BIDIRECTIONAL;
|
||||
const auto weights_format = op::LSTMWeightsFormat::ICOF;
|
||||
const std::vector<float> activations_alpha = {1, 2, 3};
|
||||
const std::vector<float> activations_beta = {4, 5, 6};
|
||||
const std::vector<std::string> activations = {"tanh", "sigmoid", "tanh"};
|
||||
const float clip_threshold = 0.5f;
|
||||
const bool input_forget = true;
|
||||
|
||||
const auto lstm_sequence = make_shared<opset1::LSTMSequence>(X,
|
||||
const auto lstm_sequence = make_shared<op::v1::LSTMSequence>(X,
|
||||
initial_hidden_state,
|
||||
initial_cell_state,
|
||||
sequence_lengths,
|
||||
@ -1143,23 +1135,19 @@ TEST(attributes, lstm_sequence_op)
|
||||
B,
|
||||
hidden_size,
|
||||
lstm_direction,
|
||||
weights_format,
|
||||
activations_alpha,
|
||||
activations_beta,
|
||||
activations,
|
||||
clip_threshold,
|
||||
input_forget);
|
||||
clip_threshold);
|
||||
NodeBuilder builder(lstm_sequence);
|
||||
auto g_lstm_sequence = as_type_ptr<opset1::LSTMSequence>(builder.create());
|
||||
auto g_lstm_sequence = as_type_ptr<op::v1::LSTMSequence>(builder.create());
|
||||
|
||||
EXPECT_EQ(g_lstm_sequence->get_hidden_size(), lstm_sequence->get_hidden_size());
|
||||
EXPECT_EQ(g_lstm_sequence->get_activations(), lstm_sequence->get_activations());
|
||||
EXPECT_EQ(g_lstm_sequence->get_activations_alpha(), lstm_sequence->get_activations_alpha());
|
||||
EXPECT_EQ(g_lstm_sequence->get_activations_beta(), lstm_sequence->get_activations_beta());
|
||||
EXPECT_EQ(g_lstm_sequence->get_clip_threshold(), lstm_sequence->get_clip_threshold());
|
||||
EXPECT_EQ(g_lstm_sequence->get_clip(), lstm_sequence->get_clip());
|
||||
EXPECT_EQ(g_lstm_sequence->get_direction(), lstm_sequence->get_direction());
|
||||
EXPECT_EQ(g_lstm_sequence->get_input_forget(), lstm_sequence->get_input_forget());
|
||||
EXPECT_EQ(g_lstm_sequence->get_weights_format(), lstm_sequence->get_weights_format());
|
||||
}
|
||||
|
||||
TEST(attributes, shuffle_channels_op)
|
||||
|
@ -33,7 +33,9 @@
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/check.hpp"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
#include "ngraph/op/util/attr_types.hpp"
|
||||
#include "ngraph/op/util/rnn_cell_base.hpp"
|
||||
#include "op/group_conv.hpp"
|
||||
#include "util/all_close.hpp"
|
||||
#include "util/all_close_f.hpp"
|
||||
@ -1629,11 +1631,17 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_zero_bias_peepholes)
|
||||
const auto B = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size});
|
||||
const auto P = make_shared<op::Parameter>(element::f32, Shape{3 * hidden_size});
|
||||
|
||||
const auto lstm_cell = make_shared<op::LSTMCell>(
|
||||
X, H_t, C_t, W, R, B, P, hidden_size, op::LSTMWeightsFormat::IOFC);
|
||||
const auto lstm_cell = make_shared<opset4::LSTMCell>(
|
||||
X,
|
||||
H_t,
|
||||
C_t,
|
||||
op::util::convert_lstm_node_format(W, op::util::LSTMWeightsFormat::IOFC),
|
||||
op::util::convert_lstm_node_format(R, op::util::LSTMWeightsFormat::IOFC),
|
||||
op::util::convert_lstm_node_format(B, op::util::LSTMWeightsFormat::IOFC),
|
||||
hidden_size);
|
||||
|
||||
auto ht_function = make_shared<Function>(OutputVector{lstm_cell->output(0)},
|
||||
ParameterVector{X, H_t, C_t, W, R, B, P});
|
||||
ParameterVector{X, H_t, C_t, W, R, B});
|
||||
auto ht_test_case = test::TestCase<TestEngine>(ht_function);
|
||||
|
||||
// X
|
||||
@ -1665,18 +1673,16 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_zero_bias_peepholes)
|
||||
// P
|
||||
vector<float> in_P(3 * hidden_size, 0.f);
|
||||
|
||||
ht_test_case.add_multiple_inputs(
|
||||
vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P});
|
||||
ht_test_case.add_multiple_inputs(vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B});
|
||||
ht_test_case.add_expected_output<float>(
|
||||
Shape{batch_size, hidden_size},
|
||||
{0.81457126f, 0.61109227f, 0.769522f, 0.52239674f, 0.4324641f, 0.63183f});
|
||||
ht_test_case.run();
|
||||
|
||||
auto ct_function = make_shared<Function>(OutputVector{lstm_cell->output(1)},
|
||||
ParameterVector{X, H_t, C_t, W, R, B, P});
|
||||
ParameterVector{X, H_t, C_t, W, R, B});
|
||||
auto ct_test_case = test::TestCase<TestEngine>(ct_function);
|
||||
ct_test_case.add_multiple_inputs(
|
||||
vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P});
|
||||
ct_test_case.add_multiple_inputs(vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B});
|
||||
ct_test_case.add_expected_output<float>(
|
||||
Shape{batch_size, hidden_size},
|
||||
{1.4444952f, 0.9635685f, 1.2875274f, 0.8053419f, 0.7184521f, 0.95803297f});
|
||||
@ -1700,11 +1706,10 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_bias_peepholes)
|
||||
const auto B = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size});
|
||||
const auto P = make_shared<op::Parameter>(element::f32, Shape{3 * hidden_size});
|
||||
|
||||
const auto lstm_cell = make_shared<op::LSTMCell>(
|
||||
X, H_t, C_t, W, R, B, P, hidden_size, op::LSTMWeightsFormat::IOFC);
|
||||
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, B, hidden_size);
|
||||
|
||||
auto ht_function = make_shared<Function>(OutputVector{lstm_cell->output(0)},
|
||||
ParameterVector{X, H_t, C_t, W, R, B, P});
|
||||
ParameterVector{X, H_t, C_t, W, R, B});
|
||||
auto ht_test_case = test::TestCase<TestEngine>(ht_function);
|
||||
|
||||
// X
|
||||
@ -1755,18 +1760,16 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_bias_peepholes)
|
||||
0.13840231f,
|
||||
0.24175227f};
|
||||
|
||||
ht_test_case.add_multiple_inputs(
|
||||
vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P});
|
||||
ht_test_case.add_multiple_inputs(vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B});
|
||||
ht_test_case.add_expected_output<float>(
|
||||
Shape{batch_size, hidden_size},
|
||||
{0.9218244f, 0.78787273f, 0.8754273f, 0.7361462f, 0.70927656f, 0.83522964f});
|
||||
ht_test_case.run();
|
||||
|
||||
auto ct_function = make_shared<Function>(OutputVector{lstm_cell->output(1)},
|
||||
ParameterVector{X, H_t, C_t, W, R, B, P});
|
||||
ParameterVector{X, H_t, C_t, W, R, B});
|
||||
auto ct_test_case = test::TestCase<TestEngine>(ct_function);
|
||||
ct_test_case.add_multiple_inputs(
|
||||
vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P});
|
||||
ct_test_case.add_multiple_inputs(vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B});
|
||||
ct_test_case.add_expected_output<float>(
|
||||
Shape{batch_size, hidden_size},
|
||||
{1.7094649f, 1.1259761f, 1.444019f, 1.086587f, 0.9762144f, 1.3066899f});
|
||||
@ -1792,22 +1795,19 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_bias_peepholes_clip_input_forget)
|
||||
const auto B = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size});
|
||||
const auto P = make_shared<op::Parameter>(element::f32, Shape{3 * hidden_size});
|
||||
|
||||
const auto lstm_cell = make_shared<op::LSTMCell>(X,
|
||||
H_t,
|
||||
C_t,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
P,
|
||||
hidden_size,
|
||||
op::LSTMWeightsFormat::IOFC,
|
||||
vector<string>{"sigmoid", "tanh", "tanh"},
|
||||
vector<float>{},
|
||||
vector<float>{},
|
||||
clip_threshold,
|
||||
input_forget);
|
||||
const auto lstm_cell = make_shared<opset4::LSTMCell>(X,
|
||||
H_t,
|
||||
C_t,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
vector<string>{"sigmoid", "tanh", "tanh"},
|
||||
vector<float>{},
|
||||
vector<float>{},
|
||||
clip_threshold);
|
||||
auto ht_function = make_shared<Function>(OutputVector{lstm_cell->output(0)},
|
||||
ParameterVector{X, H_t, C_t, W, R, B, P});
|
||||
ParameterVector{X, H_t, C_t, W, R, B});
|
||||
auto ht_test_case = test::TestCase<TestEngine>(ht_function);
|
||||
|
||||
// X
|
||||
@ -1858,18 +1858,16 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_bias_peepholes_clip_input_forget)
|
||||
0.13840231f,
|
||||
0.24175227f};
|
||||
|
||||
ht_test_case.add_multiple_inputs(
|
||||
vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P});
|
||||
ht_test_case.add_multiple_inputs(vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B});
|
||||
ht_test_case.add_expected_output<float>(
|
||||
Shape{batch_size, hidden_size},
|
||||
{0.71485436f, 0.71844107f, 0.72704613f, 0.6235602f, 0.68306124f, 0.6978715f});
|
||||
ht_test_case.run();
|
||||
|
||||
auto ct_function = make_shared<Function>(OutputVector{lstm_cell->output(1)},
|
||||
ParameterVector{X, H_t, C_t, W, R, B, P});
|
||||
ParameterVector{X, H_t, C_t, W, R, B});
|
||||
auto ct_test_case = test::TestCase<TestEngine>(ct_function);
|
||||
ct_test_case.add_multiple_inputs(
|
||||
vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P});
|
||||
ct_test_case.add_multiple_inputs(vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B});
|
||||
ct_test_case.add_expected_output<float>(
|
||||
Shape{batch_size, hidden_size},
|
||||
{0.94656503f, 0.9527454f, 0.9706756f, 0.84206575f, 0.91898793f, 0.9127192f});
|
||||
@ -1898,22 +1896,19 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_activaction_functions)
|
||||
const auto B = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size});
|
||||
const auto P = make_shared<op::Parameter>(element::f32, Shape{3 * hidden_size});
|
||||
|
||||
const auto lstm_cell = make_shared<op::LSTMCell>(X,
|
||||
H_t,
|
||||
C_t,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
P,
|
||||
hidden_size,
|
||||
op::LSTMWeightsFormat::IOFC,
|
||||
activations,
|
||||
activation_alpha,
|
||||
activation_beta,
|
||||
clip_threshold,
|
||||
input_forget);
|
||||
const auto lstm_cell = make_shared<opset4::LSTMCell>(X,
|
||||
H_t,
|
||||
C_t,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
activations,
|
||||
activation_alpha,
|
||||
activation_beta,
|
||||
clip_threshold);
|
||||
auto ht_function = make_shared<Function>(OutputVector{lstm_cell->output(0)},
|
||||
ParameterVector{X, H_t, C_t, W, R, B, P});
|
||||
ParameterVector{X, H_t, C_t, W, R, B});
|
||||
auto ht_test_case = test::TestCase<TestEngine>(ht_function);
|
||||
|
||||
// X
|
||||
@ -1964,18 +1959,16 @@ NGRAPH_TEST(${BACKEND_NAME}, lstm_cell_activaction_functions)
|
||||
0.13840231f,
|
||||
0.24175227f};
|
||||
|
||||
ht_test_case.add_multiple_inputs(
|
||||
vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P});
|
||||
ht_test_case.add_multiple_inputs(vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B});
|
||||
ht_test_case.add_expected_output<float>(
|
||||
Shape{batch_size, hidden_size},
|
||||
{0.96834344f, 0.9695254f, 0.97068775f, 0.9077866f, 0.94161016f, 0.96599925f});
|
||||
ht_test_case.run();
|
||||
|
||||
auto ct_function = make_shared<Function>(OutputVector{lstm_cell->output(1)},
|
||||
ParameterVector{X, H_t, C_t, W, R, B, P});
|
||||
ParameterVector{X, H_t, C_t, W, R, B});
|
||||
auto ct_test_case = test::TestCase<TestEngine>(ct_function);
|
||||
ct_test_case.add_multiple_inputs(
|
||||
vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B, in_P});
|
||||
ct_test_case.add_multiple_inputs(vector<vector<float>>{in_X, in_Ht, in_Ct, in_W, in_R, in_B});
|
||||
ct_test_case.add_expected_output<float>(
|
||||
Shape{batch_size, hidden_size},
|
||||
{0.94656503f, 0.9527454f, 0.9706756f, 0.84206575f, 0.91898793f, 0.9127192f});
|
||||
@ -2168,7 +2161,7 @@ NGRAPH_TEST(${BACKEND_NAME}, rnn_cell_no_bias)
|
||||
const auto W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
|
||||
const auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
|
||||
|
||||
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, hidden_size);
|
||||
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size);
|
||||
auto function = make_shared<Function>(rnn_cell, ParameterVector{X, H_t, W, R});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
@ -2219,16 +2212,16 @@ NGRAPH_TEST(${BACKEND_NAME}, rnn_cell_bias_clip)
|
||||
const auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
|
||||
const auto B = make_shared<op::Parameter>(element::f32, Shape{hidden_size});
|
||||
|
||||
const auto rnn_cell = make_shared<op::RNNCell>(X,
|
||||
H_t,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
vector<string>{"tanh"},
|
||||
vector<float>{},
|
||||
vector<float>{},
|
||||
clip);
|
||||
const auto rnn_cell = make_shared<opset4::RNNCell>(X,
|
||||
H_t,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
vector<string>{"tanh"},
|
||||
vector<float>{},
|
||||
vector<float>{},
|
||||
clip);
|
||||
auto function = make_shared<Function>(rnn_cell, ParameterVector{X, H_t, W, R, B});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
@ -2281,16 +2274,16 @@ NGRAPH_TEST(${BACKEND_NAME}, rnn_cell_activation_function)
|
||||
const auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
|
||||
const auto B = make_shared<op::Parameter>(element::f32, Shape{hidden_size});
|
||||
|
||||
const auto rnn_cell = make_shared<op::RNNCell>(X,
|
||||
H_t,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
vector<string>{"sigmoid"},
|
||||
vector<float>{},
|
||||
vector<float>{},
|
||||
clip);
|
||||
const auto rnn_cell = make_shared<opset4::RNNCell>(X,
|
||||
H_t,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
vector<string>{"sigmoid"},
|
||||
vector<float>{},
|
||||
vector<float>{},
|
||||
clip);
|
||||
auto function = make_shared<Function>(rnn_cell, ParameterVector{X, H_t, W, R, B});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
@ -2347,17 +2340,17 @@ NGRAPH_TEST(${BACKEND_NAME}, gru_cell_bias_clip)
|
||||
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
const auto B = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size});
|
||||
|
||||
const auto gru_cell = make_shared<op::GRUCell>(X,
|
||||
H_t,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
vector<string>{"sigmoid", "tanh"},
|
||||
vector<float>{},
|
||||
vector<float>{},
|
||||
clip,
|
||||
linear_before_reset);
|
||||
const auto gru_cell = make_shared<opset4::GRUCell>(X,
|
||||
H_t,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
vector<string>{"sigmoid", "tanh"},
|
||||
vector<float>{},
|
||||
vector<float>{},
|
||||
clip,
|
||||
linear_before_reset);
|
||||
auto function = make_shared<Function>(gru_cell, ParameterVector{X, H_t, W, R, B});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
@ -2420,17 +2413,17 @@ NGRAPH_TEST(${BACKEND_NAME}, gru_cell_linear_before_reset)
|
||||
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
const auto B = make_shared<op::Parameter>(element::f32, Shape{(gates_count + 1) * hidden_size});
|
||||
|
||||
const auto gru_cell = make_shared<op::GRUCell>(X,
|
||||
H_t,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
vector<string>{"sigmoid", "tanh"},
|
||||
vector<float>{},
|
||||
vector<float>{},
|
||||
clip,
|
||||
linear_before_reset);
|
||||
const auto gru_cell = make_shared<opset4::GRUCell>(X,
|
||||
H_t,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
vector<string>{"sigmoid", "tanh"},
|
||||
vector<float>{},
|
||||
vector<float>{},
|
||||
clip,
|
||||
linear_before_reset);
|
||||
auto function = make_shared<Function>(gru_cell, ParameterVector{X, H_t, W, R, B});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
@ -2492,17 +2485,17 @@ NGRAPH_TEST(${BACKEND_NAME}, gru_cell_activation_function)
|
||||
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
const auto B = make_shared<op::Parameter>(element::f32, Shape{(gates_count + 1) * hidden_size});
|
||||
|
||||
const auto gru_cell = make_shared<op::GRUCell>(X,
|
||||
H_t,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
vector<string>{"hardsigmoid", "hardsigmoid"},
|
||||
vector<float>{1.8345f, 1.8345f},
|
||||
vector<float>{3.05f, 3.05f},
|
||||
clip,
|
||||
linear_before_reset);
|
||||
const auto gru_cell = make_shared<opset4::GRUCell>(X,
|
||||
H_t,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
vector<string>{"hardsigmoid", "hardsigmoid"},
|
||||
vector<float>{1.8345f, 1.8345f},
|
||||
vector<float>{3.05f, 3.05f},
|
||||
clip,
|
||||
linear_before_reset);
|
||||
auto function = make_shared<Function>(gru_cell, ParameterVector{X, H_t, W, R, B});
|
||||
|
||||
auto test_case = test::TestCase<TestEngine>(function);
|
||||
|
@ -346,7 +346,7 @@ namespace
|
||||
|
||||
void op_is_GRUCell()
|
||||
{
|
||||
op::GRUCell node;
|
||||
op::v3::GRUCell node;
|
||||
EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
|
||||
EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
|
||||
EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
|
||||
@ -472,7 +472,7 @@ namespace
|
||||
|
||||
void op_is_LSTMCell()
|
||||
{
|
||||
op::LSTMCell node;
|
||||
op::v4::LSTMCell node;
|
||||
EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
|
||||
EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
|
||||
EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
|
||||
@ -481,7 +481,7 @@ namespace
|
||||
|
||||
void op_is_LSTMSequence()
|
||||
{
|
||||
op::LSTMSequence node;
|
||||
op::v0::LSTMSequence node;
|
||||
EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
|
||||
EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
|
||||
EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
|
||||
@ -733,7 +733,7 @@ namespace
|
||||
|
||||
void op_is_RNNCell()
|
||||
{
|
||||
op::RNNCell node;
|
||||
op::v0::RNNCell node;
|
||||
EXPECT_FALSE(op::is_unary_elementwise_arithmetic(&node));
|
||||
EXPECT_FALSE(op::is_binary_elementwise_arithmetic(&node));
|
||||
EXPECT_FALSE(op::is_binary_elementwise_comparison(&node));
|
||||
|
@ -1085,14 +1085,14 @@ IE_CPU.builder_opset1_collapse_dyn_shape
|
||||
# IE_CPU.interpolate_down_scales_const_linear
|
||||
|
||||
# GRUCell operation has a form that is not supported
|
||||
IE_CPU.onnx_model_gru_defaults_fwd
|
||||
IE_CPU.onnx_model_gru_fwd_activations
|
||||
IE_CPU.onnx_model_gru_fwd_mixed_seq_len
|
||||
IE_CPU.onnx_model_gru_rev_clip
|
||||
IE_CPU.onnx_model_gru_reverse
|
||||
IE_CPU.onnx_model_gru_fwd_bias_initial_h
|
||||
IE_CPU.onnx_model_gru_bidirectional
|
||||
IE_CPU.onnx_model_gru_fwd_linear_before_reset
|
||||
onnx_model_gru_defaults_fwd
|
||||
onnx_model_gru_fwd_activations
|
||||
onnx_model_gru_fwd_mixed_seq_len
|
||||
onnx_model_gru_rev_clip
|
||||
onnx_model_gru_reverse
|
||||
onnx_model_gru_fwd_bias_initial_h
|
||||
onnx_model_gru_bidirectional
|
||||
onnx_model_gru_fwd_linear_before_reset
|
||||
|
||||
# Not implemented Interpolate-4:
|
||||
IE_CPU.onnx_model_resize10_import_only
|
||||
|
@ -59,8 +59,10 @@
|
||||
#include "ngraph/runtime/reference/floor.hpp"
|
||||
#include "ngraph/runtime/reference/gather.hpp"
|
||||
#include "ngraph/runtime/reference/gather_nd.hpp"
|
||||
#include "ngraph/runtime/reference/gru_cell.hpp"
|
||||
#include "ngraph/runtime/reference/log.hpp"
|
||||
#include "ngraph/runtime/reference/lrn.hpp"
|
||||
#include "ngraph/runtime/reference/lstm_cell.hpp"
|
||||
#include "ngraph/runtime/reference/matmul.hpp"
|
||||
#include "ngraph/runtime/reference/max.hpp"
|
||||
#include "ngraph/runtime/reference/max_pool.hpp"
|
||||
@ -77,6 +79,7 @@
|
||||
#include "ngraph/runtime/reference/result.hpp"
|
||||
#include "ngraph/runtime/reference/reverse.hpp"
|
||||
#include "ngraph/runtime/reference/reverse_sequence.hpp"
|
||||
#include "ngraph/runtime/reference/rnn_cell.hpp"
|
||||
#include "ngraph/runtime/reference/round.hpp"
|
||||
#include "ngraph/runtime/reference/scatter_nd_update.hpp"
|
||||
#include "ngraph/runtime/reference/select.hpp"
|
||||
@ -692,6 +695,67 @@ protected:
|
||||
}
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::GRUCell_v3:
|
||||
{
|
||||
const op::v3::GRUCell* gru_cell = static_cast<const op::v3::GRUCell*>(&node);
|
||||
runtime::reference::gru_cell(args[0]->get_data_ptr<T>(),
|
||||
args[0]->get_shape(),
|
||||
args[1]->get_data_ptr<T>(),
|
||||
args[1]->get_shape(),
|
||||
args[2]->get_data_ptr<T>(),
|
||||
args[2]->get_shape(),
|
||||
args[3]->get_data_ptr<T>(),
|
||||
args[3]->get_shape(),
|
||||
args[4]->get_data_ptr<T>(),
|
||||
args[4]->get_shape(),
|
||||
out[0]->get_data_ptr<T>(),
|
||||
gru_cell->get_activations()[0],
|
||||
gru_cell->get_activations()[1],
|
||||
gru_cell->get_clip(),
|
||||
gru_cell->get_linear_before_reset());
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::LSTMCell_v4:
|
||||
{
|
||||
const op::v4::LSTMCell* lstm_cell = static_cast<const op::v4::LSTMCell*>(&node);
|
||||
runtime::reference::lstm_cell(args[0]->get_data_ptr<T>(),
|
||||
args[0]->get_shape(),
|
||||
args[1]->get_data_ptr<T>(),
|
||||
args[1]->get_shape(),
|
||||
args[2]->get_data_ptr<T>(),
|
||||
args[2]->get_shape(),
|
||||
args[3]->get_data_ptr<T>(),
|
||||
args[3]->get_shape(),
|
||||
args[4]->get_data_ptr<T>(),
|
||||
args[4]->get_shape(),
|
||||
args[5]->get_data_ptr<T>(),
|
||||
args[5]->get_shape(),
|
||||
out[0]->get_data_ptr<T>(),
|
||||
out[1]->get_data_ptr<T>(),
|
||||
lstm_cell->get_activations()[0],
|
||||
lstm_cell->get_activations()[1],
|
||||
lstm_cell->get_activations()[2],
|
||||
lstm_cell->get_clip());
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::RNNCell_v0:
|
||||
{
|
||||
const op::v0::RNNCell* rnn_cell = static_cast<const op::v0::RNNCell*>(&node);
|
||||
runtime::reference::rnn_cell(args[0]->get_data_ptr<T>(),
|
||||
args[0]->get_shape(),
|
||||
args[1]->get_data_ptr<T>(),
|
||||
args[1]->get_shape(),
|
||||
args[2]->get_data_ptr<T>(),
|
||||
args[2]->get_shape(),
|
||||
args[3]->get_data_ptr<T>(),
|
||||
args[3]->get_shape(),
|
||||
args[4]->get_data_ptr<T>(),
|
||||
args[4]->get_shape(),
|
||||
out[0]->get_data_ptr<T>(),
|
||||
rnn_cell->get_activations()[0],
|
||||
rnn_cell->get_clip());
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::Log:
|
||||
{
|
||||
size_t element_count = shape_size(node.get_output_shape(0));
|
||||
@ -1203,15 +1267,12 @@ protected:
|
||||
case OP_TYPEID::GRN:
|
||||
case OP_TYPEID::GroupConvolution:
|
||||
case OP_TYPEID::GroupConvolutionBackpropData:
|
||||
case OP_TYPEID::GRUCell:
|
||||
case OP_TYPEID::HardSigmoid:
|
||||
case OP_TYPEID::Interpolate:
|
||||
case OP_TYPEID::LSTMCell:
|
||||
case OP_TYPEID::LSTMSequence:
|
||||
case OP_TYPEID::MVN:
|
||||
case OP_TYPEID::NormalizeL2:
|
||||
case OP_TYPEID::PRelu:
|
||||
case OP_TYPEID::RNNCell:
|
||||
case OP_TYPEID::ScatterUpdate_v3:
|
||||
case OP_TYPEID::Selu:
|
||||
case OP_TYPEID::ShuffleChannels:
|
||||
|
@ -20,6 +20,7 @@
|
||||
|
||||
#define ID_SUFFIX(NAME) NAME##_v0
|
||||
NGRAPH_OP(DetectionOutput, op::v0)
|
||||
NGRAPH_OP(RNNCell, op::v0)
|
||||
#undef ID_SUFFIX
|
||||
|
||||
#define ID_SUFFIX(NAME) NAME##_v1
|
||||
@ -31,6 +32,7 @@ NGRAPH_OP(LogicalNot, op::v1)
|
||||
#undef ID_SUFFIX
|
||||
|
||||
#define ID_SUFFIX(NAME) NAME##_v3
|
||||
NGRAPH_OP(GRUCell, op::v3)
|
||||
NGRAPH_OP(EmbeddingBagOffsetsSum, op::v3)
|
||||
NGRAPH_OP(EmbeddingBagPackedSum, op::v3)
|
||||
NGRAPH_OP(EmbeddingSegmentsSum, op::v3)
|
||||
@ -43,4 +45,5 @@ NGRAPH_OP(ScatterUpdate, op::v3)
|
||||
|
||||
#define ID_SUFFIX(NAME) NAME##_v4
|
||||
NGRAPH_OP(CTCLoss, op::v4)
|
||||
NGRAPH_OP(LSTMCell, op::v4)
|
||||
#undef ID_SUFFIX
|
||||
|
@ -105,3 +105,20 @@ INTERPRETER.onnx_model_gatherND_float
|
||||
|
||||
# Round op doesn't support some specific cases of rounding
|
||||
onnx_model_round_half_nearest_even
|
||||
|
||||
# Unsupported op 'LSTMSequence': not FusedOp anymore, no reference implementation yet
|
||||
onnx_model_lstm_fwd_with_clip
|
||||
onnx_model_lstm_fwd_mixed_seq
|
||||
onnx_model_lstm_fwd_hardsigmoid_activation
|
||||
onnx_model_lstm_fwd_large_batch_no_clip
|
||||
onnx_model_lstm_bdir_short_input_seq
|
||||
onnx_model_lstm_mixed_seq_reverse
|
||||
|
||||
# Activation function hardsigmoid is not supported.
|
||||
gru_cell_activation_function
|
||||
lstm_cell_activaction_functions
|
||||
onnx_model_gru_fwd_activations
|
||||
|
||||
# Peepholes, input_forget are not supported
|
||||
lstm_cell_bias_peepholes
|
||||
lstm_cell_bias_peepholes_clip_input_forget
|
||||
|
@ -81,7 +81,6 @@ NGRAPH_OP(Exp, ngraph::op)
|
||||
NGRAPH_OP(FakeQuantize, ngraph::op)
|
||||
NGRAPH_OP(Floor, ngraph::op)
|
||||
NGRAPH_OP(GRN, ngraph::op)
|
||||
NGRAPH_OP(GRUCell, ngraph::op)
|
||||
NGRAPH_OP(Gather, ngraph::op)
|
||||
NGRAPH_OP(GatherND, ngraph::op)
|
||||
NGRAPH_OP(Gelu, ngraph::op)
|
||||
@ -95,8 +94,7 @@ NGRAPH_OP(Less, ngraph::op)
|
||||
NGRAPH_OP(LessEq, ngraph::op)
|
||||
NGRAPH_OP(Log, ngraph::op)
|
||||
NGRAPH_OP(LRN, ngraph::op)
|
||||
NGRAPH_OP(LSTMCell, ngraph::op)
|
||||
NGRAPH_OP(LSTMSequence, ngraph::op)
|
||||
NGRAPH_OP(LSTMSequence, ngraph::op::v0)
|
||||
NGRAPH_OP(MatMul, ngraph::op)
|
||||
NGRAPH_OP(NormalizeL2, ngraph::op)
|
||||
NGRAPH_OP(Max, ngraph::op)
|
||||
@ -124,7 +122,6 @@ NGRAPH_OP(Reshape, ngraph::op)
|
||||
NGRAPH_OP(Result, ngraph::op)
|
||||
NGRAPH_OP(Reverse, ngraph::op)
|
||||
NGRAPH_OP(ReverseSequence, ngraph::op)
|
||||
NGRAPH_OP(RNNCell, ngraph::op)
|
||||
NGRAPH_OP(Round, ngraph::op)
|
||||
NGRAPH_OP(Select, ngraph::op)
|
||||
NGRAPH_OP(Selu, ngraph::op)
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -35,7 +36,7 @@ TEST(type_prop, gru_cell)
|
||||
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
|
||||
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
|
||||
const auto gru_cell = make_shared<op::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
EXPECT_EQ(gru_cell->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(gru_cell->get_output_shape(0), (Shape{batch_size, hidden_size}));
|
||||
}
|
||||
@ -56,7 +57,7 @@ TEST(type_prop, gru_cell_invalid_input)
|
||||
auto W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
|
||||
try
|
||||
{
|
||||
const auto gru_cell = make_shared<op::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
FAIL() << "GRUCell node was created with invalid data.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
@ -70,7 +71,7 @@ TEST(type_prop, gru_cell_invalid_input)
|
||||
R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, 1});
|
||||
try
|
||||
{
|
||||
const auto gru_cell = make_shared<op::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
FAIL() << "GRUCell node was created with invalid data.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
@ -86,7 +87,7 @@ TEST(type_prop, gru_cell_invalid_input)
|
||||
H_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
|
||||
try
|
||||
{
|
||||
const auto gru_cell = make_shared<op::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
FAIL() << "GRUCell node was created with invalid data.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
@ -101,7 +102,7 @@ TEST(type_prop, gru_cell_invalid_input)
|
||||
auto B = make_shared<op::Parameter>(element::f32, Shape{hidden_size});
|
||||
try
|
||||
{
|
||||
const auto gru_cell = make_shared<op::GRUCell>(X, H_t, W, R, B, hidden_size);
|
||||
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, B, hidden_size);
|
||||
FAIL() << "GRUCell node was created with invalid data.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
@ -126,7 +127,7 @@ TEST(type_prop, gru_cell_dynamic_batch_size)
|
||||
const auto H_t =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
|
||||
const auto gru_cell = make_shared<op::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size);
|
||||
EXPECT_EQ(gru_cell->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(gru_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
|
||||
}
|
||||
@ -146,7 +147,7 @@ TEST(type_prop, gru_cell_dynamic_hidden_size)
|
||||
const auto H_t =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
|
||||
const auto gru_cell = make_shared<op::GRUCell>(X, H_t, W, R, 3);
|
||||
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, 3);
|
||||
EXPECT_EQ(gru_cell->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(gru_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
|
||||
}
|
||||
@ -163,7 +164,7 @@ TEST(type_prop, gru_cell_dynamic_inputs)
|
||||
const auto H_t =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
|
||||
const auto gru_cell = make_shared<op::GRUCell>(X, H_t, W, R, 2);
|
||||
const auto gru_cell = make_shared<opset4::GRUCell>(X, H_t, W, R, 2);
|
||||
|
||||
EXPECT_EQ(gru_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
|
||||
EXPECT_EQ(gru_cell->get_output_element_type(0), element::f32);
|
||||
@ -183,33 +184,37 @@ TEST(type_prop, gru_cell_invalid_input_rank0)
|
||||
|
||||
// Invalid rank0 for W tensor.
|
||||
auto W = make_shared<op::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
|
||||
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "GRUCell node was created with invalid data.";
|
||||
|
||||
// Invalid rank0 for X tensor.
|
||||
W = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, input_size});
|
||||
X = make_shared<op::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
|
||||
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "GRUCell node was created with invalid data.";
|
||||
|
||||
// Invalid rank0 for H_t tensor.
|
||||
X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
H_t = make_shared<op::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
|
||||
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "GRUCell node was created with invalid data.";
|
||||
|
||||
// Invalid rank0 for R tensor.
|
||||
H_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
R = make_shared<op::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
|
||||
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "GRUCell node was created with invalid data.";
|
||||
|
||||
// Invalid rank0 for B tensor.
|
||||
R = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, input_size});
|
||||
auto B = make_shared<op::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, B, hidden_size),
|
||||
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, B, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "GRUCell node was created with invalid data.";
|
||||
}
|
||||
@ -228,32 +233,36 @@ TEST(type_prop, gru_cell_invalid_input_dynamic_rank)
|
||||
|
||||
// Invalid dynamic rank for W tensor.
|
||||
auto W = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
|
||||
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "GRUCell node was created with invalid data.";
|
||||
|
||||
// Invalid dynamic rank for X tensor.
|
||||
W = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, input_size});
|
||||
X = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
|
||||
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "GRUCell node was created with invalid data.";
|
||||
|
||||
// Invalid dynamic rank for H_t tensor.
|
||||
X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
H_t = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
|
||||
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "GRUCell node was created with invalid data.";
|
||||
|
||||
// Invalid dynamic rank for R tensor.
|
||||
H_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
R = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
|
||||
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "GRUCell node was created with invalid data.";
|
||||
|
||||
// Invalid dynamic rank for B tensor.
|
||||
R = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, hidden_size});
|
||||
auto B = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<op::GRUCell>(X, H_t, W, R, B, hidden_size),
|
||||
ASSERT_THROW(make_shared<opset4::GRUCell>(X, H_t, W, R, B, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "GRUCell node was created with invalid data.";
|
||||
}
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -28,15 +29,15 @@ TEST(type_prop, lstm_cell)
|
||||
const size_t hidden_size = 3;
|
||||
const size_t gates_count = 4;
|
||||
|
||||
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
|
||||
const auto X = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, input_size});
|
||||
const auto W =
|
||||
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
|
||||
make_shared<opset4::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
|
||||
const auto R =
|
||||
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
|
||||
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
const auto C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
make_shared<opset4::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
|
||||
const auto H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
const auto C_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
|
||||
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||
EXPECT_EQ(lstm_cell->get_hidden_size(), hidden_size);
|
||||
EXPECT_EQ(lstm_cell->get_clip(), 0.f);
|
||||
EXPECT_TRUE(lstm_cell->get_activations_alpha().empty());
|
||||
@ -44,8 +45,6 @@ TEST(type_prop, lstm_cell)
|
||||
EXPECT_EQ(lstm_cell->get_activations()[0], "sigmoid");
|
||||
EXPECT_EQ(lstm_cell->get_activations()[1], "tanh");
|
||||
EXPECT_EQ(lstm_cell->get_activations()[2], "tanh");
|
||||
EXPECT_EQ(lstm_cell->get_weights_format(), op::LSTMWeightsFormat::IFCO);
|
||||
EXPECT_FALSE(lstm_cell->get_input_forget());
|
||||
EXPECT_EQ(lstm_cell->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(lstm_cell->get_output_shape(0), (Shape{batch_size, hidden_size}));
|
||||
EXPECT_EQ(lstm_cell->get_output_element_type(1), element::f32);
|
||||
@ -59,17 +58,17 @@ TEST(type_prop, lstm_cell_invalid_input)
|
||||
const size_t hidden_size = 3;
|
||||
const size_t gates_count = 4;
|
||||
|
||||
auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
|
||||
auto X = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, input_size});
|
||||
auto R =
|
||||
make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
|
||||
auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
auto C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
make_shared<opset4::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
|
||||
auto H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
auto C_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
|
||||
// Invalid W tensor shape.
|
||||
auto W = make_shared<op::Parameter>(element::f32, Shape{1 * hidden_size, input_size});
|
||||
auto W = make_shared<opset4::Parameter>(element::f32, Shape{1 * hidden_size, input_size});
|
||||
try
|
||||
{
|
||||
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||
FAIL() << "LSTMCell node was created with invalid data.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
@ -79,11 +78,11 @@ TEST(type_prop, lstm_cell_invalid_input)
|
||||
}
|
||||
|
||||
// Invalid R tensor shape.
|
||||
W = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
|
||||
R = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, 1});
|
||||
W = make_shared<opset4::Parameter>(element::f32, Shape{gates_count * hidden_size, input_size});
|
||||
R = make_shared<opset4::Parameter>(element::f32, Shape{gates_count * hidden_size, 1});
|
||||
try
|
||||
{
|
||||
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||
FAIL() << "LSTMCell node was created with invalid data.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
@ -94,11 +93,11 @@ TEST(type_prop, lstm_cell_invalid_input)
|
||||
}
|
||||
|
||||
// Invalid H_t tensor shape.
|
||||
R = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
|
||||
H_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
|
||||
R = make_shared<opset4::Parameter>(element::f32, Shape{gates_count * hidden_size, hidden_size});
|
||||
H_t = make_shared<opset4::Parameter>(element::f32, Shape{4, hidden_size});
|
||||
try
|
||||
{
|
||||
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||
FAIL() << "LSTMCell node was created with invalid data.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
@ -109,11 +108,11 @@ TEST(type_prop, lstm_cell_invalid_input)
|
||||
}
|
||||
|
||||
// Invalid C_t tensor shape.
|
||||
H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
C_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
|
||||
H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
C_t = make_shared<opset4::Parameter>(element::f32, Shape{4, hidden_size});
|
||||
try
|
||||
{
|
||||
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||
FAIL() << "LSTMCell node was created with invalid data.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
@ -124,12 +123,12 @@ TEST(type_prop, lstm_cell_invalid_input)
|
||||
}
|
||||
|
||||
// Invalid B tensor shape.
|
||||
C_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
auto B = make_shared<op::Parameter>(element::f32, Shape{2 * gates_count * hidden_size});
|
||||
auto P = make_shared<op::Parameter>(element::f32, Shape{3 * hidden_size});
|
||||
C_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
auto B = make_shared<opset4::Parameter>(element::f32, Shape{2 * gates_count * hidden_size});
|
||||
auto P = make_shared<opset4::Parameter>(element::f32, Shape{3 * hidden_size});
|
||||
try
|
||||
{
|
||||
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, B, P, hidden_size);
|
||||
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, B, hidden_size);
|
||||
FAIL() << "LSTMCell node was created with invalid data.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
@ -137,20 +136,6 @@ TEST(type_prop, lstm_cell_invalid_input)
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Parameter hidden_size mistmatched in B input."));
|
||||
}
|
||||
|
||||
// Invalid P tensor shape.
|
||||
B = make_shared<op::Parameter>(element::f32, Shape{gates_count * hidden_size});
|
||||
P = make_shared<op::Parameter>(element::f32, Shape{hidden_size});
|
||||
try
|
||||
{
|
||||
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, B, P, hidden_size);
|
||||
FAIL() << "LSTMCell node was created with invalid data.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
{
|
||||
EXPECT_HAS_SUBSTRING(error.what(),
|
||||
std::string("Parameter hidden_size mistmatched in P input."));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, lstm_cell_dynamic_batch_size)
|
||||
@ -160,17 +145,18 @@ TEST(type_prop, lstm_cell_dynamic_batch_size)
|
||||
const size_t hidden_size = 3;
|
||||
const size_t gates_count = 4;
|
||||
|
||||
const auto X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
const auto W = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, input_size});
|
||||
const auto R = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, hidden_size});
|
||||
const auto X =
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
const auto W = make_shared<opset4::Parameter>(
|
||||
element::f32, PartialShape{gates_count * hidden_size, input_size});
|
||||
const auto R = make_shared<opset4::Parameter>(
|
||||
element::f32, PartialShape{gates_count * hidden_size, hidden_size});
|
||||
const auto H_t =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
const auto C_t =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
|
||||
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size);
|
||||
|
||||
EXPECT_EQ(lstm_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
|
||||
EXPECT_EQ(lstm_cell->get_output_partial_shape(1), (PartialShape{batch_size, hidden_size}));
|
||||
@ -185,17 +171,18 @@ TEST(type_prop, lstm_cell_dynamic_hidden_size)
|
||||
const auto hidden_size = Dimension::dynamic();
|
||||
const size_t gates_count = 4;
|
||||
|
||||
const auto X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
const auto W = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{hidden_size * gates_count, input_size});
|
||||
const auto R = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{hidden_size * gates_count, hidden_size});
|
||||
const auto X =
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
const auto W = make_shared<opset4::Parameter>(
|
||||
element::f32, PartialShape{hidden_size * gates_count, input_size});
|
||||
const auto R = make_shared<opset4::Parameter>(
|
||||
element::f32, PartialShape{hidden_size * gates_count, hidden_size});
|
||||
const auto H_t =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
const auto C_t =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
|
||||
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, 3);
|
||||
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, 3);
|
||||
|
||||
EXPECT_EQ(lstm_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
|
||||
EXPECT_EQ(lstm_cell->get_output_partial_shape(1), (PartialShape{batch_size, hidden_size}));
|
||||
@ -210,17 +197,18 @@ TEST(type_prop, lstm_cell_dynamic_inputs)
|
||||
const auto hidden_size = Dimension::dynamic();
|
||||
const size_t gates_count = 4;
|
||||
|
||||
const auto X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
const auto W = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{hidden_size * gates_count, input_size});
|
||||
const auto R = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{hidden_size * gates_count, hidden_size});
|
||||
const auto X =
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
const auto W = make_shared<opset4::Parameter>(
|
||||
element::f32, PartialShape{hidden_size * gates_count, input_size});
|
||||
const auto R = make_shared<opset4::Parameter>(
|
||||
element::f32, PartialShape{hidden_size * gates_count, hidden_size});
|
||||
const auto H_t =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
const auto C_t =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
|
||||
const auto lstm_cell = make_shared<op::LSTMCell>(X, H_t, C_t, W, R, 3);
|
||||
const auto lstm_cell = make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, 3);
|
||||
|
||||
EXPECT_EQ(lstm_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
|
||||
EXPECT_EQ(lstm_cell->get_output_partial_shape(1), (PartialShape{batch_size, hidden_size}));
|
||||
@ -235,62 +223,54 @@ TEST(type_prop, lstm_cell_invalid_input_rank0)
|
||||
const size_t hidden_size = 3;
|
||||
const size_t gates_count = 4;
|
||||
|
||||
auto X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
auto W = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, input_size});
|
||||
auto R = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, hidden_size});
|
||||
auto H_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
auto C_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
auto X = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
auto W = make_shared<opset4::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, input_size});
|
||||
auto R = make_shared<opset4::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, hidden_size});
|
||||
auto H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
auto C_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
|
||||
// Invalid rank0 for W tensor.
|
||||
W = make_shared<op::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
W = make_shared<opset4::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
|
||||
// Invalid rank0 for X tensor.
|
||||
W = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, input_size});
|
||||
X = make_shared<op::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
W = make_shared<opset4::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, input_size});
|
||||
X = make_shared<opset4::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
|
||||
// Invalid rank0 for H_t tensor.
|
||||
X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
H_t = make_shared<op::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
X = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
|
||||
// Invalid rank0 for C_t tensor.
|
||||
H_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
C_t = make_shared<op::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
C_t = make_shared<opset4::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
|
||||
// Invalid rank0 for R tensor.
|
||||
C_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
R = make_shared<op::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
C_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
R = make_shared<opset4::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
|
||||
// Invalid rank0 for B tensor.
|
||||
R = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, hidden_size});
|
||||
auto B = make_shared<op::Parameter>(element::f32, PartialShape{});
|
||||
auto P = make_shared<op::Parameter>(element::f32, PartialShape{3 * hidden_size});
|
||||
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, B, P, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
|
||||
// Invalid rank0 for P tensor.
|
||||
B = make_shared<op::Parameter>(element::f32, PartialShape{gates_count * hidden_size});
|
||||
P = make_shared<op::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, B, P, hidden_size),
|
||||
R = make_shared<opset4::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, hidden_size});
|
||||
auto B = make_shared<opset4::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, B, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
}
|
||||
@ -302,62 +282,54 @@ TEST(type_prop, lstm_cell_invalid_input_dynamic_rank)
|
||||
const size_t hidden_size = 3;
|
||||
const size_t gates_count = 4;
|
||||
|
||||
auto X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
auto W = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, input_size});
|
||||
auto R = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, hidden_size});
|
||||
auto H_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
auto C_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
auto X = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
auto W = make_shared<opset4::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, input_size});
|
||||
auto R = make_shared<opset4::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, hidden_size});
|
||||
auto H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
auto C_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
|
||||
// Invalid dynamic rank for W tensor.
|
||||
W = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
W = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
|
||||
// Invalid dynamic rank for X tensor.
|
||||
W = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, input_size});
|
||||
X = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
W = make_shared<opset4::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, input_size});
|
||||
X = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
|
||||
// Invalid dynamic rank for H_t tensor.
|
||||
X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
H_t = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
X = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
H_t = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
|
||||
// Invalid dynamic rank for C_t tensor.
|
||||
H_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
C_t = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
C_t = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
|
||||
// Invalid dynamic rank for R tensor.
|
||||
C_t = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
R = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
C_t = make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
R = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
|
||||
// Invalid dynamic rank for B tensor.
|
||||
R = make_shared<op::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, hidden_size});
|
||||
auto B = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
auto P = make_shared<op::Parameter>(element::f32, PartialShape{3 * hidden_size});
|
||||
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, B, P, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
|
||||
// Invalid dynamic rank for P tensor.
|
||||
B = make_shared<op::Parameter>(element::f32, PartialShape{gates_count * hidden_size});
|
||||
P = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<op::LSTMCell>(X, H_t, C_t, W, R, B, P, hidden_size),
|
||||
R = make_shared<opset4::Parameter>(element::f32,
|
||||
PartialShape{gates_count * hidden_size, hidden_size});
|
||||
auto B = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::LSTMCell>(X, H_t, C_t, W, R, B, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "LSTMCell node was created with invalid data.";
|
||||
}
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
// suppress FusedOp deprecation warnings
|
||||
@ -40,7 +41,7 @@ struct recurrent_sequence_parameters
|
||||
//
|
||||
// Create and initialize default input test tensors.
|
||||
//
|
||||
shared_ptr<op::LSTMSequence>
|
||||
shared_ptr<op::v1::LSTMSequence>
|
||||
lstm_seq_tensor_initialization(const recurrent_sequence_parameters& param)
|
||||
{
|
||||
auto batch_size = param.batch_size;
|
||||
@ -50,20 +51,21 @@ shared_ptr<op::LSTMSequence>
|
||||
auto hidden_size = param.hidden_size;
|
||||
auto et = param.et;
|
||||
|
||||
const auto X = make_shared<op::Parameter>(et, PartialShape{batch_size, seq_length, input_size});
|
||||
const auto X =
|
||||
make_shared<opset4::Parameter>(et, PartialShape{batch_size, seq_length, input_size});
|
||||
const auto initial_hidden_state =
|
||||
make_shared<op::Parameter>(et, PartialShape{batch_size, num_directions, hidden_size});
|
||||
make_shared<opset4::Parameter>(et, PartialShape{batch_size, num_directions, hidden_size});
|
||||
const auto initial_cell_state =
|
||||
make_shared<op::Parameter>(et, PartialShape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<op::Parameter>(et, PartialShape{batch_size});
|
||||
const auto W =
|
||||
make_shared<op::Parameter>(et, PartialShape{num_directions, hidden_size * 4, input_size});
|
||||
const auto R =
|
||||
make_shared<op::Parameter>(et, PartialShape{num_directions, hidden_size * 4, hidden_size});
|
||||
const auto B = make_shared<op::Parameter>(et, PartialShape{num_directions, hidden_size * 4});
|
||||
const auto P = make_shared<op::Parameter>(et, PartialShape{num_directions, hidden_size * 3});
|
||||
make_shared<opset4::Parameter>(et, PartialShape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<opset4::Parameter>(et, PartialShape{batch_size});
|
||||
const auto W = make_shared<opset4::Parameter>(
|
||||
et, PartialShape{num_directions, hidden_size * 4, input_size});
|
||||
const auto R = make_shared<opset4::Parameter>(
|
||||
et, PartialShape{num_directions, hidden_size * 4, hidden_size});
|
||||
const auto B =
|
||||
make_shared<opset4::Parameter>(et, PartialShape{num_directions, hidden_size * 4});
|
||||
|
||||
const auto lstm_sequence = make_shared<op::LSTMSequence>();
|
||||
const auto lstm_sequence = make_shared<op::v1::LSTMSequence>();
|
||||
|
||||
lstm_sequence->set_argument(0, X);
|
||||
lstm_sequence->set_argument(1, initial_hidden_state);
|
||||
@ -72,7 +74,6 @@ shared_ptr<op::LSTMSequence>
|
||||
lstm_sequence->set_argument(4, W);
|
||||
lstm_sequence->set_argument(5, R);
|
||||
lstm_sequence->set_argument(6, B);
|
||||
lstm_sequence->set_argument(7, P);
|
||||
|
||||
return lstm_sequence;
|
||||
}
|
||||
@ -86,40 +87,39 @@ TEST(type_prop, lstm_sequence_forward)
|
||||
const size_t hidden_size = 128;
|
||||
|
||||
const auto X =
|
||||
make_shared<op::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
|
||||
const auto initial_hidden_state =
|
||||
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto initial_cell_state =
|
||||
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{batch_size});
|
||||
const auto W = make_shared<op::Parameter>(element::f32,
|
||||
Shape{num_directions, 4 * hidden_size, input_size});
|
||||
const auto R = make_shared<op::Parameter>(element::f32,
|
||||
Shape{num_directions, 4 * hidden_size, hidden_size});
|
||||
const auto B = make_shared<op::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
|
||||
make_shared<opset4::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
|
||||
const auto initial_hidden_state = make_shared<opset4::Parameter>(
|
||||
element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto initial_cell_state = make_shared<opset4::Parameter>(
|
||||
element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<opset4::Parameter>(element::i32, Shape{batch_size});
|
||||
const auto W = make_shared<opset4::Parameter>(
|
||||
element::f32, Shape{num_directions, 4 * hidden_size, input_size});
|
||||
const auto R = make_shared<opset4::Parameter>(
|
||||
element::f32, Shape{num_directions, 4 * hidden_size, hidden_size});
|
||||
const auto B =
|
||||
make_shared<opset4::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
|
||||
|
||||
const auto lstm_direction = op::RecurrentSequenceDirection::FORWARD;
|
||||
|
||||
const auto lstm_sequence = make_shared<op::LSTMSequence>(X,
|
||||
initial_hidden_state,
|
||||
initial_cell_state,
|
||||
sequence_lengths,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
lstm_direction);
|
||||
const auto lstm_sequence = make_shared<op::v1::LSTMSequence>(X,
|
||||
initial_hidden_state,
|
||||
initial_cell_state,
|
||||
sequence_lengths,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
lstm_direction);
|
||||
|
||||
EXPECT_EQ(lstm_sequence->get_hidden_size(), hidden_size);
|
||||
EXPECT_EQ(lstm_sequence->get_direction(), op::RecurrentSequenceDirection::FORWARD);
|
||||
EXPECT_EQ(lstm_sequence->get_weights_format(), op::LSTMWeightsFormat::IFCO);
|
||||
EXPECT_TRUE(lstm_sequence->get_activations_alpha().empty());
|
||||
EXPECT_TRUE(lstm_sequence->get_activations_beta().empty());
|
||||
EXPECT_EQ(lstm_sequence->get_activations()[0], "sigmoid");
|
||||
EXPECT_EQ(lstm_sequence->get_activations()[1], "tanh");
|
||||
EXPECT_EQ(lstm_sequence->get_activations()[2], "tanh");
|
||||
EXPECT_EQ(lstm_sequence->get_clip_threshold(), 0.f);
|
||||
EXPECT_FALSE(lstm_sequence->get_input_forget());
|
||||
EXPECT_EQ(lstm_sequence->get_clip(), 0.f);
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(lstm_sequence->get_output_shape(0),
|
||||
(Shape{batch_size, num_directions, seq_length, hidden_size}));
|
||||
@ -138,47 +138,44 @@ TEST(type_prop, lstm_sequence_bidirectional)
|
||||
const size_t hidden_size = 256;
|
||||
|
||||
const auto X =
|
||||
make_shared<op::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
|
||||
const auto initial_hidden_state =
|
||||
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto initial_cell_state =
|
||||
make_shared<op::Parameter>(element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<op::Parameter>(element::i32, Shape{batch_size});
|
||||
const auto W = make_shared<op::Parameter>(element::f32,
|
||||
Shape{num_directions, 4 * hidden_size, input_size});
|
||||
const auto R = make_shared<op::Parameter>(element::f32,
|
||||
Shape{num_directions, 4 * hidden_size, hidden_size});
|
||||
const auto B = make_shared<op::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
|
||||
make_shared<opset4::Parameter>(element::f32, Shape{batch_size, seq_length, input_size});
|
||||
const auto initial_hidden_state = make_shared<opset4::Parameter>(
|
||||
element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto initial_cell_state = make_shared<opset4::Parameter>(
|
||||
element::f32, Shape{batch_size, num_directions, hidden_size});
|
||||
const auto sequence_lengths = make_shared<opset4::Parameter>(element::i32, Shape{batch_size});
|
||||
const auto W = make_shared<opset4::Parameter>(
|
||||
element::f32, Shape{num_directions, 4 * hidden_size, input_size});
|
||||
const auto R = make_shared<opset4::Parameter>(
|
||||
element::f32, Shape{num_directions, 4 * hidden_size, hidden_size});
|
||||
const auto B =
|
||||
make_shared<opset4::Parameter>(element::f32, Shape{num_directions, 4 * hidden_size});
|
||||
|
||||
const auto weights_format = op::LSTMWeightsFormat::FICO;
|
||||
const auto lstm_direction = op::LSTMSequence::direction::BIDIRECTIONAL;
|
||||
const auto lstm_direction = op::v1::LSTMSequence::direction::BIDIRECTIONAL;
|
||||
const std::vector<float> activations_alpha = {2.7, 7.0, 32.367};
|
||||
const std::vector<float> activations_beta = {0.0, 5.49, 6.0};
|
||||
const std::vector<std::string> activations = {"tanh", "sigmoid", "sigmoid"};
|
||||
|
||||
const auto lstm_sequence = make_shared<op::LSTMSequence>(X,
|
||||
initial_hidden_state,
|
||||
initial_cell_state,
|
||||
sequence_lengths,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
lstm_direction,
|
||||
weights_format,
|
||||
activations_alpha,
|
||||
activations_beta,
|
||||
activations);
|
||||
const auto lstm_sequence = make_shared<op::v1::LSTMSequence>(X,
|
||||
initial_hidden_state,
|
||||
initial_cell_state,
|
||||
sequence_lengths,
|
||||
W,
|
||||
R,
|
||||
B,
|
||||
hidden_size,
|
||||
lstm_direction,
|
||||
activations_alpha,
|
||||
activations_beta,
|
||||
activations);
|
||||
EXPECT_EQ(lstm_sequence->get_hidden_size(), hidden_size);
|
||||
EXPECT_EQ(lstm_sequence->get_direction(), op::LSTMSequence::direction::BIDIRECTIONAL);
|
||||
EXPECT_EQ(lstm_sequence->get_weights_format(), op::LSTMWeightsFormat::FICO);
|
||||
EXPECT_EQ(lstm_sequence->get_direction(), op::v1::LSTMSequence::direction::BIDIRECTIONAL);
|
||||
EXPECT_EQ(lstm_sequence->get_activations_alpha(), activations_alpha);
|
||||
EXPECT_EQ(lstm_sequence->get_activations_beta(), activations_beta);
|
||||
EXPECT_EQ(lstm_sequence->get_activations()[0], "tanh");
|
||||
EXPECT_EQ(lstm_sequence->get_activations()[1], "sigmoid");
|
||||
EXPECT_EQ(lstm_sequence->get_activations()[2], "sigmoid");
|
||||
EXPECT_EQ(lstm_sequence->get_clip_threshold(), 0.f);
|
||||
EXPECT_FALSE(lstm_sequence->get_input_forget());
|
||||
EXPECT_EQ(lstm_sequence->get_clip(), 0.f);
|
||||
EXPECT_EQ(lstm_sequence->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(lstm_sequence->get_output_shape(0),
|
||||
(Shape{batch_size, num_directions, seq_length, hidden_size}));
|
||||
@ -330,15 +327,14 @@ TEST(type_prop, lstm_sequence_invalid_input_dimension)
|
||||
param.et = element::f32;
|
||||
|
||||
auto lstm_sequence = lstm_seq_tensor_initialization(param);
|
||||
auto invalid_rank0_tensor = make_shared<op::Parameter>(param.et, PartialShape{});
|
||||
auto invalid_rank0_tensor = make_shared<opset4::Parameter>(param.et, PartialShape{});
|
||||
|
||||
// Validate invalid rank0 tensor for all inputs: X, initial_hidden_state, initial_cell_state W,
|
||||
// R, B and P
|
||||
// R, B
|
||||
for (auto i = 0; i < lstm_sequence->get_input_size(); i++)
|
||||
{
|
||||
lstm_sequence = lstm_seq_tensor_initialization(param);
|
||||
lstm_sequence->set_argument(i, invalid_rank0_tensor);
|
||||
|
||||
ASSERT_THROW(lstm_sequence->validate_and_infer_types(), ngraph::CheckFailure)
|
||||
<< "LSTMSequence node was created with invalid data.";
|
||||
}
|
||||
@ -357,15 +353,14 @@ TEST(type_prop, lstm_sequence_invalid_input_dynamic_rank)
|
||||
|
||||
auto lstm_sequence = lstm_seq_tensor_initialization(param);
|
||||
auto invalid_dynamic_tensor =
|
||||
make_shared<op::Parameter>(param.et, PartialShape::dynamic(Rank::dynamic()));
|
||||
make_shared<opset4::Parameter>(param.et, PartialShape::dynamic(Rank::dynamic()));
|
||||
|
||||
// Validate invalid dynamic tensor for all inputs: X, initial_hidden_state, initial_cell_state
|
||||
// W, R, B and P
|
||||
// W, R, B
|
||||
for (auto i = 0; i < lstm_sequence->get_input_size(); i++)
|
||||
{
|
||||
lstm_sequence = lstm_seq_tensor_initialization(param);
|
||||
lstm_sequence->set_argument(i, invalid_dynamic_tensor);
|
||||
|
||||
ASSERT_THROW(lstm_sequence->validate_and_infer_types(), ngraph::CheckFailure)
|
||||
<< "LSTMSequence node was created with invalid data.";
|
||||
}
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "ngraph/opsets/opset4.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -27,12 +28,12 @@ TEST(type_prop, rnn_cell)
|
||||
const size_t input_size = 3;
|
||||
const size_t hidden_size = 3;
|
||||
|
||||
const auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
|
||||
const auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
const auto W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
|
||||
const auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
|
||||
const auto X = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, input_size});
|
||||
const auto H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
const auto W = make_shared<opset4::Parameter>(element::f32, Shape{hidden_size, input_size});
|
||||
const auto R = make_shared<opset4::Parameter>(element::f32, Shape{hidden_size, hidden_size});
|
||||
|
||||
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, hidden_size);
|
||||
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size);
|
||||
EXPECT_EQ(rnn_cell->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(rnn_cell->get_output_shape(0), (Shape{batch_size, hidden_size}));
|
||||
}
|
||||
@ -43,15 +44,15 @@ TEST(type_prop, rnn_cell_invalid_input)
|
||||
const size_t input_size = 3;
|
||||
const size_t hidden_size = 3;
|
||||
|
||||
auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
|
||||
auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
|
||||
auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
auto X = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, input_size});
|
||||
auto R = make_shared<opset4::Parameter>(element::f32, Shape{hidden_size, hidden_size});
|
||||
auto H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
|
||||
// Invalid W tensor shape.
|
||||
auto W = make_shared<op::Parameter>(element::f32, Shape{2 * hidden_size, input_size});
|
||||
auto W = make_shared<opset4::Parameter>(element::f32, Shape{2 * hidden_size, input_size});
|
||||
try
|
||||
{
|
||||
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, hidden_size);
|
||||
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size);
|
||||
FAIL() << "RNNCell node was created with invalid data.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
@ -61,11 +62,11 @@ TEST(type_prop, rnn_cell_invalid_input)
|
||||
}
|
||||
|
||||
// Invalid R tensor shape.
|
||||
W = make_shared<op::Parameter>(element::f32, Shape{hidden_size, input_size});
|
||||
R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, 1});
|
||||
W = make_shared<opset4::Parameter>(element::f32, Shape{hidden_size, input_size});
|
||||
R = make_shared<opset4::Parameter>(element::f32, Shape{hidden_size, 1});
|
||||
try
|
||||
{
|
||||
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, hidden_size);
|
||||
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size);
|
||||
FAIL() << "RNNCell node was created with invalid data.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
@ -77,11 +78,11 @@ TEST(type_prop, rnn_cell_invalid_input)
|
||||
}
|
||||
|
||||
// Invalid H_t tensor shape.
|
||||
R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
|
||||
H_t = make_shared<op::Parameter>(element::f32, Shape{4, hidden_size});
|
||||
R = make_shared<opset4::Parameter>(element::f32, Shape{hidden_size, hidden_size});
|
||||
H_t = make_shared<opset4::Parameter>(element::f32, Shape{4, hidden_size});
|
||||
try
|
||||
{
|
||||
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, hidden_size);
|
||||
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size);
|
||||
FAIL() << "RNNCell node was created with invalid data.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
@ -92,11 +93,11 @@ TEST(type_prop, rnn_cell_invalid_input)
|
||||
}
|
||||
|
||||
// Invalid B tensor shape.
|
||||
H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
auto B = make_shared<op::Parameter>(element::f32, Shape{2 * hidden_size});
|
||||
H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
auto B = make_shared<opset4::Parameter>(element::f32, Shape{2 * hidden_size});
|
||||
try
|
||||
{
|
||||
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, B, hidden_size);
|
||||
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, B, hidden_size);
|
||||
FAIL() << "RNNCell node was created with invalid data.";
|
||||
}
|
||||
catch (const NodeValidationFailure& error)
|
||||
@ -112,13 +113,16 @@ TEST(type_prop, rnn_cell_dynamic_batch_size)
|
||||
const size_t input_size = 3;
|
||||
const size_t hidden_size = 3;
|
||||
|
||||
const auto X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
const auto X =
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
const auto H_t =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
const auto W = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, input_size});
|
||||
const auto R = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
const auto W =
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, input_size});
|
||||
const auto R =
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
|
||||
|
||||
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, hidden_size);
|
||||
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size);
|
||||
EXPECT_EQ(rnn_cell->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(rnn_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
|
||||
}
|
||||
@ -129,13 +133,16 @@ TEST(type_prop, rnn_cell_dynamic_hidden_size)
|
||||
const size_t input_size = 3;
|
||||
const auto hidden_size = Dimension::dynamic();
|
||||
|
||||
const auto X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
const auto X =
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
const auto H_t =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
const auto W = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, input_size});
|
||||
const auto R = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
const auto W =
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, input_size});
|
||||
const auto R =
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
|
||||
|
||||
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, 3);
|
||||
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, 3);
|
||||
EXPECT_EQ(rnn_cell->get_output_element_type(0), element::f32);
|
||||
EXPECT_EQ(rnn_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
|
||||
}
|
||||
@ -146,13 +153,16 @@ TEST(type_prop, rnn_cell_dynamic_inputs)
|
||||
const auto input_size = Dimension::dynamic();
|
||||
const auto hidden_size = Dimension::dynamic();
|
||||
|
||||
const auto X = make_shared<op::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
const auto R = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
|
||||
const auto W = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, input_size});
|
||||
const auto X =
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, input_size});
|
||||
const auto R =
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
|
||||
const auto W =
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, input_size});
|
||||
const auto H_t =
|
||||
make_shared<op::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
make_shared<opset4::Parameter>(element::f32, PartialShape{batch_size, hidden_size});
|
||||
|
||||
const auto rnn_cell = make_shared<op::RNNCell>(X, H_t, W, R, 2);
|
||||
const auto rnn_cell = make_shared<opset4::RNNCell>(X, H_t, W, R, 2);
|
||||
|
||||
EXPECT_EQ(rnn_cell->get_output_partial_shape(0), (PartialShape{batch_size, hidden_size}));
|
||||
EXPECT_EQ(rnn_cell->get_output_element_type(0), element::f32);
|
||||
@ -164,37 +174,41 @@ TEST(type_prop, rnn_cell_invalid_input_rank0)
|
||||
const size_t input_size = 3;
|
||||
const size_t hidden_size = 3;
|
||||
|
||||
auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
|
||||
auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
|
||||
auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
auto X = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, input_size});
|
||||
auto R = make_shared<opset4::Parameter>(element::f32, Shape{hidden_size, hidden_size});
|
||||
auto H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
|
||||
// Invalid rank0 for W tensor.
|
||||
auto W = make_shared<op::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
|
||||
auto W = make_shared<opset4::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "RNNCell node was created with invalid data.";
|
||||
|
||||
// Invalid rank0 for X tensor.
|
||||
W = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, input_size});
|
||||
X = make_shared<op::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
|
||||
W = make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, input_size});
|
||||
X = make_shared<opset4::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "RNNCell node was created with invalid data.";
|
||||
|
||||
// Invalid rank0 for H_t tensor.
|
||||
X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
|
||||
H_t = make_shared<op::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
|
||||
X = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, input_size});
|
||||
H_t = make_shared<opset4::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "RNNCell node was created with invalid data.";
|
||||
|
||||
// Invalid rank0 for R tensor.
|
||||
H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
R = make_shared<op::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
|
||||
H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
R = make_shared<opset4::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "RNNCell node was created with invalid data.";
|
||||
|
||||
// Invalid rank0 for B tensor.
|
||||
R = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
|
||||
auto B = make_shared<op::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, B, hidden_size),
|
||||
R = make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
|
||||
auto B = make_shared<opset4::Parameter>(element::f32, PartialShape{});
|
||||
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, B, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "RNNCell node was created with invalid data.";
|
||||
}
|
||||
@ -205,37 +219,41 @@ TEST(type_prop, rnn_cell_invalid_input_dynamic_rank)
|
||||
const size_t input_size = 3;
|
||||
const size_t hidden_size = 3;
|
||||
|
||||
auto X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
|
||||
auto R = make_shared<op::Parameter>(element::f32, Shape{hidden_size, hidden_size});
|
||||
auto H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
auto X = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, input_size});
|
||||
auto R = make_shared<opset4::Parameter>(element::f32, Shape{hidden_size, hidden_size});
|
||||
auto H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
|
||||
// Invalid dynamic rank for W tensor.
|
||||
auto W = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
|
||||
auto W = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "RNNCell node was created with invalid data.";
|
||||
|
||||
// Invalid dynamic rank for X tensor.
|
||||
W = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, input_size});
|
||||
X = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
|
||||
W = make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, input_size});
|
||||
X = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "RNNCell node was created with invalid data.";
|
||||
|
||||
// Invalid dynamic rank for H_t tensor.
|
||||
X = make_shared<op::Parameter>(element::f32, Shape{batch_size, input_size});
|
||||
H_t = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
|
||||
X = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, input_size});
|
||||
H_t = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "RNNCell node was created with invalid data.";
|
||||
|
||||
// Invalid dynamic rank for R tensor.
|
||||
H_t = make_shared<op::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
R = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, hidden_size), ngraph::NodeValidationFailure)
|
||||
H_t = make_shared<opset4::Parameter>(element::f32, Shape{batch_size, hidden_size});
|
||||
R = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "RNNCell node was created with invalid data.";
|
||||
|
||||
// Invalid dynamic rank for B tensor.
|
||||
R = make_shared<op::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
|
||||
auto B = make_shared<op::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<op::RNNCell>(X, H_t, W, R, B, hidden_size),
|
||||
R = make_shared<opset4::Parameter>(element::f32, PartialShape{hidden_size, hidden_size});
|
||||
auto B = make_shared<opset4::Parameter>(element::f32, PartialShape::dynamic(Rank::dynamic()));
|
||||
ASSERT_THROW(make_shared<opset4::RNNCell>(X, H_t, W, R, B, hidden_size),
|
||||
ngraph::NodeValidationFailure)
|
||||
<< "RNNCell node was created with invalid data.";
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user