[CPU] RNN node kernel caching. (#9588)

This commit is contained in:
Nikolay Shchegolev 2022-01-14 09:28:05 +03:00 committed by GitHub
parent d5fb0a0c24
commit 77c2c5fab3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 235 additions and 129 deletions

View File

@ -1,4 +1,4 @@
// Copyright (C) 2018-2021 Intel Corporation
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@ -10,6 +10,7 @@
#include "mkldnn_input_node.h"
#include <mkldnn_extension_utils.h>
#include "memory_desc/dnnl_blocked_memory_desc.h"
#include <common/primitive_hashing_utils.hpp>
#include <ngraph/node.hpp>
@ -110,6 +111,61 @@ const std::map<Precision, Precision> MKLDNNRNN::weightsByLayerPrec {
// {Precision::U8, Precision::I8},
};
struct RNNKey {
const std::vector<DnnlBlockedMemoryDescPtr> inDataDescs;
const std::vector<DnnlBlockedMemoryDescPtr> outDataDescs;
const std::vector<mkldnn::memory::desc> wDescs;
mkldnn::algorithm cellType;
size_t hash() const;
bool operator==(const RNNKey& rhs) const;
};
size_t RNNKey::hash() const {
using namespace dnnl::impl;
using namespace dnnl::impl::primitive_hashing;
size_t seed = 0lu;
for (auto& desc : inDataDescs) {
if (desc != nullptr)
seed = hash_combine(seed, get_md_hash(desc->getDnnlDesc().data));
}
for (auto& desc : outDataDescs) {
if (desc != nullptr)
seed = hash_combine(seed, get_md_hash(desc->getDnnlDesc().data));
}
for (auto& desc : wDescs) {
seed = hash_combine(seed, get_md_hash(desc.data));
}
seed = hash_combine(seed, cellType);
return seed;
}
bool RNNKey::operator==(const RNNKey& rhs) const {
if (inDataDescs.size() != rhs.inDataDescs.size() || outDataDescs.size() != rhs.outDataDescs.size() || wDescs.size() != rhs.wDescs.size() ||
cellType != rhs.cellType)
return false;
for (size_t i = 0lu; i < inDataDescs.size(); i++) {
if (inDataDescs[i] != rhs.inDataDescs[i] && (inDataDescs[i] == nullptr || rhs.inDataDescs[i] == nullptr ||
inDataDescs[i]->getDnnlDesc() != rhs.inDataDescs[i]->getDnnlDesc()))
return false;
}
for (size_t i = 0lu; i < outDataDescs.size(); i++) {
if (outDataDescs[i] != rhs.outDataDescs[i] && (outDataDescs[i] == nullptr || rhs.outDataDescs[i] ||
outDataDescs[i]->getDnnlDesc() == rhs.outDataDescs[i]->getDnnlDesc()))
return false;
}
for (size_t i = 0lu; i < wDescs.size(); i++) {
if (wDescs[i] != rhs.wDescs[i])
return false;
}
return true;
}
bool MKLDNNRNN::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
try {
if (!one_of(op->get_type_info(),
@ -231,7 +287,6 @@ MKLDNNRNN::MKLDNNRNN(const std::shared_ptr<ov::Node>& op, const mkldnn::engine&
THROW_ERROR << "does not have original layer for RNNCell.";
cell_type = ie2dnnl(op);
cell_act = mkldnn::algorithm::undef;
if (!rnnCellBase->get_activations().empty())
cell_act = ie2dnnl(rnnCellBase->get_activations()[0]); // Works only for RNN with one gate
@ -305,15 +360,15 @@ void MKLDNNRNN::fillCellDesc() {
inDataDescs.reserve(S + 1);
outDataDescs.reserve(S + 1);
inDataDescs.emplace_back(inShape, dataType, memory::format_tag::tnc);
outDataDescs.emplace_back(outShape, dataType, memory::format_tag::tnc);
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(inShape, dataType, memory::format_tag::tnc));
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(outShape, dataType, memory::format_tag::tnc));
inDataDescs.emplace_back(shapeS_4D, dataType, memory::format_tag::ldnc);
outDataDescs.emplace_back(shapeS_4D, dataType, memory::format_tag::ldnc);
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc));
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc));
if (haveCellState(cell_type)) {
inDataDescs.emplace_back(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
outDataDescs.emplace_back(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc));
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc));
}
copyWeightsData();
@ -375,15 +430,15 @@ void MKLDNNRNN::fillSequenceDesc() {
shapeNTDC {{N.minVal, T.minVal, DC}, {N.maxVal, T.maxVal, DC}};
// Try to create descriptor and corresponding configuration
inDataDescs.emplace_back(inShape, dataType, memory::format_tag::tnc);
outDataDescs.emplace_back(outShape, dataType, memory::format_tag::tnc);
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(inShape, dataType, memory::format_tag::tnc));
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(outShape, dataType, memory::format_tag::tnc));
inDataDescs.emplace_back(shapeS_4D, dataType, memory::format_tag::ldnc);
outDataDescs.emplace_back(shapeS_4D, dataType, memory::format_tag::ldnc);
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc));
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc));
if (haveCellState(cell_type)) {
inDataDescs.emplace_back(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
outDataDescs.emplace_back(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc));
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc));
}
copyWeightsData();
@ -424,7 +479,7 @@ void MKLDNNRNN::fillSequenceDesc() {
outCandidate.reserve(3);
if (nativeOrder) {
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(outDataDescs[RNNInOutKind::Layer]));
outCandidate.emplace_back(outDataDescs[RNNInOutKind::Layer]);
} else if (N.isStatic() && N.maxVal == 1) {
// WA to avoid reorder after sequence for some models
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNTSC, dataType, memory::format_tag::tnc));
@ -632,54 +687,54 @@ void MKLDNNRNN::fillDescs() {
prop_kind::forward_scoring,
cell_act,
direction,
/* In Data */ inDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
/* In State */ inDataDescs[RNNInOutKind::HiddenState].getDnnlDesc(),
/* In Data */ inDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
/* In State */ inDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc(),
/* Weights data */ wDescs[0],
/* Weights state */ wDescs[1],
/* Bias */ wDescs[2],
/* Out Data */ outDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
/* Out State */ outDataDescs[RNNInOutKind::HiddenState].getDnnlDesc()));
/* Out Data */ outDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
/* Out State */ outDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc()));
descs.push_back(desc);
} break;
case mkldnn::algorithm::vanilla_gru: {
MKLDNNDescriptor desc(std::make_shared<gru_forward::desc>(
prop_kind::forward_scoring,
direction,
/* In Data */ inDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
/* In State */ inDataDescs[RNNInOutKind::HiddenState].getDnnlDesc(),
/* In Data */ inDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
/* In State */ inDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc(),
/* Weights data */ wDescs[0],
/* Weights state */ wDescs[1],
/* Bias */ wDescs[2],
/* Out Data */ outDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
/* Out State */ outDataDescs[RNNInOutKind::HiddenState].getDnnlDesc()));
/* Out Data */ outDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
/* Out State */ outDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc()));
descs.push_back(desc);
} break;
case mkldnn::algorithm::lbr_gru: {
MKLDNNDescriptor desc(std::make_shared<lbr_gru_forward::desc>(
prop_kind::forward_scoring,
direction,
/* In Data */ inDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
/* In State */ inDataDescs[RNNInOutKind::HiddenState].getDnnlDesc(),
/* In Data */ inDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
/* In State */ inDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc(),
/* Weights data */ wDescs[0],
/* Weights state */ wDescs[1],
/* Bias */ wDescs[2],
/* Out Data */ outDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
/* Out State */ outDataDescs[RNNInOutKind::HiddenState].getDnnlDesc()));
/* Out Data */ outDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
/* Out State */ outDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc()));
descs.push_back(desc);
} break;
case mkldnn::algorithm::vanilla_lstm: {
MKLDNNDescriptor desc(std::make_shared<lstm_forward::desc>(
prop_kind::forward_scoring,
direction,
/* In Data */ inDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
/* In State */ inDataDescs[RNNInOutKind::HiddenState].getDnnlDesc(),
/* In State C */ inDataDescs[RNNInOutKind::CellState].getDnnlDesc(),
/* In Data */ inDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
/* In State */ inDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc(),
/* In State C */ inDataDescs[RNNInOutKind::CellState]->getDnnlDesc(),
/* Weights data */ wDescs[0],
/* Weights state */ wDescs[1],
/* Bias */ wDescs[2],
/* Out Data */ outDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
/* Out State */ outDataDescs[RNNInOutKind::HiddenState].getDnnlDesc(),
/* Out State C */ outDataDescs[RNNInOutKind::CellState].getDnnlDesc()));
/* Out Data */ outDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
/* Out State */ outDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc(),
/* Out State C */ outDataDescs[RNNInOutKind::CellState]->getDnnlDesc()));
descs.push_back(desc);
} break;
default:
@ -740,15 +795,15 @@ void MKLDNNRNN::prepareParams() {
const size_t SL = is_cell ? 1lu : dataMemPtr->GetShape().getStaticDims()[1];
const Shape shapeS_4D{L, D, B, SC};
inDataDescs[0] = DnnlBlockedMemoryDesc({SL, B, DC}, dataType, memory::format_tag::tnc);
outDataDescs[0] = DnnlBlockedMemoryDesc({SL, B, SC}, dataType, memory::format_tag::tnc);
inDataDescs[0] = std::make_shared<DnnlBlockedMemoryDesc>(Shape{SL, B, DC}, dataType, memory::format_tag::tnc);
outDataDescs[0] = std::make_shared<DnnlBlockedMemoryDesc>(Shape{SL, B, SC}, dataType, memory::format_tag::tnc);
inDataDescs[1] = DnnlBlockedMemoryDesc(shapeS_4D, dataType, memory::format_tag::ldnc);
outDataDescs[1] = DnnlBlockedMemoryDesc(shapeS_4D, dataType, memory::format_tag::ldnc);
inDataDescs[1] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc);
outDataDescs[1] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc);
if (haveCellState(cell_type)) {
inDataDescs[2] = DnnlBlockedMemoryDesc(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
outDataDescs[2] = DnnlBlockedMemoryDesc(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
inDataDescs[2] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
outDataDescs[2] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
}
bool wFormatWasChanged = false;
@ -771,21 +826,37 @@ void MKLDNNRNN::prepareParams() {
wDescs[1] = mkldnn::memory::desc(statesDims, dataType, wFormat);
}
fillDescs();
if (cell_type == mkldnn::algorithm::vanilla_rnn) {
std::shared_ptr<vanilla_rnn_forward::desc> desc = descs[0];
prim.reset(new vanilla_rnn_forward(vanilla_rnn_forward::primitive_desc(*desc, getEngine())));
} else if (cell_type == mkldnn::algorithm::vanilla_gru) {
std::shared_ptr<gru_forward::desc> desc = descs[0];
prim.reset(new gru_forward(gru_forward::primitive_desc(*desc, getEngine())));
} else if (cell_type == mkldnn::algorithm::lbr_gru) {
std::shared_ptr<lbr_gru_forward::desc> desc = descs[0];
prim.reset(new lbr_gru_forward(lbr_gru_forward::primitive_desc(*desc, getEngine())));
} else if (cell_type == mkldnn::algorithm::vanilla_lstm) {
std::shared_ptr<lstm_forward::desc> desc = descs[0];
prim.reset(new lstm_forward(lstm_forward::primitive_desc(*desc, getEngine())));
RNNKey key = { inDataDescs, outDataDescs, wDescs, cell_type };
auto builder = [this](const RNNKey& key) -> std::shared_ptr<mkldnn::primitive> {
fillDescs();
if (key.cellType == mkldnn::algorithm::vanilla_rnn) {
std::shared_ptr<vanilla_rnn_forward::desc> desc = descs[0];
return std::make_shared<vanilla_rnn_forward>(vanilla_rnn_forward::primitive_desc(*desc, getEngine()));
} else if (key.cellType == mkldnn::algorithm::vanilla_gru) {
std::shared_ptr<gru_forward::desc> desc = descs[0];
return std::make_shared<gru_forward>(gru_forward::primitive_desc(*desc, getEngine()));
} else if (key.cellType == mkldnn::algorithm::lbr_gru) {
std::shared_ptr<lbr_gru_forward::desc> desc = descs[0];
return std::make_shared<lbr_gru_forward>(lbr_gru_forward::primitive_desc(*desc, getEngine()));
} else if (key.cellType == mkldnn::algorithm::vanilla_lstm) {
std::shared_ptr<lstm_forward::desc> desc = descs[0];
return std::make_shared<lstm_forward>(lstm_forward::primitive_desc(*desc, getEngine()));
} else {
return nullptr;
}
};
auto cache = getRuntimeCache();
auto result = cache->getOrCreate(key, builder);
if (!result.first) {
IE_THROW() << "Primitive descriptor was not found for node " << getName() << ".";
}
prim = result.first;
if (!wasMemoryPrepared || wFormatWasChanged) {
auto itpd = descs[0].createPrimitiveDescriptorIterator(getEngine(), mkldnn::primitive_attr());
prepareMemory(itpd);

View File

@ -5,6 +5,7 @@
#pragma once
#include <mkldnn_node.h>
#include "memory_desc/dnnl_blocked_memory_desc.h"
#include <string>
#include <memory>
@ -97,8 +98,8 @@ private:
const size_t L = 1; /**< What is it??. Constant for mkldnn impl */
const size_t D = 1; /**< Num of direction. 1 or 2 */
std::vector<DnnlBlockedMemoryDesc> inDataDescs;
std::vector<DnnlBlockedMemoryDesc> outDataDescs;
std::vector<DnnlBlockedMemoryDescPtr> inDataDescs;
std::vector<DnnlBlockedMemoryDescPtr> outDataDescs;
std::vector<mkldnn::memory::desc> wDescs;
enum RNNInOutKind {

View File

@ -166,7 +166,11 @@ const std::vector<std::vector<ov::test::InputShape>> dynamicShapes = {
{ { { {1, 10}, {25, 35} }, // Dynamic shape 0
{ {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
{ { {1, 10}, -1 }, // Dynamic shape 1
{ {2, 10}, {5, 10}, {8, 10} } } } // Target shapes
{ {2, 10}, {5, 10}, {8, 10} } } }, // Target shapes
{ { { {1, 10}, {25, 35} }, // Dynamic shape 0
{ {2, 30}, {5, 30}, {8, 30}, {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
{ { {1, 10}, -1 }, // Dynamic shape 1
{ {2, 10}, {5, 10}, {8, 10}, {2, 10}, {5, 10}, {8, 10} } } } // Target shapes
};
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, GRUCellCPUTest,

View File

@ -310,11 +310,17 @@ const std::vector<std::vector<InputShape>> dynamicShapes = {
{ { {5, -1, 10}, // #7. Dynamic shape 0
{ {5, 2, 10}, {5, 4, 10}, {5, 5, 10} } }, // Target shapes
{ {5, 1, 10}, // Dynamic shape 1
{ {5, 1, 10}, {5, 1, 10}, {5, 1, 10} } } } // Target shapes
{ {5, 1, 10}, {5, 1, 10}, {5, 1, 10} } } }, // Target shapes
{ { {{0, 11}, -1, {7, 11}}, // #8. Dynamic shape 0
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10}, {10, 2, 10}, {5, 5, 10} } }, // Target shapes
{ {-1, 1, {8, 12}}, // Dynamic shape 1
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10}, {10, 1, 10}, {5, 1, 10} } }, // Target shapes
{ {-1}, // Dynamic shape 2
{ {10}, {3}, {5}, {10}, {5} } } } // Target shapes
};
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, GRUSequenceCPUTest,
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[0], dynamicShapes[1], dynamicShapes[2]}),
::testing::Combine(::testing::ValuesIn({dynamicShapes[0], dynamicShapes[1], dynamicShapes[2]}),
::testing::ValuesIn(mode),
::testing::ValuesIn(activations),
::testing::ValuesIn(clip),
@ -326,7 +332,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamic, GRUSequenceCPUTest,
GRUSequenceCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_BatchSizeOne, GRUSequenceCPUTest,
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[4]}),
::testing::Combine(::testing::ValuesIn({dynamicShapes[4]}),
::testing::ValuesIn(mode),
::testing::ValuesIn(activations),
::testing::ValuesIn(clip),
@ -338,7 +344,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamic_BatchSizeOne, GRUSequenceCPUTest,
GRUSequenceCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(nightly_dynamic, GRUSequenceCPUTest,
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[5]}),
::testing::Combine(::testing::ValuesIn({dynamicShapes[5], dynamicShapes[8]}),
::testing::ValuesIn(mode),
::testing::ValuesIn(activations),
::testing::ValuesIn(clip),
@ -350,7 +356,7 @@ INSTANTIATE_TEST_SUITE_P(nightly_dynamic, GRUSequenceCPUTest,
GRUSequenceCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(nightly_dynamic_bf16, GRUSequenceCPUTest,
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[6], dynamicShapes[7]}),
::testing::Combine(::testing::ValuesIn({dynamicShapes[6], dynamicShapes[7]}),
::testing::ValuesIn(mode),
::testing::ValuesIn(activations),
::testing::ValuesIn(clip),

View File

@ -167,7 +167,13 @@ const std::vector<std::vector<ov::test::InputShape>> dynamicShapes = {
{ { {1, 20}, {8, 12} }, // Dynamic shape 1
{ {2, 10}, {5, 10}, {8, 10} } }, // Target shapes
{ { {1, 20}, -1 }, // Dynamic shape 2
{ {2, 10}, {5, 10}, {8, 10} } } } // Target shapes
{ {2, 10}, {5, 10}, {8, 10} } } }, // Target shapes
{ { { {1, 20}, {28, 32} }, // Dynamic shape 0
{ {2, 30}, {5, 30}, {8, 30}, {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
{ { {1, 20}, {8, 12} }, // Dynamic shape 1
{ {2, 10}, {5, 10}, {8, 10}, {2, 10}, {5, 10}, {8, 10} } }, // Target shapes
{ { {1, 20}, -1 }, // Dynamic shape 2
{ {2, 10}, {5, 10}, {8, 10}, {2, 10}, {5, 10}, {8, 10} } } } // Target shapes
};
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, LSTMCellLayerCPUTest,

View File

@ -319,14 +319,22 @@ const std::vector<std::vector<InputShape>> dynamicShapes = {
{ {1}, {1}, {1} } } }, // Target shapes
{ { {3, -1, {0, 12}}, // #6. Dynamic shape 0
{ {3, 2, 10}, {3, 4, 10}, {3, 5, 10} } }, // Target shapes
{ {3, -1, {0, 12}}, // Dynamic shape 1
{ {3, 1, 10}, {3, 1, 10}, {3, 1, 10} } }, // Target shapes
{ {3, -1, {0, 12}}, // Dynamic shape 2
{ {3, 1, 10}, {3, 1, 10}, {3, 1, 10} } } } // Target shapes
{ {3, -1, {0, 12}}, // Dynamic shape 1
{ {3, 1, 10}, {3, 1, 10}, {3, 1, 10} } }, // Target shapes
{ {3, -1, {0, 12}}, // Dynamic shape 2
{ {3, 1, 10}, {3, 1, 10}, {3, 1, 10} } } }, // Target shapes
{ { {{0, 11}, -1, {5, 15}}, // #7. Dynamic shape 0
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10}, {10, 2, 10}, {5, 5, 10} } }, // Target shapes
{ {-1, 1, -1}, // Dynamic shape 1
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10}, {10, 1, 10}, {5, 1, 10} } }, // Target shapes
{ {-1, 1, -1}, // Dynamic shape 2
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10}, {10, 1, 10}, {5, 1, 10} } }, // Target shapes
{ {-1}, // Dynamic shape 3
{ {10}, {3}, {5}, {10}, {5} } } } // Target shapes
};
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, LSTMSequenceCPUTest,
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[0], dynamicShapes[1], dynamicShapes[2]}),
::testing::Combine(::testing::ValuesIn({dynamicShapes[0], dynamicShapes[1], dynamicShapes[2]}),
::testing::ValuesIn(mode),
::testing::ValuesIn(activations),
::testing::ValuesIn(clip),
@ -337,7 +345,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamic, LSTMSequenceCPUTest,
LSTMSequenceCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_BatchSizeOne, LSTMSequenceCPUTest,
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[4]}),
::testing::Combine(::testing::ValuesIn({dynamicShapes[4]}),
::testing::ValuesIn(mode),
::testing::ValuesIn(activations),
::testing::ValuesIn(clip),
@ -348,7 +356,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamic_BatchSizeOne, LSTMSequenceCPUTest,
LSTMSequenceCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(nightly_dynamic, LSTMSequenceCPUTest,
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[5]}),
::testing::Combine(::testing::ValuesIn({dynamicShapes[5], dynamicShapes[7]}),
::testing::ValuesIn(mode),
::testing::ValuesIn(activations),
::testing::ValuesIn(clip),
@ -359,7 +367,7 @@ INSTANTIATE_TEST_SUITE_P(nightly_dynamic, LSTMSequenceCPUTest,
LSTMSequenceCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(nightly_dynamic_bf16, LSTMSequenceCPUTest,
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[6]}),
::testing::Combine(::testing::ValuesIn({dynamicShapes[6]}),
::testing::ValuesIn(mode),
::testing::ValuesIn(activations),
::testing::ValuesIn(clip),

View File

@ -139,18 +139,22 @@ INSTANTIATE_TEST_SUITE_P(smoke_static, RNNCellCPUTest,
RNNCellCPUTest::getTestCaseName);
const std::vector<std::vector<ov::test::InputShape>> dynamicShapes = {
{ { { {-1}, 1 }, // Dynamic shape 0
{ {1, 1}, {3, 1}, {5, 1} } }, // Target shapes
{ { {-1}, 1 }, // Dynamic shape 1
{ {1, 1}, {3, 1}, {5, 1} } } }, // Target shapes
{ { { {1, 10}, 30 }, // Dynamic shape 0
{ {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
{ { {1, 10}, 10 }, // Dynamic shape 1
{ {2, 10}, {5, 10}, {8, 10} } } }, // Target shapes
{ { { {1, 10}, -1 }, // Dynamic shape 0
{ {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
{ { {1, 10}, {1, 11} }, // Dynamic shape 1
{ {2, 10}, {5, 10}, {8, 10} } } } // Target shapes
{ { { {-1}, 1 }, // Dynamic shape 0
{ {1, 1}, {3, 1}, {5, 1} } }, // Target shapes
{ { {-1}, 1 }, // Dynamic shape 1
{ {1, 1}, {3, 1}, {5, 1} } } }, // Target shapes
{ { { {1, 10}, 30 }, // Dynamic shape 0
{ {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
{ { {1, 10}, 10 }, // Dynamic shape 1
{ {2, 10}, {5, 10}, {8, 10} } } }, // Target shapes
{ { { {1, 10}, -1 }, // Dynamic shape 0
{ {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
{ { {1, 10}, {1, 11} }, // Dynamic shape 1
{ {2, 10}, {5, 10}, {8, 10} } } }, // Target shapes
{ { { {1, 10}, -1 }, // Dynamic shape 0
{ {2, 30}, {5, 30}, {2, 30}, {8, 30}, {5, 30}, {8, 30} } }, // Target shapes
{ { {1, 10}, {1, 11} }, // Dynamic shape 1
{ {2, 10}, {5, 10}, {2, 10}, {8, 10}, {5, 10}, {8, 10} } } } // Target shapes
};
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, RNNCellCPUTest,

View File

@ -229,54 +229,60 @@ INSTANTIATE_TEST_SUITE_P(smoke_static_BatchSizeOne, RNNSequenceCPUTest,
RNNSequenceCPUTest::getTestCaseName);
const std::vector<std::vector<InputShape>> dynamicShapes = {
{ { {-1, {1, 5}, 10}, // #0. Dynamic shape 0
{ {10, 2, 10}, {8, 3, 10}, {5, 4, 10} } }, // Target shapes
{ {{0, 15}, 1, 1}, // Dynamic shape 1
{ {10, 1, 1}, {8, 1, 1}, {5, 1, 1} } }, // Target shapes
{ {{0, 12}}, // Dynamic shape 2
{ {10}, {8}, {5} } } }, // Target shapes
{ { {{0, 11}, -1, 10}, // #1. Dynamic shape 0
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10} } }, // Target shapes
{ {-1, 1, 10}, // Dynamic shape 1
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10} } }, // Target shapes
{ {-1}, // Dynamic shape 2
{ {10}, {3}, {5} } } }, // Target shapes
{ { {{0, 11}, -1, {5, 15}}, // #2. Dynamic shape 0
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10} } }, // Target shapes
{ {-1, -1, {8, 11}}, // Dynamic shape 1
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10} } }, // Target shapes
{ {-1}, // Dynamic shape 2
{ {10}, {3}, {5} } } }, // Target shapes
{ { {-1, {0, 7}, 10}, // #3. Dynamic shape 0
{ {1, 2, 10}, {1, 3, 10}, {1, 6, 10} } }, // Target shapes
{ {-1, 1, 1}, // Dynamic shape 1
{ {1, 1, 1}, {1, 1, 1}, {1, 1, 1} } }, // Target shapes
{ {-1}, // Dynamic shape 2
{ {1}, {1}, {1} } } }, // Target shapes
{ { {1, -1, 10}, // #4. Dynamic shape 0
{ {1, 2, 10}, {1, 4, 10}, {1, 8, 10} } }, // Target shapes
{ {1, 1, 10}, // Dynamic shape 1
{ {1, 1, 10}, {1, 1, 10}, {1, 1, 10} } }, // Target shapes
{ {1}, // Dynamic shape 2
{ {1}, {1}, {1} } } }, // Target shapes
{ { {-1, -1, -1}, // #5. Dynamic shape 0
{ {1, 2, 10}, {1, 4, 10}, {1, 8, 10} } }, // Target shapes
{ {-1, -1, -1}, // Dynamic shape 1
{ {1, 1, 10}, {1, 1, 10}, {1, 1, 10} } }, // Target shapes
{ {-1}, // Dynamic shape 2
{ {1}, {1}, {1} } } }, // Target shapes
{ { {-1, {1, 5}, 10}, // #6. Dynamic shape 0
{ {10, 2, 10}, {8, 3, 10}, {5, 4, 10} } }, // Target shapes
{ {{0, 15}, 1, 1}, // Dynamic shape 1
{ {10, 1, 1}, {8, 1, 1}, {5, 1, 1} } } }, // Target shapes
{ { {{0, 11}, -1, 10}, // #7. Dynamic shape 0
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10} } }, // Target shapes
{ {-1, 1, 10}, // Dynamic shape 1
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10} } } } // Target shapes
{ { {-1, {1, 5}, 10}, // #0. Dynamic shape 0
{ {10, 2, 10}, {8, 3, 10}, {5, 4, 10} } }, // Target shapes
{ {{0, 15}, 1, 1}, // Dynamic shape 1
{ {10, 1, 1}, {8, 1, 1}, {5, 1, 1} } }, // Target shapes
{ {{0, 12}}, // Dynamic shape 2
{ {10}, {8}, {5} } } }, // Target shapes
{ { {{0, 11}, -1, 10}, // #1. Dynamic shape 0
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10} } }, // Target shapes
{ {-1, 1, 10}, // Dynamic shape 1
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10} } }, // Target shapes
{ {-1}, // Dynamic shape 2
{ {10}, {3}, {5} } } }, // Target shapes
{ { {{0, 11}, -1, {5, 15}}, // #2. Dynamic shape 0
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10} } }, // Target shapes
{ {-1, -1, {8, 11}}, // Dynamic shape 1
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10} } }, // Target shapes
{ {-1}, // Dynamic shape 2
{ {10}, {3}, {5} } } }, // Target shapes
{ { {-1, {0, 7}, 10}, // #3. Dynamic shape 0
{ {1, 2, 10}, {1, 3, 10}, {1, 6, 10} } }, // Target shapes
{ {-1, 1, 1}, // Dynamic shape 1
{ {1, 1, 1}, {1, 1, 1}, {1, 1, 1} } }, // Target shapes
{ {-1}, // Dynamic shape 2
{ {1}, {1}, {1} } } }, // Target shapes
{ { {1, -1, 10}, // #4. Dynamic shape 0
{ {1, 2, 10}, {1, 4, 10}, {1, 8, 10} } }, // Target shapes
{ {1, 1, 10}, // Dynamic shape 1
{ {1, 1, 10}, {1, 1, 10}, {1, 1, 10} } }, // Target shapes
{ {1}, // Dynamic shape 2
{ {1}, {1}, {1} } } }, // Target shapes
{ { {-1, -1, -1}, // #5. Dynamic shape 0
{ {1, 2, 10}, {1, 4, 10}, {1, 8, 10} } }, // Target shapes
{ {-1, -1, -1}, // Dynamic shape 1
{ {1, 1, 10}, {1, 1, 10}, {1, 1, 10} } }, // Target shapes
{ {-1}, // Dynamic shape 2
{ {1}, {1}, {1} } } }, // Target shapes
{ { {7, {1, 5}, 10}, // #6. Dynamic shape 0
{ {7, 2, 10}, {7, 3, 10}, {7, 4, 10} } }, // Target shapes
{ {7, 1, 1}, // Dynamic shape 1
{ {7, 1, 1}, {7, 1, 1}, {7, 1, 1} } } }, // Target shapes
{ { {5, -1, 10}, // #7. Dynamic shape 0
{ {5, 2, 10}, {5, 4, 10}, {5, 5, 10} } }, // Target shapes
{ {5, 1, 10}, // Dynamic shape 1
{ {5, 1, 10}, {5, 1, 10}, {5, 1, 10} } } }, // Target shapes
{ { {{0, 11}, -1, 10}, // #8. Dynamic shape 0
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10}, {10, 2, 10}, {5, 5, 10} } }, // Target shapes
{ {-1, 1, 10}, // Dynamic shape 1
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10}, {10, 1, 10}, {5, 1, 10} } }, // Target shapes
{ {-1}, // Dynamic shape 2
{ {10}, {3}, {5}, {10}, {5} } } } // Target shapes
};
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, RNNSequenceCPUTest,
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[0], dynamicShapes[1], dynamicShapes[2]}),
::testing::Combine(::testing::ValuesIn({dynamicShapes[0], dynamicShapes[1], dynamicShapes[2]}),
::testing::ValuesIn(mode),
::testing::ValuesIn(activations),
::testing::ValuesIn(clip),
@ -298,13 +304,13 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamic_BatchSizeOne, RNNSequenceCPUTest,
RNNSequenceCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(nightly_dynamic, RNNSequenceCPUTest,
::testing::Combine(::testing::Values(dynamicShapes[5]),
::testing::Combine(::testing::ValuesIn({dynamicShapes[5], dynamicShapes[8]}),
::testing::ValuesIn(mode),
::testing::ValuesIn(activations),
::testing::ValuesIn(clip),
::testing::ValuesIn(direction),
::testing::ValuesIn(netPrecisions),
::testing::Values(cpuParamsBatchSizeOne),
::testing::Values(cpuParams),
::testing::Values(std::map<std::string, std::string>{})),
RNNSequenceCPUTest::getTestCaseName);
@ -315,7 +321,7 @@ INSTANTIATE_TEST_SUITE_P(nightly_dynamic_bf16, RNNSequenceCPUTest,
::testing::ValuesIn(clip),
::testing::ValuesIn(direction),
::testing::ValuesIn(netPrecisions),
::testing::Values(cpuParamsBatchSizeOne),
::testing::Values(cpuParams),
::testing::Values(additionalConfig[1])),
RNNSequenceCPUTest::getTestCaseName);
} // namespace