codestyle
This commit is contained in:
parent
849dc70763
commit
7a6988a4a6
@ -4,9 +4,9 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "transformations_visibility.hpp"
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
|
@ -103,7 +103,6 @@ TSSliceBackward::TSSliceBackward() {
|
||||
auto new_axis = std::make_shared<Gather>(data, indices, axis);
|
||||
main_node->input(4).replace_source_output(new_axis);
|
||||
|
||||
const auto& interpolate = std::dynamic_pointer_cast<Slice>(main_node);
|
||||
main_node->validate_and_infer_types();
|
||||
return true;
|
||||
};
|
||||
|
@ -23,13 +23,13 @@ using namespace ov::pass::transpose_sinking::utils;
|
||||
namespace {
|
||||
|
||||
bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
const std::shared_ptr<Constant>& reshape_to_shape,
|
||||
std::vector<size_t>& result_axes) {
|
||||
const std::shared_ptr<Constant>& reshape_to_shape,
|
||||
std::vector<size_t>& result_axes) {
|
||||
result_axes.clear();
|
||||
auto reduction_axes_values = reshape_to_shape->cast_vector<int64_t>();
|
||||
// supported the case if Reshape is equal to Unsqueeze
|
||||
const auto &new_shape = reduction_axes_values;
|
||||
const auto &input_pshape = reshape->get_input_partial_shape(0);
|
||||
const auto& new_shape = reduction_axes_values;
|
||||
const auto& input_pshape = reshape->get_input_partial_shape(0);
|
||||
// todo: support dynamic case
|
||||
if (input_pshape.is_dynamic()) {
|
||||
return false;
|
||||
@ -37,7 +37,7 @@ bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
|
||||
const auto input_shape = input_pshape.to_shape();
|
||||
if (new_shape.size() > input_shape.size()) {
|
||||
for (size_t i = 0, j = 0; i < input_shape.size();j++) {
|
||||
for (size_t i = 0, j = 0; i < input_shape.size(); j++) {
|
||||
const auto input_dim = static_cast<int64_t>(input_shape[i]);
|
||||
if (input_dim == new_shape[j]) {
|
||||
i++;
|
||||
@ -55,8 +55,9 @@ bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<size_t> unsqueeze_axes_to_shape(const std::shared_ptr<Node>& input_node, std::vector<size_t> unsqueeze_axes) {
|
||||
const auto& input_shape = input_node->input(0).get_shape(); // check is static
|
||||
std::vector<size_t> unsqueeze_axes_to_shape(const std::shared_ptr<Node>& input_node,
|
||||
std::vector<size_t> unsqueeze_axes) {
|
||||
const auto& input_shape = input_node->input(0).get_shape(); // check is static
|
||||
std::vector<size_t> to_shape(input_shape.size() + unsqueeze_axes.size());
|
||||
std::sort(unsqueeze_axes.begin(), unsqueeze_axes.end());
|
||||
std::stack<size_t, std::vector<size_t>> shape_to_add(input_shape);
|
||||
@ -99,14 +100,14 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
|
||||
}
|
||||
} else {
|
||||
auto rank = unsqueeze->get_output_partial_shape(0).rank();
|
||||
non_negative_axes = normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
|
||||
non_negative_axes =
|
||||
normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
|
||||
}
|
||||
auto ts_order_values = transpose_order->cast_vector<size_t>();
|
||||
|
||||
ts_order_values = GetOrderBeforeReduction(non_negative_axes, ts_order_values);
|
||||
auto new_transpose_order = Constant::create(transpose_order->get_element_type(),
|
||||
{ts_order_values.size()},
|
||||
ts_order_values);
|
||||
auto new_transpose_order =
|
||||
Constant::create(transpose_order->get_element_type(), {ts_order_values.size()}, ts_order_values);
|
||||
|
||||
auto new_unsqueeze = unsqueeze->clone_with_new_inputs({transpose->input_value(0), unsqueeze->input_value(1)});
|
||||
auto new_transpose = transpose->clone_with_new_inputs({new_unsqueeze, new_transpose_order});
|
||||
@ -128,7 +129,8 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
|
||||
TSUnsqueezeBackward::TSUnsqueezeBackward() {
|
||||
MATCHER_SCOPE(TSUnsqueezeBackward);
|
||||
|
||||
auto unsqueeze_label = wrap_type<Unsqueeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
|
||||
auto unsqueeze_label =
|
||||
wrap_type<Unsqueeze, Reshape>({any_input(), wrap_type<Constant>()}, HasSameOutputTransposeNodes);
|
||||
auto transpose_label = wrap_type<Transpose>({unsqueeze_label, wrap_type<Constant>()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
@ -150,7 +152,8 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
|
||||
}
|
||||
} else {
|
||||
auto rank = unsqueeze->get_output_partial_shape(0).rank();
|
||||
non_negative_axes = normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
|
||||
non_negative_axes =
|
||||
normalize_axes(unsqueeze->get_friendly_name(), unsqueeze_axes->cast_vector<int64_t>(), rank);
|
||||
}
|
||||
|
||||
auto transpose_order_values = transpose_order->cast_vector<size_t>();
|
||||
|
Loading…
Reference in New Issue
Block a user