[CPU][LPT] Support quantized RNN (LSTM and GRU algorithms) (#12981)
This commit is contained in:
@@ -0,0 +1,30 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include "low_precision/layer_transformation.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
namespace low_precision {
|
||||
|
||||
class LP_TRANSFORMATIONS_API RecurrentCellTransformation : public LayerTransformation {
|
||||
public:
|
||||
OPENVINO_RTTI("RecurrentCellTransformation", "0");
|
||||
RecurrentCellTransformation(const Params& params = Params());
|
||||
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
|
||||
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
|
||||
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
|
||||
void propagateSkipCleanupAttribute(std::shared_ptr<Node> dequantization_multiply);
|
||||
static std::shared_ptr<ov::Node> wrap_fake_quantize(const std::shared_ptr<ov::Node> parameter);
|
||||
static std::shared_ptr<ov::Node> wrap_quantization(const std::shared_ptr<ov::Node> parameter);
|
||||
static std::shared_ptr<ov::Node> wrap_dequantization(const std::shared_ptr<ov::Node> parameter, const bool with_subtract);
|
||||
};
|
||||
|
||||
} // namespace low_precision
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/node.hpp>
|
||||
|
||||
#include "low_precision/rt_info/attribute_parameters.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
class LP_TRANSFORMATIONS_API SkipCleanupAttribute : public ov::RuntimeAttribute {
|
||||
public:
|
||||
OPENVINO_RTTI("LowPrecision::SkipCleanup", "", ov::RuntimeAttribute, 0);
|
||||
static ov::Any create(const std::shared_ptr<ngraph::Node>& node);
|
||||
};
|
||||
} // namespace ngraph
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "low_precision/common/ie_lpt_exception.hpp"
|
||||
#include "low_precision/network_helper.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "low_precision/rt_info/skip_cleanup_attribute.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
@@ -113,6 +114,10 @@ bool FuseConvertTransformation::transform(TransformationContext& context, ngraph
|
||||
}
|
||||
|
||||
bool FuseConvertTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const {
|
||||
if (!getAttribute<SkipCleanupAttribute>(op).empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto convert = ov::as_type_ptr<opset1::Convert>(op->get_input_node_shared_ptr(0));
|
||||
// issue #40395
|
||||
if (convert == nullptr) {
|
||||
|
||||
@@ -10,6 +10,7 @@
|
||||
#include "low_precision/fake_quantize.hpp"
|
||||
#include "low_precision/network_helper.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "low_precision/rt_info/skip_cleanup_attribute.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
@@ -98,6 +99,10 @@ bool FuseMultiplyToFakeQuantizeTransformation::canBeTransformed(const Transforma
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!getAttribute<SkipCleanupAttribute>(operation).empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto parent = operation->get_input_node_shared_ptr(0);
|
||||
auto fq = ov::as_type_ptr<opset1::FakeQuantize>(parent);
|
||||
const auto convert = ov::as_type_ptr<opset1::Convert>(parent);
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "low_precision/fake_quantize.hpp"
|
||||
#include "low_precision/network_helper.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "low_precision/rt_info/skip_cleanup_attribute.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
@@ -92,6 +93,10 @@ bool FuseSubtractToFakeQuantizeTransformation::canBeTransformed(const Transforma
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!getAttribute<SkipCleanupAttribute>(operation).empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto children = operation->get_output_target_inputs(0);
|
||||
|
||||
for (const auto& target : children) {
|
||||
|
||||
@@ -56,6 +56,7 @@
|
||||
#include "low_precision/normalize_l2.hpp"
|
||||
#include "low_precision/pad.hpp"
|
||||
#include "low_precision/prelu.hpp"
|
||||
#include "low_precision/recurrent_cell.hpp"
|
||||
#include "low_precision/reduce_max.hpp"
|
||||
#include "low_precision/reduce_mean.hpp"
|
||||
#include "low_precision/reduce_min.hpp"
|
||||
@@ -229,6 +230,7 @@ bool ngraph::pass::low_precision::LowPrecision::run_on_model(const std::shared_p
|
||||
common->add_matcher<ngraph::pass::low_precision::NormalizeL2Transformation>(params);
|
||||
common->add_matcher<ngraph::pass::low_precision::PadTransformation>(params);
|
||||
common->add_matcher<ngraph::pass::low_precision::PReluTransformation>(params);
|
||||
common->add_matcher<ngraph::pass::low_precision::RecurrentCellTransformation>(params);
|
||||
common->add_matcher<ngraph::pass::low_precision::ReduceMaxTransformation>(params);
|
||||
common->add_matcher<ngraph::pass::low_precision::ReduceMeanTransformation>(params);
|
||||
common->add_matcher<ngraph::pass::low_precision::ReduceMinTransformation>(params);
|
||||
|
||||
@@ -220,6 +220,8 @@ bool ngraph::pass::low_precision::MarkupPrecisions::isSupported(const std::share
|
||||
{ name<opset1::Transpose>() },
|
||||
{ name<opset1::Unsqueeze>() },
|
||||
{ name<opset1::VariadicSplit>() },
|
||||
{ name<opset5::LSTMSequence>() },
|
||||
{ name<opset6::GRUSequence>() },
|
||||
};
|
||||
|
||||
return supportedOps.find(node->get_type_name()) != supportedOps.end();
|
||||
|
||||
218
src/common/low_precision_transformations/src/recurrent_cell.cpp
Normal file
218
src/common/low_precision_transformations/src/recurrent_cell.cpp
Normal file
@@ -0,0 +1,218 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "low_precision/recurrent_cell.hpp"
|
||||
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/node.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
|
||||
#include "low_precision/network_helper.hpp"
|
||||
#include "low_precision/rt_info/skip_cleanup_attribute.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
namespace low_precision {
|
||||
|
||||
RecurrentCellTransformation::RecurrentCellTransformation(const Params& params) : LayerTransformation(params) {
|
||||
const auto X = ngraph::pattern::any_input();
|
||||
const auto H = ngraph::pattern::any_input();
|
||||
const auto C = ngraph::pattern::any_input();
|
||||
const auto S = ngraph::pattern::any_input();
|
||||
const auto W = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
|
||||
const auto R = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
|
||||
const auto B = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
|
||||
|
||||
const auto H_as_const = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
|
||||
|
||||
const auto fq_X = wrap_fake_quantize(X);
|
||||
const auto fq_H = wrap_fake_quantize(H);
|
||||
const auto fq_W = wrap_fake_quantize(W);
|
||||
const auto fq_R = wrap_fake_quantize(R);
|
||||
|
||||
const auto dequantization_X = wrap_dequantization(ngraph::pattern::any_input(), true);
|
||||
const auto dequantization_H = wrap_dequantization(ngraph::pattern::any_input(), true);
|
||||
const auto dequantization_W = wrap_dequantization(ngraph::pattern::any_input(), true);
|
||||
const auto dequantization_R = wrap_dequantization(ngraph::pattern::any_input(), true);
|
||||
|
||||
const auto dequantization_without_subtract_X = wrap_dequantization(ngraph::pattern::any_input(), false);
|
||||
const auto dequantization_without_subtract_H = wrap_dequantization(ngraph::pattern::any_input(), false);
|
||||
const auto dequantization_without_subtract_W = wrap_dequantization(ngraph::pattern::any_input(), false);
|
||||
const auto dequantization_without_subtract_R = wrap_dequantization(ngraph::pattern::any_input(), false);
|
||||
|
||||
auto X_in = std::make_shared<ngraph::pattern::op::Or>(
|
||||
OutputVector{
|
||||
fq_X, dequantization_X, dequantization_without_subtract_X
|
||||
});
|
||||
|
||||
auto H_in = std::make_shared<ngraph::pattern::op::Or>(
|
||||
OutputVector{
|
||||
H_as_const, fq_H, dequantization_H, dequantization_without_subtract_H
|
||||
});
|
||||
|
||||
auto W_in = std::make_shared<ngraph::pattern::op::Or>(
|
||||
OutputVector{
|
||||
fq_W, dequantization_W, dequantization_without_subtract_W
|
||||
});
|
||||
|
||||
auto R_in = std::make_shared<ngraph::pattern::op::Or>(
|
||||
OutputVector{
|
||||
fq_R, dequantization_R, dequantization_without_subtract_R
|
||||
});
|
||||
|
||||
const auto lstm_seq = ngraph::pattern::wrap_type<ngraph::opset5::LSTMSequence>(
|
||||
{X_in, H_in, C, S, W_in, R_in, B});
|
||||
const auto gru_seq = ngraph::pattern::wrap_type<ngraph::opset5::GRUSequence>(
|
||||
{X_in, H_in, S, W_in, R_in, B});
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||
auto op = m.get_match_root();
|
||||
if (transformation_callback(op)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return transform(*context, m);
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(
|
||||
std::make_shared<pattern::op::Or>(
|
||||
OutputVector {
|
||||
lstm_seq,
|
||||
gru_seq
|
||||
}),
|
||||
"RecurrentCellTransformation");
|
||||
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
bool RecurrentCellTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher& m) {
|
||||
const auto lstm = m.get_match_root();
|
||||
if (!canBeTransformed(context, lstm)) {
|
||||
return false;
|
||||
}
|
||||
for (size_t parentIndex = 0ul; parentIndex < lstm->get_input_size(); parentIndex++) {
|
||||
auto lstm_parent = lstm->get_input_node_shared_ptr(parentIndex);
|
||||
if (is_type<ngraph::opset1::FakeQuantize>(lstm_parent)) {
|
||||
auto fq_parent = lstm_parent->get_input_node_shared_ptr(0);
|
||||
if (is_type<ngraph::opset5::Constant>(fq_parent)) {
|
||||
auto fq_node = as_type_ptr<ngraph::opset1::FakeQuantize>(lstm_parent);
|
||||
const QuantizationDetails quantizationDetails = QuantizationDetails::getDetails(fq_node);
|
||||
const auto precisionsAttribute = getAttributeFromOutput<PrecisionsAttribute>(lstm_parent);
|
||||
const auto precisions = precisionsAttribute.empty()
|
||||
? defaultPrecisions
|
||||
: precisionsAttribute.as<PrecisionsAttribute>().value();
|
||||
const DataPrecision dataPrecision = getDataPrecision(lstm_parent, quantizationDetails, precisions);
|
||||
auto QDQ = NetworkHelper::decomposeFakeQuantize(fq_node,
|
||||
dataPrecision.precision,
|
||||
dataPrecision.min,
|
||||
dataPrecision.max,
|
||||
dataPrecision.hasZeroPoint,
|
||||
updatePrecisions);
|
||||
std::shared_ptr<ngraph::Node> new_fq = std::get<0>(QDQ);
|
||||
std::shared_ptr<ngraph::Node> deq_multiply = std::get<1>(QDQ);
|
||||
if (deq_multiply == nullptr || new_fq == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Node> convert;
|
||||
auto multiply_parent = deq_multiply->get_input_node_shared_ptr(0);
|
||||
if (is_type<ngraph::opset1::Subtract>(multiply_parent)) {
|
||||
convert = multiply_parent->get_input_node_shared_ptr(0);
|
||||
} else {
|
||||
convert = multiply_parent;
|
||||
}
|
||||
ov::disable_constant_folding(convert);
|
||||
propagateSkipCleanupAttribute(deq_multiply);
|
||||
|
||||
this->register_new_node(new_fq);
|
||||
updateOutput(context, deq_multiply, new_fq);
|
||||
} else {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
if (is_type<ngraph::opset1::Multiply>(lstm_parent)) {
|
||||
auto multiply = lstm_parent->get_input_node_shared_ptr(0);
|
||||
ov::disable_constant_folding(multiply);
|
||||
propagateSkipCleanupAttribute(lstm_parent);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RecurrentCellTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> lstm) const {
|
||||
std::shared_ptr<ov::Node> W, R;
|
||||
|
||||
if (is_type<opset5::LSTMSequence>(lstm)) {
|
||||
W = lstm->get_input_node_shared_ptr(4);
|
||||
R = lstm->get_input_node_shared_ptr(5);
|
||||
} else if (is_type<opset5::GRUSequence>(lstm)) {
|
||||
W = lstm->get_input_node_shared_ptr(3);
|
||||
R = lstm->get_input_node_shared_ptr(4);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RecurrentCellTransformation::isPrecisionPreserved(std::shared_ptr<Node>) const noexcept {
|
||||
return true;
|
||||
}
|
||||
|
||||
void RecurrentCellTransformation::propagateSkipCleanupAttribute(std::shared_ptr<Node> multiply) {
|
||||
SkipCleanupAttribute::create(multiply);
|
||||
auto multiply_parent = multiply->get_input_node_shared_ptr(0);
|
||||
SkipCleanupAttribute::create(multiply_parent);
|
||||
if (is_type<ngraph::opset1::Subtract>(multiply_parent)) {
|
||||
auto subtract_parent = multiply_parent->get_input_node_shared_ptr(0);
|
||||
SkipCleanupAttribute::create(subtract_parent);
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Node> RecurrentCellTransformation::wrap_fake_quantize(
|
||||
const std::shared_ptr<ov::Node> parameter) {
|
||||
const auto input_low = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
|
||||
const auto input_high = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
|
||||
const auto output_low = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
|
||||
const auto output_high = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
|
||||
return ngraph::pattern::wrap_type<opset1::FakeQuantize>({
|
||||
parameter,
|
||||
input_low,
|
||||
input_high,
|
||||
output_low,
|
||||
output_high});
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Node> RecurrentCellTransformation::wrap_quantization(
|
||||
const std::shared_ptr<ov::Node> parameter) {
|
||||
const auto quantization_fake_quantize = wrap_fake_quantize(parameter);
|
||||
const auto quantization_convert = ngraph::pattern::wrap_type<ngraph::opset1::Convert>(
|
||||
{quantization_fake_quantize});
|
||||
return quantization_convert;
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Node> RecurrentCellTransformation::wrap_dequantization(
|
||||
const std::shared_ptr<ov::Node> parameter,
|
||||
const bool with_subtract) {
|
||||
const auto dequantization_convert = ngraph::pattern::wrap_type<ngraph::opset1::Convert>({parameter});
|
||||
const auto subtract_constant = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
|
||||
const auto dequantization_subtract = ngraph::pattern::wrap_type<ngraph::opset1::Subtract>(
|
||||
{dequantization_convert, subtract_constant});
|
||||
const auto multiply_constant = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
|
||||
const auto multiply_parent = with_subtract ? dequantization_subtract : dequantization_convert;
|
||||
const auto dequantization_multiply = ngraph::pattern::wrap_type<ngraph::opset1::Multiply>(
|
||||
{multiply_parent, multiply_constant});
|
||||
return dequantization_multiply;
|
||||
}
|
||||
|
||||
} // namespace low_precision
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
@@ -0,0 +1,20 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "low_precision/rt_info/skip_cleanup_attribute.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <iterator>
|
||||
#include <vector>
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace ov;
|
||||
|
||||
ov::Any SkipCleanupAttribute::create(
|
||||
const std::shared_ptr<ngraph::Node>& node) {
|
||||
auto& rt = node->get_rt_info();
|
||||
return (rt[SkipCleanupAttribute::get_type_info_static()] = SkipCleanupAttribute());
|
||||
}
|
||||
@@ -0,0 +1,221 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include "convert_fq_rnn_to_quantized_rnn.hpp"
|
||||
|
||||
#include <ngraph/opsets/opset9.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include "ngraph/except.hpp"
|
||||
#include "ngraph/node_output.hpp"
|
||||
#include "ngraph/type/element_type.hpp"
|
||||
#include <ngraph_ops/type_relaxed.hpp>
|
||||
|
||||
#include "ie_common.h"
|
||||
#include "itt.hpp"
|
||||
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
#include <cstdlib>
|
||||
|
||||
ov::intel_cpu::ConvertFqRnnToQuantizedRnn::ConvertFqRnnToQuantizedRnn() {
|
||||
MATCHER_SCOPE(ConvertFqRnnToQuantizedRnn);
|
||||
|
||||
auto X_m = ngraph::pattern::any_input();
|
||||
auto convert_X = ngraph::pattern::wrap_type<ngraph::opset9::Convert>({X_m});
|
||||
auto input_shift_X = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
|
||||
auto subtract_X = ngraph::pattern::wrap_type<ngraph::opset9::Subtract>({convert_X, input_shift_X});
|
||||
auto input_scale_X = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
|
||||
|
||||
auto deq_X = std::make_shared<ngraph::pattern::op::Or>(
|
||||
OutputVector{
|
||||
ngraph::pattern::wrap_type<ngraph::opset9::Multiply>({convert_X, input_scale_X}),
|
||||
ngraph::pattern::wrap_type<ngraph::opset9::Multiply>({subtract_X, input_scale_X}),
|
||||
});
|
||||
|
||||
auto H_m = ngraph::pattern::any_input();
|
||||
auto convert_H = ngraph::pattern::wrap_type<ngraph::opset9::Convert>({H_m});
|
||||
auto input_shift_H = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
|
||||
auto subtract_H = ngraph::pattern::wrap_type<ngraph::opset9::Subtract>({convert_H, input_shift_H});
|
||||
auto input_scale_H = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
|
||||
|
||||
auto deq_H = std::make_shared<ngraph::pattern::op::Or>(
|
||||
OutputVector{
|
||||
ngraph::pattern::wrap_type<ngraph::opset9::Multiply>({convert_H, input_scale_H}),
|
||||
ngraph::pattern::wrap_type<ngraph::opset9::Multiply>({subtract_H, input_scale_H}),
|
||||
});
|
||||
|
||||
auto H_as_const = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
|
||||
auto H_in = std::make_shared<ngraph::pattern::op::Or>(
|
||||
OutputVector {
|
||||
deq_H,
|
||||
H_as_const
|
||||
});
|
||||
|
||||
auto cell_state_m = ngraph::pattern::any_input(); // for LSTM
|
||||
auto sequence_length_m = ngraph::pattern::any_input(); // for Sequences
|
||||
|
||||
auto W_m = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
|
||||
auto convert_W = ngraph::pattern::wrap_type<ngraph::opset9::Convert>({W_m});
|
||||
auto weights_scale_W = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
|
||||
auto deq_W = ngraph::pattern::wrap_type<ngraph::opset9::Multiply>({convert_W, weights_scale_W});
|
||||
|
||||
auto R_m = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
|
||||
auto convert_R = ngraph::pattern::wrap_type<ngraph::opset9::Convert>({R_m});
|
||||
auto weights_scale_R = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
|
||||
auto deq_R = ngraph::pattern::wrap_type<ngraph::opset9::Multiply>({convert_R, weights_scale_R});
|
||||
|
||||
const auto B_m = ngraph::pattern::wrap_type<ngraph::opset9::Constant>();
|
||||
|
||||
auto lstm_seq_m = ngraph::pattern::wrap_type<ngraph::opset9::LSTMSequence>({deq_X, H_in, cell_state_m, sequence_length_m, deq_W, deq_R, B_m});
|
||||
auto gru_seq_m = ngraph::pattern::wrap_type<ngraph::opset9::GRUSequence> ({deq_X, H_in, sequence_length_m, deq_W, deq_R, B_m});
|
||||
|
||||
auto rnn_pattern = std::make_shared<ngraph::pattern::op::Or>(
|
||||
OutputVector {
|
||||
lstm_seq_m,
|
||||
gru_seq_m
|
||||
});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto rnn = m.get_match_root();
|
||||
|
||||
if (!rnn || transformation_callback(rnn))
|
||||
return false;
|
||||
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
const auto& activation = pattern_map.at(X_m);
|
||||
const auto hidden_state_it = pattern_map.find(H_m);
|
||||
|
||||
ngraph::Output<ngraph::Node> hidden_state;
|
||||
if (hidden_state_it != pattern_map.end()) { // is it H(i8/u8) -> dequantized -> RNN pattern?
|
||||
hidden_state = hidden_state_it->second;
|
||||
} else {
|
||||
hidden_state = pattern_map.at(H_as_const); // if not, then it is just H (f32 const) -> RNN
|
||||
}
|
||||
|
||||
const auto& weights = pattern_map.at(W_m);
|
||||
const auto& r_weights = pattern_map.at(R_m);
|
||||
const auto& bias = pattern_map.at(B_m);
|
||||
|
||||
std::shared_ptr<ngraph::Node> rnn_quantized;
|
||||
|
||||
if (const auto lstm_seq = ngraph::as_type_ptr<ngraph::opset9::LSTMSequence>(rnn)) {
|
||||
const auto& cell_state = pattern_map.at(cell_state_m);
|
||||
const auto& sequence_length = pattern_map.at(sequence_length_m);
|
||||
|
||||
// @todo prototype removal of unnecessary fq between two consequtive rnn nodes
|
||||
auto rnn_quantized_tr = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset9::LSTMSequence>>(
|
||||
element::TypeVector{ element::f32, element::f32, element::f32, element::f32, element::f32, element::f32, element::f32 },
|
||||
element::TypeVector{ element::f32, element::f32, element::f32 },
|
||||
ngraph::op::TemporaryReplaceOutputType(activation, element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(hidden_state, element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(cell_state, element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(sequence_length, element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(weights, element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(r_weights, element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(bias, element::f32).get(),
|
||||
lstm_seq->get_hidden_size(),
|
||||
lstm_seq->get_direction(),
|
||||
lstm_seq->get_activations_alpha(),
|
||||
lstm_seq->get_activations_beta(),
|
||||
lstm_seq->get_activations(),
|
||||
lstm_seq->get_clip());
|
||||
|
||||
rnn_quantized_tr->set_overridden_output_type(hidden_state.get_element_type(), 1);
|
||||
rnn_quantized = rnn_quantized_tr;
|
||||
} else if (const auto gru_seq = ngraph::as_type_ptr<ngraph::opset9::GRUSequence>(rnn)) {
|
||||
const auto& sequence_length = pattern_map.at(sequence_length_m);
|
||||
|
||||
auto rnn_quantized_tr = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset9::GRUSequence>>(
|
||||
std::vector<ngraph::element::Type>{ element::f32, element::f32, element::f32, element::f32, element::f32, element::f32},
|
||||
std::vector<ngraph::element::Type>{ element::f32, element::f32 },
|
||||
ngraph::op::TemporaryReplaceOutputType(activation, element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(hidden_state, element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(sequence_length, element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(weights, element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(r_weights, element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(bias, element::f32).get(),
|
||||
gru_seq->get_hidden_size(),
|
||||
gru_seq->get_direction(),
|
||||
gru_seq->get_activations(),
|
||||
gru_seq->get_activations_alpha(),
|
||||
gru_seq->get_activations_beta(),
|
||||
gru_seq->get_clip(),
|
||||
gru_seq->get_linear_before_reset());
|
||||
|
||||
rnn_quantized_tr->set_overridden_output_type(hidden_state.get_element_type(), 1);
|
||||
rnn_quantized = rnn_quantized_tr;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
// input scales (Multiply per tensor) and weights_scales (Multiply per multiple dimensions) must be present
|
||||
const auto& input_scale_output = pattern_map.at(input_scale_X);
|
||||
const auto& weights_scale_output = pattern_map.at(weights_scale_W);
|
||||
// extract constant values
|
||||
const auto input_scale_constant = std::dynamic_pointer_cast<ngraph::opset9::Constant>(input_scale_output.get_node_shared_ptr());
|
||||
const auto weights_scale_constant = std::dynamic_pointer_cast<ngraph::opset9::Constant>(weights_scale_output.get_node_shared_ptr());
|
||||
|
||||
if (!input_scale_constant || !weights_scale_constant)
|
||||
return false;
|
||||
|
||||
const float* input_scale_ptr = input_scale_constant->get_data_ptr<float>();
|
||||
if (*input_scale_ptr == 0.f)
|
||||
throw ngraph::ngraph_error("Cannot handle zero input scale");
|
||||
|
||||
const float input_scale = 1 / *input_scale_ptr;
|
||||
const std::vector<float> weights_scales = weights_scale_constant->get_vector<float>();
|
||||
|
||||
auto& runtime_info = rnn_quantized->get_rt_info();
|
||||
|
||||
runtime_info["inputScale"] = input_scale;
|
||||
runtime_info["weightsScales"] = weights_scales;
|
||||
|
||||
// input shift (Subtract) is optional
|
||||
const auto input_shift_it = pattern_map.find(input_shift_X);
|
||||
|
||||
if (input_shift_it != pattern_map.end()) {
|
||||
const auto input_shift_constant = std::dynamic_pointer_cast<ngraph::opset9::Constant>(input_shift_it->second.get_node_shared_ptr());
|
||||
const float* input_shift_ptr = input_shift_constant->get_data_ptr<float>();
|
||||
|
||||
runtime_info["inputShift"] = *input_shift_ptr;
|
||||
}
|
||||
|
||||
auto H_outputs = rnn->output(1).get_target_inputs();
|
||||
rnn_quantized->set_friendly_name(rnn->get_friendly_name());
|
||||
ngraph::copy_runtime_info(rnn, rnn_quantized);
|
||||
ngraph::replace_node(rnn, rnn_quantized);
|
||||
|
||||
/* in case of pattern:
|
||||
* H(u8,i8) -> dequantize -> RNN
|
||||
* dequantize has to be inserted after H output port since
|
||||
* oneDNN supports only equal data types on H in/out ports
|
||||
* either: u8u8, i8i8 or f32f32 */
|
||||
if (hidden_state_it != pattern_map.end()) {
|
||||
const auto& convert = pattern_map.at(convert_H).get_node_shared_ptr();
|
||||
const auto subtract_it = pattern_map.find(subtract_H);
|
||||
const auto& multiply = rnn->get_input_node_shared_ptr(1);
|
||||
|
||||
auto new_convert = convert->clone_with_new_inputs({rnn_quantized->output(1)});
|
||||
std::shared_ptr<Node> multiply_in = new_convert;
|
||||
// dequantize with subtract
|
||||
if (subtract_it != pattern_map.end()) {
|
||||
const auto subtract = std::dynamic_pointer_cast<ngraph::opset9::Subtract>(subtract_it->second.get_node_shared_ptr());
|
||||
auto new_subtract = subtract->clone_with_new_inputs({rnn_quantized->output(1), subtract->input_value(1)});
|
||||
multiply_in = new_subtract;
|
||||
}
|
||||
|
||||
auto new_multiply = multiply->clone_with_new_inputs({multiply_in, multiply->input_value(1)});
|
||||
|
||||
for (auto output : H_outputs) {
|
||||
output.replace_source_output(new_multiply);
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(rnn_pattern, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
/*
|
||||
* Description:
|
||||
* ConvertFqRnnToQuantizedRnn detects RNN / LSTM / GRU_RNN operations
|
||||
* with FQ operations on the inputs and forms a new TypeRelaxed operation
|
||||
* with quantization parameters as runtime parameters of the operation.
|
||||
* @todo add ascii graph examples
|
||||
*/
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
class ConvertFqRnnToQuantizedRnn: public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ConvertFqRnnToQuantizedRnn", "0");
|
||||
ConvertFqRnnToQuantizedRnn();
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
@@ -741,7 +741,8 @@ void Node::prepareMemory(const std::vector<DnnlMemoryDescPtr>& intDescs) {
|
||||
}
|
||||
|
||||
if (internalBlobs.size() != intDescs.size()) {
|
||||
IE_THROW() << "Can't prepare memory for internal blob, internal blobs and internal descs number do not match";
|
||||
IE_THROW() << "Can't prepare memory for internal blob, internal blob and internal descs number do not match "
|
||||
<< internalBlobs.size() << " vs " << intDescs.size();
|
||||
}
|
||||
|
||||
internalBlobMemory.clear();
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "rnn.h"
|
||||
#include <utils/general_utils.h>
|
||||
#include "ie_precision.hpp"
|
||||
#include "nodes/common/cpu_memcpy.h"
|
||||
#include "nodes/common/cpu_convert.h"
|
||||
#include "utils/bfloat16.hpp"
|
||||
@@ -17,6 +18,7 @@
|
||||
|
||||
#include <ngraph/node.hpp>
|
||||
|
||||
#include <oneapi/dnnl/dnnl.hpp>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
@@ -124,13 +126,13 @@ inline bool haveAttention(const dnnl::algorithm& alg) {
|
||||
return alg == dnnl::algorithm::vanilla_augru || alg == dnnl::algorithm::lbr_augru;
|
||||
}
|
||||
|
||||
const std::map<Precision, Precision> RNN::weightsByLayerPrec {
|
||||
// layer precision, weights precision
|
||||
{Precision::FP32, Precision::FP32},
|
||||
{Precision::BF16, Precision::BF16},
|
||||
// FP16 and U8 are not supported yet
|
||||
// {Precision::FP16, Precision::FP16},
|
||||
// {Precision::U8, Precision::I8},
|
||||
// what weight data type should be used for particular input data type
|
||||
const std::map<memory::data_type, memory::data_type> RNN::weightsByinputDataType {
|
||||
// layer data type weights data type
|
||||
{memory::data_type::f32, memory::data_type::f32},
|
||||
{memory::data_type::bf16, memory::data_type::bf16},
|
||||
{memory::data_type::u8, memory::data_type::s8},
|
||||
{memory::data_type::s8, memory::data_type::s8},
|
||||
};
|
||||
|
||||
|
||||
@@ -292,20 +294,29 @@ RNN::RNN(const std::shared_ptr<ov::Node>& op, const dnnl::engine& eng, WeightsSh
|
||||
|
||||
if (one_of(op->get_type_info(),
|
||||
ov::op::v0::RNNCell::get_type_info_static(),
|
||||
ov::op::v3::GRUCell::get_type_info_static(),
|
||||
ov::op::internal::AUGRUCell::get_type_info_static())) {
|
||||
ov::op::v3::GRUCell::get_type_info_static())) {
|
||||
wIdx = 2; rIdx = 3; bIdx = 4;
|
||||
hoIdx = 0;
|
||||
} else if (op->get_type_info() == ov::op::internal::AUGRUCell::get_type_info_static()) {
|
||||
wIdx = 2; rIdx = 3; bIdx = 4; aIdx = 5;
|
||||
} else if (one_of(op->get_type_info(),
|
||||
ov::op::v0::LSTMCell::get_type_info_static(),
|
||||
ov::op::v4::LSTMCell::get_type_info_static())) {
|
||||
wIdx = 3; rIdx = 4; bIdx = 5;
|
||||
yIdx = hoIdx = 0; coIdx = 1;
|
||||
} else if (one_of(op->get_type_info(),
|
||||
ov::op::v5::RNNSequence::get_type_info_static(),
|
||||
ov::op::v0::LSTMCell::get_type_info_static(),
|
||||
ov::op::v4::LSTMCell::get_type_info_static(),
|
||||
ov::op::v5::GRUSequence::get_type_info_static(),
|
||||
ov::op::internal::AUGRUSequence::get_type_info_static())) {
|
||||
wIdx = 3; rIdx = 4; bIdx = 5;
|
||||
ov::op::v5::GRUSequence::get_type_info_static())) {
|
||||
sIdx = 2; wIdx = 3; rIdx = 4; bIdx = 5;
|
||||
yIdx = 0; hoIdx = 1;
|
||||
} else if (op->get_type_info() == ov::op::internal::AUGRUSequence::get_type_info_static()) {
|
||||
sIdx = 2; wIdx = 3; rIdx = 4; bIdx = 5; aIdx = 6;
|
||||
yIdx = 0; hoIdx = 1;
|
||||
} else if (one_of(op->get_type_info(),
|
||||
ov::op::v0::LSTMSequence::get_type_info_static(),
|
||||
ov::op::v5::LSTMSequence::get_type_info_static())) {
|
||||
wIdx = 4; rIdx = 5; bIdx = 6;
|
||||
sIdx = 3; wIdx = 4; rIdx = 5; bIdx = 6;
|
||||
yIdx = 0; hoIdx = 1; coIdx = 2;
|
||||
}
|
||||
|
||||
auto rnnCellBase = std::dynamic_pointer_cast<ngraph::op::util::RNNCellBase>(op);
|
||||
@@ -322,26 +333,62 @@ RNN::RNN(const std::shared_ptr<ov::Node>& op, const dnnl::engine& eng, WeightsSh
|
||||
SC = rnnCellBase->get_hidden_size();
|
||||
N = {getInputShapeAtPort(0).getMinDims()[0], getInputShapeAtPort(0).getMaxDims()[0]};
|
||||
|
||||
const auto& rtInfo = op->get_rt_info();
|
||||
|
||||
if (rtInfo.count("inputScale"))
|
||||
inputScale = rtInfo.at("inputScale").as<float>();
|
||||
|
||||
if (rtInfo.count("inputShift"))
|
||||
inputShift = rtInfo.at("inputShift").as<float>();
|
||||
|
||||
if (rtInfo.count("weightsScales"))
|
||||
weightsScales = rtInfo.at("weightsScales").as<std::vector<float>>();
|
||||
|
||||
if (is_cell) {
|
||||
initCell();
|
||||
} else {
|
||||
direction = ieDirection2dnnl(op);
|
||||
|
||||
nativeOrder = false;
|
||||
const auto& rtInfo = op->get_rt_info();
|
||||
if (rtInfo.count("seqAxis")) {
|
||||
nativeOrder = rtInfo.at("seqAxis").as<int64_t>() == 0;
|
||||
}
|
||||
|
||||
initSequence();
|
||||
}
|
||||
|
||||
inDataTypes.reserve(getOriginalInputsNumber());
|
||||
outDataTypes.reserve(getOriginalOutputsNumber());
|
||||
}
|
||||
|
||||
bool RNN::created() const {
|
||||
return getType() == (is_cell ? Type::RNNCell : Type::RNNSeq);
|
||||
}
|
||||
|
||||
void RNN::configurePortDataTypes() {
|
||||
inDataTypes[xIdx] = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(0));
|
||||
inDataTypes[hIdx] = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(1));
|
||||
if (haveCellState(cell_type))
|
||||
inDataTypes[cIdx] = memory::data_type::f32; // @todo bf16 is also allowed, should be tried out
|
||||
if (!is_cell)
|
||||
inDataTypes[sIdx] = memory::data_type::s32;
|
||||
inDataTypes[wIdx] = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(wIdx));
|
||||
inDataTypes[rIdx] = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(wIdx));
|
||||
|
||||
inDataTypes[bIdx] = memory::data_type::f32;
|
||||
if (haveAttention(cell_type))
|
||||
inDataTypes[aIdx] = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(aIdx));
|
||||
|
||||
if (!is_cell)
|
||||
outDataTypes[yIdx] = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalOutputPrecisionAtPort(0));
|
||||
outDataTypes[hoIdx] = inDataTypes[hIdx]; // required by oneDNN. Output hidden state is a input hidden state for the next iteration
|
||||
if (haveCellState(cell_type))
|
||||
outDataTypes[coIdx] = inDataTypes[cIdx]; // required by oneDNN.
|
||||
}
|
||||
|
||||
void RNN::getSupportedDescriptors() {
|
||||
configurePortDataTypes();
|
||||
|
||||
if (is_cell)
|
||||
fillCellDesc();
|
||||
else
|
||||
@@ -386,10 +433,9 @@ void RNN::initCell() {
|
||||
}
|
||||
|
||||
void RNN::fillCellDesc() {
|
||||
const auto dataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(0));
|
||||
const Shape shapeS_4D = MemoryDescUtils::makeDummyShape({{L, D, N.minVal, SC}, {L, D, N.maxVal, SC}}),
|
||||
inShape = MemoryDescUtils::makeDummyShape({{T.minVal, N.minVal, DC}, {T.maxVal, N.maxVal, DC}}),
|
||||
outShape = MemoryDescUtils::makeDummyShape({{T.minVal, N.minVal, SC}, {T.maxVal, N.maxVal, SC}});
|
||||
const Shape shapeS_4D = MemoryDescUtils::makeDummyShape({{L, D, N.minVal, SC}, {L, D, N.maxVal, SC}});
|
||||
const Shape inShape = MemoryDescUtils::makeDummyShape({{T.minVal, N.minVal, DC}, {T.maxVal, N.maxVal, DC}});
|
||||
const Shape outShape = MemoryDescUtils::makeDummyShape({{T.minVal, N.minVal, D * SC}, {T.maxVal, N.maxVal, D * SC}});
|
||||
|
||||
// layer input plus states
|
||||
if (haveAttention(cell_type)) {
|
||||
@@ -399,47 +445,52 @@ void RNN::fillCellDesc() {
|
||||
}
|
||||
outDataDescs.reserve(S + 1);
|
||||
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(inShape, dataType, memory::format_tag::tnc));
|
||||
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(outShape, dataType, memory::format_tag::tnc));
|
||||
// @todo use indexies instead of emplacing back, since order matters
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(inShape, inDataTypes[xIdx], memory::format_tag::tnc));
|
||||
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(outShape, outDataTypes[yIdx], memory::format_tag::tnc));
|
||||
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc));
|
||||
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc));
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, inDataTypes[hIdx], memory::format_tag::ldnc));
|
||||
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, outDataTypes[hoIdx], memory::format_tag::ldnc));
|
||||
|
||||
if (haveCellState(cell_type)) {
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc));
|
||||
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc));
|
||||
}
|
||||
|
||||
if (haveAttention(cell_type)) {
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, inDataTypes[cIdx], memory::format_tag::ldnc));
|
||||
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, inDataTypes[coIdx], memory::format_tag::ldnc));
|
||||
} else if (haveAttention(cell_type)) {
|
||||
const Shape attnShape = MemoryDescUtils::makeDummyShape({{T.minVal, N.minVal, 1}, {T.maxVal, N.maxVal, 1}});
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(attnShape, dataType, memory::format_tag::tnc));
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(attnShape, inDataTypes[aIdx], memory::format_tag::tnc));
|
||||
}
|
||||
|
||||
copyWeightsData();
|
||||
|
||||
// Expected shapes.
|
||||
Shape shapeD{{N.minVal, DC}, {N.maxVal, DC}}, shapeS{{N.minVal, SC}, {N.maxVal, SC}},
|
||||
WShape{SC * G, DC}, RShape{SC * G, SC}, BShape{SC * Gb};
|
||||
std::vector<MemoryDescPtr> inCandidate, outCandidate;
|
||||
inCandidate.reserve(6);
|
||||
const Shape shapeD{{N.minVal, DC}, {N.maxVal, DC}};
|
||||
const Shape shapeS{{N.minVal, SC}, {N.maxVal, SC}};
|
||||
const Shape WShape{SC * G, DC};
|
||||
const Shape RShape{SC * G, SC};
|
||||
const Shape BShape{SC * Gb};
|
||||
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeD, dataType, memory::format_tag::nc));
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS, dataType, memory::format_tag::nc));
|
||||
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS, dataType, memory::format_tag::nc));
|
||||
std::vector<MemoryDescPtr> inCandidate, outCandidate;
|
||||
|
||||
inCandidate.reserve(getOriginalInputsNumber());
|
||||
outCandidate.reserve(getOriginalOutputsNumber());
|
||||
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeD, inDataTypes[xIdx], memory::format_tag::nc));
|
||||
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS, inDataTypes[hIdx], memory::format_tag::nc));
|
||||
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS, outDataTypes[hoIdx], memory::format_tag::nc));
|
||||
|
||||
if (haveCellState(cell_type)) {
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS, memory::data_type::f32, memory::format_tag::nc));
|
||||
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS, memory::data_type::f32, memory::format_tag::nc));
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS, inDataTypes[cIdx], memory::format_tag::nc));
|
||||
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS, outDataTypes[coIdx], memory::format_tag::nc));
|
||||
}
|
||||
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(WShape, memory::data_type::f32, memory::format_tag::nc));
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(RShape, memory::data_type::f32, memory::format_tag::nc));
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(BShape, memory::data_type::f32, memory::format_tag::x));
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(WShape, inDataTypes[wIdx], memory::format_tag::nc));
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(RShape, inDataTypes[rIdx], memory::format_tag::nc));
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(BShape, inDataTypes[bIdx], memory::format_tag::x));
|
||||
|
||||
// note: the order matters. attention is the last input of augru.
|
||||
if (haveAttention(cell_type)) {
|
||||
Shape shapeAttn{{N.minVal, 1}, {N.maxVal, 1}};
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeAttn, dataType, memory::format_tag::nc));
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeAttn, inDataTypes[aIdx], memory::format_tag::nc));
|
||||
}
|
||||
|
||||
createDescriptor(inCandidate, outCandidate);
|
||||
@@ -470,118 +521,94 @@ void RNN::initSequence() {
|
||||
} else {
|
||||
inDataDescs.reserve(S + 1);
|
||||
}
|
||||
|
||||
outDataDescs.reserve(S + 1);
|
||||
}
|
||||
|
||||
void RNN::fillSequenceDesc() {
|
||||
const auto dataType = DnnlExtensionUtils::IEPrecisionToDataType(getOriginalInputPrecisionAtPort(0));
|
||||
const Shape shapeS_4D = MemoryDescUtils::makeDummyShape({{L, D, N.minVal, SC}, {L, D, N.maxVal, SC}}),
|
||||
inShape = MemoryDescUtils::makeDummyShape({{T.minVal, N.minVal, DC}, {T.maxVal, N.maxVal, DC}}),
|
||||
outShape = MemoryDescUtils::makeDummyShape({{T.minVal, N.minVal, SC}, {T.maxVal, N.maxVal, SC}}),
|
||||
shapeNDSC {{N.minVal, D, SC}, {N.maxVal, D, SC}},
|
||||
shapeNTSC {{N.minVal, T.minVal, SC}, {N.maxVal, T.maxVal, SC}},
|
||||
shapeNTDC {{N.minVal, T.minVal, DC}, {N.maxVal, T.maxVal, DC}};
|
||||
const Shape shapeS_4D = MemoryDescUtils::makeDummyShape({{L, D, N.minVal, SC}, {L, D, N.maxVal, SC}});
|
||||
const Shape inShape = MemoryDescUtils::makeDummyShape({{T.minVal, N.minVal, DC}, {T.maxVal, N.maxVal, DC}});
|
||||
const Shape outShape = MemoryDescUtils::makeDummyShape({{T.minVal, N.minVal, D * SC}, {T.maxVal, N.maxVal, D * SC}});
|
||||
|
||||
// Try to create descriptor and corresponding configuration
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(inShape, dataType, memory::format_tag::tnc));
|
||||
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(outShape, dataType, memory::format_tag::tnc));
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(inShape, inDataTypes[xIdx], memory::format_tag::tnc));
|
||||
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(outShape, outDataTypes[yIdx], memory::format_tag::tnc));
|
||||
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc));
|
||||
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc));
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, inDataTypes[hIdx], memory::format_tag::ldnc));
|
||||
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, outDataTypes[hoIdx], memory::format_tag::ldnc));
|
||||
|
||||
if (haveCellState(cell_type)) {
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc));
|
||||
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc));
|
||||
}
|
||||
if (haveAttention(cell_type)) {
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, inDataTypes[cIdx], memory::format_tag::ldnc));
|
||||
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, outDataTypes[coIdx], memory::format_tag::ldnc));
|
||||
} else if (haveAttention(cell_type)) {
|
||||
const Shape attnShape = MemoryDescUtils::makeDummyShape({{T.minVal, N.minVal, 1}, {T.maxVal, N.maxVal, 1}});
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(attnShape, dataType, memory::format_tag::tnc));
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(attnShape, inDataTypes[aIdx], memory::format_tag::tnc));
|
||||
}
|
||||
|
||||
copyWeightsData();
|
||||
|
||||
std::vector<MemoryDescPtr> inCandidate;
|
||||
inCandidate.reserve(7);
|
||||
const Shape shapeNDSC {{N.minVal, D, SC}, {N.maxVal, D, SC}};
|
||||
Shape shapeNTSC {{N.minVal, T.minVal, SC}, {N.maxVal, T.maxVal, SC}};
|
||||
const Shape shapeNTDC {{N.minVal, T.minVal, DC}, {N.maxVal, T.maxVal, DC}};
|
||||
const Shape TShape {VectorDims{N.minVal}, VectorDims{N.maxVal}};
|
||||
const Shape WShape {D, G * SC, DC};
|
||||
const Shape RShape {D, G * SC, SC};
|
||||
const Shape BShape {D, Gb * SC};
|
||||
|
||||
if (nativeOrder)
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(inputShapes[RNNInOutKind::Layer], dataType, memory::format_tag::tnc));
|
||||
else if (N.isStatic() && N.maxVal == 1)
|
||||
// WA to avoid reorder before sequence for some models.
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNTDC, dataType, memory::format_tag::tnc));
|
||||
else
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNTDC, dataType, memory::format_tag::ntc));
|
||||
std::vector<MemoryDescPtr> inCandidate, outCandidate;
|
||||
|
||||
// Initial hidden state.
|
||||
// WA to avoid reorder before.
|
||||
if (D == 1)
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNDSC, dataType, memory::format_tag::tnc));
|
||||
else
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNDSC, dataType, memory::format_tag::ntc));
|
||||
inCandidate.reserve(getOriginalInputsNumber());
|
||||
outCandidate.reserve(getOriginalOutputsNumber());
|
||||
|
||||
auto srcLayerMemoryFormat = memory::format_tag::undef;
|
||||
auto dstLayerMemoryFormat = memory::format_tag::undef;
|
||||
|
||||
if (nativeOrder) {
|
||||
srcLayerMemoryFormat = memory::format_tag::tnc;
|
||||
dstLayerMemoryFormat = memory::format_tag::abcd;
|
||||
shapeNTSC = {{N.minVal, D, T.minVal, SC}, {N.maxVal, D, T.maxVal, SC}};
|
||||
} else if (N.isStatic() && N.maxVal == 1) {
|
||||
srcLayerMemoryFormat = memory::format_tag::tnc;
|
||||
dstLayerMemoryFormat = memory::format_tag::tnc;
|
||||
} else {
|
||||
srcLayerMemoryFormat = memory::format_tag::ntc;
|
||||
dstLayerMemoryFormat = memory::format_tag::ntc;
|
||||
}
|
||||
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNTDC, inDataTypes[xIdx], srcLayerMemoryFormat));
|
||||
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNTSC, outDataTypes[yIdx], dstLayerMemoryFormat));
|
||||
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNDSC, inDataTypes[hIdx], memory::format_tag::tnc));
|
||||
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNDSC, outDataTypes[hoIdx], memory::format_tag::tnc));
|
||||
|
||||
// initial cell state
|
||||
if (haveCellState(cell_type)) {
|
||||
if (D == 1)
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNDSC, memory::data_type::f32, memory::format_tag::tnc));
|
||||
else
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNDSC, memory::data_type::f32, memory::format_tag::ntc));
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNDSC, inDataTypes[cIdx], memory::format_tag::tnc));
|
||||
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNDSC, outDataTypes[coIdx], memory::format_tag::tnc));
|
||||
}
|
||||
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(Shape{VectorDims{N.minVal}, VectorDims{N.maxVal}},
|
||||
memory::data_type::s32, memory::format_tag::x)); // sequence lengths
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(Shape{D, G * SC, DC}, memory::data_type::f32, memory::format_tag::ntc)); // W
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(Shape{D, G * SC, SC}, memory::data_type::f32, memory::format_tag::ntc)); // R
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(Shape{D, Gb * SC}, memory::data_type::f32, memory::format_tag::nc)); // B
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(TShape, inDataTypes[sIdx], memory::format_tag::x)); // sequence lengths
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(WShape, inDataTypes[wIdx], memory::format_tag::ntc)); // W
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(RShape, inDataTypes[rIdx], memory::format_tag::ntc)); // R
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(BShape, inDataTypes[bIdx], memory::format_tag::nc)); // B
|
||||
|
||||
// note: the order matters. attention is the last input of augru.
|
||||
if (haveAttention(cell_type)) {
|
||||
Shape shapeAttn{{N.minVal, T.minVal, 1}, {N.maxVal, T.maxVal, 1}};
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeAttn, dataType, memory::format_tag::ntc));
|
||||
}
|
||||
|
||||
std::vector<MemoryDescPtr> outCandidate;
|
||||
outCandidate.reserve(3);
|
||||
|
||||
if (nativeOrder) {
|
||||
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(getOutputShapeAtPort(0), dataType, memory::format_tag::abcd));
|
||||
} else if (N.isStatic() && N.maxVal == 1) {
|
||||
// WA to avoid reorder after sequence for some models
|
||||
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNTSC, dataType, memory::format_tag::tnc));
|
||||
} else {
|
||||
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNTSC, dataType, memory::format_tag::ntc));
|
||||
}
|
||||
|
||||
// WA to avoid reorder after
|
||||
if (D == 1)
|
||||
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNDSC, dataType, memory::format_tag::tnc));
|
||||
else
|
||||
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNDSC, dataType, memory::format_tag::ntc));
|
||||
|
||||
if (haveCellState(cell_type)) {
|
||||
if (D == 1)
|
||||
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNDSC, memory::data_type::f32, memory::format_tag::tnc));
|
||||
else
|
||||
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNDSC, memory::data_type::f32, memory::format_tag::ntc));
|
||||
inCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeAttn, inDataTypes[aIdx], memory::format_tag::ntc));
|
||||
}
|
||||
|
||||
createDescriptor(inCandidate, outCandidate);
|
||||
}
|
||||
|
||||
bool RNN::verifyWeightsPrecision(const Precision &layerPrec, const Precision &weightsPrec) {
|
||||
if (!weightsByLayerPrec.count(layerPrec))
|
||||
THROW_ERROR << "has unsupported layer precision " << layerPrec;
|
||||
return weightsPrec == weightsByLayerPrec.at(layerPrec);
|
||||
}
|
||||
|
||||
template <typename Prec>
|
||||
void RNN::fillWeights(const int *gate_map, const size_t wIdx, const size_t rIdx) {
|
||||
const auto& dataPrecision = getOriginalInputPrecisionAtPort(0);
|
||||
const auto& weightPrec = getOriginalInputPrecisionAtPort(wIdx);
|
||||
if (!verifyWeightsPrecision(dataPrecision, weightPrec) && dataPrecision != Precision::BF16 && weightPrec != Precision::FP32) {
|
||||
THROW_ERROR << "doesn't support combination of weights precision: " << weightPrec << " and runtime precision: " << dataPrecision;
|
||||
}
|
||||
const auto& weightPrec = DnnlExtensionUtils::DataTypeToIEPrecision(inDataTypes[wIdx]);
|
||||
const auto& targetWeightPrec = DnnlExtensionUtils::DataTypeToIEPrecision(weightsByinputDataType.at(inDataTypes[xIdx]));
|
||||
|
||||
// create weight blobs (data and state part)
|
||||
const VectorDims dims_w = { L, D, DC, G, SC };
|
||||
TensorDesc w_data_desc(dataPrecision, dims_w, getWeightsLayoutByDims(dims_w, false));
|
||||
TensorDesc w_data_desc(targetWeightPrec, dims_w, getWeightsLayoutByDims(dims_w, false));
|
||||
|
||||
Blob::Ptr w_data_mem = make_shared_blob<Prec>(w_data_desc);
|
||||
w_data_mem->allocate();
|
||||
auto w_ptr = static_cast<Prec*>(w_data_mem->buffer());
|
||||
@@ -589,7 +616,7 @@ void RNN::fillWeights(const int *gate_map, const size_t wIdx, const size_t rIdx)
|
||||
IE_THROW(NotAllocated) << "Internal blob was not allocated for node " << getName() << ".";
|
||||
|
||||
const VectorDims dims_s = { L, D, SC, G, SC };
|
||||
TensorDesc w_state_desc(dataPrecision, dims_s, getWeightsLayoutByDims(dims_s, false));
|
||||
TensorDesc w_state_desc(targetWeightPrec, dims_s, getWeightsLayoutByDims(dims_s, false));
|
||||
Blob::Ptr w_state_mem = make_shared_blob<Prec>(w_state_desc);
|
||||
w_state_mem->allocate();
|
||||
auto r_ptr = static_cast<Prec*>(w_state_mem->buffer());
|
||||
@@ -609,8 +636,9 @@ void RNN::fillWeights(const int *gate_map, const size_t wIdx, const size_t rIdx)
|
||||
|
||||
auto ie_w_ptr = ie_w_vec.data();
|
||||
auto ie_r_ptr = ie_r_vec.data();
|
||||
cpu_convert(wConstBlob->GetPtr(), ie_w_ptr, weightPrec, dataPrecision, ie_w_vec_size);
|
||||
cpu_convert(rConstBlob->GetPtr(), ie_r_ptr, weightPrec, dataPrecision, ie_r_vec_size);
|
||||
|
||||
cpu_convert(wConstBlob->GetPtr(), ie_w_ptr, weightPrec, targetWeightPrec, ie_w_vec_size);
|
||||
cpu_convert(rConstBlob->GetPtr(), ie_r_ptr, weightPrec, targetWeightPrec, ie_r_vec_size);
|
||||
|
||||
const int step = SC * G;
|
||||
|
||||
@@ -668,6 +696,7 @@ void RNN::fillBiases(const int *gate_map) {
|
||||
const dataType *l_ie_b_ptr = &ie_b_vec[g * SC];
|
||||
cpu_memcpy(l_b_ptr, l_ie_b_ptr, SC * sizeof(typename PrecisionTrait<Prec>::value_type));
|
||||
}
|
||||
// @todo replace push_back with copy assignment by index, since order matters
|
||||
internalBlobs.push_back(w_bias_data_mem);
|
||||
}
|
||||
|
||||
@@ -733,12 +762,13 @@ void RNN::copyWeightsData() {
|
||||
if (T.minVal > 1 || N.maxVal < optimalBatchSize)
|
||||
wFormat = dnnl::memory::format_tag::ldigo;
|
||||
fillWeights<float>(gate_map, wIdx, rIdx);
|
||||
} else {// TODO FP16 and INT8 support
|
||||
} else if (dataPrecision == Precision::U8 || dataPrecision == Precision::I8) {
|
||||
fillWeights<int8_t>(gate_map, wIdx, rIdx);
|
||||
} else {
|
||||
THROW_ERROR << "has unsupported data type: " << dataPrecision;
|
||||
}
|
||||
|
||||
if (dataPrecision == Precision::BF16 || dataPrecision == Precision::FP32)
|
||||
fillBiases<Precision::FP32>(gate_map);
|
||||
fillBiases<Precision::FP32>(gate_map);
|
||||
}
|
||||
|
||||
void RNN::fillDescs() {
|
||||
@@ -834,17 +864,19 @@ void RNN::fillDescs() {
|
||||
}
|
||||
|
||||
void RNN::createDescriptor(const std::vector<MemoryDescPtr> &inputDesc,
|
||||
const std::vector<MemoryDescPtr> &outputDesc) {
|
||||
const std::vector<MemoryDescPtr> &outputDesc) {
|
||||
if (descs.empty()) {
|
||||
wDescs.resize(3);
|
||||
const auto& dataPrecision = getOriginalInputPrecisionAtPort(0);
|
||||
auto dataType = DnnlExtensionUtils::IEPrecisionToDataType(dataPrecision);
|
||||
|
||||
/* for descriptor configuration use the same type which is used for internalBlobs
|
||||
since internalBlobs are used for the execution, not the initial weights */
|
||||
const auto& targetWeightDataType = weightsByinputDataType.at(inDataTypes[xIdx]);
|
||||
auto weightsDims = DnnlExtensionUtils::convertToDnnlDims(VectorDims{ L, D, DC, G, SC });
|
||||
wDescs[0] = dnnl::memory::desc(weightsDims, dataType, wFormat);
|
||||
wDescs[0] = dnnl::memory::desc(weightsDims, targetWeightDataType, wFormat);
|
||||
auto statesDims = DnnlExtensionUtils::convertToDnnlDims(VectorDims{ L, D, SC, G, SC });
|
||||
wDescs[1] = dnnl::memory::desc(statesDims, dataType, wFormat);
|
||||
wDescs[1] = dnnl::memory::desc(statesDims, targetWeightDataType, wFormat);
|
||||
auto biasDims = DnnlExtensionUtils::convertToDnnlDims(VectorDims{ L, D, Gb, SC });
|
||||
wDescs[2] = dnnl::memory::desc(biasDims, memory::data_type::f32, memory::format_tag::ldgo);
|
||||
wDescs[2] = dnnl::memory::desc(biasDims, inDataTypes[bIdx], memory::format_tag::ldgo);
|
||||
|
||||
fillDescs();
|
||||
}
|
||||
@@ -871,6 +903,20 @@ void RNN::createDescriptor(const std::vector<MemoryDescPtr> &inputDesc,
|
||||
supportedPrimitiveDescriptors.emplace_back(config, ref_any);
|
||||
}
|
||||
|
||||
Node::AttrPtr RNN::initPrimitiveAttr() {
|
||||
auto attr = std::make_shared<dnnl::primitive_attr>(dnnl::primitive_attr());
|
||||
attr->set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
if (one_of(getOriginalInputPrecisionAtPort(0), Precision::U8, Precision::I8)) {
|
||||
const int weightsScaleMask = 0;
|
||||
|
||||
attr->set_rnn_weights_qparams(weightsScaleMask, weightsScales);
|
||||
attr->set_rnn_data_qparams(inputScale, inputShift);
|
||||
}
|
||||
|
||||
return attr;
|
||||
}
|
||||
|
||||
void RNN::prepareParams() {
|
||||
for (size_t i = 0; i < wIdx; i++) {
|
||||
auto memPtr = getParentEdgesAtPort(i).front()->getMemoryPtr();
|
||||
@@ -878,31 +924,28 @@ void RNN::prepareParams() {
|
||||
THROW_ERROR << "has uninitialized memory at port " << i;
|
||||
}
|
||||
|
||||
const auto& dataPrecision = getOriginalInputPrecisionAtPort(0);
|
||||
const auto dataType = DnnlExtensionUtils::IEPrecisionToDataType(dataPrecision);
|
||||
|
||||
auto dataMemPtr = getParentEdgesAtPort(0).front()->getMemoryPtr();
|
||||
const size_t B = dataMemPtr->GetShape().getStaticDims()[0];
|
||||
const size_t SL = is_cell ? 1lu : dataMemPtr->GetShape().getStaticDims()[1];
|
||||
const Shape shapeS_4D{L, D, B, SC};
|
||||
|
||||
inDataDescs[0] = std::make_shared<DnnlBlockedMemoryDesc>(Shape{SL, B, DC}, dataType, memory::format_tag::tnc);
|
||||
outDataDescs[0] = std::make_shared<DnnlBlockedMemoryDesc>(Shape{SL, B, SC}, dataType, memory::format_tag::tnc);
|
||||
inDataDescs[0] = std::make_shared<DnnlBlockedMemoryDesc>(Shape{SL, B, DC}, inDataTypes[xIdx], memory::format_tag::tnc);
|
||||
outDataDescs[0] = std::make_shared<DnnlBlockedMemoryDesc>(Shape{SL, B, D * SC}, outDataTypes[yIdx], memory::format_tag::tnc);
|
||||
|
||||
inDataDescs[1] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc);
|
||||
outDataDescs[1] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc);
|
||||
inDataDescs[1] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, inDataTypes[hIdx], memory::format_tag::ldnc);
|
||||
outDataDescs[1] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, outDataTypes[hoIdx], memory::format_tag::ldnc);
|
||||
|
||||
if (haveCellState(cell_type)) {
|
||||
inDataDescs[2] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
|
||||
outDataDescs[2] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
|
||||
inDataDescs[2] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, inDataTypes[cIdx], memory::format_tag::ldnc);
|
||||
outDataDescs[2] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, outDataTypes[coIdx], memory::format_tag::ldnc);
|
||||
} else if (haveAttention(cell_type)) {
|
||||
inDataDescs[2] = std::make_shared<DnnlBlockedMemoryDesc>(Shape{SL, B, 1}, inDataTypes[aIdx], memory::format_tag::tnc);
|
||||
}
|
||||
|
||||
if (haveAttention(cell_type)) {
|
||||
inDataDescs[2] = std::make_shared<DnnlBlockedMemoryDesc>(Shape{SL, B, 1}, dataType, memory::format_tag::tnc);
|
||||
}
|
||||
bool wFormatWasChanged = false;
|
||||
// WA To avoid different weights layer and iter formats in FP32 case.
|
||||
if ((dataPrecision == Precision::FP32) && (SL != 1 || B < optimalBatchSize)) {
|
||||
if (one_of(inDataTypes[xIdx], memory::data_type::f32, memory::data_type::bf16) &&
|
||||
(SL != 1 || B < optimalBatchSize)) {
|
||||
if (wFormat != dnnl::memory::format_tag::ldigo) {
|
||||
wFormat = dnnl::memory::format_tag::ldigo;
|
||||
wFormatWasChanged = true;
|
||||
@@ -911,38 +954,40 @@ void RNN::prepareParams() {
|
||||
wFormat = dnnl::memory::format_tag::any;
|
||||
wFormatWasChanged = true;
|
||||
}
|
||||
|
||||
if (wFormatWasChanged) {
|
||||
auto weightsDims = DnnlExtensionUtils::convertToDnnlDims(VectorDims{ L, D, DC, G, SC });
|
||||
wDescs[0] = dnnl::memory::desc(weightsDims, dataType, wFormat);
|
||||
const auto& targetWeightDataType = weightsByinputDataType.at(inDataTypes[xIdx]);
|
||||
wDescs[0] = dnnl::memory::desc(weightsDims, targetWeightDataType, wFormat);
|
||||
auto statesDims = DnnlExtensionUtils::convertToDnnlDims(VectorDims{ L, D, SC, G, SC });
|
||||
wDescs[1] = dnnl::memory::desc(statesDims, dataType, wFormat);
|
||||
wDescs[1] = dnnl::memory::desc(statesDims, targetWeightDataType, wFormat);
|
||||
}
|
||||
|
||||
RNNKey key = { inDataDescs, outDataDescs, wDescs, cell_type, cell_act, direction };
|
||||
|
||||
auto builder = [this](const RNNKey& key) -> std::shared_ptr<dnnl::primitive> {
|
||||
const auto attr = initPrimitiveAttr();
|
||||
|
||||
auto builder = [this, attr](const RNNKey& key) -> std::shared_ptr<dnnl::primitive> {
|
||||
fillDescs();
|
||||
dnnl::primitive_attr attr;
|
||||
attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
if (key.cellType == dnnl::algorithm::vanilla_rnn) {
|
||||
std::shared_ptr<vanilla_rnn_forward::desc> desc = descs[0];
|
||||
return std::make_shared<vanilla_rnn_forward>(vanilla_rnn_forward::primitive_desc(*desc, attr, getEngine()));
|
||||
return std::make_shared<vanilla_rnn_forward>(vanilla_rnn_forward::primitive_desc(*desc, *attr, getEngine()));
|
||||
} else if (key.cellType == dnnl::algorithm::vanilla_gru) {
|
||||
std::shared_ptr<gru_forward::desc> desc = descs[0];
|
||||
return std::make_shared<gru_forward>(gru_forward::primitive_desc(*desc, attr, getEngine()));
|
||||
return std::make_shared<gru_forward>(gru_forward::primitive_desc(*desc, *attr, getEngine()));
|
||||
} else if (key.cellType == dnnl::algorithm::lbr_gru) {
|
||||
std::shared_ptr<lbr_gru_forward::desc> desc = descs[0];
|
||||
return std::make_shared<lbr_gru_forward>(lbr_gru_forward::primitive_desc(*desc, attr, getEngine()));
|
||||
return std::make_shared<lbr_gru_forward>(lbr_gru_forward::primitive_desc(*desc, *attr, getEngine()));
|
||||
} else if (key.cellType == dnnl::algorithm::vanilla_lstm) {
|
||||
std::shared_ptr<lstm_forward::desc> desc = descs[0];
|
||||
return std::make_shared<lstm_forward>(lstm_forward::primitive_desc(*desc, attr, getEngine()));
|
||||
return std::make_shared<lstm_forward>(lstm_forward::primitive_desc(*desc, *attr, getEngine()));
|
||||
} else if (key.cellType == dnnl::algorithm::vanilla_augru) {
|
||||
std::shared_ptr<augru_forward::desc> desc = descs[0];
|
||||
return std::make_shared<augru_forward>(augru_forward::primitive_desc(*desc, attr, getEngine()));
|
||||
return std::make_shared<augru_forward>(augru_forward::primitive_desc(*desc, *attr, getEngine()));
|
||||
} else if (key.cellType == dnnl::algorithm::lbr_augru) {
|
||||
std::shared_ptr<lbr_augru_forward::desc> desc = descs[0];
|
||||
return std::make_shared<lbr_augru_forward>(lbr_augru_forward::primitive_desc(*desc, attr, getEngine()));
|
||||
return std::make_shared<lbr_augru_forward>(lbr_augru_forward::primitive_desc(*desc, *attr, getEngine()));
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ public:
|
||||
bool created() const override;
|
||||
void createDescriptor(const std::vector<MemoryDescPtr>& inputDesc,
|
||||
const std::vector<MemoryDescPtr>& outputDesc) override;
|
||||
std::shared_ptr<dnnl::primitive_attr> initPrimitiveAttr() override;
|
||||
|
||||
void execute(dnnl::stream strm) override;
|
||||
|
||||
@@ -41,6 +42,7 @@ protected:
|
||||
void executeDynamicImpl(dnnl::stream strm) override;
|
||||
|
||||
private:
|
||||
void configurePortDataTypes();
|
||||
void initCell();
|
||||
void initSequence();
|
||||
void fillCellDesc();
|
||||
@@ -106,6 +108,9 @@ private:
|
||||
std::vector<DnnlBlockedMemoryDescPtr> outDataDescs;
|
||||
std::vector<dnnl::memory::desc> wDescs;
|
||||
|
||||
std::vector<dnnl::memory::data_type> inDataTypes;
|
||||
std::vector<dnnl::memory::data_type> outDataTypes;
|
||||
|
||||
enum RNNInOutKind {
|
||||
Layer = 0,
|
||||
HiddenState = 1,
|
||||
@@ -113,17 +118,31 @@ private:
|
||||
Attention = 2
|
||||
};
|
||||
|
||||
size_t wIdx = 0;
|
||||
size_t rIdx = 0;
|
||||
size_t bIdx = 0;
|
||||
const size_t xIdx = 0; // ov -> input X; dnnl -> src_layer
|
||||
const size_t hIdx = 1; // ov -> initial_hidden_state; dnnl -> src_iter_h
|
||||
const size_t cIdx = 2; // ov -> initial_cell_state; dnnl -> src_iter_c
|
||||
size_t sIdx = 0; // ov -> sequence_length; dnnl -> additional input dimension 't'
|
||||
// oneDNN does not support unique t (seq_len) per batch
|
||||
size_t wIdx = 0; // ov -> W; dnnl -> weights_layer
|
||||
size_t rIdx = 0; // ov -> R; dnnl -> weights_iter
|
||||
size_t bIdx = 0; // ov -> B; dnnl -> bias
|
||||
size_t aIdx = 0; // ov -> A: dnnl -> attention
|
||||
|
||||
static const std::map<InferenceEngine::Precision, InferenceEngine::Precision> weightsByLayerPrec;
|
||||
size_t yIdx = 0; // ov -> Y; dnnl -> dst_layer
|
||||
size_t hoIdx = 0; // ov -> Ho; dnnl -> dst_iter_h
|
||||
size_t coIdx = 0; // ov -> Co; dnnl -> dst_iter_c
|
||||
|
||||
static const std::map<dnnl::memory::data_type, dnnl::memory::data_type> weightsByinputDataType;
|
||||
|
||||
static constexpr size_t optimalBatchSize = 16lu;
|
||||
static constexpr size_t batchDimDummyValue = 64lu;
|
||||
|
||||
bool wasMemoryPrepared = false;
|
||||
MemoryPtr scratchpadMem;
|
||||
|
||||
float inputScale = 0.f;
|
||||
float inputShift = 0.f;
|
||||
std::vector<float> weightsScales;
|
||||
};
|
||||
|
||||
} // namespace node
|
||||
|
||||
@@ -22,8 +22,6 @@
|
||||
#include <ie_system_conf.h>
|
||||
#include <ie_ngraph_utils.hpp>
|
||||
|
||||
#include <transformations/opset_conversions/convert_opset3_to_opset2.hpp>
|
||||
#include <transformations/opset_conversions/convert_opset2_to_opset1.hpp>
|
||||
|
||||
#include <transformations/common_optimizations/add_fake_quantize_fusion.hpp>
|
||||
#include <transformations/common_optimizations/common_optimizations.hpp>
|
||||
@@ -35,6 +33,10 @@
|
||||
#include <transformations/common_optimizations/wrap_interpolate_into_transposes.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking.hpp>
|
||||
#include "transformations/common_optimizations/convert_compression_only_to_legacy.hpp"
|
||||
#include <transformations/common_optimizations/lin_op_sequence_fusion.hpp>
|
||||
|
||||
#include <transformations/opset_conversions/convert_opset3_to_opset2.hpp>
|
||||
#include <transformations/opset_conversions/convert_opset2_to_opset1.hpp>
|
||||
#include <transformations/op_conversions/convert_broadcast_to_tiles.hpp>
|
||||
#include <transformations/op_conversions/convert_depth_to_space.hpp>
|
||||
#include <transformations/op_conversions/convert_shuffle_channels3.hpp>
|
||||
@@ -81,15 +83,23 @@
|
||||
#include <transformations/rt_info/fused_names_attribute.hpp>
|
||||
#include <transformations/op_conversions/fq_decomposition.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <snippets/pass/collapse_subgraph.hpp>
|
||||
#include <snippets/pass/common_optimizations.hpp>
|
||||
#include <snippets/pass/convert_constants.hpp>
|
||||
#include "ngraph_transformations/snippets_mark_skipped.hpp"
|
||||
#include <transformations/op_conversions/convert_roi_align_v9_to_v3.hpp>
|
||||
#include <transformations/op_conversions/convert_roi_align_v3_to_v9.hpp>
|
||||
#include <transformations/op_conversions/softsign_decomposition.hpp>
|
||||
#include "transformations/op_conversions/eye_decomposition.hpp"
|
||||
#include "transformations/smart_reshape/smart_reshape.hpp"
|
||||
|
||||
#include "ngraph_transformations/convert_to_cpu_specific_opset.hpp"
|
||||
#include "ngraph_transformations/snippets_mark_skipped.hpp"
|
||||
#include "ngraph_transformations/mha_fusion.hpp"
|
||||
#include "ngraph_transformations/convert_to_interaction.hpp"
|
||||
#include "ngraph_transformations/convert_fq_rnn_to_quantized_rnn.hpp"
|
||||
#include "ngraph_transformations/move_eltwise_up_data_movement.hpp"
|
||||
#include "ngraph_transformations/swap_convert_transpose.hpp"
|
||||
|
||||
#include <snippets/pass/collapse_subgraph.hpp>
|
||||
#include <snippets/pass/common_optimizations.hpp>
|
||||
#include <snippets/pass/convert_constants.hpp>
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset2.hpp>
|
||||
@@ -97,15 +107,12 @@
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include "ngraph_ops/augru_cell.hpp"
|
||||
#include "ngraph_ops/augru_sequence.hpp"
|
||||
#include <ngraph/op/util/op_types.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/graph_util.hpp>
|
||||
|
||||
#include "ngraph_ops/augru_cell.hpp"
|
||||
#include "ngraph_ops/augru_sequence.hpp"
|
||||
|
||||
#include <transformations/common_optimizations/lin_op_sequence_fusion.hpp>
|
||||
|
||||
#include <transformations/low_precision/disable_convert_constant_folding_on_const_path.hpp>
|
||||
#include <low_precision/common/quantization_granularity_restriction.hpp>
|
||||
#include <low_precision/common/precisions_restriction.hpp>
|
||||
@@ -126,11 +133,6 @@
|
||||
#include "nodes/fake_quantize.h"
|
||||
#include "nodes/normalize.h"
|
||||
#include "nodes/mha.h"
|
||||
#include "ngraph_transformations/convert_to_cpu_specific_opset.hpp"
|
||||
#include "ngraph_transformations/convert_to_interaction.hpp"
|
||||
#include "ngraph_transformations/move_eltwise_up_data_movement.hpp"
|
||||
#include "transformations/smart_reshape/smart_reshape.hpp"
|
||||
#include "ngraph_transformations/swap_convert_transpose.hpp"
|
||||
#include "utils/denormals.hpp"
|
||||
|
||||
#if !defined(__arm__) && !defined(_M_ARM) && !defined(__aarch64__) && !defined(_M_ARM64)
|
||||
@@ -566,6 +568,12 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
|
||||
{{0}, {ngraph::element::u8, ngraph::element::i8}},
|
||||
{{1}, {ngraph::element::i8}}
|
||||
}),
|
||||
PrecisionsRestriction::create<ngraph::opset5::LSTMSequence>({
|
||||
{{0, 1}, {ngraph::element::u8, ngraph::element::i8}},
|
||||
}),
|
||||
PrecisionsRestriction::create<ngraph::opset6::GRUSequence>({
|
||||
{{0, 1}, {ngraph::element::u8, ngraph::element::i8}},
|
||||
}),
|
||||
});
|
||||
|
||||
auto quantizationRestrictions = std::vector<QuantizationGranularityRestriction>({
|
||||
@@ -639,6 +647,9 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
|
||||
|
||||
return false;
|
||||
});
|
||||
|
||||
// Execute before snippets. Otherwise FQ will be converted to Subgraph
|
||||
postLPTPassManager.register_pass<ConvertFqRnnToQuantizedRnn>();
|
||||
postLPTPassManager.run_passes(nGraphFunc);
|
||||
|
||||
if (_enableSnippets && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) {
|
||||
|
||||
@@ -164,9 +164,6 @@ std::vector<std::string> disabledTestPatterns() {
|
||||
// is shared across plugins
|
||||
// passed local test and cpu has specific test cases with nms9 to cover
|
||||
R"(smoke_NmsLayerTest.*)",
|
||||
// Issue: 95915
|
||||
R"(smoke_dynamic/AUGRUCellCPUTest.CompareWithRefs/IS=\(\{\?\.1\}_\{\?\.1\}_\{\?\.1\}_\)_TS=\{\(1\.1\)_\(1\.1\)_\(1\.1\)\}_\{\(3\.1\)_\(3\.1\)_\(3\.1\)\}_\{\(5\.1\)_\(5\.1\)_\(5\.1\)\}_decompose=0_activations=\(sigmoid\.tanh\)_clip=0_linear=0_netPrec=f32__inFmts=nc\.nc_outFmts=nc_primitive=ref_any_PluginConf_ENFORCE_BF16=YES)", // NOLINT
|
||||
R"(smoke_dynamic/GRUCellCPUTest.CompareWithRefs/IS=\(\{\?\.1\}_\{\?\.1\}_\)_TS=\{\(1\.1\)_\(1\.1\)\}_\{\(3\.1\)_\(3\.1\)\}_\{\(5\.1\)_\(5\.1\)\}_decompose=0_activations=\(sigmoid\.tanh\)_clip=0_linear=0_netPrec=f32__inFmts=nc\.nc_outFmts=nc_primitive=ref_any_PluginConf_ENFORCE_BF16=YES)", // NOLINT
|
||||
// 94982. FP32->I32 conversion issue in the reference implementation. There can be some garbage in the rest of float values like 0.333333745.
|
||||
// The kernel does not have such garbage. The diff 0.000000745 is taken into account in calculations and affects further type conversion.
|
||||
// Reorder->GridSample->Reorder also does not work here. Potential fix is to use nearest conversion instead of truncation.
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include <ngraph/opsets/opset2.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
#include <ngraph/opsets/opset6.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
@@ -98,6 +99,7 @@
|
||||
#include <low_precision/strided_slice.hpp>
|
||||
#include <low_precision/network_helper.hpp>
|
||||
#include "transformations/op_conversions/eye_decomposition.hpp"
|
||||
#include <low_precision/recurrent_cell.hpp>
|
||||
|
||||
#include "intel_gpu/plugin/itt.hpp"
|
||||
|
||||
@@ -425,7 +427,9 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
||||
PrecisionsRestriction::create<ngraph::opset1::GroupConvolution>({
|
||||
{{0}, {ngraph::element::u8, ngraph::element::i8}},
|
||||
{{1}, {ngraph::element::i8}}
|
||||
})
|
||||
}),
|
||||
PrecisionsRestriction::create<ngraph::opset5::LSTMSequence>({}),
|
||||
PrecisionsRestriction::create<ngraph::opset6::GRUSequence>({})
|
||||
});
|
||||
|
||||
auto perTensorQuantization = std::vector<QuantizationGranularityRestriction>({
|
||||
@@ -436,6 +440,8 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
|
||||
ngraph::pass::Manager lptManager;
|
||||
|
||||
auto lptPassConfig = lptManager.get_pass_config();
|
||||
// quantized LSTMSequence / GPUSequence are not supported yet. Avoid extra transformation
|
||||
lptPassConfig->disable<ngraph::pass::low_precision::RecurrentCellTransformation>();
|
||||
lptPassConfig->set_callback<ngraph::pass::low_precision::MarkupPrecisions>([](const_node_ptr& node) -> bool {
|
||||
if (const auto mulitply = std::dynamic_pointer_cast<const ngraph::opset1::Multiply>(node)) {
|
||||
return !MultiplyToGroupConvolutionTransformation::canBeTransformedToGroupConvolution(mulitply);
|
||||
|
||||
@@ -0,0 +1,352 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <low_precision/common/precisions_restriction.hpp>
|
||||
#include <low_precision/recurrent_cell.hpp>
|
||||
#include <low_precision/fold_convert.hpp>
|
||||
#include <low_precision/fuse_convert.hpp>
|
||||
#include <low_precision/fuse_multiply_to_fake_quantize.hpp>
|
||||
#include <low_precision/fuse_subtract_to_fake_quantize.hpp>
|
||||
#include <low_precision/rt_info/intervals_alignment_attribute.hpp>
|
||||
#include <low_precision/rt_info/precision_preserved_attribute.hpp>
|
||||
#include <low_precision/rt_info/quantization_alignment_attribute.hpp>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "layer_transformation.hpp"
|
||||
#include "lpt_ngraph_functions/common/builders.hpp"
|
||||
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
|
||||
#include "lpt_ngraph_functions/recurrent_cell_function.hpp"
|
||||
#include "simple_low_precision_transformer.hpp"
|
||||
#include <ngraph/opsets/opset5.hpp>
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
using namespace ngraph::pass;
|
||||
using namespace ngraph::builder::subgraph;
|
||||
|
||||
namespace {
|
||||
|
||||
class RecurrentCellTransformationValues {
|
||||
public:
|
||||
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize_X;
|
||||
ngraph::builder::subgraph::DequantizationOperations::Convert convert_X;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization_X;
|
||||
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize_H;
|
||||
ngraph::builder::subgraph::DequantizationOperations::Convert convert_H;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization_H;
|
||||
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize_W;
|
||||
ngraph::builder::subgraph::DequantizationOperations::Convert convert_W;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization_W;
|
||||
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize_R;
|
||||
ngraph::builder::subgraph::DequantizationOperations::Convert convert_R;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization_R;
|
||||
ngraph::element::Type precisionAfterOperation;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const RecurrentCellTransformationValues& values) {
|
||||
return out << "_" << values.fakeQuantize_X << "_" << values.convert_X << "_" << values.dequantization_X <<
|
||||
"_" << values.fakeQuantize_H << "_" << values.convert_H << "_" << values.dequantization_H <<
|
||||
"_" << values.fakeQuantize_W << "_" << values.convert_W << "_" << values.dequantization_W <<
|
||||
"_" << values.fakeQuantize_R << "_" << values.convert_R << "_" << values.dequantization_R;
|
||||
}
|
||||
|
||||
class RecurrentCellTransformationTestValues {
|
||||
public:
|
||||
RecurrentCellTransformationTestValues() = default;
|
||||
RecurrentCellTransformationTestValues(const TestTransformationParams& params,
|
||||
const RecurrentCellFunction::RNNType type,
|
||||
const RecurrentCellTransformationValues& actual,
|
||||
const RecurrentCellTransformationValues& result,
|
||||
const bool addNotPrecisionPreservedOperation = false,
|
||||
const bool checkIntervalsAlignmentAttributes = true)
|
||||
: params(params),
|
||||
type(type),
|
||||
actual(actual),
|
||||
result(result) {}
|
||||
|
||||
TestTransformationParams params;
|
||||
RecurrentCellFunction::RNNType type;
|
||||
RecurrentCellTransformationValues actual;
|
||||
RecurrentCellTransformationValues result;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const RecurrentCellTransformationTestValues& values) {
|
||||
return out << "_" << values.actual << "_" << values.result;
|
||||
}
|
||||
|
||||
typedef std::tuple<ngraph::element::Type, std::vector<ngraph::PartialShape>, std::vector<ngraph::Shape>, RecurrentCellTransformationTestValues>
|
||||
RecurrentCellTransformationParams;
|
||||
|
||||
class RecurrentCellTransformation : public LayerTransformation, public testing::WithParamInterface<RecurrentCellTransformationParams> {
|
||||
public:
|
||||
void SetUp() override {
|
||||
const ngraph::element::Type precision = std::get<0>(GetParam());
|
||||
const std::vector<ngraph::PartialShape> activations_shapes = std::get<1>(GetParam());
|
||||
const std::vector<ngraph::Shape> weights_shapes = std::get<2>(GetParam());
|
||||
RecurrentCellTransformationTestValues testValues = std::get<3>(GetParam());
|
||||
|
||||
actualFunction = ngraph::builder::subgraph::RecurrentCellFunction::get(precision,
|
||||
activations_shapes,
|
||||
weights_shapes,
|
||||
testValues.type,
|
||||
{
|
||||
testValues.actual.fakeQuantize_X,
|
||||
testValues.actual.fakeQuantize_H,
|
||||
testValues.actual.fakeQuantize_W,
|
||||
testValues.actual.fakeQuantize_R
|
||||
},
|
||||
{
|
||||
testValues.actual.convert_X,
|
||||
testValues.actual.convert_H,
|
||||
testValues.actual.convert_W,
|
||||
testValues.actual.convert_R
|
||||
},
|
||||
{
|
||||
testValues.actual.dequantization_X,
|
||||
testValues.actual.dequantization_H,
|
||||
testValues.actual.dequantization_W,
|
||||
testValues.actual.dequantization_R
|
||||
});
|
||||
|
||||
const auto params = TestTransformationParams::toParams(testValues.params);
|
||||
|
||||
SimpleLowPrecisionTransformer transformer;
|
||||
transformer.commonGraphRewrite->add_matcher<ngraph::pass::low_precision::RecurrentCellTransformation>(params);
|
||||
transformer.transform(actualFunction);
|
||||
|
||||
SimpleLowPrecisionTransformer clenup_transformer;
|
||||
clenup_transformer.commonGraphRewrite->add_matcher<ngraph::pass::low_precision::FoldConvertTransformation>(params);
|
||||
clenup_transformer.commonGraphRewrite->add_matcher<ngraph::pass::low_precision::FuseConvertTransformation>(params);
|
||||
clenup_transformer.commonGraphRewrite->add_matcher<ngraph::pass::low_precision::FuseSubtractToFakeQuantizeTransformation>(params);
|
||||
clenup_transformer.commonGraphRewrite->add_matcher<ngraph::pass::low_precision::FuseMultiplyToFakeQuantizeTransformation>(params);
|
||||
clenup_transformer.transform(actualFunction);
|
||||
|
||||
// dequantization output precision depends on input precision
|
||||
// to avoid huge amount of tests cases let's define dequantization output precision as input precision
|
||||
if (!testValues.result.dequantizationAfter.multiply.empty()) {
|
||||
testValues.result.dequantizationAfter.multiply.outPrecision = precision;
|
||||
}
|
||||
|
||||
referenceFunction =
|
||||
ngraph::builder::subgraph::RecurrentCellFunction::get(precision,
|
||||
activations_shapes,
|
||||
weights_shapes,
|
||||
testValues.type,
|
||||
{
|
||||
testValues.result.fakeQuantize_X,
|
||||
testValues.result.fakeQuantize_H,
|
||||
testValues.result.fakeQuantize_W,
|
||||
testValues.result.fakeQuantize_R
|
||||
},
|
||||
{
|
||||
testValues.result.convert_X,
|
||||
testValues.result.convert_H,
|
||||
testValues.result.convert_W,
|
||||
testValues.result.convert_R
|
||||
},
|
||||
{
|
||||
testValues.result.dequantization_X,
|
||||
testValues.result.dequantization_H,
|
||||
testValues.result.dequantization_W,
|
||||
testValues.result.dequantization_R
|
||||
});
|
||||
}
|
||||
|
||||
static std::string getTestCaseName(testing::TestParamInfo<RecurrentCellTransformationParams> obj) {
|
||||
const ngraph::element::Type precision = std::get<0>(obj.param);
|
||||
const std::vector<ngraph::PartialShape> activations_shapes = std::get<1>(obj.param);
|
||||
const std::vector<ngraph::Shape> weights_shapes = std::get<2>(obj.param);
|
||||
const RecurrentCellTransformationTestValues testValues = std::get<3>(obj.param);
|
||||
|
||||
std::ostringstream result;
|
||||
result << LayerTransformation::getTestCaseNameByParams(precision, activations_shapes[0], testValues.params)
|
||||
<< "_" << testValues.actual << "_" << testValues.result << "_";
|
||||
return result.str();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(RecurrentCellTransformation, CompareFunctions) {
|
||||
actualFunction->validate_nodes_and_infer_types();
|
||||
auto res = compare_functions(actualFunction, referenceFunction);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
|
||||
ASSERT_TRUE(LayerTransformation::allNamesAreUnique(actualFunction)) << "Not all names are unique";
|
||||
}
|
||||
|
||||
const std::vector<ngraph::element::Type> precisions = {
|
||||
ngraph::element::f32,
|
||||
// ngraph::element::f16
|
||||
};
|
||||
|
||||
namespace testValues2 {
|
||||
const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{1, 1, 16}, {1, 1, 128}, {1, 1, 128}}};
|
||||
|
||||
const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{1, 512, 16}, {1, 512, 128}, {1, 512}}};
|
||||
|
||||
const std::vector<RecurrentCellTransformationTestValues> testValues = {
|
||||
// LSTM Sequence
|
||||
{LayerTransformation::createParamsU8I8(),
|
||||
RecurrentCellFunction::RNNType::LSTMSequence,
|
||||
{
|
||||
// X
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// H
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// W
|
||||
{255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
// R
|
||||
{255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
},
|
||||
{
|
||||
// X
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// H
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// W
|
||||
{},
|
||||
{},
|
||||
{
|
||||
{element::f32},
|
||||
{},
|
||||
{0.01f}
|
||||
},
|
||||
// R
|
||||
{},
|
||||
{},
|
||||
{
|
||||
{element::f32},
|
||||
{},
|
||||
{0.01f}
|
||||
},
|
||||
}
|
||||
},
|
||||
};
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
smoke_LPT,
|
||||
RecurrentCellTransformation,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(precisions),
|
||||
::testing::ValuesIn(activations_shapes),
|
||||
::testing::ValuesIn(weights_shapes),
|
||||
::testing::ValuesIn(testValues)),
|
||||
RecurrentCellTransformation::getTestCaseName);
|
||||
} // namespace testValues2
|
||||
|
||||
namespace testValues3 {
|
||||
const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{1, 2, 3}, {1, 1, 3}, {}}};
|
||||
|
||||
const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{1, 9, 3}, {1, 9, 3}, {1, 9}}};
|
||||
|
||||
const std::vector<RecurrentCellTransformationTestValues> testValues = {
|
||||
// GRU
|
||||
{LayerTransformation::createParamsU8I8(),
|
||||
RecurrentCellFunction::RNNType::GRUSequence,
|
||||
{
|
||||
// X
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// H
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// W
|
||||
{255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
// R
|
||||
{255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
},
|
||||
{
|
||||
// X
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// H
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// W
|
||||
{},
|
||||
{},
|
||||
{
|
||||
{element::f32},
|
||||
{},
|
||||
{0.01f}
|
||||
},
|
||||
// R
|
||||
{},
|
||||
{},
|
||||
{
|
||||
{element::f32},
|
||||
{},
|
||||
{0.01f}
|
||||
},
|
||||
}
|
||||
}
|
||||
};
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
smoke_LPT,
|
||||
RecurrentCellTransformation,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(precisions),
|
||||
::testing::ValuesIn(activations_shapes),
|
||||
::testing::ValuesIn(weights_shapes),
|
||||
::testing::ValuesIn(testValues)),
|
||||
RecurrentCellTransformation::getTestCaseName);
|
||||
} // namespace testValues3
|
||||
|
||||
} // namespace
|
||||
@@ -0,0 +1,177 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "low_precision_transformations/recurrent_cell_transformation.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
const std::vector<ngraph::element::Type> netPrecisions = {
|
||||
ngraph::element::f32,
|
||||
//ngraph::element::f16
|
||||
};
|
||||
|
||||
const std::vector<ngraph::pass::low_precision::LayerTransformation::Params> trasformationParamValues = {
|
||||
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParams().setUpdatePrecisions(true)
|
||||
};
|
||||
|
||||
namespace testValues1 {
|
||||
|
||||
const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> params = {
|
||||
// LSTMCell
|
||||
{
|
||||
// X
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// H
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// W
|
||||
{255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
// R
|
||||
{255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMCell,
|
||||
"RNNCell",
|
||||
"U8"
|
||||
},
|
||||
// asymmetrical FQ on weights
|
||||
{
|
||||
// X
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// H
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// W
|
||||
{256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
// R
|
||||
{256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMCell,
|
||||
"RNNCell",
|
||||
"FP32"
|
||||
}
|
||||
};
|
||||
|
||||
const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{1, 16}, {1, 128}, {1, 128}}};
|
||||
const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{512, 16}, {512, 128}, {512}}};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_LPT, RecurrentCellTransformation,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::ValuesIn(activations_shapes),
|
||||
::testing::ValuesIn(weights_shapes),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::ValuesIn(trasformationParamValues),
|
||||
::testing::ValuesIn(params)),
|
||||
RecurrentCellTransformation::getTestCaseName);
|
||||
} // namespace testValues1
|
||||
|
||||
namespace testValues2 {
|
||||
|
||||
const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> params = {
|
||||
// GRU
|
||||
{
|
||||
// X
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// H
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// W
|
||||
{255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
// R
|
||||
{255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRU,
|
||||
"RNNCell",
|
||||
"U8"
|
||||
},
|
||||
// asymmetrical FQ on weights
|
||||
{
|
||||
// X
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// H
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// W
|
||||
{256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
// R
|
||||
{256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRU,
|
||||
"RNNCell",
|
||||
"FP32"
|
||||
}
|
||||
};
|
||||
|
||||
const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{2, 3}, {2, 3}, {}}};
|
||||
const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{9, 3}, {9, 3}, {9}}};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_LPT, RecurrentCellTransformation,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::ValuesIn(activations_shapes),
|
||||
::testing::ValuesIn(weights_shapes),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::ValuesIn(trasformationParamValues),
|
||||
::testing::ValuesIn(params)),
|
||||
RecurrentCellTransformation::getTestCaseName);
|
||||
} // namespace testValues2
|
||||
@@ -0,0 +1,123 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "test_utils/cpu_test_utils.hpp"
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
#include "test_utils/fusing_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace InferenceEngine;
|
||||
using namespace CPUTestUtils;
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
using ConvertFqRnnToQuantizedRnnTestParams = std::tuple<std::string, SizeVector>;
|
||||
/* using ConvertFqRnnToQuantizedRnnTestParams = std::string; */
|
||||
|
||||
class ConvertFqRnnToQuantizedRnn : public testing::WithParamInterface<ConvertFqRnnToQuantizedRnnTestParams>,
|
||||
public CpuTestWithFusing,
|
||||
virtual public ov::test::SubgraphBaseTest {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<ConvertFqRnnToQuantizedRnnTestParams>& obj) {
|
||||
SizeVector inputShapes;
|
||||
std::string rnnType;
|
||||
std::tie(rnnType, inputShapes) = obj.param;
|
||||
|
||||
auto batchSize = inputShapes[0];
|
||||
auto inputSize = inputShapes[1];
|
||||
auto hiddenSize = inputShapes[2];
|
||||
|
||||
std::ostringstream result;
|
||||
result << "Type = " << rnnType << "_";
|
||||
result << "batch = " << batchSize << "_";
|
||||
result << "input = " << inputSize << "_";
|
||||
result << "hidden = " << hiddenSize << "_";
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||
|
||||
SizeVector inputShapes;
|
||||
std::string rnnType;
|
||||
|
||||
std::tie(rnnType, inputShapes) = this->GetParam();
|
||||
|
||||
auto batchSize = inputShapes[0];
|
||||
auto inputSize = inputShapes[1];
|
||||
auto hiddenSize = inputShapes[2];
|
||||
|
||||
const float inputDataMin = 6.43123;
|
||||
const float inputDataMax = -6.48187;
|
||||
const float outputDataMin = inputDataMin;
|
||||
const float outputDataMax = outputDataMin;
|
||||
|
||||
const SizeVector inputShape = {batchSize, inputSize};
|
||||
const SizeVector hiddenStateShape = {batchSize, hiddenSize};
|
||||
const SizeVector cellStateShape = {batchSize, hiddenSize};
|
||||
|
||||
init_input_shapes({
|
||||
{{}, {inputShape}},
|
||||
{{}, {hiddenStateShape}},
|
||||
{{}, {cellStateShape}}
|
||||
});
|
||||
|
||||
const auto ngPrec = element::f32;
|
||||
auto inputParams = builder::makeParams(ngPrec, {inputShape, hiddenStateShape, cellStateShape});
|
||||
const auto outputNodes = helpers::convert2OutputVector(helpers::castOps2Nodes<op::Parameter>(inputParams));
|
||||
|
||||
std::vector<float> empty;
|
||||
auto W = ngraph::builder::makeConstant(ngraph::element::f32, {4 * hiddenSize, inputSize}, empty, true);
|
||||
auto R = ngraph::builder::makeConstant(ngraph::element::f32, {4 * hiddenSize, hiddenSize}, empty, true);
|
||||
auto B = ngraph::builder::makeConstant(ngraph::element::f32, {4 * hiddenSize}, empty, true);
|
||||
|
||||
const auto fqLevels = 256;
|
||||
|
||||
auto inputFQ = ngraph::builder::makeFakeQuantize(outputNodes[0], ngraph::element::f32, fqLevels, std::vector<size_t>{},
|
||||
{ inputDataMin }, { inputDataMax }, { outputDataMin }, { outputDataMax });
|
||||
|
||||
auto hiddenStateFQ = ngraph::builder::makeFakeQuantize(outputNodes[1], ngraph::element::f32, fqLevels, std::vector<size_t>{},
|
||||
{ inputDataMin }, { inputDataMax }, { inputDataMin }, { inputDataMax });
|
||||
|
||||
auto weightsFQ = ngraph::builder::makeFakeQuantize(W, ngraph::element::f32, fqLevels, std::vector<size_t>{},
|
||||
{ inputDataMin }, { inputDataMax }, { inputDataMin }, { inputDataMax });
|
||||
|
||||
auto recurrentWeightsFQ = ngraph::builder::makeFakeQuantize(R, ngraph::element::f32, fqLevels, std::vector<size_t>{},
|
||||
{ inputDataMin }, { inputDataMax }, { inputDataMin }, { inputDataMax });
|
||||
|
||||
auto rnnCellOp = std::make_shared<ov::op::v4::LSTMCell>(inputFQ, hiddenStateFQ, inputParams[2], weightsFQ, recurrentWeightsFQ, B, hiddenSize);
|
||||
|
||||
function = makeNgraphFunction(ngPrec, inputParams, rnnCellOp, "ConvertFqRnnToQuantizedRnn");
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ConvertFqRnnToQuantizedRnn, CompareWithRefs) {
|
||||
run();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
const std::vector<SizeVector> inputShapes {
|
||||
{37, 128, 512},
|
||||
/* {256, 128, 256}, */
|
||||
};
|
||||
|
||||
std::vector<std::string> rnnTypes {"LSTMCell", "RNNCell", "GRUCell"};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Check, ConvertFqRnnToQuantizedRnn,
|
||||
/* ::testing::ValuesIn(rnnTypes), */
|
||||
::testing::Combine(::testing::ValuesIn(rnnTypes),
|
||||
::testing::ValuesIn(inputShapes)),
|
||||
ConvertFqRnnToQuantizedRnn::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
||||
@@ -0,0 +1,179 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "low_precision_transformations/recurrent_cell_transformation.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
const std::vector<ngraph::element::Type> netPrecisions = {
|
||||
ngraph::element::f32,
|
||||
ngraph::element::f16
|
||||
};
|
||||
|
||||
const std::vector<ngraph::pass::low_precision::LayerTransformation::Params> trasformationParamValues = {
|
||||
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParams().setUpdatePrecisions(true)
|
||||
};
|
||||
|
||||
namespace testValues1 {
|
||||
|
||||
const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> params = {
|
||||
// LSTMSequence
|
||||
{
|
||||
// X
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// H
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// W
|
||||
{255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
// R
|
||||
{255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMSequence,
|
||||
"RNNCell",
|
||||
"U8"
|
||||
},
|
||||
// asymmetrical FQ on weights
|
||||
{
|
||||
// X
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// H
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// W
|
||||
{256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
// R
|
||||
{256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMSequence,
|
||||
"RNNCell",
|
||||
"FP32"
|
||||
}
|
||||
};
|
||||
|
||||
const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{1, 1, 16}, {1, 1, 128}, {1, 1, 128}}};
|
||||
const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{1, 512, 16}, {1, 512, 128}, {1, 512}}};
|
||||
|
||||
// Quantized Recurrent models are not supported by GPU yet. Keep tests for future
|
||||
INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_LPT, RecurrentCellTransformation,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::ValuesIn(activations_shapes),
|
||||
::testing::ValuesIn(weights_shapes),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||
::testing::ValuesIn(trasformationParamValues),
|
||||
::testing::ValuesIn(params)),
|
||||
RecurrentCellTransformation::getTestCaseName);
|
||||
} // namespace testValues1
|
||||
|
||||
namespace testValues2 {
|
||||
|
||||
const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> params = {
|
||||
// GRUSequence
|
||||
{
|
||||
// X
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// H
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// W
|
||||
{255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
// R
|
||||
{255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRUSequence,
|
||||
"RNNCell",
|
||||
"U8"
|
||||
},
|
||||
// asymmetrical FQ on weights
|
||||
{
|
||||
// X
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// H
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
{ngraph::element::u8},
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{},
|
||||
{0.01f},
|
||||
},
|
||||
// W
|
||||
{256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
// R
|
||||
{256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRUSequence,
|
||||
"RNNCell",
|
||||
"FP32"
|
||||
}
|
||||
};
|
||||
|
||||
const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{1, 2, 3}, {1, 2, 3}, {}}};
|
||||
const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{1, 9, 3}, {1, 9, 3}, {1, 9}}};
|
||||
|
||||
// Quantized Recurrent models are not supported by GPU yet. Keep tests for future
|
||||
INSTANTIATE_TEST_SUITE_P(DISABLED_smoke_LPT, RecurrentCellTransformation,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::ValuesIn(activations_shapes),
|
||||
::testing::ValuesIn(weights_shapes),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||
::testing::ValuesIn(trasformationParamValues),
|
||||
::testing::ValuesIn(params)),
|
||||
RecurrentCellTransformation::getTestCaseName);
|
||||
} // namespace testValues2
|
||||
@@ -0,0 +1,60 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "shared_test_classes/base/low_precision_transformations/layer_transformation.hpp"
|
||||
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
|
||||
#include "lpt_ngraph_functions/common/fake_quantize_on_weights.hpp"
|
||||
|
||||
#include "low_precision/recurrent_cell.hpp"
|
||||
|
||||
#include "lpt_ngraph_functions/recurrent_cell_function.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
class RecurrentCellTransformationParam {
|
||||
public:
|
||||
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize_X;
|
||||
ngraph::builder::subgraph::DequantizationOperations::Convert convert_X;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization_X;
|
||||
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize_H;
|
||||
ngraph::builder::subgraph::DequantizationOperations::Convert convert_H;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization_H;
|
||||
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize_W;
|
||||
ngraph::builder::subgraph::DequantizationOperations::Convert convert_W;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization_W;
|
||||
ngraph::builder::subgraph::FakeQuantizeOnDataWithConstant fakeQuantize_R;
|
||||
ngraph::builder::subgraph::DequantizationOperations::Convert convert_R;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization_R;
|
||||
ngraph::builder::subgraph::RecurrentCellFunction::RNNType RNNType;
|
||||
std::string layerName;
|
||||
std::string expectedKernelType;
|
||||
};
|
||||
|
||||
typedef std::tuple<
|
||||
ngraph::element::Type,
|
||||
std::vector<ngraph::PartialShape>,
|
||||
std::vector<ngraph::Shape>,
|
||||
std::string,
|
||||
ngraph::pass::low_precision::LayerTransformation::Params,
|
||||
RecurrentCellTransformationParam
|
||||
>RecurrentCellTransformationParams;
|
||||
|
||||
class RecurrentCellTransformation :
|
||||
public testing::WithParamInterface<RecurrentCellTransformationParams>,
|
||||
public LayerTestsUtils::LayerTransformation {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<RecurrentCellTransformationParams> obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
|
||||
void Run() override;
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
@@ -0,0 +1,88 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "low_precision_transformations/recurrent_cell_transformation.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include <ie_core.hpp>
|
||||
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
#include "shared_test_classes/base/layer_test_utils.hpp"
|
||||
#include "functional_test_utils/blob_utils.hpp"
|
||||
#include "lpt_ngraph_functions/recurrent_cell_function.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
std::string RecurrentCellTransformation::getTestCaseName(testing::TestParamInfo<RecurrentCellTransformationParams> obj) {
|
||||
ngraph::element::Type netPrecision;
|
||||
std::vector<ngraph::PartialShape> activationsShape;
|
||||
std::vector<ngraph::Shape> weightsShape;
|
||||
std::string targetDevice;
|
||||
RecurrentCellTransformationParam param;
|
||||
ngraph::pass::low_precision::LayerTransformation::Params params;
|
||||
std::tie(netPrecision, activationsShape, weightsShape, targetDevice, params, param) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << getTestCaseNameByParams(netPrecision, activationsShape[0], targetDevice, params) <<
|
||||
"FQ_X_" << param.fakeQuantize_X << "_" <<
|
||||
"DQ_X_" << param.dequantization_X << "_" <<
|
||||
"FQ_W_" << param.fakeQuantize_W << "_" <<
|
||||
"DQ_W_" << param.dequantization_W;
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void RecurrentCellTransformation::SetUp() {
|
||||
ngraph::element::Type precision;
|
||||
std::vector<ngraph::PartialShape> activations_shapes;
|
||||
std::vector<ngraph::Shape> weights_shapes;
|
||||
RecurrentCellTransformationParam param;
|
||||
ngraph::pass::low_precision::LayerTransformation::Params params;
|
||||
|
||||
std::tie(precision, activations_shapes, weights_shapes, targetDevice, params, param) = this->GetParam();
|
||||
|
||||
function = ngraph::builder::subgraph::RecurrentCellFunction::get(precision,
|
||||
activations_shapes,
|
||||
weights_shapes,
|
||||
param.RNNType,
|
||||
{
|
||||
param.fakeQuantize_X,
|
||||
param.fakeQuantize_H,
|
||||
param.fakeQuantize_W,
|
||||
param.fakeQuantize_R
|
||||
},
|
||||
{
|
||||
param.convert_X,
|
||||
param.convert_H,
|
||||
param.convert_W,
|
||||
param.convert_R
|
||||
},
|
||||
{
|
||||
param.dequantization_X,
|
||||
param.dequantization_H,
|
||||
param.dequantization_W,
|
||||
param.dequantization_R
|
||||
});
|
||||
}
|
||||
|
||||
void RecurrentCellTransformation::Run() {
|
||||
LayerTestsCommon::Run();
|
||||
|
||||
const auto params = std::get<5>(GetParam());
|
||||
const auto actualPrecision = getRuntimePrecisionByType(params.layerName);
|
||||
auto expectedPrecision = params.expectedKernelType;
|
||||
if (expectedPrecision == "FP32" && std::get<0>(GetParam()) == ngraph::element::f16) {
|
||||
expectedPrecision = "FP16";
|
||||
}
|
||||
EXPECT_EQ(actualPrecision, expectedPrecision);
|
||||
}
|
||||
|
||||
TEST_P(RecurrentCellTransformation, CompareWithRefImpl) {
|
||||
Run();
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include "low_precision/layer_transformation.hpp"
|
||||
#include "common/fake_quantize_on_data.hpp"
|
||||
#include "common/dequantization_operations.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
namespace subgraph {
|
||||
|
||||
class RecurrentCellFunction {
|
||||
public:
|
||||
enum class RNNType { LSTMSequence, GRUSequence };
|
||||
|
||||
static std::shared_ptr<ngraph::Function> get(
|
||||
const ngraph::element::Type inputPrecision,
|
||||
const std::vector<ngraph::PartialShape>& inputActivationsShapes,
|
||||
const std::vector<ngraph::Shape>& inputWeightsShapes,
|
||||
const RNNType type,
|
||||
const std::vector<FakeQuantizeOnDataWithConstant>& fqOnDatas,
|
||||
const std::vector<DequantizationOperations::Convert>& converts,
|
||||
const std::vector<DequantizationOperations>& dequantizations);
|
||||
};
|
||||
|
||||
std::shared_ptr<Node> makeQuantizationAndDequantization(const std::shared_ptr<Node> input,
|
||||
const ngraph::element::Type inputPrecision,
|
||||
const std::string friendly_name,
|
||||
const FakeQuantizeOnDataWithConstant& fqOnData,
|
||||
const DequantizationOperations::Convert& convert,
|
||||
const DequantizationOperations& dequantization);
|
||||
} // namespace subgraph
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
||||
@@ -0,0 +1,149 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "lpt_ngraph_functions/recurrent_cell_function.hpp"
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include "ngraph_ops/type_relaxed.hpp"
|
||||
#include "low_precision/network_helper.hpp"
|
||||
#include "low_precision/rt_info/precision_preserved_attribute.hpp"
|
||||
#include "low_precision/rt_info/intervals_alignment_attribute.hpp"
|
||||
#include "low_precision/rt_info/quantization_alignment_attribute.hpp"
|
||||
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "lpt_ngraph_functions/common/builders.hpp"
|
||||
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
|
||||
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
|
||||
#include "lpt_ngraph_functions/common/builders.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
namespace subgraph {
|
||||
|
||||
using namespace ngraph::pass;
|
||||
|
||||
std::shared_ptr<ngraph::Function> RecurrentCellFunction::get(
|
||||
const ngraph::element::Type inputPrecision,
|
||||
const std::vector<ngraph::PartialShape>& inputActivationsShapes,
|
||||
const std::vector<ngraph::Shape>& inputWeightsShapes,
|
||||
const RNNType type,
|
||||
const std::vector<FakeQuantizeOnDataWithConstant>& fqOnDatas,
|
||||
const std::vector<DequantizationOperations::Convert>& converts,
|
||||
const std::vector<DequantizationOperations>& dequantizations) {
|
||||
auto X = std::make_shared<opset1::Parameter>(inputPrecision, inputActivationsShapes[0]);
|
||||
X->set_friendly_name("X");
|
||||
std::shared_ptr<Node> parent_X = makeQuantizationAndDequantization(X,
|
||||
inputPrecision,
|
||||
X->get_friendly_name(),
|
||||
fqOnDatas[0],
|
||||
converts[0],
|
||||
dequantizations[0]);
|
||||
auto H = std::make_shared<opset1::Parameter>(inputPrecision, inputActivationsShapes[1]);
|
||||
H->set_friendly_name("H");
|
||||
std::shared_ptr<Node> parent_H = makeQuantizationAndDequantization(H,
|
||||
inputPrecision,
|
||||
H->get_friendly_name(),
|
||||
fqOnDatas[1],
|
||||
converts[1],
|
||||
dequantizations[1]);
|
||||
auto C = std::make_shared<opset1::Parameter>(inputPrecision, inputActivationsShapes[2]);
|
||||
C->set_friendly_name("C");
|
||||
|
||||
auto W = ngraph::opset1::Constant::create(fqOnDatas[2].empty() ? ngraph::element::i8 : inputPrecision,
|
||||
inputWeightsShapes[0],
|
||||
{1});
|
||||
std::shared_ptr<Node> parent_W = makeQuantizationAndDequantization(W,
|
||||
inputPrecision,
|
||||
W->get_friendly_name(),
|
||||
fqOnDatas[2],
|
||||
converts[2],
|
||||
dequantizations[2]);
|
||||
auto R = ngraph::opset1::Constant::create(fqOnDatas[2].empty() ? ngraph::element::i8 : inputPrecision,
|
||||
inputWeightsShapes[1],
|
||||
{1});
|
||||
std::shared_ptr<Node> parent_R = makeQuantizationAndDequantization(R,
|
||||
inputPrecision,
|
||||
R->get_friendly_name(),
|
||||
fqOnDatas[3],
|
||||
converts[3],
|
||||
dequantizations[3]);
|
||||
auto B = ngraph::opset1::Constant::create(inputPrecision, inputWeightsShapes[2], {1});
|
||||
auto seq_lengths = ngraph::opset1::Constant::create(element::i32, Shape{1}, {3});
|
||||
|
||||
std::shared_ptr<ov::op::util::RNNCellBase> rnn_layer;
|
||||
switch (type) {
|
||||
case RNNType::LSTMSequence:
|
||||
rnn_layer = std::make_shared<opset5::LSTMSequence>(parent_X,
|
||||
parent_H,
|
||||
C,
|
||||
seq_lengths,
|
||||
parent_W,
|
||||
parent_R,
|
||||
B,
|
||||
128,
|
||||
op::RecurrentSequenceDirection::FORWARD);
|
||||
rnn_layer->set_friendly_name("lstm_sequense");
|
||||
break;
|
||||
case RNNType::GRUSequence:
|
||||
rnn_layer = std::make_shared<opset5::GRUSequence>(parent_X,
|
||||
parent_H,
|
||||
seq_lengths,
|
||||
parent_W,
|
||||
parent_R,
|
||||
B,
|
||||
3,
|
||||
op::RecurrentSequenceDirection::FORWARD);
|
||||
rnn_layer->set_friendly_name("gru_sequence");
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
auto& rtInfo = rnn_layer->get_rt_info();
|
||||
bool is_lstm = type == RNNType::LSTMSequence;
|
||||
rtInfo["Variant::std::string"] = "rnn_layer";
|
||||
|
||||
auto rnn_layer_res_1 = std::make_shared<opset5::Result>(rnn_layer->output(0));
|
||||
rnn_layer_res_1->set_friendly_name("output_1");
|
||||
std::shared_ptr<ov::op::v0::Result> rnn_layer_res_2 = {};
|
||||
if (is_lstm) {
|
||||
rnn_layer_res_2 = std::make_shared<opset5::Result>(rnn_layer->output(1));
|
||||
rnn_layer_res_2->set_friendly_name("output_2");
|
||||
}
|
||||
|
||||
ngraph::ResultVector results{rnn_layer_res_2 ? rnn_layer_res_1, rnn_layer_res_2 : rnn_layer_res_1};
|
||||
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
|
||||
results,
|
||||
is_lstm ? ngraph::ParameterVector{X, H, C} : ngraph::ParameterVector{X, H},
|
||||
"LSTMTransformation");
|
||||
|
||||
return function;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> makeQuantizationAndDequantization(const std::shared_ptr<Node> input,
|
||||
const ngraph::element::Type inputPrecision,
|
||||
const std::string friendly_name,
|
||||
const FakeQuantizeOnDataWithConstant& fqOnData,
|
||||
const DequantizationOperations::Convert& convert,
|
||||
const DequantizationOperations& dequantization) {
|
||||
std::shared_ptr<Node> parent;
|
||||
if (fqOnData.empty()) {
|
||||
parent = input;
|
||||
} else {
|
||||
std::shared_ptr<Node> fakeQuantize1 = makeFakeQuantizeTypeRelaxed(input, inputPrecision, fqOnData);
|
||||
fakeQuantize1->set_friendly_name("fakeQuantize_" + friendly_name);
|
||||
parent = fakeQuantize1;
|
||||
}
|
||||
if (!convert.empty()) {
|
||||
parent = std::make_shared<opset1::Convert>(parent, convert.outPrecision);
|
||||
}
|
||||
if (!dequantization.empty()) {
|
||||
parent = makeDequantization(parent, dequantization);
|
||||
}
|
||||
return parent;
|
||||
}
|
||||
|
||||
} // namespace subgraph
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
||||
Reference in New Issue
Block a user