[GPU] Add blocked layouts support for Roll operation (#12125)
* Add blocked layouts support for Roll operation * Reduce number of Roll unit tests
This commit is contained in:
@@ -41,28 +41,32 @@ struct roll_impl : typed_primitive_impl_ocl<roll> {
|
||||
namespace detail {
|
||||
|
||||
attach_roll_impl::attach_roll_impl() {
|
||||
implementation_map<roll>::add(impl_types::ocl,
|
||||
roll_impl::create,
|
||||
{
|
||||
std::make_tuple(data_types::u8, format::bfyx),
|
||||
std::make_tuple(data_types::u8, format::bfzyx),
|
||||
std::make_tuple(data_types::u8, format::bfwzyx),
|
||||
std::make_tuple(data_types::i8, format::bfyx),
|
||||
std::make_tuple(data_types::i8, format::bfzyx),
|
||||
std::make_tuple(data_types::i8, format::bfwzyx),
|
||||
std::make_tuple(data_types::f16, format::bfyx),
|
||||
std::make_tuple(data_types::f16, format::bfzyx),
|
||||
std::make_tuple(data_types::f16, format::bfwzyx),
|
||||
std::make_tuple(data_types::f32, format::bfyx),
|
||||
std::make_tuple(data_types::f32, format::bfzyx),
|
||||
std::make_tuple(data_types::f32, format::bfwzyx),
|
||||
std::make_tuple(data_types::i32, format::bfyx),
|
||||
std::make_tuple(data_types::i32, format::bfzyx),
|
||||
std::make_tuple(data_types::i32, format::bfwzyx),
|
||||
std::make_tuple(data_types::i64, format::bfyx),
|
||||
std::make_tuple(data_types::i64, format::bfzyx),
|
||||
std::make_tuple(data_types::i64, format::bfwzyx),
|
||||
});
|
||||
auto types = {data_types::f16, data_types::f32, data_types::i8, data_types::u8, data_types::i32, data_types::i64};
|
||||
auto formats = {
|
||||
format::bfyx,
|
||||
format::b_fs_yx_fsv16,
|
||||
format::b_fs_yx_fsv32,
|
||||
format::bs_fs_yx_bsv16_fsv16,
|
||||
format::bs_fs_yx_bsv32_fsv32,
|
||||
format::bs_fs_yx_bsv32_fsv16,
|
||||
|
||||
format::bfzyx,
|
||||
format::b_fs_zyx_fsv16,
|
||||
format::b_fs_zyx_fsv32,
|
||||
format::bs_fs_zyx_bsv16_fsv32,
|
||||
format::bs_fs_zyx_bsv16_fsv16,
|
||||
format::bs_fs_zyx_bsv32_fsv32,
|
||||
format::bs_fs_zyx_bsv32_fsv16,
|
||||
|
||||
format::bfwzyx
|
||||
};
|
||||
std::set<std::tuple<data_types, format::type>> keys;
|
||||
for (const auto& t : types) {
|
||||
for (const auto& f : formats) {
|
||||
keys.emplace(t, f);
|
||||
}
|
||||
}
|
||||
implementation_map<roll>::add(impl_types::ocl, roll_impl::create, keys);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
@@ -1424,7 +1424,8 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
|
||||
prim.type() != cldnn::non_max_suppression::type_id() &&
|
||||
prim.type() != cldnn::roi_align::type_id() &&
|
||||
prim.type() != cldnn::adaptive_pooling::type_id() &&
|
||||
prim.type() != cldnn::bucketize::type_id()) {
|
||||
prim.type() != cldnn::bucketize::type_id() &&
|
||||
prim.type() != cldnn::roll::type_id()) {
|
||||
can_use_fsv16 = false;
|
||||
}
|
||||
|
||||
@@ -1455,7 +1456,8 @@ void program::set_layout_optimizer_attributes(layout_optimizer& lo) {
|
||||
prim.type() != cldnn::non_max_suppression::type_id() &&
|
||||
prim.type() != cldnn::roi_align::type_id() &&
|
||||
prim.type() != cldnn::adaptive_pooling::type_id() &&
|
||||
prim.type() != cldnn::bucketize::type_id()) {
|
||||
prim.type() != cldnn::bucketize::type_id() &&
|
||||
prim.type() != cldnn::roll::type_id()) {
|
||||
can_use_bs_fs_yx_bsv16_fsv16 = false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,20 +17,20 @@ CommonDispatchData SetDefault(const roll_params& kernel_params) {
|
||||
const auto out_layout = output.GetLayout();
|
||||
std::vector<std::vector<Tensor::DataChannelName>> dims_by_gws;
|
||||
|
||||
switch (out_layout) {
|
||||
case DataLayout::bfyx:
|
||||
switch (output.Dimentions()) {
|
||||
case 4:
|
||||
dispatch_data.gws = {output.X().v, output.Y().v, output.Feature().v * output.Batch().v};
|
||||
dims_by_gws = {{Tensor::DataChannelName::X},
|
||||
{Tensor::DataChannelName::Y},
|
||||
{Tensor::DataChannelName::FEATURE, Tensor::DataChannelName::BATCH}};
|
||||
break;
|
||||
case DataLayout::bfzyx:
|
||||
case 5:
|
||||
dispatch_data.gws = {output.X().v * output.Y().v, output.Z().v, output.Feature().v * output.Batch().v};
|
||||
dims_by_gws = {{Tensor::DataChannelName::X, Tensor::DataChannelName::Y},
|
||||
{Tensor::DataChannelName::Z},
|
||||
{Tensor::DataChannelName::FEATURE, Tensor::DataChannelName::BATCH}};
|
||||
break;
|
||||
case DataLayout::bfwzyx:
|
||||
case 6:
|
||||
dispatch_data.gws = {output.X().v * output.Y().v,
|
||||
output.Z().v * output.W().v,
|
||||
output.Feature().v * output.Batch().v};
|
||||
@@ -72,12 +72,8 @@ ParamsKey RollKernelRef::GetSupportedKey() const {
|
||||
ParamsKey key;
|
||||
key.EnableAllInputDataType();
|
||||
key.EnableAllOutputDataType();
|
||||
key.EnableInputLayout(DataLayout::bfyx);
|
||||
key.EnableInputLayout(DataLayout::bfzyx);
|
||||
key.EnableInputLayout(DataLayout::bfwzyx);
|
||||
key.EnableOutputLayout(DataLayout::bfyx);
|
||||
key.EnableOutputLayout(DataLayout::bfzyx);
|
||||
key.EnableOutputLayout(DataLayout::bfwzyx);
|
||||
key.EnableAllInputLayout();
|
||||
key.EnableAllOutputLayout();
|
||||
key.EnableTensorOffset();
|
||||
key.EnableTensorPitches();
|
||||
key.EnableBatching();
|
||||
|
||||
@@ -25,36 +25,40 @@ std::string vec2str(const std::vector<vecElementType>& vec) {
|
||||
}
|
||||
|
||||
template <class T>
|
||||
struct roll_test_params {
|
||||
struct roll_test_input {
|
||||
std::vector<int32_t> input_shape;
|
||||
std::vector<T> input_values;
|
||||
std::vector<int32_t> shift;
|
||||
std::vector<T> expected_values;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
using roll_test_params = std::tuple<roll_test_input<T>, format::type>;
|
||||
|
||||
template <class T>
|
||||
struct roll_test : testing::TestWithParam<roll_test_params<T>> {
|
||||
void test() {
|
||||
auto p = testing::TestWithParam<roll_test_params<T>>::GetParam();
|
||||
roll_test_input<T> p;
|
||||
format::type input_format;
|
||||
std::tie(p, input_format) = testing::TestWithParam<roll_test_params<T>>::GetParam();
|
||||
auto& engine = get_test_engine();
|
||||
|
||||
const auto input_format = format::get_default_format(p.input_shape.size());
|
||||
const layout data_layout(type_to_data_type<T>::value, input_format, tensor(input_format, p.input_shape));
|
||||
format::type plane_format = format::get_default_format(p.input_shape.size());
|
||||
const layout data_layout(type_to_data_type<T>::value, plane_format, tensor(input_format, p.input_shape));
|
||||
auto input = engine.allocate_memory(data_layout);
|
||||
set_values(input, p.input_values);
|
||||
|
||||
topology topology;
|
||||
topology.add(input_layout("input", input->get_layout()));
|
||||
topology.add(roll("roll", "input", tensor(input_format, p.shift)));
|
||||
topology.add(reorder("reordered_input", "input", input_format, type_to_data_type<T>::value));
|
||||
topology.add(roll("roll", "reordered_input", tensor(input_format, p.shift)));
|
||||
topology.add(reorder("reordered_roll", "roll", plane_format, type_to_data_type<T>::value));
|
||||
|
||||
network network(engine, topology);
|
||||
network.set_input_data("input", input);
|
||||
const auto outputs = network.execute();
|
||||
|
||||
EXPECT_EQ(outputs.size(), size_t(1));
|
||||
EXPECT_EQ(outputs.begin()->first, "roll");
|
||||
|
||||
auto output = outputs.at("roll").get_memory();
|
||||
auto output = outputs.at("reordered_roll").get_memory();
|
||||
cldnn::mem_lock<T> output_ptr(output, get_test_stream());
|
||||
|
||||
ASSERT_EQ(output_ptr.size(), p.expected_values.size());
|
||||
@@ -64,17 +68,18 @@ struct roll_test : testing::TestWithParam<roll_test_params<T>> {
|
||||
}
|
||||
|
||||
static std::string PrintToStringParamName(const testing::TestParamInfo<roll_test_params<T>>& info) {
|
||||
auto& p = info.param;
|
||||
auto& p = std::get<0>(info.param);
|
||||
std::ostringstream result;
|
||||
result << "InputShape=" << vec2str(p.input_shape) << "_";
|
||||
result << "Precision=" << data_type_traits::name(type_to_data_type<T>::value) << "_";
|
||||
result << "Shift=" << vec2str(p.shift);
|
||||
result << "Shift=" << vec2str(p.shift) << "_";
|
||||
result << "Format=" << std::get<1>(info.param);
|
||||
return result.str();
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
std::vector<roll_test_params<T>> getRollParams() {
|
||||
std::vector<roll_test_input<T>> getRollParamsToCheckLogic() {
|
||||
return {
|
||||
// from reference tests
|
||||
{{4, 3, 1, 1}, // Input shape
|
||||
@@ -90,17 +95,35 @@ std::vector<roll_test_params<T>> getRollParams() {
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, // Input values
|
||||
{1, 0, 0, 0}, // Shift
|
||||
{10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9}}, // Expected values
|
||||
};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::vector<roll_test_input<T>> getRollParamsToCheckLayouts() {
|
||||
return {
|
||||
// custom tests
|
||||
// 4d
|
||||
{{2, 3, 1, 2}, // Input shape
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}, // Input values
|
||||
{1, 2, 0, 5}, // Shift
|
||||
{10, 9, 12, 11, 8, 7, 4, 3, 6, 5, 2, 1}}, // Expected values
|
||||
};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::vector<roll_test_input<T>> getRollParams5D() {
|
||||
return {
|
||||
// 5d
|
||||
{{1, 1, 3, 3, 2}, // Input shape
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}, // Input values
|
||||
{1, 2, 10, 23, 6}, // Shift
|
||||
{15, 16, 17, 18, 13, 14, 3, 4, 5, 6, 1, 2, 9, 10, 11, 12, 7, 8}}, // Expected values
|
||||
};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::vector<roll_test_input<T>> getRollParams6D() {
|
||||
return {
|
||||
// 6d
|
||||
{{2, 1, 1, 3, 2, 3}, // Input shape
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, //
|
||||
@@ -112,7 +135,7 @@ std::vector<roll_test_params<T>> getRollParams() {
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::vector<roll_test_params<T>> getRollFloatingPointParams() {
|
||||
std::vector<roll_test_input<T>> getRollFloatingPointParams() {
|
||||
return {
|
||||
// from reference tests
|
||||
{{4, 3, 1, 1}, // Input shape
|
||||
@@ -141,7 +164,13 @@ std::vector<roll_test_params<T>> getRollFloatingPointParams() {
|
||||
40.5383f,
|
||||
-15.3859f,
|
||||
-4.5881f}}}, // Expected values
|
||||
{{4, 3, 1, 1}, // Input shape
|
||||
};
|
||||
}
|
||||
|
||||
template <class T>
|
||||
std::vector<roll_test_input<T>> getRollFloatingPointAdditionalLogic() {
|
||||
return {
|
||||
{{4, 3, 1, 1}, // Input shape
|
||||
{50.2907f,
|
||||
70.8054f,
|
||||
-68.3403f,
|
||||
@@ -178,22 +207,53 @@ std::vector<roll_test_params<T>> getRollFloatingPointParams() {
|
||||
};
|
||||
}
|
||||
|
||||
#define INSTANTIATE_ROLL_TEST_SUITE(type, func) \
|
||||
using roll_test_##type = roll_test<type>; \
|
||||
TEST_P(roll_test_##type, roll_##type) { \
|
||||
test(); \
|
||||
} \
|
||||
INSTANTIATE_TEST_SUITE_P(roll_smoke_##type, \
|
||||
roll_test_##type, \
|
||||
testing::ValuesIn(func<type>()), \
|
||||
roll_test_##type::PrintToStringParamName);
|
||||
std::vector<format::type> formats4d = {format::bfyx,
|
||||
format::bs_fs_yx_bsv32_fsv32,
|
||||
format::bs_fs_yx_bsv32_fsv16,
|
||||
format::b_fs_yx_fsv32,
|
||||
format::b_fs_yx_fsv16,
|
||||
format::bs_fs_yx_bsv16_fsv16};
|
||||
|
||||
INSTANTIATE_ROLL_TEST_SUITE(int8_t, getRollParams)
|
||||
INSTANTIATE_ROLL_TEST_SUITE(uint8_t, getRollParams)
|
||||
INSTANTIATE_ROLL_TEST_SUITE(int32_t, getRollParams)
|
||||
INSTANTIATE_ROLL_TEST_SUITE(int64_t, getRollParams)
|
||||
INSTANTIATE_ROLL_TEST_SUITE(FLOAT16, getRollFloatingPointParams)
|
||||
INSTANTIATE_ROLL_TEST_SUITE(float, getRollFloatingPointParams)
|
||||
std::vector<format::type> formats5d = {format::bfzyx,
|
||||
format::bs_fs_zyx_bsv32_fsv32,
|
||||
format::bs_fs_zyx_bsv32_fsv16,
|
||||
format::b_fs_zyx_fsv32,
|
||||
format::b_fs_zyx_fsv16,
|
||||
format::bs_fs_zyx_bsv16_fsv16};
|
||||
|
||||
std::vector<format::type> formats6d = {format::bfwzyx};
|
||||
|
||||
#define INSTANTIATE_ROLL_TEST_SUITE(type, func, formats) \
|
||||
class roll_test_##type##func : public roll_test<type> {}; \
|
||||
TEST_P(roll_test_##type##func, roll_##type##func) { \
|
||||
test(); \
|
||||
} \
|
||||
INSTANTIATE_TEST_SUITE_P(roll_smoke_##type##func, \
|
||||
roll_test_##type##func, \
|
||||
testing::Combine(testing::ValuesIn(func<type>()), testing::ValuesIn(formats)), \
|
||||
roll_test_##type##func::PrintToStringParamName);
|
||||
|
||||
INSTANTIATE_ROLL_TEST_SUITE(int8_t, getRollParamsToCheckLogic, {format::bfyx})
|
||||
INSTANTIATE_ROLL_TEST_SUITE(uint8_t, getRollParamsToCheckLogic, {format::bfyx})
|
||||
INSTANTIATE_ROLL_TEST_SUITE(int32_t, getRollParamsToCheckLogic, {format::bfyx})
|
||||
INSTANTIATE_ROLL_TEST_SUITE(int64_t, getRollParamsToCheckLogic, {format::bfyx})
|
||||
INSTANTIATE_ROLL_TEST_SUITE(int8_t, getRollParamsToCheckLayouts, formats4d)
|
||||
INSTANTIATE_ROLL_TEST_SUITE(uint8_t, getRollParamsToCheckLayouts, formats4d)
|
||||
INSTANTIATE_ROLL_TEST_SUITE(int32_t, getRollParamsToCheckLayouts, formats4d)
|
||||
INSTANTIATE_ROLL_TEST_SUITE(int64_t, getRollParamsToCheckLayouts, formats4d)
|
||||
INSTANTIATE_ROLL_TEST_SUITE(int8_t, getRollParams5D, formats5d)
|
||||
INSTANTIATE_ROLL_TEST_SUITE(uint8_t, getRollParams5D, formats5d)
|
||||
INSTANTIATE_ROLL_TEST_SUITE(int32_t, getRollParams5D, formats5d)
|
||||
INSTANTIATE_ROLL_TEST_SUITE(int64_t, getRollParams5D, formats5d)
|
||||
INSTANTIATE_ROLL_TEST_SUITE(int8_t, getRollParams6D, formats6d)
|
||||
INSTANTIATE_ROLL_TEST_SUITE(uint8_t, getRollParams6D, formats6d)
|
||||
INSTANTIATE_ROLL_TEST_SUITE(int32_t, getRollParams6D, formats6d)
|
||||
INSTANTIATE_ROLL_TEST_SUITE(int64_t, getRollParams6D, formats6d)
|
||||
|
||||
INSTANTIATE_ROLL_TEST_SUITE(FLOAT16, getRollFloatingPointParams, formats4d)
|
||||
INSTANTIATE_ROLL_TEST_SUITE(float, getRollFloatingPointParams, formats4d)
|
||||
INSTANTIATE_ROLL_TEST_SUITE(FLOAT16, getRollFloatingPointAdditionalLogic, {format::bfyx})
|
||||
INSTANTIATE_ROLL_TEST_SUITE(float, getRollFloatingPointAdditionalLogic, {format::bfyx})
|
||||
|
||||
#undef INSTANTIATE_ROLL_TEST_SUITE
|
||||
|
||||
|
||||
Reference in New Issue
Block a user