Merge branch 'itikhono/ts/slice' of https://github.com/itikhono/openvino into itikhono/ts/slice

This commit is contained in:
Ivan 2023-03-17 16:09:35 +04:00
commit 83ab2cc5f6
6 changed files with 79 additions and 60 deletions

View File

@ -4,9 +4,9 @@
#pragma once
#include "../../transformations_visibility.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "../../transformations_visibility.hpp"
namespace ov {
namespace pass {
@ -15,7 +15,7 @@ namespace transpose_sinking {
class TRANSFORMATIONS_API TSSliceForward;
class TRANSFORMATIONS_API TSSliceBackward;
}
} // namespace transpose_sinking
} // namespace pass
} // namespace ov

View File

@ -15,10 +15,10 @@
#include "transformations/transpose_sinking/ts_fuse.hpp"
#include "transformations/transpose_sinking/ts_interpolate.hpp"
#include "transformations/transpose_sinking/ts_reduction.hpp"
#include "transformations/transpose_sinking/ts_squeeze.hpp"
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
#include "transformations/transpose_sinking/ts_split.hpp"
#include "transformations/transpose_sinking/ts_squeeze.hpp"
#include "transformations/transpose_sinking/ts_unary.hpp"
#include "transformations/transpose_sinking/ts_unsqueeze.hpp"
#include "transformations/utils/utils.hpp"
using namespace ov::pass::transpose_sinking;

View File

