[GPU] Strided slice primitive params update (#11339)

This commit is contained in:
Vladimir Paramuzov 2022-03-31 08:54:05 +03:00 committed by GitHub
parent 3e58ccbce7
commit f5f93cfbeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 43 additions and 43 deletions

View File

@ -38,10 +38,10 @@ struct strided_slice : public primitive_base<strided_slice> {
const primitive_id& begin_id,
const primitive_id& end_id,
const primitive_id& strides_id,
std::vector<uint8_t> begin_mask,
std::vector<uint8_t> end_mask,
std::vector<uint8_t> new_axis_mask,
std::vector<uint8_t> shrink_axis_mask,
std::vector<int64_t> begin_mask,
std::vector<int64_t> end_mask,
std::vector<int64_t> new_axis_mask,
std::vector<int64_t> shrink_axis_mask,
const ov::Shape out_size,
const primitive_id& ext_prim_id = "",
const padding& output_padding = padding())
@ -53,13 +53,13 @@ struct strided_slice : public primitive_base<strided_slice> {
out_size(out_size) {}
/// @brief Array of bits, that provide replace begin[i] to max possible range in that dimension.
std::vector<uint8_t> begin_mask;
std::vector<int64_t> begin_mask;
/// @brief Array of bits, that provide replace end[i] to max possible range in that dimension.
std::vector<uint8_t> end_mask;
std::vector<int64_t> end_mask;
/// @brief Array of bits, that provide adding a new length 1 dimension at ith position in the output tensor.
std::vector<uint8_t> new_axis_mask;
std::vector<int64_t> new_axis_mask;
/// @brief Array of bits, that provide shrinks the dimensionality by 1, taking on the value at index begin[i].
std::vector<uint8_t> shrink_axis_mask;
std::vector<int64_t> shrink_axis_mask;
/// @brief Size of output tensor
ov::Shape out_size;
};

View File

@ -52,12 +52,28 @@ public:
params.striding_params.push_back(sizes);
}
params.end_mask = arg.get_primitive()->end_mask;
auto begin_mask_ = arg.get_primitive()->begin_mask;
auto end_mask_ = arg.get_primitive()->end_mask;
auto new_axis_mask_ = arg.get_primitive()->new_axis_mask;
auto shrink_axis_mask_ = arg.get_primitive()->shrink_axis_mask;
std::vector<uint8_t> begin_mask(begin_mask_.begin(), begin_mask_.end());
std::vector<uint8_t> end_mask(end_mask_.begin(), end_mask_.end());
std::vector<uint8_t> new_axis_mask(new_axis_mask_.begin(), new_axis_mask_.end());
std::vector<uint8_t> shrink_axis_mask(shrink_axis_mask_.begin(), shrink_axis_mask_.end());
// Plugin requires inverted mask values. Consider changing primitive impl to be aligned with the spec.
for (auto& b : begin_mask) {
b = 1 - b;
}
for (auto& e : end_mask) {
e = 1 - e;
}
params.end_mask = end_mask;
pad_vector_to_size(params.end_mask, dims_num, 1);
params.begin_mask = arg.get_primitive()->begin_mask;
params.begin_mask = begin_mask;
pad_vector_to_size(params.begin_mask, dims_num, 1);
params.new_axis_mask = arg.get_primitive()->new_axis_mask;
params.shrink_axis_mask = arg.get_primitive()->shrink_axis_mask;
params.new_axis_mask = new_axis_mask;
params.shrink_axis_mask = shrink_axis_mask;
pad_vector_to_size(params.shrink_axis_mask, dims_num, 0);
std::vector<size_t> logical_dims = params.inputs[0].LogicalDims();

View File

