[CPU] RNN node kernel caching. (#9588)
This commit is contained in:
parent
d5fb0a0c24
commit
77c2c5fab3
@ -1,4 +1,4 @@
|
||||
// Copyright (C) 2018-2021 Intel Corporation
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
#include "mkldnn_input_node.h"
|
||||
#include <mkldnn_extension_utils.h>
|
||||
#include "memory_desc/dnnl_blocked_memory_desc.h"
|
||||
#include <common/primitive_hashing_utils.hpp>
|
||||
|
||||
#include <ngraph/node.hpp>
|
||||
|
||||
@ -110,6 +111,61 @@ const std::map<Precision, Precision> MKLDNNRNN::weightsByLayerPrec {
|
||||
// {Precision::U8, Precision::I8},
|
||||
};
|
||||
|
||||
|
||||
struct RNNKey {
|
||||
const std::vector<DnnlBlockedMemoryDescPtr> inDataDescs;
|
||||
const std::vector<DnnlBlockedMemoryDescPtr> outDataDescs;
|
||||
const std::vector<mkldnn::memory::desc> wDescs;
|
||||
mkldnn::algorithm cellType;
|
||||
|
||||
size_t hash() const;
|
||||
bool operator==(const RNNKey& rhs) const;
|
||||
};
|
||||
|
||||
size_t RNNKey::hash() const {
|
||||
using namespace dnnl::impl;
|
||||
using namespace dnnl::impl::primitive_hashing;
|
||||
|
||||
size_t seed = 0lu;
|
||||
|
||||
for (auto& desc : inDataDescs) {
|
||||
if (desc != nullptr)
|
||||
seed = hash_combine(seed, get_md_hash(desc->getDnnlDesc().data));
|
||||
}
|
||||
for (auto& desc : outDataDescs) {
|
||||
if (desc != nullptr)
|
||||
seed = hash_combine(seed, get_md_hash(desc->getDnnlDesc().data));
|
||||
}
|
||||
for (auto& desc : wDescs) {
|
||||
seed = hash_combine(seed, get_md_hash(desc.data));
|
||||
}
|
||||
seed = hash_combine(seed, cellType);
|
||||
return seed;
|
||||
}
|
||||
|
||||
bool RNNKey::operator==(const RNNKey& rhs) const {
|
||||
if (inDataDescs.size() != rhs.inDataDescs.size() || outDataDescs.size() != rhs.outDataDescs.size() || wDescs.size() != rhs.wDescs.size() ||
|
||||
cellType != rhs.cellType)
|
||||
return false;
|
||||
|
||||
for (size_t i = 0lu; i < inDataDescs.size(); i++) {
|
||||
if (inDataDescs[i] != rhs.inDataDescs[i] && (inDataDescs[i] == nullptr || rhs.inDataDescs[i] == nullptr ||
|
||||
inDataDescs[i]->getDnnlDesc() != rhs.inDataDescs[i]->getDnnlDesc()))
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0lu; i < outDataDescs.size(); i++) {
|
||||
if (outDataDescs[i] != rhs.outDataDescs[i] && (outDataDescs[i] == nullptr || rhs.outDataDescs[i] ||
|
||||
outDataDescs[i]->getDnnlDesc() == rhs.outDataDescs[i]->getDnnlDesc()))
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0lu; i < wDescs.size(); i++) {
|
||||
if (wDescs[i] != rhs.wDescs[i])
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MKLDNNRNN::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
|
||||
try {
|
||||
if (!one_of(op->get_type_info(),
|
||||
@ -231,7 +287,6 @@ MKLDNNRNN::MKLDNNRNN(const std::shared_ptr<ov::Node>& op, const mkldnn::engine&
|
||||
THROW_ERROR << "does not have original layer for RNNCell.";
|
||||
|
||||
cell_type = ie2dnnl(op);
|
||||
cell_act = mkldnn::algorithm::undef;
|
||||
if (!rnnCellBase->get_activations().empty())
|
||||
cell_act = ie2dnnl(rnnCellBase->get_activations()[0]); // Works only for RNN with one gate
|
||||
|
||||
@ -305,15 +360,15 @@ void MKLDNNRNN::fillCellDesc() {
|
||||
inDataDescs.reserve(S + 1);
|
||||
outDataDescs.reserve(S + 1);
|
||||
|
||||
inDataDescs.emplace_back(inShape, dataType, memory::format_tag::tnc);
|
||||
outDataDescs.emplace_back(outShape, dataType, memory::format_tag::tnc);
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(inShape, dataType, memory::format_tag::tnc));
|
||||
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(outShape, dataType, memory::format_tag::tnc));
|
||||
|
||||
inDataDescs.emplace_back(shapeS_4D, dataType, memory::format_tag::ldnc);
|
||||
outDataDescs.emplace_back(shapeS_4D, dataType, memory::format_tag::ldnc);
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc));
|
||||
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc));
|
||||
|
||||
if (haveCellState(cell_type)) {
|
||||
inDataDescs.emplace_back(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
|
||||
outDataDescs.emplace_back(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc));
|
||||
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc));
|
||||
}
|
||||
|
||||
copyWeightsData();
|
||||
@ -375,15 +430,15 @@ void MKLDNNRNN::fillSequenceDesc() {
|
||||
shapeNTDC {{N.minVal, T.minVal, DC}, {N.maxVal, T.maxVal, DC}};
|
||||
|
||||
// Try to create descriptor and corresponding configuration
|
||||
inDataDescs.emplace_back(inShape, dataType, memory::format_tag::tnc);
|
||||
outDataDescs.emplace_back(outShape, dataType, memory::format_tag::tnc);
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(inShape, dataType, memory::format_tag::tnc));
|
||||
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(outShape, dataType, memory::format_tag::tnc));
|
||||
|
||||
inDataDescs.emplace_back(shapeS_4D, dataType, memory::format_tag::ldnc);
|
||||
outDataDescs.emplace_back(shapeS_4D, dataType, memory::format_tag::ldnc);
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc));
|
||||
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc));
|
||||
|
||||
if (haveCellState(cell_type)) {
|
||||
inDataDescs.emplace_back(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
|
||||
outDataDescs.emplace_back(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
|
||||
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc));
|
||||
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc));
|
||||
}
|
||||
|
||||
copyWeightsData();
|
||||
@ -424,7 +479,7 @@ void MKLDNNRNN::fillSequenceDesc() {
|
||||
outCandidate.reserve(3);
|
||||
|
||||
if (nativeOrder) {
|
||||
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(outDataDescs[RNNInOutKind::Layer]));
|
||||
outCandidate.emplace_back(outDataDescs[RNNInOutKind::Layer]);
|
||||
} else if (N.isStatic() && N.maxVal == 1) {
|
||||
// WA to avoid reorder after sequence for some models
|
||||
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNTSC, dataType, memory::format_tag::tnc));
|
||||
@ -632,54 +687,54 @@ void MKLDNNRNN::fillDescs() {
|
||||
prop_kind::forward_scoring,
|
||||
cell_act,
|
||||
direction,
|
||||
/* In Data */ inDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
|
||||
/* In State */ inDataDescs[RNNInOutKind::HiddenState].getDnnlDesc(),
|
||||
/* In Data */ inDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
|
||||
/* In State */ inDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc(),
|
||||
/* Weights data */ wDescs[0],
|
||||
/* Weights state */ wDescs[1],
|
||||
/* Bias */ wDescs[2],
|
||||
/* Out Data */ outDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
|
||||
/* Out State */ outDataDescs[RNNInOutKind::HiddenState].getDnnlDesc()));
|
||||
/* Out Data */ outDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
|
||||
/* Out State */ outDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc()));
|
||||
descs.push_back(desc);
|
||||
} break;
|
||||
case mkldnn::algorithm::vanilla_gru: {
|
||||
MKLDNNDescriptor desc(std::make_shared<gru_forward::desc>(
|
||||
prop_kind::forward_scoring,
|
||||
direction,
|
||||
/* In Data */ inDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
|
||||
/* In State */ inDataDescs[RNNInOutKind::HiddenState].getDnnlDesc(),
|
||||
/* In Data */ inDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
|
||||
/* In State */ inDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc(),
|
||||
/* Weights data */ wDescs[0],
|
||||
/* Weights state */ wDescs[1],
|
||||
/* Bias */ wDescs[2],
|
||||
/* Out Data */ outDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
|
||||
/* Out State */ outDataDescs[RNNInOutKind::HiddenState].getDnnlDesc()));
|
||||
/* Out Data */ outDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
|
||||
/* Out State */ outDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc()));
|
||||
descs.push_back(desc);
|
||||
} break;
|
||||
case mkldnn::algorithm::lbr_gru: {
|
||||
MKLDNNDescriptor desc(std::make_shared<lbr_gru_forward::desc>(
|
||||
prop_kind::forward_scoring,
|
||||
direction,
|
||||
/* In Data */ inDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
|
||||
/* In State */ inDataDescs[RNNInOutKind::HiddenState].getDnnlDesc(),
|
||||
/* In Data */ inDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
|
||||
/* In State */ inDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc(),
|
||||
/* Weights data */ wDescs[0],
|
||||
/* Weights state */ wDescs[1],
|
||||
/* Bias */ wDescs[2],
|
||||
/* Out Data */ outDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
|
||||
/* Out State */ outDataDescs[RNNInOutKind::HiddenState].getDnnlDesc()));
|
||||
/* Out Data */ outDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
|
||||
/* Out State */ outDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc()));
|
||||
descs.push_back(desc);
|
||||
} break;
|
||||
case mkldnn::algorithm::vanilla_lstm: {
|
||||
MKLDNNDescriptor desc(std::make_shared<lstm_forward::desc>(
|
||||
prop_kind::forward_scoring,
|
||||
direction,
|
||||
/* In Data */ inDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
|
||||
/* In State */ inDataDescs[RNNInOutKind::HiddenState].getDnnlDesc(),
|
||||
/* In State C */ inDataDescs[RNNInOutKind::CellState].getDnnlDesc(),
|
||||
/* In Data */ inDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
|
||||
/* In State */ inDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc(),
|
||||
/* In State C */ inDataDescs[RNNInOutKind::CellState]->getDnnlDesc(),
|
||||
/* Weights data */ wDescs[0],
|
||||
/* Weights state */ wDescs[1],
|
||||
/* Bias */ wDescs[2],
|
||||
/* Out Data */ outDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
|
||||
/* Out State */ outDataDescs[RNNInOutKind::HiddenState].getDnnlDesc(),
|
||||
/* Out State C */ outDataDescs[RNNInOutKind::CellState].getDnnlDesc()));
|
||||
/* Out Data */ outDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
|
||||
/* Out State */ outDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc(),
|
||||
/* Out State C */ outDataDescs[RNNInOutKind::CellState]->getDnnlDesc()));
|
||||
descs.push_back(desc);
|
||||
} break;
|
||||
default:
|
||||
@ -740,15 +795,15 @@ void MKLDNNRNN::prepareParams() {
|
||||
const size_t SL = is_cell ? 1lu : dataMemPtr->GetShape().getStaticDims()[1];
|
||||
const Shape shapeS_4D{L, D, B, SC};
|
||||
|
||||
inDataDescs[0] = DnnlBlockedMemoryDesc({SL, B, DC}, dataType, memory::format_tag::tnc);
|
||||
outDataDescs[0] = DnnlBlockedMemoryDesc({SL, B, SC}, dataType, memory::format_tag::tnc);
|
||||
inDataDescs[0] = std::make_shared<DnnlBlockedMemoryDesc>(Shape{SL, B, DC}, dataType, memory::format_tag::tnc);
|
||||
outDataDescs[0] = std::make_shared<DnnlBlockedMemoryDesc>(Shape{SL, B, SC}, dataType, memory::format_tag::tnc);
|
||||
|
||||
inDataDescs[1] = DnnlBlockedMemoryDesc(shapeS_4D, dataType, memory::format_tag::ldnc);
|
||||
outDataDescs[1] = DnnlBlockedMemoryDesc(shapeS_4D, dataType, memory::format_tag::ldnc);
|
||||
inDataDescs[1] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc);
|
||||
outDataDescs[1] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc);
|
||||
|
||||
if (haveCellState(cell_type)) {
|
||||
inDataDescs[2] = DnnlBlockedMemoryDesc(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
|
||||
outDataDescs[2] = DnnlBlockedMemoryDesc(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
|
||||
inDataDescs[2] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
|
||||
outDataDescs[2] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
|
||||
}
|
||||
|
||||
bool wFormatWasChanged = false;
|
||||
@ -771,21 +826,37 @@ void MKLDNNRNN::prepareParams() {
|
||||
wDescs[1] = mkldnn::memory::desc(statesDims, dataType, wFormat);
|
||||
}
|
||||
|
||||
fillDescs();
|
||||
if (cell_type == mkldnn::algorithm::vanilla_rnn) {
|
||||
std::shared_ptr<vanilla_rnn_forward::desc> desc = descs[0];
|
||||
prim.reset(new vanilla_rnn_forward(vanilla_rnn_forward::primitive_desc(*desc, getEngine())));
|
||||
} else if (cell_type == mkldnn::algorithm::vanilla_gru) {
|
||||
std::shared_ptr<gru_forward::desc> desc = descs[0];
|
||||
prim.reset(new gru_forward(gru_forward::primitive_desc(*desc, getEngine())));
|
||||
} else if (cell_type == mkldnn::algorithm::lbr_gru) {
|
||||
std::shared_ptr<lbr_gru_forward::desc> desc = descs[0];
|
||||
prim.reset(new lbr_gru_forward(lbr_gru_forward::primitive_desc(*desc, getEngine())));
|
||||
} else if (cell_type == mkldnn::algorithm::vanilla_lstm) {
|
||||
std::shared_ptr<lstm_forward::desc> desc = descs[0];
|
||||
prim.reset(new lstm_forward(lstm_forward::primitive_desc(*desc, getEngine())));
|
||||
RNNKey key = { inDataDescs, outDataDescs, wDescs, cell_type };
|
||||
|
||||
auto builder = [this](const RNNKey& key) -> std::shared_ptr<mkldnn::primitive> {
|
||||
fillDescs();
|
||||
|
||||
if (key.cellType == mkldnn::algorithm::vanilla_rnn) {
|
||||
std::shared_ptr<vanilla_rnn_forward::desc> desc = descs[0];
|
||||
return std::make_shared<vanilla_rnn_forward>(vanilla_rnn_forward::primitive_desc(*desc, getEngine()));
|
||||
} else if (key.cellType == mkldnn::algorithm::vanilla_gru) {
|
||||
std::shared_ptr<gru_forward::desc> desc = descs[0];
|
||||
return std::make_shared<gru_forward>(gru_forward::primitive_desc(*desc, getEngine()));
|
||||
} else if (key.cellType == mkldnn::algorithm::lbr_gru) {
|
||||
std::shared_ptr<lbr_gru_forward::desc> desc = descs[0];
|
||||
return std::make_shared<lbr_gru_forward>(lbr_gru_forward::primitive_desc(*desc, getEngine()));
|
||||
} else if (key.cellType == mkldnn::algorithm::vanilla_lstm) {
|
||||
std::shared_ptr<lstm_forward::desc> desc = descs[0];
|
||||
return std::make_shared<lstm_forward>(lstm_forward::primitive_desc(*desc, getEngine()));
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
auto cache = getRuntimeCache();
|
||||
auto result = cache->getOrCreate(key, builder);
|
||||
|
||||
if (!result.first) {
|
||||
IE_THROW() << "Primitive descriptor was not found for node " << getName() << ".";
|
||||
}
|
||||
|
||||
prim = result.first;
|
||||
|
||||
if (!wasMemoryPrepared || wFormatWasChanged) {
|
||||
auto itpd = descs[0].createPrimitiveDescriptorIterator(getEngine(), mkldnn::primitive_attr());
|
||||
prepareMemory(itpd);
|
||||
|
@ -5,6 +5,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <mkldnn_node.h>
|
||||
#include "memory_desc/dnnl_blocked_memory_desc.h"
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
@ -97,8 +98,8 @@ private:
|
||||
const size_t L = 1; /**< What is it??. Constant for mkldnn impl */
|
||||
const size_t D = 1; /**< Num of direction. 1 or 2 */
|
||||
|
||||
std::vector<DnnlBlockedMemoryDesc> inDataDescs;
|
||||
std::vector<DnnlBlockedMemoryDesc> outDataDescs;
|
||||
std::vector<DnnlBlockedMemoryDescPtr> inDataDescs;
|
||||
std::vector<DnnlBlockedMemoryDescPtr> outDataDescs;
|
||||
std::vector<mkldnn::memory::desc> wDescs;
|
||||
|
||||
enum RNNInOutKind {
|
||||
|
@ -166,7 +166,11 @@ const std::vector<std::vector<ov::test::InputShape>> dynamicShapes = {
|
||||
{ { { {1, 10}, {25, 35} }, // Dynamic shape 0
|
||||
{ {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
|
||||
{ { {1, 10}, -1 }, // Dynamic shape 1
|
||||
{ {2, 10}, {5, 10}, {8, 10} } } } // Target shapes
|
||||
{ {2, 10}, {5, 10}, {8, 10} } } }, // Target shapes
|
||||
{ { { {1, 10}, {25, 35} }, // Dynamic shape 0
|
||||
{ {2, 30}, {5, 30}, {8, 30}, {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
|
||||
{ { {1, 10}, -1 }, // Dynamic shape 1
|
||||
{ {2, 10}, {5, 10}, {8, 10}, {2, 10}, {5, 10}, {8, 10} } } } // Target shapes
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, GRUCellCPUTest,
|
||||
|
@ -310,11 +310,17 @@ const std::vector<std::vector<InputShape>> dynamicShapes = {
|
||||
{ { {5, -1, 10}, // #7. Dynamic shape 0
|
||||
{ {5, 2, 10}, {5, 4, 10}, {5, 5, 10} } }, // Target shapes
|
||||
{ {5, 1, 10}, // Dynamic shape 1
|
||||
{ {5, 1, 10}, {5, 1, 10}, {5, 1, 10} } } } // Target shapes
|
||||
{ {5, 1, 10}, {5, 1, 10}, {5, 1, 10} } } }, // Target shapes
|
||||
{ { {{0, 11}, -1, {7, 11}}, // #8. Dynamic shape 0
|
||||
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10}, {10, 2, 10}, {5, 5, 10} } }, // Target shapes
|
||||
{ {-1, 1, {8, 12}}, // Dynamic shape 1
|
||||
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10}, {10, 1, 10}, {5, 1, 10} } }, // Target shapes
|
||||
{ {-1}, // Dynamic shape 2
|
||||
{ {10}, {3}, {5}, {10}, {5} } } } // Target shapes
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, GRUSequenceCPUTest,
|
||||
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[0], dynamicShapes[1], dynamicShapes[2]}),
|
||||
::testing::Combine(::testing::ValuesIn({dynamicShapes[0], dynamicShapes[1], dynamicShapes[2]}),
|
||||
::testing::ValuesIn(mode),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip),
|
||||
@ -326,7 +332,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamic, GRUSequenceCPUTest,
|
||||
GRUSequenceCPUTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_BatchSizeOne, GRUSequenceCPUTest,
|
||||
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[4]}),
|
||||
::testing::Combine(::testing::ValuesIn({dynamicShapes[4]}),
|
||||
::testing::ValuesIn(mode),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip),
|
||||
@ -338,7 +344,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamic_BatchSizeOne, GRUSequenceCPUTest,
|
||||
GRUSequenceCPUTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(nightly_dynamic, GRUSequenceCPUTest,
|
||||
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[5]}),
|
||||
::testing::Combine(::testing::ValuesIn({dynamicShapes[5], dynamicShapes[8]}),
|
||||
::testing::ValuesIn(mode),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip),
|
||||
@ -350,7 +356,7 @@ INSTANTIATE_TEST_SUITE_P(nightly_dynamic, GRUSequenceCPUTest,
|
||||
GRUSequenceCPUTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(nightly_dynamic_bf16, GRUSequenceCPUTest,
|
||||
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[6], dynamicShapes[7]}),
|
||||
::testing::Combine(::testing::ValuesIn({dynamicShapes[6], dynamicShapes[7]}),
|
||||
::testing::ValuesIn(mode),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip),
|
||||
|
@ -167,7 +167,13 @@ const std::vector<std::vector<ov::test::InputShape>> dynamicShapes = {
|
||||
{ { {1, 20}, {8, 12} }, // Dynamic shape 1
|
||||
{ {2, 10}, {5, 10}, {8, 10} } }, // Target shapes
|
||||
{ { {1, 20}, -1 }, // Dynamic shape 2
|
||||
{ {2, 10}, {5, 10}, {8, 10} } } } // Target shapes
|
||||
{ {2, 10}, {5, 10}, {8, 10} } } }, // Target shapes
|
||||
{ { { {1, 20}, {28, 32} }, // Dynamic shape 0
|
||||
{ {2, 30}, {5, 30}, {8, 30}, {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
|
||||
{ { {1, 20}, {8, 12} }, // Dynamic shape 1
|
||||
{ {2, 10}, {5, 10}, {8, 10}, {2, 10}, {5, 10}, {8, 10} } }, // Target shapes
|
||||
{ { {1, 20}, -1 }, // Dynamic shape 2
|
||||
{ {2, 10}, {5, 10}, {8, 10}, {2, 10}, {5, 10}, {8, 10} } } } // Target shapes
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, LSTMCellLayerCPUTest,
|
||||
|
@ -319,14 +319,22 @@ const std::vector<std::vector<InputShape>> dynamicShapes = {
|
||||
{ {1}, {1}, {1} } } }, // Target shapes
|
||||
{ { {3, -1, {0, 12}}, // #6. Dynamic shape 0
|
||||
{ {3, 2, 10}, {3, 4, 10}, {3, 5, 10} } }, // Target shapes
|
||||
{ {3, -1, {0, 12}}, // Dynamic shape 1
|
||||
{ {3, 1, 10}, {3, 1, 10}, {3, 1, 10} } }, // Target shapes
|
||||
{ {3, -1, {0, 12}}, // Dynamic shape 2
|
||||
{ {3, 1, 10}, {3, 1, 10}, {3, 1, 10} } } } // Target shapes
|
||||
{ {3, -1, {0, 12}}, // Dynamic shape 1
|
||||
{ {3, 1, 10}, {3, 1, 10}, {3, 1, 10} } }, // Target shapes
|
||||
{ {3, -1, {0, 12}}, // Dynamic shape 2
|
||||
{ {3, 1, 10}, {3, 1, 10}, {3, 1, 10} } } }, // Target shapes
|
||||
{ { {{0, 11}, -1, {5, 15}}, // #7. Dynamic shape 0
|
||||
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10}, {10, 2, 10}, {5, 5, 10} } }, // Target shapes
|
||||
{ {-1, 1, -1}, // Dynamic shape 1
|
||||
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10}, {10, 1, 10}, {5, 1, 10} } }, // Target shapes
|
||||
{ {-1, 1, -1}, // Dynamic shape 2
|
||||
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10}, {10, 1, 10}, {5, 1, 10} } }, // Target shapes
|
||||
{ {-1}, // Dynamic shape 3
|
||||
{ {10}, {3}, {5}, {10}, {5} } } } // Target shapes
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, LSTMSequenceCPUTest,
|
||||
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[0], dynamicShapes[1], dynamicShapes[2]}),
|
||||
::testing::Combine(::testing::ValuesIn({dynamicShapes[0], dynamicShapes[1], dynamicShapes[2]}),
|
||||
::testing::ValuesIn(mode),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip),
|
||||
@ -337,7 +345,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamic, LSTMSequenceCPUTest,
|
||||
LSTMSequenceCPUTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_BatchSizeOne, LSTMSequenceCPUTest,
|
||||
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[4]}),
|
||||
::testing::Combine(::testing::ValuesIn({dynamicShapes[4]}),
|
||||
::testing::ValuesIn(mode),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip),
|
||||
@ -348,7 +356,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamic_BatchSizeOne, LSTMSequenceCPUTest,
|
||||
LSTMSequenceCPUTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(nightly_dynamic, LSTMSequenceCPUTest,
|
||||
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[5]}),
|
||||
::testing::Combine(::testing::ValuesIn({dynamicShapes[5], dynamicShapes[7]}),
|
||||
::testing::ValuesIn(mode),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip),
|
||||
@ -359,7 +367,7 @@ INSTANTIATE_TEST_SUITE_P(nightly_dynamic, LSTMSequenceCPUTest,
|
||||
LSTMSequenceCPUTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(nightly_dynamic_bf16, LSTMSequenceCPUTest,
|
||||
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[6]}),
|
||||
::testing::Combine(::testing::ValuesIn({dynamicShapes[6]}),
|
||||
::testing::ValuesIn(mode),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip),
|
||||
|
@ -139,18 +139,22 @@ INSTANTIATE_TEST_SUITE_P(smoke_static, RNNCellCPUTest,
|
||||
RNNCellCPUTest::getTestCaseName);
|
||||
|
||||
const std::vector<std::vector<ov::test::InputShape>> dynamicShapes = {
|
||||
{ { { {-1}, 1 }, // Dynamic shape 0
|
||||
{ {1, 1}, {3, 1}, {5, 1} } }, // Target shapes
|
||||
{ { {-1}, 1 }, // Dynamic shape 1
|
||||
{ {1, 1}, {3, 1}, {5, 1} } } }, // Target shapes
|
||||
{ { { {1, 10}, 30 }, // Dynamic shape 0
|
||||
{ {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
|
||||
{ { {1, 10}, 10 }, // Dynamic shape 1
|
||||
{ {2, 10}, {5, 10}, {8, 10} } } }, // Target shapes
|
||||
{ { { {1, 10}, -1 }, // Dynamic shape 0
|
||||
{ {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
|
||||
{ { {1, 10}, {1, 11} }, // Dynamic shape 1
|
||||
{ {2, 10}, {5, 10}, {8, 10} } } } // Target shapes
|
||||
{ { { {-1}, 1 }, // Dynamic shape 0
|
||||
{ {1, 1}, {3, 1}, {5, 1} } }, // Target shapes
|
||||
{ { {-1}, 1 }, // Dynamic shape 1
|
||||
{ {1, 1}, {3, 1}, {5, 1} } } }, // Target shapes
|
||||
{ { { {1, 10}, 30 }, // Dynamic shape 0
|
||||
{ {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
|
||||
{ { {1, 10}, 10 }, // Dynamic shape 1
|
||||
{ {2, 10}, {5, 10}, {8, 10} } } }, // Target shapes
|
||||
{ { { {1, 10}, -1 }, // Dynamic shape 0
|
||||
{ {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
|
||||
{ { {1, 10}, {1, 11} }, // Dynamic shape 1
|
||||
{ {2, 10}, {5, 10}, {8, 10} } } }, // Target shapes
|
||||
{ { { {1, 10}, -1 }, // Dynamic shape 0
|
||||
{ {2, 30}, {5, 30}, {2, 30}, {8, 30}, {5, 30}, {8, 30} } }, // Target shapes
|
||||
{ { {1, 10}, {1, 11} }, // Dynamic shape 1
|
||||
{ {2, 10}, {5, 10}, {2, 10}, {8, 10}, {5, 10}, {8, 10} } } } // Target shapes
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, RNNCellCPUTest,
|
||||
|
@ -229,54 +229,60 @@ INSTANTIATE_TEST_SUITE_P(smoke_static_BatchSizeOne, RNNSequenceCPUTest,
|
||||
RNNSequenceCPUTest::getTestCaseName);
|
||||
|
||||
const std::vector<std::vector<InputShape>> dynamicShapes = {
|
||||
{ { {-1, {1, 5}, 10}, // #0. Dynamic shape 0
|
||||
{ {10, 2, 10}, {8, 3, 10}, {5, 4, 10} } }, // Target shapes
|
||||
{ {{0, 15}, 1, 1}, // Dynamic shape 1
|
||||
{ {10, 1, 1}, {8, 1, 1}, {5, 1, 1} } }, // Target shapes
|
||||
{ {{0, 12}}, // Dynamic shape 2
|
||||
{ {10}, {8}, {5} } } }, // Target shapes
|
||||
{ { {{0, 11}, -1, 10}, // #1. Dynamic shape 0
|
||||
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10} } }, // Target shapes
|
||||
{ {-1, 1, 10}, // Dynamic shape 1
|
||||
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10} } }, // Target shapes
|
||||
{ {-1}, // Dynamic shape 2
|
||||
{ {10}, {3}, {5} } } }, // Target shapes
|
||||
{ { {{0, 11}, -1, {5, 15}}, // #2. Dynamic shape 0
|
||||
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10} } }, // Target shapes
|
||||
{ {-1, -1, {8, 11}}, // Dynamic shape 1
|
||||
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10} } }, // Target shapes
|
||||
{ {-1}, // Dynamic shape 2
|
||||
{ {10}, {3}, {5} } } }, // Target shapes
|
||||
{ { {-1, {0, 7}, 10}, // #3. Dynamic shape 0
|
||||
{ {1, 2, 10}, {1, 3, 10}, {1, 6, 10} } }, // Target shapes
|
||||
{ {-1, 1, 1}, // Dynamic shape 1
|
||||
{ {1, 1, 1}, {1, 1, 1}, {1, 1, 1} } }, // Target shapes
|
||||
{ {-1}, // Dynamic shape 2
|
||||
{ {1}, {1}, {1} } } }, // Target shapes
|
||||
{ { {1, -1, 10}, // #4. Dynamic shape 0
|
||||
{ {1, 2, 10}, {1, 4, 10}, {1, 8, 10} } }, // Target shapes
|
||||
{ {1, 1, 10}, // Dynamic shape 1
|
||||
{ {1, 1, 10}, {1, 1, 10}, {1, 1, 10} } }, // Target shapes
|
||||
{ {1}, // Dynamic shape 2
|
||||
{ {1}, {1}, {1} } } }, // Target shapes
|
||||
{ { {-1, -1, -1}, // #5. Dynamic shape 0
|
||||
{ {1, 2, 10}, {1, 4, 10}, {1, 8, 10} } }, // Target shapes
|
||||
{ {-1, -1, -1}, // Dynamic shape 1
|
||||
{ {1, 1, 10}, {1, 1, 10}, {1, 1, 10} } }, // Target shapes
|
||||
{ {-1}, // Dynamic shape 2
|
||||
{ {1}, {1}, {1} } } }, // Target shapes
|
||||
{ { {-1, {1, 5}, 10}, // #6. Dynamic shape 0
|
||||
{ {10, 2, 10}, {8, 3, 10}, {5, 4, 10} } }, // Target shapes
|
||||
{ {{0, 15}, 1, 1}, // Dynamic shape 1
|
||||
{ {10, 1, 1}, {8, 1, 1}, {5, 1, 1} } } }, // Target shapes
|
||||
{ { {{0, 11}, -1, 10}, // #7. Dynamic shape 0
|
||||
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10} } }, // Target shapes
|
||||
{ {-1, 1, 10}, // Dynamic shape 1
|
||||
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10} } } } // Target shapes
|
||||
{ { {-1, {1, 5}, 10}, // #0. Dynamic shape 0
|
||||
{ {10, 2, 10}, {8, 3, 10}, {5, 4, 10} } }, // Target shapes
|
||||
{ {{0, 15}, 1, 1}, // Dynamic shape 1
|
||||
{ {10, 1, 1}, {8, 1, 1}, {5, 1, 1} } }, // Target shapes
|
||||
{ {{0, 12}}, // Dynamic shape 2
|
||||
{ {10}, {8}, {5} } } }, // Target shapes
|
||||
{ { {{0, 11}, -1, 10}, // #1. Dynamic shape 0
|
||||
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10} } }, // Target shapes
|
||||
{ {-1, 1, 10}, // Dynamic shape 1
|
||||
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10} } }, // Target shapes
|
||||
{ {-1}, // Dynamic shape 2
|
||||
{ {10}, {3}, {5} } } }, // Target shapes
|
||||
{ { {{0, 11}, -1, {5, 15}}, // #2. Dynamic shape 0
|
||||
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10} } }, // Target shapes
|
||||
{ {-1, -1, {8, 11}}, // Dynamic shape 1
|
||||
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10} } }, // Target shapes
|
||||
{ {-1}, // Dynamic shape 2
|
||||
{ {10}, {3}, {5} } } }, // Target shapes
|
||||
{ { {-1, {0, 7}, 10}, // #3. Dynamic shape 0
|
||||
{ {1, 2, 10}, {1, 3, 10}, {1, 6, 10} } }, // Target shapes
|
||||
{ {-1, 1, 1}, // Dynamic shape 1
|
||||
{ {1, 1, 1}, {1, 1, 1}, {1, 1, 1} } }, // Target shapes
|
||||
{ {-1}, // Dynamic shape 2
|
||||
{ {1}, {1}, {1} } } }, // Target shapes
|
||||
{ { {1, -1, 10}, // #4. Dynamic shape 0
|
||||
{ {1, 2, 10}, {1, 4, 10}, {1, 8, 10} } }, // Target shapes
|
||||
{ {1, 1, 10}, // Dynamic shape 1
|
||||
{ {1, 1, 10}, {1, 1, 10}, {1, 1, 10} } }, // Target shapes
|
||||
{ {1}, // Dynamic shape 2
|
||||
{ {1}, {1}, {1} } } }, // Target shapes
|
||||
{ { {-1, -1, -1}, // #5. Dynamic shape 0
|
||||
{ {1, 2, 10}, {1, 4, 10}, {1, 8, 10} } }, // Target shapes
|
||||
{ {-1, -1, -1}, // Dynamic shape 1
|
||||
{ {1, 1, 10}, {1, 1, 10}, {1, 1, 10} } }, // Target shapes
|
||||
{ {-1}, // Dynamic shape 2
|
||||
{ {1}, {1}, {1} } } }, // Target shapes
|
||||
{ { {7, {1, 5}, 10}, // #6. Dynamic shape 0
|
||||
{ {7, 2, 10}, {7, 3, 10}, {7, 4, 10} } }, // Target shapes
|
||||
{ {7, 1, 1}, // Dynamic shape 1
|
||||
{ {7, 1, 1}, {7, 1, 1}, {7, 1, 1} } } }, // Target shapes
|
||||
{ { {5, -1, 10}, // #7. Dynamic shape 0
|
||||
{ {5, 2, 10}, {5, 4, 10}, {5, 5, 10} } }, // Target shapes
|
||||
{ {5, 1, 10}, // Dynamic shape 1
|
||||
{ {5, 1, 10}, {5, 1, 10}, {5, 1, 10} } } }, // Target shapes
|
||||
{ { {{0, 11}, -1, 10}, // #8. Dynamic shape 0
|
||||
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10}, {10, 2, 10}, {5, 5, 10} } }, // Target shapes
|
||||
{ {-1, 1, 10}, // Dynamic shape 1
|
||||
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10}, {10, 1, 10}, {5, 1, 10} } }, // Target shapes
|
||||
{ {-1}, // Dynamic shape 2
|
||||
{ {10}, {3}, {5}, {10}, {5} } } } // Target shapes
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, RNNSequenceCPUTest,
|
||||
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[0], dynamicShapes[1], dynamicShapes[2]}),
|
||||
::testing::Combine(::testing::ValuesIn({dynamicShapes[0], dynamicShapes[1], dynamicShapes[2]}),
|
||||
::testing::ValuesIn(mode),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip),
|
||||
@ -298,13 +304,13 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamic_BatchSizeOne, RNNSequenceCPUTest,
|
||||
RNNSequenceCPUTest::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(nightly_dynamic, RNNSequenceCPUTest,
|
||||
::testing::Combine(::testing::Values(dynamicShapes[5]),
|
||||
::testing::Combine(::testing::ValuesIn({dynamicShapes[5], dynamicShapes[8]}),
|
||||
::testing::ValuesIn(mode),
|
||||
::testing::ValuesIn(activations),
|
||||
::testing::ValuesIn(clip),
|
||||
::testing::ValuesIn(direction),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(cpuParamsBatchSizeOne),
|
||||
::testing::Values(cpuParams),
|
||||
::testing::Values(std::map<std::string, std::string>{})),
|
||||
RNNSequenceCPUTest::getTestCaseName);
|
||||
|
||||
@ -315,7 +321,7 @@ INSTANTIATE_TEST_SUITE_P(nightly_dynamic_bf16, RNNSequenceCPUTest,
|
||||
::testing::ValuesIn(clip),
|
||||
::testing::ValuesIn(direction),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(cpuParamsBatchSizeOne),
|
||||
::testing::Values(cpuParams),
|
||||
::testing::Values(additionalConfig[1])),
|
||||
RNNSequenceCPUTest::getTestCaseName);
|
||||
} // namespace
|
||||
|
Loading…
Reference in New Issue
Block a user