add gather_sinking_reshape

This commit is contained in:
Evgeny Kotov 2023-03-22 16:00:00 +01:00
parent 02abf9b1f0
commit 31d7af368e
3 changed files with 145 additions and 0 deletions

View File

@ -15,6 +15,7 @@
#include "transformations/gather_sinking_binary.hpp"
#include "transformations/gather_sinking_fuse.hpp"
#include "transformations/gather_sinking_transpose_reshape.hpp"
#include "transformations/gather_sinking_reshape.hpp"
using namespace ov;
using namespace ov::pass::pattern;
@ -34,6 +35,7 @@ GatherSinkingGeneralBackward::GatherSinkingGeneralBackward() {
add_matcher<GatherSinkingUnaryBackward>();
add_matcher<GatherSinkingBinaryBackward>();
add_matcher<GatherSinkingTransposeReshapeBackward>();
add_matcher<GatherSinkingReshapeBackward>();
add_matcher<GatherSinkingFuse>();
}

View File

@ -0,0 +1,121 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/gather_sinking_reshape.hpp"
#include <openvino/cc/ngraph/itt.hpp>
#include <transformations/utils/utils.hpp>
#include <utility>
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/gather_sinking_attr.hpp"
#include "transformations/utils/gather_sinking_utils.hpp"
using namespace ov;
using namespace ov::opset10;
using namespace ov::pass::pattern;
using namespace ov::op::util;
using namespace gather_sinking;
using namespace ov::intel_gna::pass;
using namespace ov::intel_gna::rt_info;
namespace {
using NodePtr = std::shared_ptr<ov::Node>;
// remove leading and trailing '1' from shape
Shape StripShape(const Shape& shape) {
auto if_not_eq_1 = [](const Shape::value_type& value) { return value != 1; };
const auto start_it = std::find_if(shape.begin(), shape.end(), if_not_eq_1);
if (start_it == shape.end())
return {};
const auto end_it = std::find_if(shape.rbegin(), shape.rend(), if_not_eq_1);
if (end_it == shape.rend())
return {};
Shape result;
result.reserve(shape.size());
std::copy(start_it, shape.begin() + std::distance(end_it, shape.rend()), std::back_inserter(result));
return result;
}
template <typename InputIt, typename Predicate>
int FindFirstIndex(InputIt begin, InputIt end, Predicate predicate) {
const auto it = std::find_if(begin, end, predicate);
if (it == end)
return -1;
return std::distance(begin, it);
}
int GetLeftShift(const Shape& shape1, const Shape& shape2) {
auto if_not_eq_1 = [](const Shape::value_type& value) { return value != 1; };
const int index_1 = FindFirstIndex(shape1.begin(), shape1.end(), if_not_eq_1);
const int index_2 = FindFirstIndex(shape2.begin(), shape2.end(), if_not_eq_1);
if (index_1 < 0 || index_2 < 0)
return 0;
return index_1 - index_2;
}
bool IfGatherSinkingEnabled(const Output<Node>& output) {
return is_gather_sinking_node(output.get_node_shared_ptr());
}
bool IsReshapeUnsqueeze(const Output<Node>& output) {
NodePtr reshape = output.get_node_shared_ptr();
const Shape input_shape = StripShape(reshape->get_input_shape(0));
const Shape output_shape = StripShape(reshape->get_output_shape(0));
return std::equal(input_shape.begin(), input_shape.end(), output_shape.begin());
}
int64_t ConvertAxisToPositive(int64_t axis, ov::Rank::value_type rank) {
if (axis >= 0)
return axis;
return axis + rank;
}
} // namespace
GatherSinkingReshapeBackward::GatherSinkingReshapeBackward() {
MATCHER_SCOPE(GatherSinkingReshapeBackward);
auto reshape_const_label = wrap_type<Constant>();
auto reshape_label = wrap_type<Reshape>({any_input(), reshape_const_label}, IsReshapeUnsqueeze);
auto gather_indices_label = wrap_type<Constant>();
auto gather_axis_label = wrap_type<Constant>();
auto gather_label = wrap_type<Gather>({reshape_label, gather_indices_label, gather_axis_label}, IfGatherSinkingEnabled);
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto gather_indices = as_type_ptr<Constant>(pattern_to_output.at(gather_indices_label).get_node_shared_ptr());
auto gather_axis = as_type_ptr<Constant>(pattern_to_output.at(gather_axis_label).get_node_shared_ptr());
auto gather = as_type_ptr<Gather>(pattern_to_output.at(gather_label).get_node_shared_ptr());
auto reshape_const = as_type_ptr<Constant>(pattern_to_output.at(reshape_const_label).get_node_shared_ptr());
auto reshape = as_type_ptr<Reshape>(pattern_to_output.at(reshape_label).get_node_shared_ptr());
const int left_shift = GetLeftShift(reshape->get_input_shape(0), reshape->get_output_shape(0));
size_t gather_axis_value_current = ConvertAxisToPositive(gather_axis->cast_vector<int64_t>()[0],
gather->get_input_shape(0).size());
size_t gather_axis_value_new = gather_axis_value_current - left_shift;
auto gather_axis_new = std::make_shared<Constant>(element::i64, Shape{}, gather_axis_value_new);
auto gather_indices_new = gather_indices->clone_with_new_inputs({});
auto gather_new = std::make_shared<Gather>(reshape->input_value(0), gather_indices_new, gather_axis_new);
auto reshape_const_new = reshape_const->clone_with_new_inputs({});
auto reshape_new = reshape->clone_with_new_inputs({gather_new, reshape_const_new});
replace_node_update_name(gather, reshape_new);
copy_runtime_info(gather, {gather_new, gather_indices_new, gather_axis_new, reshape_new});
register_new_node(gather_new);
register_new_node(reshape_new);
return true;
};
auto m = std::make_shared<Matcher>(gather_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -0,0 +1,22 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "transformations_visibility.hpp"
namespace ov {
namespace intel_gna {
namespace pass {
class GatherSinkingReshapeBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("GatherSinkingReshapeBackward", "0");
GatherSinkingReshapeBackward();
};
} // namespace pass
} // namespace intel_gna
} // namespace ov