[GPU] Strided slice primitive params update (#11339)
This commit is contained in:
parent
3e58ccbce7
commit
f5f93cfbeb
@ -38,10 +38,10 @@ struct strided_slice : public primitive_base<strided_slice> {
|
||||
const primitive_id& begin_id,
|
||||
const primitive_id& end_id,
|
||||
const primitive_id& strides_id,
|
||||
std::vector<uint8_t> begin_mask,
|
||||
std::vector<uint8_t> end_mask,
|
||||
std::vector<uint8_t> new_axis_mask,
|
||||
std::vector<uint8_t> shrink_axis_mask,
|
||||
std::vector<int64_t> begin_mask,
|
||||
std::vector<int64_t> end_mask,
|
||||
std::vector<int64_t> new_axis_mask,
|
||||
std::vector<int64_t> shrink_axis_mask,
|
||||
const ov::Shape out_size,
|
||||
const primitive_id& ext_prim_id = "",
|
||||
const padding& output_padding = padding())
|
||||
@ -53,13 +53,13 @@ struct strided_slice : public primitive_base<strided_slice> {
|
||||
out_size(out_size) {}
|
||||
|
||||
/// @brief Array of bits, that provide replace begin[i] to max possible range in that dimension.
|
||||
std::vector<uint8_t> begin_mask;
|
||||
std::vector<int64_t> begin_mask;
|
||||
/// @brief Array of bits, that provide replace end[i] to max possible range in that dimension.
|
||||
std::vector<uint8_t> end_mask;
|
||||
std::vector<int64_t> end_mask;
|
||||
/// @brief Array of bits, that provide adding a new length 1 dimension at ith position in the output tensor.
|
||||
std::vector<uint8_t> new_axis_mask;
|
||||
std::vector<int64_t> new_axis_mask;
|
||||
/// @brief Array of bits, that provide shrinks the dimensionality by 1, taking on the value at index begin[i].
|
||||
std::vector<uint8_t> shrink_axis_mask;
|
||||
std::vector<int64_t> shrink_axis_mask;
|
||||
/// @brief Size of output tensor
|
||||
ov::Shape out_size;
|
||||
};
|
||||
|
@ -52,12 +52,28 @@ public:
|
||||
params.striding_params.push_back(sizes);
|
||||
}
|
||||
|
||||
params.end_mask = arg.get_primitive()->end_mask;
|
||||
auto begin_mask_ = arg.get_primitive()->begin_mask;
|
||||
auto end_mask_ = arg.get_primitive()->end_mask;
|
||||
auto new_axis_mask_ = arg.get_primitive()->new_axis_mask;
|
||||
auto shrink_axis_mask_ = arg.get_primitive()->shrink_axis_mask;
|
||||
|
||||
std::vector<uint8_t> begin_mask(begin_mask_.begin(), begin_mask_.end());
|
||||
std::vector<uint8_t> end_mask(end_mask_.begin(), end_mask_.end());
|
||||
std::vector<uint8_t> new_axis_mask(new_axis_mask_.begin(), new_axis_mask_.end());
|
||||
std::vector<uint8_t> shrink_axis_mask(shrink_axis_mask_.begin(), shrink_axis_mask_.end());
|
||||
// Plugin requires inverted mask values. Consider changing primitive impl to be aligned with the spec.
|
||||
for (auto& b : begin_mask) {
|
||||
b = 1 - b;
|
||||
}
|
||||
for (auto& e : end_mask) {
|
||||
e = 1 - e;
|
||||
}
|
||||
params.end_mask = end_mask;
|
||||
pad_vector_to_size(params.end_mask, dims_num, 1);
|
||||
params.begin_mask = arg.get_primitive()->begin_mask;
|
||||
params.begin_mask = begin_mask;
|
||||
pad_vector_to_size(params.begin_mask, dims_num, 1);
|
||||
params.new_axis_mask = arg.get_primitive()->new_axis_mask;
|
||||
params.shrink_axis_mask = arg.get_primitive()->shrink_axis_mask;
|
||||
params.new_axis_mask = new_axis_mask;
|
||||
params.shrink_axis_mask = shrink_axis_mask;
|
||||
pad_vector_to_size(params.shrink_axis_mask, dims_num, 0);
|
||||
|
||||
std::vector<size_t> logical_dims = params.inputs[0].LogicalDims();
|
||||
|
@ -20,7 +20,8 @@ layout strided_slice_inst::calc_output_layout(strided_slice_node const& node) {
|
||||
auto desc = node.get_primitive();
|
||||
auto input_layout = node.input(0).get_output_layout();
|
||||
auto output_format = format::get_default_format(desc->out_size.size());
|
||||
std::vector<tensor::value_type> dims_converted(desc->out_size.begin(), desc->out_size.end());
|
||||
auto out_shape = desc->out_size;
|
||||
std::vector<tensor::value_type> dims_converted(out_shape.begin(), out_shape.end());
|
||||
// extend shape to 4d
|
||||
for (size_t i = dims_converted.size(); i < 4; i++) {
|
||||
dims_converted.push_back(1);
|
||||
|
@ -234,33 +234,16 @@ static void CreateStridedSliceOp(Program& p, const std::shared_ptr<ngraph::op::v
|
||||
return;
|
||||
} while (false);
|
||||
|
||||
auto end_mask_ = op->get_end_mask();
|
||||
auto begin_mask_ = op->get_begin_mask();
|
||||
auto new_axis_mask_ = op->get_new_axis_mask();
|
||||
auto shrink_axis_mask_ = op->get_shrink_axis_mask();
|
||||
std::vector<uint8_t> begin_mask(begin_mask_.begin(), begin_mask_.end());
|
||||
std::vector<uint8_t> end_mask(end_mask_.begin(), end_mask_.end());
|
||||
std::vector<uint8_t> new_axis_mask(new_axis_mask_.begin(), new_axis_mask_.end());
|
||||
std::vector<uint8_t> shrink_axis_mask(shrink_axis_mask_.begin(), shrink_axis_mask_.end());
|
||||
|
||||
// Plugin requires inverted mask values. Consider changing primitive impl to be aligned with the spec.
|
||||
for (auto& b : begin_mask) {
|
||||
b = 1 - b;
|
||||
}
|
||||
for (auto& e : end_mask) {
|
||||
e = 1 - e;
|
||||
}
|
||||
|
||||
auto stridedSlicePrim = cldnn::strided_slice(layerName,
|
||||
inputPrimitives[0],
|
||||
inputPrimitives[1],
|
||||
inputPrimitives[2],
|
||||
inputPrimitives[3],
|
||||
begin_mask,
|
||||
end_mask,
|
||||
new_axis_mask,
|
||||
shrink_axis_mask,
|
||||
op->get_output_shape(0),
|
||||
op->get_begin_mask(),
|
||||
op->get_end_mask(),
|
||||
op->get_new_axis_mask(),
|
||||
op->get_shrink_axis_mask(),
|
||||
op->get_output_partial_shape(0).to_shape(),
|
||||
op->get_friendly_name());
|
||||
|
||||
p.AddPrimitive(stridedSlicePrim);
|
||||
|
@ -159,7 +159,7 @@ TEST(strided_slice_gpu_f32_i32, test_2x2x2x2_ignore) {
|
||||
topology.add(data("input2", begin));
|
||||
topology.add(data("input3", end));
|
||||
topology.add(data("input4", strides));
|
||||
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {0, 0, 0, 0}, {0, 0, 0, 0}, {}, {}, {2, 2, 2, 2}));
|
||||
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}, {2, 2, 2, 2}));
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
@ -218,7 +218,7 @@ TEST(strided_slice_gpu_f32_i64, test_2x2x2x2_ignore) {
|
||||
topology.add(data("input2", begin));
|
||||
topology.add(data("input3", end));
|
||||
topology.add(data("input4", strides));
|
||||
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {0, 0, 0, 0}, {0, 0, 0, 0}, {}, {}, {2, 2, 2, 2}));
|
||||
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}, {2, 2, 2, 2}));
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
@ -393,7 +393,7 @@ TEST(strided_slice_gpu_f32_i32, test_2x2x4x3_stride) {
|
||||
topology.add(data("input2", begin));
|
||||
topology.add(data("input3", end));
|
||||
topology.add(data("input4", strides));
|
||||
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {0, 0, 0, 0}, {0, 0, 0, 0}, {}, {}, {2, 2, 2, 3}));
|
||||
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}, {2, 2, 2, 3}));
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
@ -456,7 +456,7 @@ TEST(strided_slice_gpu_f32_i64, test_2x2x4x3_stride) {
|
||||
topology.add(data("input2", begin));
|
||||
topology.add(data("input3", end));
|
||||
topology.add(data("input4", strides));
|
||||
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {0, 0, 0, 0}, {0, 0, 0, 0}, {}, {}, {2, 2, 2, 3}));
|
||||
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {1, 1, 1, 1}, {1, 1, 1, 1}, {}, {}, {2, 2, 2, 3}));
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
@ -532,7 +532,7 @@ TEST(strided_slice_gpu_f32_i32, test_2x2x4x4_part_stride) {
|
||||
topology.add(data("input2", begin));
|
||||
topology.add(data("input3", end));
|
||||
topology.add(data("input4", strides));
|
||||
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {1, 0, 0, 1}, {}, {}, {}, {1, 2, 4, 2}));
|
||||
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {0, 1, 1, 0}, {}, {}, {}, {1, 2, 4, 2}));
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
@ -615,7 +615,7 @@ TEST(strided_slice_gpu_f32_i64, test_2x2x4x4_part_stride) {
|
||||
topology.add(data("input2", begin));
|
||||
topology.add(data("input3", end));
|
||||
topology.add(data("input4", strides));
|
||||
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {1, 0, 0, 1}, {}, {}, {}, {1, 2, 4, 2}));
|
||||
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {0, 1, 1, 0}, {}, {}, {}, {1, 2, 4, 2}));
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
@ -897,7 +897,7 @@ TEST(strided_slice_gpu_f32_i32, test_2x2x1x1) {
|
||||
topology.add(data("input2", begin));
|
||||
topology.add(data("input3", end));
|
||||
topology.add(data("input4", strides));
|
||||
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {0,1}, {}, {}, {}, {2, 2, 1, 1}));
|
||||
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {1, 0}, {}, {}, {}, {2, 2, 1, 1}));
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
@ -950,7 +950,7 @@ TEST(strided_slice_gpu_f32_i64, test_2x2x1x1) {
|
||||
topology.add(data("input2", begin));
|
||||
topology.add(data("input3", end));
|
||||
topology.add(data("input4", strides));
|
||||
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {0,1}, {}, {}, {}, {2, 2, 1, 1}));
|
||||
topology.add(strided_slice("strided_slice", "input", "input2", "input3", "input4", {1, 0}, {}, {}, {}, {2, 2, 1, 1}));
|
||||
|
||||
network network(engine, topology);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user