fix TSSqueeze/TSUnsqueeze transformations

This commit is contained in:
Tikhonov Ivan 2023-03-17 15:33:27 +00:00
parent 7a6988a4a6
commit 346796af9c
3 changed files with 42 additions and 23 deletions

View File

@ -79,6 +79,7 @@
#include <transformations/smart_reshape/reshape_sinking.hpp>
#include "itt.hpp"
#include "transformations/transpose_sinking/ts_general.hpp"
bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph::Function>& f) {
RUN_ON_FUNCTION_SCOPE(MOCTransformations);
@ -112,6 +113,9 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph::Fu
// RemoveConcatZeroDimInput and RemoveMultiSubGraphOpDanglingParamsResults should be called together.
using namespace ov::pass;
REGISTER_PASS(manager, EliminateScatterUpdate)
REGISTER_PASS(manager, TransposeSinkingGeneral)
manager.register_pass<Serialize>("/home/tikhonov/OpenVINO/tmp/test_model/ser_test.xml",
"/home/tikhonov/OpenVINO/tmp/test_model/ser_test.bin");
REGISTER_PASS(manager, RemoveConcatZeroDimInput)
REGISTER_PASS(manager, Validate)
// todo: ticket 96960

View File

@ -11,6 +11,7 @@
#include "openvino/core/validation_util.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
#include "transformations/utils/utils.hpp"
@ -37,16 +38,21 @@ bool shape_to_squeeze_axes(const std::shared_ptr<Node>& reshape,
const auto input_shape = input_pshape.to_shape();
if (new_shape.size() < input_shape.size()) {
for (size_t i = 0, j = 0; i < new_shape.size(); j++) {
const auto input_dim = static_cast<int64_t>(input_shape[j]);
if (new_shape[i] == input_dim) {
i++;
} else if (new_shape[i] != input_dim && input_dim != 1) {
size_t j = 0;
for (size_t i = 0; i < input_shape.size(); i++) {
const auto input_dim = static_cast<int64_t>(input_shape[i]);
if (j < new_shape.size() && new_shape[j] == input_dim) {
j++;
} else if (input_dim != 1) {
return false;
} else {
result_axes.push_back(j);
result_axes.push_back(i);
}
}
if (j != new_shape.size()) {
// not all new_shape values are in input_shape
return false;
}
} else {
// another reshape type, not Squeeze
// todo: move this checks in the pattern
@ -152,7 +158,10 @@ TSSqueezeBackward::TSSqueezeBackward() {
MATCHER_SCOPE(TSSqueezeBackward);
auto squeeze_label = wrap_type<Squeeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
auto transpose_label = wrap_type<Transpose>({squeeze_label, wrap_type<Constant>()});
auto transpose_label =
wrap_type<Transpose>({squeeze_label, wrap_type<Constant>()}, [](const Output<Node>& output) -> bool {
return has_static_rank()(output) && is_sinking_node(output);
});
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_map();
@ -220,7 +229,6 @@ TSSqueezeBackward::TSSqueezeBackward() {
replace_node(transpose, new_squeeze);
copy_runtime_info({transpose, squeeze}, {new_transpose, new_squeeze});
UpdateForwardSinkingAbility(new_transpose);
new_squeeze->set_friendly_name(transpose->get_friendly_name());
new_transpose->set_friendly_name(squeeze->get_friendly_name());
register_new_node(new_transpose);

View File

@ -11,6 +11,7 @@
#include "openvino/core/validation_util.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
#include "transformations/utils/utils.hpp"
@ -37,16 +38,20 @@ bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
const auto input_shape = input_pshape.to_shape();
if (new_shape.size() > input_shape.size()) {
for (size_t i = 0, j = 0; i < input_shape.size(); j++) {
const auto input_dim = static_cast<int64_t>(input_shape[i]);
if (input_dim == new_shape[j]) {
i++;
} else if (input_dim != new_shape[j] && new_shape[j] != 1) {
size_t j = 0;
for (size_t i = 0; i < new_shape.size(); ++i) {
if (j < input_shape.size() && static_cast<int64_t>(input_shape[j]) == new_shape[i]) {
j++;
} else if (new_shape[i] != 1) {
return false;
} else {
result_axes.push_back(j);
result_axes.push_back(i);
}
}
if (j != input_shape.size()) {
// not all input_shape values are in new_shape
return false;
}
} else {
// another reshape type, not Unsqueeze
// todo: move this checks in the pattern
@ -60,15 +65,14 @@ std::vector<size_t> unsqueeze_axes_to_shape(const std::shared_ptr<Node>& input_n
const auto& input_shape = input_node->input(0).get_shape(); // check is static
std::vector<size_t> to_shape(input_shape.size() + unsqueeze_axes.size());
std::sort(unsqueeze_axes.begin(), unsqueeze_axes.end());
std::stack<size_t, std::vector<size_t>> shape_to_add(input_shape);
for (size_t i = 0, j = 0; i < to_shape.size(); ++i) {
for (size_t i = 0, j = 0, k = 0; i < to_shape.size(); ++i) {
if (j < unsqueeze_axes.size() && i == unsqueeze_axes[j]) {
to_shape[i] = 1;
j++;
continue;
} else if (k < input_shape.size()) {
to_shape[i] = input_shape[k];
k++;
}
to_shape[i] = shape_to_add.top();
shape_to_add.pop();
}
return to_shape;
}
@ -131,7 +135,10 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
auto unsqueeze_label =
wrap_type<Unsqueeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
auto transpose_label = wrap_type<Transpose>({unsqueeze_label, wrap_type<Constant>()});
auto transpose_label =
wrap_type<Transpose>({unsqueeze_label, wrap_type<Constant>()}, [](const Output<Node>& output) -> bool {
return has_static_rank()(output) && is_sinking_node(output);
});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pattern_to_output = m.get_pattern_map();
@ -178,16 +185,16 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
auto new_transpose_order = std::make_shared<Constant>(transpose_order->get_element_type(),
Shape{transpose_order_values.size()},
transpose_order_values);
auto new_transpose = transpose->clone_with_new_inputs({unsqueeze->input_value(0), new_transpose_order});
if (as_type_ptr<Reshape>(unsqueeze)) {
new_values = unsqueeze_axes_to_shape(unsqueeze, new_values);
new_values = unsqueeze_axes_to_shape(new_transpose, new_values);
}
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), unsqueeze_axes->get_shape(), new_values);
auto new_transpose = transpose->clone_with_new_inputs({unsqueeze->input_value(0), new_transpose_order});
auto new_unsqueeze = unsqueeze->clone_with_new_inputs({new_transpose, new_const});
replace_node(transpose, new_unsqueeze);
copy_runtime_info({transpose, unsqueeze}, {new_transpose, new_unsqueeze});
UpdateForwardSinkingAbility(new_transpose);
new_unsqueeze->set_friendly_name(transpose->get_friendly_name());
new_transpose->set_friendly_name(unsqueeze->get_friendly_name());
register_new_node(new_transpose);