codestyle

This commit is contained in:
Tikhonov Ivan 2023-03-17 08:50:15 +00:00
parent 849dc70763
commit 7a6988a4a6
3 changed files with 17 additions and 15 deletions

View File

@ -4,9 +4,9 @@
#pragma once
#include "transformations_visibility.hpp"
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pass.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace pass {

View File

@ -103,7 +103,6 @@ TSSliceBackward::TSSliceBackward() {
auto new_axis = std::make_shared<Gather>(data, indices, axis);
main_node->input(4).replace_source_output(new_axis);
const auto& interpolate = std::dynamic_pointer_cast<Slice>(main_node);
main_node->validate_and_infer_types();
return true;
};

View File

@ -23,13 +23,13 @@ using namespace ov::pass::transpose_sinking::utils;
namespace {
bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
const std::shared_ptr<Constant>& reshape_to_shape,
std::vector<size_t>& result_axes) {
const std::shared_ptr<Constant>& reshape_to_shape,
std::vector<size_t>& result_axes) {
result_axes.clear();
auto reduction_axes_values = reshape_to_shape->cast_vector<int64_t>();
// supported the case if Reshape is equal to Unsqueeze
const auto &new_shape = reduction_axes_values;
const auto &input_pshape = reshape->get_input_partial_shape(0);
const auto& new_shape = reduction_axes_values;
const auto& input_pshape = reshape->get_input_partial_shape(0);
// todo: support dynamic case
if (input_pshape.is_dynamic()) {
return false;
@ -37,7 +37,7 @@ bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
const auto input_shape = input_pshape.to_shape();
if (new_shape.size() > input_shape.size()) {
for (size_t i = 0, j = 0; i < input_shape.size();j++) {
for (size_t i = 0, j = 0; i < input_shape.size(); j++) {
const auto input_dim = static_cast<int64_t>(input_shape[i]);
if (input_dim == new_shape[j]) {
i++;
@ -55,8 +55,9 @@ bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
return true;
}
std::vector<size_t> unsqueeze_axes_to_shape(const std::shared_ptr<Node>& input_node, std::vector<size_t> unsqueeze_axes) {
const auto& input_shape = input_node->input(0).get_shape(); // check is static
std::vector<size_t> unsqueeze_axes_to_shape(const std::shared_ptr<Node>& input_node,
std::vector<size_t> unsqueeze_axes) {
const auto& input_shape = input_node->input(0).get_shape(); // check is static
std::vector<size_t> to_shape(input_shape.size() + unsqueeze_axes.size());
std::sort(unsqueeze_axes.begin(), unsqueeze_axes.end());
std::stack<size_t, std::vector<size_t>> shape_to_add(input_shape);
@ -99,14 +100,14 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
}
} else {
auto rank = unsqueeze->get_output_partial_shape(0).rank();
non_negative_axes = normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
non_negative_axes =
normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
}
auto ts_order_values = transpose_order->cast_vector<size_t>();
ts_order_values = GetOrderBeforeReduction(non_negative_axes, ts_order_values);
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
{ts_order_values.size()},
ts_order_values);
auto new_transpose_order =
Constant::create(transpose_order->get_element_type(), {ts_order_values.size()}, ts_order_values);
auto new_unsqueeze = unsqueeze->clone_with_new_inputs({transpose->input_value(0), unsqueeze->input_value(1)});
auto new_transpose = transpose->clone_with_new_inputs({new_unsqueeze, new_transpose_order});
@ -128,7 +129,8 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
TSUnsqueezeBackward::TSUnsqueezeBackward() {
MATCHER_SCOPE(TSUnsqueezeBackward);
auto unsqueeze_label = wrap_type<Unsqueeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
auto unsqueeze_label =
wrap_type<Unsqueeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
auto transpose_label = wrap_type<Transpose>({unsqueeze_label, wrap_type<Constant>()});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
@ -150,7 +152,8 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
}
} else {
auto rank = unsqueeze->get_output_partial_shape(0).rank();
non_negative_axes = normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
non_negative_axes =
normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
}
auto transpose_order_values = transpose_order->cast_vector<size_t>();