codestyle

This commit is contained in:
Tikhonov Ivan 2023-03-08 17:42:25 +00:00
parent 11b500953d
commit 3eeaf7f9bd
6 changed files with 94 additions and 94 deletions

View File

@ -19,7 +19,8 @@ using namespace opset10;
ov::pass::TransposeSinkingFuse::TransposeSinkingFuse() {
MATCHER_SCOPE(TransposeFuse);
auto transpose_1_label = pattern::wrap_type<Transpose>({pattern::any_input(), pattern::wrap_type<Constant>()}, transpose_sinking::HasSameOutputTransposeNodes);
auto transpose_1_label = pattern::wrap_type<Transpose>({pattern::any_input(), pattern::wrap_type<Constant>()},
transpose_sinking::HasSameOutputTransposeNodes);
auto transpose_2_label = pattern::wrap_type<Transpose>({transpose_1_label, pattern::wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_map();
@ -49,7 +50,6 @@ ov::pass::TransposeSinkingFuse::TransposeSinkingFuse() {
if (transpose_order_type != transpose2_order->get_element_type())
transpose_order_type = element::i64;
if (is_ordered) {
for (const auto& out_transpose : transpose1->output(0).get_target_inputs()) {
ov::replace_output_update_name(out_transpose.get_node()->output(0), input);

View File

@ -10,10 +10,10 @@
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
#include "transformations/common_optimizations/transpose_sinking_interpolate.hpp"
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
#include "transformations/common_optimizations/transpose_sinking_interpolate.hpp"
#include "transpose_sinking_test_utils.hpp"
using namespace std;
@ -154,7 +154,9 @@ FactoryPtr CreateReductionFactory(const std::string& type_name) {
class InterpolateFactory : public IFactory {
public:
explicit InterpolateFactory(const std::string& type_name, bool is_reference) : IFactory(type_name), m_is_reference(is_reference) {}
explicit InterpolateFactory(const std::string& type_name, bool is_reference)
: IFactory(type_name),
m_is_reference(is_reference) {}
NodePtr create(const OutputVector& parent_nodes) const override {
std::vector<size_t> pads_begin{1, 2, 3, 4};
std::vector<size_t> pads_end{1, 2, 3, 4};
@ -163,15 +165,16 @@ public:
pads_end = {4, 3, 2, 1};
}
const Interpolate::InterpolateAttrs attrs{Interpolate::InterpolateMode::NEAREST,
Interpolate::ShapeCalcMode::SCALES,
pads_begin,
pads_end,
Interpolate::CoordinateTransformMode::HALF_PIXEL,
Interpolate::NearestMode::ROUND_PREFER_FLOOR,
false,
-0.75};
Interpolate::ShapeCalcMode::SCALES,
pads_begin,
pads_end,
Interpolate::CoordinateTransformMode::HALF_PIXEL,
Interpolate::NearestMode::ROUND_PREFER_FLOOR,
false,
-0.75};
return std::make_shared<Interpolate>(parent_nodes[0], parent_nodes[1], parent_nodes[2], parent_nodes[3], attrs);
}
private:
bool m_is_reference = false;
};
@ -277,22 +280,21 @@ public:
};
vector<FactoryPtr> unary_factories = {
CREATE_UNARY_FACTORY(Abs), CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Acosh),
CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh), CREATE_UNARY_FACTORY(Atan),
CREATE_UNARY_FACTORY(Atanh), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Clamp),
CREATE_UNARY_FACTORY(Cos), CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Convert),
CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(Exp),
CREATE_UNARY_FACTORY(Floor), CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid),
CREATE_UNARY_FACTORY(HSwish), CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(LogicalNot),
CREATE_UNARY_FACTORY(Mish), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu),
CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin),
CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftPlus), CREATE_UNARY_FACTORY(SoftSign),
CREATE_UNARY_FACTORY(Sqrt), CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)};
CREATE_UNARY_FACTORY(Abs), CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Acosh),
CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh), CREATE_UNARY_FACTORY(Atan),
CREATE_UNARY_FACTORY(Atanh), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Clamp),
CREATE_UNARY_FACTORY(Cos), CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Convert),
CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(Exp),
CREATE_UNARY_FACTORY(Floor), CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid),
CREATE_UNARY_FACTORY(HSwish), CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(LogicalNot),
CREATE_UNARY_FACTORY(Mish), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu),
CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin),
CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftPlus), CREATE_UNARY_FACTORY(SoftSign),
CREATE_UNARY_FACTORY(Sqrt), CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)};
vector<FactoryPtr> logical_unary_factories = {
CREATE_UNARY_FACTORY(IsFinite),
CREATE_UNARY_FACTORY(IsInf),
CREATE_UNARY_FACTORY(IsNaN)};
vector<FactoryPtr> logical_unary_factories = {CREATE_UNARY_FACTORY(IsFinite),
CREATE_UNARY_FACTORY(IsInf),
CREATE_UNARY_FACTORY(IsNaN)};
std::vector<FactoryPtr> binary_factories = {CREATE_BINARY_FACTORY(Add),
CREATE_BINARY_FACTORY(Divide),
@ -379,8 +381,12 @@ auto test_forward_unary = [](const vector<FactoryPtr>& factories, const vector<s
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnaryForward, TransposeSinkingTestFixture, test_forward_unary(unary_factories, {1, 10}));
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonLogicalUnaryForward, TransposeSinkingTestFixture, test_forward_unary(logical_unary_factories, {1}));
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnaryForward,
TransposeSinkingTestFixture,
test_forward_unary(unary_factories, {1, 10}));
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonLogicalUnaryForward,
TransposeSinkingTestFixture,
test_forward_unary(logical_unary_factories, {1}));
auto test_forward_binary = []() {
TestCase test_case;
@ -636,7 +642,9 @@ auto test_forward_interpolate = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateForward, TransposeSinkingTestFixture, test_forward_interpolate());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateForward,
TransposeSinkingTestFixture,
test_forward_interpolate());
// ------------------ BACKWARD --------------------
@ -647,7 +655,7 @@ auto test_backward_unary = []() {
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
test_case.num_main_ops = {1, 10};
test_case.inputs_to_main = {
parameter(element::f32, {1, 96, 55, 55}),
parameter(element::f32, {1, 96, 55, 55}),
};
// Test model description:
@ -672,8 +680,8 @@ auto test_backward_binary = []() {
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward);
test_case.num_main_ops = {1, 10};
test_case.inputs_to_main = {
parameter(element::f32, {1, 96, 55, 55}),
parameter(element::f32, {1, 96, 55, 55}),
parameter(element::f32, {1, 96, 55, 55}),
parameter(element::f32, {1, 96, 55, 55}),
};
// Test model description:
@ -698,9 +706,9 @@ auto test_backward_concat = []() {
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingConcatBackward);
test_case.num_main_ops = {1, 3};
test_case.inputs_to_main = {
parameter(element::f32, {1, 96, 55, 55}),
parameter(element::f32, {1, 96, 55, 55}),
parameter(element::f32, {1, 96, 55, 55}),
parameter(element::f32, {1, 96, 55, 55}),
parameter(element::f32, {1, 96, 55, 55}),
parameter(element::f32, {1, 96, 55, 55}),
};
// Test model description:
@ -725,8 +733,8 @@ auto test_backward_split = []() {
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingSplitBackward);
test_case.num_main_ops = {1, 2};
test_case.inputs_to_main = {
parameter(element::f32, {1, 9, 55, 55}),
constant<int64_t>(element::i32, {}, {1}),
parameter(element::f32, {1, 9, 55, 55}),
constant<int64_t>(element::i32, {}, {1}),
};
// Test model description:
@ -739,7 +747,7 @@ auto test_backward_split = []() {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] =
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{2});
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{2});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
@ -757,9 +765,9 @@ auto test_backward_pad = []() {
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward);
test_case.num_main_ops = {1, 2};
test_case.inputs_to_main = {
parameter(element::f32, {1, 3, 55, 55}),
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
parameter(element::f32, {1, 3, 55, 55}),
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
};
// Test model description:
@ -784,10 +792,10 @@ auto test_backward_batch_to_space = []() {
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {128, 55, 3, 128}),
constant<int64_t>(element::i32, {4}, {1, 2, 2, 2}),
constant<int64_t>(element::i32, {4}, {1, 2, 2, 2}),
constant<int64_t>(element::i32, {4}, {1, 2, 2, 2}),
parameter(element::f32, {128, 55, 3, 128}),
constant<int64_t>(element::i32, {4}, {1, 2, 2, 2}),
constant<int64_t>(element::i32, {4}, {1, 2, 2, 2}),
constant<int64_t>(element::i32, {4}, {1, 2, 2, 2}),
};
// Reference model description:
@ -814,10 +822,10 @@ auto test_backward_space_to_batch = []() {
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {1, 8, 9, 64}),
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
parameter(element::f32, {1, 8, 9, 64}),
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
};
// Test model description:
@ -843,8 +851,8 @@ auto test_backward_reduction = []() {
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 4, 2, 1}),
constant<int64_t>(element::i32, {2}, {1, 3}),
parameter(element::f32, {32, 4, 2, 1}),
constant<int64_t>(element::i32, {2}, {1, 3}),
};
// Test model description:
@ -857,7 +865,7 @@ auto test_backward_reduction = []() {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] =
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{2, 0});
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{2, 0});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
@ -867,7 +875,9 @@ auto test_backward_reduction = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReductionBackward, TransposeSinkingTestFixture, test_backward_reduction());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReductionBackward,
TransposeSinkingTestFixture,
test_backward_reduction());
auto test_backward_interpolate = []() {
TestCase test_case;
@ -876,10 +886,10 @@ auto test_backward_interpolate = []() {
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingInterpolateBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {1, 2, 48, 80}),
constant<int64_t>(element::i32, {2}, {24, 160}),
constant<float>(element::f32, {2}, {0.5, 2.}),
constant<int64_t>(element::i32, {2}, {1, 2}),
parameter(element::f32, {1, 2, 48, 80}),
constant<int64_t>(element::i32, {2}, {24, 160}),
constant<float>(element::f32, {2}, {0.5, 2.}),
constant<int64_t>(element::i32, {2}, {1, 2}),
};
// Test model description:
@ -909,7 +919,9 @@ auto test_backward_interpolate = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateBackward, TransposeSinkingTestFixture, test_backward_interpolate());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateBackward,
TransposeSinkingTestFixture,
test_backward_interpolate());
}
}
} // namespace common
} // namespace transpose_sinking

