fix transformation; add unit test (#18314)
Co-authored-by: Ivan Tikhonov <ivan.tikhonov@intel.com>
This commit is contained in:
parent
f313dde66b
commit
6a0c6a1a60
@ -12,6 +12,7 @@
|
||||
#include "openvino/op/space_to_batch.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "openvino/op/util/op_types.hpp"
|
||||
#include "openvino/op/util/pad_base.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/util/common_util.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
@ -25,7 +26,7 @@ using namespace ov::pass::transpose_sinking::utils;
|
||||
namespace {
|
||||
|
||||
std::vector<size_t> get_indices_by_op_type(const std::shared_ptr<Node>& main_node) {
|
||||
if (as_type_ptr<ov::op::v1::Pad>(main_node)) {
|
||||
if (as_type_ptr<ov::op::util::PadBase>(main_node)) {
|
||||
return {1, 2};
|
||||
} else if (as_type_ptr<ov::op::v1::BatchToSpace>(main_node) || as_type_ptr<ov::op::v1::SpaceToBatch>(main_node)) {
|
||||
return {1, 2, 3};
|
||||
@ -38,7 +39,7 @@ std::vector<size_t> get_indices_by_op_type(const std::shared_ptr<Node>& main_nod
|
||||
|
||||
TSDataMovementForward::TSDataMovementForward() {
|
||||
MATCHER_SCOPE(TSDataMovementForward);
|
||||
create_pattern<ov::op::v1::Pad, ov::op::v1::BatchToSpace, ov::op::v1::SpaceToBatch, ov::op::v0::ReverseSequence>(
|
||||
create_pattern<op::util::PadBase, ov::op::v1::BatchToSpace, ov::op::v1::SpaceToBatch, ov::op::v0::ReverseSequence>(
|
||||
true,
|
||||
{0});
|
||||
|
||||
@ -74,7 +75,7 @@ TSDataMovementBackward::TSDataMovementBackward() {
|
||||
MATCHER_SCOPE(TSDataMovementBackward);
|
||||
|
||||
auto main_node_label =
|
||||
wrap_type<ov::op::v1::Pad, ov::op::v1::BatchToSpace, ov::op::v1::SpaceToBatch, ov::op::v0::ReverseSequence>(
|
||||
wrap_type<op::util::PadBase, ov::op::v1::BatchToSpace, ov::op::v1::SpaceToBatch, ov::op::v0::ReverseSequence>(
|
||||
[](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && CheckTransposeConsumers(output);
|
||||
});
|
||||
|
@ -114,13 +114,30 @@ class PadFactory : public IFactory {
|
||||
public:
|
||||
explicit PadFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||
NodePtr create(const OutputVector& parent_nodes) const override {
|
||||
return std::make_shared<Pad>(parent_nodes[0], parent_nodes[1], parent_nodes[2], ov::op::PadMode::CONSTANT);
|
||||
return std::make_shared<ov::op::v1::Pad>(parent_nodes[0],
|
||||
parent_nodes[1],
|
||||
parent_nodes[2],
|
||||
ov::op::PadMode::CONSTANT);
|
||||
}
|
||||
};
|
||||
FactoryPtr CreatePadFactory(const std::string& type_name) {
|
||||
return std::make_shared<PadFactory>(type_name);
|
||||
}
|
||||
|
||||
class Pad12Factory : public IFactory {
|
||||
public:
|
||||
explicit Pad12Factory(const std::string& type_name) : IFactory(type_name) {}
|
||||
NodePtr create(const OutputVector& parent_nodes) const override {
|
||||
return std::make_shared<ov::op::v12::Pad>(parent_nodes[0],
|
||||
parent_nodes[1],
|
||||
parent_nodes[2],
|
||||
ov::op::PadMode::CONSTANT);
|
||||
}
|
||||
};
|
||||
FactoryPtr CreatePad12Factory(const std::string& type_name) {
|
||||
return std::make_shared<Pad12Factory>(type_name);
|
||||
}
|
||||
|
||||
class BatchToSpaceFactory : public IFactory {
|
||||
public:
|
||||
explicit BatchToSpaceFactory(const std::string& type_name) : IFactory(type_name) {}
|
||||
@ -253,6 +270,9 @@ FactoryPtr CreateFakeQuantizeFactory(const std::string& type_name) {
|
||||
#undef CREATE_PAD_FACTORY
|
||||
#define CREATE_PAD_FACTORY(type_name) CreatePadFactory(#type_name)
|
||||
|
||||
#undef CREATE_PAD12_FACTORY
|
||||
#define CREATE_PAD12_FACTORY(type_name) CreatePad12Factory(#type_name)
|
||||
|
||||
#undef CREATE_BATCH_TO_SPACE_FACTORY
|
||||
#define CREATE_BATCH_TO_SPACE_FACTORY(type_name) CreateBatchToSpaceFactory(#type_name)
|
||||
|
||||
@ -538,6 +558,34 @@ auto test_forward_pad = []() {
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonPadForward, TSTestFixture, test_forward_pad());
|
||||
|
||||
auto test_negative_forward_pad = []() {
|
||||
TestCase test_case;
|
||||
|
||||
// Initialize common attributes
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TSDataMovementForward);
|
||||
test_case.num_main_ops = {1, 2};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 3, 55, 55}),
|
||||
constant<int64_t>(element::i32, {4}, {1, -2, -3, -4}),
|
||||
constant<int64_t>(element::i32, {4}, {1, -2, -3, -4}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
test_case.model.preprocess_inputs_to_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model.main_op = {CREATE_PAD12_FACTORY(Pad12)};
|
||||
test_case.model.model_template = create_model;
|
||||
|
||||
// Reference model description:
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_gather_for}, {{1, 2}}};
|
||||
test_case.model_ref.main_op = {CREATE_PAD12_FACTORY(Pad12)};
|
||||
test_case.model_ref.preprocess_outputs_of_main = {{set_transpose_for}, {{0}}};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonNegativePad12Forward, TSTestFixture, test_negative_forward_pad());
|
||||
|
||||
auto test_forward_batch_to_space = []() {
|
||||
TestCase test_case;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user