Remove useless Slices (#19451)

Adjust UselessStridedSliceEraser to work with Slice nodes.

Ticket: CVS-118895
This commit is contained in:
Mateusz Tabaka 2023-08-30 11:28:33 +02:00 committed by GitHub
parent f2167a9545
commit 3e8c0fac1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 52 additions and 13 deletions

View File

@ -14,7 +14,7 @@ namespace ov {
namespace pass {
class TRANSFORMATIONS_API StridedSliceOptimization;
class TRANSFORMATIONS_API UselessStridedSliceEraser;
class TRANSFORMATIONS_API UselessSliceEraser;
class TRANSFORMATIONS_API SharedStridedSliceEraser;
class TRANSFORMATIONS_API GroupedStridedSliceOptimizer;
class TRANSFORMATIONS_API GroupedSliceToVSplitOptimization;
@ -24,12 +24,12 @@ class TRANSFORMATIONS_API GroupedSliceToVSplitOptimization;
/**
* @ingroup ie_transformation_common_api
* @brief UselessStridedSliceEraser transformation removes StridedSlice operations
* @brief UselessSliceEraser transformation removes Slice/StridedSlice operations
* with equal input and output shapes.
*/
class ov::pass::UselessStridedSliceEraser : public ov::pass::ModelPass {
class ov::pass::UselessSliceEraser : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("UselessStridedSliceEraser", "0");
OPENVINO_RTTI("UselessSliceEraser", "0");
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
};

View File

@ -21,8 +21,8 @@
using namespace ov;
bool ov::pass::UselessStridedSliceEraser::run_on_model(const std::shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(UselessStridedSliceEraser);
bool ov::pass::UselessSliceEraser::run_on_model(const std::shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(UselessSliceEraser);
bool rewritten = false;
for (auto& node : f->get_ordered_ops()) {
// Recursively apply transformation for sub-graph based operations
@ -31,19 +31,21 @@ bool ov::pass::UselessStridedSliceEraser::run_on_model(const std::shared_ptr<ov:
rewritten |= run_on_model(sub_graph);
}
}
auto ss = std::dynamic_pointer_cast<ov::op::v1::StridedSlice>(node);
if (!ss || ss->get_output_partial_shape(0).is_dynamic() || ss->get_input_partial_shape(0).is_dynamic())
bool is_slice = ov::is_type<ov::op::v1::StridedSlice>(node) || ov::is_type<ov::op::v8::Slice>(node);
if (!is_slice || node->get_output_partial_shape(0).is_dynamic() ||
node->get_input_partial_shape(0).is_dynamic())
continue;
if (ss->input(0).get_shape() != ss->output(0).get_shape())
if (node->get_input_shape(0) != node->get_output_shape(0))
continue;
auto stridesNode = std::dynamic_pointer_cast<ov::op::v0::Constant>(ss->input_value(3).get_node_shared_ptr());
auto stridesNode = std::dynamic_pointer_cast<ov::op::v0::Constant>(node->get_input_node_shared_ptr(3));
if (stridesNode) {
auto strides = stridesNode->cast_vector<int64_t>();
if (!std::any_of(strides.begin(), strides.end(), [](int64_t strd) {
return strd < 0;
}))
rewritten |= replace_output_update_name(ss->output(0), ss->input_value(0));
})) {
rewritten |= replace_output_update_name(node->output(0), node->input_value(0));
}
}
}
return rewritten;
@ -404,7 +406,7 @@ bool ov::pass::StridedSliceOptimization::run_on_model(const std::shared_ptr<ov::
bool rewritten = false;
if (m_use_shapes) {
rewritten = UselessStridedSliceEraser().run_on_model(f);
rewritten = UselessSliceEraser().run_on_model(f);
// Execution of other passes is also needed even if 'rewritten' is already 'true'
rewritten = SharedStridedSliceEraser().run_on_model(f) || rewritten;
rewritten = GroupedStridedSliceOptimizer().run_on_model(f) || rewritten;

View File

@ -119,6 +119,43 @@ TEST_F(TransformationTestsF, OptimizeSS_SkipUselessDeletionRevertCase) {
}
}
TEST_F(TransformationTestsF, UselessSlice) {
{
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{5, 5, 5, 5});
auto relu = std::make_shared<opset1::Relu>(data);
auto begin = opset1::Constant::create(element::i64, Shape{2}, {0, 0});
auto end = opset1::Constant::create(element::i64, Shape{2}, {5, 5});
auto stride = opset1::Constant::create(element::i64, Shape{2}, {1, 1});
auto axis = opset1::Constant::create(element::i64, Shape{2}, {1, 3});
auto slice = std::make_shared<opset8::Slice>(relu, begin, end, stride, axis);
model = std::make_shared<ov::Model>(slice, ParameterVector{data});
manager.register_pass<ov::pass::UselessSliceEraser>();
}
{
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{5, 5, 5, 5});
auto relu = std::make_shared<opset1::Relu>(data);
model_ref = std::make_shared<ov::Model>(relu, ParameterVector{data});
}
}
TEST_F(TransformationTestsF, NegativeUselessSliceWithNegativeStrides) {
{
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{5, 5, 5, 5});
auto relu = std::make_shared<opset1::Relu>(data);
auto begin = opset1::Constant::create(element::i64, Shape{2}, {4, 0});
auto end = opset1::Constant::create(element::i64, Shape{2}, {INT32_MIN, 5});
auto stride = opset1::Constant::create(element::i64, Shape{2}, {-1, 1});
auto axis = opset1::Constant::create(element::i64, Shape{2}, {1, 3});
auto slice = std::make_shared<opset8::Slice>(relu, begin, end, stride, axis);
model = std::make_shared<ov::Model>(slice, ParameterVector{data});
manager.register_pass<ov::pass::UselessSliceEraser>();
}
}
TEST_F(TransformationTestsF, OptimizeSS_Usefull_Test) {
{
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{5, 5, 5, 5});