fix TSSqueeze/TSUnsqueeze transformations in case of Reshape op

This commit is contained in:
Tikhonov Ivan 2023-03-21 09:41:48 +00:00
parent a56a0768f1
commit 8f8f0e821b
3 changed files with 20 additions and 22 deletions

View File

@ -46,11 +46,8 @@ TSSliceForward::TSSliceForward() {
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
auto data = std::make_shared<Constant>(element::i32, Shape{transpose_axis_order.size()}, transpose_axis_order);
const auto& indices = main_node->input_value(4);
auto new_axis = std::make_shared<Gather>(data, indices, axis);
main_node->input(4).replace_source_output(new_axis);
main_node->input(4).replace_source_output(
ChangeValuesOrder(main_node->input_value(4), transpose_axis_order, axis));
main_node->validate_and_infer_types();
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
@ -90,19 +87,14 @@ TSSliceBackward::TSSliceBackward() {
/* input_indexes= */ {0})) {
register_new_node(new_node);
}
// remove output transposes
RemoveSingleOutputConsumers(main_node);
SwapNames(main_node, transpose);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order);
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
auto data =
std::make_shared<Constant>(element::i32, Shape{reversed_transpose_order.size()}, reversed_transpose_order);
const auto& indices = main_node->input_value(4);
auto new_axis = std::make_shared<Gather>(data, indices, axis);
main_node->input(4).replace_source_output(new_axis);
main_node->input(4).replace_source_output(
ChangeValuesOrder(main_node->input_value(4), reversed_transpose_order, axis));
main_node->validate_and_infer_types();
return true;
};

View File

@ -61,10 +61,10 @@ bool shape_to_squeeze_axes(const std::shared_ptr<Node>& reshape,
return true;
}
std::vector<size_t> squeeze_axes_to_shape(const std::shared_ptr<Node>& input_node, std::vector<size_t> squeeze_axes) {
std::vector<size_t> squeeze_axes_to_shape(const Output<Node>& input_node, std::vector<size_t> squeeze_axes) {
std::vector<size_t> to_shape;
std::sort(squeeze_axes.begin(), squeeze_axes.end());
const auto& input_shape = input_node->input(0).get_shape(); // check is static
const auto& input_shape = input_node.get_shape(); // check is static
for (size_t i = 0, j = 0; i < input_shape.size(); ++i) {
if (j < squeeze_axes.size() && i == squeeze_axes[j]) {
++j;
@ -133,7 +133,7 @@ TSSqueezeForward::TSSqueezeForward() {
transpose_order_values);
if (as_type_ptr<Reshape>(squeeze)) {
new_values = squeeze_axes_to_shape(transpose, new_values);
new_values = squeeze_axes_to_shape(transpose->input_value(0), new_values);
}
auto new_const = Constant::create(squeeze_axes->get_element_type(), squeeze_axes->get_shape(), new_values);
@ -213,12 +213,12 @@ TSSqueezeBackward::TSSqueezeBackward() {
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
{transpose_order_values.size()},
transpose_order_values);
auto new_transpose = transpose->clone_with_new_inputs({squeeze->input_value(0), new_transpose_order});
if (as_type_ptr<Reshape>(squeeze)) {
new_values = squeeze_axes_to_shape(squeeze, new_values);
new_values = squeeze_axes_to_shape(new_transpose->output(0), new_values);
}
std::shared_ptr<Node> new_squeeze;
auto new_transpose = transpose->clone_with_new_inputs({squeeze->input_value(0), new_transpose_order});
if (squeeze_all_dims) {
new_squeeze = squeeze->clone_with_new_inputs({new_transpose, squeeze->input_value(1)});
} else {

View File

@ -60,9 +60,8 @@ bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
return true;
}
std::vector<size_t> unsqueeze_axes_to_shape(const std::shared_ptr<Node>& input_node,
std::vector<size_t> unsqueeze_axes) {
const auto& input_shape = input_node->input(0).get_shape(); // check is static
std::vector<size_t> unsqueeze_axes_to_shape(const Output<Node>& input_node, std::vector<size_t> unsqueeze_axes) {
const auto& input_shape = input_node.get_shape(); // check is static
std::vector<size_t> to_shape(input_shape.size() + unsqueeze_axes.size());
std::sort(unsqueeze_axes.begin(), unsqueeze_axes.end());
for (size_t i = 0, j = 0, k = 0; i < to_shape.size(); ++i) {
@ -113,7 +112,14 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
auto new_transpose_order =
Constant::create(transpose_order->get_element_type(), {ts_order_values.size()}, ts_order_values);
auto new_unsqueeze = unsqueeze->clone_with_new_inputs({transpose->input_value(0), unsqueeze->input_value(1)});
std::shared_ptr<Node> new_unsqueeze;
if (as_type_ptr<Reshape>(unsqueeze)) {
auto new_values = unsqueeze_axes_to_shape(transpose->input_value(0), non_negative_axes);
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), {new_values.size()}, new_values);
new_unsqueeze = unsqueeze->clone_with_new_inputs({transpose->input_value(0), new_const});
} else {
new_unsqueeze = unsqueeze->clone_with_new_inputs({transpose->input_value(0), unsqueeze->input_value(1)});
}
auto new_transpose = transpose->clone_with_new_inputs({new_unsqueeze, new_transpose_order});
replace_node(unsqueeze, new_transpose);
@ -188,7 +194,7 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
auto new_transpose = transpose->clone_with_new_inputs({unsqueeze->input_value(0), new_transpose_order});
if (as_type_ptr<Reshape>(unsqueeze)) {
new_values = unsqueeze_axes_to_shape(new_transpose, new_values);
new_values = unsqueeze_axes_to_shape(new_transpose->output(0), new_values);
}
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), unsqueeze_axes->get_shape(), new_values);
auto new_unsqueeze = unsqueeze->clone_with_new_inputs({new_transpose, new_const});