Remove useless Slices (#19451)
Adjust UselessStridedSliceEraser to work with Slice nodes. Ticket: CVS-118895
This commit is contained in:
parent
f2167a9545
commit
3e8c0fac1b
@ -14,7 +14,7 @@ namespace ov {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API StridedSliceOptimization;
|
||||
class TRANSFORMATIONS_API UselessStridedSliceEraser;
|
||||
class TRANSFORMATIONS_API UselessSliceEraser;
|
||||
class TRANSFORMATIONS_API SharedStridedSliceEraser;
|
||||
class TRANSFORMATIONS_API GroupedStridedSliceOptimizer;
|
||||
class TRANSFORMATIONS_API GroupedSliceToVSplitOptimization;
|
||||
@ -24,12 +24,12 @@ class TRANSFORMATIONS_API GroupedSliceToVSplitOptimization;
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief UselessStridedSliceEraser transformation removes StridedSlice operations
|
||||
* @brief UselessSliceEraser transformation removes Slice/StridedSlice operations
|
||||
* with equal input and output shapes.
|
||||
*/
|
||||
class ov::pass::UselessStridedSliceEraser : public ov::pass::ModelPass {
|
||||
class ov::pass::UselessSliceEraser : public ov::pass::ModelPass {
|
||||
public:
|
||||
OPENVINO_RTTI("UselessStridedSliceEraser", "0");
|
||||
OPENVINO_RTTI("UselessSliceEraser", "0");
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
|
||||
};
|
||||
|
||||
|
@ -21,8 +21,8 @@
|
||||
|
||||
using namespace ov;
|
||||
|
||||
bool ov::pass::UselessStridedSliceEraser::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||
RUN_ON_FUNCTION_SCOPE(UselessStridedSliceEraser);
|
||||
bool ov::pass::UselessSliceEraser::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||
RUN_ON_FUNCTION_SCOPE(UselessSliceEraser);
|
||||
bool rewritten = false;
|
||||
for (auto& node : f->get_ordered_ops()) {
|
||||
// Recursively apply transformation for sub-graph based operations
|
||||
@ -31,19 +31,21 @@ bool ov::pass::UselessStridedSliceEraser::run_on_model(const std::shared_ptr<ov:
|
||||
rewritten |= run_on_model(sub_graph);
|
||||
}
|
||||
}
|
||||
auto ss = std::dynamic_pointer_cast<ov::op::v1::StridedSlice>(node);
|
||||
if (!ss || ss->get_output_partial_shape(0).is_dynamic() || ss->get_input_partial_shape(0).is_dynamic())
|
||||
bool is_slice = ov::is_type<ov::op::v1::StridedSlice>(node) || ov::is_type<ov::op::v8::Slice>(node);
|
||||
if (!is_slice || node->get_output_partial_shape(0).is_dynamic() ||
|
||||
node->get_input_partial_shape(0).is_dynamic())
|
||||
continue;
|
||||
if (ss->input(0).get_shape() != ss->output(0).get_shape())
|
||||
if (node->get_input_shape(0) != node->get_output_shape(0))
|
||||
continue;
|
||||
|
||||
auto stridesNode = std::dynamic_pointer_cast<ov::op::v0::Constant>(ss->input_value(3).get_node_shared_ptr());
|
||||
auto stridesNode = std::dynamic_pointer_cast<ov::op::v0::Constant>(node->get_input_node_shared_ptr(3));
|
||||
if (stridesNode) {
|
||||
auto strides = stridesNode->cast_vector<int64_t>();
|
||||
if (!std::any_of(strides.begin(), strides.end(), [](int64_t strd) {
|
||||
return strd < 0;
|
||||
}))
|
||||
rewritten |= replace_output_update_name(ss->output(0), ss->input_value(0));
|
||||
})) {
|
||||
rewritten |= replace_output_update_name(node->output(0), node->input_value(0));
|
||||
}
|
||||
}
|
||||
}
|
||||
return rewritten;
|
||||
@ -404,7 +406,7 @@ bool ov::pass::StridedSliceOptimization::run_on_model(const std::shared_ptr<ov::
|
||||
|
||||
bool rewritten = false;
|
||||
if (m_use_shapes) {
|
||||
rewritten = UselessStridedSliceEraser().run_on_model(f);
|
||||
rewritten = UselessSliceEraser().run_on_model(f);
|
||||
// Execution of other passes is also needed even if 'rewritten' is already 'true'
|
||||
rewritten = SharedStridedSliceEraser().run_on_model(f) || rewritten;
|
||||
rewritten = GroupedStridedSliceOptimizer().run_on_model(f) || rewritten;
|
||||
|
@ -119,6 +119,43 @@ TEST_F(TransformationTestsF, OptimizeSS_SkipUselessDeletionRevertCase) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, UselessSlice) {
|
||||
{
|
||||
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{5, 5, 5, 5});
|
||||
auto relu = std::make_shared<opset1::Relu>(data);
|
||||
auto begin = opset1::Constant::create(element::i64, Shape{2}, {0, 0});
|
||||
auto end = opset1::Constant::create(element::i64, Shape{2}, {5, 5});
|
||||
auto stride = opset1::Constant::create(element::i64, Shape{2}, {1, 1});
|
||||
auto axis = opset1::Constant::create(element::i64, Shape{2}, {1, 3});
|
||||
|
||||
auto slice = std::make_shared<opset8::Slice>(relu, begin, end, stride, axis);
|
||||
|
||||
model = std::make_shared<ov::Model>(slice, ParameterVector{data});
|
||||
manager.register_pass<ov::pass::UselessSliceEraser>();
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{5, 5, 5, 5});
|
||||
auto relu = std::make_shared<opset1::Relu>(data);
|
||||
model_ref = std::make_shared<ov::Model>(relu, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, NegativeUselessSliceWithNegativeStrides) {
|
||||
{
|
||||
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{5, 5, 5, 5});
|
||||
auto relu = std::make_shared<opset1::Relu>(data);
|
||||
auto begin = opset1::Constant::create(element::i64, Shape{2}, {4, 0});
|
||||
auto end = opset1::Constant::create(element::i64, Shape{2}, {INT32_MIN, 5});
|
||||
auto stride = opset1::Constant::create(element::i64, Shape{2}, {-1, 1});
|
||||
auto axis = opset1::Constant::create(element::i64, Shape{2}, {1, 3});
|
||||
|
||||
auto slice = std::make_shared<opset8::Slice>(relu, begin, end, stride, axis);
|
||||
|
||||
model = std::make_shared<ov::Model>(slice, ParameterVector{data});
|
||||
manager.register_pass<ov::pass::UselessSliceEraser>();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, OptimizeSS_Usefull_Test) {
|
||||
{
|
||||
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{5, 5, 5, 5});
|
||||
|
Loading…
Reference in New Issue
Block a user