@ -20,7 +20,8 @@ layout strided_slice_inst::calc_output_layout(strided_slice_node const& node) {
auto desc = node.get_primitive();
auto input_layout = node.input(0).get_output_layout();
auto output_format = format::get_default_format(desc->out_size.size());
std::vector<tensor::value_type> dims_converted(desc->out_size.begin(), desc->out_size.end());
auto out_shape = desc->out_size;
std::vector<tensor::value_type> dims_converted(out_shape.begin(), out_shape.end());
// extend shape to 4d
for (size_t i = dims_converted.size(); i < 4; i++) {
dims_converted.push_back(1);

View File

@ -234,33 +234,16 @@ static void CreateStridedSliceOp(Program& p, const std::shared_ptr<ngraph::op::v
return;
} while (false);
auto end_mask_ = op->get_end_mask();
auto begin_mask_ = op->get_begin_mask();
auto new_axis_mask_ = op->get_new_axis_mask();
auto shrink_axis_mask_ = op->get_shrink_axis_mask();
std::vector<uint8_t> begin_mask(begin_mask_.begin(), begin_mask_.end());
std::vector<uint8_t> end_mask(end_mask_.begin(), end_mask_.end());
std::vector<uint8_t> new_axis_mask(new_axis_mask_.begin(), new_axis_mask_.end());
std::vector<uint8_t> shrink_axis_mask(shrink_axis_mask_.begin(), shrink_axis_mask_.end());
// Plugin requires inverted mask values. Consider changing primitive impl to be aligned with the spec.
for (auto& b : begin_mask) {
b = 1 - b;
}
for (auto& e : end_mask) {
e = 1 - e;
}
auto stridedSlicePrim = cldnn::strided_slice(layerName,
inputPrimitives[0],
inputPrimitives[1],
inputPrimitives[2],
inputPrimitives[3],
begin_mask,
end_mask,
new_axis_mask,
shrink_axis_mask,
op->get_output_shape(0),
op->get_begin_mask(),
op->get_end_mask(),
op->get_new_axis_mask(),
op->get_shrink_axis_mask(),
op->get_output_partial_shape(0).to_shape(),
op->get_friendly_name());
p.AddPrimitive(stridedSlicePrim);

View File

@ -159,7 +159,7 @@ TEST(strided_slice_gpu_f32_i32, test_2x2x2x2_ignore) {
topology.add(data("input2", begin));
topology.add(data("input3", end));
topology.add(data("input4", strides));
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {0, 0, 0, 0}, {0, 0, 0, 0}, {}, {}, {2, 2, 2, 2}));
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}, {2, 2, 2, 2}));
network network(engine, topology);
@ -218,7 +218,7 @@ TEST(strided_slice_gpu_f32_i64, test_2x2x2x2_ignore) {
topology.add(data("input2", begin));
topology.add(data("input3", end));
topology.add(data("input4", strides));
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {0, 0, 0, 0}, {0, 0, 0, 0}, {}, {}, {2, 2, 2, 2}));
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}, {2, 2, 2, 2}));
network network(engine, topology);
@ -393,7 +393,7 @@ TEST(strided_slice_gpu_f32_i32, test_2x2x4x3_stride) {
topology.add(data("input2", begin));
topology.add(data("input3", end));
topology.add(data("input4", strides));
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {0, 0, 0, 0}, {0, 0, 0, 0}, {}, {}, {2, 2, 2, 3}));
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}, {2, 2, 2, 3}));
network network(engine, topology);
@ -456,7 +456,7 @@ TEST(strided_slice_gpu_f32_i64, test_2x2x4x3_stride) {
topology.add(data("input2", begin));
topology.add(data("input3", end));
topology.add(data("input4", strides));
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {0, 0, 0, 0}, {0, 0, 0, 0}, {}, {}, {2, 2, 2, 3}));
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}, {2, 2, 2, 3}));
network network(engine, topology);
@ -532,7 +532,7 @@ TEST(strided_slice_gpu_f32_i32, test_2x2x4x4_part_stride) {
topology.add(data("input2", begin));
topology.add(data("input3", end));
topology.add(data("input4", strides));
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {1, 0, 0, 1}, {}, {}, {}, {1, 2, 4, 2}));
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {0, 1, 1, 0}, {}, {}, {}, {1, 2, 4, 2}));
network network(engine, topology);
@ -615,7 +615,7 @@ TEST(strided_slice_gpu_f32_i64, test_2x2x4x4_part_stride) {
topology.add(data("input2", begin));
topology.add(data("input3", end));
topology.add(data("input4", strides));
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {1, 0, 0, 1}, {}, {}, {}, {1, 2, 4, 2}));
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {0, 1, 1, 0}, {}, {}, {}, {1, 2, 4, 2}));
network network(engine, topology);
@ -897,7 +897,7 @@ TEST(strided_slice_gpu_f32_i32, test_2x2x1x1) {
topology.add(data("input2", begin));
topology.add(data("input3", end));
topology.add(data("input4", strides));
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {0,1}, {}, {}, {}, {2, 2, 1, 1}));
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {1, 0}, {}, {}, {}, {2, 2, 1, 1}));
network network(engine, topology);
@ -950,7 +950,7 @@ TEST(strided_slice_gpu_f32_i64, test_2x2x1x1) {
topology.add(data("input2", begin));
topology.add(data("input3", end));
topology.add(data("input4", strides));
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {0,1}, {}, {}, {}, {2, 2, 1, 1}));
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {1, 0}, {}, {}, {}, {2, 2, 1, 1}));
network network(engine, topology);