TopKLayerTest
to API2.0 (#19738)
This commit is contained in:
parent
47fe50ca35
commit
7bb22b43b3
@ -4,15 +4,15 @@
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "single_layer_tests/topk.hpp"
|
||||
#include "single_op_tests/topk.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
using ov::test::TopKLayerTest;
|
||||
|
||||
namespace {
|
||||
|
||||
const std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16
|
||||
const std::vector<ov::element::Type> model_types = {
|
||||
ov::element::f32,
|
||||
ov::element::f16
|
||||
};
|
||||
|
||||
const std::vector<int64_t> axes = {
|
||||
@ -30,28 +30,28 @@ const std::vector<int64_t> k = {
|
||||
21
|
||||
};
|
||||
|
||||
const std::vector<ngraph::opset4::TopK::Mode> modes = {
|
||||
ngraph::opset4::TopK::Mode::MIN,
|
||||
ngraph::opset4::TopK::Mode::MAX
|
||||
const std::vector<ov::op::v1::TopK::Mode> modes = {
|
||||
ov::op::v1::TopK::Mode::MIN,
|
||||
ov::op::v1::TopK::Mode::MAX
|
||||
};
|
||||
|
||||
const std::vector<ngraph::opset4::TopK::SortType> sortTypes = {
|
||||
ngraph::opset4::TopK::SortType::SORT_INDICES,
|
||||
ngraph::opset4::TopK::SortType::SORT_VALUES,
|
||||
const std::vector<ov::op::v1::TopK::SortType> sort_types = {
|
||||
ov::op::v1::TopK::SortType::SORT_INDICES,
|
||||
ov::op::v1::TopK::SortType::SORT_VALUES,
|
||||
};
|
||||
|
||||
const std::vector<std::vector<ov::Shape>> input_shape_static = {
|
||||
{{21, 21, 21, 21}}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_TopK, TopKLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(k),
|
||||
::testing::ValuesIn(axes),
|
||||
::testing::ValuesIn(modes),
|
||||
::testing::ValuesIn(sortTypes),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
|
||||
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
|
||||
::testing::Values(InferenceEngine::Layout::ANY),
|
||||
::testing::Values(std::vector<size_t>({21, 21, 21, 21})),
|
||||
::testing::ValuesIn(sort_types),
|
||||
::testing::ValuesIn(model_types),
|
||||
::testing::ValuesIn(ov::test::static_shapes_to_test_representation(input_shape_static)),
|
||||
::testing::Values(ov::test::utils::DEVICE_CPU)),
|
||||
TopKLayerTest::getTestCaseName);
|
||||
} // namespace
|
||||
|
@ -190,6 +190,11 @@ std::vector<std::string> disabledTestPatterns() {
|
||||
R"(.*smoke_AutoBatching_CPU/AutoBatching_Test_DetectionOutput.*)",
|
||||
// Issue: 117837
|
||||
R"(.*smoke_4D_out_of_range/GatherInPlaceLayerTestCPU.*_indices=\(\-15\).*)",
|
||||
// Issue: 120222
|
||||
R"(.*smoke_TopK/TopKLayerTest.Inference.*_k=1_axis=3_.*_modelType=f16_trgDev=CPU.*)",
|
||||
R"(.*smoke_TopK/TopKLayerTest.Inference.*_k=7_axis=3_.*_modelType=f16_trgDev=CPU.*)",
|
||||
R"(.*smoke_TopK/TopKLayerTest.Inference.*_k=18_.*_modelType=f16_trgDev=CPU.*)",
|
||||
R"(.*smoke_TopK/TopKLayerTest.Inference.*_k=21_.*_sort=value_modelType=f16_trgDev=CPU.*)",
|
||||
};
|
||||
|
||||
#if defined(OPENVINO_ARCH_X86)
|
||||
|
@ -0,0 +1,15 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "shared_test_classes/single_op/topk.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
TEST_P(TopKLayerTest, Inference) {
|
||||
run();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace ov
|
@ -0,0 +1,33 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
typedef std::tuple<
|
||||
int64_t, // keepK
|
||||
int64_t, // axis
|
||||
ov::op::v1::TopK::Mode, // mode
|
||||
ov::op::v1::TopK::SortType, // sort
|
||||
ov::element::Type, // Model type
|
||||
std::vector<InputShape>, // Input shape
|
||||
std::string // Target device name
|
||||
> TopKParams;
|
||||
|
||||
class TopKLayerTest : public testing::WithParamInterface<TopKParams>,
|
||||
virtual public ov::test::SubgraphBaseTest {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<TopKParams>& obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
} // namespace test
|
||||
} // namespace ov
|
@ -934,6 +934,35 @@ ov::runtime::Tensor generate(const
|
||||
}
|
||||
}
|
||||
|
||||
ov::runtime::Tensor generate(const
|
||||
std::shared_ptr<ov::op::v1::TopK>& node,
|
||||
size_t port,
|
||||
const ov::element::Type& elemType,
|
||||
const ov::Shape& targetShape) {
|
||||
auto tensor = ov::Tensor{elemType, targetShape};
|
||||
size_t size = tensor.get_size();
|
||||
int start = - static_cast<int>(size / 2);
|
||||
std::vector<int> data(size);
|
||||
std::iota(data.begin(), data.end(), start);
|
||||
std::mt19937 gen(0);
|
||||
std::shuffle(data.begin(), data.end(), gen);
|
||||
|
||||
float divisor = size / 10.0;
|
||||
|
||||
if (tensor.get_element_type() == ov::element::f32) {
|
||||
auto *p = tensor.data<float>();
|
||||
for (size_t i = 0; i < size; i++)
|
||||
p[i] = static_cast<float>(data[i] / divisor);
|
||||
} else if (tensor.get_element_type() == ov::element::f16) {
|
||||
auto *p = tensor.data<ov::float16>();
|
||||
for (size_t i = 0; i < size; i++)
|
||||
p[i] = static_cast<ov::float16>(data[i] / divisor);
|
||||
} else {
|
||||
OPENVINO_THROW("Unsupported element type: ", tensor.get_element_type());
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
ov::runtime::Tensor generateInput(const std::shared_ptr<ov::Node>& node,
|
||||
size_t port,
|
||||
|
@ -0,0 +1,57 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "shared_test_classes/single_op/topk.hpp"
|
||||
#include <random>
|
||||
#include <common_test_utils/ov_tensor_utils.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
std::string TopKLayerTest::getTestCaseName(const testing::TestParamInfo<TopKParams>& obj) {
|
||||
ov::element::Type model_type;
|
||||
std::vector<InputShape> input_shapes;
|
||||
std::string target_device;
|
||||
int64_t keepK, axis;
|
||||
ov::op::v1::TopK::Mode mode;
|
||||
ov::op::v1::TopK::SortType sort;
|
||||
std::tie(keepK, axis, mode, sort, model_type, input_shapes, target_device) = obj.param;
|
||||
std::ostringstream result;
|
||||
result << "IS=(";
|
||||
for (size_t i = 0lu; i < input_shapes.size(); i++) {
|
||||
result << ov::test::utils::partialShape2str({input_shapes[i].first})
|
||||
<< (i < input_shapes.size() - 1lu ? "_" : "");
|
||||
}
|
||||
result << ")_TS=";
|
||||
for (size_t i = 0lu; i < input_shapes.front().second.size(); i++) {
|
||||
result << "{";
|
||||
for (size_t j = 0lu; j < input_shapes.size(); j++) {
|
||||
result << ov::test::utils::vec2str(input_shapes[j].second[i]) << (j < input_shapes.size() - 1lu ? "_" : "");
|
||||
}
|
||||
result << "}_";
|
||||
}
|
||||
result << "k=" << keepK << "_";
|
||||
result << "axis=" << axis << "_";
|
||||
result << "mode=" << mode << "_";
|
||||
result << "sort=" << sort << "_";
|
||||
result << "modelType=" << model_type.to_string() << "_";
|
||||
result << "trgDev=" << target_device;
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void TopKLayerTest::SetUp() {
|
||||
std::vector<InputShape> input_shapes;
|
||||
ov::element::Type model_type;
|
||||
int64_t keepK, axis;
|
||||
ov::op::v1::TopK::Mode mode;
|
||||
ov::op::v1::TopK::SortType sort;
|
||||
std::tie(keepK, axis, mode, sort, model_type, input_shapes, targetDevice) = this->GetParam();
|
||||
init_input_shapes(input_shapes);
|
||||
|
||||
auto param = std::make_shared<ov::op::v0::Parameter>(model_type, inputDynamicShapes.front());
|
||||
auto k = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{}, &keepK);
|
||||
auto topk = std::make_shared<ov::op::v1::TopK>(param, k, axis, mode, sort);
|
||||
function = std::make_shared<ngraph::Function>(topk->outputs(), ov::ParameterVector{param}, "TopK");
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace ov
|
Loading…
Reference in New Issue
Block a user