Remove useless Slices (#19451)
Adjust UselessStridedSliceEraser to work with Slice nodes. Ticket: CVS-118895
This commit is contained in:
parent
f2167a9545
commit
3e8c0fac1b
@ -14,7 +14,7 @@ namespace ov {
|
|||||||
namespace pass {
|
namespace pass {
|
||||||
|
|
||||||
class TRANSFORMATIONS_API StridedSliceOptimization;
|
class TRANSFORMATIONS_API StridedSliceOptimization;
|
||||||
class TRANSFORMATIONS_API UselessStridedSliceEraser;
|
class TRANSFORMATIONS_API UselessSliceEraser;
|
||||||
class TRANSFORMATIONS_API SharedStridedSliceEraser;
|
class TRANSFORMATIONS_API SharedStridedSliceEraser;
|
||||||
class TRANSFORMATIONS_API GroupedStridedSliceOptimizer;
|
class TRANSFORMATIONS_API GroupedStridedSliceOptimizer;
|
||||||
class TRANSFORMATIONS_API GroupedSliceToVSplitOptimization;
|
class TRANSFORMATIONS_API GroupedSliceToVSplitOptimization;
|
||||||
@ -24,12 +24,12 @@ class TRANSFORMATIONS_API GroupedSliceToVSplitOptimization;
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* @ingroup ie_transformation_common_api
|
* @ingroup ie_transformation_common_api
|
||||||
* @brief UselessStridedSliceEraser transformation removes StridedSlice operations
|
* @brief UselessSliceEraser transformation removes Slice/StridedSlice operations
|
||||||
* with equal input and output shapes.
|
* with equal input and output shapes.
|
||||||
*/
|
*/
|
||||||
class ov::pass::UselessStridedSliceEraser : public ov::pass::ModelPass {
|
class ov::pass::UselessSliceEraser : public ov::pass::ModelPass {
|
||||||
public:
|
public:
|
||||||
OPENVINO_RTTI("UselessStridedSliceEraser", "0");
|
OPENVINO_RTTI("UselessSliceEraser", "0");
|
||||||
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
|
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -21,8 +21,8 @@
|
|||||||
|
|
||||||
using namespace ov;
|
using namespace ov;
|
||||||
|
|
||||||
bool ov::pass::UselessStridedSliceEraser::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
bool ov::pass::UselessSliceEraser::run_on_model(const std::shared_ptr<ov::Model>& f) {
|
||||||
RUN_ON_FUNCTION_SCOPE(UselessStridedSliceEraser);
|
RUN_ON_FUNCTION_SCOPE(UselessSliceEraser);
|
||||||
bool rewritten = false;
|
bool rewritten = false;
|
||||||
for (auto& node : f->get_ordered_ops()) {
|
for (auto& node : f->get_ordered_ops()) {
|
||||||
// Recursively apply transformation for sub-graph based operations
|
// Recursively apply transformation for sub-graph based operations
|
||||||
@ -31,19 +31,21 @@ bool ov::pass::UselessStridedSliceEraser::run_on_model(const std::shared_ptr<ov:
|
|||||||
rewritten |= run_on_model(sub_graph);
|
rewritten |= run_on_model(sub_graph);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto ss = std::dynamic_pointer_cast<ov::op::v1::StridedSlice>(node);
|
bool is_slice = ov::is_type<ov::op::v1::StridedSlice>(node) || ov::is_type<ov::op::v8::Slice>(node);
|
||||||
if (!ss || ss->get_output_partial_shape(0).is_dynamic() || ss->get_input_partial_shape(0).is_dynamic())
|
if (!is_slice || node->get_output_partial_shape(0).is_dynamic() ||
|
||||||
|
node->get_input_partial_shape(0).is_dynamic())
|
||||||
continue;
|
continue;
|
||||||
if (ss->input(0).get_shape() != ss->output(0).get_shape())
|
if (node->get_input_shape(0) != node->get_output_shape(0))
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
auto stridesNode = std::dynamic_pointer_cast<ov::op::v0::Constant>(ss->input_value(3).get_node_shared_ptr());
|
auto stridesNode = std::dynamic_pointer_cast<ov::op::v0::Constant>(node->get_input_node_shared_ptr(3));
|
||||||
if (stridesNode) {
|
if (stridesNode) {
|
||||||
auto strides = stridesNode->cast_vector<int64_t>();
|
auto strides = stridesNode->cast_vector<int64_t>();
|
||||||
if (!std::any_of(strides.begin(), strides.end(), [](int64_t strd) {
|
if (!std::any_of(strides.begin(), strides.end(), [](int64_t strd) {
|
||||||
return strd < 0;
|
return strd < 0;
|
||||||
}))
|
})) {
|
||||||
rewritten |= replace_output_update_name(ss->output(0), ss->input_value(0));
|
rewritten |= replace_output_update_name(node->output(0), node->input_value(0));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return rewritten;
|
return rewritten;
|
||||||
@ -404,7 +406,7 @@ bool ov::pass::StridedSliceOptimization::run_on_model(const std::shared_ptr<ov::
|
|||||||
|
|
||||||
bool rewritten = false;
|
bool rewritten = false;
|
||||||
if (m_use_shapes) {
|
if (m_use_shapes) {
|
||||||
rewritten = UselessStridedSliceEraser().run_on_model(f);
|
rewritten = UselessSliceEraser().run_on_model(f);
|
||||||
// Execution of other passes is also needed even if 'rewritten' is already 'true'
|
// Execution of other passes is also needed even if 'rewritten' is already 'true'
|
||||||
rewritten = SharedStridedSliceEraser().run_on_model(f) || rewritten;
|
rewritten = SharedStridedSliceEraser().run_on_model(f) || rewritten;
|
||||||
rewritten = GroupedStridedSliceOptimizer().run_on_model(f) || rewritten;
|
rewritten = GroupedStridedSliceOptimizer().run_on_model(f) || rewritten;
|
||||||
|
@ -119,6 +119,43 @@ TEST_F(TransformationTestsF, OptimizeSS_SkipUselessDeletionRevertCase) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformationTestsF, UselessSlice) {
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{5, 5, 5, 5});
|
||||||
|
auto relu = std::make_shared<opset1::Relu>(data);
|
||||||
|
auto begin = opset1::Constant::create(element::i64, Shape{2}, {0, 0});
|
||||||
|
auto end = opset1::Constant::create(element::i64, Shape{2}, {5, 5});
|
||||||
|
auto stride = opset1::Constant::create(element::i64, Shape{2}, {1, 1});
|
||||||
|
auto axis = opset1::Constant::create(element::i64, Shape{2}, {1, 3});
|
||||||
|
|
||||||
|
auto slice = std::make_shared<opset8::Slice>(relu, begin, end, stride, axis);
|
||||||
|
|
||||||
|
model = std::make_shared<ov::Model>(slice, ParameterVector{data});
|
||||||
|
manager.register_pass<ov::pass::UselessSliceEraser>();
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{5, 5, 5, 5});
|
||||||
|
auto relu = std::make_shared<opset1::Relu>(data);
|
||||||
|
model_ref = std::make_shared<ov::Model>(relu, ParameterVector{data});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformationTestsF, NegativeUselessSliceWithNegativeStrides) {
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{5, 5, 5, 5});
|
||||||
|
auto relu = std::make_shared<opset1::Relu>(data);
|
||||||
|
auto begin = opset1::Constant::create(element::i64, Shape{2}, {4, 0});
|
||||||
|
auto end = opset1::Constant::create(element::i64, Shape{2}, {INT32_MIN, 5});
|
||||||
|
auto stride = opset1::Constant::create(element::i64, Shape{2}, {-1, 1});
|
||||||
|
auto axis = opset1::Constant::create(element::i64, Shape{2}, {1, 3});
|
||||||
|
|
||||||
|
auto slice = std::make_shared<opset8::Slice>(relu, begin, end, stride, axis);
|
||||||
|
|
||||||
|
model = std::make_shared<ov::Model>(slice, ParameterVector{data});
|
||||||
|
manager.register_pass<ov::pass::UselessSliceEraser>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, OptimizeSS_Usefull_Test) {
|
TEST_F(TransformationTestsF, OptimizeSS_Usefull_Test) {
|
||||||
{
|
{
|
||||||
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{5, 5, 5, 5});
|
auto data = std::make_shared<opset1::Parameter>(element::f32, Shape{5, 5, 5, 5});
|
||||||
|
Loading…
Reference in New Issue
Block a user