[GPU] add different type support in range kernel (#15422)
* add different type support in range kernel * add functional test case for mixed input data type
This commit is contained in:
@@ -67,6 +67,7 @@ ParamsKey RangeKernelRef::GetSupportedKey() const {
|
||||
k.EnableOutputDataType(Datatype::F32);
|
||||
k.EnableInputLayout(DataLayout::bfyx);
|
||||
k.EnableOutputLayout(DataLayout::bfyx);
|
||||
k.EnableDifferentTypes();
|
||||
return k;
|
||||
}
|
||||
|
||||
|
||||
@@ -91,29 +91,37 @@ protected:
|
||||
const size_t num_inputs = 3;
|
||||
|
||||
std::vector<std::shared_ptr<ov::Node>> input_vec;
|
||||
for (size_t idx = 0; idx < num_inputs; idx++) {
|
||||
// netType=undifined means mixed type test
|
||||
if (netType == ov::element::Type_t::undefined) {
|
||||
input_vec.push_back(generate_constant<float>(ov::element::Type_t::f32, inputDynamicShapes[0], values[0]));
|
||||
input_vec.push_back(generate_constant<int32_t>(ov::element::Type_t::i32, inputDynamicShapes[1], values[1]));
|
||||
input_vec.push_back(generate_constant<float>(ov::element::Type_t::f32, inputDynamicShapes[2], values[2]));
|
||||
netType = ov::element::Type_t::f32;
|
||||
} else {
|
||||
for (size_t idx = 0; idx < num_inputs; idx++) {
|
||||
#define CASE(X) case X: input_vec.push_back(generate_constant<element_type_traits<X>::value_type>(netType, inputDynamicShapes[idx], values[idx])); break;
|
||||
switch (netType) {
|
||||
CASE(ov::element::Type_t::boolean)
|
||||
CASE(ov::element::Type_t::i8)
|
||||
CASE(ov::element::Type_t::i16)
|
||||
CASE(ov::element::Type_t::i32)
|
||||
CASE(ov::element::Type_t::i64)
|
||||
CASE(ov::element::Type_t::u8)
|
||||
CASE(ov::element::Type_t::u16)
|
||||
CASE(ov::element::Type_t::u32)
|
||||
CASE(ov::element::Type_t::u64)
|
||||
CASE(ov::element::Type_t::bf16)
|
||||
CASE(ov::element::Type_t::f16)
|
||||
CASE(ov::element::Type_t::f32)
|
||||
CASE(ov::element::Type_t::f64)
|
||||
case ov::element::Type_t::u1:
|
||||
case ov::element::Type_t::i4:
|
||||
case ov::element::Type_t::u4:
|
||||
input_vec.push_back(generate_constant<uint8_t>(netType, inputDynamicShapes[idx], values[idx])); break;
|
||||
default: OPENVINO_UNREACHABLE("Unsupported element type: ", netType);
|
||||
}
|
||||
switch (netType) {
|
||||
CASE(ov::element::Type_t::boolean)
|
||||
CASE(ov::element::Type_t::i8)
|
||||
CASE(ov::element::Type_t::i16)
|
||||
CASE(ov::element::Type_t::i32)
|
||||
CASE(ov::element::Type_t::i64)
|
||||
CASE(ov::element::Type_t::u8)
|
||||
CASE(ov::element::Type_t::u16)
|
||||
CASE(ov::element::Type_t::u32)
|
||||
CASE(ov::element::Type_t::u64)
|
||||
CASE(ov::element::Type_t::bf16)
|
||||
CASE(ov::element::Type_t::f16)
|
||||
CASE(ov::element::Type_t::f32)
|
||||
CASE(ov::element::Type_t::f64)
|
||||
case ov::element::Type_t::u1:
|
||||
case ov::element::Type_t::i4:
|
||||
case ov::element::Type_t::u4:
|
||||
input_vec.push_back(generate_constant<uint8_t>(netType, inputDynamicShapes[idx], values[idx])); break;
|
||||
default: OPENVINO_UNREACHABLE("Unsupported element type: ", netType);
|
||||
}
|
||||
#undef CASE
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_shared<ngraph::opset8::Range>(input_vec[0], input_vec[1], input_vec[2], netType);
|
||||
@@ -125,9 +133,16 @@ protected:
|
||||
std::vector<float> inputValues;
|
||||
ElementType netType;
|
||||
std::map<std::string, std::string> additionalConfig;
|
||||
ngraph::ParameterVector params;
|
||||
inputValues.clear();
|
||||
std::tie(inputShapes, inputValues, netType, targetDevice, additionalConfig) = basicParamsSet;
|
||||
auto params = builder::makeDynamicParams(netType, {});
|
||||
|
||||
// netType=undifined means mixed type test
|
||||
if (netType == ov::element::Type_t::undefined) {
|
||||
params = builder::makeDynamicParams(ov::element::Type_t::f32, {});
|
||||
} else {
|
||||
params = builder::makeDynamicParams(netType, {});
|
||||
}
|
||||
|
||||
init_input_shapes(inputShapes);
|
||||
|
||||
@@ -204,5 +219,27 @@ const auto testFloatParams_smoke = ::testing::Combine(::testing::ValuesIn(dynInp
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_range_02, RangeDynamicGPUTest,
|
||||
testFloatParams_smoke, RangeDynamicGPUTest::getTestCaseName);
|
||||
|
||||
const std::vector<std::vector<float>> inputMixedValues = {
|
||||
{
|
||||
// Inputs for Range
|
||||
{4.5f, 12.0f, 1.0f},
|
||||
{2.5f, 19.0f, 1.1f},
|
||||
}
|
||||
};
|
||||
|
||||
const std::vector<ElementType> netMixedPrecisions = {
|
||||
// Mixed type test(start/step:fp32, end:i32)
|
||||
ElementType::undefined
|
||||
};
|
||||
|
||||
|
||||
const auto testMixedParams_smoke = ::testing::Combine(::testing::ValuesIn(dynInputShapes),
|
||||
::testing::ValuesIn(inputMixedValues),
|
||||
::testing::ValuesIn(netMixedPrecisions), // netprec
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||
::testing::Values(emptyAdditionalConfig));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_diff_types, RangeDynamicGPUTest,
|
||||
testMixedParams_smoke, RangeDynamicGPUTest::getTestCaseName);
|
||||
} // namespace
|
||||
} // namespace GPULayerTestsDefinitions
|
||||
|
||||
Reference in New Issue
Block a user