Handle negative values in GroupedSliceToVSplitOptimization (#19495)

* Handle negative values in GroupedSliceToVSplitOptimization

CVS-118897

* change the way of getting slice inputs

* clamp value

---------

Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com>
This commit is contained in:
Mateusz Tabaka 2023-09-11 16:31:39 +02:00 committed by GitHub
parent 7e3e1e2480
commit d0dda74fc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 62 additions and 16 deletions

View File

@ -274,30 +274,47 @@ struct SliceWithAttrs {
};
bool slice_is_suitable_for_optimization(const std::shared_ptr<ov::op::v8::Slice>& op, SliceAttrs& attrs) {
const auto& data_rank = op->get_input_partial_shape(0).rank();
const auto& input_shape = op->get_input_partial_shape(0);
const auto& data_rank = input_shape.rank();
if (op->get_input_size() != 5 || data_rank.is_dynamic())
return false;
const auto rank = data_rank.get_length();
for (size_t i = 1; i < 5; ++i) {
auto input_as_constant = ov::as_type_ptr<ov::op::v0::Constant>(op->get_input_node_shared_ptr(i));
if (!input_as_constant)
auto get_scalar = [](const std::shared_ptr<ov::Node>& node, int64_t& value) -> bool {
auto constant = ov::as_type_ptr<ov::op::v0::Constant>(node);
if (!constant)
return false;
if (shape_size(input_as_constant->get_shape()) != 1)
if (shape_size(constant->get_shape()) != 1)
return false;
value = constant->cast_vector<int64_t>()[0];
return true;
};
int64_t value = input_as_constant->cast_vector<int64_t>()[0];
enum { START = 1, STOP, STRIDE, AXIS };
if (((i == 1 || i == 2) && value < 0) || (i == 3 && value != 1))
return false;
else if (i == 1)
attrs.start = value;
else if (i == 2)
attrs.stop = value;
else if (i == 4)
attrs.axis = value >= 0 ? value : value + data_rank.get_length();
}
if (attrs.axis < 0 || op->get_input_partial_shape(0)[attrs.axis].is_dynamic())
int64_t stride;
if (!get_scalar(op->get_input_node_shared_ptr(STRIDE), stride) || stride != 1)
return false;
if (!get_scalar(op->get_input_node_shared_ptr(AXIS), attrs.axis))
return false;
attrs.axis = attrs.axis >= 0 ? attrs.axis : attrs.axis + rank;
if (input_shape[attrs.axis].is_dynamic())
return false;
const auto dimension = input_shape[attrs.axis].get_length();
for (int i = START; i <= STOP; i++) {
int64_t value;
if (!get_scalar(op->get_input_node_shared_ptr(i), value))
return false;
value = value >= 0 ? value : value + dimension;
value = std::max<int64_t>(std::min(value, dimension), 0);
if (i == START)
attrs.start = value;
else if (i == STOP)
attrs.stop = value;
}
return true;
}
@ -335,6 +352,9 @@ bool ov::pass::GroupedSliceToVSplitOptimization::run_on_model(const std::shared_
const auto& axis = output_with_axis.second;
auto attributes = source_to_op_with_attrs[output_with_axis];
if (attributes.size() < 2)
continue;
std::sort(attributes.begin(), attributes.end(), [](const SliceWithAttrs& lhs, const SliceWithAttrs& rhs) {
if (lhs.attrs.start == rhs.attrs.start)
return lhs.attrs.stop < rhs.attrs.stop;

View File

@ -1175,3 +1175,29 @@ TEST_F(TransformationTestsF, GroupedSliceToVSplitSameSourceDifferentAxis) {
model_ref = std::make_shared<ov::Model>(ov::NodeVector{concat_2}, ov::ParameterVector{data});
}
}
TEST_F(TransformationTestsF, GroupedSliceToVSplitNegativeStartStop) {
{
auto data = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{-1, 5, -1, -1});
auto relu = std::make_shared<ov::opset8::Relu>(data);
auto slice_0 = make_slice(relu, -50, 1, 1, -3);
auto slice_1 = make_slice(relu, -4, -2, 1, 1);
auto slice_2 = make_slice(relu, -2, INT32_MAX, 1, 1);
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{slice_0, slice_2, slice_1}, 1);
model = std::make_shared<ov::Model>(ov::NodeVector{concat}, ov::ParameterVector{data});
manager.register_pass<ov::pass::GroupedSliceToVSplitOptimization>();
}
{
auto data = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{-1, 5, -1, -1});
auto relu = std::make_shared<ov::opset8::Relu>(data);
auto vsplit = make_vsplit(relu, 1, {1, 2, 2});
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{vsplit[0], vsplit[2], vsplit[1]}, 1);
model_ref = std::make_shared<ov::Model>(ov::NodeVector{concat}, ov::ParameterVector{data});
}
}