View File

@ -30,14 +30,12 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesForward
NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) {
auto ng_order0 =
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto ng_order0 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(in_op, ng_order0);
auto unary = std::make_shared<Tanh>(transpose0);
auto ng_order1 =
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto ng_order1 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
in_op = std::make_shared<Transpose>(unary, ng_order1);
}
@ -68,14 +66,12 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesBackwar
NodePtr in_op = X;
for (size_t i = 0; i < num_unary_ops; ++i) {
auto ng_order0 =
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto ng_order0 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(in_op, ng_order0);
auto unary = std::make_shared<Tanh>(transpose0);
auto ng_order1 =
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto ng_order1 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
in_op = std::make_shared<Transpose>(unary, ng_order1);
}
@ -108,14 +104,12 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesGeneral
NodePtr in_op = transpose0;
for (size_t i = 0; i < num_unary_ops; ++i) {
auto ng_order0 =
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto ng_order0 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(in_op, ng_order0);
auto unary = std::make_shared<Tanh>(transpose0);
auto ng_order1 =
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
auto ng_order1 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
in_op = std::make_shared<Transpose>(unary, ng_order1);
}
@ -153,8 +147,7 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestBinaryGeneral) {
NodePtr in_op = transpose0;
for (size_t i = 0; i < num_binary_ops; ++i) {
auto in_constant = std::make_shared<Constant>(input_type, input_shape, ov::Shape{1});
auto ng_order1 =
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto ng_order1 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose1 = std::make_shared<Transpose>(in_constant, ng_order1);
in_op = std::make_shared<Add>(in_op, transpose1);
@ -199,8 +192,7 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestConcatGeneral) {
concat_inputs.push_back(in_op);
for (size_t j = 1; j < num_concat_inputs; ++j) {
auto in_constant = std::make_shared<Constant>(input_type, input_shape, ov::Shape{1});
auto ng_order1 =
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto ng_order1 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose1 = std::make_shared<Transpose>(in_constant, ng_order1);
concat_inputs.push_back(transpose1);
}
@ -385,8 +377,7 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestMultipleTypes) {
auto ng_order0 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
auto transpose0 = std::make_shared<Transpose>(node0, ng_order0);
auto reshape_const =
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
auto reshape_const = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
auto reshape = std::make_shared<Reshape>(transpose0, reshape_const, false);
auto ng_order1 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
@ -405,8 +396,7 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestMultipleTypes) {
auto node0 = MakeAllNodesSubgraph(transpose0, 3, 3);
auto reshape_const =
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
auto reshape_const = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
auto reshape = std::make_shared<Reshape>(node0, reshape_const, false);
auto node1 = MakeAllNodesSubgraph(reshape, 3, 3);
@ -420,6 +410,6 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestMultipleTypes) {
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
}
}
}
}
} // namespace general
} // namespace testing
} // namespace transpose_sinking

View File

@ -835,6 +835,6 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardRestrictTestSuite,
} // namespace backward
}
}
}
} // namespace split
} // namespace testing
} // namespace transpose_sinking

View File

@ -83,5 +83,5 @@ std::shared_ptr<ov::Node> parameter(ov::element::Type el_type, const PartialShap
return std::make_shared<Parameter>(el_type, ps);
}
}
}
} // namespace testing
} // namespace transpose_sinking

View File

@ -62,12 +62,10 @@ std::shared_ptr<ov::Node> create_main_node(const ov::OutputVector& inputs, size_
ov::ParameterVector filter_parameters(const ov::OutputVector& out_vec);
std::shared_ptr<ov::Node> parameter(ov::element::Type el_type, const ov::PartialShape& ps);
template<class T>
std::shared_ptr<ov::Node> constant(ov::element::Type el_type,
const ov::Shape& shape,
const std::vector<T>& value) {
template <class T>
std::shared_ptr<ov::Node> constant(ov::element::Type el_type, const ov::Shape& shape, const std::vector<T>& value) {
return ov::opset10::Constant::create<T>(el_type, shape, value);
}
}
}
} // namespace testing
} // namespace transpose_sinking