[GPU] Fix strided slice kernel with begin/end/stride as inputs (#16302)
This commit is contained in:
parent
d59d8ba3a2
commit
28d3e1087e
@ -10,17 +10,15 @@ inline void FUNC(get_slice_step)(OPTIONAL_SHAPE_INFO_ARG
|
|||||||
int* step_batch, int* step_feature,
|
int* step_batch, int* step_feature,
|
||||||
int* step_z, int* step_y, int* step_x)
|
int* step_z, int* step_y, int* step_x)
|
||||||
{
|
{
|
||||||
|
const uint batch_index = 0;
|
||||||
|
const uint feature_index = 1;
|
||||||
#ifdef OUTPUT_LAYOUT_BFYX
|
#ifdef OUTPUT_LAYOUT_BFYX
|
||||||
const uint batch_index = STRIDE_GET_INDEX(0, 0, 0, 0);
|
const uint y_index = 2;
|
||||||
const uint feature_index = STRIDE_GET_INDEX(1, 0, 0, 0);
|
const uint x_index = 3;
|
||||||
const uint y_index = STRIDE_GET_INDEX(2, 0, 0, 0);
|
|
||||||
const uint x_index = STRIDE_GET_INDEX(3, 0, 0, 0);
|
|
||||||
#elif OUTPUT_LAYOUT_BFZYX
|
#elif OUTPUT_LAYOUT_BFZYX
|
||||||
const uint batch_index = STRIDE_GET_INDEX(0, 0, 0, 0, 0);
|
const uint z_index = 2;
|
||||||
const uint feature_index = STRIDE_GET_INDEX(1, 0, 0, 0, 0);
|
const uint y_index = 3;
|
||||||
const uint z_index = STRIDE_GET_INDEX(2, 0, 0, 0, 0);
|
const uint x_index = 4;
|
||||||
const uint y_index = STRIDE_GET_INDEX(3, 0, 0, 0, 0);
|
|
||||||
const uint x_index = STRIDE_GET_INDEX(4, 0, 0, 0, 0);
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
*step_batch = batch_index < STRIDE_DIMS ? stride[batch_index] : 1;
|
*step_batch = batch_index < STRIDE_DIMS ? stride[batch_index] : 1;
|
||||||
@ -59,17 +57,15 @@ inline void FUNC(get_slice_end)(OPTIONAL_SHAPE_INFO_ARG
|
|||||||
const uint out_z_num = INPUT0_SIZE_Z;
|
const uint out_z_num = INPUT0_SIZE_Z;
|
||||||
const uint out_y_num = INPUT0_SIZE_Y;
|
const uint out_y_num = INPUT0_SIZE_Y;
|
||||||
const uint out_x_num = INPUT0_SIZE_X;
|
const uint out_x_num = INPUT0_SIZE_X;
|
||||||
|
const uint batch_index = 0;
|
||||||
|
const uint feature_index = 1;
|
||||||
#ifdef OUTPUT_LAYOUT_BFYX
|
#ifdef OUTPUT_LAYOUT_BFYX
|
||||||
const uint batch_index = END_GET_INDEX(0, 0, 0, 0);
|
const uint y_index = 2;
|
||||||
const uint feature_index = END_GET_INDEX(1, 0, 0, 0);
|
const uint x_index = 3;
|
||||||
const uint y_index = END_GET_INDEX(2, 0, 0, 0);
|
|
||||||
const uint x_index = END_GET_INDEX(3, 0, 0, 0);
|
|
||||||
#elif OUTPUT_LAYOUT_BFZYX
|
#elif OUTPUT_LAYOUT_BFZYX
|
||||||
const uint batch_index = END_GET_INDEX(0, 0, 0, 0, 0);
|
const uint z_index = 2;
|
||||||
const uint feature_index = END_GET_INDEX(1, 0, 0, 0, 0);
|
const uint y_index = 3;
|
||||||
const uint z_index = END_GET_INDEX(2, 0, 0, 0, 0);
|
const uint x_index = 4;
|
||||||
const uint y_index = END_GET_INDEX(3, 0, 0, 0, 0);
|
|
||||||
const uint x_index = END_GET_INDEX(4, 0, 0, 0, 0);
|
|
||||||
#endif
|
#endif
|
||||||
END_TYPE batch = batch_index < END_DIMS ? end[batch_index] : 0;
|
END_TYPE batch = batch_index < END_DIMS ? end[batch_index] : 0;
|
||||||
END_TYPE feature = feature_index < END_DIMS ? end[feature_index] : 0;
|
END_TYPE feature = feature_index < END_DIMS ? end[feature_index] : 0;
|
||||||
@ -176,17 +172,15 @@ inline void FUNC(get_slice_begin)(OPTIONAL_SHAPE_INFO_ARG
|
|||||||
const uint out_z_num = INPUT0_SIZE_Z;
|
const uint out_z_num = INPUT0_SIZE_Z;
|
||||||
const uint out_y_num = INPUT0_SIZE_Y;
|
const uint out_y_num = INPUT0_SIZE_Y;
|
||||||
const uint out_x_num = INPUT0_SIZE_X;
|
const uint out_x_num = INPUT0_SIZE_X;
|
||||||
|
const uint batch_index = 0;
|
||||||
|
const uint feature_index = 1;
|
||||||
#ifdef OUTPUT_LAYOUT_BFYX
|
#ifdef OUTPUT_LAYOUT_BFYX
|
||||||
const uint batch_index = STRIDE_GET_INDEX(0, 0, 0, 0);
|
const uint y_index = 2;
|
||||||
const uint feature_index = STRIDE_GET_INDEX(1, 0, 0, 0);
|
const uint x_index = 3;
|
||||||
const uint y_index = STRIDE_GET_INDEX(2, 0, 0, 0);
|
|
||||||
const uint x_index = STRIDE_GET_INDEX(3, 0, 0, 0);
|
|
||||||
#elif OUTPUT_LAYOUT_BFZYX
|
#elif OUTPUT_LAYOUT_BFZYX
|
||||||
const uint batch_index = STRIDE_GET_INDEX(0, 0, 0, 0, 0);
|
const uint z_index = 2;
|
||||||
const uint feature_index = STRIDE_GET_INDEX(1, 0, 0, 0, 0);
|
const uint y_index = 3;
|
||||||
const uint z_index = STRIDE_GET_INDEX(2, 0, 0, 0, 0);
|
const uint x_index = 4;
|
||||||
const uint y_index = STRIDE_GET_INDEX(3, 0, 0, 0, 0);
|
|
||||||
const uint x_index = STRIDE_GET_INDEX(4, 0, 0, 0, 0);
|
|
||||||
#endif
|
#endif
|
||||||
BEGIN_TYPE batch = batch_index < BEGIN_DIMS ? begin[batch_index] : 0;
|
BEGIN_TYPE batch = batch_index < BEGIN_DIMS ? begin[batch_index] : 0;
|
||||||
BEGIN_TYPE feature = feature_index < BEGIN_DIMS ? begin[feature_index] : 0;
|
BEGIN_TYPE feature = feature_index < BEGIN_DIMS ? begin[feature_index] : 0;
|
||||||
|
@ -620,6 +620,59 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void test_2x2x2_all_dynamic_bcast() {
|
||||||
|
auto& engine = get_test_engine();
|
||||||
|
auto input_lay = layout{ ov::PartialShape::dynamic(3), data_types::f32, format::bfyx };
|
||||||
|
auto input = engine.allocate_memory({ ov::PartialShape{ 2, 2, 2 }, data_types::f32, format::bfyx });
|
||||||
|
auto begin = engine.allocate_memory({ ov::PartialShape{ 1 }, data_types::i64, format::bfyx });
|
||||||
|
auto end = engine.allocate_memory({ ov::PartialShape{ 1 }, data_types::i64, format::bfyx });
|
||||||
|
auto strides = engine.allocate_memory({ ov::PartialShape{ 1 }, data_types::i64, format::bfyx });
|
||||||
|
|
||||||
|
set_values(input, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f});
|
||||||
|
set_values<int64_t>(begin, {1});
|
||||||
|
set_values<int64_t>(end, {2});
|
||||||
|
set_values<int64_t>(strides, {1});
|
||||||
|
|
||||||
|
topology topology;
|
||||||
|
topology.add(input_layout("input", input_lay));
|
||||||
|
topology.add(data("input2", begin));
|
||||||
|
topology.add(data("input3", end));
|
||||||
|
topology.add(data("input4", strides));
|
||||||
|
topology.add(strided_slice("strided_slice", input_info("input"), input_info("input2"), input_info("input3"), input_info("input4"), {}, {}, {}, {}, {}, {}));
|
||||||
|
|
||||||
|
ExecutionConfig config;
|
||||||
|
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
|
||||||
|
network network(engine, topology, config);
|
||||||
|
|
||||||
|
network.set_input_data("input", input);
|
||||||
|
|
||||||
|
auto inst = network.get_primitive("strided_slice");
|
||||||
|
auto impl = inst->get_impl();
|
||||||
|
ASSERT_TRUE(impl != nullptr);
|
||||||
|
ASSERT_TRUE(impl->is_dynamic());
|
||||||
|
|
||||||
|
auto outputs = network.execute();
|
||||||
|
|
||||||
|
ASSERT_EQ(outputs.size(), size_t(1));
|
||||||
|
ASSERT_EQ(outputs.begin()->first, "strided_slice");
|
||||||
|
|
||||||
|
auto output = outputs.at("strided_slice").get_memory();
|
||||||
|
|
||||||
|
ov::PartialShape expected_shape{1, 2, 2};
|
||||||
|
|
||||||
|
ASSERT_EQ(output->get_layout().get_partial_shape(), expected_shape);
|
||||||
|
|
||||||
|
std::vector<float> answers = {
|
||||||
|
4.0f, 5.0f, 6.0f, 7.0f
|
||||||
|
};
|
||||||
|
|
||||||
|
cldnn::mem_lock<float> output_ptr(output, get_test_stream());
|
||||||
|
|
||||||
|
for (size_t i = 0; i < answers.size(); ++i) {
|
||||||
|
ASSERT_EQ(answers[i], output_ptr[i]) << " i = " << i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void test_2x2x2x1x1_2_negative_all_dynamic_begin() {
|
void test_2x2x2x1x1_2_negative_all_dynamic_begin() {
|
||||||
auto& engine = get_test_engine();
|
auto& engine = get_test_engine();
|
||||||
auto input = engine.allocate_memory({ ov::PartialShape{ 2, 2, 2 }, data_types::f32, format::bfyx });
|
auto input = engine.allocate_memory({ ov::PartialShape{ 2, 2, 2 }, data_types::f32, format::bfyx });
|
||||||
@ -1593,7 +1646,7 @@ public:
|
|||||||
ASSERT_TRUE(are_equal(answers[i], output_ptr[i]));
|
ASSERT_TRUE(are_equal(answers[i], output_ptr[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
class strided_slice_gpu_i8: public ::testing::Test {
|
class strided_slice_gpu_i8: public ::testing::Test {
|
||||||
public:
|
public:
|
||||||
@ -1822,6 +1875,10 @@ TEST_F(strided_slice_gpu, test_2x2x2x1x1_2_negative_all_dynamic_begin) {
|
|||||||
this->test_2x2x2x1x1_2_negative_all_dynamic_begin();
|
this->test_2x2x2x1x1_2_negative_all_dynamic_begin();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(strided_slice_gpu, test_2x2x2_all_dynamic_bcast) {
|
||||||
|
this->test_2x2x2_all_dynamic_bcast();
|
||||||
|
}
|
||||||
|
|
||||||
#ifdef RUN_ALL_MODEL_CACHING_TESTS
|
#ifdef RUN_ALL_MODEL_CACHING_TESTS
|
||||||
TEST_F(strided_slice_gpu, test_2x2x2x2_full_cached) {
|
TEST_F(strided_slice_gpu, test_2x2x2x2_full_cached) {
|
||||||
this->test_2x2x2x2_full(true);
|
this->test_2x2x2x2_full(true);
|
||||||
|
Loading…
Reference in New Issue
Block a user