Fix code style on master (#14620)
This commit is contained in:
parent
befbae28ca
commit
6507c7e8bd
@ -2,18 +2,14 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <transformations/common_optimizations/transpose_sinking_general.hpp>
|
||||
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <openvino/frontend/manager.hpp>
|
||||
|
||||
#include <openvino/opsets/opset9.hpp>
|
||||
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
#include <functional>
|
||||
#include <openvino/frontend/manager.hpp>
|
||||
#include <openvino/opsets/opset9.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include <transformations/common_optimizations/transpose_sinking_general.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using namespace testing;
|
||||
@ -30,12 +26,14 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesForward
|
||||
|
||||
NodePtr in_op = X;
|
||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto ng_order0 =
|
||||
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||
|
||||
auto unary = std::make_shared<ov::opset9::Tanh>(transpose0);
|
||||
|
||||
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto ng_order1 =
|
||||
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
in_op = std::make_shared<ov::opset9::Transpose>(unary, ng_order1);
|
||||
}
|
||||
|
||||
@ -66,12 +64,14 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesBackwar
|
||||
|
||||
NodePtr in_op = X;
|
||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto ng_order0 =
|
||||
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||
|
||||
auto unary = std::make_shared<ov::opset9::Tanh>(transpose0);
|
||||
|
||||
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto ng_order1 =
|
||||
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
in_op = std::make_shared<ov::opset9::Transpose>(unary, ng_order1);
|
||||
}
|
||||
|
||||
@ -105,12 +105,14 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesGeneral
|
||||
|
||||
NodePtr in_op = transpose0;
|
||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto ng_order0 =
|
||||
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(in_op, ng_order0);
|
||||
|
||||
auto unary = std::make_shared<ov::opset9::Tanh>(transpose0);
|
||||
|
||||
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto ng_order1 =
|
||||
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
in_op = std::make_shared<ov::opset9::Transpose>(unary, ng_order1);
|
||||
}
|
||||
|
||||
@ -148,7 +150,8 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestBinaryGeneral) {
|
||||
NodePtr in_op = transpose0;
|
||||
for (size_t i = 0; i < num_binary_ops; ++i) {
|
||||
auto in_constant = std::make_shared<ov::opset9::Constant>(input_type, input_shape, ov::Shape{1});
|
||||
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto ng_order1 =
|
||||
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose1 = std::make_shared<ov::opset9::Transpose>(in_constant, ng_order1);
|
||||
|
||||
in_op = std::make_shared<ov::opset9::Add>(in_op, transpose1);
|
||||
@ -234,7 +237,7 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestConcatGeneral) {
|
||||
class IFactory {
|
||||
public:
|
||||
virtual ~IFactory() = default;
|
||||
virtual NodePtr create(const ov::OutputVector & parent) = 0;
|
||||
virtual NodePtr create(const ov::OutputVector& parent) = 0;
|
||||
|
||||
virtual size_t getNumInputs() const = 0;
|
||||
virtual size_t getNumOuputs() const = 0;
|
||||
@ -244,7 +247,7 @@ using FactoryPtr = std::shared_ptr<IFactory>;
|
||||
|
||||
class UnaryFactory : public IFactory {
|
||||
public:
|
||||
NodePtr create(const ov::OutputVector & parent) override {
|
||||
NodePtr create(const ov::OutputVector& parent) override {
|
||||
return std::make_shared<ov::opset9::Sinh>(parent.front());
|
||||
}
|
||||
|
||||
@ -252,13 +255,17 @@ public:
|
||||
return std::make_shared<UnaryFactory>();
|
||||
}
|
||||
|
||||
size_t getNumInputs() const override { return 1; }
|
||||
size_t getNumOuputs() const override { return 1; }
|
||||
size_t getNumInputs() const override {
|
||||
return 1;
|
||||
}
|
||||
size_t getNumOuputs() const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class BinaryFactory : public IFactory {
|
||||
public:
|
||||
NodePtr create(const ov::OutputVector & parent) override {
|
||||
NodePtr create(const ov::OutputVector& parent) override {
|
||||
return std::make_shared<ov::opset9::Add>(parent[0], parent[1]);
|
||||
}
|
||||
|
||||
@ -266,17 +273,19 @@ public:
|
||||
return std::make_shared<BinaryFactory>();
|
||||
}
|
||||
|
||||
size_t getNumInputs() const override { return 2; }
|
||||
size_t getNumOuputs() const override { return 1; }
|
||||
size_t getNumInputs() const override {
|
||||
return 2;
|
||||
}
|
||||
size_t getNumOuputs() const override {
|
||||
return 1;
|
||||
}
|
||||
};
|
||||
|
||||
class SplitFactory : public IFactory {
|
||||
public:
|
||||
SplitFactory(size_t axis) : axis_(axis) {}
|
||||
NodePtr create(const ov::OutputVector & parent) override {
|
||||
auto split_axis_const = std::make_shared<ov::opset9::Constant>(ov::element::u64,
|
||||
ov::Shape{},
|
||||
axis_);
|
||||
NodePtr create(const ov::OutputVector& parent) override {
|
||||
auto split_axis_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{}, axis_);
|
||||
return std::make_shared<ov::opset9::Split>(parent.front(), split_axis_const, 2);
|
||||
}
|
||||
|
||||
@ -284,8 +293,13 @@ public:
|
||||
return std::make_shared<SplitFactory>(axis);
|
||||
}
|
||||
|
||||
size_t getNumInputs() const override { return 1; }
|
||||
size_t getNumOuputs() const override { return 2; }
|
||||
size_t getNumInputs() const override {
|
||||
return 1;
|
||||
}
|
||||
size_t getNumOuputs() const override {
|
||||
return 2;
|
||||
}
|
||||
|
||||
private:
|
||||
const size_t axis_;
|
||||
};
|
||||
@ -293,7 +307,7 @@ private:
|
||||
class ConcatFactory : public IFactory {
|
||||
public:
|
||||
ConcatFactory(size_t axis) : axis_(axis) {}
|
||||
NodePtr create(const ov::OutputVector & parent) override {
|
||||
NodePtr create(const ov::OutputVector& parent) override {
|
||||
return std::make_shared<ov::opset9::Concat>(parent, axis_);
|
||||
}
|
||||
|
||||
@ -301,17 +315,26 @@ public:
|
||||
return std::make_shared<ConcatFactory>(axis);
|
||||
}
|
||||
|
||||
size_t getNumInputs() const override { return 2; }
|
||||
size_t getNumOuputs() const override { return 1; }
|
||||
size_t getNumInputs() const override {
|
||||
return 2;
|
||||
}
|
||||
size_t getNumOuputs() const override {
|
||||
return 1;
|
||||
}
|
||||
|
||||
private:
|
||||
const size_t axis_;
|
||||
};
|
||||
|
||||
/*
|
||||
/*
|
||||
Each node pair should be started with input size = 1 node and finished with node output size = 1
|
||||
Insert Split/Concat to fullfill that.
|
||||
*/
|
||||
NodePtr CreateNodePair(FactoryPtr factory_first, FactoryPtr factory_second, NodePtr parent, size_t split_axis, size_t concat_axis) {
|
||||
NodePtr CreateNodePair(FactoryPtr factory_first,
|
||||
FactoryPtr factory_second,
|
||||
NodePtr parent,
|
||||
size_t split_axis,
|
||||
size_t concat_axis) {
|
||||
NodePtr input = parent;
|
||||
if (factory_first->getNumInputs() != 1) {
|
||||
input = SplitFactory(split_axis).create(input->outputs());
|
||||
@ -333,9 +356,9 @@ NodePtr CreateNodePair(FactoryPtr factory_first, FactoryPtr factory_second, Node
|
||||
}
|
||||
|
||||
NodePtr MakeAllNodesSubgraph(NodePtr parent, size_t split_axis, size_t concat_axis) {
|
||||
std::vector<FactoryPtr> factories = { UnaryFactory::createFactory(),
|
||||
SplitFactory::createFactory(split_axis),
|
||||
ConcatFactory::createFactory(concat_axis) };
|
||||
std::vector<FactoryPtr> factories = {UnaryFactory::createFactory(),
|
||||
SplitFactory::createFactory(split_axis),
|
||||
ConcatFactory::createFactory(concat_axis)};
|
||||
NodePtr in_op = parent;
|
||||
for (int i = 0; i < factories.size(); ++i) {
|
||||
for (int j = 0; j < factories.size(); ++j) {
|
||||
@ -358,7 +381,8 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestMultipleTypes) {
|
||||
auto ng_order0 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<ov::opset9::Transpose>(node0, ng_order0);
|
||||
|
||||
auto reshape_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
|
||||
auto reshape_const =
|
||||
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
|
||||
auto reshape = std::make_shared<ov::opset9::Reshape>(transpose0, reshape_const, false);
|
||||
|
||||
auto ng_order1 = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
@ -377,7 +401,8 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestMultipleTypes) {
|
||||
|
||||
auto node0 = MakeAllNodesSubgraph(transpose0, 3, 3);
|
||||
|
||||
auto reshape_const = std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
|
||||
auto reshape_const =
|
||||
std::make_shared<ov::opset9::Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
|
||||
auto reshape = std::make_shared<ov::opset9::Reshape>(node0, reshape_const, false);
|
||||
|
||||
auto node1 = MakeAllNodesSubgraph(reshape, 3, 3);
|
||||
|
Loading…
Reference in New Issue
Block a user