VariadicSplitLayerTest
refactoring to API2.0 (#19648)
This commit is contained in:
parent
da79964bd3
commit
fb59d0eb36
@ -2,39 +2,35 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "single_op_tests/variadic_split.hpp"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "single_layer_tests/variadic_split.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
using ov::test::VariadicSplitLayerTest;
|
||||
|
||||
namespace {
|
||||
const std::vector<ov::element::Type> model_types = {ov::element::f32, ov::element::f16};
|
||||
|
||||
const std::vector<InferenceEngine::Precision> netPrecisions = {
|
||||
InferenceEngine::Precision::FP32,
|
||||
InferenceEngine::Precision::FP16
|
||||
};
|
||||
// Sum of elements numSplits = inputShapes[Axis]
|
||||
const std::vector<std::vector<size_t>> num_splits = {{1, 16, 5, 8},
|
||||
{2, 19, 5, 4},
|
||||
{7, 13, 2, 8},
|
||||
{5, 8, 12, 5},
|
||||
{4, 11, 6, 9}};
|
||||
|
||||
// Sum of elements numSplits = inputShapes[Axis]
|
||||
const std::vector<std::vector<size_t>> numSplits = {
|
||||
{1, 16, 5, 8},
|
||||
{2, 19, 5, 4},
|
||||
{7, 13, 2, 8},
|
||||
{5, 8, 12, 5},
|
||||
{4, 11, 6, 9}
|
||||
};
|
||||
const std::vector<std::vector<ov::Shape>> input_shapes_static = {
|
||||
{{30, 30, 30, 30}},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_NumSplitsCheck, VariadicSplitLayerTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(numSplits),
|
||||
::testing::Values(0, 1, 2, 3),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
|
||||
::testing::Values(InferenceEngine::Precision::UNSPECIFIED),
|
||||
::testing::Values(InferenceEngine::Layout::ANY),
|
||||
::testing::Values(InferenceEngine::Layout::ANY),
|
||||
::testing::Values(std::vector<size_t>({30, 30, 30, 30})),
|
||||
::testing::Values(ov::test::utils::DEVICE_CPU)),
|
||||
VariadicSplitLayerTest::getTestCaseName);
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
smoke_NumSplitsCheck,
|
||||
VariadicSplitLayerTest,
|
||||
::testing::Combine(::testing::ValuesIn(num_splits),
|
||||
::testing::Values(0, 1, 2, 3),
|
||||
::testing::ValuesIn(model_types),
|
||||
::testing::ValuesIn(ov::test::static_shapes_to_test_representation(input_shapes_static)),
|
||||
::testing::Values(ov::test::utils::DEVICE_CPU)),
|
||||
VariadicSplitLayerTest::getTestCaseName);
|
||||
} // namespace
|
||||
|
@ -0,0 +1,15 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "shared_test_classes/single_op/variadic_split.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
TEST_P(VariadicSplitLayerTest, Inference) {
|
||||
run();
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace ov
|
@ -0,0 +1,33 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
typedef std::tuple<std::vector<size_t>, // Num splits
|
||||
int64_t, // Axis
|
||||
ov::element::Type, // Model type
|
||||
std::vector<InputShape>, // Input shapes
|
||||
std::string // Target device name
|
||||
>
|
||||
VariadicSplitParams;
|
||||
|
||||
class VariadicSplitLayerTest : public testing::WithParamInterface<VariadicSplitParams>,
|
||||
virtual public ov::test::SubgraphBaseTest {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<VariadicSplitParams>& obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
} // namespace test
|
||||
} // namespace ov
|
@ -0,0 +1,56 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "shared_test_classes/single_op/variadic_split.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
std::string VariadicSplitLayerTest::getTestCaseName(const testing::TestParamInfo<VariadicSplitParams>& obj) {
|
||||
int64_t axis;
|
||||
std::vector<size_t> num_splits;
|
||||
ov::element::Type model_type;
|
||||
std::vector<InputShape> input_shapes;
|
||||
ov::test::TargetDevice target_device;
|
||||
std::tie(num_splits, axis, model_type, input_shapes, target_device) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << "IS=(";
|
||||
for (size_t i = 0lu; i < input_shapes.size(); i++) {
|
||||
result << ov::test::utils::partialShape2str({input_shapes[i].first})
|
||||
<< (i < input_shapes.size() - 1lu ? "_" : "");
|
||||
}
|
||||
result << ")_TS=";
|
||||
for (size_t i = 0lu; i < input_shapes.front().second.size(); i++) {
|
||||
result << "{";
|
||||
for (size_t j = 0lu; j < input_shapes.size(); j++) {
|
||||
result << ov::test::utils::vec2str(input_shapes[j].second[i]) << (j < input_shapes.size() - 1lu ? "_" : "");
|
||||
}
|
||||
result << "}_";
|
||||
}
|
||||
result << "numSplits=" << ov::test::utils::vec2str(num_splits) << "_";
|
||||
result << "axis=" << axis << "_";
|
||||
result << "modelType=" << model_type.to_string() << "_";
|
||||
result << "trgDev=" << target_device;
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void VariadicSplitLayerTest::SetUp() {
|
||||
int64_t axis;
|
||||
std::vector<size_t> num_splits;
|
||||
std::vector<InputShape> input_shapes;
|
||||
ov::element::Type model_type;
|
||||
std::tie(num_splits, axis, model_type, input_shapes, targetDevice) = this->GetParam();
|
||||
|
||||
init_input_shapes(input_shapes);
|
||||
|
||||
auto param = std::make_shared<ov::op::v0::Parameter>(model_type, inputDynamicShapes.front());
|
||||
auto split_axis_const =
|
||||
std::make_shared<ov::op::v0::Constant>(element::i64, ngraph::Shape{}, std::vector<int64_t>{axis});
|
||||
auto num_split_const =
|
||||
std::make_shared<ov::op::v0::Constant>(element::u64, ngraph::Shape{num_splits.size()}, num_splits);
|
||||
auto variadic_split = std::make_shared<ov::op::v1::VariadicSplit>(param, split_axis_const, num_split_const);
|
||||
function = std::make_shared<ov::Model>(variadic_split->outputs(), ov::ParameterVector{param}, "VariadicSplit");
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace ov
|
Loading…
Reference in New Issue
Block a user