TransposeSinking - add support for ShapeOf (#19471)

* TransposeSinking - add support for ShapeOf

It's often in the model that transposes have two consumers: ShapeOf and
another transpose and with ShapeOf absent - the transpose pair could be eliminated.
This patch's approach is to propagate a transpose through ShapeOf by creating
"ShapeOf->Gather" subgraph and replacing ShapeOf's input with transpose input.

Ticket: CVS-118896

* enhance docs
This commit is contained in:
Mateusz Tabaka 2023-08-30 14:45:48 +02:00 committed by GitHub
parent 38cf4764cb
commit 02d6c1cb5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 154 additions and 0 deletions

View File

@ -0,0 +1,51 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "transformations/transpose_sinking/ts_base.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {
namespace transpose_sinking {
class TRANSFORMATIONS_API TSShapeOfForward;
} // namespace transpose_sinking
} // namespace pass
} // namespace ov
/**
* @ingroup ie_transformation_common_api
* @brief TSShapeOfForward transformation sinks Transpose through ShapeOf in the forward direction.
*
* It replaces:
*
* +---------+
* |Transpose|
* +---------+
* |
* v
* +---------+
* | ShapeOf |
* +---------+
*
* with the following:
*
* +---------+
* | ShapeOf |
* +---------+
* |
* v
* +------+
* |Gather|
* +------+
*
*/
class ov::pass::transpose_sinking::TSShapeOfForward : public ov::pass::transpose_sinking::TSForwardBase {
public:
OPENVINO_RTTI("TSShapeOfForward", "0");
TSShapeOfForward();
};

View File

@ -18,6 +18,7 @@
#include "transformations/transpose_sinking/ts_interpolate.hpp"
#include "transformations/transpose_sinking/ts_reduction.hpp"
#include "transformations/transpose_sinking/ts_reset_no_sinking_attribute.hpp"
#include "transformations/transpose_sinking/ts_shape_of.hpp"
#include "transformations/transpose_sinking/ts_slice.hpp"
#include "transformations/transpose_sinking/ts_split.hpp"
#include "transformations/transpose_sinking/ts_squeeze.hpp"
@ -40,6 +41,7 @@ TSGeneralForward::TSGeneralForward() {
add_matcher<TSInterpolateForward>();
add_matcher<TSSliceForward>();
add_matcher<TSGatherForward>();
add_matcher<TSShapeOfForward>();
add_matcher<TSFuse>();
}

View File

@ -0,0 +1,31 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/transpose_sinking/ts_shape_of.hpp"
#include "itt.hpp"
#include "openvino/op/util/shape_of_base.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
ov::pass::transpose_sinking::TSShapeOfForward::TSShapeOfForward() {
MATCHER_SCOPE(TSShapeOfForward);
create_pattern<op::util::ShapeOfBase>(true);
auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node,
const utils::TransposeInputsInfo& transpose_info) -> bool {
main_node->input(0).replace_source_output(transpose_info.transpose->input_value(0));
auto shape_of_consumers = main_node->get_output_target_inputs(0);
const auto axis = op::v0::Constant::create(element::i32, Shape{}, {0});
const auto gather = std::make_shared<op::v8::Gather>(main_node, transpose_info.transpose_const, axis);
for (auto& input : shape_of_consumers)
input.replace_source_output(gather);
copy_runtime_info(main_node, gather);
utils::SwapOutputNames(main_node->output(0), gather->output(0));
utils::SwapFriendlyNames(main_node, gather);
return true;
};
transpose_sinking(matcher_name, sinking_transformation);
}

View File

@ -0,0 +1,70 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/transpose_sinking/ts_shape_of.hpp"
#include "common_test_utils/ngraph_test_utils.hpp"
#include "transformations/transpose_sinking/ts_fuse.hpp"
using namespace ov;
using TSShapeOfForward = TransformationTestsF;
TEST_F(TSShapeOfForward, v0ShapeOf) {
{
auto param = std::make_shared<op::v0::Parameter>(element::f32, Shape{2, 3, 4, 5});
auto order1 = op::v0::Constant::create(element::i32, Shape{4}, {0, 3, 1, 2});
auto transpose1 = std::make_shared<op::v1::Transpose>(param, order1);
auto shape_of = std::make_shared<op::v0::ShapeOf>(transpose1);
auto order2 = op::v0::Constant::create(element::i32, Shape{4}, {0, 2, 3, 1});
auto transpose2 = std::make_shared<op::v1::Transpose>(transpose1, order2);
auto abs = std::make_shared<op::v0::Abs>(transpose2);
model = std::make_shared<Model>(NodeVector{shape_of, abs}, ParameterVector{param});
manager.register_pass<ov::pass::transpose_sinking::TSShapeOfForward>();
manager.register_pass<ov::pass::transpose_sinking::TSFuse>();
}
{
auto param = std::make_shared<op::v0::Parameter>(element::f32, Shape{2, 3, 4, 5});
auto order1 = op::v0::Constant::create(element::i32, Shape{4}, {0, 3, 1, 2});
auto shape_of = std::make_shared<op::v0::ShapeOf>(param);
auto axis = op::v0::Constant::create(element::i32, Shape{}, {0});
auto gather = std::make_shared<op::v8::Gather>(shape_of, order1, axis);
auto abs = std::make_shared<op::v0::Abs>(param);
model_ref = std::make_shared<Model>(NodeVector{gather, abs}, ParameterVector{param});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TSShapeOfForward, v3ShapeOf) {
{
auto param = std::make_shared<op::v0::Parameter>(element::f32, Shape{2, 3, 4, 5});
auto order1 = op::v0::Constant::create(element::i32, Shape{4}, {0, 3, 1, 2});
auto transpose1 = std::make_shared<op::v1::Transpose>(param, order1);
auto shape_of = std::make_shared<op::v3::ShapeOf>(transpose1);
auto order2 = op::v0::Constant::create(element::i32, Shape{4}, {0, 2, 3, 1});
auto transpose2 = std::make_shared<op::v1::Transpose>(transpose1, order2);
auto abs = std::make_shared<op::v0::Abs>(transpose2);
model = std::make_shared<Model>(NodeVector{shape_of, abs}, ParameterVector{param});
manager.register_pass<ov::pass::transpose_sinking::TSShapeOfForward>();
manager.register_pass<ov::pass::transpose_sinking::TSFuse>();
}
{
auto param = std::make_shared<op::v0::Parameter>(element::f32, Shape{2, 3, 4, 5});
auto order1 = op::v0::Constant::create(element::i32, Shape{4}, {0, 3, 1, 2});
auto shape_of = std::make_shared<op::v3::ShapeOf>(param);
auto axis = op::v0::Constant::create(element::i32, Shape{}, {0});
auto gather = std::make_shared<op::v8::Gather>(shape_of, order1, axis);
auto abs = std::make_shared<op::v0::Abs>(param);
model_ref = std::make_shared<Model>(NodeVector{gather, abs}, ParameterVector{param});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}