@ -24,7 +24,7 @@ using namespace ov::pass::transpose_sinking::utils;
namespace {
bool get_keep_dims(const std::shared_ptr<Node> &reduction) {
bool get_keep_dims(const std::shared_ptr<Node>& reduction) {
auto arithmetic_reduce = as_type_ptr<ov::op::util::ArithmeticReductionKeepDims>(reduction);
auto logical_reduce = as_type_ptr<ov::op::util::LogicalReductionKeepDims>(reduction);
@ -36,14 +36,14 @@ bool get_keep_dims(const std::shared_ptr<Node> &reduction) {
return keep_dims;
}
}
} // namespace
TSReductionForward::TSReductionForward() {
MATCHER_SCOPE(TSReductionForward);
auto transpose_label = wrap_type<Transpose>({any_input(), wrap_type<Constant>()});
auto reduce_label = wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
{transpose_label, wrap_type<Constant>()});
{transpose_label, wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_map();
@ -72,7 +72,8 @@ TSReductionForward::TSReductionForward() {
transpose_order_values = GetOrderAfterReduction(non_negative_axes, transpose_order_values);
}
auto new_transpose_order = Constant::create(transpose_order->get_element_type(), {transpose_order_values.size()},
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
{transpose_order_values.size()},
transpose_order_values);
auto new_const = Constant::create(reduction_axes->get_element_type(), reduction_axes->get_shape(), new_values);
@ -96,7 +97,8 @@ TSReductionBackward::TSReductionBackward() {
MATCHER_SCOPE(TSReductionBackward);
auto reduce_label = wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims>(
{any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
{any_input(), wrap_type<Constant>()},
HasSameOutputTransposeNodes);
auto transpose_label = wrap_type<Transpose>({reduce_label, wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
@ -112,13 +114,14 @@ TSReductionBackward::TSReductionBackward() {
auto rank = reduction->get_input_partial_shape(0).rank();
auto non_negative_axes =
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
normalize_axes(reduction->get_friendly_name(), reduction_axes->cast_vector<int64_t>(), rank);
auto transpose_order_values = transpose_order->cast_vector<size_t>();
if (!keep_dims) {
transpose_order_values = GetOrderBeforeReduction(non_negative_axes, transpose_order_values);
}
auto reversed_order_values = ReverseTransposeOrder(transpose_order_values);
auto new_transpose_order = Constant::create(transpose_order->get_element_type(), {transpose_order_values.size()},
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
{transpose_order_values.size()},
transpose_order_values);
std::vector<size_t> new_values;

View File

@ -4,15 +4,14 @@
#include "transformations/transpose_sinking/ts_slice.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "itt.hpp"
#include "openvino/op/util/op_types.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/util/common_util.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
using namespace ov;
using namespace ov::opset10;

View File

@ -28,8 +28,8 @@ bool shape_to_squeeze_axes(const std::shared_ptr<Node>& reshape,
result_axes.clear();
auto reduction_axes_values = reshape_to_shape->cast_vector<int64_t>();
// supported the case if Reshape is equal to Squeeze
const auto &new_shape = reduction_axes_values;
const auto &input_pshape = reshape->get_input_partial_shape(0);
const auto& new_shape = reduction_axes_values;
const auto& input_pshape = reshape->get_input_partial_shape(0);
// todo: support dynamic case
if (input_pshape.is_dynamic()) {
return false;
@ -58,7 +58,7 @@ bool shape_to_squeeze_axes(const std::shared_ptr<Node>& reshape,
std::vector<size_t> squeeze_axes_to_shape(const std::shared_ptr<Node>& input_node, std::vector<size_t> squeeze_axes) {
std::vector<size_t> to_shape;
std::sort(squeeze_axes.begin(), squeeze_axes.end());
const auto& input_shape = input_node->input(0).get_shape(); // check is static
const auto& input_shape = input_node->input(0).get_shape(); // check is static
for (size_t i = 0, j = 0; i < input_shape.size(); ++i) {
if (j < squeeze_axes.size() && i == squeeze_axes[j]) {
++j;
@ -69,7 +69,7 @@ std::vector<size_t> squeeze_axes_to_shape(const std::shared_ptr<Node>& input_nod
return to_shape;
}
}
} // namespace
TSSqueezeForward::TSSqueezeForward() {
MATCHER_SCOPE(TSSqueezeForward);
@ -97,7 +97,8 @@ TSSqueezeForward::TSSqueezeForward() {
}
} else {
auto rank = squeeze->get_input_partial_shape(0).rank();
non_negative_axes = normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
non_negative_axes =
normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
}
// if 2nd input to squeeze is empty then all '1' dims will be deleted.
@ -173,7 +174,8 @@ TSSqueezeBackward::TSSqueezeBackward() {
}
} else {
auto rank = squeeze->get_input_partial_shape(0).rank();
non_negative_axes = normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
non_negative_axes =
normalize_axes(squeeze->get_friendly_name(), squeeze_axes->cast_vector<int64_t>(), rank);
}
bool squeeze_all_dims = false;
@ -211,7 +213,8 @@ TSSqueezeBackward::TSSqueezeBackward() {
if (squeeze_all_dims) {
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, squeeze->input_value(1)});
} else {
auto new_const = std::make_shared<Constant>(squeeze_axes->get_element_type(), squeeze_axes->get_shape(), new_values);
auto new_const =
std::make_shared<Constant>(squeeze_axes->get_element_type(), squeeze_axes->get_shape(), new_values);
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, new_const});
}

View File

@ -191,7 +191,11 @@ class SliceFactory : public IFactory {
public:
explicit SliceFactory(const std::string& type_name) : IFactory(type_name) {}
NodePtr create(const OutputVector& parent_nodes) const override {
return std::make_shared<Slice>(parent_nodes[0], parent_nodes[1], parent_nodes[2], parent_nodes[3], parent_nodes[4]);
return std::make_shared<Slice>(parent_nodes[0],
parent_nodes[1],
parent_nodes[2],
parent_nodes[3],
parent_nodes[4]);
}
};
@ -687,8 +691,8 @@ auto test_forward_squeeze = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 1, 2, 1}),
constant<int64_t>(element::i32, {2}, {0, 2}),
parameter(element::f32, {32, 1, 2, 1}),
constant<int64_t>(element::i32, {2}, {0, 2}),
};
// Test model description:
@ -701,7 +705,7 @@ auto test_forward_squeeze = []() {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] =
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{3, 1});
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{3, 1});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
@ -721,8 +725,8 @@ auto test_forward_unsqueeze = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 3, 2, 1}),
constant<int64_t>(element::i32, {2}, {0, 2}),
parameter(element::f32, {32, 3, 2, 1}),
constant<int64_t>(element::i32, {2}, {0, 2}),
};
// Test model description:
@ -735,7 +739,7 @@ auto test_forward_unsqueeze = []() {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] =
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{0, 2});
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{0, 2});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
@ -762,11 +766,11 @@ auto test_forward_slice = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSSliceForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {6, 4, 5, 3}),
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
parameter(element::f32, {6, 4, 5, 3}),
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
};
// Test model description:
@ -806,8 +810,8 @@ auto test_forward_reshape_squeeze = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {6, 1, 5, 1, 4}),
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
parameter(element::f32, {6, 1, 5, 1, 4}),
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
};
// Test model description:
@ -820,7 +824,7 @@ auto test_forward_reshape_squeeze = []() {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] =
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 5, 4});
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 5, 4});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
@ -831,7 +835,9 @@ auto test_forward_reshape_squeeze = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeForward, TransposeSinkingTestFixture, test_forward_reshape_squeeze());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeForward,
TransposeSinkingTestFixture,
test_forward_reshape_squeeze());
auto test_forward_reshape_unsqueeze = []() {
TestCase test_case;
@ -840,8 +846,8 @@ auto test_forward_reshape_unsqueeze = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeForward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {6, 5, 4}),
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
parameter(element::f32, {6, 5, 4}),
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
};
// Test model description:
@ -863,7 +869,9 @@ auto test_forward_reshape_unsqueeze = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeForward, TransposeSinkingTestFixture, test_forward_reshape_unsqueeze());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeForward,
TransposeSinkingTestFixture,
test_forward_reshape_unsqueeze());
// ------------------ BACKWARD --------------------
auto test_backward_unary = []() {
@ -1141,7 +1149,6 @@ INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonInterpolateBackward,
TransposeSinkingTestFixture,
test_backward_interpolate());
auto test_backward_squeeze = []() {
TestCase test_case;
@ -1149,8 +1156,8 @@ auto test_backward_squeeze = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 1, 2, 1}),
constant<int64_t>(element::i32, {2}, {1, 3}),
parameter(element::f32, {32, 1, 2, 1}),
constant<int64_t>(element::i32, {2}, {1, 3}),
};
// Test model description:
@ -1182,8 +1189,8 @@ auto test_backward_unsqueeze = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {32, 3, 2, 1}),
constant<int64_t>(element::i32, {2}, {0, 2}),
parameter(element::f32, {32, 3, 2, 1}),
constant<int64_t>(element::i32, {2}, {0, 2}),
};
// Test model description:
@ -1196,7 +1203,7 @@ auto test_backward_unsqueeze = []() {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] =
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{5, 3});
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{5, 3});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
@ -1206,7 +1213,9 @@ auto test_backward_unsqueeze = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackward, TransposeSinkingTestFixture, test_backward_unsqueeze());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackward,
TransposeSinkingTestFixture,
test_backward_unsqueeze());
auto test_backward_slice = []() {
TestCase test_case;
@ -1215,11 +1224,11 @@ auto test_backward_slice = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSSliceBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {6, 4, 5, 3}),
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
parameter(element::f32, {6, 4, 5, 3}),
constant<int64_t>(element::i32, {3}, {1, 2, 3}),
constant<int64_t>(element::i32, {3}, {0, 4, 11}),
constant<int64_t>(element::i32, {3}, {1, 2, -1}),
constant<int64_t>(element::i32, {3}, {0, 1, 2}),
};
// Test model description:
@ -1257,8 +1266,8 @@ auto test_backward_reshape_squeeze = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSSqueezeBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {4, 1, 5, 1, 6}),
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
parameter(element::f32, {4, 1, 5, 1, 6}),
constant<int64_t>(element::i32, {3}, {4, 5, 6}),
};
// Test model description:
@ -1281,7 +1290,9 @@ auto test_backward_reshape_squeeze = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeBackward, TransposeSinkingTestFixture, test_backward_reshape_squeeze());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeSqueezeBackward,
TransposeSinkingTestFixture,
test_backward_reshape_squeeze());
auto test_backward_reshape_unsqueeze = []() {
TestCase test_case;
@ -1290,8 +1301,8 @@ auto test_backward_reshape_unsqueeze = []() {
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeBackward);
test_case.num_main_ops = {1};
test_case.inputs_to_main = {
parameter(element::f32, {4, 5, 6}),
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
parameter(element::f32, {4, 5, 6}),
constant<int64_t>(element::i32, {5}, {4, 1, 5, 1, 6}),
};
// Test model description:
@ -1303,8 +1314,9 @@ auto test_backward_reshape_unsqueeze = []() {
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] =
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 1, 5, 1, 4});
new_out_vec[1] = make_shared<Constant>(out_vec[1].get_element_type(),
out_vec[1].get_shape(),
std::vector<int64_t>{6, 1, 5, 1, 4});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {1}}};
@ -1314,7 +1326,9 @@ auto test_backward_reshape_unsqueeze = []() {
return wrapper(test_case);
};
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeBackward, TransposeSinkingTestFixture, test_backward_reshape_unsqueeze());
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeBackward,
TransposeSinkingTestFixture,
test_backward_reshape_unsqueeze());
} // namespace common
} // namespace testing
} // namespace transpose_sinking