TransposeSinking - add support for ShapeOf (#19471)
* TransposeSinking - add support for ShapeOf It's often in the model that transposes have two consumers: ShapeOf and another transpose and with ShapeOf absent - the transpose pair could be eliminated. This patch's approach is to propagate a transpose through ShapeOf by creating "ShapeOf->Gather" subgraph and replacing ShapeOf's input with transpose input. Ticket: CVS-118896 * enhance docs
This commit is contained in:
parent
38cf4764cb
commit
02d6c1cb5d
@ -0,0 +1,51 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "transformations/transpose_sinking/ts_base.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
namespace transpose_sinking {
|
||||
|
||||
class TRANSFORMATIONS_API TSShapeOfForward;
|
||||
|
||||
} // namespace transpose_sinking
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief TSShapeOfForward transformation sinks Transpose through ShapeOf in the forward direction.
|
||||
*
|
||||
* It replaces:
|
||||
*
|
||||
* +---------+
|
||||
* |Transpose|
|
||||
* +---------+
|
||||
* |
|
||||
* v
|
||||
* +---------+
|
||||
* | ShapeOf |
|
||||
* +---------+
|
||||
*
|
||||
* with the following:
|
||||
*
|
||||
* +---------+
|
||||
* | ShapeOf |
|
||||
* +---------+
|
||||
* |
|
||||
* v
|
||||
* +------+
|
||||
* |Gather|
|
||||
* +------+
|
||||
*
|
||||
*/
|
||||
class ov::pass::transpose_sinking::TSShapeOfForward : public ov::pass::transpose_sinking::TSForwardBase {
|
||||
public:
|
||||
OPENVINO_RTTI("TSShapeOfForward", "0");
|
||||
TSShapeOfForward();
|
||||
};
|
@ -18,6 +18,7 @@
|
||||
#include "transformations/transpose_sinking/ts_interpolate.hpp"
|
||||
#include "transformations/transpose_sinking/ts_reduction.hpp"
|
||||
#include "transformations/transpose_sinking/ts_reset_no_sinking_attribute.hpp"
|
||||
#include "transformations/transpose_sinking/ts_shape_of.hpp"
|
||||
#include "transformations/transpose_sinking/ts_slice.hpp"
|
||||
#include "transformations/transpose_sinking/ts_split.hpp"
|
||||
#include "transformations/transpose_sinking/ts_squeeze.hpp"
|
||||
@ -40,6 +41,7 @@ TSGeneralForward::TSGeneralForward() {
|
||||
add_matcher<TSInterpolateForward>();
|
||||
add_matcher<TSSliceForward>();
|
||||
add_matcher<TSGatherForward>();
|
||||
add_matcher<TSShapeOfForward>();
|
||||
add_matcher<TSFuse>();
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,31 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/transpose_sinking/ts_shape_of.hpp"
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/util/shape_of_base.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
|
||||
ov::pass::transpose_sinking::TSShapeOfForward::TSShapeOfForward() {
|
||||
MATCHER_SCOPE(TSShapeOfForward);
|
||||
|
||||
create_pattern<op::util::ShapeOfBase>(true);
|
||||
auto sinking_transformation = [=](const std::shared_ptr<Node>& main_node,
|
||||
const utils::TransposeInputsInfo& transpose_info) -> bool {
|
||||
main_node->input(0).replace_source_output(transpose_info.transpose->input_value(0));
|
||||
auto shape_of_consumers = main_node->get_output_target_inputs(0);
|
||||
const auto axis = op::v0::Constant::create(element::i32, Shape{}, {0});
|
||||
const auto gather = std::make_shared<op::v8::Gather>(main_node, transpose_info.transpose_const, axis);
|
||||
for (auto& input : shape_of_consumers)
|
||||
input.replace_source_output(gather);
|
||||
copy_runtime_info(main_node, gather);
|
||||
utils::SwapOutputNames(main_node->output(0), gather->output(0));
|
||||
utils::SwapFriendlyNames(main_node, gather);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
transpose_sinking(matcher_name, sinking_transformation);
|
||||
}
|
@ -0,0 +1,70 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/transpose_sinking/ts_shape_of.hpp"
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "transformations/transpose_sinking/ts_fuse.hpp"
|
||||
|
||||
using namespace ov;
|
||||
|
||||
using TSShapeOfForward = TransformationTestsF;
|
||||
|
||||
TEST_F(TSShapeOfForward, v0ShapeOf) {
|
||||
{
|
||||
auto param = std::make_shared<op::v0::Parameter>(element::f32, Shape{2, 3, 4, 5});
|
||||
auto order1 = op::v0::Constant::create(element::i32, Shape{4}, {0, 3, 1, 2});
|
||||
auto transpose1 = std::make_shared<op::v1::Transpose>(param, order1);
|
||||
auto shape_of = std::make_shared<op::v0::ShapeOf>(transpose1);
|
||||
auto order2 = op::v0::Constant::create(element::i32, Shape{4}, {0, 2, 3, 1});
|
||||
auto transpose2 = std::make_shared<op::v1::Transpose>(transpose1, order2);
|
||||
auto abs = std::make_shared<op::v0::Abs>(transpose2);
|
||||
model = std::make_shared<Model>(NodeVector{shape_of, abs}, ParameterVector{param});
|
||||
|
||||
manager.register_pass<ov::pass::transpose_sinking::TSShapeOfForward>();
|
||||
manager.register_pass<ov::pass::transpose_sinking::TSFuse>();
|
||||
}
|
||||
|
||||
{
|
||||
auto param = std::make_shared<op::v0::Parameter>(element::f32, Shape{2, 3, 4, 5});
|
||||
auto order1 = op::v0::Constant::create(element::i32, Shape{4}, {0, 3, 1, 2});
|
||||
auto shape_of = std::make_shared<op::v0::ShapeOf>(param);
|
||||
auto axis = op::v0::Constant::create(element::i32, Shape{}, {0});
|
||||
auto gather = std::make_shared<op::v8::Gather>(shape_of, order1, axis);
|
||||
auto abs = std::make_shared<op::v0::Abs>(param);
|
||||
model_ref = std::make_shared<Model>(NodeVector{gather, abs}, ParameterVector{param});
|
||||
}
|
||||
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TSShapeOfForward, v3ShapeOf) {
|
||||
{
|
||||
auto param = std::make_shared<op::v0::Parameter>(element::f32, Shape{2, 3, 4, 5});
|
||||
auto order1 = op::v0::Constant::create(element::i32, Shape{4}, {0, 3, 1, 2});
|
||||
auto transpose1 = std::make_shared<op::v1::Transpose>(param, order1);
|
||||
auto shape_of = std::make_shared<op::v3::ShapeOf>(transpose1);
|
||||
auto order2 = op::v0::Constant::create(element::i32, Shape{4}, {0, 2, 3, 1});
|
||||
auto transpose2 = std::make_shared<op::v1::Transpose>(transpose1, order2);
|
||||
auto abs = std::make_shared<op::v0::Abs>(transpose2);
|
||||
model = std::make_shared<Model>(NodeVector{shape_of, abs}, ParameterVector{param});
|
||||
|
||||
manager.register_pass<ov::pass::transpose_sinking::TSShapeOfForward>();
|
||||
manager.register_pass<ov::pass::transpose_sinking::TSFuse>();
|
||||
}
|
||||
|
||||
{
|
||||
auto param = std::make_shared<op::v0::Parameter>(element::f32, Shape{2, 3, 4, 5});
|
||||
auto order1 = op::v0::Constant::create(element::i32, Shape{4}, {0, 3, 1, 2});
|
||||
auto shape_of = std::make_shared<op::v3::ShapeOf>(param);
|
||||
auto axis = op::v0::Constant::create(element::i32, Shape{}, {0});
|
||||
auto gather = std::make_shared<op::v8::Gather>(shape_of, order1, axis);
|
||||
auto abs = std::make_shared<op::v0::Abs>(param);
|
||||
model_ref = std::make_shared<Model>(NodeVector{gather, abs}, ParameterVector{param});
|
||||
}
|
||||
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
Loading…
Reference in New Issue
Block a user