add gather_sinking_reshape
This commit is contained in:
parent
02abf9b1f0
commit
31d7af368e
@ -15,6 +15,7 @@
|
||||
#include "transformations/gather_sinking_binary.hpp"
|
||||
#include "transformations/gather_sinking_fuse.hpp"
|
||||
#include "transformations/gather_sinking_transpose_reshape.hpp"
|
||||
#include "transformations/gather_sinking_reshape.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::pass::pattern;
|
||||
@ -34,6 +35,7 @@ GatherSinkingGeneralBackward::GatherSinkingGeneralBackward() {
|
||||
add_matcher<GatherSinkingUnaryBackward>();
|
||||
add_matcher<GatherSinkingBinaryBackward>();
|
||||
add_matcher<GatherSinkingTransposeReshapeBackward>();
|
||||
add_matcher<GatherSinkingReshapeBackward>();
|
||||
add_matcher<GatherSinkingFuse>();
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,121 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/gather_sinking_reshape.hpp"
|
||||
|
||||
#include <openvino/cc/ngraph/itt.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <utility>
|
||||
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/rt_info/gather_sinking_attr.hpp"
|
||||
#include "transformations/utils/gather_sinking_utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::opset10;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace ov::op::util;
|
||||
using namespace gather_sinking;
|
||||
using namespace ov::intel_gna::pass;
|
||||
using namespace ov::intel_gna::rt_info;
|
||||
|
||||
namespace {
|
||||
|
||||
using NodePtr = std::shared_ptr<ov::Node>;
|
||||
|
||||
// remove leading and trailing '1' from shape
|
||||
Shape StripShape(const Shape& shape) {
|
||||
auto if_not_eq_1 = [](const Shape::value_type& value) { return value != 1; };
|
||||
const auto start_it = std::find_if(shape.begin(), shape.end(), if_not_eq_1);
|
||||
if (start_it == shape.end())
|
||||
return {};
|
||||
const auto end_it = std::find_if(shape.rbegin(), shape.rend(), if_not_eq_1);
|
||||
if (end_it == shape.rend())
|
||||
return {};
|
||||
|
||||
Shape result;
|
||||
result.reserve(shape.size());
|
||||
std::copy(start_it, shape.begin() + std::distance(end_it, shape.rend()), std::back_inserter(result));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename InputIt, typename Predicate>
|
||||
int FindFirstIndex(InputIt begin, InputIt end, Predicate predicate) {
|
||||
const auto it = std::find_if(begin, end, predicate);
|
||||
if (it == end)
|
||||
return -1;
|
||||
return std::distance(begin, it);
|
||||
}
|
||||
|
||||
int GetLeftShift(const Shape& shape1, const Shape& shape2) {
|
||||
auto if_not_eq_1 = [](const Shape::value_type& value) { return value != 1; };
|
||||
const int index_1 = FindFirstIndex(shape1.begin(), shape1.end(), if_not_eq_1);
|
||||
const int index_2 = FindFirstIndex(shape2.begin(), shape2.end(), if_not_eq_1);
|
||||
|
||||
if (index_1 < 0 || index_2 < 0)
|
||||
return 0;
|
||||
return index_1 - index_2;
|
||||
}
|
||||
|
||||
bool IfGatherSinkingEnabled(const Output<Node>& output) {
|
||||
return is_gather_sinking_node(output.get_node_shared_ptr());
|
||||
}
|
||||
|
||||
bool IsReshapeUnsqueeze(const Output<Node>& output) {
|
||||
NodePtr reshape = output.get_node_shared_ptr();
|
||||
const Shape input_shape = StripShape(reshape->get_input_shape(0));
|
||||
const Shape output_shape = StripShape(reshape->get_output_shape(0));
|
||||
return std::equal(input_shape.begin(), input_shape.end(), output_shape.begin());
|
||||
}
|
||||
|
||||
int64_t ConvertAxisToPositive(int64_t axis, ov::Rank::value_type rank) {
|
||||
if (axis >= 0)
|
||||
return axis;
|
||||
return axis + rank;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
GatherSinkingReshapeBackward::GatherSinkingReshapeBackward() {
|
||||
MATCHER_SCOPE(GatherSinkingReshapeBackward);
|
||||
auto reshape_const_label = wrap_type<Constant>();
|
||||
auto reshape_label = wrap_type<Reshape>({any_input(), reshape_const_label}, IsReshapeUnsqueeze);
|
||||
auto gather_indices_label = wrap_type<Constant>();
|
||||
auto gather_axis_label = wrap_type<Constant>();
|
||||
auto gather_label = wrap_type<Gather>({reshape_label, gather_indices_label, gather_axis_label}, IfGatherSinkingEnabled);
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto gather_indices = as_type_ptr<Constant>(pattern_to_output.at(gather_indices_label).get_node_shared_ptr());
|
||||
auto gather_axis = as_type_ptr<Constant>(pattern_to_output.at(gather_axis_label).get_node_shared_ptr());
|
||||
auto gather = as_type_ptr<Gather>(pattern_to_output.at(gather_label).get_node_shared_ptr());
|
||||
auto reshape_const = as_type_ptr<Constant>(pattern_to_output.at(reshape_const_label).get_node_shared_ptr());
|
||||
auto reshape = as_type_ptr<Reshape>(pattern_to_output.at(reshape_label).get_node_shared_ptr());
|
||||
|
||||
const int left_shift = GetLeftShift(reshape->get_input_shape(0), reshape->get_output_shape(0));
|
||||
size_t gather_axis_value_current = ConvertAxisToPositive(gather_axis->cast_vector<int64_t>()[0],
|
||||
gather->get_input_shape(0).size());
|
||||
size_t gather_axis_value_new = gather_axis_value_current - left_shift;
|
||||
|
||||
auto gather_axis_new = std::make_shared<Constant>(element::i64, Shape{}, gather_axis_value_new);
|
||||
auto gather_indices_new = gather_indices->clone_with_new_inputs({});
|
||||
auto gather_new = std::make_shared<Gather>(reshape->input_value(0), gather_indices_new, gather_axis_new);
|
||||
|
||||
auto reshape_const_new = reshape_const->clone_with_new_inputs({});
|
||||
auto reshape_new = reshape->clone_with_new_inputs({gather_new, reshape_const_new});
|
||||
|
||||
replace_node_update_name(gather, reshape_new);
|
||||
copy_runtime_info(gather, {gather_new, gather_indices_new, gather_axis_new, reshape_new});
|
||||
|
||||
register_new_node(gather_new);
|
||||
register_new_node(reshape_new);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(gather_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
@ -0,0 +1,22 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "transformations_visibility.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_gna {
|
||||
namespace pass {
|
||||
|
||||
class GatherSinkingReshapeBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("GatherSinkingReshapeBackward", "0");
|
||||
GatherSinkingReshapeBackward();
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace intel_gna
|
||||
} // namespace ov
|
Loading…
Reference in New Issue
Block a user