This commit is contained in:
Tikhonov Ivan 2023-03-07 07:34:40 +00:00
parent 3c5f62c013
commit 176686318f
11 changed files with 151 additions and 174 deletions

View File

@ -18,8 +18,8 @@ class TRANSFORMATIONS_API TransposeSinkingFuse;
/**
* @ingroup ie_transformation_common_api
* @brief TransposeSinkingFuse transformation eliminates 2 consecutive Transposes if they result in no changes to input or
* fuses them to single Transpose if input gets changed
* @brief TransposeSinkingFuse transformation eliminates 2 consecutive Transposes if they result in no changes to input
* or fuses them to single Transpose if input gets changed
*/
class ov::pass::TransposeSinkingFuse : public ov::pass::MatcherPass {
public:

View File

@ -13,11 +13,11 @@
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
#include "transformations/common_optimizations/transpose_sinking_fuse.hpp"
#include "transformations/common_optimizations/transpose_sinking_interpolate.hpp"
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
#include "transformations/common_optimizations/transpose_sinking_fuse.hpp"
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
#include "transformations/utils/utils.hpp"
ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {

View File

@ -9,9 +9,9 @@
#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/op/util/arithmetic_reductions_keep_dims.hpp"
#include "openvino/op/util/logical_reduction_keep_dims.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/common_optimizations/transpose_sinking_utils.hpp"
#include "transformations/utils/utils.hpp"

View File

@ -23,8 +23,7 @@ ov::pass::TransposeSinkingStridedSliceForward::TransposeSinkingStridedSliceForwa
MATCHER_SCOPE(TransposeSinkingStridedSliceForward);
auto const_label = wrap_type<Constant>();
auto transpose_label = wrap_type<Transpose>({any_input(), const_label});
auto main_node_label =
wrap_type<StridedSlice>({transpose_label, any_input(), any_input(), any_input()});
auto main_node_label = wrap_type<StridedSlice>({transpose_label, any_input(), any_input(), any_input()});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_node = m.get_pattern_map();

View File

@ -65,25 +65,25 @@ FactoryPtr CreateBinaryFactory(const std::string& type_name) {
* Unsqueeze insertion
*/
std::vector<FactoryPtr> binary_elementwise_factories = {CREATE_BINARY_FACTORY(Add),
CREATE_BINARY_FACTORY(Divide),
CREATE_BINARY_FACTORY(Maximum),
CREATE_BINARY_FACTORY(Minimum),
CREATE_BINARY_FACTORY(Mod),
CREATE_BINARY_FACTORY(Multiply),
CREATE_BINARY_FACTORY(Power),
CREATE_BINARY_FACTORY(SquaredDifference),
CREATE_BINARY_FACTORY(Subtract)};
CREATE_BINARY_FACTORY(Divide),
CREATE_BINARY_FACTORY(Maximum),
CREATE_BINARY_FACTORY(Minimum),
CREATE_BINARY_FACTORY(Mod),
CREATE_BINARY_FACTORY(Multiply),
CREATE_BINARY_FACTORY(Power),
CREATE_BINARY_FACTORY(SquaredDifference),
CREATE_BINARY_FACTORY(Subtract)};
std::vector<FactoryPtr> binary_factories = {CREATE_BINARY_FACTORY(Add),
CREATE_BINARY_FACTORY(Divide),
CREATE_BINARY_FACTORY(Maximum),
CREATE_BINARY_FACTORY(Minimum),
CREATE_BINARY_FACTORY(Mod),
CREATE_BINARY_FACTORY(Multiply),
CREATE_BINARY_FACTORY(Power),
CREATE_BINARY_FACTORY(SquaredDifference),
CREATE_BINARY_FACTORY(Subtract),
CREATE_BINARY_FACTORY(PRelu)};
CREATE_BINARY_FACTORY(Divide),
CREATE_BINARY_FACTORY(Maximum),
CREATE_BINARY_FACTORY(Minimum),
CREATE_BINARY_FACTORY(Mod),
CREATE_BINARY_FACTORY(Multiply),
CREATE_BINARY_FACTORY(Power),
CREATE_BINARY_FACTORY(SquaredDifference),
CREATE_BINARY_FACTORY(Subtract),
CREATE_BINARY_FACTORY(PRelu)};
std::vector<size_t> binary_operations_numbers = {1, 10};
@ -150,9 +150,7 @@ std::shared_ptr<Model> CreateReferenceFunction(FactoryPtr binary_factory,
} // namespace one_input_transpose
namespace double_transpose {
std::shared_ptr<Model> CreateFunction(FactoryPtr binary_factory,
size_t num_binary_ops,
element::Type input_type) {
std::shared_ptr<Model> CreateFunction(FactoryPtr binary_factory, size_t num_binary_ops, element::Type input_type) {
const Shape input_shape{1, 96, 55, 55};
auto X = std::make_shared<Parameter>(input_type, input_shape);
@ -198,8 +196,8 @@ std::shared_ptr<Model> CreateReferenceFunction(FactoryPtr binary_factory,
return std::make_shared<Model>(ov::OutputVector{transpose0}, ov::ParameterVector{X});
}
using CreateGraphBinaryTwoTransposeInputsF = std::function<
std::shared_ptr<Model>(FactoryPtr binary_factory, size_t num_binary_ops, element::Type input_type)>;
using CreateGraphBinaryTwoTransposeInputsF =
std::function<std::shared_ptr<Model>(FactoryPtr binary_factory, size_t num_binary_ops, element::Type input_type)>;
using TestBinaryTwoTransposeInputsParams =
std::tuple<FactoryPtr,
@ -1071,9 +1069,8 @@ std::shared_ptr<Model> CreateReferenceFunction(FactoryPtr binary_factory,
} // namespace backward
using CreateGraphF = std::function<std::shared_ptr<Model>(FactoryPtr binary_factory,
element::Type input_type,
size_t binary_transpose_input_idx)>;
using CreateGraphF = std::function<
std::shared_ptr<Model>(FactoryPtr binary_factory, element::Type input_type, size_t binary_transpose_input_idx)>;
struct CreateGraphFunctionDesc {
CreateGraphFunctionDesc() = default;

View File

@ -2,19 +2,18 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
#include "transpose_sinking_test_utils.hpp"
#include <openvino/opsets/opset10.hpp>
#include <openvino/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "gtest/gtest.h"
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
#include "transformations/common_optimizations/transpose_sinking_data_movement.hpp"
#include "transformations/common_optimizations/transpose_sinking_reduction.hpp"
#include "transformations/common_optimizations/transpose_sinking_split.hpp"
#include "transformations/common_optimizations/transpose_sinking_unary.hpp"
#include "transpose_sinking_test_utils.hpp"
using namespace std;
using namespace ov;
@ -89,7 +88,6 @@ FactoryPtr CreateConcatRefFactory(const std::string& type_name) {
return std::make_shared<ConcatFactoryRef>(type_name);
}
class SplitFactory : public IFactory {
public:
explicit SplitFactory(const std::string& type_name) : IFactory(type_name) {}
@ -194,7 +192,8 @@ struct Preprocessing {
struct TestCase;
struct ModelDescription;
using TestParams = tuple<size_t /* idx num_main_ops */, size_t /* idx main_op */, TestCase>;
using CreateGraphF = function<shared_ptr<ov::Model>(size_t main_op_idx, const ModelDescription&, size_t, const OutputVector&)>;
using CreateGraphF =
function<shared_ptr<ov::Model>(size_t main_op_idx, const ModelDescription&, size_t, const OutputVector&)>;
// Describes a model to test.
// Expects to be used in such a scenario:
@ -242,16 +241,16 @@ public:
};
vector<FactoryPtr> unary_factories = {
CREATE_UNARY_FACTORY(Clamp), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(SoftPlus),
CREATE_UNARY_FACTORY(LogicalNot), CREATE_UNARY_FACTORY(Convert), CREATE_UNARY_FACTORY(Abs),
CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh),
CREATE_UNARY_FACTORY(Atan), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Cos),
CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Exp),
CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid), CREATE_UNARY_FACTORY(HSwish),
CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu),
CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin),
CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftSign), CREATE_UNARY_FACTORY(Sqrt),
CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)};
CREATE_UNARY_FACTORY(Clamp), CREATE_UNARY_FACTORY(Elu), CREATE_UNARY_FACTORY(SoftPlus),
CREATE_UNARY_FACTORY(LogicalNot), CREATE_UNARY_FACTORY(Convert), CREATE_UNARY_FACTORY(Abs),
CREATE_UNARY_FACTORY(Acos), CREATE_UNARY_FACTORY(Asin), CREATE_UNARY_FACTORY(Asinh),
CREATE_UNARY_FACTORY(Atan), CREATE_UNARY_FACTORY(Ceiling), CREATE_UNARY_FACTORY(Cos),
CREATE_UNARY_FACTORY(Cosh), CREATE_UNARY_FACTORY(Erf), CREATE_UNARY_FACTORY(Exp),
CREATE_UNARY_FACTORY(Gelu), CREATE_UNARY_FACTORY(HSigmoid), CREATE_UNARY_FACTORY(HSwish),
CREATE_UNARY_FACTORY(Log), CREATE_UNARY_FACTORY(Negative), CREATE_UNARY_FACTORY(Relu),
CREATE_UNARY_FACTORY(Sigmoid), CREATE_UNARY_FACTORY(Sign), CREATE_UNARY_FACTORY(Sin),
CREATE_UNARY_FACTORY(Sinh), CREATE_UNARY_FACTORY(SoftSign), CREATE_UNARY_FACTORY(Sqrt),
CREATE_UNARY_FACTORY(Tan), CREATE_UNARY_FACTORY(Tanh)};
std::vector<FactoryPtr> binary_factories = {CREATE_BINARY_FACTORY(Add),
CREATE_BINARY_FACTORY(Divide),
@ -264,19 +263,19 @@ std::vector<FactoryPtr> binary_factories = {CREATE_BINARY_FACTORY(Add),
CREATE_BINARY_FACTORY(Subtract),
CREATE_BINARY_FACTORY(PRelu)};
std::vector<FactoryPtr> reduction_factories = {CREATE_BINARY_FACTORY(ReduceMax),
CREATE_BINARY_FACTORY(ReduceMin),
CREATE_BINARY_FACTORY(ReduceMean),
CREATE_BINARY_FACTORY(ReduceSum),
CREATE_BINARY_FACTORY(ReduceProd),
//CREATE_BINARY_FACTORY(ReduceLogicalOr),
//CREATE_BINARY_FACTORY(ReduceLogicalAnd),
CREATE_BINARY_FACTORY(ReduceL1),
CREATE_BINARY_FACTORY(ReduceL2),
//CREATE_BINARY_FACTORY(Squeeze),
//CREATE_BINARY_FACTORY(Unsqueeze),
};
std::vector<FactoryPtr> reduction_factories = {
CREATE_BINARY_FACTORY(ReduceMax),
CREATE_BINARY_FACTORY(ReduceMin),
CREATE_BINARY_FACTORY(ReduceMean),
CREATE_BINARY_FACTORY(ReduceSum),
CREATE_BINARY_FACTORY(ReduceProd),
// CREATE_BINARY_FACTORY(ReduceLogicalOr),
// CREATE_BINARY_FACTORY(ReduceLogicalAnd),
CREATE_BINARY_FACTORY(ReduceL1),
CREATE_BINARY_FACTORY(ReduceL2),
// CREATE_BINARY_FACTORY(Squeeze),
// CREATE_BINARY_FACTORY(Unsqueeze),
};
TEST_P(TransposeSinkingTestFixture, CompareFunctions) {
int num_main_ops_idx;
@ -311,8 +310,8 @@ shared_ptr<ov::Model> create_model(size_t main_node_idx,
auto outputs = model_desc.preprocess_outputs_of_main.apply(main_node->outputs());
return make_shared<ov::Model>(outputs, filter_parameters(inputs_to_main));
}
}
}
} // namespace common
} // namespace transpose_sinking
auto wrapper = [](const TestCase& test_case) {
OPENVINO_ASSERT(test_case.model.main_op.size() == test_case.model_ref.main_op.size(),
@ -330,7 +329,7 @@ auto test_forward_unary = []() {
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryForward);
test_case.num_main_ops = {1, 10};
test_case.inputs_to_main = {
parameter(element::f32, {1, 96, 55, 55}),
parameter(element::f32, {1, 96, 55, 55}),
};
// Test model description:
@ -346,9 +345,7 @@ auto test_forward_unary = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnaryForward,
TransposeSinkingTestFixture,
test_forward_unary());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnaryForward, TransposeSinkingTestFixture, test_forward_unary());
auto test_forward_binary = []() {
TestCase test_case;
@ -357,8 +354,8 @@ auto test_forward_binary = []() {
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingBinaryForward);
test_case.num_main_ops = {1, 10};
test_case.inputs_to_main = {
parameter(element::f32, {1, 96, 55, 55}),
parameter(element::f32, {55, 55, 96, 1}),
parameter(element::f32, {1, 96, 55, 55}),
parameter(element::f32, {55, 55, 96, 1}),
};
// Test model description:
@ -375,10 +372,7 @@ auto test_forward_binary = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryForward,
TransposeSinkingTestFixture,
test_forward_binary());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonBinaryForward, TransposeSinkingTestFixture, test_forward_binary());
auto test_forward_concat = []() {
TestCase test_case;
@ -387,9 +381,9 @@ auto test_forward_concat = []() {
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingConcatForward);
test_case.num_main_ops = {1, 3};
test_case.inputs_to_main = {
parameter(element::f32, {1, 96, 55, 55}),
parameter(element::f32, {55, 55, 96, 1}),
parameter(element::f32, {55, 55, 96, 1}),
parameter(element::f32, {1, 96, 55, 55}),
parameter(element::f32, {55, 55, 96, 1}),
parameter(element::f32, {55, 55, 96, 1}),
};
// Test model description:
@ -406,9 +400,7 @@ auto test_forward_concat = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonConcatForward,
TransposeSinkingTestFixture,
test_forward_concat());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonConcatForward, TransposeSinkingTestFixture, test_forward_concat());
auto test_forward_split = []() {
TestCase test_case;
@ -417,8 +409,8 @@ auto test_forward_split = []() {
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingSplitForward);
test_case.num_main_ops = {1, 2};
test_case.inputs_to_main = {
parameter(element::f32, {1, 9, 55, 55}),
constant(element::i32, {}, {2}),
parameter(element::f32, {1, 9, 55, 55}),
constant(element::i32, {}, {2}),
};
// Test model description:
@ -430,7 +422,8 @@ auto test_forward_split = []() {
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] = make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{1});
new_out_vec[1] =
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{1});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
@ -441,9 +434,7 @@ auto test_forward_split = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSplitForward,
TransposeSinkingTestFixture,
test_forward_split());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonSplitForward, TransposeSinkingTestFixture, test_forward_split());
auto test_forward_pad = []() {
TestCase test_case;
@ -452,9 +443,9 @@ auto test_forward_pad = []() {
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward);
test_case.num_main_ops = {1, 2};
test_case.inputs_to_main = {
parameter(element::f32, {1, 3, 55, 55}),
constant(element::i32, {4}, {1, 2, 3, 4}),
constant(element::i32, {4}, {1, 2, 3, 4}),
parameter(element::f32, {1, 3, 55, 55}),
constant(element::i32, {4}, {1, 2, 3, 4}),
constant(element::i32, {4}, {1, 2, 3, 4}),
};
// Test model description:
@ -471,9 +462,7 @@ auto test_forward_pad = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonPadForward,
TransposeSinkingTestFixture,
test_forward_pad());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonPadForward, TransposeSinkingTestFixture, test_forward_pad());
auto test_forward_batch_to_space = []() {
TestCase test_case;
@ -482,10 +471,10 @@ auto test_forward_batch_to_space = []() {
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward);
test_case.num_main_ops = {1, 2};
test_case.inputs_to_main = {
parameter(element::f32, {128, 55, 3, 128}),
constant(element::i32, {4}, {1, 2, 2, 2}),
constant(element::i32, {4}, {1, 2, 2, 2}),
constant(element::i32, {4}, {1, 2, 2, 2}),
parameter(element::f32, {128, 55, 3, 128}),
constant(element::i32, {4}, {1, 2, 2, 2}),
constant(element::i32, {4}, {1, 2, 2, 2}),
constant(element::i32, {4}, {1, 2, 2, 2}),
};
// Test model description:
@ -513,10 +502,10 @@ auto test_forward_space_to_batch = []() {
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingDataMovementForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {64, 9, 8, 1}),
constant(element::i32, {4}, {1, 2, 3, 4}),
constant(element::i32, {4}, {1, 2, 3, 4}),
constant(element::i32, {4}, {1, 2, 3, 4}),
parameter(element::f32, {64, 9, 8, 1}),
constant(element::i32, {4}, {1, 2, 3, 4}),
constant(element::i32, {4}, {1, 2, 3, 4}),
constant(element::i32, {4}, {1, 2, 3, 4}),
};
// Test model description:
@ -544,8 +533,8 @@ auto test_forward_reduction = []() {
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingReductionForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 4, 2, 1}),
constant(element::i32, {2}, {1, 3}),
parameter(element::f32, {32, 4, 2, 1}),
constant(element::i32, {2}, {1, 3}),
};
// Test model description:
@ -557,7 +546,8 @@ auto test_forward_reduction = []() {
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] = make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{2, 0});
new_out_vec[1] =
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{2, 0});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
@ -568,6 +558,4 @@ auto test_forward_reduction = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReductionForward,
TransposeSinkingTestFixture,
test_forward_reduction());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReductionForward, TransposeSinkingTestFixture, test_forward_reduction());

