[CPU] Add and correct tests for int8 LSTM (#17447)

This commit is contained in:
Egor Duplenskii
2023-06-14 07:58:31 +02:00
committed by GitHub
parent ca0d40969a
commit d66e322529
9 changed files with 333 additions and 154 deletions

View File

@@ -455,6 +455,9 @@ void RNN::configurePortDataTypes() {
if (one_of(memory::data_type::bf16, inDataTypes[xIdx], inDataTypes[hIdx]))
inDataTypes[xIdx] = outDataTypes[yIdx] = outDataTypes[hoIdx] = inDataTypes[hIdx] = memory::data_type::bf16; // required by oneDNN.
if (outDataTypes[yIdx] == memory::data_type::bf16 && one_of(inDataTypes[xIdx], memory::data_type::s8, memory::data_type::u8))
outDataTypes[yIdx] = memory::data_type::f32; // oneDNN does not support bf16 output precision for quantized rnn primitive yet
}
void RNN::getSupportedDescriptors() {
@@ -870,7 +873,8 @@ dnnl::primitive_desc createPrimitiveDescriptor(const dnnl::engine engine,
wDescs[1], // Weights state
wDescs[2], // Bias
outDataDescs[RNN::InOutKind::Layer]->getDnnlDesc(), // Out Data
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc()); // Out State
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc(), // Out State
attr);
case dnnl::algorithm::vanilla_gru:
return dnnl::gru_forward::primitive_desc(
engine,
@@ -882,7 +886,8 @@ dnnl::primitive_desc createPrimitiveDescriptor(const dnnl::engine engine,
wDescs[1], // Weights state
wDescs[2], // Bias
outDataDescs[RNN::InOutKind::Layer]->getDnnlDesc(), // Out Data
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc()); // Out State
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc(), // Out State
attr);
case dnnl::algorithm::lbr_gru:
return dnnl::lbr_gru_forward::primitive_desc(
engine,
@@ -894,7 +899,8 @@ dnnl::primitive_desc createPrimitiveDescriptor(const dnnl::engine engine,
wDescs[1], // Weights state
wDescs[2], // Bias
outDataDescs[RNN::InOutKind::Layer]->getDnnlDesc(), // Out Data
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc()); // Out State
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc(), // Out State
attr);
case dnnl::algorithm::vanilla_lstm:
return dnnl::lstm_forward::primitive_desc(
engine,
@@ -908,7 +914,8 @@ dnnl::primitive_desc createPrimitiveDescriptor(const dnnl::engine engine,
wDescs[2], // Bias
outDataDescs[RNN::InOutKind::Layer]->getDnnlDesc(), // Out Data
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc(), // Out State
outDataDescs[RNN::InOutKind::CellState]->getDnnlDesc()); // Out State C
outDataDescs[RNN::InOutKind::CellState]->getDnnlDesc(), // Out State C
attr);
case dnnl::algorithm::vanilla_augru:
return dnnl::augru_forward::primitive_desc(
engine,
@@ -921,7 +928,8 @@ dnnl::primitive_desc createPrimitiveDescriptor(const dnnl::engine engine,
wDescs[1], // Weights state
wDescs[2], // Bias
outDataDescs[RNN::InOutKind::Layer]->getDnnlDesc(), // Out Data
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc()); // Out State
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc(), // Out State
attr);
case dnnl::algorithm::lbr_augru:
return dnnl::lbr_augru_forward::primitive_desc(
engine,
@@ -934,7 +942,8 @@ dnnl::primitive_desc createPrimitiveDescriptor(const dnnl::engine engine,
wDescs[1], // Weights state
wDescs[2], // Bias
outDataDescs[RNN::InOutKind::Layer]->getDnnlDesc(), // Out Data
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc()); // Out State
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc(), // Out State
attr);
default:
IE_THROW() << "RNN. Unknown cell type";
}
@@ -979,19 +988,19 @@ void RNN::createDescriptor(const std::vector<MemoryDescPtr> &inputDesc,
// Fill supported config
NodeConfig config;
for (size_t i = 0; i < inputDesc.size(); i++) {
for (const auto &desc : inputDesc) {
PortConfig dataConfig;
dataConfig.inPlace(-1);
dataConfig.constant(false);
dataConfig.setMemDesc(inputDesc[i]);
dataConfig.setMemDesc(desc);
config.inConfs.push_back(dataConfig);
}
for (size_t i = 0; i < outputDesc.size(); i++) {
for (const auto &desc : outputDesc) {
PortConfig dataConfig;
dataConfig.inPlace(-1);
dataConfig.constant(false);
dataConfig.setMemDesc(outputDesc[i]);
dataConfig.setMemDesc(desc);
config.outConfs.push_back(dataConfig);
}
@@ -1003,7 +1012,12 @@ Node::AttrPtr RNN::initPrimitiveAttr() {
attr->set_scratchpad_mode(dnnl::scratchpad_mode::user);
if (one_of(inDataTypes[xIdx], memory::data_type::u8, memory::data_type::s8)) {
const int weightsScaleMask = 0;
const int weightsScaleMask = 0
+ (1 << 3) // bit, indicating the unique scales for `g` dim in `ldigo`
+ (1 << 4); // bit, indicating the unique scales for `o` dim in `ldigo`
DEBUG_LOG(getName(), ": inputScale: ", inputScale, ", inputShift: ", inputShift,
", weightsScaleMask: ", weightsScaleMask, ", weightsScales[0]: ", weightsScales[0]);
attr->set_rnn_weights_qparams(weightsScaleMask, weightsScales);
attr->set_rnn_data_qparams(inputScale, inputShift);

View File

@@ -3,6 +3,7 @@
//
#include "convert_fq_rnn_to_quantized_rnn.hpp"
#include <algorithm>
#include <ngraph/opsets/opset9.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
@@ -14,6 +15,7 @@
#include "ie_common.h"
#include "itt.hpp"
#include "openvino/core/type/element_type.hpp"
#include <stdexcept>
#include <vector>
@@ -164,11 +166,15 @@ ov::intel_cpu::ConvertFqRnnToQuantizedRnn::ConvertFqRnnToQuantizedRnn() {
if (*input_scale_ptr == 0.f)
OPENVINO_THROW("Cannot handle zero input scale");
const float input_scale = 1 / *input_scale_ptr;
const std::vector<float> weights_scales = weights_scale_constant->get_vector<float>();
const float input_scale = 1 / *input_scale_ptr;
std::vector<float> weights_scales = weights_scale_constant->get_vector<float>();
// transform dequantization scales into quantization ones
std::transform(weights_scales.begin(), weights_scales.end(), weights_scales.begin(), [](float& scale) { return 1 / scale; });
auto& runtime_info = rnn_quantized->get_rt_info();
// use runtime information to store input and weight scales
runtime_info["inputScale"] = input_scale;
runtime_info["weightsScales"] = weights_scales;
@@ -178,7 +184,6 @@ ov::intel_cpu::ConvertFqRnnToQuantizedRnn::ConvertFqRnnToQuantizedRnn() {
if (input_shift_it != pattern_map.end()) {
const auto input_shift_constant = std::dynamic_pointer_cast<ngraph::opset9::Constant>(input_shift_it->second.get_node_shared_ptr());
const float* input_shift_ptr = input_shift_constant->get_data_ptr<float>();
runtime_info["inputShift"] = *input_shift_ptr;
}
@@ -207,6 +212,7 @@ ov::intel_cpu::ConvertFqRnnToQuantizedRnn::ConvertFqRnnToQuantizedRnn() {
}
auto new_multiply = multiply->clone_with_new_inputs({multiply_in, multiply->input_value(1)});
new_multiply->set_friendly_name(rnn_quantized->get_friendly_name() + ".1");
for (auto output : H_outputs) {
output.replace_source_output(new_multiply);

View File

@@ -12,6 +12,63 @@
* with FQ operations on the inputs and forms a new TypeRelaxed operation
* with quantization parameters as runtime parameters of the operation.
* @todo add ascii graph examples
*
* Before:
*
* +-------+ +-------+ +-------+ +-------+ +-------+ +-------+
* | X | | H | | C | | W | | R | | B |
* | | | | | | | | | | | |
* | u8/i8 | | u8/i8 | | f32 | | i8 | | i8 | | f32 |
* +---+---+ +---+---+ +---+---+ +---+---+ +---+---+ +---+---+
* | | | | | |
* +---v---+ +---v---+ | +---v---+ +---v---+ |
* | | | | | | | | | |
* | deq | | deq | | | deq | | deq | |
* | | | | | | | | | |
* +---+---+ +---+---+ | +---+---+ +---+---+ |
* | | | | | |
* | | | | | |
* +---v-----------v-----------v----------v----------v----------v---+
* | |
* | LSTMSequence / GRUSequence (f32) |
* | |
* +---------------+-----------+----------+-------------------------+
* | | |
* |Y f32 |Ho f32 |Co f32
* | | |
* | | |
* | | |
* v v v
*
* v
*
*
* After:
*
* +-------+ +-------+ +-------+ +-------+ +-------+ +-------+
* | X | | H | | C | | W | | R | | B |
* | | | | | | | | | | | |
* | u8/i8 | | u8/i8 | | f32 | | i8 | | i8 | | f32 |
* +---+---+ +---+---+ +---+---+ +---+---+ +---+---+ +---+---+
* | | | | | |
* | | | | | |
* +---v-----------v-----------v----------v----------v----------v---+
* | TypeRelaxed rt_info[inputScales] |
* | |
* | LSTMSequence / GRUSequence (u8/i8) rt_into[weightsScales] |
* +---------------+-----------+----------+-------------------------+
* | | |
* |Y f32 |Ho u8/i8 |Co f32
* | | |
* | +---v---+ |
* | | | |
* | | deq | |
* | | | |
* | +---+---+ |
* | | |
* | | |
* | | |
* v v v
*/
namespace ov {

View File

@@ -87,6 +87,7 @@
#include "low_precision/convolution_backprop_data.hpp"
#include "low_precision/group_convolution.hpp"
#include "low_precision/multiply_to_group_convolution.hpp"
#include "low_precision/recurrent_cell.hpp"
#include "low_precision/network_helper.hpp"
#include "low_precision/rt_info/bias_attribute.hpp"
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"
@@ -504,10 +505,10 @@ void Transformations::Lpt(const bool hasINT16orINT32Levels, const std::vector<ov
{{1}, {ov::element::i8}}
}),
PrecisionsRestriction::create<ov::opset5::LSTMSequence>({
{{0, 1}, {ov::element::u8, ov::element::i8}},
{{0, 1}, {ov::element::u8}}
}),
PrecisionsRestriction::create<ov::opset6::GRUSequence>({
{{0, 1}, {ov::element::u8, ov::element::i8}},
{{0, 1}, {ov::element::u8}}
}),
});
@@ -548,6 +549,7 @@ void Transformations::Lpt(const bool hasINT16orINT32Levels, const std::vector<ov
return ov::marked_as_bias(node);
});
CPU_DISABLE_PASS_ARM(lptManager, ngraph::pass::low_precision::RecurrentCellTransformation);
CPU_DISABLE_PASS_COMMON(lptManager, ngraph::pass::low_precision::MultiplyToGroupConvolutionTransformation);
lptManager.run_passes(model);
@@ -609,7 +611,7 @@ void Transformations::PostLpt() {
}
// Execute before snippets. Otherwise FQ will be converted to Subgraph
CPU_REGISTER_PASS_COMMON(postLPTPassManager, ConvertFqRnnToQuantizedRnn);
CPU_REGISTER_PASS_X64(postLPTPassManager, ConvertFqRnnToQuantizedRnn);
postLPTPassManager.run_passes(model);
}

View File

@@ -21,7 +21,7 @@ const std::vector<ngraph::pass::low_precision::LayerTransformation::Params> tras
namespace testValues1 {
const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> params = {
// LSTMCell
// LSTMSequence
{
// X
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
@@ -47,8 +47,8 @@ const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> param
{255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
{},
{{}, {}, {}},
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMCell,
"RNNCell",
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMSequence,
"RNNSeq",
"U8"
},
// asymmetrical FQ on weights
@@ -77,14 +77,14 @@ const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> param
{256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
{},
{{}, {}, {}},
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMCell,
"RNNCell",
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMSequence,
"RNNSeq",
"FP32"
}
};
const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{1, 16}, {1, 128}, {1, 128}}};
const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{512, 16}, {512, 128}, {512}}};
const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{1, 2, 16}, {1, 1, 128}, {1, 1, 128}}};
const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{1, 512, 16}, {1, 512, 128}, {1, 512}}};
INSTANTIATE_TEST_SUITE_P(smoke_LPT, RecurrentCellTransformation,
::testing::Combine(
@@ -126,8 +126,8 @@ const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> param
{255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
{},
{{}, {}, {}},
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRU,
"RNNCell",
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRUSequence,
"RNNSeq",
"U8"
},
// asymmetrical FQ on weights
@@ -156,14 +156,14 @@ const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> param
{256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
{},
{{}, {}, {}},
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRU,
"RNNCell",
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRUSequence,
"RNNSeq",
"FP32"
}
};
const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{2, 3}, {2, 3}, {}}};
const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{9, 3}, {9, 3}, {9}}};
const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{1, 1, 3}, {1, 1, 3}, {}}};
const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{1, 9, 3}, {1, 9, 3}, {1, 9}}};
INSTANTIATE_TEST_SUITE_P(smoke_LPT, RecurrentCellTransformation,
::testing::Combine(

View File

@@ -0,0 +1,219 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph/output_vector.hpp"
#include "ngraph/type/element_type.hpp"
#include "openvino/core/node.hpp"
#include "openvino/core/type/element_type.hpp"
#include "openvino/op/gru_sequence.hpp"
#include "openvino/op/lstm_sequence.hpp"
#include "openvino/runtime/tensor.hpp"
#include "test_utils/cpu_test_utils.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "test_utils/fusing_test_utils.hpp"
#include "ngraph_functions/builders.hpp"
#include "common_test_utils/common_utils.hpp"
#include <common_test_utils/ov_tensor_utils.hpp>
#include <algorithm>
#include <cassert>
#include <memory>
#include <vector>
using namespace InferenceEngine;
using namespace CPUTestUtils;
using namespace ov::test;
using namespace ov;
namespace SubgraphTestsDefinitions {
using ConvertFqRnnToQuantizedRnnTestParams = std::tuple<std::string, std::vector<InputShape>, bool>;
class ConvertFqRnnToQuantizedRnn : public testing::WithParamInterface<ConvertFqRnnToQuantizedRnnTestParams>,
public CpuTestWithFusing,
virtual public ov::test::SubgraphBaseTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<ConvertFqRnnToQuantizedRnnTestParams>& obj) {
std::vector<InputShape> inputShapes;
std::string rnnType;
bool quantizedHiddenState = false;
std::tie(rnnType, inputShapes, quantizedHiddenState) = obj.param;
auto batchSize = inputShapes[0];
auto inputSize = inputShapes[1];
auto hiddenSize = inputShapes[2];
std::ostringstream result;
result << "Type=" << rnnType << "_";
result << "IS=(";
for (const auto& shape : inputShapes) {
result << CommonTestUtils::partialShape2str({shape.first}) << "_";
}
result << ")_TS=";
for (size_t i = 0lu; i < inputShapes.front().second.size(); i++) {
result << "{";
for (size_t j = 0lu; j < inputShapes.size(); j++) {
result << CommonTestUtils::vec2str(inputShapes[j].second[i]) << (j < inputShapes.size() - 1 ? "_" : "");
}
result << "}_";
}
result << "quantizedHiddenState=" << quantizedHiddenState;
return result.str();
}
protected:
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override {
inputs.clear();
const auto& funcInputs = function->inputs();
const auto& shapeX = targetInputStaticShapes[0];
const auto& shapeH = targetInputStaticShapes[1];
ov::Tensor tensorX = utils::create_and_fill_tensor(funcInputs[0].get_element_type(), shapeX, 1, 0, 16);
ov::Tensor tensorH = utils::create_and_fill_tensor(funcInputs[1].get_element_type(), shapeH, 1, 0, 16);
inputs.insert({funcInputs[0].get_node_shared_ptr(), tensorX});
inputs.insert({funcInputs[1].get_node_shared_ptr(), tensorH});
if (hasCell) {
const auto& shapeC = targetInputStaticShapes[cellIdx];
ov::Tensor tensorC = utils::create_and_fill_tensor(funcInputs[cellIdx].get_element_type(), shapeC, 2, -1, 128, 2);
inputs.insert({funcInputs[cellIdx].get_node_shared_ptr(), tensorC});
}
}
void SetUp() override {
targetDevice = CommonTestUtils::DEVICE_CPU;
selectedType = "ref_any_I8";
std::vector<InputShape> inputShapes;
std::string rnnType;
bool quantizedHiddenState = false;
std::tie(rnnType, inputShapes, quantizedHiddenState) = this->GetParam();
if (rnnType != "LSTMSequence") // remove cell input for non-cell rnn types
inputShapes.erase(inputShapes.begin() + cellIdx);
init_input_shapes(inputShapes);
const auto inputSize = targetStaticShapes.front()[0][2];
const auto hiddenSize = targetStaticShapes.front()[1][2];
const size_t numDirections = 1;
const size_t numOfGates = rnnType == "LSTMSequence" ? 4 : 3;
const size_t numOfBiasGates = rnnType == "LBRGRUSequence" ? numOfGates + 1 : numOfGates;
const auto ngPrec = element::f32;
ngraph::ParameterVector inputParams;
std::shared_ptr<Node> H;
inputParams = ngraph::builder::makeDynamicParams(ngPrec, inputDynamicShapes);
const auto outputNodes = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(inputParams));
auto makeDataFQ = [](const ngraph::Output<Node>& input) {
const auto fqLevels = 256;
return ngraph::builder::makeFakeQuantize(input, ngraph::element::f32, fqLevels, {},
{-128.f/127}, {1.f},
{-128.f/127}, {1.f});
};
auto X_FQ = makeDataFQ(outputNodes[0]);
if (quantizedHiddenState) {
H = makeDataFQ(outputNodes[1]);
} else {
H = ngraph::builder::makeConstant(ngraph::element::f32, inputDynamicShapes[1].get_shape(), {}, true, 1.f, -1.f);
}
auto W = ngraph::builder::makeConstant(ngraph::element::f32, {numDirections, numOfGates * hiddenSize, inputSize}, {}, true, 1.f, -1.f);
auto R = ngraph::builder::makeConstant(ngraph::element::f32, {numDirections, numOfGates * hiddenSize, hiddenSize}, {}, true, 1.f, -1.f);
auto B = ngraph::builder::makeConstant(ngraph::element::f32, {numDirections, numOfBiasGates * hiddenSize}, {}, true, 0.1f, -0.1f);
auto makeWeightsFQ = [](const std::shared_ptr<Node> weight) {
const auto fqLevelsW = 255;
return ngraph::builder::makeFakeQuantize(weight, ngraph::element::f32,
fqLevelsW, std::vector<size_t>{},
{-127.f/63}, {127.f/63},
{-127.f/63}, {127.f/63});
};
auto W_FQ = makeWeightsFQ(W);
auto R_FQ = makeWeightsFQ(R);
std::shared_ptr<ov::Node> rnnCellOp;
// fill sequence_length constant with max sequence length values
const auto batchSize = targetStaticShapes.front()[0][0];
const auto maxSeqLen = targetStaticShapes.front()[0][1];
std::vector<int> lengths(batchSize, static_cast<int>(maxSeqLen));
auto seq_lengths = ngraph::opset1::Constant::create(element::i64, Shape{batchSize}, lengths);
if (rnnType == "LSTMSequence") {
hasCell = true;
auto C = outputNodes[cellIdx];
rnnCellOp = std::make_shared<ov::op::v5::LSTMSequence>(
X_FQ, H, C, seq_lengths, W_FQ, R_FQ, B,
hiddenSize, op::RecurrentSequenceDirection::FORWARD);
} else if (rnnType == "GRUSequence") {
rnnCellOp = std::make_shared<ov::op::v5::GRUSequence>(
X_FQ, H, seq_lengths, W_FQ, R_FQ, B,
hiddenSize, op::RecurrentSequenceDirection::FORWARD);
} else if (rnnType == "LBRGRUSequence") {
const std::vector<std::string> activations{"sigmoid", "tanh"};
const std::vector<float> activations_alpha, activations_beta;
rnnCellOp = std::make_shared<ov::op::v5::GRUSequence>(
X_FQ, H, seq_lengths, W_FQ, R_FQ, B,
hiddenSize, op::RecurrentSequenceDirection::FORWARD,
activations, activations_alpha, activations_beta, 0.f, true);
} else {
IE_THROW() << "Unexpected offset type";
}
if (maxSeqLen > 1)
abs_threshold = 0.05; // RNN int8 computation is expected to affect the accuracy, especially when sequence_length > 1
function = makeNgraphFunction(ngPrec, inputParams, rnnCellOp, "ConvertFqRnnToQuantizedRnn");
}
private:
static const size_t cellIdx = 2;
bool hasCell = false;
};
TEST_P(ConvertFqRnnToQuantizedRnn, CompareWithRefs) {
run();
CheckPluginRelatedResults(compiledModel, "RNNSeq");
}
namespace {
const std::vector<std::vector<InputShape>> staticShapesLSTM = {
{ // seq len > 1
{ {}, { {2, 5, 10} } }, // X
{ {}, { {2, 1, 4}} }, // H
{ {}, { {2, 1, 4}} }, // C
},
{ // seq len = 1
{ {}, { {2, 1, 5} } }, // X
{ {}, { {2, 1, 1}} }, // H
{ {}, { {2, 1, 1}} }, // C
},
};
std::vector<bool> quantizedHiddenStateParam{true, false};
INSTANTIATE_TEST_SUITE_P(smoke_static, ConvertFqRnnToQuantizedRnn,
::testing::Combine(::testing::Values("LSTMSequence", "GRUSequence"),
// "LBRGRUSequence", // enable after implemented in oneDNN
::testing::ValuesIn(staticShapesLSTM),
::testing::ValuesIn(quantizedHiddenStateParam)),
ConvertFqRnnToQuantizedRnn::getTestCaseName);
} // namespace
} // namespace SubgraphTestsDefinitions

