[CPU] RNN node kernel caching. (#9588)
This commit is contained in:
parent
d5fb0a0c24
commit
77c2c5fab3
@ -1,4 +1,4 @@
|
|||||||
// Copyright (C) 2018-2021 Intel Corporation
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
@ -10,6 +10,7 @@
|
|||||||
#include "mkldnn_input_node.h"
|
#include "mkldnn_input_node.h"
|
||||||
#include <mkldnn_extension_utils.h>
|
#include <mkldnn_extension_utils.h>
|
||||||
#include "memory_desc/dnnl_blocked_memory_desc.h"
|
#include "memory_desc/dnnl_blocked_memory_desc.h"
|
||||||
|
#include <common/primitive_hashing_utils.hpp>
|
||||||
|
|
||||||
#include <ngraph/node.hpp>
|
#include <ngraph/node.hpp>
|
||||||
|
|
||||||
@ -110,6 +111,61 @@ const std::map<Precision, Precision> MKLDNNRNN::weightsByLayerPrec {
|
|||||||
// {Precision::U8, Precision::I8},
|
// {Precision::U8, Precision::I8},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct RNNKey {
|
||||||
|
const std::vector<DnnlBlockedMemoryDescPtr> inDataDescs;
|
||||||
|
const std::vector<DnnlBlockedMemoryDescPtr> outDataDescs;
|
||||||
|
const std::vector<mkldnn::memory::desc> wDescs;
|
||||||
|
mkldnn::algorithm cellType;
|
||||||
|
|
||||||
|
size_t hash() const;
|
||||||
|
bool operator==(const RNNKey& rhs) const;
|
||||||
|
};
|
||||||
|
|
||||||
|
size_t RNNKey::hash() const {
|
||||||
|
using namespace dnnl::impl;
|
||||||
|
using namespace dnnl::impl::primitive_hashing;
|
||||||
|
|
||||||
|
size_t seed = 0lu;
|
||||||
|
|
||||||
|
for (auto& desc : inDataDescs) {
|
||||||
|
if (desc != nullptr)
|
||||||
|
seed = hash_combine(seed, get_md_hash(desc->getDnnlDesc().data));
|
||||||
|
}
|
||||||
|
for (auto& desc : outDataDescs) {
|
||||||
|
if (desc != nullptr)
|
||||||
|
seed = hash_combine(seed, get_md_hash(desc->getDnnlDesc().data));
|
||||||
|
}
|
||||||
|
for (auto& desc : wDescs) {
|
||||||
|
seed = hash_combine(seed, get_md_hash(desc.data));
|
||||||
|
}
|
||||||
|
seed = hash_combine(seed, cellType);
|
||||||
|
return seed;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool RNNKey::operator==(const RNNKey& rhs) const {
|
||||||
|
if (inDataDescs.size() != rhs.inDataDescs.size() || outDataDescs.size() != rhs.outDataDescs.size() || wDescs.size() != rhs.wDescs.size() ||
|
||||||
|
cellType != rhs.cellType)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
for (size_t i = 0lu; i < inDataDescs.size(); i++) {
|
||||||
|
if (inDataDescs[i] != rhs.inDataDescs[i] && (inDataDescs[i] == nullptr || rhs.inDataDescs[i] == nullptr ||
|
||||||
|
inDataDescs[i]->getDnnlDesc() != rhs.inDataDescs[i]->getDnnlDesc()))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (size_t i = 0lu; i < outDataDescs.size(); i++) {
|
||||||
|
if (outDataDescs[i] != rhs.outDataDescs[i] && (outDataDescs[i] == nullptr || rhs.outDataDescs[i] ||
|
||||||
|
outDataDescs[i]->getDnnlDesc() == rhs.outDataDescs[i]->getDnnlDesc()))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (size_t i = 0lu; i < wDescs.size(); i++) {
|
||||||
|
if (wDescs[i] != rhs.wDescs[i])
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool MKLDNNRNN::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
|
bool MKLDNNRNN::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
|
||||||
try {
|
try {
|
||||||
if (!one_of(op->get_type_info(),
|
if (!one_of(op->get_type_info(),
|
||||||
@ -231,7 +287,6 @@ MKLDNNRNN::MKLDNNRNN(const std::shared_ptr<ov::Node>& op, const mkldnn::engine&
|
|||||||
THROW_ERROR << "does not have original layer for RNNCell.";
|
THROW_ERROR << "does not have original layer for RNNCell.";
|
||||||
|
|
||||||
cell_type = ie2dnnl(op);
|
cell_type = ie2dnnl(op);
|
||||||
cell_act = mkldnn::algorithm::undef;
|
|
||||||
if (!rnnCellBase->get_activations().empty())
|
if (!rnnCellBase->get_activations().empty())
|
||||||
cell_act = ie2dnnl(rnnCellBase->get_activations()[0]); // Works only for RNN with one gate
|
cell_act = ie2dnnl(rnnCellBase->get_activations()[0]); // Works only for RNN with one gate
|
||||||
|
|
||||||
@ -305,15 +360,15 @@ void MKLDNNRNN::fillCellDesc() {
|
|||||||
inDataDescs.reserve(S + 1);
|
inDataDescs.reserve(S + 1);
|
||||||
outDataDescs.reserve(S + 1);
|
outDataDescs.reserve(S + 1);
|
||||||
|
|
||||||
inDataDescs.emplace_back(inShape, dataType, memory::format_tag::tnc);
|
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(inShape, dataType, memory::format_tag::tnc));
|
||||||
outDataDescs.emplace_back(outShape, dataType, memory::format_tag::tnc);
|
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(outShape, dataType, memory::format_tag::tnc));
|
||||||
|
|
||||||
inDataDescs.emplace_back(shapeS_4D, dataType, memory::format_tag::ldnc);
|
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc));
|
||||||
outDataDescs.emplace_back(shapeS_4D, dataType, memory::format_tag::ldnc);
|
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc));
|
||||||
|
|
||||||
if (haveCellState(cell_type)) {
|
if (haveCellState(cell_type)) {
|
||||||
inDataDescs.emplace_back(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
|
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc));
|
||||||
outDataDescs.emplace_back(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
|
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc));
|
||||||
}
|
}
|
||||||
|
|
||||||
copyWeightsData();
|
copyWeightsData();
|
||||||
@ -375,15 +430,15 @@ void MKLDNNRNN::fillSequenceDesc() {
|
|||||||
shapeNTDC {{N.minVal, T.minVal, DC}, {N.maxVal, T.maxVal, DC}};
|
shapeNTDC {{N.minVal, T.minVal, DC}, {N.maxVal, T.maxVal, DC}};
|
||||||
|
|
||||||
// Try to create descriptor and corresponding configuration
|
// Try to create descriptor and corresponding configuration
|
||||||
inDataDescs.emplace_back(inShape, dataType, memory::format_tag::tnc);
|
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(inShape, dataType, memory::format_tag::tnc));
|
||||||
outDataDescs.emplace_back(outShape, dataType, memory::format_tag::tnc);
|
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(outShape, dataType, memory::format_tag::tnc));
|
||||||
|
|
||||||
inDataDescs.emplace_back(shapeS_4D, dataType, memory::format_tag::ldnc);
|
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc));
|
||||||
outDataDescs.emplace_back(shapeS_4D, dataType, memory::format_tag::ldnc);
|
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc));
|
||||||
|
|
||||||
if (haveCellState(cell_type)) {
|
if (haveCellState(cell_type)) {
|
||||||
inDataDescs.emplace_back(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
|
inDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc));
|
||||||
outDataDescs.emplace_back(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
|
outDataDescs.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc));
|
||||||
}
|
}
|
||||||
|
|
||||||
copyWeightsData();
|
copyWeightsData();
|
||||||
@ -424,7 +479,7 @@ void MKLDNNRNN::fillSequenceDesc() {
|
|||||||
outCandidate.reserve(3);
|
outCandidate.reserve(3);
|
||||||
|
|
||||||
if (nativeOrder) {
|
if (nativeOrder) {
|
||||||
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(outDataDescs[RNNInOutKind::Layer]));
|
outCandidate.emplace_back(outDataDescs[RNNInOutKind::Layer]);
|
||||||
} else if (N.isStatic() && N.maxVal == 1) {
|
} else if (N.isStatic() && N.maxVal == 1) {
|
||||||
// WA to avoid reorder after sequence for some models
|
// WA to avoid reorder after sequence for some models
|
||||||
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNTSC, dataType, memory::format_tag::tnc));
|
outCandidate.emplace_back(std::make_shared<DnnlBlockedMemoryDesc>(shapeNTSC, dataType, memory::format_tag::tnc));
|
||||||
@ -632,54 +687,54 @@ void MKLDNNRNN::fillDescs() {
|
|||||||
prop_kind::forward_scoring,
|
prop_kind::forward_scoring,
|
||||||
cell_act,
|
cell_act,
|
||||||
direction,
|
direction,
|
||||||
/* In Data */ inDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
|
/* In Data */ inDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
|
||||||
/* In State */ inDataDescs[RNNInOutKind::HiddenState].getDnnlDesc(),
|
/* In State */ inDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc(),
|
||||||
/* Weights data */ wDescs[0],
|
/* Weights data */ wDescs[0],
|
||||||
/* Weights state */ wDescs[1],
|
/* Weights state */ wDescs[1],
|
||||||
/* Bias */ wDescs[2],
|
/* Bias */ wDescs[2],
|
||||||
/* Out Data */ outDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
|
/* Out Data */ outDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
|
||||||
/* Out State */ outDataDescs[RNNInOutKind::HiddenState].getDnnlDesc()));
|
/* Out State */ outDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc()));
|
||||||
descs.push_back(desc);
|
descs.push_back(desc);
|
||||||
} break;
|
} break;
|
||||||
case mkldnn::algorithm::vanilla_gru: {
|
case mkldnn::algorithm::vanilla_gru: {
|
||||||
MKLDNNDescriptor desc(std::make_shared<gru_forward::desc>(
|
MKLDNNDescriptor desc(std::make_shared<gru_forward::desc>(
|
||||||
prop_kind::forward_scoring,
|
prop_kind::forward_scoring,
|
||||||
direction,
|
direction,
|
||||||
/* In Data */ inDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
|
/* In Data */ inDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
|
||||||
/* In State */ inDataDescs[RNNInOutKind::HiddenState].getDnnlDesc(),
|
/* In State */ inDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc(),
|
||||||
/* Weights data */ wDescs[0],
|
/* Weights data */ wDescs[0],
|
||||||
/* Weights state */ wDescs[1],
|
/* Weights state */ wDescs[1],
|
||||||
/* Bias */ wDescs[2],
|
/* Bias */ wDescs[2],
|
||||||
/* Out Data */ outDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
|
/* Out Data */ outDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
|
||||||
/* Out State */ outDataDescs[RNNInOutKind::HiddenState].getDnnlDesc()));
|
/* Out State */ outDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc()));
|
||||||
descs.push_back(desc);
|
descs.push_back(desc);
|
||||||
} break;
|
} break;
|
||||||
case mkldnn::algorithm::lbr_gru: {
|
case mkldnn::algorithm::lbr_gru: {
|
||||||
MKLDNNDescriptor desc(std::make_shared<lbr_gru_forward::desc>(
|
MKLDNNDescriptor desc(std::make_shared<lbr_gru_forward::desc>(
|
||||||
prop_kind::forward_scoring,
|
prop_kind::forward_scoring,
|
||||||
direction,
|
direction,
|
||||||
/* In Data */ inDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
|
/* In Data */ inDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
|
||||||
/* In State */ inDataDescs[RNNInOutKind::HiddenState].getDnnlDesc(),
|
/* In State */ inDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc(),
|
||||||
/* Weights data */ wDescs[0],
|
/* Weights data */ wDescs[0],
|
||||||
/* Weights state */ wDescs[1],
|
/* Weights state */ wDescs[1],
|
||||||
/* Bias */ wDescs[2],
|
/* Bias */ wDescs[2],
|
||||||
/* Out Data */ outDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
|
/* Out Data */ outDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
|
||||||
/* Out State */ outDataDescs[RNNInOutKind::HiddenState].getDnnlDesc()));
|
/* Out State */ outDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc()));
|
||||||
descs.push_back(desc);
|
descs.push_back(desc);
|
||||||
} break;
|
} break;
|
||||||
case mkldnn::algorithm::vanilla_lstm: {
|
case mkldnn::algorithm::vanilla_lstm: {
|
||||||
MKLDNNDescriptor desc(std::make_shared<lstm_forward::desc>(
|
MKLDNNDescriptor desc(std::make_shared<lstm_forward::desc>(
|
||||||
prop_kind::forward_scoring,
|
prop_kind::forward_scoring,
|
||||||
direction,
|
direction,
|
||||||
/* In Data */ inDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
|
/* In Data */ inDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
|
||||||
/* In State */ inDataDescs[RNNInOutKind::HiddenState].getDnnlDesc(),
|
/* In State */ inDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc(),
|
||||||
/* In State C */ inDataDescs[RNNInOutKind::CellState].getDnnlDesc(),
|
/* In State C */ inDataDescs[RNNInOutKind::CellState]->getDnnlDesc(),
|
||||||
/* Weights data */ wDescs[0],
|
/* Weights data */ wDescs[0],
|
||||||
/* Weights state */ wDescs[1],
|
/* Weights state */ wDescs[1],
|
||||||
/* Bias */ wDescs[2],
|
/* Bias */ wDescs[2],
|
||||||
/* Out Data */ outDataDescs[RNNInOutKind::Layer].getDnnlDesc(),
|
/* Out Data */ outDataDescs[RNNInOutKind::Layer]->getDnnlDesc(),
|
||||||
/* Out State */ outDataDescs[RNNInOutKind::HiddenState].getDnnlDesc(),
|
/* Out State */ outDataDescs[RNNInOutKind::HiddenState]->getDnnlDesc(),
|
||||||
/* Out State C */ outDataDescs[RNNInOutKind::CellState].getDnnlDesc()));
|
/* Out State C */ outDataDescs[RNNInOutKind::CellState]->getDnnlDesc()));
|
||||||
descs.push_back(desc);
|
descs.push_back(desc);
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
@ -740,15 +795,15 @@ void MKLDNNRNN::prepareParams() {
|
|||||||
const size_t SL = is_cell ? 1lu : dataMemPtr->GetShape().getStaticDims()[1];
|
const size_t SL = is_cell ? 1lu : dataMemPtr->GetShape().getStaticDims()[1];
|
||||||
const Shape shapeS_4D{L, D, B, SC};
|
const Shape shapeS_4D{L, D, B, SC};
|
||||||
|
|
||||||
inDataDescs[0] = DnnlBlockedMemoryDesc({SL, B, DC}, dataType, memory::format_tag::tnc);
|
inDataDescs[0] = std::make_shared<DnnlBlockedMemoryDesc>(Shape{SL, B, DC}, dataType, memory::format_tag::tnc);
|
||||||
outDataDescs[0] = DnnlBlockedMemoryDesc({SL, B, SC}, dataType, memory::format_tag::tnc);
|
outDataDescs[0] = std::make_shared<DnnlBlockedMemoryDesc>(Shape{SL, B, SC}, dataType, memory::format_tag::tnc);
|
||||||
|
|
||||||
inDataDescs[1] = DnnlBlockedMemoryDesc(shapeS_4D, dataType, memory::format_tag::ldnc);
|
inDataDescs[1] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc);
|
||||||
outDataDescs[1] = DnnlBlockedMemoryDesc(shapeS_4D, dataType, memory::format_tag::ldnc);
|
outDataDescs[1] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, dataType, memory::format_tag::ldnc);
|
||||||
|
|
||||||
if (haveCellState(cell_type)) {
|
if (haveCellState(cell_type)) {
|
||||||
inDataDescs[2] = DnnlBlockedMemoryDesc(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
|
inDataDescs[2] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
|
||||||
outDataDescs[2] = DnnlBlockedMemoryDesc(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
|
outDataDescs[2] = std::make_shared<DnnlBlockedMemoryDesc>(shapeS_4D, memory::data_type::f32, memory::format_tag::ldnc);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool wFormatWasChanged = false;
|
bool wFormatWasChanged = false;
|
||||||
@ -771,21 +826,37 @@ void MKLDNNRNN::prepareParams() {
|
|||||||
wDescs[1] = mkldnn::memory::desc(statesDims, dataType, wFormat);
|
wDescs[1] = mkldnn::memory::desc(statesDims, dataType, wFormat);
|
||||||
}
|
}
|
||||||
|
|
||||||
fillDescs();
|
RNNKey key = { inDataDescs, outDataDescs, wDescs, cell_type };
|
||||||
if (cell_type == mkldnn::algorithm::vanilla_rnn) {
|
|
||||||
std::shared_ptr<vanilla_rnn_forward::desc> desc = descs[0];
|
auto builder = [this](const RNNKey& key) -> std::shared_ptr<mkldnn::primitive> {
|
||||||
prim.reset(new vanilla_rnn_forward(vanilla_rnn_forward::primitive_desc(*desc, getEngine())));
|
fillDescs();
|
||||||
} else if (cell_type == mkldnn::algorithm::vanilla_gru) {
|
|
||||||
std::shared_ptr<gru_forward::desc> desc = descs[0];
|
if (key.cellType == mkldnn::algorithm::vanilla_rnn) {
|
||||||
prim.reset(new gru_forward(gru_forward::primitive_desc(*desc, getEngine())));
|
std::shared_ptr<vanilla_rnn_forward::desc> desc = descs[0];
|
||||||
} else if (cell_type == mkldnn::algorithm::lbr_gru) {
|
return std::make_shared<vanilla_rnn_forward>(vanilla_rnn_forward::primitive_desc(*desc, getEngine()));
|
||||||
std::shared_ptr<lbr_gru_forward::desc> desc = descs[0];
|
} else if (key.cellType == mkldnn::algorithm::vanilla_gru) {
|
||||||
prim.reset(new lbr_gru_forward(lbr_gru_forward::primitive_desc(*desc, getEngine())));
|
std::shared_ptr<gru_forward::desc> desc = descs[0];
|
||||||
} else if (cell_type == mkldnn::algorithm::vanilla_lstm) {
|
return std::make_shared<gru_forward>(gru_forward::primitive_desc(*desc, getEngine()));
|
||||||
std::shared_ptr<lstm_forward::desc> desc = descs[0];
|
} else if (key.cellType == mkldnn::algorithm::lbr_gru) {
|
||||||
prim.reset(new lstm_forward(lstm_forward::primitive_desc(*desc, getEngine())));
|
std::shared_ptr<lbr_gru_forward::desc> desc = descs[0];
|
||||||
|
return std::make_shared<lbr_gru_forward>(lbr_gru_forward::primitive_desc(*desc, getEngine()));
|
||||||
|
} else if (key.cellType == mkldnn::algorithm::vanilla_lstm) {
|
||||||
|
std::shared_ptr<lstm_forward::desc> desc = descs[0];
|
||||||
|
return std::make_shared<lstm_forward>(lstm_forward::primitive_desc(*desc, getEngine()));
|
||||||
|
} else {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto cache = getRuntimeCache();
|
||||||
|
auto result = cache->getOrCreate(key, builder);
|
||||||
|
|
||||||
|
if (!result.first) {
|
||||||
|
IE_THROW() << "Primitive descriptor was not found for node " << getName() << ".";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
prim = result.first;
|
||||||
|
|
||||||
if (!wasMemoryPrepared || wFormatWasChanged) {
|
if (!wasMemoryPrepared || wFormatWasChanged) {
|
||||||
auto itpd = descs[0].createPrimitiveDescriptorIterator(getEngine(), mkldnn::primitive_attr());
|
auto itpd = descs[0].createPrimitiveDescriptorIterator(getEngine(), mkldnn::primitive_attr());
|
||||||
prepareMemory(itpd);
|
prepareMemory(itpd);
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <mkldnn_node.h>
|
#include <mkldnn_node.h>
|
||||||
|
#include "memory_desc/dnnl_blocked_memory_desc.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
@ -97,8 +98,8 @@ private:
|
|||||||
const size_t L = 1; /**< What is it??. Constant for mkldnn impl */
|
const size_t L = 1; /**< What is it??. Constant for mkldnn impl */
|
||||||
const size_t D = 1; /**< Num of direction. 1 or 2 */
|
const size_t D = 1; /**< Num of direction. 1 or 2 */
|
||||||
|
|
||||||
std::vector<DnnlBlockedMemoryDesc> inDataDescs;
|
std::vector<DnnlBlockedMemoryDescPtr> inDataDescs;
|
||||||
std::vector<DnnlBlockedMemoryDesc> outDataDescs;
|
std::vector<DnnlBlockedMemoryDescPtr> outDataDescs;
|
||||||
std::vector<mkldnn::memory::desc> wDescs;
|
std::vector<mkldnn::memory::desc> wDescs;
|
||||||
|
|
||||||
enum RNNInOutKind {
|
enum RNNInOutKind {
|
||||||
|
@ -166,7 +166,11 @@ const std::vector<std::vector<ov::test::InputShape>> dynamicShapes = {
|
|||||||
{ { { {1, 10}, {25, 35} }, // Dynamic shape 0
|
{ { { {1, 10}, {25, 35} }, // Dynamic shape 0
|
||||||
{ {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
|
{ {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
|
||||||
{ { {1, 10}, -1 }, // Dynamic shape 1
|
{ { {1, 10}, -1 }, // Dynamic shape 1
|
||||||
{ {2, 10}, {5, 10}, {8, 10} } } } // Target shapes
|
{ {2, 10}, {5, 10}, {8, 10} } } }, // Target shapes
|
||||||
|
{ { { {1, 10}, {25, 35} }, // Dynamic shape 0
|
||||||
|
{ {2, 30}, {5, 30}, {8, 30}, {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
|
||||||
|
{ { {1, 10}, -1 }, // Dynamic shape 1
|
||||||
|
{ {2, 10}, {5, 10}, {8, 10}, {2, 10}, {5, 10}, {8, 10} } } } // Target shapes
|
||||||
};
|
};
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, GRUCellCPUTest,
|
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, GRUCellCPUTest,
|
||||||
|
@ -310,11 +310,17 @@ const std::vector<std::vector<InputShape>> dynamicShapes = {
|
|||||||
{ { {5, -1, 10}, // #7. Dynamic shape 0
|
{ { {5, -1, 10}, // #7. Dynamic shape 0
|
||||||
{ {5, 2, 10}, {5, 4, 10}, {5, 5, 10} } }, // Target shapes
|
{ {5, 2, 10}, {5, 4, 10}, {5, 5, 10} } }, // Target shapes
|
||||||
{ {5, 1, 10}, // Dynamic shape 1
|
{ {5, 1, 10}, // Dynamic shape 1
|
||||||
{ {5, 1, 10}, {5, 1, 10}, {5, 1, 10} } } } // Target shapes
|
{ {5, 1, 10}, {5, 1, 10}, {5, 1, 10} } } }, // Target shapes
|
||||||
|
{ { {{0, 11}, -1, {7, 11}}, // #8. Dynamic shape 0
|
||||||
|
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10}, {10, 2, 10}, {5, 5, 10} } }, // Target shapes
|
||||||
|
{ {-1, 1, {8, 12}}, // Dynamic shape 1
|
||||||
|
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10}, {10, 1, 10}, {5, 1, 10} } }, // Target shapes
|
||||||
|
{ {-1}, // Dynamic shape 2
|
||||||
|
{ {10}, {3}, {5}, {10}, {5} } } } // Target shapes
|
||||||
};
|
};
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, GRUSequenceCPUTest,
|
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, GRUSequenceCPUTest,
|
||||||
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[0], dynamicShapes[1], dynamicShapes[2]}),
|
::testing::Combine(::testing::ValuesIn({dynamicShapes[0], dynamicShapes[1], dynamicShapes[2]}),
|
||||||
::testing::ValuesIn(mode),
|
::testing::ValuesIn(mode),
|
||||||
::testing::ValuesIn(activations),
|
::testing::ValuesIn(activations),
|
||||||
::testing::ValuesIn(clip),
|
::testing::ValuesIn(clip),
|
||||||
@ -326,7 +332,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamic, GRUSequenceCPUTest,
|
|||||||
GRUSequenceCPUTest::getTestCaseName);
|
GRUSequenceCPUTest::getTestCaseName);
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_BatchSizeOne, GRUSequenceCPUTest,
|
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_BatchSizeOne, GRUSequenceCPUTest,
|
||||||
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[4]}),
|
::testing::Combine(::testing::ValuesIn({dynamicShapes[4]}),
|
||||||
::testing::ValuesIn(mode),
|
::testing::ValuesIn(mode),
|
||||||
::testing::ValuesIn(activations),
|
::testing::ValuesIn(activations),
|
||||||
::testing::ValuesIn(clip),
|
::testing::ValuesIn(clip),
|
||||||
@ -338,7 +344,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamic_BatchSizeOne, GRUSequenceCPUTest,
|
|||||||
GRUSequenceCPUTest::getTestCaseName);
|
GRUSequenceCPUTest::getTestCaseName);
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(nightly_dynamic, GRUSequenceCPUTest,
|
INSTANTIATE_TEST_SUITE_P(nightly_dynamic, GRUSequenceCPUTest,
|
||||||
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[5]}),
|
::testing::Combine(::testing::ValuesIn({dynamicShapes[5], dynamicShapes[8]}),
|
||||||
::testing::ValuesIn(mode),
|
::testing::ValuesIn(mode),
|
||||||
::testing::ValuesIn(activations),
|
::testing::ValuesIn(activations),
|
||||||
::testing::ValuesIn(clip),
|
::testing::ValuesIn(clip),
|
||||||
@ -350,7 +356,7 @@ INSTANTIATE_TEST_SUITE_P(nightly_dynamic, GRUSequenceCPUTest,
|
|||||||
GRUSequenceCPUTest::getTestCaseName);
|
GRUSequenceCPUTest::getTestCaseName);
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(nightly_dynamic_bf16, GRUSequenceCPUTest,
|
INSTANTIATE_TEST_SUITE_P(nightly_dynamic_bf16, GRUSequenceCPUTest,
|
||||||
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[6], dynamicShapes[7]}),
|
::testing::Combine(::testing::ValuesIn({dynamicShapes[6], dynamicShapes[7]}),
|
||||||
::testing::ValuesIn(mode),
|
::testing::ValuesIn(mode),
|
||||||
::testing::ValuesIn(activations),
|
::testing::ValuesIn(activations),
|
||||||
::testing::ValuesIn(clip),
|
::testing::ValuesIn(clip),
|
||||||
|
@ -167,7 +167,13 @@ const std::vector<std::vector<ov::test::InputShape>> dynamicShapes = {
|
|||||||
{ { {1, 20}, {8, 12} }, // Dynamic shape 1
|
{ { {1, 20}, {8, 12} }, // Dynamic shape 1
|
||||||
{ {2, 10}, {5, 10}, {8, 10} } }, // Target shapes
|
{ {2, 10}, {5, 10}, {8, 10} } }, // Target shapes
|
||||||
{ { {1, 20}, -1 }, // Dynamic shape 2
|
{ { {1, 20}, -1 }, // Dynamic shape 2
|
||||||
{ {2, 10}, {5, 10}, {8, 10} } } } // Target shapes
|
{ {2, 10}, {5, 10}, {8, 10} } } }, // Target shapes
|
||||||
|
{ { { {1, 20}, {28, 32} }, // Dynamic shape 0
|
||||||
|
{ {2, 30}, {5, 30}, {8, 30}, {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
|
||||||
|
{ { {1, 20}, {8, 12} }, // Dynamic shape 1
|
||||||
|
{ {2, 10}, {5, 10}, {8, 10}, {2, 10}, {5, 10}, {8, 10} } }, // Target shapes
|
||||||
|
{ { {1, 20}, -1 }, // Dynamic shape 2
|
||||||
|
{ {2, 10}, {5, 10}, {8, 10}, {2, 10}, {5, 10}, {8, 10} } } } // Target shapes
|
||||||
};
|
};
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, LSTMCellLayerCPUTest,
|
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, LSTMCellLayerCPUTest,
|
||||||
|
@ -319,14 +319,22 @@ const std::vector<std::vector<InputShape>> dynamicShapes = {
|
|||||||
{ {1}, {1}, {1} } } }, // Target shapes
|
{ {1}, {1}, {1} } } }, // Target shapes
|
||||||
{ { {3, -1, {0, 12}}, // #6. Dynamic shape 0
|
{ { {3, -1, {0, 12}}, // #6. Dynamic shape 0
|
||||||
{ {3, 2, 10}, {3, 4, 10}, {3, 5, 10} } }, // Target shapes
|
{ {3, 2, 10}, {3, 4, 10}, {3, 5, 10} } }, // Target shapes
|
||||||
{ {3, -1, {0, 12}}, // Dynamic shape 1
|
{ {3, -1, {0, 12}}, // Dynamic shape 1
|
||||||
{ {3, 1, 10}, {3, 1, 10}, {3, 1, 10} } }, // Target shapes
|
{ {3, 1, 10}, {3, 1, 10}, {3, 1, 10} } }, // Target shapes
|
||||||
{ {3, -1, {0, 12}}, // Dynamic shape 2
|
{ {3, -1, {0, 12}}, // Dynamic shape 2
|
||||||
{ {3, 1, 10}, {3, 1, 10}, {3, 1, 10} } } } // Target shapes
|
{ {3, 1, 10}, {3, 1, 10}, {3, 1, 10} } } }, // Target shapes
|
||||||
|
{ { {{0, 11}, -1, {5, 15}}, // #7. Dynamic shape 0
|
||||||
|
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10}, {10, 2, 10}, {5, 5, 10} } }, // Target shapes
|
||||||
|
{ {-1, 1, -1}, // Dynamic shape 1
|
||||||
|
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10}, {10, 1, 10}, {5, 1, 10} } }, // Target shapes
|
||||||
|
{ {-1, 1, -1}, // Dynamic shape 2
|
||||||
|
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10}, {10, 1, 10}, {5, 1, 10} } }, // Target shapes
|
||||||
|
{ {-1}, // Dynamic shape 3
|
||||||
|
{ {10}, {3}, {5}, {10}, {5} } } } // Target shapes
|
||||||
};
|
};
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, LSTMSequenceCPUTest,
|
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, LSTMSequenceCPUTest,
|
||||||
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[0], dynamicShapes[1], dynamicShapes[2]}),
|
::testing::Combine(::testing::ValuesIn({dynamicShapes[0], dynamicShapes[1], dynamicShapes[2]}),
|
||||||
::testing::ValuesIn(mode),
|
::testing::ValuesIn(mode),
|
||||||
::testing::ValuesIn(activations),
|
::testing::ValuesIn(activations),
|
||||||
::testing::ValuesIn(clip),
|
::testing::ValuesIn(clip),
|
||||||
@ -337,7 +345,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamic, LSTMSequenceCPUTest,
|
|||||||
LSTMSequenceCPUTest::getTestCaseName);
|
LSTMSequenceCPUTest::getTestCaseName);
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_BatchSizeOne, LSTMSequenceCPUTest,
|
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_BatchSizeOne, LSTMSequenceCPUTest,
|
||||||
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[4]}),
|
::testing::Combine(::testing::ValuesIn({dynamicShapes[4]}),
|
||||||
::testing::ValuesIn(mode),
|
::testing::ValuesIn(mode),
|
||||||
::testing::ValuesIn(activations),
|
::testing::ValuesIn(activations),
|
||||||
::testing::ValuesIn(clip),
|
::testing::ValuesIn(clip),
|
||||||
@ -348,7 +356,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamic_BatchSizeOne, LSTMSequenceCPUTest,
|
|||||||
LSTMSequenceCPUTest::getTestCaseName);
|
LSTMSequenceCPUTest::getTestCaseName);
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(nightly_dynamic, LSTMSequenceCPUTest,
|
INSTANTIATE_TEST_SUITE_P(nightly_dynamic, LSTMSequenceCPUTest,
|
||||||
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[5]}),
|
::testing::Combine(::testing::ValuesIn({dynamicShapes[5], dynamicShapes[7]}),
|
||||||
::testing::ValuesIn(mode),
|
::testing::ValuesIn(mode),
|
||||||
::testing::ValuesIn(activations),
|
::testing::ValuesIn(activations),
|
||||||
::testing::ValuesIn(clip),
|
::testing::ValuesIn(clip),
|
||||||
@ -359,7 +367,7 @@ INSTANTIATE_TEST_SUITE_P(nightly_dynamic, LSTMSequenceCPUTest,
|
|||||||
LSTMSequenceCPUTest::getTestCaseName);
|
LSTMSequenceCPUTest::getTestCaseName);
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(nightly_dynamic_bf16, LSTMSequenceCPUTest,
|
INSTANTIATE_TEST_SUITE_P(nightly_dynamic_bf16, LSTMSequenceCPUTest,
|
||||||
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[6]}),
|
::testing::Combine(::testing::ValuesIn({dynamicShapes[6]}),
|
||||||
::testing::ValuesIn(mode),
|
::testing::ValuesIn(mode),
|
||||||
::testing::ValuesIn(activations),
|
::testing::ValuesIn(activations),
|
||||||
::testing::ValuesIn(clip),
|
::testing::ValuesIn(clip),
|
||||||
|
@ -139,18 +139,22 @@ INSTANTIATE_TEST_SUITE_P(smoke_static, RNNCellCPUTest,
|
|||||||
RNNCellCPUTest::getTestCaseName);
|
RNNCellCPUTest::getTestCaseName);
|
||||||
|
|
||||||
const std::vector<std::vector<ov::test::InputShape>> dynamicShapes = {
|
const std::vector<std::vector<ov::test::InputShape>> dynamicShapes = {
|
||||||
{ { { {-1}, 1 }, // Dynamic shape 0
|
{ { { {-1}, 1 }, // Dynamic shape 0
|
||||||
{ {1, 1}, {3, 1}, {5, 1} } }, // Target shapes
|
{ {1, 1}, {3, 1}, {5, 1} } }, // Target shapes
|
||||||
{ { {-1}, 1 }, // Dynamic shape 1
|
{ { {-1}, 1 }, // Dynamic shape 1
|
||||||
{ {1, 1}, {3, 1}, {5, 1} } } }, // Target shapes
|
{ {1, 1}, {3, 1}, {5, 1} } } }, // Target shapes
|
||||||
{ { { {1, 10}, 30 }, // Dynamic shape 0
|
{ { { {1, 10}, 30 }, // Dynamic shape 0
|
||||||
{ {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
|
{ {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
|
||||||
{ { {1, 10}, 10 }, // Dynamic shape 1
|
{ { {1, 10}, 10 }, // Dynamic shape 1
|
||||||
{ {2, 10}, {5, 10}, {8, 10} } } }, // Target shapes
|
{ {2, 10}, {5, 10}, {8, 10} } } }, // Target shapes
|
||||||
{ { { {1, 10}, -1 }, // Dynamic shape 0
|
{ { { {1, 10}, -1 }, // Dynamic shape 0
|
||||||
{ {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
|
{ {2, 30}, {5, 30}, {8, 30} } }, // Target shapes
|
||||||
{ { {1, 10}, {1, 11} }, // Dynamic shape 1
|
{ { {1, 10}, {1, 11} }, // Dynamic shape 1
|
||||||
{ {2, 10}, {5, 10}, {8, 10} } } } // Target shapes
|
{ {2, 10}, {5, 10}, {8, 10} } } }, // Target shapes
|
||||||
|
{ { { {1, 10}, -1 }, // Dynamic shape 0
|
||||||
|
{ {2, 30}, {5, 30}, {2, 30}, {8, 30}, {5, 30}, {8, 30} } }, // Target shapes
|
||||||
|
{ { {1, 10}, {1, 11} }, // Dynamic shape 1
|
||||||
|
{ {2, 10}, {5, 10}, {2, 10}, {8, 10}, {5, 10}, {8, 10} } } } // Target shapes
|
||||||
};
|
};
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, RNNCellCPUTest,
|
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, RNNCellCPUTest,
|
||||||
|
@ -229,54 +229,60 @@ INSTANTIATE_TEST_SUITE_P(smoke_static_BatchSizeOne, RNNSequenceCPUTest,
|
|||||||
RNNSequenceCPUTest::getTestCaseName);
|
RNNSequenceCPUTest::getTestCaseName);
|
||||||
|
|
||||||
const std::vector<std::vector<InputShape>> dynamicShapes = {
|
const std::vector<std::vector<InputShape>> dynamicShapes = {
|
||||||
{ { {-1, {1, 5}, 10}, // #0. Dynamic shape 0
|
{ { {-1, {1, 5}, 10}, // #0. Dynamic shape 0
|
||||||
{ {10, 2, 10}, {8, 3, 10}, {5, 4, 10} } }, // Target shapes
|
{ {10, 2, 10}, {8, 3, 10}, {5, 4, 10} } }, // Target shapes
|
||||||
{ {{0, 15}, 1, 1}, // Dynamic shape 1
|
{ {{0, 15}, 1, 1}, // Dynamic shape 1
|
||||||
{ {10, 1, 1}, {8, 1, 1}, {5, 1, 1} } }, // Target shapes
|
{ {10, 1, 1}, {8, 1, 1}, {5, 1, 1} } }, // Target shapes
|
||||||
{ {{0, 12}}, // Dynamic shape 2
|
{ {{0, 12}}, // Dynamic shape 2
|
||||||
{ {10}, {8}, {5} } } }, // Target shapes
|
{ {10}, {8}, {5} } } }, // Target shapes
|
||||||
{ { {{0, 11}, -1, 10}, // #1. Dynamic shape 0
|
{ { {{0, 11}, -1, 10}, // #1. Dynamic shape 0
|
||||||
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10} } }, // Target shapes
|
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10} } }, // Target shapes
|
||||||
{ {-1, 1, 10}, // Dynamic shape 1
|
{ {-1, 1, 10}, // Dynamic shape 1
|
||||||
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10} } }, // Target shapes
|
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10} } }, // Target shapes
|
||||||
{ {-1}, // Dynamic shape 2
|
{ {-1}, // Dynamic shape 2
|
||||||
{ {10}, {3}, {5} } } }, // Target shapes
|
{ {10}, {3}, {5} } } }, // Target shapes
|
||||||
{ { {{0, 11}, -1, {5, 15}}, // #2. Dynamic shape 0
|
{ { {{0, 11}, -1, {5, 15}}, // #2. Dynamic shape 0
|
||||||
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10} } }, // Target shapes
|
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10} } }, // Target shapes
|
||||||
{ {-1, -1, {8, 11}}, // Dynamic shape 1
|
{ {-1, -1, {8, 11}}, // Dynamic shape 1
|
||||||
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10} } }, // Target shapes
|
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10} } }, // Target shapes
|
||||||
{ {-1}, // Dynamic shape 2
|
{ {-1}, // Dynamic shape 2
|
||||||
{ {10}, {3}, {5} } } }, // Target shapes
|
{ {10}, {3}, {5} } } }, // Target shapes
|
||||||
{ { {-1, {0, 7}, 10}, // #3. Dynamic shape 0
|
{ { {-1, {0, 7}, 10}, // #3. Dynamic shape 0
|
||||||
{ {1, 2, 10}, {1, 3, 10}, {1, 6, 10} } }, // Target shapes
|
{ {1, 2, 10}, {1, 3, 10}, {1, 6, 10} } }, // Target shapes
|
||||||
{ {-1, 1, 1}, // Dynamic shape 1
|
{ {-1, 1, 1}, // Dynamic shape 1
|
||||||
{ {1, 1, 1}, {1, 1, 1}, {1, 1, 1} } }, // Target shapes
|
{ {1, 1, 1}, {1, 1, 1}, {1, 1, 1} } }, // Target shapes
|
||||||
{ {-1}, // Dynamic shape 2
|
{ {-1}, // Dynamic shape 2
|
||||||
{ {1}, {1}, {1} } } }, // Target shapes
|
{ {1}, {1}, {1} } } }, // Target shapes
|
||||||
{ { {1, -1, 10}, // #4. Dynamic shape 0
|
{ { {1, -1, 10}, // #4. Dynamic shape 0
|
||||||
{ {1, 2, 10}, {1, 4, 10}, {1, 8, 10} } }, // Target shapes
|
{ {1, 2, 10}, {1, 4, 10}, {1, 8, 10} } }, // Target shapes
|
||||||
{ {1, 1, 10}, // Dynamic shape 1
|
{ {1, 1, 10}, // Dynamic shape 1
|
||||||
{ {1, 1, 10}, {1, 1, 10}, {1, 1, 10} } }, // Target shapes
|
{ {1, 1, 10}, {1, 1, 10}, {1, 1, 10} } }, // Target shapes
|
||||||
{ {1}, // Dynamic shape 2
|
{ {1}, // Dynamic shape 2
|
||||||
{ {1}, {1}, {1} } } }, // Target shapes
|
{ {1}, {1}, {1} } } }, // Target shapes
|
||||||
{ { {-1, -1, -1}, // #5. Dynamic shape 0
|
{ { {-1, -1, -1}, // #5. Dynamic shape 0
|
||||||
{ {1, 2, 10}, {1, 4, 10}, {1, 8, 10} } }, // Target shapes
|
{ {1, 2, 10}, {1, 4, 10}, {1, 8, 10} } }, // Target shapes
|
||||||
{ {-1, -1, -1}, // Dynamic shape 1
|
{ {-1, -1, -1}, // Dynamic shape 1
|
||||||
{ {1, 1, 10}, {1, 1, 10}, {1, 1, 10} } }, // Target shapes
|
{ {1, 1, 10}, {1, 1, 10}, {1, 1, 10} } }, // Target shapes
|
||||||
{ {-1}, // Dynamic shape 2
|
{ {-1}, // Dynamic shape 2
|
||||||
{ {1}, {1}, {1} } } }, // Target shapes
|
{ {1}, {1}, {1} } } }, // Target shapes
|
||||||
{ { {-1, {1, 5}, 10}, // #6. Dynamic shape 0
|
{ { {7, {1, 5}, 10}, // #6. Dynamic shape 0
|
||||||
{ {10, 2, 10}, {8, 3, 10}, {5, 4, 10} } }, // Target shapes
|
{ {7, 2, 10}, {7, 3, 10}, {7, 4, 10} } }, // Target shapes
|
||||||
{ {{0, 15}, 1, 1}, // Dynamic shape 1
|
{ {7, 1, 1}, // Dynamic shape 1
|
||||||
{ {10, 1, 1}, {8, 1, 1}, {5, 1, 1} } } }, // Target shapes
|
{ {7, 1, 1}, {7, 1, 1}, {7, 1, 1} } } }, // Target shapes
|
||||||
{ { {{0, 11}, -1, 10}, // #7. Dynamic shape 0
|
{ { {5, -1, 10}, // #7. Dynamic shape 0
|
||||||
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10} } }, // Target shapes
|
{ {5, 2, 10}, {5, 4, 10}, {5, 5, 10} } }, // Target shapes
|
||||||
{ {-1, 1, 10}, // Dynamic shape 1
|
{ {5, 1, 10}, // Dynamic shape 1
|
||||||
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10} } } } // Target shapes
|
{ {5, 1, 10}, {5, 1, 10}, {5, 1, 10} } } }, // Target shapes
|
||||||
|
{ { {{0, 11}, -1, 10}, // #8. Dynamic shape 0
|
||||||
|
{ {10, 2, 10}, {3, 4, 10}, {5, 5, 10}, {10, 2, 10}, {5, 5, 10} } }, // Target shapes
|
||||||
|
{ {-1, 1, 10}, // Dynamic shape 1
|
||||||
|
{ {10, 1, 10}, {3, 1, 10}, {5, 1, 10}, {10, 1, 10}, {5, 1, 10} } }, // Target shapes
|
||||||
|
{ {-1}, // Dynamic shape 2
|
||||||
|
{ {10}, {3}, {5}, {10}, {5} } } } // Target shapes
|
||||||
};
|
};
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, RNNSequenceCPUTest,
|
INSTANTIATE_TEST_SUITE_P(smoke_dynamic, RNNSequenceCPUTest,
|
||||||
::testing::Combine(::testing::ValuesIn(std::vector<std::vector<InputShape>>{dynamicShapes[0], dynamicShapes[1], dynamicShapes[2]}),
|
::testing::Combine(::testing::ValuesIn({dynamicShapes[0], dynamicShapes[1], dynamicShapes[2]}),
|
||||||
::testing::ValuesIn(mode),
|
::testing::ValuesIn(mode),
|
||||||
::testing::ValuesIn(activations),
|
::testing::ValuesIn(activations),
|
||||||
::testing::ValuesIn(clip),
|
::testing::ValuesIn(clip),
|
||||||
@ -298,13 +304,13 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamic_BatchSizeOne, RNNSequenceCPUTest,
|
|||||||
RNNSequenceCPUTest::getTestCaseName);
|
RNNSequenceCPUTest::getTestCaseName);
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(nightly_dynamic, RNNSequenceCPUTest,
|
INSTANTIATE_TEST_SUITE_P(nightly_dynamic, RNNSequenceCPUTest,
|
||||||
::testing::Combine(::testing::Values(dynamicShapes[5]),
|
::testing::Combine(::testing::ValuesIn({dynamicShapes[5], dynamicShapes[8]}),
|
||||||
::testing::ValuesIn(mode),
|
::testing::ValuesIn(mode),
|
||||||
::testing::ValuesIn(activations),
|
::testing::ValuesIn(activations),
|
||||||
::testing::ValuesIn(clip),
|
::testing::ValuesIn(clip),
|
||||||
::testing::ValuesIn(direction),
|
::testing::ValuesIn(direction),
|
||||||
::testing::ValuesIn(netPrecisions),
|
::testing::ValuesIn(netPrecisions),
|
||||||
::testing::Values(cpuParamsBatchSizeOne),
|
::testing::Values(cpuParams),
|
||||||
::testing::Values(std::map<std::string, std::string>{})),
|
::testing::Values(std::map<std::string, std::string>{})),
|
||||||
RNNSequenceCPUTest::getTestCaseName);
|
RNNSequenceCPUTest::getTestCaseName);
|
||||||
|
|
||||||
@ -315,7 +321,7 @@ INSTANTIATE_TEST_SUITE_P(nightly_dynamic_bf16, RNNSequenceCPUTest,
|
|||||||
::testing::ValuesIn(clip),
|
::testing::ValuesIn(clip),
|
||||||
::testing::ValuesIn(direction),
|
::testing::ValuesIn(direction),
|
||||||
::testing::ValuesIn(netPrecisions),
|
::testing::ValuesIn(netPrecisions),
|
||||||
::testing::Values(cpuParamsBatchSizeOne),
|
::testing::Values(cpuParams),
|
||||||
::testing::Values(additionalConfig[1])),
|
::testing::Values(additionalConfig[1])),
|
||||||
RNNSequenceCPUTest::getTestCaseName);
|
RNNSequenceCPUTest::getTestCaseName);
|
||||||
} // namespace
|
} // namespace
|
||||||
|
Loading…
Reference in New Issue
Block a user