[CPU] Add and correct tests for int8 LSTM (#17447)
This commit is contained in:
@@ -455,6 +455,9 @@ void RNN::configurePortDataTypes() {
|
||||
|
||||
if (one_of(memory::data_type::bf16, inDataTypes[xIdx], inDataTypes[hIdx]))
|
||||
inDataTypes[xIdx] = outDataTypes[yIdx] = outDataTypes[hoIdx] = inDataTypes[hIdx] = memory::data_type::bf16; // required by oneDNN.
|
||||
|
||||
if (outDataTypes[yIdx] == memory::data_type::bf16 && one_of(inDataTypes[xIdx], memory::data_type::s8, memory::data_type::u8))
|
||||
outDataTypes[yIdx] = memory::data_type::f32; // oneDNN does not support bf16 output precision for quantized rnn primitive yet
|
||||
}
|
||||
|
||||
void RNN::getSupportedDescriptors() {
|
||||
@@ -870,7 +873,8 @@ dnnl::primitive_desc createPrimitiveDescriptor(const dnnl::engine engine,
|
||||
wDescs[1], // Weights state
|
||||
wDescs[2], // Bias
|
||||
outDataDescs[RNN::InOutKind::Layer]->getDnnlDesc(), // Out Data
|
||||
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc()); // Out State
|
||||
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc(), // Out State
|
||||
attr);
|
||||
case dnnl::algorithm::vanilla_gru:
|
||||
return dnnl::gru_forward::primitive_desc(
|
||||
engine,
|
||||
@@ -882,7 +886,8 @@ dnnl::primitive_desc createPrimitiveDescriptor(const dnnl::engine engine,
|
||||
wDescs[1], // Weights state
|
||||
wDescs[2], // Bias
|
||||
outDataDescs[RNN::InOutKind::Layer]->getDnnlDesc(), // Out Data
|
||||
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc()); // Out State
|
||||
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc(), // Out State
|
||||
attr);
|
||||
case dnnl::algorithm::lbr_gru:
|
||||
return dnnl::lbr_gru_forward::primitive_desc(
|
||||
engine,
|
||||
@@ -894,7 +899,8 @@ dnnl::primitive_desc createPrimitiveDescriptor(const dnnl::engine engine,
|
||||
wDescs[1], // Weights state
|
||||
wDescs[2], // Bias
|
||||
outDataDescs[RNN::InOutKind::Layer]->getDnnlDesc(), // Out Data
|
||||
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc()); // Out State
|
||||
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc(), // Out State
|
||||
attr);
|
||||
case dnnl::algorithm::vanilla_lstm:
|
||||
return dnnl::lstm_forward::primitive_desc(
|
||||
engine,
|
||||
@@ -908,7 +914,8 @@ dnnl::primitive_desc createPrimitiveDescriptor(const dnnl::engine engine,
|
||||
wDescs[2], // Bias
|
||||
outDataDescs[RNN::InOutKind::Layer]->getDnnlDesc(), // Out Data
|
||||
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc(), // Out State
|
||||
outDataDescs[RNN::InOutKind::CellState]->getDnnlDesc()); // Out State C
|
||||
outDataDescs[RNN::InOutKind::CellState]->getDnnlDesc(), // Out State C
|
||||
attr);
|
||||
case dnnl::algorithm::vanilla_augru:
|
||||
return dnnl::augru_forward::primitive_desc(
|
||||
engine,
|
||||
@@ -921,7 +928,8 @@ dnnl::primitive_desc createPrimitiveDescriptor(const dnnl::engine engine,
|
||||
wDescs[1], // Weights state
|
||||
wDescs[2], // Bias
|
||||
outDataDescs[RNN::InOutKind::Layer]->getDnnlDesc(), // Out Data
|
||||
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc()); // Out State
|
||||
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc(), // Out State
|
||||
attr);
|
||||
case dnnl::algorithm::lbr_augru:
|
||||
return dnnl::lbr_augru_forward::primitive_desc(
|
||||
engine,
|
||||
@@ -934,7 +942,8 @@ dnnl::primitive_desc createPrimitiveDescriptor(const dnnl::engine engine,
|
||||
wDescs[1], // Weights state
|
||||
wDescs[2], // Bias
|
||||
outDataDescs[RNN::InOutKind::Layer]->getDnnlDesc(), // Out Data
|
||||
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc()); // Out State
|
||||
outDataDescs[RNN::InOutKind::HiddenState]->getDnnlDesc(), // Out State
|
||||
attr);
|
||||
default:
|
||||
IE_THROW() << "RNN. Unknown cell type";
|
||||
}
|
||||
@@ -979,19 +988,19 @@ void RNN::createDescriptor(const std::vector<MemoryDescPtr> &inputDesc,
|
||||
|
||||
// Fill supported config
|
||||
NodeConfig config;
|
||||
for (size_t i = 0; i < inputDesc.size(); i++) {
|
||||
for (const auto &desc : inputDesc) {
|
||||
PortConfig dataConfig;
|
||||
dataConfig.inPlace(-1);
|
||||
dataConfig.constant(false);
|
||||
dataConfig.setMemDesc(inputDesc[i]);
|
||||
dataConfig.setMemDesc(desc);
|
||||
config.inConfs.push_back(dataConfig);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < outputDesc.size(); i++) {
|
||||
for (const auto &desc : outputDesc) {
|
||||
PortConfig dataConfig;
|
||||
dataConfig.inPlace(-1);
|
||||
dataConfig.constant(false);
|
||||
dataConfig.setMemDesc(outputDesc[i]);
|
||||
dataConfig.setMemDesc(desc);
|
||||
config.outConfs.push_back(dataConfig);
|
||||
}
|
||||
|
||||
@@ -1003,7 +1012,12 @@ Node::AttrPtr RNN::initPrimitiveAttr() {
|
||||
attr->set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
||||
|
||||
if (one_of(inDataTypes[xIdx], memory::data_type::u8, memory::data_type::s8)) {
|
||||
const int weightsScaleMask = 0;
|
||||
const int weightsScaleMask = 0
|
||||
+ (1 << 3) // bit, indicating the unique scales for `g` dim in `ldigo`
|
||||
+ (1 << 4); // bit, indicating the unique scales for `o` dim in `ldigo`
|
||||
|
||||
DEBUG_LOG(getName(), ": inputScale: ", inputScale, ", inputShift: ", inputShift,
|
||||
", weightsScaleMask: ", weightsScaleMask, ", weightsScales[0]: ", weightsScales[0]);
|
||||
|
||||
attr->set_rnn_weights_qparams(weightsScaleMask, weightsScales);
|
||||
attr->set_rnn_data_qparams(inputScale, inputShift);
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
//
|
||||
#include "convert_fq_rnn_to_quantized_rnn.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <ngraph/opsets/opset9.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
@@ -14,6 +15,7 @@
|
||||
|
||||
#include "ie_common.h"
|
||||
#include "itt.hpp"
|
||||
#include "openvino/core/type/element_type.hpp"
|
||||
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
@@ -164,11 +166,15 @@ ov::intel_cpu::ConvertFqRnnToQuantizedRnn::ConvertFqRnnToQuantizedRnn() {
|
||||
if (*input_scale_ptr == 0.f)
|
||||
OPENVINO_THROW("Cannot handle zero input scale");
|
||||
|
||||
const float input_scale = 1 / *input_scale_ptr;
|
||||
const std::vector<float> weights_scales = weights_scale_constant->get_vector<float>();
|
||||
const float input_scale = 1 / *input_scale_ptr;
|
||||
std::vector<float> weights_scales = weights_scale_constant->get_vector<float>();
|
||||
|
||||
// transform dequantization scales into quantization ones
|
||||
std::transform(weights_scales.begin(), weights_scales.end(), weights_scales.begin(), [](float& scale) { return 1 / scale; });
|
||||
|
||||
auto& runtime_info = rnn_quantized->get_rt_info();
|
||||
|
||||
// use runtime information to store input and weight scales
|
||||
runtime_info["inputScale"] = input_scale;
|
||||
runtime_info["weightsScales"] = weights_scales;
|
||||
|
||||
@@ -178,7 +184,6 @@ ov::intel_cpu::ConvertFqRnnToQuantizedRnn::ConvertFqRnnToQuantizedRnn() {
|
||||
if (input_shift_it != pattern_map.end()) {
|
||||
const auto input_shift_constant = std::dynamic_pointer_cast<ngraph::opset9::Constant>(input_shift_it->second.get_node_shared_ptr());
|
||||
const float* input_shift_ptr = input_shift_constant->get_data_ptr<float>();
|
||||
|
||||
runtime_info["inputShift"] = *input_shift_ptr;
|
||||
}
|
||||
|
||||
@@ -207,6 +212,7 @@ ov::intel_cpu::ConvertFqRnnToQuantizedRnn::ConvertFqRnnToQuantizedRnn() {
|
||||
}
|
||||
|
||||
auto new_multiply = multiply->clone_with_new_inputs({multiply_in, multiply->input_value(1)});
|
||||
new_multiply->set_friendly_name(rnn_quantized->get_friendly_name() + ".1");
|
||||
|
||||
for (auto output : H_outputs) {
|
||||
output.replace_source_output(new_multiply);
|
||||
|
||||
@@ -12,6 +12,63 @@
|
||||
* with FQ operations on the inputs and forms a new TypeRelaxed operation
|
||||
* with quantization parameters as runtime parameters of the operation.
|
||||
* @todo add ascii graph examples
|
||||
*
|
||||
* Before:
|
||||
*
|
||||
* +-------+ +-------+ +-------+ +-------+ +-------+ +-------+
|
||||
* | X | | H | | C | | W | | R | | B |
|
||||
* | | | | | | | | | | | |
|
||||
* | u8/i8 | | u8/i8 | | f32 | | i8 | | i8 | | f32 |
|
||||
* +---+---+ +---+---+ +---+---+ +---+---+ +---+---+ +---+---+
|
||||
* | | | | | |
|
||||
* +---v---+ +---v---+ | +---v---+ +---v---+ |
|
||||
* | | | | | | | | | |
|
||||
* | deq | | deq | | | deq | | deq | |
|
||||
* | | | | | | | | | |
|
||||
* +---+---+ +---+---+ | +---+---+ +---+---+ |
|
||||
* | | | | | |
|
||||
* | | | | | |
|
||||
* +---v-----------v-----------v----------v----------v----------v---+
|
||||
* | |
|
||||
* | LSTMSequence / GRUSequence (f32) |
|
||||
* | |
|
||||
* +---------------+-----------+----------+-------------------------+
|
||||
* | | |
|
||||
* |Y f32 |Ho f32 |Co f32
|
||||
* | | |
|
||||
* | | |
|
||||
* | | |
|
||||
* v v v
|
||||
*
|
||||
* v
|
||||
*
|
||||
*
|
||||
* After:
|
||||
*
|
||||
* +-------+ +-------+ +-------+ +-------+ +-------+ +-------+
|
||||
* | X | | H | | C | | W | | R | | B |
|
||||
* | | | | | | | | | | | |
|
||||
* | u8/i8 | | u8/i8 | | f32 | | i8 | | i8 | | f32 |
|
||||
* +---+---+ +---+---+ +---+---+ +---+---+ +---+---+ +---+---+
|
||||
* | | | | | |
|
||||
* | | | | | |
|
||||
* +---v-----------v-----------v----------v----------v----------v---+
|
||||
* | TypeRelaxed rt_info[inputScales] |
|
||||
* | |
|
||||
* | LSTMSequence / GRUSequence (u8/i8) rt_into[weightsScales] |
|
||||
* +---------------+-----------+----------+-------------------------+
|
||||
* | | |
|
||||
* |Y f32 |Ho u8/i8 |Co f32
|
||||
* | | |
|
||||
* | +---v---+ |
|
||||
* | | | |
|
||||
* | | deq | |
|
||||
* | | | |
|
||||
* | +---+---+ |
|
||||
* | | |
|
||||
* | | |
|
||||
* | | |
|
||||
* v v v
|
||||
*/
|
||||
|
||||
namespace ov {
|
||||
|
||||
@@ -87,6 +87,7 @@
|
||||
#include "low_precision/convolution_backprop_data.hpp"
|
||||
#include "low_precision/group_convolution.hpp"
|
||||
#include "low_precision/multiply_to_group_convolution.hpp"
|
||||
#include "low_precision/recurrent_cell.hpp"
|
||||
#include "low_precision/network_helper.hpp"
|
||||
#include "low_precision/rt_info/bias_attribute.hpp"
|
||||
#include "transformations/low_precision/mark_dequantization_subgraph.hpp"
|
||||
@@ -504,10 +505,10 @@ void Transformations::Lpt(const bool hasINT16orINT32Levels, const std::vector<ov
|
||||
{{1}, {ov::element::i8}}
|
||||
}),
|
||||
PrecisionsRestriction::create<ov::opset5::LSTMSequence>({
|
||||
{{0, 1}, {ov::element::u8, ov::element::i8}},
|
||||
{{0, 1}, {ov::element::u8}}
|
||||
}),
|
||||
PrecisionsRestriction::create<ov::opset6::GRUSequence>({
|
||||
{{0, 1}, {ov::element::u8, ov::element::i8}},
|
||||
{{0, 1}, {ov::element::u8}}
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -548,6 +549,7 @@ void Transformations::Lpt(const bool hasINT16orINT32Levels, const std::vector<ov
|
||||
return ov::marked_as_bias(node);
|
||||
});
|
||||
|
||||
CPU_DISABLE_PASS_ARM(lptManager, ngraph::pass::low_precision::RecurrentCellTransformation);
|
||||
CPU_DISABLE_PASS_COMMON(lptManager, ngraph::pass::low_precision::MultiplyToGroupConvolutionTransformation);
|
||||
|
||||
lptManager.run_passes(model);
|
||||
@@ -609,7 +611,7 @@ void Transformations::PostLpt() {
|
||||
}
|
||||
|
||||
// Execute before snippets. Otherwise FQ will be converted to Subgraph
|
||||
CPU_REGISTER_PASS_COMMON(postLPTPassManager, ConvertFqRnnToQuantizedRnn);
|
||||
CPU_REGISTER_PASS_X64(postLPTPassManager, ConvertFqRnnToQuantizedRnn);
|
||||
postLPTPassManager.run_passes(model);
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ const std::vector<ngraph::pass::low_precision::LayerTransformation::Params> tras
|
||||
namespace testValues1 {
|
||||
|
||||
const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> params = {
|
||||
// LSTMCell
|
||||
// LSTMSequence
|
||||
{
|
||||
// X
|
||||
{256ul, {}, {0.f}, {2.55f}, {0.f}, {255.f}},
|
||||
@@ -47,8 +47,8 @@ const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> param
|
||||
{255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMCell,
|
||||
"RNNCell",
|
||||
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMSequence,
|
||||
"RNNSeq",
|
||||
"U8"
|
||||
},
|
||||
// asymmetrical FQ on weights
|
||||
@@ -77,14 +77,14 @@ const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> param
|
||||
{256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMCell,
|
||||
"RNNCell",
|
||||
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::LSTMSequence,
|
||||
"RNNSeq",
|
||||
"FP32"
|
||||
}
|
||||
};
|
||||
|
||||
const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{1, 16}, {1, 128}, {1, 128}}};
|
||||
const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{512, 16}, {512, 128}, {512}}};
|
||||
const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{1, 2, 16}, {1, 1, 128}, {1, 1, 128}}};
|
||||
const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{1, 512, 16}, {1, 512, 128}, {1, 512}}};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_LPT, RecurrentCellTransformation,
|
||||
::testing::Combine(
|
||||
@@ -126,8 +126,8 @@ const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> param
|
||||
{255ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRU,
|
||||
"RNNCell",
|
||||
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRUSequence,
|
||||
"RNNSeq",
|
||||
"U8"
|
||||
},
|
||||
// asymmetrical FQ on weights
|
||||
@@ -156,14 +156,14 @@ const std::vector<LayerTestsDefinitions::RecurrentCellTransformationParam> param
|
||||
{256ul, {}, {-1.27f}, {1.27f}, {-1.27f}, {1.27f}},
|
||||
{},
|
||||
{{}, {}, {}},
|
||||
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRU,
|
||||
"RNNCell",
|
||||
ngraph::builder::subgraph::RecurrentCellFunction::RNNType::GRUSequence,
|
||||
"RNNSeq",
|
||||
"FP32"
|
||||
}
|
||||
};
|
||||
|
||||
const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{2, 3}, {2, 3}, {}}};
|
||||
const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{9, 3}, {9, 3}, {9}}};
|
||||
const std::vector<std::vector<ngraph::PartialShape>> activations_shapes = {{{1, 1, 3}, {1, 1, 3}, {}}};
|
||||
const std::vector<std::vector<ngraph::Shape>> weights_shapes = {{{1, 9, 3}, {1, 9, 3}, {1, 9}}};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_LPT, RecurrentCellTransformation,
|
||||
::testing::Combine(
|
||||
@@ -0,0 +1,219 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph/output_vector.hpp"
|
||||
#include "ngraph/type/element_type.hpp"
|
||||
#include "openvino/core/node.hpp"
|
||||
#include "openvino/core/type/element_type.hpp"
|
||||
#include "openvino/op/gru_sequence.hpp"
|
||||
#include "openvino/op/lstm_sequence.hpp"
|
||||
#include "openvino/runtime/tensor.hpp"
|
||||
#include "test_utils/cpu_test_utils.hpp"
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
#include "test_utils/fusing_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
#include <common_test_utils/ov_tensor_utils.hpp>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
using namespace InferenceEngine;
|
||||
using namespace CPUTestUtils;
|
||||
using namespace ov::test;
|
||||
using namespace ov;
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
using ConvertFqRnnToQuantizedRnnTestParams = std::tuple<std::string, std::vector<InputShape>, bool>;
|
||||
|
||||
class ConvertFqRnnToQuantizedRnn : public testing::WithParamInterface<ConvertFqRnnToQuantizedRnnTestParams>,
|
||||
public CpuTestWithFusing,
|
||||
virtual public ov::test::SubgraphBaseTest {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<ConvertFqRnnToQuantizedRnnTestParams>& obj) {
|
||||
std::vector<InputShape> inputShapes;
|
||||
std::string rnnType;
|
||||
bool quantizedHiddenState = false;
|
||||
|
||||
std::tie(rnnType, inputShapes, quantizedHiddenState) = obj.param;
|
||||
|
||||
auto batchSize = inputShapes[0];
|
||||
auto inputSize = inputShapes[1];
|
||||
auto hiddenSize = inputShapes[2];
|
||||
|
||||
std::ostringstream result;
|
||||
|
||||
result << "Type=" << rnnType << "_";
|
||||
|
||||
result << "IS=(";
|
||||
for (const auto& shape : inputShapes) {
|
||||
result << CommonTestUtils::partialShape2str({shape.first}) << "_";
|
||||
}
|
||||
result << ")_TS=";
|
||||
for (size_t i = 0lu; i < inputShapes.front().second.size(); i++) {
|
||||
result << "{";
|
||||
for (size_t j = 0lu; j < inputShapes.size(); j++) {
|
||||
result << CommonTestUtils::vec2str(inputShapes[j].second[i]) << (j < inputShapes.size() - 1 ? "_" : "");
|
||||
}
|
||||
result << "}_";
|
||||
}
|
||||
|
||||
result << "quantizedHiddenState=" << quantizedHiddenState;
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override {
|
||||
inputs.clear();
|
||||
|
||||
const auto& funcInputs = function->inputs();
|
||||
const auto& shapeX = targetInputStaticShapes[0];
|
||||
const auto& shapeH = targetInputStaticShapes[1];
|
||||
|
||||
ov::Tensor tensorX = utils::create_and_fill_tensor(funcInputs[0].get_element_type(), shapeX, 1, 0, 16);
|
||||
ov::Tensor tensorH = utils::create_and_fill_tensor(funcInputs[1].get_element_type(), shapeH, 1, 0, 16);
|
||||
|
||||
inputs.insert({funcInputs[0].get_node_shared_ptr(), tensorX});
|
||||
inputs.insert({funcInputs[1].get_node_shared_ptr(), tensorH});
|
||||
|
||||
if (hasCell) {
|
||||
const auto& shapeC = targetInputStaticShapes[cellIdx];
|
||||
ov::Tensor tensorC = utils::create_and_fill_tensor(funcInputs[cellIdx].get_element_type(), shapeC, 2, -1, 128, 2);
|
||||
inputs.insert({funcInputs[cellIdx].get_node_shared_ptr(), tensorC});
|
||||
}
|
||||
}
|
||||
|
||||
void SetUp() override {
|
||||
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||
selectedType = "ref_any_I8";
|
||||
|
||||
std::vector<InputShape> inputShapes;
|
||||
std::string rnnType;
|
||||
bool quantizedHiddenState = false;
|
||||
|
||||
std::tie(rnnType, inputShapes, quantizedHiddenState) = this->GetParam();
|
||||
|
||||
if (rnnType != "LSTMSequence") // remove cell input for non-cell rnn types
|
||||
inputShapes.erase(inputShapes.begin() + cellIdx);
|
||||
|
||||
init_input_shapes(inputShapes);
|
||||
|
||||
const auto inputSize = targetStaticShapes.front()[0][2];
|
||||
const auto hiddenSize = targetStaticShapes.front()[1][2];
|
||||
const size_t numDirections = 1;
|
||||
const size_t numOfGates = rnnType == "LSTMSequence" ? 4 : 3;
|
||||
const size_t numOfBiasGates = rnnType == "LBRGRUSequence" ? numOfGates + 1 : numOfGates;
|
||||
|
||||
const auto ngPrec = element::f32;
|
||||
ngraph::ParameterVector inputParams;
|
||||
std::shared_ptr<Node> H;
|
||||
|
||||
inputParams = ngraph::builder::makeDynamicParams(ngPrec, inputDynamicShapes);
|
||||
|
||||
const auto outputNodes = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(inputParams));
|
||||
|
||||
auto makeDataFQ = [](const ngraph::Output<Node>& input) {
|
||||
const auto fqLevels = 256;
|
||||
return ngraph::builder::makeFakeQuantize(input, ngraph::element::f32, fqLevels, {},
|
||||
{-128.f/127}, {1.f},
|
||||
{-128.f/127}, {1.f});
|
||||
};
|
||||
|
||||
auto X_FQ = makeDataFQ(outputNodes[0]);
|
||||
|
||||
if (quantizedHiddenState) {
|
||||
H = makeDataFQ(outputNodes[1]);
|
||||
} else {
|
||||
H = ngraph::builder::makeConstant(ngraph::element::f32, inputDynamicShapes[1].get_shape(), {}, true, 1.f, -1.f);
|
||||
}
|
||||
|
||||
auto W = ngraph::builder::makeConstant(ngraph::element::f32, {numDirections, numOfGates * hiddenSize, inputSize}, {}, true, 1.f, -1.f);
|
||||
auto R = ngraph::builder::makeConstant(ngraph::element::f32, {numDirections, numOfGates * hiddenSize, hiddenSize}, {}, true, 1.f, -1.f);
|
||||
auto B = ngraph::builder::makeConstant(ngraph::element::f32, {numDirections, numOfBiasGates * hiddenSize}, {}, true, 0.1f, -0.1f);
|
||||
|
||||
auto makeWeightsFQ = [](const std::shared_ptr<Node> weight) {
|
||||
const auto fqLevelsW = 255;
|
||||
return ngraph::builder::makeFakeQuantize(weight, ngraph::element::f32,
|
||||
fqLevelsW, std::vector<size_t>{},
|
||||
{-127.f/63}, {127.f/63},
|
||||
{-127.f/63}, {127.f/63});
|
||||
};
|
||||
|
||||
auto W_FQ = makeWeightsFQ(W);
|
||||
auto R_FQ = makeWeightsFQ(R);
|
||||
|
||||
std::shared_ptr<ov::Node> rnnCellOp;
|
||||
|
||||
// fill sequence_length constant with max sequence length values
|
||||
const auto batchSize = targetStaticShapes.front()[0][0];
|
||||
const auto maxSeqLen = targetStaticShapes.front()[0][1];
|
||||
std::vector<int> lengths(batchSize, static_cast<int>(maxSeqLen));
|
||||
auto seq_lengths = ngraph::opset1::Constant::create(element::i64, Shape{batchSize}, lengths);
|
||||
|
||||
if (rnnType == "LSTMSequence") {
|
||||
hasCell = true;
|
||||
auto C = outputNodes[cellIdx];
|
||||
rnnCellOp = std::make_shared<ov::op::v5::LSTMSequence>(
|
||||
X_FQ, H, C, seq_lengths, W_FQ, R_FQ, B,
|
||||
hiddenSize, op::RecurrentSequenceDirection::FORWARD);
|
||||
} else if (rnnType == "GRUSequence") {
|
||||
rnnCellOp = std::make_shared<ov::op::v5::GRUSequence>(
|
||||
X_FQ, H, seq_lengths, W_FQ, R_FQ, B,
|
||||
hiddenSize, op::RecurrentSequenceDirection::FORWARD);
|
||||
} else if (rnnType == "LBRGRUSequence") {
|
||||
const std::vector<std::string> activations{"sigmoid", "tanh"};
|
||||
const std::vector<float> activations_alpha, activations_beta;
|
||||
rnnCellOp = std::make_shared<ov::op::v5::GRUSequence>(
|
||||
X_FQ, H, seq_lengths, W_FQ, R_FQ, B,
|
||||
hiddenSize, op::RecurrentSequenceDirection::FORWARD,
|
||||
activations, activations_alpha, activations_beta, 0.f, true);
|
||||
} else {
|
||||
IE_THROW() << "Unexpected offset type";
|
||||
}
|
||||
|
||||
if (maxSeqLen > 1)
|
||||
abs_threshold = 0.05; // RNN int8 computation is expected to affect the accuracy, especially when sequence_length > 1
|
||||
|
||||
function = makeNgraphFunction(ngPrec, inputParams, rnnCellOp, "ConvertFqRnnToQuantizedRnn");
|
||||
}
|
||||
private:
|
||||
static const size_t cellIdx = 2;
|
||||
bool hasCell = false;
|
||||
};
|
||||
|
||||
TEST_P(ConvertFqRnnToQuantizedRnn, CompareWithRefs) {
|
||||
run();
|
||||
CheckPluginRelatedResults(compiledModel, "RNNSeq");
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
const std::vector<std::vector<InputShape>> staticShapesLSTM = {
|
||||
{ // seq len > 1
|
||||
{ {}, { {2, 5, 10} } }, // X
|
||||
{ {}, { {2, 1, 4}} }, // H
|
||||
{ {}, { {2, 1, 4}} }, // C
|
||||
},
|
||||
{ // seq len = 1
|
||||
{ {}, { {2, 1, 5} } }, // X
|
||||
{ {}, { {2, 1, 1}} }, // H
|
||||
{ {}, { {2, 1, 1}} }, // C
|
||||
},
|
||||
};
|
||||
|
||||
std::vector<bool> quantizedHiddenStateParam{true, false};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_static, ConvertFqRnnToQuantizedRnn,
|
||||
::testing::Combine(::testing::Values("LSTMSequence", "GRUSequence"),
|
||||
// "LBRGRUSequence", // enable after implemented in oneDNN
|
||||
::testing::ValuesIn(staticShapesLSTM),
|
||||
::testing::ValuesIn(quantizedHiddenStateParam)),
|
||||
ConvertFqRnnToQuantizedRnn::getTestCaseName);
|
||||
} // namespace
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
||||
@@ -1,123 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "test_utils/cpu_test_utils.hpp"
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
#include "test_utils/fusing_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace InferenceEngine;
|
||||
using namespace CPUTestUtils;
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
using ConvertFqRnnToQuantizedRnnTestParams = std::tuple<std::string, SizeVector>;
|
||||
/* using ConvertFqRnnToQuantizedRnnTestParams = std::string; */
|
||||
|
||||
class ConvertFqRnnToQuantizedRnn : public testing::WithParamInterface<ConvertFqRnnToQuantizedRnnTestParams>,
|
||||
public CpuTestWithFusing,
|
||||
virtual public ov::test::SubgraphBaseTest {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<ConvertFqRnnToQuantizedRnnTestParams>& obj) {
|
||||
SizeVector inputShapes;
|
||||
std::string rnnType;
|
||||
std::tie(rnnType, inputShapes) = obj.param;
|
||||
|
||||
auto batchSize = inputShapes[0];
|
||||
auto inputSize = inputShapes[1];
|
||||
auto hiddenSize = inputShapes[2];
|
||||
|
||||
std::ostringstream result;
|
||||
result << "Type = " << rnnType << "_";
|
||||
result << "batch = " << batchSize << "_";
|
||||
result << "input = " << inputSize << "_";
|
||||
result << "hidden = " << hiddenSize << "_";
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
void SetUp() override {
|
||||
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||
|
||||
SizeVector inputShapes;
|
||||
std::string rnnType;
|
||||
|
||||
std::tie(rnnType, inputShapes) = this->GetParam();
|
||||
|
||||
auto batchSize = inputShapes[0];
|
||||
auto inputSize = inputShapes[1];
|
||||
auto hiddenSize = inputShapes[2];
|
||||
|
||||
const float inputDataMin = 6.43123;
|
||||
const float inputDataMax = -6.48187;
|
||||
const float outputDataMin = inputDataMin;
|
||||
const float outputDataMax = outputDataMin;
|
||||
|
||||
const SizeVector inputShape = {batchSize, inputSize};
|
||||
const SizeVector hiddenStateShape = {batchSize, hiddenSize};
|
||||
const SizeVector cellStateShape = {batchSize, hiddenSize};
|
||||
|
||||
init_input_shapes({
|
||||
{{}, {inputShape}},
|
||||
{{}, {hiddenStateShape}},
|
||||
{{}, {cellStateShape}}
|
||||
});
|
||||
|
||||
const auto ngPrec = element::f32;
|
||||
auto inputParams = builder::makeParams(ngPrec, {inputShape, hiddenStateShape, cellStateShape});
|
||||
const auto outputNodes = helpers::convert2OutputVector(helpers::castOps2Nodes<op::Parameter>(inputParams));
|
||||
|
||||
std::vector<float> empty;
|
||||
auto W = ngraph::builder::makeConstant(ngraph::element::f32, {4 * hiddenSize, inputSize}, empty, true);
|
||||
auto R = ngraph::builder::makeConstant(ngraph::element::f32, {4 * hiddenSize, hiddenSize}, empty, true);
|
||||
auto B = ngraph::builder::makeConstant(ngraph::element::f32, {4 * hiddenSize}, empty, true);
|
||||
|
||||
const auto fqLevels = 256;
|
||||
|
||||
auto inputFQ = ngraph::builder::makeFakeQuantize(outputNodes[0], ngraph::element::f32, fqLevels, std::vector<size_t>{},
|
||||
{ inputDataMin }, { inputDataMax }, { outputDataMin }, { outputDataMax });
|
||||
|
||||
auto hiddenStateFQ = ngraph::builder::makeFakeQuantize(outputNodes[1], ngraph::element::f32, fqLevels, std::vector<size_t>{},
|
||||
{ inputDataMin }, { inputDataMax }, { inputDataMin }, { inputDataMax });
|
||||
|
||||
auto weightsFQ = ngraph::builder::makeFakeQuantize(W, ngraph::element::f32, fqLevels, std::vector<size_t>{},
|
||||
{ inputDataMin }, { inputDataMax }, { inputDataMin }, { inputDataMax });
|
||||
|
||||
auto recurrentWeightsFQ = ngraph::builder::makeFakeQuantize(R, ngraph::element::f32, fqLevels, std::vector<size_t>{},
|
||||
{ inputDataMin }, { inputDataMax }, { inputDataMin }, { inputDataMax });
|
||||
|
||||
auto rnnCellOp = std::make_shared<ov::op::v4::LSTMCell>(inputFQ, hiddenStateFQ, inputParams[2], weightsFQ, recurrentWeightsFQ, B, hiddenSize);
|
||||
|
||||
function = makeNgraphFunction(ngPrec, inputParams, rnnCellOp, "ConvertFqRnnToQuantizedRnn");
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ConvertFqRnnToQuantizedRnn, CompareWithRefs) {
|
||||
run();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
const std::vector<SizeVector> inputShapes {
|
||||
{37, 128, 512},
|
||||
/* {256, 128, 256}, */
|
||||
};
|
||||
|
||||
std::vector<std::string> rnnTypes {"LSTMCell", "RNNCell", "GRUCell"};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Check, ConvertFqRnnToQuantizedRnn,
|
||||
/* ::testing::ValuesIn(rnnTypes), */
|
||||
::testing::Combine(::testing::ValuesIn(rnnTypes),
|
||||
::testing::ValuesIn(inputShapes)),
|
||||
ConvertFqRnnToQuantizedRnn::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
||||
@@ -72,6 +72,9 @@ void RecurrentCellTransformation::SetUp() {
|
||||
void RecurrentCellTransformation::Run() {
|
||||
LayerTestsCommon::Run();
|
||||
|
||||
if (!executableNetwork)
|
||||
return;
|
||||
|
||||
const auto params = std::get<5>(GetParam());
|
||||
const auto actualPrecision = getRuntimePrecisionByType(params.layerName);
|
||||
auto expectedPrecision = params.expectedKernelType;
|
||||
|
||||
@@ -69,7 +69,8 @@ std::shared_ptr<ngraph::Function> RecurrentCellFunction::get(
|
||||
converts[3],
|
||||
dequantizations[3]);
|
||||
auto B = ngraph::opset1::Constant::create(inputPrecision, inputWeightsShapes[2], {1});
|
||||
auto seq_lengths = ngraph::opset1::Constant::create(element::i32, Shape{1}, {3});
|
||||
auto max_seq_length = inputActivationsShapes[0][1].get_max_length();
|
||||
auto seq_lengths = ngraph::opset1::Constant::create(element::i32, Shape{1}, {max_seq_length});
|
||||
|
||||
std::shared_ptr<ov::op::util::RNNCellBase> rnn_layer;
|
||||
switch (type) {
|
||||
|
||||
Reference in New Issue
Block a user