View File

@@ -1,123 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "test_utils/cpu_test_utils.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "test_utils/fusing_test_utils.hpp"
#include "ngraph_functions/builders.hpp"
#include "common_test_utils/common_utils.hpp"
#include <algorithm>
#include <cassert>
using namespace ngraph;
using namespace InferenceEngine;
using namespace CPUTestUtils;
namespace SubgraphTestsDefinitions {
using ConvertFqRnnToQuantizedRnnTestParams = std::tuple<std::string, SizeVector>;
/* using ConvertFqRnnToQuantizedRnnTestParams = std::string; */
class ConvertFqRnnToQuantizedRnn : public testing::WithParamInterface<ConvertFqRnnToQuantizedRnnTestParams>,
public CpuTestWithFusing,
virtual public ov::test::SubgraphBaseTest {
public:
static std::string getTestCaseName(const testing::TestParamInfo<ConvertFqRnnToQuantizedRnnTestParams>& obj) {
SizeVector inputShapes;
std::string rnnType;
std::tie(rnnType, inputShapes) = obj.param;
auto batchSize = inputShapes[0];
auto inputSize = inputShapes[1];
auto hiddenSize = inputShapes[2];
std::ostringstream result;
result << "Type = " << rnnType << "_";
result << "batch = " << batchSize << "_";
result << "input = " << inputSize << "_";
result << "hidden = " << hiddenSize << "_";
return result.str();
}
protected:
void SetUp() override {
targetDevice = CommonTestUtils::DEVICE_CPU;
SizeVector inputShapes;
std::string rnnType;
std::tie(rnnType, inputShapes) = this->GetParam();
auto batchSize = inputShapes[0];
auto inputSize = inputShapes[1];
auto hiddenSize = inputShapes[2];
const float inputDataMin = 6.43123;
const float inputDataMax = -6.48187;
const float outputDataMin = inputDataMin;
const float outputDataMax = outputDataMin;
const SizeVector inputShape = {batchSize, inputSize};
const SizeVector hiddenStateShape = {batchSize, hiddenSize};
const SizeVector cellStateShape = {batchSize, hiddenSize};
init_input_shapes({
{{}, {inputShape}},
{{}, {hiddenStateShape}},
{{}, {cellStateShape}}
});
const auto ngPrec = element::f32;
auto inputParams = builder::makeParams(ngPrec, {inputShape, hiddenStateShape, cellStateShape});
const auto outputNodes = helpers::convert2OutputVector(helpers::castOps2Nodes<op::Parameter>(inputParams));
std::vector<float> empty;
auto W = ngraph::builder::makeConstant(ngraph::element::f32, {4 * hiddenSize, inputSize}, empty, true);
auto R = ngraph::builder::makeConstant(ngraph::element::f32, {4 * hiddenSize, hiddenSize}, empty, true);
auto B = ngraph::builder::makeConstant(ngraph::element::f32, {4 * hiddenSize}, empty, true);
const auto fqLevels = 256;
auto inputFQ = ngraph::builder::makeFakeQuantize(outputNodes[0], ngraph::element::f32, fqLevels, std::vector<size_t>{},
{ inputDataMin }, { inputDataMax }, { outputDataMin }, { outputDataMax });
auto hiddenStateFQ = ngraph::builder::makeFakeQuantize(outputNodes[1], ngraph::element::f32, fqLevels, std::vector<size_t>{},
{ inputDataMin }, { inputDataMax }, { inputDataMin }, { inputDataMax });
auto weightsFQ = ngraph::builder::makeFakeQuantize(W, ngraph::element::f32, fqLevels, std::vector<size_t>{},
{ inputDataMin }, { inputDataMax }, { inputDataMin }, { inputDataMax });
auto recurrentWeightsFQ = ngraph::builder::makeFakeQuantize(R, ngraph::element::f32, fqLevels, std::vector<size_t>{},
{ inputDataMin }, { inputDataMax }, { inputDataMin }, { inputDataMax });
auto rnnCellOp = std::make_shared<ov::op::v4::LSTMCell>(inputFQ, hiddenStateFQ, inputParams[2], weightsFQ, recurrentWeightsFQ, B, hiddenSize);
function = makeNgraphFunction(ngPrec, inputParams, rnnCellOp, "ConvertFqRnnToQuantizedRnn");
}
};
TEST_P(ConvertFqRnnToQuantizedRnn, CompareWithRefs) {
run();
}
namespace {
const std::vector<SizeVector> inputShapes {
{37, 128, 512},
/* {256, 128, 256}, */
};
std::vector<std::string> rnnTypes {"LSTMCell", "RNNCell", "GRUCell"};
INSTANTIATE_TEST_SUITE_P(smoke_Check, ConvertFqRnnToQuantizedRnn,
/* ::testing::ValuesIn(rnnTypes), */
::testing::Combine(::testing::ValuesIn(rnnTypes),
::testing::ValuesIn(inputShapes)),
ConvertFqRnnToQuantizedRnn::getTestCaseName);
} // namespace
} // namespace SubgraphTestsDefinitions

View File

@@ -72,6 +72,9 @@ void RecurrentCellTransformation::SetUp() {
void RecurrentCellTransformation::Run() {
LayerTestsCommon::Run();
if (!executableNetwork)
return;
const auto params = std::get<5>(GetParam());
const auto actualPrecision = getRuntimePrecisionByType(params.layerName);
auto expectedPrecision = params.expectedKernelType;

View File

@@ -69,7 +69,8 @@ std::shared_ptr<ngraph::Function> RecurrentCellFunction::get(
converts[3],
dequantizations[3]);
auto B = ngraph::opset1::Constant::create(inputPrecision, inputWeightsShapes[2], {1});
auto seq_lengths = ngraph::opset1::Constant::create(element::i32, Shape{1}, {3});
auto max_seq_length = inputActivationsShapes[0][1].get_max_length();
auto seq_lengths = ngraph::opset1::Constant::create(element::i32, Shape{1}, {max_seq_length});
std::shared_ptr<ov::op::util::RNNCellBase> rnn_layer;
switch (type) {