Fix validate for split, revert changes for concat, add BatchToSpace/SpaceToBatch
This commit is contained in:
parent
20579455b7
commit
9d8016d1e6
@ -20,7 +20,8 @@
|
||||
using namespace ov;
|
||||
|
||||
namespace {
|
||||
std::vector<size_t> get_updated_order_forward(std::vector<size_t>& axes_values, std::vector<size_t>& order_values) {
|
||||
std::vector<size_t> get_updated_order_forward(const std::vector<size_t>& axes_values,
|
||||
const std::vector<size_t>& order_values) {
|
||||
size_t buffer_size = order_values.size() - axes_values.size();
|
||||
std::vector<size_t> aligned_order(buffer_size, 0);
|
||||
std::vector<size_t> values_to_reduce(axes_values);
|
||||
@ -33,14 +34,15 @@ std::vector<size_t> get_updated_order_forward(std::vector<size_t>& axes_values,
|
||||
continue;
|
||||
}
|
||||
|
||||
auto ub = std::lower_bound(values_to_reduce.begin(), values_to_reduce.end(), order_values[i]);
|
||||
aligned_order[j] = order_values[i] - (ub - values_to_reduce.begin());
|
||||
auto lb = std::lower_bound(values_to_reduce.begin(), values_to_reduce.end(), order_values[i]);
|
||||
aligned_order[j] = order_values[i] - (lb - values_to_reduce.begin());
|
||||
++j;
|
||||
}
|
||||
return aligned_order;
|
||||
}
|
||||
|
||||
std::vector<size_t> get_updated_order_backward(std::vector<size_t>& axes_values, std::vector<size_t>& order_values) {
|
||||
std::vector<size_t> get_updated_order_backward(const std::vector<size_t>& axes_values,
|
||||
const std::vector<size_t>& order_values) {
|
||||
size_t buffer_size = order_values.size() + axes_values.size();
|
||||
std::vector<size_t> aligned_order(buffer_size);
|
||||
|
||||
|
@ -23,10 +23,12 @@ using namespace transpose_sinking;
|
||||
ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
|
||||
MATCHER_SCOPE(TransposeSinkingBinaryForward);
|
||||
|
||||
auto main_node_label =
|
||||
wrap_type<op::util::BinaryElementwiseArithmetic, PRelu>([](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && IfNodeHasTransposeInputs(output);
|
||||
});
|
||||
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic,
|
||||
op::util::BinaryElementwiseComparison,
|
||||
op::util::BinaryElementwiseLogical,
|
||||
PRelu>([](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && IfNodeHasTransposeInputs(output);
|
||||
});
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
@ -54,10 +56,12 @@ ov::pass::TransposeSinkingBinaryForward::TransposeSinkingBinaryForward() {
|
||||
ov::pass::TransposeSinkingBinaryBackward::TransposeSinkingBinaryBackward() {
|
||||
MATCHER_SCOPE(TransposeSinkingBinaryBackward);
|
||||
|
||||
auto main_node_label =
|
||||
wrap_type<op::util::BinaryElementwiseArithmetic, PRelu>([](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
|
||||
});
|
||||
auto main_node_label = wrap_type<op::util::BinaryElementwiseArithmetic,
|
||||
op::util::BinaryElementwiseComparison,
|
||||
op::util::BinaryElementwiseLogical,
|
||||
PRelu>([](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
|
||||
});
|
||||
|
||||
auto transpose_const_label = wrap_type<Constant>();
|
||||
|
||||
|
@ -68,19 +68,18 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
|
||||
return has_static_rank()(output) && HasSameOutputTransposeNodes(output);
|
||||
});
|
||||
|
||||
/* auto transpose_const_label = wrap_type<Constant>();
|
||||
auto transpose_const_label = wrap_type<Constant>();
|
||||
|
||||
auto transpose_label =
|
||||
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output);
|
||||
});*/
|
||||
auto transpose_label =
|
||||
wrap_type<Transpose>({main_node_label, transpose_const_label}, [](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && is_sinking_node(output);
|
||||
});
|
||||
|
||||
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
||||
auto transpose = main_node->output(0).get_target_inputs().begin()->get_node()->shared_from_this();
|
||||
auto transpose_const = as_type_ptr<Constant>(transpose->input_value(1).get_node_shared_ptr());
|
||||
auto concat_node = as_type_ptr<Concat>(main_node);
|
||||
auto concat_axis = concat_node->get_concatenation_axis();
|
||||
if (concat_axis < 0) {
|
||||
@ -102,6 +101,6 @@ ov::pass::TransposeSinkingConcatBackward::TransposeSinkingConcatBackward() {
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(main_node_label, matcher_name);
|
||||
auto m = std::make_shared<Matcher>(transpose_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
@ -12,6 +12,7 @@
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_batch_to_space.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_binary.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_concat.hpp"
|
||||
#include "transformations/common_optimizations/transpose_sinking_pad.hpp"
|
||||
@ -27,6 +28,7 @@ ov::pass::TransposeSinkingGeneralForward::TransposeSinkingGeneralForward() {
|
||||
add_matcher<ov::pass::TransposeSinkingSplitForward>();
|
||||
add_matcher<ov::pass::TransposeSinkingPadForward>();
|
||||
add_matcher<ov::pass::TransposeReduction>();
|
||||
add_matcher<ov::pass::TransposeSinkingBatchToSpaceForward>();
|
||||
add_matcher<ov::pass::TransposeFuse>();
|
||||
}
|
||||
|
||||
@ -38,6 +40,7 @@ ov::pass::TransposeSinkingGeneralBackward::TransposeSinkingGeneralBackward() {
|
||||
add_matcher<ov::pass::TransposeSinkingSplitBackward>();
|
||||
add_matcher<ov::pass::TransposeSinkingPadBackward>();
|
||||
add_matcher<ov::pass::TransposeReductionBackward>();
|
||||
add_matcher<ov::pass::TransposeSinkingBatchToSpaceBackward>();
|
||||
add_matcher<ov::pass::TransposeFuse>();
|
||||
}
|
||||
|
||||
|
@ -49,6 +49,7 @@ ov::pass::TransposeSinkingPadForward::TransposeSinkingPadForward() {
|
||||
main_node->input(2).replace_source_output(
|
||||
ChangeValuesOrder(main_node->input_value(2), reversed_transpose_order, axis));
|
||||
|
||||
main_node->validate_and_infer_types();
|
||||
// insert Transpose for Pad output
|
||||
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
|
||||
for (auto& new_node : sink_forward::InsertOutputTransposes(main_node, transpose_input_info)) {
|
||||
|
Loading…
Reference in New Issue
Block a user