Fix validate for split, revert changes for concat, add BatchToSpace/SpaceToBatch

This commit is contained in:
Tikhonov Ivan 2023-02-23 17:12:54 +00:00
parent 20579455b7
commit 9d8016d1e6
5 changed files with 30 additions and 21 deletions

View File

@ -20,7 +20,8 @@
using namespace ov;
namespace {
std::vector<size_t> get_updated_order_forward(std::vector<size_t>& axes_values, std::vector<size_t>& order_values) {
std::vector<size_t> get_updated_order_forward(const std::vector<size_t>& axes_values,
const std::vector<size_t>& order_values) {
size_t buffer_size = order_values.size() - axes_values.size();
std::vector<size_t> aligned_order(buffer_size, 0);
std::vector<size_t> values_to_reduce(axes_values);
@ -33,14 +34,15 @@ std::vector<size_t> get_updated_order_forward(std::vector<size_t>& axes_values,
continue;
}
auto ub = std::lower_bound(values_to_reduce.begin(), values_to_reduce.end(), order_values[i]);
aligned_order[j] = order_values[i] - (ub - values_to_reduce.begin());
auto lb = std::lower_bound(values_to_reduce.begin(), values_to_reduce.end(), order_values[i]);
aligned_order[j] = order_values[i] - (lb - values_to_reduce.begin());
++j;
}
return aligned_order;
}
std::vector<size_t> get_updated_order_backward(std::vector<size_t>& axes_values, std::vector<size_t>& order_values) {
std::vector<size_t> get_updated_order_backward(const std::vector<size_t>& axes_values,
const std::vector<size_t>& order_values) {
size_t buffer_size = order_values.size() + axes_values.size();
std::vector<size_t> aligned_order(buffer_size);

View File

@ -23,10 +23,12 @@ using namespace transpose_sinking;
ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
MATCHER_SCOPE(TransposeSinkingBinaryForward);
auto main_node_label =
wrap_type<op::util::BinaryElementwiseArithmetic, PRelu>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && IfNodeHasTransposeInputs(output);
});
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic,
op::util::BinaryElementwiseComparison,
op::util::BinaryElementwiseLogical,
PRelu>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && IfNodeHasTransposeInputs(output);
});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
@ -54,10 +56,12 @@ ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
ov::pass::TransposeSinkingBinaryBackward::TransposeSinkingBinaryBackward() {
MATCHER_SCOPE(TransposeSinkingBinaryBackward);
auto main_node_label =
wrap_type<op::util::BinaryElementwiseArithmetic, PRelu>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
});
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic,
op::util::BinaryElementwiseComparison,
op::util::BinaryElementwiseLogical,
PRelu>([](const Output<Node>& output) -> bool {
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
});
auto transpose_const_label = wrap_type<Constant>();

View File

@ -68,19 +68,18 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
});
/* auto transpose_const_label = wrap_type<Constant>();
auto transpose_const_label = wrap_type<Constant>();
auto transpose_label =
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
return has_static_rank()(output);
});*/
auto transpose_label =
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
return has_static_rank()(output) && is_sinking_node(output);
});
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
auto transpose = main_node->output(0).get_target_inputs().begin()->get_node()->shared_from_this();
auto transpose_const = as_type_ptr<Constant>(transpose->input_value(1).get_node_shared_ptr());
auto concat_node = as_type_ptr<Concat>(main_node);
auto concat_axis = concat_node->get_concatenation_axis();
if (concat_axis < 0) {
@ -102,6 +101,6 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
return true;
};
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -12,6 +12,7 @@
#include "itt.hpp"
#include "transformations/common_optimizations/transpose_sinking.hpp"
#include "transformations/common_optimizations/transpose_sinking_batch_to_space.hpp"
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
#include "transformations/common_optimizations/transpose_sinking_pad.hpp"
@ -27,6 +28,7 @@ ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
add_matcher<ov::pass::TransposeSinkingSplitForward>();
add_matcher<ov::pass::TransposeSinkingPadForward>();
add_matcher<ov::pass::TransposeReduction>();
add_matcher<ov::pass::TransposeSinkingBatchToSpaceForward>();
add_matcher<ov::pass::TransposeFuse>();
}
@ -38,6 +40,7 @@ ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() {
add_matcher<ov::pass::TransposeSinkingSplitBackward>();
add_matcher<ov::pass::TransposeSinkingPadBackward>();
add_matcher<ov::pass::TransposeReductionBackward>();
add_matcher<ov::pass::TransposeSinkingBatchToSpaceBackward>();
add_matcher<ov::pass::TransposeFuse>();
}

View File

@ -49,6 +49,7 @@ ov::pass::TransposeSinkingPadForward::TransposeSinkingPadForward() {
main_node->input(2).replace_source_output(
ChangeValuesOrder(main_node->input_value(2), reversed_transpose_order, axis));
main_node->validate_and_infer_types();
// insert Transpose for Pad output
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {