Handle negative values in GroupedSliceToVSplitOptimization (#19495)
* Handle negative values in GroupedSliceToVSplitOptimization CVS-118897 * change the way of getting slice inputs * clamp value --------- Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com>
This commit is contained in:
parent
7e3e1e2480
commit
d0dda74fc2
@ -274,30 +274,47 @@ struct SliceWithAttrs {
|
||||
};
|
||||
|
||||
bool slice_is_suitable_for_optimization(const std::shared_ptr<ov::op::v8::Slice>& op, SliceAttrs& attrs) {
|
||||
const auto& data_rank = op->get_input_partial_shape(0).rank();
|
||||
const auto& input_shape = op->get_input_partial_shape(0);
|
||||
const auto& data_rank = input_shape.rank();
|
||||
if (op->get_input_size() != 5 || data_rank.is_dynamic())
|
||||
return false;
|
||||
const auto rank = data_rank.get_length();
|
||||
|
||||
for (size_t i = 1; i < 5; ++i) {
|
||||
auto input_as_constant = ov::as_type_ptr<ov::op::v0::Constant>(op->get_input_node_shared_ptr(i));
|
||||
if (!input_as_constant)
|
||||
auto get_scalar = [](const std::shared_ptr<ov::Node>& node, int64_t& value) -> bool {
|
||||
auto constant = ov::as_type_ptr<ov::op::v0::Constant>(node);
|
||||
if (!constant)
|
||||
return false;
|
||||
if (shape_size(input_as_constant->get_shape()) != 1)
|
||||
if (shape_size(constant->get_shape()) != 1)
|
||||
return false;
|
||||
value = constant->cast_vector<int64_t>()[0];
|
||||
return true;
|
||||
};
|
||||
|
||||
int64_t value = input_as_constant->cast_vector<int64_t>()[0];
|
||||
enum { START = 1, STOP, STRIDE, AXIS };
|
||||
|
||||
if (((i == 1 || i == 2) && value < 0) || (i == 3 && value != 1))
|
||||
return false;
|
||||
else if (i == 1)
|
||||
attrs.start = value;
|
||||
else if (i == 2)
|
||||
attrs.stop = value;
|
||||
else if (i == 4)
|
||||
attrs.axis = value >= 0 ? value : value + data_rank.get_length();
|
||||
}
|
||||
if (attrs.axis < 0 || op->get_input_partial_shape(0)[attrs.axis].is_dynamic())
|
||||
int64_t stride;
|
||||
if (!get_scalar(op->get_input_node_shared_ptr(STRIDE), stride) || stride != 1)
|
||||
return false;
|
||||
if (!get_scalar(op->get_input_node_shared_ptr(AXIS), attrs.axis))
|
||||
return false;
|
||||
attrs.axis = attrs.axis >= 0 ? attrs.axis : attrs.axis + rank;
|
||||
|
||||
if (input_shape[attrs.axis].is_dynamic())
|
||||
return false;
|
||||
const auto dimension = input_shape[attrs.axis].get_length();
|
||||
|
||||
for (int i = START; i <= STOP; i++) {
|
||||
int64_t value;
|
||||
if (!get_scalar(op->get_input_node_shared_ptr(i), value))
|
||||
return false;
|
||||
value = value >= 0 ? value : value + dimension;
|
||||
value = std::max<int64_t>(std::min(value, dimension), 0);
|
||||
if (i == START)
|
||||
attrs.start = value;
|
||||
else if (i == STOP)
|
||||
attrs.stop = value;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -335,6 +352,9 @@ bool ov::pass::GroupedSliceToVSplitOptimization::run_on_model(const std::shared_
|
||||
const auto& axis = output_with_axis.second;
|
||||
auto attributes = source_to_op_with_attrs[output_with_axis];
|
||||
|
||||
if (attributes.size() < 2)
|
||||
continue;
|
||||
|
||||
std::sort(attributes.begin(), attributes.end(), [](const SliceWithAttrs& lhs, const SliceWithAttrs& rhs) {
|
||||
if (lhs.attrs.start == rhs.attrs.start)
|
||||
return lhs.attrs.stop < rhs.attrs.stop;
|
||||
|
@ -1175,3 +1175,29 @@ TEST_F(TransformationTestsF, GroupedSliceToVSplitSameSourceDifferentAxis) {
|
||||
model_ref = std::make_shared<ov::Model>(ov::NodeVector{concat_2}, ov::ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GroupedSliceToVSplitNegativeStartStop) {
|
||||
{
|
||||
auto data = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{-1, 5, -1, -1});
|
||||
auto relu = std::make_shared<ov::opset8::Relu>(data);
|
||||
|
||||
auto slice_0 = make_slice(relu, -50, 1, 1, -3);
|
||||
auto slice_1 = make_slice(relu, -4, -2, 1, 1);
|
||||
auto slice_2 = make_slice(relu, -2, INT32_MAX, 1, 1);
|
||||
|
||||
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{slice_0, slice_2, slice_1}, 1);
|
||||
|
||||
model = std::make_shared<ov::Model>(ov::NodeVector{concat}, ov::ParameterVector{data});
|
||||
manager.register_pass<ov::pass::GroupedSliceToVSplitOptimization>();
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<ov::opset8::Parameter>(ov::element::f32, ov::PartialShape{-1, 5, -1, -1});
|
||||
auto relu = std::make_shared<ov::opset8::Relu>(data);
|
||||
|
||||
auto vsplit = make_vsplit(relu, 1, {1, 2, 2});
|
||||
|
||||
auto concat = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{vsplit[0], vsplit[2], vsplit[1]}, 1);
|
||||
|
||||
model_ref = std::make_shared<ov::Model>(ov::NodeVector{concat}, ov::ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user