View File

@ -11,8 +11,8 @@
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "transpose_sinking_test_utils.hpp"
#include "gtest/gtest.h"
#include "transpose_sinking_test_utils.hpp"
using namespace ov;
using namespace ov::opset10;

View File

@ -9,11 +9,11 @@
#include <openvino/pass/manager.hpp>
#include <transformations/common_optimizations/transpose_sinking_data_movement.hpp>
#include <transformations/common_optimizations/transpose_sinking_utils.hpp>
#include "transpose_sinking_test_utils.hpp"
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "gtest/gtest.h"
#include "transpose_sinking_test_utils.hpp"
using namespace std;
using namespace ov;

View File

@ -2,18 +2,19 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "transpose_sinking_test_utils.hpp"
#include <openvino/frontend/manager.hpp>
#include <openvino/opsets/opset10.hpp>
#include <openvino/pass/manager.hpp>
#include "transpose_sinking_test_utils.hpp"
#include "gtest/gtest.h"
using namespace std;
using namespace ov;
using namespace ov::opset10;
shared_ptr <Node> create_main_node(const OutputVector &inputs, size_t num_ops, const FactoryPtr &creator) {
shared_ptr<Node> create_main_node(const OutputVector& inputs, size_t num_ops, const FactoryPtr& creator) {
OutputVector current_inputs = inputs;
for (size_t i = 0; i < num_ops; ++i) {
auto op = creator->create(current_inputs);
@ -22,7 +23,7 @@ shared_ptr <Node> create_main_node(const OutputVector &inputs, size_t num_ops, c
return current_inputs[0].get_node_shared_ptr();
}
ParameterVector filter_parameters(const OutputVector &out_vec) {
ParameterVector filter_parameters(const OutputVector& out_vec) {
ParameterVector parameters;
for (const auto& out : out_vec) {
auto node = out.get_node_shared_ptr();
@ -36,7 +37,7 @@ ParameterVector filter_parameters(const OutputVector &out_vec) {
OutputVector set_transpose_for(const vector<size_t>& idxs, const OutputVector& out_vec) {
OutputVector result = out_vec;
for (const auto& idx : idxs) {
const auto &out = out_vec[idx];
const auto& out = out_vec[idx];
auto rank = out.get_partial_shape().rank().get_length();
vector<int64_t> axes(rank);
iota(axes.begin(), axes.end(), 0);
@ -51,7 +52,7 @@ OutputVector set_transpose_for(const vector<size_t>& idxs, const OutputVector& o
OutputVector set_gather_for(const vector<size_t>& idxs, const OutputVector& out_vec) {
OutputVector result = out_vec;
for (const auto& idx : idxs) {
const auto &out = out_vec[idx];
const auto& out = out_vec[idx];
vector<int64_t> axes(out.get_shape()[0]);
iota(axes.begin(), axes.end(), 0);
reverse(axes.begin(), axes.end());
@ -63,7 +64,7 @@ OutputVector set_gather_for(const vector<size_t>& idxs, const OutputVector& out_
return result;
}
std::string to_string(const Shape &shape) {
std::string to_string(const Shape& shape) {
ostringstream result;
result << "{";
for (size_t idx = 0; idx < shape.size(); ++idx) {
@ -75,10 +76,10 @@ std::string to_string(const Shape &shape) {
return result.str();
}
std::shared_ptr<ov::Node> parameter(ov::element::Type el_type, const PartialShape &ps) {
std::shared_ptr<ov::Node> parameter(ov::element::Type el_type, const PartialShape& ps) {
return std::make_shared<Parameter>(el_type, ps);
}
shared_ptr<ov::Node> constant(ov::element::Type el_type, const Shape &shape, const vector<int64_t> &value) {
shared_ptr<ov::Node> constant(ov::element::Type el_type, const Shape& shape, const vector<int64_t>& value) {
return make_shared<Constant>(el_type, shape, value);
}

View File

@ -64,7 +64,8 @@ ov::ParameterVector filter_parameters(const ov::OutputVector& out_vec);
std::shared_ptr<ov::Node> create_main_node(const ov::OutputVector& inputs, size_t num_ops, const FactoryPtr& creator);
std::shared_ptr<ov::Node> parameter(ov::element::Type el_type, const ov::PartialShape& ps);
std::shared_ptr<ov::Node> constant(ov::element::Type el_type, const ov::Shape& shape, const std::vector<int64_t>& value);
std::shared_ptr<ov::Node> constant(ov::element::Type el_type,
const ov::Shape& shape,
const std::vector<int64_t>& value);

View File

@ -9,26 +9,24 @@
#include <openvino/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "transpose_sinking_test_utils.hpp"
#include "gtest/gtest.h"
#include "transpose_sinking_test_utils.hpp"
using namespace ov;
using namespace ov::opset10;
using NodePtr = std::shared_ptr<ov::Node>;
using CreateGraphF = std::function<std::shared_ptr<ov::Model>(FactoryPtr unary_factory,
size_t num_unary_ops,
const Shape& input_shape,
element::Type input_type)>;
using CreateGraphF = std::function<std::shared_ptr<
ov::Model>(FactoryPtr unary_factory, size_t num_unary_ops, const Shape& input_shape, element::Type input_type)>;
using TestParams = std::tuple<FactoryPtr,
PassFactoryPtr,
size_t, /* num_unary_ops */
CreateGraphF, /* model_factory */
CreateGraphF, /* reference_model_factory */
Shape, /* input shape */
element::Type>; /* input type */
PassFactoryPtr,
size_t, /* num_unary_ops */
CreateGraphF, /* model_factory */
CreateGraphF, /* reference_model_factory */
Shape, /* input shape */
element::Type>; /* input type */
class TransposeSinkingUnaryTestFixture : public ::testing::WithParamInterface<TestParams>, public TransformationTestsF {
public:
@ -60,7 +58,6 @@ public:
}
};
namespace transpose_sinking {
namespace testing {
namespace unary {
@ -404,12 +401,12 @@ struct TestCase {
auto wrapper = [](const TestCase& test_case) {
return ::testing::Combine(::testing::ValuesIn(test_case.main_node),
::testing::Values(test_case.transformation),
::testing::ValuesIn(test_case.num_main_ops),
::testing::Values(test_case.test_model),
::testing::Values(test_case.ref_model),
::testing::Values(test_case.input_shape),
::testing::Values(test_case.type));
::testing::Values(test_case.transformation),
::testing::ValuesIn(test_case.num_main_ops),
::testing::Values(test_case.test_model),
::testing::Values(test_case.ref_model),
::testing::Values(test_case.input_shape),
::testing::Values(test_case.type));
};
auto test_forward = []() {
@ -454,7 +451,8 @@ auto test_backward_multiple_consumers_reshape = []() {
test_case.transformation = CREATE_PASS_FACTORY(TransposeSinkingUnaryBackward);
test_case.num_main_ops = {1, 10};
test_case.test_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeAfter;
test_case.ref_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore;;
test_case.ref_model = mult_consumers_last_node::with_reshape::CreateFunctionTransposeBefore;
;
test_case.input_shape = {1, 96, 55, 55};
test_case.type = element::f32;
return wrapper(test_case);
@ -519,9 +517,9 @@ auto test_forward_multiple_consumers_first_node = []() {
test_case.type = element::f32;
return wrapper(test_case);
};
}
}
}
} // namespace unary
} // namespace testing
} // namespace transpose_sinking
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardTestSuite,
TransposeSinkingUnaryTestFixture,
@ -533,44 +531,37 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardTestSuite,
transpose_sinking::testing::unary::test_backward(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeReshape,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_forward_multiple_consumers_reshape(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeReshape,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_forward_multiple_consumers_reshape(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeReshape,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_backward_multiple_consumers_reshape(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteLastNodeReshape,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_backward_multiple_consumers_reshape(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeEltwise,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_forward_multiple_consumers_eltwise(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryForwardMultConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_backward_multiple_consumers_eltwise(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardMultConsumersTestSuiteLastNodeEltwise,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_forward_multiple_consumers_eltwise(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardMultConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_backward_multiple_consumers_eltwise(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_backward_multiple_consumers_first_node(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryBackwardMultTransposeConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_backward_multiple_transposes_first_node(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryBackwardMultTransposeConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_backward_multiple_transposes_first_node(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(
TransposeSinkingUnaryForwardMultTransposeConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_forward_multiple_consumers_first_node(),
TransposeSinkingUnaryTestFixture::get_test_name);
INSTANTIATE_TEST_SUITE_P(TransposeSinkingUnaryForwardMultTransposeConsumersTestSuiteFirstNode,
TransposeSinkingUnaryTestFixture,
transpose_sinking::testing::unary::test_forward_multiple_consumers_first_node(),
TransposeSinkingUnaryTestFixture::get_test_name);