fix transformation; add unit test (#18314)

Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com>
This commit is contained in:
Evgeny Kotov 2023-07-05 14:36:46 +02:00 committed by GitHub
parent f313dde66b
commit 6a0c6a1a60
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 4 deletions

View File

@ -12,6 +12,7 @@
#include "openvino/op/space_to_batch.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/op/util/pad_base.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
@ -25,7 +26,7 @@ using namespace ov::pass::transpose_sinking::utils;
namespace {
std::vector<size_t> get_indices_by_op_type(const std::shared_ptr<Node>& main_node) {
if (as_type_ptr<ov::op::v1::Pad>(main_node)) {
if (as_type_ptr<ov::op::util::PadBase>(main_node)) {
return {1, 2};
} else if (as_type_ptr<ov::op::v1::BatchToSpace>(main_node) || as_type_ptr<ov::op::v1::SpaceToBatch>(main_node)) {
return {1, 2, 3};
@ -38,7 +39,7 @@ std::vector<size_t> get_indices_by_op_type(const std::shared_ptr<Node>& main_nod
TSDataMovementForward::TSDataMovementForward() {
MATCHER_SCOPE(TSDataMovementForward);
create_pattern<ov::op::v1::Pad, ov::op::v1::BatchToSpace, ov::op::v1::SpaceToBatch, ov::op::v0::ReverseSequence>(
create_pattern<op::util::PadBase, ov::op::v1::BatchToSpace, ov::op::v1::SpaceToBatch, ov::op::v0::ReverseSequence>(
true,
{0});
@ -74,7 +75,7 @@ TSDataMovementBackward::TSDataMovementBackward() {
MATCHER_SCOPE(TSDataMovementBackward);
auto main_node_label =
wrap_type<ov::op::v1::Pad, ov::op::v1::BatchToSpace, ov::op::v1::SpaceToBatch, ov::op::v0::ReverseSequence>(
wrap_type<op::util::PadBase, ov::op::v1::BatchToSpace, ov::op::v1::SpaceToBatch, ov::op::v0::ReverseSequence>(
[](const Output<Node>& output) -> bool {
return has_static_rank()(output) && CheckTransposeConsumers(output);
});

View File

@ -114,13 +114,30 @@ class PadFactory : public IFactory {
public:
explicit PadFactory(const std::string& type_name) : IFactory(type_name) {}
NodePtr create(const OutputVector& parent_nodes) const override {
return std::make_shared<Pad>(parent_nodes[0], parent_nodes[1], parent_nodes[2], ov::op::PadMode::CONSTANT);
return std::make_shared<ov::op::v1::Pad>(parent_nodes[0],
parent_nodes[1],
parent_nodes[2],
ov::op::PadMode::CONSTANT);
}
};
FactoryPtr CreatePadFactory(const std::string& type_name) {
return std::make_shared<PadFactory>(type_name);
}
class Pad12Factory : public IFactory {
public:
explicit Pad12Factory(const std::string& type_name) : IFactory(type_name) {}
NodePtr create(const OutputVector& parent_nodes) const override {
return std::make_shared<ov::op::v12::Pad>(parent_nodes[0],
parent_nodes[1],
parent_nodes[2],
ov::op::PadMode::CONSTANT);
}
};
FactoryPtr CreatePad12Factory(const std::string& type_name) {
return std::make_shared<Pad12Factory>(type_name);
}
class BatchToSpaceFactory : public IFactory {
public:
explicit BatchToSpaceFactory(const std::string& type_name) : IFactory(type_name) {}
@ -253,6 +270,9 @@ FactoryPtr CreateFakeQuantizeFactory(const std::string& type_name) {
#undef CREATE_PAD_FACTORY
#define CREATE_PAD_FACTORY(type_name) CreatePadFactory(#type_name)
#undef CREATE_PAD12_FACTORY
#define CREATE_PAD12_FACTORY(type_name) CreatePad12Factory(#type_name)
#undef CREATE_BATCH_TO_SPACE_FACTORY
#define CREATE_BATCH_TO_SPACE_FACTORY(type_name) CreateBatchToSpaceFactory(#type_name)
@ -538,6 +558,34 @@ auto test_forward_pad = []() {
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonPadForward, TSTestFixture, test_forward_pad());
auto test_negative_forward_pad = []() {
TestCase test_case;
// Initialize common attributes
test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementForward);
test_case.num_main_ops = {1, 2};
test_case.inputs_to_main = {
parameter(element::f32, {1, 3, 55, 55}),
constant<int64_t>(element::i32, {4}, {1, -2, -3, -4}),
constant<int64_t>(element::i32, {4}, {1, -2, -3, -4}),
};
// Test model description:
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
test_case.model.main_op = {CREATE_PAD12_FACTORY(Pad12)};
test_case.model.model_template = create_model;
// Reference model description:
test_case.model_ref.preprocess_inputs_to_main = {{set_gather_for}, {{1, 2}}};
test_case.model_ref.main_op = {CREATE_PAD12_FACTORY(Pad12)};
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
test_case.model_ref.model_template = create_model;
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonNegativePad12Forward, TSTestFixture, test_negative_forward_pad());
auto test_forward_batch_to_space = []() {
TestCase test_case;