[GPU] Added shape agnostic kernel support for Range + fix dynamic tests (#15640)
This commit is contained in:
parent
e4551b66c4
commit
803a927e70
@ -14,11 +14,18 @@ struct range: public primitive_base<range> {
|
||||
/// @param id This primitive id.
|
||||
/// @param inputs Input primitive id vector.
|
||||
/// @param output_layout requested range output layout
|
||||
range(const primitive_id &id,
|
||||
const std::vector<input_info> &inputs,
|
||||
const layout &output_layout)
|
||||
: primitive_base{ id, inputs, {output_layout.data_padding}, {output_layout.data_type} },
|
||||
output_layout { output_layout } { }
|
||||
range(const primitive_id& id,
|
||||
const std::vector<input_info>& inputs,
|
||||
const layout& output_layout)
|
||||
: primitive_base(id, inputs, {output_layout.data_padding}, {output_layout.data_type}),
|
||||
output_layout(output_layout) {}
|
||||
|
||||
range(const primitive_id& id,
|
||||
const std::vector<input_info>& inputs,
|
||||
const data_types data_type,
|
||||
const padding& output_padding = padding())
|
||||
: primitive_base(id, inputs, {output_padding}, {optional_data_type(data_type)}),
|
||||
output_layout({}) {}
|
||||
|
||||
/// @brief requested range output layout
|
||||
layout output_layout;
|
||||
|
@ -33,22 +33,34 @@ struct range_impl : typed_primitive_impl_ocl<range> {
|
||||
|
||||
return {params, optional_params};
|
||||
}
|
||||
|
||||
void update_dispatch_data(const kernel_impl_params& impl_param) override {
|
||||
auto kernel_params = get_kernel_params(impl_param);
|
||||
(_kernel_data.update_dispatch_data_func)(kernel_params.first, _kernel_data);
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
attach_range_impl::attach_range_impl() {
|
||||
implementation_map<range>::add(
|
||||
impl_types::ocl,
|
||||
typed_primitive_impl_ocl<range>::create<range_impl>,
|
||||
{
|
||||
std::make_tuple(data_types::u8, format::bfyx),
|
||||
std::make_tuple(data_types::i8, format::bfyx),
|
||||
std::make_tuple(data_types::f16, format::bfyx),
|
||||
std::make_tuple(data_types::f32, format::bfyx),
|
||||
std::make_tuple(data_types::i32, format::bfyx),
|
||||
std::make_tuple(data_types::i64, format::bfyx),
|
||||
});
|
||||
auto types = {
|
||||
data_types::f32,
|
||||
data_types::f16,
|
||||
data_types::i32,
|
||||
data_types::i64,
|
||||
data_types::i8,
|
||||
data_types::u8
|
||||
};
|
||||
|
||||
auto formats = {
|
||||
format::bfyx
|
||||
};
|
||||
|
||||
implementation_map<range>::add(impl_types::ocl,
|
||||
shape_types::any,
|
||||
typed_primitive_impl_ocl<range>::create<range_impl>,
|
||||
types,
|
||||
formats);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
@ -2,7 +2,10 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
KERNEL (range_ref)(const __global INPUT0_TYPE *startP, const __global INPUT2_TYPE *stepP, __global OUTPUT_TYPE *output)
|
||||
KERNEL(range_ref)(OPTIONAL_SHAPE_INFO_ARG
|
||||
const __global INPUT0_TYPE *startP,
|
||||
const __global INPUT2_TYPE *stepP,
|
||||
__global OUTPUT_TYPE *output)
|
||||
{
|
||||
const uint i = get_global_id(2);
|
||||
const OUTPUT_TYPE start = TO_OUTPUT_TYPE(*startP);
|
||||
|
@ -23,22 +23,34 @@ CommonDispatchData SetDefault(const range_params ¶ms) {
|
||||
} // namespace
|
||||
|
||||
KernelsData RangeKernelRef::GetKernelsData(const Params ¶ms, const optional_params &options) const {
|
||||
KernelsData kernels_data;
|
||||
if (!Validate(params, options))
|
||||
return kernels_data;
|
||||
kernels_data.push_back(KernelData::Default<range_params>(params));
|
||||
KernelData &kernel_data = kernels_data.front();
|
||||
auto &derived_params = dynamic_cast<range_params&>(*kernel_data.params.get());
|
||||
auto dispatch_data = SetDefault(derived_params);
|
||||
auto entry_point = GetEntryPoint(kernelName, derived_params.layerID, params, options);
|
||||
auto jit_constants = MakeBaseParamsJitConstants(derived_params);
|
||||
return {};
|
||||
|
||||
KernelData kernel_data = KernelData::Default<range_params>(params);
|
||||
const auto& prim_params = static_cast<const range_params&>(params);
|
||||
|
||||
auto dispatch_data = SetDefault(prim_params);
|
||||
auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, params, options);
|
||||
auto jit_constants = MakeBaseParamsJitConstants(prim_params);
|
||||
auto jit = CreateJit(kernelName, jit_constants, entry_point);
|
||||
|
||||
kernel_data.update_dispatch_data_func = [this](const Params& params, KernelData& kd) {
|
||||
const auto& prim_params = static_cast<const range_params&>(params);
|
||||
auto dispatchData = SetDefault(prim_params);
|
||||
OPENVINO_ASSERT(kd.kernels.size() == 1, "[GPU] Invalid kernels size for update dispatch data func");
|
||||
kd.kernels[0].params.workGroups.global = dispatchData.gws;
|
||||
kd.kernels[0].params.workGroups.local = dispatchData.lws;
|
||||
};
|
||||
|
||||
auto &clKernelData = kernel_data.kernels[0];
|
||||
FillCLKernelData(clKernelData, dispatch_data, params.engineInfo, kernelName, jit, entry_point, EXE_MODE_DEFAULT,
|
||||
false, false, 3);
|
||||
bool is_dynamic = prim_params.has_dynamic_tensors();
|
||||
FillCLKernelData(clKernelData, dispatch_data, params.engineInfo, kernelName, jit, entry_point,
|
||||
EXE_MODE_DEFAULT, false, false, 3, 0, 1, is_dynamic);
|
||||
|
||||
auto &arguments = clKernelData.params.arguments;
|
||||
arguments.erase(arguments.begin()+1); // stop is not used by kernel
|
||||
return kernels_data;
|
||||
arguments.erase(arguments.begin() + 1 + static_cast<int>(is_dynamic)); // stop is not used by kernel
|
||||
|
||||
return {kernel_data};
|
||||
}
|
||||
|
||||
KernelsPriority RangeKernelRef::GetKernelsPriority(const Params& /*params*/, const optional_params& /*options*/) const {
|
||||
@ -68,6 +80,7 @@ ParamsKey RangeKernelRef::GetSupportedKey() const {
|
||||
k.EnableInputLayout(DataLayout::bfyx);
|
||||
k.EnableOutputLayout(DataLayout::bfyx);
|
||||
k.EnableDifferentTypes();
|
||||
k.EnableDynamicShapesSupport();
|
||||
return k;
|
||||
}
|
||||
|
||||
|
@ -14,21 +14,21 @@ namespace intel_gpu {
|
||||
static void CreateRangeOp(Program &p, const std::shared_ptr<ngraph::op::v4::Range> &op) {
|
||||
validate_inputs_count(op, { 3 });
|
||||
auto output_pshape = op->get_output_partial_shape(0);
|
||||
OPENVINO_ASSERT(output_pshape.rank().get_length() == 1 , "[GPU] range v4 output rank should be 1");
|
||||
auto output_dtype = cldnn::element_type_to_data_type(op->get_output_element_type(0));
|
||||
|
||||
std::shared_ptr<cldnn::layout> outLayout = nullptr;
|
||||
if (output_pshape.is_static()) {
|
||||
OPENVINO_ASSERT(output_pshape.rank().get_length() == 1 , "[GPU] range v4 output rank should be 1");
|
||||
auto& out_shape = op->get_output_shape(0);
|
||||
outLayout = std::make_shared<cldnn::layout>(output_dtype, cldnn::format::bfyx, cldnn::tensor(cldnn::batch(out_shape[0])));
|
||||
std::shared_ptr<cldnn::range> range_prim = nullptr;
|
||||
if (p.use_new_shape_infer()) {
|
||||
range_prim = std::make_shared<cldnn::range>(layer_type_name_ID(op),
|
||||
p.GetInputInfo(op),
|
||||
output_dtype);
|
||||
} else {
|
||||
outLayout = std::make_shared<cldnn::layout>(output_pshape, output_dtype, cldnn::format::bfyx);
|
||||
auto outLayout = cldnn::layout{ output_pshape, output_dtype, cldnn::format::bfyx };
|
||||
range_prim = std::make_shared<cldnn::range>(layer_type_name_ID(op),
|
||||
p.GetInputInfo(op),
|
||||
outLayout);
|
||||
}
|
||||
|
||||
cldnn::range prim(layer_type_name_ID(op),
|
||||
p.GetInputInfo(op),
|
||||
*outLayout);
|
||||
p.add_primitive(*op, prim);
|
||||
p.add_primitive(*op, range_prim);
|
||||
}
|
||||
|
||||
REGISTER_FACTORY_IMPL(v4, Range);
|
||||
|
@ -8,6 +8,8 @@
|
||||
#include <intel_gpu/primitives/range.hpp>
|
||||
#include <intel_gpu/primitives/select.hpp>
|
||||
#include <intel_gpu/primitives/data.hpp>
|
||||
#include "range_inst.h"
|
||||
|
||||
using namespace ::tests;
|
||||
using namespace testing;
|
||||
|
||||
@ -162,19 +164,19 @@ struct range_test_param_generator : std::vector<range_test_params> {
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<data_types> signed_types = {data_types::i8};
|
||||
std::vector<data_types> general_types = {data_types::u8, data_types::i32, data_types::i32, data_types::f16, data_types::f32};
|
||||
std::vector<data_types> float_types = {data_types::f16, data_types::f32};
|
||||
std::vector<data_types> signed_types = {data_types::i8};
|
||||
std::vector<data_types> general_types = {data_types::u8, data_types::i32, data_types::i32, data_types::f16, data_types::f32};
|
||||
std::vector<data_types> float_types = {data_types::f16, data_types::f32};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(range_gpu_test,
|
||||
smoke_range_test,
|
||||
testing::ValuesIn(
|
||||
range_test_param_generator()
|
||||
.simple_params(general_types, 2, 23, 3)
|
||||
.simple_params(general_types, 1, 21, 2)
|
||||
.simple_params(float_types, 1, 2.5f, 0.5f)
|
||||
.simple_params(signed_types, 23, 2, -3)
|
||||
.simple_params(signed_types, 4, 0, -1)
|
||||
.simple_params(general_types, 2, 23, 3)
|
||||
.simple_params(general_types, 1, 21, 2)
|
||||
.simple_params(float_types, 1, 2.5f, 0.5f)
|
||||
.simple_params(signed_types, 23, 2, -3)
|
||||
.simple_params(signed_types, 4, 0, -1)
|
||||
));
|
||||
|
||||
TEST(range_gpu_test, range_with_select) {
|
||||
@ -184,22 +186,20 @@ TEST(range_gpu_test, range_with_select) {
|
||||
int32_t step_val = 1;
|
||||
int32_t expected_dim = 25;
|
||||
|
||||
|
||||
auto select_input1 = engine.allocate_memory({ { 1 }, data_types::u8, format::bfyx });
|
||||
auto select_input2 = engine.allocate_memory({ { }, data_types::i32, format::bfyx });
|
||||
auto select_mask = engine.allocate_memory({ { 1 }, data_types::i32, format::bfyx });
|
||||
auto input0 = engine.allocate_memory({ { }, data_types::i32, format::bfyx });
|
||||
auto input2 = engine.allocate_memory({ { }, data_types::i32, format::bfyx });
|
||||
auto select_input1 = engine.allocate_memory({ {1}, data_types::u8, format::bfyx });
|
||||
auto select_input2 = engine.allocate_memory({ { }, data_types::i32, format::bfyx });
|
||||
auto select_mask = engine.allocate_memory({ {1}, data_types::i32, format::bfyx });
|
||||
auto input0 = engine.allocate_memory({ { }, data_types::i32, format::bfyx });
|
||||
auto input2 = engine.allocate_memory({ { }, data_types::i32, format::bfyx });
|
||||
|
||||
topology topology;
|
||||
topology.add(data("select_input1", select_input1));
|
||||
topology.add(data("select_input2", select_input2));
|
||||
topology.add(data("select_mask", select_mask));
|
||||
topology.add(data("input0", input0));
|
||||
topology.add(data("input2", input2));
|
||||
topology.add(data("select_input1", select_input1));
|
||||
topology.add(data("select_input2", select_input2));
|
||||
topology.add(data("select_mask", select_mask));
|
||||
topology.add(data("input0", input0));
|
||||
topology.add(data("input2", input2));
|
||||
topology.add(cldnn::select("select", input_info("select_input1"), input_info("select_input2"), input_info("select_mask")));
|
||||
topology.add(range { "range", { input_info("input0"), input_info("select"), input_info("input2") }, { data_types::i32, format::bfyx, tensor{batch(expected_dim)} } });
|
||||
|
||||
topology.add(range{ "range", { input_info("input0"), input_info("select"), input_info("input2") }, data_types::i32 });
|
||||
|
||||
set_values<uint8_t>(select_input1, {0});
|
||||
set_values<int32_t>(select_input2, {384});
|
||||
@ -221,5 +221,132 @@ TEST(range_gpu_test, range_with_select) {
|
||||
ASSERT_EQ(start_val + i * step_val, output_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(range_gpu_test, constant_folding) {
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
int32_t start_val = 0;
|
||||
int32_t step_val = 1;
|
||||
int32_t expected_dim = 25;
|
||||
|
||||
auto input0 = engine.allocate_memory({ ov::PartialShape::dynamic(0), data_types::i32, format::bfyx });
|
||||
auto input1 = engine.allocate_memory({ ov::PartialShape::dynamic(0), data_types::i32, format::bfyx });
|
||||
auto input2 = engine.allocate_memory({ ov::PartialShape::dynamic(0), data_types::i32, format::bfyx });
|
||||
|
||||
set_values<int32_t>(input0, { start_val });
|
||||
set_values<int32_t>(input1, { expected_dim });
|
||||
set_values<int32_t>(input2, { step_val });
|
||||
|
||||
topology topology;
|
||||
topology.add(data("input0", input0));
|
||||
topology.add(data("input1", input1));
|
||||
topology.add(data("input2", input2));
|
||||
topology.add(range{ "range", { input_info("input0"), input_info("input1"), input_info("input2") }, data_types::i32});
|
||||
|
||||
ExecutionConfig config;
|
||||
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
|
||||
|
||||
network network(engine, topology, config);
|
||||
|
||||
auto outputs = network.execute();
|
||||
auto output = outputs.at("range").get_memory();
|
||||
|
||||
mem_lock<int32_t> output_ptr(output, tests::get_test_stream());
|
||||
|
||||
for (size_t i = 0; i < static_cast<size_t>(expected_dim); ++i) {
|
||||
ASSERT_EQ(start_val + i * step_val, output_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(range_gpu_test, dynamic_all) {
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
int32_t start_val = 0;
|
||||
int32_t step_val = 1;
|
||||
int32_t expected_dim = 25;
|
||||
|
||||
auto dynamic_input_layout = layout{ ov::PartialShape::dynamic(0), data_types::i32, format::bfyx };
|
||||
|
||||
auto input0 = engine.allocate_memory({ {}, data_types::i32, format::bfyx });
|
||||
auto input1 = engine.allocate_memory({ {}, data_types::i32, format::bfyx });
|
||||
auto input2 = engine.allocate_memory({ {}, data_types::i32, format::bfyx });
|
||||
|
||||
set_values<int32_t>(input0, { start_val });
|
||||
set_values<int32_t>(input1, { expected_dim });
|
||||
set_values<int32_t>(input2, { step_val });
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("input0", dynamic_input_layout));
|
||||
topology.add(input_layout("input1", dynamic_input_layout));
|
||||
topology.add(input_layout("input2", dynamic_input_layout));
|
||||
topology.add(range{ "range", { input_info("input0"), input_info("input1"), input_info("input2") }, data_types::i32});
|
||||
|
||||
ExecutionConfig config;
|
||||
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
|
||||
|
||||
network network(engine, topology, config);
|
||||
network.set_input_data("input0", input0);
|
||||
network.set_input_data("input1", input1);
|
||||
network.set_input_data("input2", input2);
|
||||
|
||||
auto inst = network.get_primitive("range");
|
||||
auto impl = inst->get_impl();
|
||||
ASSERT_TRUE(impl != nullptr);
|
||||
ASSERT_TRUE(impl->is_dynamic());
|
||||
|
||||
auto outputs = network.execute();
|
||||
auto output = outputs.at("range").get_memory();
|
||||
|
||||
mem_lock<int32_t> output_ptr(output, tests::get_test_stream());
|
||||
|
||||
for (size_t i = 0; i < static_cast<size_t>(expected_dim); ++i) {
|
||||
ASSERT_EQ(start_val + i * step_val, output_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
TEST(range_gpu_test, dynamic_stop) {
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
int32_t start_val = 0;
|
||||
int32_t step_val = 1;
|
||||
int32_t expected_dim = 25;
|
||||
|
||||
auto dynamic_input_layout = layout{ ov::PartialShape::dynamic(0), data_types::i32, format::bfyx };
|
||||
|
||||
auto input0 = engine.allocate_memory({ {}, data_types::i32, format::bfyx });
|
||||
auto input1 = engine.allocate_memory({ {}, data_types::i32, format::bfyx });
|
||||
auto input2 = engine.allocate_memory({ {}, data_types::i32, format::bfyx });
|
||||
|
||||
set_values<int32_t>(input0, { start_val });
|
||||
set_values<int32_t>(input1, { expected_dim });
|
||||
set_values<int32_t>(input2, { step_val });
|
||||
|
||||
topology topology;
|
||||
topology.add(data("input0", input0));
|
||||
topology.add(input_layout("input1", dynamic_input_layout));
|
||||
topology.add(data("input2", input2));
|
||||
topology.add(range{ "range", { input_info("input0"), input_info("input1"), input_info("input2") }, data_types::i32});
|
||||
|
||||
ExecutionConfig config;
|
||||
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
|
||||
|
||||
network network(engine, topology, config);
|
||||
network.set_input_data("input1", input1);
|
||||
|
||||
auto inst = network.get_primitive("range");
|
||||
auto impl = inst->get_impl();
|
||||
ASSERT_TRUE(impl != nullptr);
|
||||
ASSERT_TRUE(impl->is_dynamic());
|
||||
|
||||
auto outputs = network.execute();
|
||||
auto output = outputs.at("range").get_memory();
|
||||
|
||||
mem_lock<int32_t> output_ptr(output, tests::get_test_stream());
|
||||
|
||||
for (size_t i = 0; i < static_cast<size_t>(expected_dim); ++i) {
|
||||
ASSERT_EQ(start_val + i * step_val, output_ptr[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace cldnn
|
||||
|
@ -1,36 +1,26 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "ngraph_functions/utils/ngraph_helpers.hpp"
|
||||
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
#include "shared_test_classes/single_layer/shape_of.hpp"
|
||||
#include "shared_test_classes/single_layer/strided_slice.hpp"
|
||||
#include <shared_test_classes/single_layer/eltwise.hpp>
|
||||
#include <common_test_utils/ov_tensor_utils.hpp>
|
||||
#include "shared_test_classes/base/layer_test_utils.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace InferenceEngine;
|
||||
using namespace ov::test;
|
||||
|
||||
namespace GPULayerTestsDefinitions {
|
||||
|
||||
typedef std::tuple<
|
||||
std::vector<InputShape>, // input shapes
|
||||
std::vector<float>, // input values
|
||||
ElementType, // Network precision
|
||||
TargetDevice, // Device name
|
||||
std::map<std::string, std::string> // Additional network configuration
|
||||
std::vector<InputShape>, // input shapes
|
||||
std::vector<float>, // input values
|
||||
ElementType, // Network precision
|
||||
TargetDevice, // Device name
|
||||
std::map<std::string, std::string> // Additional network configuration
|
||||
> RangeDynamicGPUTestParamsSet;
|
||||
|
||||
|
||||
|
||||
class RangeDynamicGPUTest : public testing::WithParamInterface<RangeDynamicGPUTestParamsSet>,
|
||||
virtual public SubgraphBaseTest {
|
||||
virtual public SubgraphBaseTest {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<RangeDynamicGPUTestParamsSet>& obj) {
|
||||
RangeDynamicGPUTestParamsSet basicParamsSet = obj.param;
|
||||
@ -82,49 +72,55 @@ protected:
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
std::shared_ptr<ov::Node> inline generate_constant(ElementType netType, ov::PartialShape& pshape, const float value) {
|
||||
std::vector<T> data_vec = {static_cast<T>(value)};
|
||||
return builder::makeConstant(netType, pshape.to_shape(), data_vec);
|
||||
void add_scalar_to_tensor(T scalar, ov::Tensor& tensor) {
|
||||
#define CASE(X) \
|
||||
case X: { \
|
||||
auto *dataPtr = tensor.data<element_type_traits<X>::value_type>(); \
|
||||
dataPtr[0] = static_cast<element_type_traits<X>::value_type>(scalar); \
|
||||
break; \
|
||||
}
|
||||
|
||||
switch (tensor.get_element_type()) {
|
||||
CASE(ElementType::boolean)
|
||||
CASE(ElementType::i8)
|
||||
CASE(ElementType::i16)
|
||||
CASE(ElementType::i32)
|
||||
CASE(ElementType::i64)
|
||||
CASE(ElementType::u8)
|
||||
CASE(ElementType::u16)
|
||||
CASE(ElementType::u32)
|
||||
CASE(ElementType::u64)
|
||||
CASE(ElementType::bf16)
|
||||
CASE(ElementType::f16)
|
||||
CASE(ElementType::f32)
|
||||
CASE(ElementType::f64)
|
||||
CASE(ElementType::u1)
|
||||
CASE(ElementType::i4)
|
||||
CASE(ElementType::u4)
|
||||
default: OPENVINO_UNREACHABLE("Unsupported element type: ", tensor.get_element_type());
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::opset8::Range> generate_range_op(ElementType netType, std::vector<ov::PartialShape>& pshapes, std::vector<float>& values) {
|
||||
const size_t num_inputs = 3;
|
||||
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override {
|
||||
inputs.clear();
|
||||
const auto& funcInputs = function->inputs();
|
||||
|
||||
std::vector<std::shared_ptr<ov::Node>> input_vec;
|
||||
// netType=undifined means mixed type test
|
||||
if (netType == ov::element::Type_t::undefined) {
|
||||
input_vec.push_back(generate_constant<float>(ov::element::Type_t::f32, inputDynamicShapes[0], values[0]));
|
||||
input_vec.push_back(generate_constant<int32_t>(ov::element::Type_t::i32, inputDynamicShapes[1], values[1]));
|
||||
input_vec.push_back(generate_constant<float>(ov::element::Type_t::f32, inputDynamicShapes[2], values[2]));
|
||||
netType = ov::element::Type_t::f32;
|
||||
auto generate_input = [&](size_t index, ElementType element_type) {
|
||||
ov::Tensor tensor(element_type, targetInputStaticShapes[index]);
|
||||
add_scalar_to_tensor<float>(input_values[index], tensor);
|
||||
inputs.insert({funcInputs[index].get_node_shared_ptr(), tensor});
|
||||
};
|
||||
|
||||
// net_type=undifined means mixed type test
|
||||
if (net_type == ElementType::undefined) {
|
||||
generate_input(0, ElementType::f32);
|
||||
generate_input(1, ElementType::i32);
|
||||
generate_input(2, ElementType::f32);
|
||||
} else {
|
||||
for (size_t idx = 0; idx < num_inputs; idx++) {
|
||||
#define CASE(X) case X: input_vec.push_back(generate_constant<element_type_traits<X>::value_type>(netType, inputDynamicShapes[idx], values[idx])); break;
|
||||
switch (netType) {
|
||||
CASE(ov::element::Type_t::boolean)
|
||||
CASE(ov::element::Type_t::i8)
|
||||
CASE(ov::element::Type_t::i16)
|
||||
CASE(ov::element::Type_t::i32)
|
||||
CASE(ov::element::Type_t::i64)
|
||||
CASE(ov::element::Type_t::u8)
|
||||
CASE(ov::element::Type_t::u16)
|
||||
CASE(ov::element::Type_t::u32)
|
||||
CASE(ov::element::Type_t::u64)
|
||||
CASE(ov::element::Type_t::bf16)
|
||||
CASE(ov::element::Type_t::f16)
|
||||
CASE(ov::element::Type_t::f32)
|
||||
CASE(ov::element::Type_t::f64)
|
||||
case ov::element::Type_t::u1:
|
||||
case ov::element::Type_t::i4:
|
||||
case ov::element::Type_t::u4:
|
||||
input_vec.push_back(generate_constant<uint8_t>(netType, inputDynamicShapes[idx], values[idx])); break;
|
||||
default: OPENVINO_UNREACHABLE("Unsupported element type: ", netType);
|
||||
}
|
||||
#undef CASE
|
||||
for (size_t i = 0; i < funcInputs.size(); ++i) {
|
||||
generate_input(i, funcInputs[i].get_element_type());
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_shared<ngraph::opset8::Range>(input_vec[0], input_vec[1], input_vec[2], netType);
|
||||
}
|
||||
|
||||
void SetUp() override {
|
||||
@ -134,22 +130,29 @@ protected:
|
||||
ElementType netType;
|
||||
std::map<std::string, std::string> additionalConfig;
|
||||
ngraph::ParameterVector params;
|
||||
inputValues.clear();
|
||||
std::tie(inputShapes, inputValues, netType, targetDevice, additionalConfig) = basicParamsSet;
|
||||
|
||||
// netType=undifined means mixed type test
|
||||
if (netType == ov::element::Type_t::undefined) {
|
||||
params = builder::makeDynamicParams(ov::element::Type_t::f32, {});
|
||||
} else {
|
||||
params = builder::makeDynamicParams(netType, {});
|
||||
}
|
||||
input_values = inputValues;
|
||||
net_type = netType;
|
||||
|
||||
init_input_shapes(inputShapes);
|
||||
|
||||
const auto range = generate_range_op(netType, inputDynamicShapes, inputValues);
|
||||
if (netType == ElementType::undefined) {
|
||||
std::vector<element::Type> types = { ElementType::f32, ElementType::i32, ElementType::f32 };
|
||||
params = builder::makeDynamicParams(types, inputDynamicShapes);
|
||||
netType = ElementType::f32;
|
||||
} else {
|
||||
params = builder::makeDynamicParams(netType, inputDynamicShapes);
|
||||
}
|
||||
const auto range = std::make_shared<ngraph::opset8::Range>(params[0], params[1], params[2], netType);
|
||||
|
||||
ngraph::ResultVector results = {std::make_shared<ngraph::opset1::Result>(range)};
|
||||
function = std::make_shared<ngraph::Function>(results, params, "shapeof_out");
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<float> input_values;
|
||||
ElementType net_type;
|
||||
};
|
||||
|
||||
|
||||
@ -174,8 +177,8 @@ const std::vector<std::vector<float>> inputValues = {
|
||||
// Inputs for Range
|
||||
{2, 23, 3},
|
||||
{1, 21, 2},
|
||||
{23, 2, -3},
|
||||
{4, 0, -1},
|
||||
{23, 2, -3},
|
||||
{4, 0, -1},
|
||||
}
|
||||
};
|
||||
|
||||
@ -185,12 +188,11 @@ const std::vector<ElementType> netPrecisions = {
|
||||
ElementType::i64,
|
||||
};
|
||||
|
||||
|
||||
const auto testParams_smoke = ::testing::Combine(::testing::ValuesIn(dynInputShapes),
|
||||
::testing::ValuesIn(inputValues),
|
||||
::testing::ValuesIn(netPrecisions), // netprec
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||
::testing::Values(emptyAdditionalConfig));
|
||||
::testing::ValuesIn(inputValues),
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||
::testing::Values(emptyAdditionalConfig));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_range_01, RangeDynamicGPUTest,
|
||||
testParams_smoke, RangeDynamicGPUTest::getTestCaseName);
|
||||
@ -199,8 +201,8 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamic_range_01, RangeDynamicGPUTest,
|
||||
const std::vector<std::vector<float>> inputFloatValues = {
|
||||
{
|
||||
// Inputs for Range
|
||||
{1.0f, 2.5f, 0.5f},
|
||||
{23.0f, 5.0f, -2.0f},
|
||||
{1.0f, 2.5f, 0.5f},
|
||||
{23.0f, 5.0f, -2.0f},
|
||||
}
|
||||
};
|
||||
|
||||
@ -209,12 +211,11 @@ const std::vector<ElementType> netFloatPrecisions = {
|
||||
ElementType::f32,
|
||||
};
|
||||
|
||||
|
||||
const auto testFloatParams_smoke = ::testing::Combine(::testing::ValuesIn(dynInputShapes),
|
||||
::testing::ValuesIn(inputFloatValues),
|
||||
::testing::ValuesIn(netFloatPrecisions), // netprec
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||
::testing::Values(emptyAdditionalConfig));
|
||||
::testing::ValuesIn(inputFloatValues),
|
||||
::testing::ValuesIn(netFloatPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||
::testing::Values(emptyAdditionalConfig));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_range_02, RangeDynamicGPUTest,
|
||||
testFloatParams_smoke, RangeDynamicGPUTest::getTestCaseName);
|
||||
@ -222,8 +223,8 @@ INSTANTIATE_TEST_SUITE_P(smoke_dynamic_range_02, RangeDynamicGPUTest,
|
||||
const std::vector<std::vector<float>> inputMixedValues = {
|
||||
{
|
||||
// Inputs for Range
|
||||
{4.5f, 12.0f, 1.0f},
|
||||
{2.5f, 19.0f, 1.1f},
|
||||
{4.5f, 12.0f, 1.0f},
|
||||
{2.5f, 19.0f, 1.1f},
|
||||
}
|
||||
};
|
||||
|
||||
@ -234,10 +235,10 @@ const std::vector<ElementType> netMixedPrecisions = {
|
||||
|
||||
|
||||
const auto testMixedParams_smoke = ::testing::Combine(::testing::ValuesIn(dynInputShapes),
|
||||
::testing::ValuesIn(inputMixedValues),
|
||||
::testing::ValuesIn(netMixedPrecisions), // netprec
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||
::testing::Values(emptyAdditionalConfig));
|
||||
::testing::ValuesIn(inputMixedValues),
|
||||
::testing::ValuesIn(netMixedPrecisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||
::testing::Values(emptyAdditionalConfig));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_diff_types, RangeDynamicGPUTest,
|
||||
testMixedParams_smoke, RangeDynamicGPUTest::getTestCaseName);
|
||||
|
Loading…
Reference in New Issue
Block a user