[CPU] Implement TopK-11 to CPU plugin (#16522)

This commit is contained in:
Chen Xu 2023-03-31 16:28:20 +08:00 committed by GitHub
parent 6d064d26cb
commit 35398e339d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 130 additions and 74 deletions

View File

@ -8,6 +8,7 @@
#include <ngraph/runtime/reference/convert.hpp> #include <ngraph/runtime/reference/convert.hpp>
#include <openvino/opsets/opset1.hpp> #include <openvino/opsets/opset1.hpp>
#include <openvino/opsets/opset10.hpp> #include <openvino/opsets/opset10.hpp>
#include <openvino/opsets/opset11.hpp>
#include <openvino/opsets/opset3.hpp> #include <openvino/opsets/opset3.hpp>
#include <openvino/opsets/opset4.hpp> #include <openvino/opsets/opset4.hpp>
#include <openvino/opsets/opset5.hpp> #include <openvino/opsets/opset5.hpp>
@ -334,7 +335,9 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ngraph::Func
{opset9::MulticlassNms::get_type_info_static(), fuse_type_to_multiclass_nms}, {opset9::MulticlassNms::get_type_info_static(), fuse_type_to_multiclass_nms},
{opset9::GenerateProposals::get_type_info_static(), fuse_type_to_generate_proposals}, {opset9::GenerateProposals::get_type_info_static(), fuse_type_to_generate_proposals},
{opset6::CTCGreedyDecoderSeqLen::get_type_info_static(), fuse_type_to_ctc_greedy_decoder_seq_len}, {opset6::CTCGreedyDecoderSeqLen::get_type_info_static(), fuse_type_to_ctc_greedy_decoder_seq_len},
{opset1::TopK::get_type_info_static(), fuse_type_to_topk},
{opset4::TopK::get_type_info_static(), fuse_type_to_topk}, {opset4::TopK::get_type_info_static(), fuse_type_to_topk},
{opset11::TopK::get_type_info_static(), fuse_type_to_topk},
{opset8::MaxPool::get_type_info_static(), fuse_type_to_maxpool}, {opset8::MaxPool::get_type_info_static(), fuse_type_to_maxpool},
{opset4::NonZero::get_type_info_static(), fuse_type_to_nonzero}, {opset4::NonZero::get_type_info_static(), fuse_type_to_nonzero},
{opset4::Bucketize::get_type_info_static(), fuse_type_to_bucketize}, {opset4::Bucketize::get_type_info_static(), fuse_type_to_bucketize},
@ -653,7 +656,7 @@ bool fuse_type_to_generate_proposals(const std::shared_ptr<ngraph::Node>& node,
} }
bool fuse_type_to_topk(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) { bool fuse_type_to_topk(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
if (auto topk = ov::as_type_ptr<opset4::TopK>(node)) { if (auto topk = ov::as_type_ptr<ov::op::util::TopKBase>(node)) {
return update_type(1, node, precisions, [&](const element::Type& to) { return update_type(1, node, precisions, [&](const element::Type& to) {
topk->set_index_element_type(to); topk->set_index_element_type(to);
}); });

View File

@ -270,12 +270,12 @@ private:
inline void topk_loop() { inline void topk_loop() {
if (jcp_.algorithm == TopKAlgorithm::topk_bubble_sort) { if (jcp_.algorithm == TopKAlgorithm::topk_bubble_sort) {
if (jcp_.layout == TopKLayoutType::topk_blocked && jcp_.topk_innermost) { if (jcp_.layout == TopKLayoutType::topk_blocked && jcp_.topk_innermost) {
if (jcp_.top_k == 1) { if (jcp_.top_k == 1 && !jcp_.stable) {
topk_bubble_horiz(); topk_bubble_horiz();
} else { } else {
topk_bubble_BLK_on_channel_verti(); topk_bubble_BLK_on_channel_verti();
} }
} else if (jcp_.topk_innermost && jcp_.top_k == 1) { } else if (jcp_.topk_innermost && jcp_.top_k == 1 && !jcp_.stable) {
topk_bubble_horiz(); topk_bubble_horiz();
} else { } else {
topk_bubble_vector(); topk_bubble_vector();
@ -1788,30 +1788,32 @@ private:
} }
}; };
bool TopK::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept { bool TopK::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
try { try {
const auto topKOp = ngraph::as_type_ptr<const ngraph::op::v1::TopK>(op); if (!one_of(op->get_type_info(), ov::op::v1::TopK::get_type_info_static(),
if (!topKOp) { ov::op::v3::TopK::get_type_info_static(),
errorMessage = "Node is not an instance of the TopK from the operations set v1 or v3"; ov::op::v11::TopK::get_type_info_static())) {
errorMessage = "Node is not an instance of the TopK from the operation sets v1, v3 or v11";
return false; return false;
} }
auto topKOp = ov::as_type_ptr<const ov::op::util::TopKBase>(op);
if (!isDynamicNgraphNode(op)) { if (!isDynamicNgraphNode(op)) {
auto topKConst = std::dynamic_pointer_cast<const ngraph::opset1::Constant>(topKOp->get_input_node_shared_ptr(TOPK_K)); auto topKConst = std::dynamic_pointer_cast<const ov::op::v0::Constant>(topKOp->get_input_node_shared_ptr(TOPK_K));
if (!topKConst) { if (!topKConst) {
errorMessage = "Second tensor is not constant in static shape mode"; errorMessage = "Second tensor is not constant in static shape mode";
return false; return false;
} }
} }
if (topKOp->get_mode() != ngraph::op::TopKMode::MAX && if (topKOp->get_mode() != ov::op::TopKMode::MAX &&
topKOp->get_mode() != ngraph::op::TopKMode::MIN) { topKOp->get_mode() != ov::op::TopKMode::MIN) {
errorMessage = "Unsupported mode."; errorMessage = "Unsupported mode.";
return false; return false;
} }
if (!one_of(topKOp->get_sort_type(), ngraph::op::TopKSortType::NONE, if (!one_of(topKOp->get_sort_type(), ov::op::TopKSortType::NONE,
ngraph::op::TopKSortType::SORT_VALUES, ov::op::TopKSortType::SORT_VALUES,
ngraph::op::TopKSortType::SORT_INDICES)) { ov::op::TopKSortType::SORT_INDICES)) {
errorMessage = "Unsupported sort type."; errorMessage = "Unsupported sort type.";
return false; return false;
} }
@ -1821,13 +1823,13 @@ bool TopK::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, s
return true; return true;
} }
TopK::TopK(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context) TopK::TopK(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context)
: Node(op, context, NgraphShapeInferFactory(op, PortMask(TOPK_K))) { : Node(op, context, NgraphShapeInferFactory(op, PortMask(TOPK_K))) {
std::string errorMessage; std::string errorMessage;
if (isSupportedOperation(op, errorMessage)) { if (isSupportedOperation(op, errorMessage)) {
errorPrefix = "TopK layer with name '" + getName() + "'"; errorPrefix = "TopK layer with name '" + getName() + "'";
auto topKOp = ngraph::as_type_ptr<ngraph::op::v1::TopK>(op); auto topKOp = ov::as_type_ptr<const ov::op::util::TopKBase>(op);
auto in_dims = topKOp->get_input_partial_shape(TOPK_DATA); auto in_dims = topKOp->get_input_partial_shape(TOPK_DATA);
auto out_dims = topKOp->get_output_partial_shape(TOPK_DATA); auto out_dims = topKOp->get_output_partial_shape(TOPK_DATA);
@ -1835,15 +1837,23 @@ TopK::TopK(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr con
auto in_dims_size = in_dims.size(); auto in_dims_size = in_dims.size();
if (!isDynamicNgraphNode(op)) { if (!isDynamicNgraphNode(op)) {
auto topKConst = std::dynamic_pointer_cast<const ngraph::opset1::Constant>(topKOp->get_input_node_shared_ptr(TOPK_K)); auto topKConst = std::dynamic_pointer_cast<const ov::op::v0::Constant>(topKOp->get_input_node_shared_ptr(TOPK_K));
if (!topKConst) { if (!topKConst) {
IE_THROW() << errorPrefix << "gets non-constant second tensor in static shape mode!"; IE_THROW() << errorPrefix << "gets non-constant second tensor in static shape mode!";
} }
} }
axis = topKOp->get_axis(); axis = topKOp->get_axis();
mode_max = topKOp->get_mode() == ngraph::op::TopKMode::MAX; mode_max = topKOp->get_mode() == ov::op::TopKMode::MAX;
sort_index = topKOp->get_sort_type() == ngraph::op::TopKSortType::SORT_INDICES; sort_index = topKOp->get_sort_type() == ov::op::TopKSortType::SORT_INDICES;
stable = false;
if (!sort_index) {
const auto topKOpV11 = ngraph::as_type_ptr<const ov::op::v11::TopK>(op);
if (topKOpV11) {
stable = topKOpV11->get_stable();
}
}
top_k = 0; top_k = 0;
preset_params_done = false; preset_params_done = false;
@ -1959,7 +1969,10 @@ void TopK::preset_params() {
} }
if (isDynamicNode()) { if (isDynamicNode()) {
if ((layout == TopKLayoutType::topk_ncsp || layout == TopKLayoutType::topk_nspc) && topk_innermost) { if (stable) {
algorithm = TopKAlgorithm::topk_bubble_sort;
bubble_inplace = false;
} else if ((layout == TopKLayoutType::topk_ncsp || layout == TopKLayoutType::topk_nspc) && topk_innermost) {
algorithm = TopKAlgorithm::topk_heap_sort; algorithm = TopKAlgorithm::topk_heap_sort;
} else { } else {
algorithm = TopKAlgorithm::topk_bubble_sort; algorithm = TopKAlgorithm::topk_bubble_sort;
@ -2006,8 +2019,11 @@ void TopK::prepareParams() {
// [case 1]: if 2 * (top_k + 1) + 2 <= count_xmm, thus top_k is small enough that the vector registers are sufficient // [case 1]: if 2 * (top_k + 1) + 2 <= count_xmm, thus top_k is small enough that the vector registers are sufficient
// to keep all necessary data for sorting, no need to load and store frequently, use inplace bubble sort; // to keep all necessary data for sorting, no need to load and store frequently, use inplace bubble sort;
// (horizotal sorting cases not included) // (horizotal sorting cases not included)
// [case 2]: only when topk is imposed on innermost dimsension of planar(ncsp/nspc) layout, should heap sort be used; // [case 2]: if stable sorting is required, bubble sort(topk_bubble_vector/topk_bubble_BLK_on_channel_verti) will be
// [case 3]: by default, use bitonic sort when alg_cost_bitonic < alg_cost_bubble, otherwise use bubble sort. // applied currently, because among the implemented sorting algorithms, these bubble sort implementations
// are the only stable ones;
// [case 3]: only when topk is imposed on innermost dimsension of planar(ncsp/nspc) layout, should heap sort be used;
// [case 4]: by default, use bitonic sort when alg_cost_bitonic < alg_cost_bubble, otherwise use bubble sort.
// alg_cost_bitonic = (N / 4) * logN * (logN + 1) // alg_cost_bitonic = (N / 4) * logN * (logN + 1)
// alg_cost_bubble = K * (K - 1) / 2 + (N - K) * K // alg_cost_bubble = K * (K - 1) / 2 + (N - K) * K
// where, N = axis_dim, K = topk_k // where, N = axis_dim, K = topk_k
@ -2018,6 +2034,9 @@ void TopK::prepareParams() {
if (top_k <= count_xmm / 2 - 2) { if (top_k <= count_xmm / 2 - 2) {
algorithm = TopKAlgorithm::topk_bubble_sort; algorithm = TopKAlgorithm::topk_bubble_sort;
bubble_inplace = topk_innermost && top_k == 1 ? false : true; bubble_inplace = topk_innermost && top_k == 1 ? false : true;
} else if (stable) {
algorithm = TopKAlgorithm::topk_bubble_sort;
bubble_inplace = false;
} else if ((layout == TopKLayoutType::topk_ncsp || layout == TopKLayoutType::topk_nspc) && topk_innermost) { } else if ((layout == TopKLayoutType::topk_ncsp || layout == TopKLayoutType::topk_nspc) && topk_innermost) {
algorithm = TopKAlgorithm::topk_heap_sort; algorithm = TopKAlgorithm::topk_heap_sort;
} else { } else {
@ -2074,6 +2093,7 @@ void TopK::createPrimitive() {
jcp.topk_innermost = topk_innermost; jcp.topk_innermost = topk_innermost;
jcp.algorithm = algorithm; jcp.algorithm = algorithm;
jcp.bubble_inplace = bubble_inplace; jcp.bubble_inplace = bubble_inplace;
jcp.stable = stable;
jcp.sort_stride = static_cast<int>(I); jcp.sort_stride = static_cast<int>(I);
jcp.work_amount = static_cast<int>(I); jcp.work_amount = static_cast<int>(I);
jcp.bitonic_idx_cnt = 0; jcp.bitonic_idx_cnt = 0;
@ -2207,7 +2227,9 @@ inline void TopK::prepare_original_idx() {
bool shape_agnostic_alg = algorithm == TopKAlgorithm::topk_heap_sort || bool shape_agnostic_alg = algorithm == TopKAlgorithm::topk_heap_sort ||
(algorithm == TopKAlgorithm::topk_bubble_sort && !bubble_inplace); (algorithm == TopKAlgorithm::topk_bubble_sort && !bubble_inplace);
if (shape_agnostic_alg) { if (shape_agnostic_alg) {
if (topk_innermost) { bool use_idx_seq = stable ? topk_innermost && (layout == TopKLayoutType::topk_blocked || (top_k == 1 && !stable))
: topk_innermost;
if (use_idx_seq) {
if (vec_idx_seq.empty()) { if (vec_idx_seq.empty()) {
vec_idx_seq.resize(axis_dim); vec_idx_seq.resize(axis_dim);
std::iota(vec_idx_seq.begin(), vec_idx_seq.end(), 0); std::iota(vec_idx_seq.begin(), vec_idx_seq.end(), 0);

View File

@ -32,6 +32,7 @@ struct jit_topk_config_params {
bool sort_index; // sort by value or index. true: index; false: value bool sort_index; // sort by value or index. true: index; false: value
bool topk_innermost; // if topk sorting is applied on innermost dimension or other dimension bool topk_innermost; // if topk sorting is applied on innermost dimension or other dimension
bool bubble_inplace; // all the elements in sorting is right in the register, no need to load and store for each comparison bool bubble_inplace; // all the elements in sorting is right in the register, no need to load and store for each comparison
bool stable; // if require stable sorting
TopKLayoutType layout; // memory layout TopKLayoutType layout; // memory layout
TopKAlgorithm algorithm; // topk sorting algorithm TopKAlgorithm algorithm; // topk sorting algorithm
InferenceEngine::Precision precision; // precision InferenceEngine::Precision precision; // precision
@ -115,6 +116,7 @@ private:
bool topk_innermost; bool topk_innermost;
bool jit_mode; bool jit_mode;
bool sort_index; bool sort_index;
bool stable;
bool mode_max; bool mode_max;
int axis; int axis;
static const size_t TOPK_DATA = 0; static const size_t TOPK_DATA = 0;

View File

@ -71,6 +71,8 @@
#include "transformations/op_conversions/softsign_decomposition.hpp" #include "transformations/op_conversions/softsign_decomposition.hpp"
#include "transformations/op_conversions/softmax_decomposition.hpp" #include "transformations/op_conversions/softmax_decomposition.hpp"
#include "transformations/op_conversions/unique_decomposition.hpp" #include "transformations/op_conversions/unique_decomposition.hpp"
#include "transformations/op_conversions/convert_topk3.hpp"
#include "transformations/op_conversions/convert_topk11_downgrade.hpp"
#include "transformations/opset_conversions/convert_opset2_to_opset1.hpp" #include "transformations/opset_conversions/convert_opset2_to_opset1.hpp"
#include "transformations/opset_conversions/convert_opset3_to_opset2.hpp" #include "transformations/opset_conversions/convert_opset3_to_opset2.hpp"
#include "transformations/smart_reshape/matmul_sr.hpp" #include "transformations/smart_reshape/matmul_sr.hpp"
@ -398,6 +400,8 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
pass_config->disable<ov::pass::ConvertROIAlign9To3>(); pass_config->disable<ov::pass::ConvertROIAlign9To3>();
pass_config->disable<ov::pass::SoftSignDecomposition>(); pass_config->disable<ov::pass::SoftSignDecomposition>();
pass_config->disable<ov::pass::UniqueDecomposition>(); pass_config->disable<ov::pass::UniqueDecomposition>();
pass_config->disable<ov::pass::ConvertTopK3>();
pass_config->disable<ov::pass::ConvertTopK11ToTopK3>();
pass_config->enable<ov::pass::NormalizeL2Decomposition>(); pass_config->enable<ov::pass::NormalizeL2Decomposition>();
pass_config->enable<ov::pass::ConvertInterpolate1ToInterpolate4>(); pass_config->enable<ov::pass::ConvertInterpolate1ToInterpolate4>();

View File

@ -10,18 +10,20 @@
using namespace InferenceEngine; using namespace InferenceEngine;
using namespace CPUTestUtils; using namespace CPUTestUtils;
using namespace ov::test; using namespace ov::test;
using SortMode = ov::op::TopKMode;
using SortType = ov::op::TopKSortType;
namespace CPULayerTestsDefinitions { namespace CPULayerTestsDefinitions {
typedef std::tuple< typedef std::tuple<
int64_t, // keepK int64_t, // keepK
int64_t, // axis int64_t, // axis
ngraph::opset4::TopK::Mode, // mode SortMode, // mode
ngraph::opset4::TopK::SortType, // sort std::tuple<SortType, bool>, // sort and stable
ElementType, // Net precision ElementType, // Net precision
ElementType, // Input precision ElementType, // Input precision
ElementType, // Output precision ElementType, // Output precision
InputShape // inputShape InputShape // inputShape
> basicTopKParams; > basicTopKParams;
typedef std::tuple< typedef std::tuple<
@ -39,11 +41,13 @@ public:
std::tie(basicParamsSet, cpuParams, additionalConfig) = obj.param; std::tie(basicParamsSet, cpuParams, additionalConfig) = obj.param;
int64_t keepK, axis; int64_t keepK, axis;
ngraph::opset4::TopK::Mode mode; SortMode mode;
ngraph::opset4::TopK::SortType sort; std::tuple<SortType, bool> sortTypeStable;
ElementType netPrecision, inPrc, outPrc; ElementType netPrecision, inPrc, outPrc;
InputShape inputShape; InputShape inputShape;
std::tie(keepK, axis, mode, sort, netPrecision, inPrc, outPrc, inputShape) = basicParamsSet; std::tie(keepK, axis, mode, sortTypeStable, netPrecision, inPrc, outPrc, inputShape) = basicParamsSet;
SortType sort = std::get<0>(sortTypeStable);
bool stable = std::get<1>(sortTypeStable);
std::ostringstream result; std::ostringstream result;
bool staticShape = inputShape.first.rank() == 0; bool staticShape = inputShape.first.rank() == 0;
@ -52,6 +56,7 @@ public:
result << "axis=" << axis << "_"; result << "axis=" << axis << "_";
result << "mode=" << mode << "_"; result << "mode=" << mode << "_";
result << "sort=" << sort << "_"; result << "sort=" << sort << "_";
result << "stable=" << (stable ? "True" : "False") << "_";
result << "netPRC=" << netPrecision << "_"; result << "netPRC=" << netPrecision << "_";
result << "inPRC=" << inPrc << "_"; result << "inPRC=" << inPrc << "_";
result << "outPRC=" << outPrc << "_"; result << "outPRC=" << outPrc << "_";
@ -85,11 +90,13 @@ protected:
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams; std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
int64_t keepK; int64_t keepK;
ngraph::opset4::TopK::Mode mode; SortMode mode;
ngraph::opset4::TopK::SortType sort; std::tuple<SortType, bool> sortTypeStable;
ElementType inPrc, outPrc; ElementType inPrc, outPrc;
InputShape inputShape; InputShape inputShape;
std::tie(keepK, axis, mode, sort, netPrecision, inPrc, outPrc, inputShape) = basicParamsSet; std::tie(keepK, axis, mode, sortTypeStable, netPrecision, inPrc, outPrc, inputShape) = basicParamsSet;
sort = std::get<0>(sortTypeStable);
stable = std::get<1>(sortTypeStable);
if (additionalConfig[PluginConfigParams::KEY_ENFORCE_BF16] == PluginConfigParams::YES) if (additionalConfig[PluginConfigParams::KEY_ENFORCE_BF16] == PluginConfigParams::YES)
inPrc = outPrc = netPrecision = ElementType::bf16; inPrc = outPrc = netPrecision = ElementType::bf16;
@ -112,33 +119,34 @@ protected:
auto params = ngraph::builder::makeDynamicParams(netPrecision, {inputDynamicShapes[0]}); auto params = ngraph::builder::makeDynamicParams(netPrecision, {inputDynamicShapes[0]});
// static shape need specific const k to test different sorting algorithms, dynamic shape tests random param k // static shape need specific const k to test different sorting algorithms, dynamic shape tests random param k
std::shared_ptr<ngraph::opset4::TopK> topk; std::shared_ptr<ov::op::v11::TopK> topk;
if (staticShape) { if (staticShape) {
auto k = std::make_shared<ngraph::opset3::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{}, &keepK); auto k = std::make_shared<ov::op::v0::Constant>(ElementType::i64, ov::Shape{}, &keepK);
topk = std::dynamic_pointer_cast<ngraph::opset4::TopK>( topk = std::dynamic_pointer_cast<ov::op::v11::TopK>(
std::make_shared<ngraph::opset4::TopK>(params[0], k, axis, mode, sort)); std::make_shared<ov::op::v11::TopK>(params[0], k, axis, mode, sort, ElementType::i32, stable));
} else { } else {
auto k = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::Type_t::i64, inputDynamicShapes[1]); auto k = std::make_shared<ov::op::v0::Parameter>(ElementType::i64, inputDynamicShapes[1]);
params.push_back(k); params.push_back(k);
topk = std::dynamic_pointer_cast<ngraph::opset4::TopK>( topk = std::dynamic_pointer_cast<ov::op::v11::TopK>(
std::make_shared<ngraph::opset4::TopK>(params[0], k, axis, mode, sort)); std::make_shared<ov::op::v11::TopK>(params[0], k, axis, mode, sort, ElementType::i32, stable));
} }
topk->get_rt_info() = getCPUInfo(); topk->get_rt_info() = getCPUInfo();
ngraph::ResultVector results; ngraph::ResultVector results;
for (size_t i = 0; i < topk->get_output_size(); i++) { for (size_t i = 0; i < topk->get_output_size(); i++) {
results.push_back(std::make_shared<ngraph::opset4::Result>(topk->output(i))); results.push_back(std::make_shared<ov::op::v0::Result>(topk->output(i)));
} }
function = std::make_shared<ngraph::Function>(results, params, "TopK"); function = std::make_shared<ngraph::Function>(results, params, "TopK");
} }
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override { void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
inputs.clear(); inputs.clear();
const auto& funcInputs = function->inputs(); const auto& funcInputs = function->inputs();
// Spec TopK_3.md allows to use unstable sorting, thus generate unreapeated input data to avoid a. and b. // For unstable sorting, generate unrepeated input data to avoid a. and b. While for stable sorting,
// repeating values are explicitly set.
// a. Skip comparing of index results, because an element in actual index tensor can be different with // a. Skip comparing of index results, because an element in actual index tensor can be different with
// its counterpart in expected index tensor // its counterpart in expected index tensor
// b. If SortType is SORT_INDICES or NONE, the test program still needs to apply std::sort for all pairs // b. If SortType is SORT_INDICES or NONE, the test program still needs to apply std::sort for all pairs
@ -153,7 +161,11 @@ protected:
// For int32, deliberately set big numbers which are not accurately representable in fp32 // For int32, deliberately set big numbers which are not accurately representable in fp32
int start = netPrecision == ElementType::i32 ? pow(2, 30) + 1 : - static_cast<int>(size / 2); int start = netPrecision == ElementType::i32 ? pow(2, 30) + 1 : - static_cast<int>(size / 2);
std::iota(data.begin(), data.end(), start); size_t set_size = sort == SortType::SORT_VALUES && stable ? size / 2 : size;
std::iota(data.begin(), data.begin() + set_size, start);
if (sort == SortType::SORT_VALUES && stable) {
std::copy(data.begin(), data.begin() + set_size, data.begin() + set_size);
}
std::mt19937 gen(0); std::mt19937 gen(0);
std::shuffle(data.begin(), data.end(), gen); std::shuffle(data.begin(), data.end(), gen);
@ -178,7 +190,7 @@ protected:
if (O * A * I != size) if (O * A * I != size)
FAIL() << "Incorrect blob shape " << shape; FAIL() << "Incorrect blob shape " << shape;
auto *rawBlobDataPtr = static_cast<ngraph::bfloat16 *>(tensor.data()); auto *rawBlobDataPtr = static_cast<ov::bfloat16 *>(tensor.data());
for (size_t o = 0; o < O; o++) { for (size_t o = 0; o < O; o++) {
for (size_t i = 0; i < I; i++) { for (size_t i = 0; i < I; i++) {
std::vector<int> data(A); std::vector<int> data(A);
@ -188,7 +200,7 @@ protected:
std::mt19937 gen(seed); std::mt19937 gen(seed);
std::shuffle(data.begin(), data.end(), gen); std::shuffle(data.begin(), data.end(), gen);
for (size_t a = 0; a < A; a++) { for (size_t a = 0; a < A; a++) {
rawBlobDataPtr[o * A * I + a * I + i] = static_cast<ngraph::bfloat16>(data[a]); rawBlobDataPtr[o * A * I + a * I + i] = static_cast<ov::bfloat16>(data[a]);
} }
} }
} }
@ -198,20 +210,28 @@ protected:
inputs.insert({funcInputs[0].get_node_shared_ptr(), tensor}); inputs.insert({funcInputs[0].get_node_shared_ptr(), tensor});
if (!staticShape) { if (!staticShape) {
const auto& kPrecision = funcInputs[1].get_element_type(); generate_dynamic_k(funcInputs, targetInputStaticShapes);
const auto& kShape = targetInputStaticShapes[1];
const size_t startFrom = 1;
const size_t range = targetInputStaticShapes[0][axis];
const size_t seed = inferRequestNum++;
const auto kTensor = ov::test::utils::create_and_fill_tensor(kPrecision, kShape, range, startFrom, 1, seed);
inputs.insert({funcInputs[1].get_node_shared_ptr(), kTensor});
} }
} }
private:
void generate_dynamic_k(const std::vector<ov::Output<ov::Node>>& funcInputs,
const std::vector<ov::Shape>& targetInputStaticShapes) {
const auto& kPrecision = funcInputs[1].get_element_type();
const auto& kShape = targetInputStaticShapes[1];
const size_t startFrom = 1;
const size_t range = targetInputStaticShapes[0][axis];
const size_t seed = inferRequestNum++;
const auto kTensor = ov::test::utils::create_and_fill_tensor(kPrecision, kShape, range, startFrom, 1, seed);
inputs.insert({funcInputs[1].get_node_shared_ptr(), kTensor});
}
private: private:
int64_t axis; int64_t axis;
SortType sort;
bool stable;
size_t inferRequestNum = 0; size_t inferRequestNum = 0;
ElementType netPrecision; ElementType netPrecision;
bool staticShape; bool staticShape;
@ -236,14 +256,15 @@ std::vector<std::map<std::string, std::string>> additionalConfig = {
const std::vector<int64_t> axes = {0, 1, 2, 3}; const std::vector<int64_t> axes = {0, 1, 2, 3};
const std::vector<int64_t> k = {1, 5, 7, 18, 21}; const std::vector<int64_t> k = {1, 5, 7, 18, 21};
const std::vector<ngraph::opset4::TopK::Mode> modes = { const std::vector<SortMode> modes = {
ngraph::opset4::TopK::Mode::MIN, SortMode::MIN,
ngraph::opset4::TopK::Mode::MAX SortMode::MAX
}; };
const std::vector<ngraph::opset4::TopK::SortType> sortTypes = { const std::vector<std::tuple<SortType, bool>> sortTypeStable = {
ngraph::opset4::TopK::SortType::SORT_VALUES, std::tuple<SortType, bool>{SortType::SORT_VALUES, false},
ngraph::opset4::TopK::SortType::SORT_INDICES, std::tuple<SortType, bool>{SortType::SORT_VALUES, true},
std::tuple<SortType, bool>{SortType::SORT_INDICES, false}
}; };
std::vector<ov::test::InputShape> inputShapes = { std::vector<ov::test::InputShape> inputShapes = {
@ -266,7 +287,7 @@ INSTANTIATE_TEST_CASE_P(smoke_TopK, TopKLayerCPUTest,
::testing::ValuesIn(k), ::testing::ValuesIn(k),
::testing::ValuesIn(axes), ::testing::ValuesIn(axes),
::testing::ValuesIn(modes), ::testing::ValuesIn(modes),
::testing::ValuesIn(sortTypes), ::testing::ValuesIn(sortTypeStable),
::testing::ValuesIn(netPrecisions), ::testing::ValuesIn(netPrecisions),
::testing::Values(ElementType::undefined), ::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined), ::testing::Values(ElementType::undefined),
@ -281,7 +302,7 @@ INSTANTIATE_TEST_CASE_P(smoke_TopK_dynamic, TopKLayerCPUTest,
::testing::Values(1), ::testing::Values(1),
::testing::ValuesIn(axes), ::testing::ValuesIn(axes),
::testing::ValuesIn(modes), ::testing::ValuesIn(modes),
::testing::ValuesIn(sortTypes), ::testing::ValuesIn(sortTypeStable),
::testing::ValuesIn(netPrecisions), ::testing::ValuesIn(netPrecisions),
::testing::Values(ElementType::undefined), ::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined), ::testing::Values(ElementType::undefined),
@ -306,7 +327,7 @@ INSTANTIATE_TEST_CASE_P(smoke_TopK_int32, TopKLayerCPUTest,
::testing::ValuesIn(k_int32), ::testing::ValuesIn(k_int32),
::testing::ValuesIn(axes), ::testing::ValuesIn(axes),
::testing::ValuesIn(modes), ::testing::ValuesIn(modes),
::testing::ValuesIn(sortTypes), ::testing::ValuesIn(sortTypeStable),
::testing::Values(ElementType::i32), ::testing::Values(ElementType::i32),
::testing::Values(ElementType::undefined), ::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined), ::testing::Values(ElementType::undefined),
@ -321,7 +342,7 @@ INSTANTIATE_TEST_CASE_P(smoke_TopK_int32_dynamic, TopKLayerCPUTest,
::testing::Values(1), ::testing::Values(1),
::testing::ValuesIn(axes), ::testing::ValuesIn(axes),
::testing::ValuesIn(modes), ::testing::ValuesIn(modes),
::testing::ValuesIn(sortTypes), ::testing::ValuesIn(sortTypeStable),
::testing::Values(ElementType::i32), ::testing::Values(ElementType::i32),
::testing::Values(ElementType::undefined), ::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined), ::testing::Values(ElementType::undefined),
@ -344,7 +365,7 @@ INSTANTIATE_TEST_CASE_P(smoke_TopK_bubble_BLK_on_channel_horiz, TopKLayerCPUTest
::testing::Values(1), ::testing::Values(1),
::testing::Values(1), ::testing::Values(1),
::testing::ValuesIn(modes), ::testing::ValuesIn(modes),
::testing::ValuesIn(sortTypes), ::testing::ValuesIn(sortTypeStable),
::testing::ValuesIn(netPrecisions), ::testing::ValuesIn(netPrecisions),
::testing::Values(ElementType::undefined), ::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined), ::testing::Values(ElementType::undefined),
@ -359,7 +380,7 @@ INSTANTIATE_TEST_CASE_P(smoke_TopK_bubble_BLK_on_channel_horiz_dynamic, TopKLaye
::testing::Values(1), ::testing::Values(1),
::testing::Values(1), ::testing::Values(1),
::testing::ValuesIn(modes), ::testing::ValuesIn(modes),
::testing::ValuesIn(sortTypes), ::testing::ValuesIn(sortTypeStable),
::testing::ValuesIn(netPrecisions), ::testing::ValuesIn(netPrecisions),
::testing::Values(ElementType::undefined), ::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined), ::testing::Values(ElementType::undefined),
@ -381,8 +402,8 @@ INSTANTIATE_TEST_CASE_P(smoke_Top1, TopKLayerCPUTest,
::testing::Combine( ::testing::Combine(
::testing::Values(1), ::testing::Values(1),
::testing::Values(3), ::testing::Values(3),
::testing::Values(ngraph::opset4::TopK::Mode::MAX), ::testing::Values(SortMode::MAX),
::testing::Values(ngraph::opset4::TopK::SortType::SORT_INDICES), ::testing::Values(std::tuple<SortType, bool>(SortType::SORT_INDICES, false)),
::testing::ValuesIn(netPrecisions), ::testing::ValuesIn(netPrecisions),
::testing::Values(ElementType::undefined), ::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined), ::testing::Values(ElementType::undefined),
@ -396,8 +417,8 @@ INSTANTIATE_TEST_CASE_P(smoke_Top1_dynamic, TopKLayerCPUTest,
::testing::Combine( ::testing::Combine(
::testing::Values(1), ::testing::Values(1),
::testing::Values(3), ::testing::Values(3),
::testing::Values(ngraph::opset4::TopK::Mode::MAX), ::testing::Values(SortMode::MAX),
::testing::Values(ngraph::opset4::TopK::SortType::SORT_INDICES), ::testing::Values(std::tuple<SortType, bool>(SortType::SORT_INDICES, false)),
::testing::ValuesIn(netPrecisions), ::testing::ValuesIn(netPrecisions),
::testing::Values(ElementType::undefined), ::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined), ::testing::Values(ElementType::undefined),

View File

@ -80,6 +80,7 @@ CompareMap getCompareMap() {
#include "openvino/opsets/opset8_tbl.hpp" #include "openvino/opsets/opset8_tbl.hpp"
#include "openvino/opsets/opset9_tbl.hpp" #include "openvino/opsets/opset9_tbl.hpp"
#include "openvino/opsets/opset10_tbl.hpp" #include "openvino/opsets/opset10_tbl.hpp"
#include "openvino/opsets/opset11_tbl.hpp"
#include "ov_ops/opset_private_tbl.hpp" #include "ov_ops/opset_private_tbl.hpp"
#undef _OPENVINO_OP_REG #undef _OPENVINO_OP_REG

View File

@ -835,6 +835,7 @@ InputsMap getInputMap() {
#include "openvino/opsets/opset8_tbl.hpp" #include "openvino/opsets/opset8_tbl.hpp"
#include "openvino/opsets/opset9_tbl.hpp" #include "openvino/opsets/opset9_tbl.hpp"
#include "openvino/opsets/opset10_tbl.hpp" #include "openvino/opsets/opset10_tbl.hpp"
#include "openvino/opsets/opset11_tbl.hpp"
#include "ov_ops/opset_private_tbl.hpp" #include "ov_ops/opset_private_tbl.hpp"
#undef _OPENVINO_OP_REG #undef _OPENVINO_OP_REG

View File

@ -16,6 +16,8 @@
#include <ngraph/opsets/opset7.hpp> #include <ngraph/opsets/opset7.hpp>
#include <ngraph/opsets/opset8.hpp> #include <ngraph/opsets/opset8.hpp>
#include <ngraph/opsets/opset9.hpp> #include <ngraph/opsets/opset9.hpp>
#include <ngraph/opsets/opset10.hpp>
#include <ngraph/opsets/opset11.hpp>
#include "ngraph_functions/utils/data_utils.hpp" #include "ngraph_functions/utils/data_utils.hpp"
#include "openvino/core/partial_shape.hpp" #include "openvino/core/partial_shape.hpp"