[CPU] Implement TopK-11 to CPU plugin (#16522)
This commit is contained in:
parent
6d064d26cb
commit
35398e339d
@ -8,6 +8,7 @@
|
|||||||
#include <ngraph/runtime/reference/convert.hpp>
|
#include <ngraph/runtime/reference/convert.hpp>
|
||||||
#include <openvino/opsets/opset1.hpp>
|
#include <openvino/opsets/opset1.hpp>
|
||||||
#include <openvino/opsets/opset10.hpp>
|
#include <openvino/opsets/opset10.hpp>
|
||||||
|
#include <openvino/opsets/opset11.hpp>
|
||||||
#include <openvino/opsets/opset3.hpp>
|
#include <openvino/opsets/opset3.hpp>
|
||||||
#include <openvino/opsets/opset4.hpp>
|
#include <openvino/opsets/opset4.hpp>
|
||||||
#include <openvino/opsets/opset5.hpp>
|
#include <openvino/opsets/opset5.hpp>
|
||||||
@ -334,7 +335,9 @@ bool ov::pass::ConvertPrecision::run_on_model(const std::shared_ptr<ngraph::Func
|
|||||||
{opset9::MulticlassNms::get_type_info_static(), fuse_type_to_multiclass_nms},
|
{opset9::MulticlassNms::get_type_info_static(), fuse_type_to_multiclass_nms},
|
||||||
{opset9::GenerateProposals::get_type_info_static(), fuse_type_to_generate_proposals},
|
{opset9::GenerateProposals::get_type_info_static(), fuse_type_to_generate_proposals},
|
||||||
{opset6::CTCGreedyDecoderSeqLen::get_type_info_static(), fuse_type_to_ctc_greedy_decoder_seq_len},
|
{opset6::CTCGreedyDecoderSeqLen::get_type_info_static(), fuse_type_to_ctc_greedy_decoder_seq_len},
|
||||||
|
{opset1::TopK::get_type_info_static(), fuse_type_to_topk},
|
||||||
{opset4::TopK::get_type_info_static(), fuse_type_to_topk},
|
{opset4::TopK::get_type_info_static(), fuse_type_to_topk},
|
||||||
|
{opset11::TopK::get_type_info_static(), fuse_type_to_topk},
|
||||||
{opset8::MaxPool::get_type_info_static(), fuse_type_to_maxpool},
|
{opset8::MaxPool::get_type_info_static(), fuse_type_to_maxpool},
|
||||||
{opset4::NonZero::get_type_info_static(), fuse_type_to_nonzero},
|
{opset4::NonZero::get_type_info_static(), fuse_type_to_nonzero},
|
||||||
{opset4::Bucketize::get_type_info_static(), fuse_type_to_bucketize},
|
{opset4::Bucketize::get_type_info_static(), fuse_type_to_bucketize},
|
||||||
@ -653,7 +656,7 @@ bool fuse_type_to_generate_proposals(const std::shared_ptr<ngraph::Node>& node,
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool fuse_type_to_topk(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
bool fuse_type_to_topk(const std::shared_ptr<ngraph::Node>& node, const precisions_map& precisions) {
|
||||||
if (auto topk = ov::as_type_ptr<opset4::TopK>(node)) {
|
if (auto topk = ov::as_type_ptr<ov::op::util::TopKBase>(node)) {
|
||||||
return update_type(1, node, precisions, [&](const element::Type& to) {
|
return update_type(1, node, precisions, [&](const element::Type& to) {
|
||||||
topk->set_index_element_type(to);
|
topk->set_index_element_type(to);
|
||||||
});
|
});
|
||||||
|
@ -270,12 +270,12 @@ private:
|
|||||||
inline void topk_loop() {
|
inline void topk_loop() {
|
||||||
if (jcp_.algorithm == TopKAlgorithm::topk_bubble_sort) {
|
if (jcp_.algorithm == TopKAlgorithm::topk_bubble_sort) {
|
||||||
if (jcp_.layout == TopKLayoutType::topk_blocked && jcp_.topk_innermost) {
|
if (jcp_.layout == TopKLayoutType::topk_blocked && jcp_.topk_innermost) {
|
||||||
if (jcp_.top_k == 1) {
|
if (jcp_.top_k == 1 && !jcp_.stable) {
|
||||||
topk_bubble_horiz();
|
topk_bubble_horiz();
|
||||||
} else {
|
} else {
|
||||||
topk_bubble_BLK_on_channel_verti();
|
topk_bubble_BLK_on_channel_verti();
|
||||||
}
|
}
|
||||||
} else if (jcp_.topk_innermost && jcp_.top_k == 1) {
|
} else if (jcp_.topk_innermost && jcp_.top_k == 1 && !jcp_.stable) {
|
||||||
topk_bubble_horiz();
|
topk_bubble_horiz();
|
||||||
} else {
|
} else {
|
||||||
topk_bubble_vector();
|
topk_bubble_vector();
|
||||||
@ -1788,30 +1788,32 @@ private:
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
bool TopK::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
|
bool TopK::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
|
||||||
try {
|
try {
|
||||||
const auto topKOp = ngraph::as_type_ptr<const ngraph::op::v1::TopK>(op);
|
if (!one_of(op->get_type_info(), ov::op::v1::TopK::get_type_info_static(),
|
||||||
if (!topKOp) {
|
ov::op::v3::TopK::get_type_info_static(),
|
||||||
errorMessage = "Node is not an instance of the TopK from the operations set v1 or v3";
|
ov::op::v11::TopK::get_type_info_static())) {
|
||||||
|
errorMessage = "Node is not an instance of the TopK from the operation sets v1, v3 or v11";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto topKOp = ov::as_type_ptr<const ov::op::util::TopKBase>(op);
|
||||||
if (!isDynamicNgraphNode(op)) {
|
if (!isDynamicNgraphNode(op)) {
|
||||||
auto topKConst = std::dynamic_pointer_cast<const ngraph::opset1::Constant>(topKOp->get_input_node_shared_ptr(TOPK_K));
|
auto topKConst = std::dynamic_pointer_cast<const ov::op::v0::Constant>(topKOp->get_input_node_shared_ptr(TOPK_K));
|
||||||
if (!topKConst) {
|
if (!topKConst) {
|
||||||
errorMessage = "Second tensor is not constant in static shape mode";
|
errorMessage = "Second tensor is not constant in static shape mode";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (topKOp->get_mode() != ngraph::op::TopKMode::MAX &&
|
if (topKOp->get_mode() != ov::op::TopKMode::MAX &&
|
||||||
topKOp->get_mode() != ngraph::op::TopKMode::MIN) {
|
topKOp->get_mode() != ov::op::TopKMode::MIN) {
|
||||||
errorMessage = "Unsupported mode.";
|
errorMessage = "Unsupported mode.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!one_of(topKOp->get_sort_type(), ngraph::op::TopKSortType::NONE,
|
if (!one_of(topKOp->get_sort_type(), ov::op::TopKSortType::NONE,
|
||||||
ngraph::op::TopKSortType::SORT_VALUES,
|
ov::op::TopKSortType::SORT_VALUES,
|
||||||
ngraph::op::TopKSortType::SORT_INDICES)) {
|
ov::op::TopKSortType::SORT_INDICES)) {
|
||||||
errorMessage = "Unsupported sort type.";
|
errorMessage = "Unsupported sort type.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -1821,13 +1823,13 @@ bool TopK::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, s
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
TopK::TopK(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr context)
|
TopK::TopK(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context)
|
||||||
: Node(op, context, NgraphShapeInferFactory(op, PortMask(TOPK_K))) {
|
: Node(op, context, NgraphShapeInferFactory(op, PortMask(TOPK_K))) {
|
||||||
std::string errorMessage;
|
std::string errorMessage;
|
||||||
if (isSupportedOperation(op, errorMessage)) {
|
if (isSupportedOperation(op, errorMessage)) {
|
||||||
errorPrefix = "TopK layer with name '" + getName() + "'";
|
errorPrefix = "TopK layer with name '" + getName() + "'";
|
||||||
|
|
||||||
auto topKOp = ngraph::as_type_ptr<ngraph::op::v1::TopK>(op);
|
auto topKOp = ov::as_type_ptr<const ov::op::util::TopKBase>(op);
|
||||||
|
|
||||||
auto in_dims = topKOp->get_input_partial_shape(TOPK_DATA);
|
auto in_dims = topKOp->get_input_partial_shape(TOPK_DATA);
|
||||||
auto out_dims = topKOp->get_output_partial_shape(TOPK_DATA);
|
auto out_dims = topKOp->get_output_partial_shape(TOPK_DATA);
|
||||||
@ -1835,15 +1837,23 @@ TopK::TopK(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr con
|
|||||||
auto in_dims_size = in_dims.size();
|
auto in_dims_size = in_dims.size();
|
||||||
|
|
||||||
if (!isDynamicNgraphNode(op)) {
|
if (!isDynamicNgraphNode(op)) {
|
||||||
auto topKConst = std::dynamic_pointer_cast<const ngraph::opset1::Constant>(topKOp->get_input_node_shared_ptr(TOPK_K));
|
auto topKConst = std::dynamic_pointer_cast<const ov::op::v0::Constant>(topKOp->get_input_node_shared_ptr(TOPK_K));
|
||||||
if (!topKConst) {
|
if (!topKConst) {
|
||||||
IE_THROW() << errorPrefix << "gets non-constant second tensor in static shape mode!";
|
IE_THROW() << errorPrefix << "gets non-constant second tensor in static shape mode!";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
axis = topKOp->get_axis();
|
axis = topKOp->get_axis();
|
||||||
mode_max = topKOp->get_mode() == ngraph::op::TopKMode::MAX;
|
mode_max = topKOp->get_mode() == ov::op::TopKMode::MAX;
|
||||||
sort_index = topKOp->get_sort_type() == ngraph::op::TopKSortType::SORT_INDICES;
|
sort_index = topKOp->get_sort_type() == ov::op::TopKSortType::SORT_INDICES;
|
||||||
|
|
||||||
|
stable = false;
|
||||||
|
if (!sort_index) {
|
||||||
|
const auto topKOpV11 = ngraph::as_type_ptr<const ov::op::v11::TopK>(op);
|
||||||
|
if (topKOpV11) {
|
||||||
|
stable = topKOpV11->get_stable();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
top_k = 0;
|
top_k = 0;
|
||||||
preset_params_done = false;
|
preset_params_done = false;
|
||||||
@ -1959,7 +1969,10 @@ void TopK::preset_params() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (isDynamicNode()) {
|
if (isDynamicNode()) {
|
||||||
if ((layout == TopKLayoutType::topk_ncsp || layout == TopKLayoutType::topk_nspc) && topk_innermost) {
|
if (stable) {
|
||||||
|
algorithm = TopKAlgorithm::topk_bubble_sort;
|
||||||
|
bubble_inplace = false;
|
||||||
|
} else if ((layout == TopKLayoutType::topk_ncsp || layout == TopKLayoutType::topk_nspc) && topk_innermost) {
|
||||||
algorithm = TopKAlgorithm::topk_heap_sort;
|
algorithm = TopKAlgorithm::topk_heap_sort;
|
||||||
} else {
|
} else {
|
||||||
algorithm = TopKAlgorithm::topk_bubble_sort;
|
algorithm = TopKAlgorithm::topk_bubble_sort;
|
||||||
@ -2006,8 +2019,11 @@ void TopK::prepareParams() {
|
|||||||
// [case 1]: if 2 * (top_k + 1) + 2 <= count_xmm, thus top_k is small enough that the vector registers are sufficient
|
// [case 1]: if 2 * (top_k + 1) + 2 <= count_xmm, thus top_k is small enough that the vector registers are sufficient
|
||||||
// to keep all necessary data for sorting, no need to load and store frequently, use inplace bubble sort;
|
// to keep all necessary data for sorting, no need to load and store frequently, use inplace bubble sort;
|
||||||
// (horizotal sorting cases not included)
|
// (horizotal sorting cases not included)
|
||||||
// [case 2]: only when topk is imposed on innermost dimsension of planar(ncsp/nspc) layout, should heap sort be used;
|
// [case 2]: if stable sorting is required, bubble sort(topk_bubble_vector/topk_bubble_BLK_on_channel_verti) will be
|
||||||
// [case 3]: by default, use bitonic sort when alg_cost_bitonic < alg_cost_bubble, otherwise use bubble sort.
|
// applied currently, because among the implemented sorting algorithms, these bubble sort implementations
|
||||||
|
// are the only stable ones;
|
||||||
|
// [case 3]: only when topk is imposed on innermost dimsension of planar(ncsp/nspc) layout, should heap sort be used;
|
||||||
|
// [case 4]: by default, use bitonic sort when alg_cost_bitonic < alg_cost_bubble, otherwise use bubble sort.
|
||||||
// alg_cost_bitonic = (N / 4) * logN * (logN + 1)
|
// alg_cost_bitonic = (N / 4) * logN * (logN + 1)
|
||||||
// alg_cost_bubble = K * (K - 1) / 2 + (N - K) * K
|
// alg_cost_bubble = K * (K - 1) / 2 + (N - K) * K
|
||||||
// where, N = axis_dim, K = topk_k
|
// where, N = axis_dim, K = topk_k
|
||||||
@ -2018,6 +2034,9 @@ void TopK::prepareParams() {
|
|||||||
if (top_k <= count_xmm / 2 - 2) {
|
if (top_k <= count_xmm / 2 - 2) {
|
||||||
algorithm = TopKAlgorithm::topk_bubble_sort;
|
algorithm = TopKAlgorithm::topk_bubble_sort;
|
||||||
bubble_inplace = topk_innermost && top_k == 1 ? false : true;
|
bubble_inplace = topk_innermost && top_k == 1 ? false : true;
|
||||||
|
} else if (stable) {
|
||||||
|
algorithm = TopKAlgorithm::topk_bubble_sort;
|
||||||
|
bubble_inplace = false;
|
||||||
} else if ((layout == TopKLayoutType::topk_ncsp || layout == TopKLayoutType::topk_nspc) && topk_innermost) {
|
} else if ((layout == TopKLayoutType::topk_ncsp || layout == TopKLayoutType::topk_nspc) && topk_innermost) {
|
||||||
algorithm = TopKAlgorithm::topk_heap_sort;
|
algorithm = TopKAlgorithm::topk_heap_sort;
|
||||||
} else {
|
} else {
|
||||||
@ -2074,6 +2093,7 @@ void TopK::createPrimitive() {
|
|||||||
jcp.topk_innermost = topk_innermost;
|
jcp.topk_innermost = topk_innermost;
|
||||||
jcp.algorithm = algorithm;
|
jcp.algorithm = algorithm;
|
||||||
jcp.bubble_inplace = bubble_inplace;
|
jcp.bubble_inplace = bubble_inplace;
|
||||||
|
jcp.stable = stable;
|
||||||
jcp.sort_stride = static_cast<int>(I);
|
jcp.sort_stride = static_cast<int>(I);
|
||||||
jcp.work_amount = static_cast<int>(I);
|
jcp.work_amount = static_cast<int>(I);
|
||||||
jcp.bitonic_idx_cnt = 0;
|
jcp.bitonic_idx_cnt = 0;
|
||||||
@ -2207,7 +2227,9 @@ inline void TopK::prepare_original_idx() {
|
|||||||
bool shape_agnostic_alg = algorithm == TopKAlgorithm::topk_heap_sort ||
|
bool shape_agnostic_alg = algorithm == TopKAlgorithm::topk_heap_sort ||
|
||||||
(algorithm == TopKAlgorithm::topk_bubble_sort && !bubble_inplace);
|
(algorithm == TopKAlgorithm::topk_bubble_sort && !bubble_inplace);
|
||||||
if (shape_agnostic_alg) {
|
if (shape_agnostic_alg) {
|
||||||
if (topk_innermost) {
|
bool use_idx_seq = stable ? topk_innermost && (layout == TopKLayoutType::topk_blocked || (top_k == 1 && !stable))
|
||||||
|
: topk_innermost;
|
||||||
|
if (use_idx_seq) {
|
||||||
if (vec_idx_seq.empty()) {
|
if (vec_idx_seq.empty()) {
|
||||||
vec_idx_seq.resize(axis_dim);
|
vec_idx_seq.resize(axis_dim);
|
||||||
std::iota(vec_idx_seq.begin(), vec_idx_seq.end(), 0);
|
std::iota(vec_idx_seq.begin(), vec_idx_seq.end(), 0);
|
||||||
|
@ -32,6 +32,7 @@ struct jit_topk_config_params {
|
|||||||
bool sort_index; // sort by value or index. true: index; false: value
|
bool sort_index; // sort by value or index. true: index; false: value
|
||||||
bool topk_innermost; // if topk sorting is applied on innermost dimension or other dimension
|
bool topk_innermost; // if topk sorting is applied on innermost dimension or other dimension
|
||||||
bool bubble_inplace; // all the elements in sorting is right in the register, no need to load and store for each comparison
|
bool bubble_inplace; // all the elements in sorting is right in the register, no need to load and store for each comparison
|
||||||
|
bool stable; // if require stable sorting
|
||||||
TopKLayoutType layout; // memory layout
|
TopKLayoutType layout; // memory layout
|
||||||
TopKAlgorithm algorithm; // topk sorting algorithm
|
TopKAlgorithm algorithm; // topk sorting algorithm
|
||||||
InferenceEngine::Precision precision; // precision
|
InferenceEngine::Precision precision; // precision
|
||||||
@ -115,6 +116,7 @@ private:
|
|||||||
bool topk_innermost;
|
bool topk_innermost;
|
||||||
bool jit_mode;
|
bool jit_mode;
|
||||||
bool sort_index;
|
bool sort_index;
|
||||||
|
bool stable;
|
||||||
bool mode_max;
|
bool mode_max;
|
||||||
int axis;
|
int axis;
|
||||||
static const size_t TOPK_DATA = 0;
|
static const size_t TOPK_DATA = 0;
|
||||||
|
@ -71,6 +71,8 @@
|
|||||||
#include "transformations/op_conversions/softsign_decomposition.hpp"
|
#include "transformations/op_conversions/softsign_decomposition.hpp"
|
||||||
#include "transformations/op_conversions/softmax_decomposition.hpp"
|
#include "transformations/op_conversions/softmax_decomposition.hpp"
|
||||||
#include "transformations/op_conversions/unique_decomposition.hpp"
|
#include "transformations/op_conversions/unique_decomposition.hpp"
|
||||||
|
#include "transformations/op_conversions/convert_topk3.hpp"
|
||||||
|
#include "transformations/op_conversions/convert_topk11_downgrade.hpp"
|
||||||
#include "transformations/opset_conversions/convert_opset2_to_opset1.hpp"
|
#include "transformations/opset_conversions/convert_opset2_to_opset1.hpp"
|
||||||
#include "transformations/opset_conversions/convert_opset3_to_opset2.hpp"
|
#include "transformations/opset_conversions/convert_opset3_to_opset2.hpp"
|
||||||
#include "transformations/smart_reshape/matmul_sr.hpp"
|
#include "transformations/smart_reshape/matmul_sr.hpp"
|
||||||
@ -398,6 +400,8 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
|
|||||||
pass_config->disable<ov::pass::ConvertROIAlign9To3>();
|
pass_config->disable<ov::pass::ConvertROIAlign9To3>();
|
||||||
pass_config->disable<ov::pass::SoftSignDecomposition>();
|
pass_config->disable<ov::pass::SoftSignDecomposition>();
|
||||||
pass_config->disable<ov::pass::UniqueDecomposition>();
|
pass_config->disable<ov::pass::UniqueDecomposition>();
|
||||||
|
pass_config->disable<ov::pass::ConvertTopK3>();
|
||||||
|
pass_config->disable<ov::pass::ConvertTopK11ToTopK3>();
|
||||||
|
|
||||||
pass_config->enable<ov::pass::NormalizeL2Decomposition>();
|
pass_config->enable<ov::pass::NormalizeL2Decomposition>();
|
||||||
pass_config->enable<ov::pass::ConvertInterpolate1ToInterpolate4>();
|
pass_config->enable<ov::pass::ConvertInterpolate1ToInterpolate4>();
|
||||||
|
@ -10,14 +10,16 @@
|
|||||||
using namespace InferenceEngine;
|
using namespace InferenceEngine;
|
||||||
using namespace CPUTestUtils;
|
using namespace CPUTestUtils;
|
||||||
using namespace ov::test;
|
using namespace ov::test;
|
||||||
|
using SortMode = ov::op::TopKMode;
|
||||||
|
using SortType = ov::op::TopKSortType;
|
||||||
|
|
||||||
namespace CPULayerTestsDefinitions {
|
namespace CPULayerTestsDefinitions {
|
||||||
|
|
||||||
typedef std::tuple<
|
typedef std::tuple<
|
||||||
int64_t, // keepK
|
int64_t, // keepK
|
||||||
int64_t, // axis
|
int64_t, // axis
|
||||||
ngraph::opset4::TopK::Mode, // mode
|
SortMode, // mode
|
||||||
ngraph::opset4::TopK::SortType, // sort
|
std::tuple<SortType, bool>, // sort and stable
|
||||||
ElementType, // Net precision
|
ElementType, // Net precision
|
||||||
ElementType, // Input precision
|
ElementType, // Input precision
|
||||||
ElementType, // Output precision
|
ElementType, // Output precision
|
||||||
@ -39,11 +41,13 @@ public:
|
|||||||
std::tie(basicParamsSet, cpuParams, additionalConfig) = obj.param;
|
std::tie(basicParamsSet, cpuParams, additionalConfig) = obj.param;
|
||||||
|
|
||||||
int64_t keepK, axis;
|
int64_t keepK, axis;
|
||||||
ngraph::opset4::TopK::Mode mode;
|
SortMode mode;
|
||||||
ngraph::opset4::TopK::SortType sort;
|
std::tuple<SortType, bool> sortTypeStable;
|
||||||
ElementType netPrecision, inPrc, outPrc;
|
ElementType netPrecision, inPrc, outPrc;
|
||||||
InputShape inputShape;
|
InputShape inputShape;
|
||||||
std::tie(keepK, axis, mode, sort, netPrecision, inPrc, outPrc, inputShape) = basicParamsSet;
|
std::tie(keepK, axis, mode, sortTypeStable, netPrecision, inPrc, outPrc, inputShape) = basicParamsSet;
|
||||||
|
SortType sort = std::get<0>(sortTypeStable);
|
||||||
|
bool stable = std::get<1>(sortTypeStable);
|
||||||
|
|
||||||
std::ostringstream result;
|
std::ostringstream result;
|
||||||
bool staticShape = inputShape.first.rank() == 0;
|
bool staticShape = inputShape.first.rank() == 0;
|
||||||
@ -52,6 +56,7 @@ public:
|
|||||||
result << "axis=" << axis << "_";
|
result << "axis=" << axis << "_";
|
||||||
result << "mode=" << mode << "_";
|
result << "mode=" << mode << "_";
|
||||||
result << "sort=" << sort << "_";
|
result << "sort=" << sort << "_";
|
||||||
|
result << "stable=" << (stable ? "True" : "False") << "_";
|
||||||
result << "netPRC=" << netPrecision << "_";
|
result << "netPRC=" << netPrecision << "_";
|
||||||
result << "inPRC=" << inPrc << "_";
|
result << "inPRC=" << inPrc << "_";
|
||||||
result << "outPRC=" << outPrc << "_";
|
result << "outPRC=" << outPrc << "_";
|
||||||
@ -85,11 +90,13 @@ protected:
|
|||||||
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
|
std::tie(inFmts, outFmts, priority, selectedType) = cpuParams;
|
||||||
|
|
||||||
int64_t keepK;
|
int64_t keepK;
|
||||||
ngraph::opset4::TopK::Mode mode;
|
SortMode mode;
|
||||||
ngraph::opset4::TopK::SortType sort;
|
std::tuple<SortType, bool> sortTypeStable;
|
||||||
ElementType inPrc, outPrc;
|
ElementType inPrc, outPrc;
|
||||||
InputShape inputShape;
|
InputShape inputShape;
|
||||||
std::tie(keepK, axis, mode, sort, netPrecision, inPrc, outPrc, inputShape) = basicParamsSet;
|
std::tie(keepK, axis, mode, sortTypeStable, netPrecision, inPrc, outPrc, inputShape) = basicParamsSet;
|
||||||
|
sort = std::get<0>(sortTypeStable);
|
||||||
|
stable = std::get<1>(sortTypeStable);
|
||||||
|
|
||||||
if (additionalConfig[PluginConfigParams::KEY_ENFORCE_BF16] == PluginConfigParams::YES)
|
if (additionalConfig[PluginConfigParams::KEY_ENFORCE_BF16] == PluginConfigParams::YES)
|
||||||
inPrc = outPrc = netPrecision = ElementType::bf16;
|
inPrc = outPrc = netPrecision = ElementType::bf16;
|
||||||
@ -112,33 +119,34 @@ protected:
|
|||||||
auto params = ngraph::builder::makeDynamicParams(netPrecision, {inputDynamicShapes[0]});
|
auto params = ngraph::builder::makeDynamicParams(netPrecision, {inputDynamicShapes[0]});
|
||||||
|
|
||||||
// static shape need specific const k to test different sorting algorithms, dynamic shape tests random param k
|
// static shape need specific const k to test different sorting algorithms, dynamic shape tests random param k
|
||||||
std::shared_ptr<ngraph::opset4::TopK> topk;
|
std::shared_ptr<ov::op::v11::TopK> topk;
|
||||||
if (staticShape) {
|
if (staticShape) {
|
||||||
auto k = std::make_shared<ngraph::opset3::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{}, &keepK);
|
auto k = std::make_shared<ov::op::v0::Constant>(ElementType::i64, ov::Shape{}, &keepK);
|
||||||
topk = std::dynamic_pointer_cast<ngraph::opset4::TopK>(
|
topk = std::dynamic_pointer_cast<ov::op::v11::TopK>(
|
||||||
std::make_shared<ngraph::opset4::TopK>(params[0], k, axis, mode, sort));
|
std::make_shared<ov::op::v11::TopK>(params[0], k, axis, mode, sort, ElementType::i32, stable));
|
||||||
} else {
|
} else {
|
||||||
auto k = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::Type_t::i64, inputDynamicShapes[1]);
|
auto k = std::make_shared<ov::op::v0::Parameter>(ElementType::i64, inputDynamicShapes[1]);
|
||||||
params.push_back(k);
|
params.push_back(k);
|
||||||
topk = std::dynamic_pointer_cast<ngraph::opset4::TopK>(
|
topk = std::dynamic_pointer_cast<ov::op::v11::TopK>(
|
||||||
std::make_shared<ngraph::opset4::TopK>(params[0], k, axis, mode, sort));
|
std::make_shared<ov::op::v11::TopK>(params[0], k, axis, mode, sort, ElementType::i32, stable));
|
||||||
}
|
}
|
||||||
|
|
||||||
topk->get_rt_info() = getCPUInfo();
|
topk->get_rt_info() = getCPUInfo();
|
||||||
|
|
||||||
ngraph::ResultVector results;
|
ngraph::ResultVector results;
|
||||||
for (size_t i = 0; i < topk->get_output_size(); i++) {
|
for (size_t i = 0; i < topk->get_output_size(); i++) {
|
||||||
results.push_back(std::make_shared<ngraph::opset4::Result>(topk->output(i)));
|
results.push_back(std::make_shared<ov::op::v0::Result>(topk->output(i)));
|
||||||
}
|
}
|
||||||
|
|
||||||
function = std::make_shared<ngraph::Function>(results, params, "TopK");
|
function = std::make_shared<ngraph::Function>(results, params, "TopK");
|
||||||
}
|
}
|
||||||
|
|
||||||
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override {
|
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
|
||||||
inputs.clear();
|
inputs.clear();
|
||||||
const auto& funcInputs = function->inputs();
|
const auto& funcInputs = function->inputs();
|
||||||
|
|
||||||
// Spec TopK_3.md allows to use unstable sorting, thus generate unreapeated input data to avoid a. and b.
|
// For unstable sorting, generate unrepeated input data to avoid a. and b. While for stable sorting,
|
||||||
|
// repeating values are explicitly set.
|
||||||
// a. Skip comparing of index results, because an element in actual index tensor can be different with
|
// a. Skip comparing of index results, because an element in actual index tensor can be different with
|
||||||
// its counterpart in expected index tensor
|
// its counterpart in expected index tensor
|
||||||
// b. If SortType is SORT_INDICES or NONE, the test program still needs to apply std::sort for all pairs
|
// b. If SortType is SORT_INDICES or NONE, the test program still needs to apply std::sort for all pairs
|
||||||
@ -153,7 +161,11 @@ protected:
|
|||||||
|
|
||||||
// For int32, deliberately set big numbers which are not accurately representable in fp32
|
// For int32, deliberately set big numbers which are not accurately representable in fp32
|
||||||
int start = netPrecision == ElementType::i32 ? pow(2, 30) + 1 : - static_cast<int>(size / 2);
|
int start = netPrecision == ElementType::i32 ? pow(2, 30) + 1 : - static_cast<int>(size / 2);
|
||||||
std::iota(data.begin(), data.end(), start);
|
size_t set_size = sort == SortType::SORT_VALUES && stable ? size / 2 : size;
|
||||||
|
std::iota(data.begin(), data.begin() + set_size, start);
|
||||||
|
if (sort == SortType::SORT_VALUES && stable) {
|
||||||
|
std::copy(data.begin(), data.begin() + set_size, data.begin() + set_size);
|
||||||
|
}
|
||||||
std::mt19937 gen(0);
|
std::mt19937 gen(0);
|
||||||
std::shuffle(data.begin(), data.end(), gen);
|
std::shuffle(data.begin(), data.end(), gen);
|
||||||
|
|
||||||
@ -178,7 +190,7 @@ protected:
|
|||||||
if (O * A * I != size)
|
if (O * A * I != size)
|
||||||
FAIL() << "Incorrect blob shape " << shape;
|
FAIL() << "Incorrect blob shape " << shape;
|
||||||
|
|
||||||
auto *rawBlobDataPtr = static_cast<ngraph::bfloat16 *>(tensor.data());
|
auto *rawBlobDataPtr = static_cast<ov::bfloat16 *>(tensor.data());
|
||||||
for (size_t o = 0; o < O; o++) {
|
for (size_t o = 0; o < O; o++) {
|
||||||
for (size_t i = 0; i < I; i++) {
|
for (size_t i = 0; i < I; i++) {
|
||||||
std::vector<int> data(A);
|
std::vector<int> data(A);
|
||||||
@ -188,7 +200,7 @@ protected:
|
|||||||
std::mt19937 gen(seed);
|
std::mt19937 gen(seed);
|
||||||
std::shuffle(data.begin(), data.end(), gen);
|
std::shuffle(data.begin(), data.end(), gen);
|
||||||
for (size_t a = 0; a < A; a++) {
|
for (size_t a = 0; a < A; a++) {
|
||||||
rawBlobDataPtr[o * A * I + a * I + i] = static_cast<ngraph::bfloat16>(data[a]);
|
rawBlobDataPtr[o * A * I + a * I + i] = static_cast<ov::bfloat16>(data[a]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -198,6 +210,13 @@ protected:
|
|||||||
inputs.insert({funcInputs[0].get_node_shared_ptr(), tensor});
|
inputs.insert({funcInputs[0].get_node_shared_ptr(), tensor});
|
||||||
|
|
||||||
if (!staticShape) {
|
if (!staticShape) {
|
||||||
|
generate_dynamic_k(funcInputs, targetInputStaticShapes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void generate_dynamic_k(const std::vector<ov::Output<ov::Node>>& funcInputs,
|
||||||
|
const std::vector<ov::Shape>& targetInputStaticShapes) {
|
||||||
const auto& kPrecision = funcInputs[1].get_element_type();
|
const auto& kPrecision = funcInputs[1].get_element_type();
|
||||||
const auto& kShape = targetInputStaticShapes[1];
|
const auto& kShape = targetInputStaticShapes[1];
|
||||||
|
|
||||||
@ -208,10 +227,11 @@ protected:
|
|||||||
|
|
||||||
inputs.insert({funcInputs[1].get_node_shared_ptr(), kTensor});
|
inputs.insert({funcInputs[1].get_node_shared_ptr(), kTensor});
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
int64_t axis;
|
int64_t axis;
|
||||||
|
SortType sort;
|
||||||
|
bool stable;
|
||||||
size_t inferRequestNum = 0;
|
size_t inferRequestNum = 0;
|
||||||
ElementType netPrecision;
|
ElementType netPrecision;
|
||||||
bool staticShape;
|
bool staticShape;
|
||||||
@ -236,14 +256,15 @@ std::vector<std::map<std::string, std::string>> additionalConfig = {
|
|||||||
const std::vector<int64_t> axes = {0, 1, 2, 3};
|
const std::vector<int64_t> axes = {0, 1, 2, 3};
|
||||||
const std::vector<int64_t> k = {1, 5, 7, 18, 21};
|
const std::vector<int64_t> k = {1, 5, 7, 18, 21};
|
||||||
|
|
||||||
const std::vector<ngraph::opset4::TopK::Mode> modes = {
|
const std::vector<SortMode> modes = {
|
||||||
ngraph::opset4::TopK::Mode::MIN,
|
SortMode::MIN,
|
||||||
ngraph::opset4::TopK::Mode::MAX
|
SortMode::MAX
|
||||||
};
|
};
|
||||||
|
|
||||||
const std::vector<ngraph::opset4::TopK::SortType> sortTypes = {
|
const std::vector<std::tuple<SortType, bool>> sortTypeStable = {
|
||||||
ngraph::opset4::TopK::SortType::SORT_VALUES,
|
std::tuple<SortType, bool>{SortType::SORT_VALUES, false},
|
||||||
ngraph::opset4::TopK::SortType::SORT_INDICES,
|
std::tuple<SortType, bool>{SortType::SORT_VALUES, true},
|
||||||
|
std::tuple<SortType, bool>{SortType::SORT_INDICES, false}
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<ov::test::InputShape> inputShapes = {
|
std::vector<ov::test::InputShape> inputShapes = {
|
||||||
@ -266,7 +287,7 @@ INSTANTIATE_TEST_CASE_P(smoke_TopK, TopKLayerCPUTest,
|
|||||||
::testing::ValuesIn(k),
|
::testing::ValuesIn(k),
|
||||||
::testing::ValuesIn(axes),
|
::testing::ValuesIn(axes),
|
||||||
::testing::ValuesIn(modes),
|
::testing::ValuesIn(modes),
|
||||||
::testing::ValuesIn(sortTypes),
|
::testing::ValuesIn(sortTypeStable),
|
||||||
::testing::ValuesIn(netPrecisions),
|
::testing::ValuesIn(netPrecisions),
|
||||||
::testing::Values(ElementType::undefined),
|
::testing::Values(ElementType::undefined),
|
||||||
::testing::Values(ElementType::undefined),
|
::testing::Values(ElementType::undefined),
|
||||||
@ -281,7 +302,7 @@ INSTANTIATE_TEST_CASE_P(smoke_TopK_dynamic, TopKLayerCPUTest,
|
|||||||
::testing::Values(1),
|
::testing::Values(1),
|
||||||
::testing::ValuesIn(axes),
|
::testing::ValuesIn(axes),
|
||||||
::testing::ValuesIn(modes),
|
::testing::ValuesIn(modes),
|
||||||
::testing::ValuesIn(sortTypes),
|
::testing::ValuesIn(sortTypeStable),
|
||||||
::testing::ValuesIn(netPrecisions),
|
::testing::ValuesIn(netPrecisions),
|
||||||
::testing::Values(ElementType::undefined),
|
::testing::Values(ElementType::undefined),
|
||||||
::testing::Values(ElementType::undefined),
|
::testing::Values(ElementType::undefined),
|
||||||
@ -306,7 +327,7 @@ INSTANTIATE_TEST_CASE_P(smoke_TopK_int32, TopKLayerCPUTest,
|
|||||||
::testing::ValuesIn(k_int32),
|
::testing::ValuesIn(k_int32),
|
||||||
::testing::ValuesIn(axes),
|
::testing::ValuesIn(axes),
|
||||||
::testing::ValuesIn(modes),
|
::testing::ValuesIn(modes),
|
||||||
::testing::ValuesIn(sortTypes),
|
::testing::ValuesIn(sortTypeStable),
|
||||||
::testing::Values(ElementType::i32),
|
::testing::Values(ElementType::i32),
|
||||||
::testing::Values(ElementType::undefined),
|
::testing::Values(ElementType::undefined),
|
||||||
::testing::Values(ElementType::undefined),
|
::testing::Values(ElementType::undefined),
|
||||||
@ -321,7 +342,7 @@ INSTANTIATE_TEST_CASE_P(smoke_TopK_int32_dynamic, TopKLayerCPUTest,
|
|||||||
::testing::Values(1),
|
::testing::Values(1),
|
||||||
::testing::ValuesIn(axes),
|
::testing::ValuesIn(axes),
|
||||||
::testing::ValuesIn(modes),
|
::testing::ValuesIn(modes),
|
||||||
::testing::ValuesIn(sortTypes),
|
::testing::ValuesIn(sortTypeStable),
|
||||||
::testing::Values(ElementType::i32),
|
::testing::Values(ElementType::i32),
|
||||||
::testing::Values(ElementType::undefined),
|
::testing::Values(ElementType::undefined),
|
||||||
::testing::Values(ElementType::undefined),
|
::testing::Values(ElementType::undefined),
|
||||||
@ -344,7 +365,7 @@ INSTANTIATE_TEST_CASE_P(smoke_TopK_bubble_BLK_on_channel_horiz, TopKLayerCPUTest
|
|||||||
::testing::Values(1),
|
::testing::Values(1),
|
||||||
::testing::Values(1),
|
::testing::Values(1),
|
||||||
::testing::ValuesIn(modes),
|
::testing::ValuesIn(modes),
|
||||||
::testing::ValuesIn(sortTypes),
|
::testing::ValuesIn(sortTypeStable),
|
||||||
::testing::ValuesIn(netPrecisions),
|
::testing::ValuesIn(netPrecisions),
|
||||||
::testing::Values(ElementType::undefined),
|
::testing::Values(ElementType::undefined),
|
||||||
::testing::Values(ElementType::undefined),
|
::testing::Values(ElementType::undefined),
|
||||||
@ -359,7 +380,7 @@ INSTANTIATE_TEST_CASE_P(smoke_TopK_bubble_BLK_on_channel_horiz_dynamic, TopKLaye
|
|||||||
::testing::Values(1),
|
::testing::Values(1),
|
||||||
::testing::Values(1),
|
::testing::Values(1),
|
||||||
::testing::ValuesIn(modes),
|
::testing::ValuesIn(modes),
|
||||||
::testing::ValuesIn(sortTypes),
|
::testing::ValuesIn(sortTypeStable),
|
||||||
::testing::ValuesIn(netPrecisions),
|
::testing::ValuesIn(netPrecisions),
|
||||||
::testing::Values(ElementType::undefined),
|
::testing::Values(ElementType::undefined),
|
||||||
::testing::Values(ElementType::undefined),
|
::testing::Values(ElementType::undefined),
|
||||||
@ -381,8 +402,8 @@ INSTANTIATE_TEST_CASE_P(smoke_Top1, TopKLayerCPUTest,
|
|||||||
::testing::Combine(
|
::testing::Combine(
|
||||||
::testing::Values(1),
|
::testing::Values(1),
|
||||||
::testing::Values(3),
|
::testing::Values(3),
|
||||||
::testing::Values(ngraph::opset4::TopK::Mode::MAX),
|
::testing::Values(SortMode::MAX),
|
||||||
::testing::Values(ngraph::opset4::TopK::SortType::SORT_INDICES),
|
::testing::Values(std::tuple<SortType, bool>(SortType::SORT_INDICES, false)),
|
||||||
::testing::ValuesIn(netPrecisions),
|
::testing::ValuesIn(netPrecisions),
|
||||||
::testing::Values(ElementType::undefined),
|
::testing::Values(ElementType::undefined),
|
||||||
::testing::Values(ElementType::undefined),
|
::testing::Values(ElementType::undefined),
|
||||||
@ -396,8 +417,8 @@ INSTANTIATE_TEST_CASE_P(smoke_Top1_dynamic, TopKLayerCPUTest,
|
|||||||
::testing::Combine(
|
::testing::Combine(
|
||||||
::testing::Values(1),
|
::testing::Values(1),
|
||||||
::testing::Values(3),
|
::testing::Values(3),
|
||||||
::testing::Values(ngraph::opset4::TopK::Mode::MAX),
|
::testing::Values(SortMode::MAX),
|
||||||
::testing::Values(ngraph::opset4::TopK::SortType::SORT_INDICES),
|
::testing::Values(std::tuple<SortType, bool>(SortType::SORT_INDICES, false)),
|
||||||
::testing::ValuesIn(netPrecisions),
|
::testing::ValuesIn(netPrecisions),
|
||||||
::testing::Values(ElementType::undefined),
|
::testing::Values(ElementType::undefined),
|
||||||
::testing::Values(ElementType::undefined),
|
::testing::Values(ElementType::undefined),
|
||||||
|
@ -80,6 +80,7 @@ CompareMap getCompareMap() {
|
|||||||
#include "openvino/opsets/opset8_tbl.hpp"
|
#include "openvino/opsets/opset8_tbl.hpp"
|
||||||
#include "openvino/opsets/opset9_tbl.hpp"
|
#include "openvino/opsets/opset9_tbl.hpp"
|
||||||
#include "openvino/opsets/opset10_tbl.hpp"
|
#include "openvino/opsets/opset10_tbl.hpp"
|
||||||
|
#include "openvino/opsets/opset11_tbl.hpp"
|
||||||
|
|
||||||
#include "ov_ops/opset_private_tbl.hpp"
|
#include "ov_ops/opset_private_tbl.hpp"
|
||||||
#undef _OPENVINO_OP_REG
|
#undef _OPENVINO_OP_REG
|
||||||
|
@ -835,6 +835,7 @@ InputsMap getInputMap() {
|
|||||||
#include "openvino/opsets/opset8_tbl.hpp"
|
#include "openvino/opsets/opset8_tbl.hpp"
|
||||||
#include "openvino/opsets/opset9_tbl.hpp"
|
#include "openvino/opsets/opset9_tbl.hpp"
|
||||||
#include "openvino/opsets/opset10_tbl.hpp"
|
#include "openvino/opsets/opset10_tbl.hpp"
|
||||||
|
#include "openvino/opsets/opset11_tbl.hpp"
|
||||||
|
|
||||||
#include "ov_ops/opset_private_tbl.hpp"
|
#include "ov_ops/opset_private_tbl.hpp"
|
||||||
#undef _OPENVINO_OP_REG
|
#undef _OPENVINO_OP_REG
|
||||||
|
@ -16,6 +16,8 @@
|
|||||||
#include <ngraph/opsets/opset7.hpp>
|
#include <ngraph/opsets/opset7.hpp>
|
||||||
#include <ngraph/opsets/opset8.hpp>
|
#include <ngraph/opsets/opset8.hpp>
|
||||||
#include <ngraph/opsets/opset9.hpp>
|
#include <ngraph/opsets/opset9.hpp>
|
||||||
|
#include <ngraph/opsets/opset10.hpp>
|
||||||
|
#include <ngraph/opsets/opset11.hpp>
|
||||||
|
|
||||||
#include "ngraph_functions/utils/data_utils.hpp"
|
#include "ngraph_functions/utils/data_utils.hpp"
|
||||||
#include "openvino/core/partial_shape.hpp"
|
#include "openvino/core/partial_shape.hpp"
|
||||||
|
Loading…
Reference in New Issue
Block a user