codestyle
This commit is contained in:
parent
11b500953d
commit
3eeaf7f9bd
@ -19,7 +19,8 @@ using namespace opset10;
|
||||
|
||||
ov::pass::TransposeSinkingFuse::TransposeSinkingFuse() {
|
||||
MATCHER_SCOPE(TransposeFuse);
|
||||
auto transpose_1_label = pattern::wrap_type<Transpose>({pattern::any_input(), pattern::wrap_type<Constant>()}, transpose_sinking::HasSameOutputTransposeNodes);
|
||||
auto transpose_1_label = pattern::wrap_type<Transpose>({pattern::any_input(), pattern::wrap_type<Constant>()},
|
||||
transpose_sinking::HasSameOutputTransposeNodes);
|
||||
auto transpose_2_label = pattern::wrap_type<Transpose>({transpose_1_label, pattern::wrap_type<Constant>()});
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
@ -49,7 +50,6 @@ ov::pass::TransposeSinkingFuse::TransposeSinkingFuse() {
|
||||
if (transpose_order_type != transpose2_order->get_element_type())
|
||||
transpose_order_type = element::i64;
|
||||
|
||||
|
||||
if (is_ordered) {
|
||||
for (const auto& out_transpose : transpose1->output(0).get_target_inputs()) {
|
||||
ov::replace_output_update_name(out_transpose.get_node()->output(0), input);
|
||||
|
@ -10,10 +10,10 @@
|
||||
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_interpolate.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_interpolate.hpp"
|
||||
#include "transpose_sinking_test_utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
@ -154,7 +154,9 @@ FactoryPtr CreateReductionFactory(const std::string& type_name) {
|
||||
|
||||
class InterpolateFactory : public IFactory {
|
||||
public:
|
||||
explicit InterpolateFactory(const std::string& type_name, bool is_reference) : IFactory(type_name), m_is_reference(is_reference) {}
|
||||
explicit InterpolateFactory(const std::string& type_name, bool is_reference)
|
||||
: IFactory(type_name),
|
||||
m_is_reference(is_reference) {}
|
||||
NodePtr create(const OutputVector& parent_nodes) const override {
|
||||
std::vector<size_t> pads_begin{1, 2, 3, 4};
|
||||
std::vector<size_t> pads_end{1, 2, 3, 4};
|
||||
@ -163,15 +165,16 @@ public:
|
||||
pads_end = {4, 3, 2, 1};
|
||||
}
|
||||
const Interpolate::InterpolateAttrs attrs{Interpolate::InterpolateMode::NEAREST,
|
||||
Interpolate::ShapeCalcMode::SCALES,
|
||||
pads_begin,
|
||||
pads_end,
|
||||
Interpolate::CoordinateTransformMode::HALF_PIXEL,
|
||||
Interpolate::NearestMode::ROUND_PREFER_FLOOR,
|
||||
false,
|
||||
-0.75};
|
||||
Interpolate::ShapeCalcMode::SCALES,
|
||||
pads_begin,
|
||||
pads_end,
|
||||
Interpolate::CoordinateTransformMode::HALF_PIXEL,
|
||||
Interpolate::NearestMode::ROUND_PREFER_FLOOR,
|
||||
false,
|
||||
-0.75};
|
||||
return std::make_shared<Interpolate>(parent_nodes[0], parent_nodes[1], parent_nodes[2], parent_nodes[3], attrs);
|
||||
}
|
||||
|
||||
private:
|
||||
bool m_is_reference = false;
|
||||
};
|
||||
@ -277,22 +280,21 @@ public:
|
||||
};
|
||||
|
||||
vector<FactoryPtr> unary_factories = {
|
||||
CREATE_UNARY_FACTORY(Abs), CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Acosh),
|
||||
CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh), CREATE_UNARY_FACTORY(Atan),
|
||||
CREATE_UNARY_FACTORY(Atanh), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Clamp),
|
||||
CREATE_UNARY_FACTORY(Cos), CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Convert),
|
||||
CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(Exp),
|
||||
CREATE_UNARY_FACTORY(Floor), CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid),
|
||||
CREATE_UNARY_FACTORY(HSwish), CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(LogicalNot),
|
||||
CREATE_UNARY_FACTORY(Mish), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu),
|
||||
CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin),
|
||||
CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftPlus), CREATE_UNARY_FACTORY(SoftSign),
|
||||
CREATE_UNARY_FACTORY(Sqrt), CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)};
|
||||
CREATE_UNARY_FACTORY(Abs), CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Acosh),
|
||||
CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh), CREATE_UNARY_FACTORY(Atan),
|
||||
CREATE_UNARY_FACTORY(Atanh), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Clamp),
|
||||
CREATE_UNARY_FACTORY(Cos), CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Convert),
|
||||
CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(Exp),
|
||||
CREATE_UNARY_FACTORY(Floor), CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid),
|
||||
CREATE_UNARY_FACTORY(HSwish), CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(LogicalNot),
|
||||
CREATE_UNARY_FACTORY(Mish), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu),
|
||||
CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin),
|
||||
CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftPlus), CREATE_UNARY_FACTORY(SoftSign),
|
||||
CREATE_UNARY_FACTORY(Sqrt), CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)};
|
||||
|
||||
vector<FactoryPtr> logical_unary_factories = {
|
||||
CREATE_UNARY_FACTORY(IsFinite),
|
||||
CREATE_UNARY_FACTORY(IsInf),
|
||||
CREATE_UNARY_FACTORY(IsNaN)};
|
||||
vector<FactoryPtr> logical_unary_factories = {CREATE_UNARY_FACTORY(IsFinite),
|
||||
CREATE_UNARY_FACTORY(IsInf),
|
||||
CREATE_UNARY_FACTORY(IsNaN)};
|
||||
|
||||
std::vector<FactoryPtr> binary_factories = {CREATE_BINARY_FACTORY(Add),
|
||||
CREATE_BINARY_FACTORY(Divide),
|
||||
@ -379,8 +381,12 @@ auto test_forward_unary = [](const vector<FactoryPtr>& factories, const vector<s
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnaryForward, TransposeSinkingTestFixture, test_forward_unary(unary_factories, {1, 10}));
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonLogicalUnaryForward, TransposeSinkingTestFixture, test_forward_unary(logical_unary_factories, {1}));
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnaryForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_unary(unary_factories, {1, 10}));
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonLogicalUnaryForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_unary(logical_unary_factories, {1}));
|
||||
|
||||
auto test_forward_binary = []() {
|
||||
TestCase test_case;
|
||||
@ -636,7 +642,9 @@ auto test_forward_interpolate = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateForward, TransposeSinkingTestFixture, test_forward_interpolate());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateForward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_forward_interpolate());
|
||||
|
||||
// ------------------ BACKWARD --------------------
|
||||
|
||||
@ -647,7 +655,7 @@ auto test_backward_unary = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -672,8 +680,8 @@ auto test_backward_binary = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingBinaryBackward);
|
||||
test_case.num_main_ops = {1, 10};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -698,9 +706,9 @@ auto test_backward_concat = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingConcatBackward);
|
||||
test_case.num_main_ops = {1, 3};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
parameter(element::f32, {1, 96, 55, 55}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -725,8 +733,8 @@ auto test_backward_split = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingSplitBackward);
|
||||
test_case.num_main_ops = {1, 2};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 9, 55, 55}),
|
||||
constant<int64_t>(element::i32, {}, {1}),
|
||||
parameter(element::f32, {1, 9, 55, 55}),
|
||||
constant<int64_t>(element::i32, {}, {1}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -739,7 +747,7 @@ auto test_backward_split = []() {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] =
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{2});
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{2});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
|
||||
@ -757,9 +765,9 @@ auto test_backward_pad = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward);
|
||||
test_case.num_main_ops = {1, 2};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 3, 55, 55}),
|
||||
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
|
||||
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
|
||||
parameter(element::f32, {1, 3, 55, 55}),
|
||||
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
|
||||
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -784,10 +792,10 @@ auto test_backward_batch_to_space = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {128, 55, 3, 128}),
|
||||
constant<int64_t>(element::i32, {4}, {1, 2, 2, 2}),
|
||||
constant<int64_t>(element::i32, {4}, {1, 2, 2, 2}),
|
||||
constant<int64_t>(element::i32, {4}, {1, 2, 2, 2}),
|
||||
parameter(element::f32, {128, 55, 3, 128}),
|
||||
constant<int64_t>(element::i32, {4}, {1, 2, 2, 2}),
|
||||
constant<int64_t>(element::i32, {4}, {1, 2, 2, 2}),
|
||||
constant<int64_t>(element::i32, {4}, {1, 2, 2, 2}),
|
||||
};
|
||||
|
||||
// Reference model description:
|
||||
@ -814,10 +822,10 @@ auto test_backward_space_to_batch = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 8, 9, 64}),
|
||||
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
|
||||
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
|
||||
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
|
||||
parameter(element::f32, {1, 8, 9, 64}),
|
||||
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
|
||||
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
|
||||
constant<int64_t>(element::i32, {4}, {1, 2, 3, 4}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -843,8 +851,8 @@ auto test_backward_reduction = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {32, 4, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {1, 3}),
|
||||
parameter(element::f32, {32, 4, 2, 1}),
|
||||
constant<int64_t>(element::i32, {2}, {1, 3}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -857,7 +865,7 @@ auto test_backward_reduction = []() {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] =
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{2, 0});
|
||||
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{2, 0});
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
|
||||
@ -867,7 +875,9 @@ auto test_backward_reduction = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReductionBackward, TransposeSinkingTestFixture, test_backward_reduction());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReductionBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_reduction());
|
||||
|
||||
auto test_backward_interpolate = []() {
|
||||
TestCase test_case;
|
||||
@ -876,10 +886,10 @@ auto test_backward_interpolate = []() {
|
||||
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingInterpolateBackward);
|
||||
test_case.num_main_ops = {1};
|
||||
test_case.inputs_to_main = {
|
||||
parameter(element::f32, {1, 2, 48, 80}),
|
||||
constant<int64_t>(element::i32, {2}, {24, 160}),
|
||||
constant<float>(element::f32, {2}, {0.5, 2.}),
|
||||
constant<int64_t>(element::i32, {2}, {1, 2}),
|
||||
parameter(element::f32, {1, 2, 48, 80}),
|
||||
constant<int64_t>(element::i32, {2}, {24, 160}),
|
||||
constant<float>(element::f32, {2}, {0.5, 2.}),
|
||||
constant<int64_t>(element::i32, {2}, {1, 2}),
|
||||
};
|
||||
|
||||
// Test model description:
|
||||
@ -909,7 +919,9 @@ auto test_backward_interpolate = []() {
|
||||
return wrapper(test_case);
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateBackward, TransposeSinkingTestFixture, test_backward_interpolate());
|
||||
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateBackward,
|
||||
TransposeSinkingTestFixture,
|
||||
test_backward_interpolate());
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace common
|
||||
} // namespace transpose_sinking
|
@ -30,14 +30,12 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesForward
|
||||
|
||||
NodePtr in_op = X;
|
||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||
auto ng_order0 =
|
||||
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto ng_order0 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<Transpose>(in_op, ng_order0);
|
||||
|
||||
auto unary = std::make_shared<Tanh>(transpose0);
|
||||
|
||||
auto ng_order1 =
|
||||
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto ng_order1 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
in_op = std::make_shared<Transpose>(unary, ng_order1);
|
||||
}
|
||||
|
||||
@ -68,14 +66,12 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesBackwar
|
||||
|
||||
NodePtr in_op = X;
|
||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||
auto ng_order0 =
|
||||
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto ng_order0 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<Transpose>(in_op, ng_order0);
|
||||
|
||||
auto unary = std::make_shared<Tanh>(transpose0);
|
||||
|
||||
auto ng_order1 =
|
||||
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto ng_order1 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
in_op = std::make_shared<Transpose>(unary, ng_order1);
|
||||
}
|
||||
|
||||
@ -108,14 +104,12 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestUnariesTransposesGeneral
|
||||
|
||||
NodePtr in_op = transpose0;
|
||||
for (size_t i = 0; i < num_unary_ops; ++i) {
|
||||
auto ng_order0 =
|
||||
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto ng_order0 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<Transpose>(in_op, ng_order0);
|
||||
|
||||
auto unary = std::make_shared<Tanh>(transpose0);
|
||||
|
||||
auto ng_order1 =
|
||||
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
auto ng_order1 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
in_op = std::make_shared<Transpose>(unary, ng_order1);
|
||||
}
|
||||
|
||||
@ -153,8 +147,7 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestBinaryGeneral) {
|
||||
NodePtr in_op = transpose0;
|
||||
for (size_t i = 0; i < num_binary_ops; ++i) {
|
||||
auto in_constant = std::make_shared<Constant>(input_type, input_shape, ov::Shape{1});
|
||||
auto ng_order1 =
|
||||
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto ng_order1 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose1 = std::make_shared<Transpose>(in_constant, ng_order1);
|
||||
|
||||
in_op = std::make_shared<Add>(in_op, transpose1);
|
||||
@ -199,8 +192,7 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestConcatGeneral) {
|
||||
concat_inputs.push_back(in_op);
|
||||
for (size_t j = 1; j < num_concat_inputs; ++j) {
|
||||
auto in_constant = std::make_shared<Constant>(input_type, input_shape, ov::Shape{1});
|
||||
auto ng_order1 =
|
||||
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto ng_order1 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose1 = std::make_shared<Transpose>(in_constant, ng_order1);
|
||||
concat_inputs.push_back(transpose1);
|
||||
}
|
||||
@ -385,8 +377,7 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestMultipleTypes) {
|
||||
auto ng_order0 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 2, 3, 1});
|
||||
auto transpose0 = std::make_shared<Transpose>(node0, ng_order0);
|
||||
|
||||
auto reshape_const =
|
||||
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
|
||||
auto reshape_const = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
|
||||
auto reshape = std::make_shared<Reshape>(transpose0, reshape_const, false);
|
||||
|
||||
auto ng_order1 = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{0, 3, 1, 2});
|
||||
@ -405,8 +396,7 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestMultipleTypes) {
|
||||
|
||||
auto node0 = MakeAllNodesSubgraph(transpose0, 3, 3);
|
||||
|
||||
auto reshape_const =
|
||||
std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
|
||||
auto reshape_const = std::make_shared<Constant>(ov::element::u64, ov::Shape{4}, ov::Shape{1, 40, 55, 96});
|
||||
auto reshape = std::make_shared<Reshape>(node0, reshape_const, false);
|
||||
|
||||
auto node1 = MakeAllNodesSubgraph(reshape, 3, 3);
|
||||
@ -420,6 +410,6 @@ TEST_F(TransformationTestsF, TransposeSinkingGeneralTestMultipleTypes) {
|
||||
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace general
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
@ -835,6 +835,6 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingSplitBackwardRestrictTestSuite,
|
||||
|
||||
} // namespace backward
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace split
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
||||
|
@ -83,5 +83,5 @@ std::shared_ptr<ov::Node> parameter(ov::element::Type el_type, const PartialShap
|
||||
return std::make_shared<Parameter>(el_type, ps);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
@ -62,12 +62,10 @@ std::shared_ptr<ov::Node> create_main_node(const ov::OutputVector& inputs, size_
|
||||
ov::ParameterVector filter_parameters(const ov::OutputVector& out_vec);
|
||||
|
||||
std::shared_ptr<ov::Node> parameter(ov::element::Type el_type, const ov::PartialShape& ps);
|
||||
template<class T>
|
||||
std::shared_ptr<ov::Node> constant(ov::element::Type el_type,
|
||||
const ov::Shape& shape,
|
||||
const std::vector<T>& value) {
|
||||
template <class T>
|
||||
std::shared_ptr<ov::Node> constant(ov::element::Type el_type, const ov::Shape& shape, const std::vector<T>& value) {
|
||||
return ov::opset10::Constant::create<T>(el_type, shape, value);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
Loading…
Reference in New Issue
Block a user