[GPU] add different type support in range kernel (#15422)

* add different type support in range kernel

* add functional test case for mixed input data type
This commit is contained in:
Wilson Seok
2023-02-09 09:36:25 +09:00
committed by GitHub
parent 4301ede385
commit 0d06e525db
2 changed files with 60 additions and 22 deletions

View File

@@ -67,6 +67,7 @@ ParamsKey RangeKernelRef::GetSupportedKey() const {
k.EnableOutputDataType(Datatype::F32);
k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
k.EnableDifferentTypes();
return k;
}

View File

@@ -91,29 +91,37 @@ protected:
const size_t num_inputs = 3;
std::vector<std::shared_ptr<ov::Node>> input_vec;
for (size_t idx = 0; idx < num_inputs; idx++) {
// netType=undifined means mixed type test
if (netType == ov::element::Type_t::undefined) {
input_vec.push_back(generate_constant<float>(ov::element::Type_t::f32, inputDynamicShapes[0], values[0]));
input_vec.push_back(generate_constant<int32_t>(ov::element::Type_t::i32, inputDynamicShapes[1], values[1]));
input_vec.push_back(generate_constant<float>(ov::element::Type_t::f32, inputDynamicShapes[2], values[2]));
netType = ov::element::Type_t::f32;
} else {
for (size_t idx = 0; idx < num_inputs; idx++) {
#define CASE(X) case X: input_vec.push_back(generate_constant<element_type_traits<X>::value_type>(netType, inputDynamicShapes[idx], values[idx])); break;
switch (netType) {
CASE(ov::element::Type_t::boolean)
CASE(ov::element::Type_t::i8)
CASE(ov::element::Type_t::i16)
CASE(ov::element::Type_t::i32)
CASE(ov::element::Type_t::i64)
CASE(ov::element::Type_t::u8)
CASE(ov::element::Type_t::u16)
CASE(ov::element::Type_t::u32)
CASE(ov::element::Type_t::u64)
CASE(ov::element::Type_t::bf16)
CASE(ov::element::Type_t::f16)
CASE(ov::element::Type_t::f32)
CASE(ov::element::Type_t::f64)
case ov::element::Type_t::u1:
case ov::element::Type_t::i4:
case ov::element::Type_t::u4:
input_vec.push_back(generate_constant<uint8_t>(netType, inputDynamicShapes[idx], values[idx])); break;
default: OPENVINO_UNREACHABLE("Unsupported element type: ", netType);
}
switch (netType) {
CASE(ov::element::Type_t::boolean)
CASE(ov::element::Type_t::i8)
CASE(ov::element::Type_t::i16)
CASE(ov::element::Type_t::i32)
CASE(ov::element::Type_t::i64)
CASE(ov::element::Type_t::u8)
CASE(ov::element::Type_t::u16)
CASE(ov::element::Type_t::u32)
CASE(ov::element::Type_t::u64)
CASE(ov::element::Type_t::bf16)
CASE(ov::element::Type_t::f16)
CASE(ov::element::Type_t::f32)
CASE(ov::element::Type_t::f64)
case ov::element::Type_t::u1:
case ov::element::Type_t::i4:
case ov::element::Type_t::u4:
input_vec.push_back(generate_constant<uint8_t>(netType, inputDynamicShapes[idx], values[idx])); break;
default: OPENVINO_UNREACHABLE("Unsupported element type: ", netType);
}
#undef CASE
}
}
return std::make_shared<ngraph::opset8::Range>(input_vec[0], input_vec[1], input_vec[2], netType);
@@ -125,9 +133,16 @@ protected:
std::vector<float> inputValues;
ElementType netType;
std::map<std::string, std::string> additionalConfig;
ngraph::ParameterVector params;
inputValues.clear();
std::tie(inputShapes, inputValues, netType, targetDevice, additionalConfig) = basicParamsSet;
auto params = builder::makeDynamicParams(netType, {});
// netType=undifined means mixed type test
if (netType == ov::element::Type_t::undefined) {
params = builder::makeDynamicParams(ov::element::Type_t::f32, {});
} else {
params = builder::makeDynamicParams(netType, {});
}
init_input_shapes(inputShapes);
@@ -204,5 +219,27 @@ const auto testFloatParams_smoke = ::testing::Combine(::testing::ValuesIn(dynInp
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_range_02, RangeDynamicGPUTest,
testFloatParams_smoke, RangeDynamicGPUTest::getTestCaseName);
const std::vector<std::vector<float>> inputMixedValues = {
{
// Inputs for Range
{4.5f, 12.0f, 1.0f},
{2.5f, 19.0f, 1.1f},
}
};
const std::vector<ElementType> netMixedPrecisions = {
// Mixed type test(start/step:fp32, end:i32)
ElementType::undefined
};
const auto testMixedParams_smoke = ::testing::Combine(::testing::ValuesIn(dynInputShapes),
::testing::ValuesIn(inputMixedValues),
::testing::ValuesIn(netMixedPrecisions), // netprec
::testing::Values(CommonTestUtils::DEVICE_GPU),
::testing::Values(emptyAdditionalConfig));
INSTANTIATE_TEST_SUITE_P(smoke_dynamic_diff_types, RangeDynamicGPUTest,
testMixedParams_smoke, RangeDynamicGPUTest::getTestCaseName);
} // namespace
} // namespace GPULayerTestsDefinitions