Added TSSlice transformation to TSGeneral, created TransposeSinkingGeneral alias in ov::pass namespace

This commit is contained in:
Ivan 2023-03-17 16:20:06 +04:00
parent 83ab2cc5f6
commit 958f000e02
4 changed files with 8 additions and 2 deletions

View File

@ -16,6 +16,9 @@ class TRANSFORMATIONS_API TSGeneralBackward;
class TRANSFORMATIONS_API TSGeneral;
} // namespace transpose_sinking
using TransposeSinkingGeneral = ov::pass::transpose_sinking::TSGeneral;
} // namespace pass
} // namespace ov

View File

@ -15,6 +15,7 @@
#include "transformations/transpose_sinking/ts_fuse.hpp"
#include "transformations/transpose_sinking/ts_interpolate.hpp"
#include "transformations/transpose_sinking/ts_reduction.hpp"
#include "transformations/transpose_sinking/ts_slice.hpp"
#include "transformations/transpose_sinking/ts_split.hpp"
#include "transformations/transpose_sinking/ts_squeeze.hpp"
#include "transformations/transpose_sinking/ts_unary.hpp"
@ -34,6 +35,7 @@ TSGeneralForward::TSGeneralForward() {
add_matcher<TSSqueezeForward>();
add_matcher<TSUnsqueezeForward>();
add_matcher<TSInterpolateForward>();
add_matcher<TSSliceForward>();
add_matcher<TSFuse>();
}
@ -48,6 +50,7 @@ TSGeneralBackward::TSGeneralBackward() {
add_matcher<TSSqueezeBackward>();
add_matcher<TSUnsqueezeBackward>();
add_matcher<TSInterpolateBackward>();
add_matcher<TSSliceBackward>();
add_matcher<TSFuse>();
}

View File

@ -248,7 +248,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
{
// perform transpose sinking and reverse infer if the model contains only OpenVINO operations
ov::pass::Manager manager;
manager.register_pass<ov::pass::transpose_sinking::TSGeneral>();
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
manager.register_pass<ov::pass::ReverseShapeAndTypeInfer>();
manager.run_passes(model);
}

View File

@ -268,7 +268,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& function) const {
manager.register_pass<ov::frontend::tensorflow_lite::pass::TFLQuantizeResolver>();
manager.register_pass<ov::frontend::tensorflow_lite::pass::Rfft2dSimplifier>();
manager.register_pass<ov::pass::TransposeSinking>();
manager.register_pass<ov::pass::transpose_sinking::TSGeneral>();
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
manager.run_passes(function);
}