fix TSSqueeze/TSUnsqueeze transformations
This commit is contained in:
parent
7a6988a4a6
commit
346796af9c
@ -79,6 +79,7 @@
|
||||
#include <transformations/smart_reshape/reshape_sinking.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "transformations/transpose_sinking/ts_general.hpp"
|
||||
|
||||
bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph::Function>& f) {
|
||||
RUN_ON_FUNCTION_SCOPE(MOCTransformations);
|
||||
@ -112,6 +113,9 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph::Fu
|
||||
// RemoveConcatZeroDimInput and RemoveMultiSubGraphOpDanglingParamsResults should be called together.
|
||||
using namespace ov::pass;
|
||||
REGISTER_PASS(manager, EliminateScatterUpdate)
|
||||
REGISTER_PASS(manager, TransposeSinkingGeneral)
|
||||
manager.register_pass<Serialize>("/home/tikhonov/OpenVINO/tmp/test_model/ser_test.xml",
|
||||
"/home/tikhonov/OpenVINO/tmp/test_model/ser_test.bin");
|
||||
REGISTER_PASS(manager, RemoveConcatZeroDimInput)
|
||||
REGISTER_PASS(manager, Validate)
|
||||
// todo: ticket 96960
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
@ -37,16 +38,21 @@ bool shape_to_squeeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
|
||||
const auto input_shape = input_pshape.to_shape();
|
||||
if (new_shape.size() < input_shape.size()) {
|
||||
for (size_t i = 0, j = 0; i < new_shape.size(); j++) {
|
||||
const auto input_dim = static_cast<int64_t>(input_shape[j]);
|
||||
if (new_shape[i] == input_dim) {
|
||||
i++;
|
||||
} else if (new_shape[i] != input_dim && input_dim != 1) {
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
const auto input_dim = static_cast<int64_t>(input_shape[i]);
|
||||
if (j < new_shape.size() && new_shape[j] == input_dim) {
|
||||
j++;
|
||||
} else if (input_dim != 1) {
|
||||
return false;
|
||||
} else {
|
||||
result_axes.push_back(j);
|
||||
result_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
if (j != new_shape.size()) {
|
||||
// not all new_shape values are in input_shape
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// another reshape type, not Squeeze
|
||||
// todo: move this checks in the pattern
|
||||
@ -152,7 +158,10 @@ TSSqueezeBackward::TSSqueezeBackward() {
|
||||
MATCHER_SCOPE(TSSqueezeBackward);
|
||||
|
||||
auto squeeze_label = wrap_type<Squeeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
|
||||
auto transpose_label = wrap_type<Transpose>({squeeze_label, wrap_type<Constant>()});
|
||||
auto transpose_label =
|
||||
wrap_type<Transpose>({squeeze_label, wrap_type<Constant>()}, [](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && is_sinking_node(output);
|
||||
});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
@ -220,7 +229,6 @@ TSSqueezeBackward::TSSqueezeBackward() {
|
||||
|
||||
replace_node(transpose, new_squeeze);
|
||||
copy_runtime_info({transpose, squeeze}, {new_transpose, new_squeeze});
|
||||
UpdateForwardSinkingAbility(new_transpose);
|
||||
new_squeeze->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(squeeze->get_friendly_name());
|
||||
register_new_node(new_transpose);
|
||||
|
@ -11,6 +11,7 @@
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
@ -37,16 +38,20 @@ bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
|
||||
const auto input_shape = input_pshape.to_shape();
|
||||
if (new_shape.size() > input_shape.size()) {
|
||||
for (size_t i = 0, j = 0; i < input_shape.size(); j++) {
|
||||
const auto input_dim = static_cast<int64_t>(input_shape[i]);
|
||||
if (input_dim == new_shape[j]) {
|
||||
i++;
|
||||
} else if (input_dim != new_shape[j] && new_shape[j] != 1) {
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < new_shape.size(); ++i) {
|
||||
if (j < input_shape.size() && static_cast<int64_t>(input_shape[j]) == new_shape[i]) {
|
||||
j++;
|
||||
} else if (new_shape[i] != 1) {
|
||||
return false;
|
||||
} else {
|
||||
result_axes.push_back(j);
|
||||
result_axes.push_back(i);
|
||||
}
|
||||
}
|
||||
if (j != input_shape.size()) {
|
||||
// not all input_shape values are in new_shape
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
// another reshape type, not Unsqueeze
|
||||
// todo: move this checks in the pattern
|
||||
@ -60,15 +65,14 @@ std::vector<size_t> unsqueeze_axes_to_shape(const std::shared_ptr<Node>& input_n
|
||||
const auto& input_shape = input_node->input(0).get_shape(); // check is static
|
||||
std::vector<size_t> to_shape(input_shape.size() + unsqueeze_axes.size());
|
||||
std::sort(unsqueeze_axes.begin(), unsqueeze_axes.end());
|
||||
std::stack<size_t, std::vector<size_t>> shape_to_add(input_shape);
|
||||
for (size_t i = 0, j = 0; i < to_shape.size(); ++i) {
|
||||
for (size_t i = 0, j = 0, k = 0; i < to_shape.size(); ++i) {
|
||||
if (j < unsqueeze_axes.size() && i == unsqueeze_axes[j]) {
|
||||
to_shape[i] = 1;
|
||||
j++;
|
||||
continue;
|
||||
} else if (k < input_shape.size()) {
|
||||
to_shape[i] = input_shape[k];
|
||||
k++;
|
||||
}
|
||||
to_shape[i] = shape_to_add.top();
|
||||
shape_to_add.pop();
|
||||
}
|
||||
return to_shape;
|
||||
}
|
||||
@ -131,7 +135,10 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
|
||||
|
||||
auto unsqueeze_label =
|
||||
wrap_type<Unsqueeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
|
||||
auto transpose_label = wrap_type<Transpose>({unsqueeze_label, wrap_type<Constant>()});
|
||||
auto transpose_label =
|
||||
wrap_type<Transpose>({unsqueeze_label, wrap_type<Constant>()}, [](const Output<Node>& output) -> bool {
|
||||
return has_static_rank()(output) && is_sinking_node(output);
|
||||
});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_map();
|
||||
@ -178,16 +185,16 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
|
||||
auto new_transpose_order = std::make_shared<Constant>(transpose_order->get_element_type(),
|
||||
Shape{transpose_order_values.size()},
|
||||
transpose_order_values);
|
||||
|
||||
auto new_transpose = transpose->clone_with_new_inputs({unsqueeze->input_value(0), new_transpose_order});
|
||||
if (as_type_ptr<Reshape>(unsqueeze)) {
|
||||
new_values = unsqueeze_axes_to_shape(unsqueeze, new_values);
|
||||
new_values = unsqueeze_axes_to_shape(new_transpose, new_values);
|
||||
}
|
||||
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), unsqueeze_axes->get_shape(), new_values);
|
||||
auto new_transpose = transpose->clone_with_new_inputs({unsqueeze->input_value(0), new_transpose_order});
|
||||
auto new_unsqueeze = unsqueeze->clone_with_new_inputs({new_transpose, new_const});
|
||||
|
||||
replace_node(transpose, new_unsqueeze);
|
||||
copy_runtime_info({transpose, unsqueeze}, {new_transpose, new_unsqueeze});
|
||||
UpdateForwardSinkingAbility(new_transpose);
|
||||
new_unsqueeze->set_friendly_name(transpose->get_friendly_name());
|
||||
new_transpose->set_friendly_name(unsqueeze->get_friendly_name());
|
||||
register_new_node(new_transpose);
|
||||
|
Loading…
Reference in New Issue
Block a user