fix gather_sinking_matmul

This commit is contained in:
Evgeny Kotov 2023-03-24 13:54:49 +01:00
parent 06d6fbf0e8
commit 7b95c90df8
3 changed files with 79 additions and 80 deletions

View File

@ -22,76 +22,60 @@ using namespace ov::intel_gna::pass;
using namespace ov::intel_gna::rt_info;
namespace {
/*
Reverts gather indices in a such way that reverted and initial gather will do nothing if
stays after another.
Works only with positive form (no negative indices).
*/
std::vector<int64_t> ReverseGatherIndexes(const std::vector<int64_t>& indexes) {
std::vector<int64_t> out(indexes.size());
for (size_t i = 0; i < indexes.size(); i++) {
out.at(indexes[i]) = i;
}
return out;
bool Has2dInputs(const Output<Node>& output) {
auto node = output.get_node_shared_ptr();
auto input_left_rank = node->get_input_partial_shape(0).rank();
auto input_right_rank = node->get_input_partial_shape(0).rank();
return (input_left_rank.is_static() && input_right_rank.is_static() &&
input_left_rank.get_length() == 2 && input_right_rank.get_length() == 2);
}
bool HasGatherInputs(const Output<Node>& output) {
return !GetFirstGatherInput(output.get_node_shared_ptr()).isEmpty();
}
bool IsSinked(const Output<Node>& output) {
return Has2dInputs(output) && HasGatherInputs(output);
}
int64_t Swap2DNegativeAxis(int64_t axis) {
if (axis == -1)
return -2;
return -1;
}
size_t GetAnotherMatMulIndex(size_t input_idx) {
if (!input_idx)
return 1;
return 0;
}
bool IsMatMulInputTransposed(const std::shared_ptr<MatMul>& matmul, size_t input_idx) {
if (!input_idx)
return matmul->get_transpose_a();
return matmul->get_transpose_b();
}
} // namespace
GatherSinkingMatmulForward::GatherSinkingMatmulForward() {
MATCHER_SCOPE(GatherSinkingMatmulForward);
auto gather_indices_label = wrap_type<Constant>();
auto gather_axis_label = wrap_type<Constant>();
auto gather_label = wrap_type<Gather>({any_input(), gather_indices_label, gather_axis_label});
auto matmul_label = wrap_type<MatMul>({gather_label, any_input()});
auto matmul_label = wrap_type<MatMul>({any_input(), any_input()}, IsSinked);
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto gather_indices = as_type_ptr<Constant>(pattern_to_output.at(gather_indices_label).get_node_shared_ptr());
auto gather_axis = as_type_ptr<Constant>(pattern_to_output.at(gather_axis_label).get_node_shared_ptr());
auto gather = as_type_ptr<Gather>(pattern_to_output.at(gather_label).get_node_shared_ptr());
auto matmul = as_type_ptr<MatMul>(pattern_to_output.at(matmul_label).get_node_shared_ptr());
GatherInputsInfo gather_input_info = GetFirstGatherInput(matmul);
std::cout << "[EMUTEX DEBUG] GatherSinkingMatmulForward gather " << gather->get_friendly_name() << " matmul " << matmul->get_friendly_name() << std::endl;
std::cout << "[EMUTEX DEBUG] GatherSinkingMatmulForward gather axis " << gather_axis->cast_vector<int64_t>()[0] << std::endl;
auto gather_parent = matmul->input_value(0 /* TODO */).get_node()->input_value(0);;
// insert input gather
#if 0
size_t gather_axis_value_current = ConvertAxisToPositive(gather_axis->cast_vector<int64_t>()[0],
gather->get_input_shape(0).size());
#endif
const size_t gather_axis_value_new = 0; // TODO
auto gather_axis_new1 = std::make_shared<Constant>(element::i64, Shape{}, gather_axis_value_new);
auto gather_indices_values = ReverseGatherIndexes(gather_indices->cast_vector<int64_t>());
auto gather_indices_new1 = std::make_shared<Constant>(element::i64, Shape{gather_indices_values.size()}, gather_indices_values);
auto gather_new1 = std::make_shared<Gather>(matmul->input_value(1) /* TODO */, gather_indices_new1, gather_axis_new1);
matmul->input(1 /* TODO */).replace_source_output(gather_new1->output(0));
// remove input gather
matmul->input(0 /* TODO */).replace_source_output(gather_parent);
// insert output gather
auto matmul_consumers = matmul->output(0).get_target_inputs();
auto gather_axis_new2 = gather_axis->clone_with_new_inputs({});
auto gather_indices_new2 = gather_indices->clone_with_new_inputs({});
auto gather_new2 = std::make_shared<Gather>(matmul->output(0), gather_indices_new2, gather_axis_new2);
for (auto& consumer : matmul_consumers) {
consumer.replace_source_output(gather_new2);
}
SwapFriendlyNames(gather_new2, matmul);
copy_runtime_info(gather, {gather_new1, gather_indices_new1, gather_axis_new1, gather_new2, gather_indices_new2, gather_axis_new2});
register_new_node(gather_new1);
register_new_node(gather_new2);
gather_sinking::UpdateForwardGatherSinkingAbility(gather_new2);
int64_t gather_negative_axis =
GetNormalizedNegativeGatherAxis(gather_input_info.axis_const,
gather_input_info.gather->get_input_partial_shape(0).rank().get_length());
gather_negative_axis = Swap2DNegativeAxis(gather_negative_axis);
if (IsMatMulInputTransposed(matmul, GetAnotherMatMulIndex(gather_input_info.input_idx)))
gather_negative_axis = Swap2DNegativeAxis(gather_negative_axis);
sink_forward::UpdateInputGather(matmul, gather_input_info, &gather_negative_axis);
return true;
};

View File

@ -103,12 +103,7 @@ int64_t NormalizeNegativeGatherAxis(int64_t axis, ov::Rank::value_type gather_in
return axis - gather_input_rank;
}
/*
Gets gather axis in negative form
*/
int64_t GetNormalizedNegativeGatherAxis(const std::shared_ptr<Constant>& axis, ov::Rank::value_type gather_input_rank) {
return NormalizeNegativeGatherAxis(axis->cast_vector<int64_t>()[0], gather_input_rank);
}
int64_t ConvertAxisToPositive(int64_t axis, ov::Rank::value_type rank) {
if (axis >= 0)
@ -116,18 +111,6 @@ int64_t ConvertAxisToPositive(int64_t axis, ov::Rank::value_type rank) {
return axis + rank;
}
/*
Reverts gather indices in a such way that reverted and initial gather will do nothing if
stays after another.
Works only with positive form (no negative indices).
*/
std::vector<int64_t> ReverseGatherIndexes(const std::vector<int64_t>& indexes) {
std::vector<int64_t> out(indexes.size());
for (size_t i = 0; i < indexes.size(); i++) {
out.at(indexes[i]) = i;
}
return out;
}
size_t GetDimByAxis(const Shape& shape, int64_t axis) {
if (axis < 0)
@ -152,6 +135,26 @@ Shape Broadcast(const Shape& shape, ov::Rank::value_type rank) {
} // namespace
/*
Gets gather axis in negative form
*/
int64_t GetNormalizedNegativeGatherAxis(const std::shared_ptr<Constant>& axis, ov::Rank::value_type gather_input_rank) {
return NormalizeNegativeGatherAxis(axis->cast_vector<int64_t>()[0], gather_input_rank);
}
/*
Reverts gather indices in a such way that reverted and initial gather will do nothing if
stays after another.
Works only with positive form (no negative indices).
*/
std::vector<int64_t> ReverseGatherIndexes(const std::vector<int64_t>& indexes) {
std::vector<int64_t> out(indexes.size());
for (size_t i = 0; i < indexes.size(); i++) {
out.at(indexes[i]) = i;
}
return out;
}
void SwapOutputNames(Output<Node> output1, Output<Node> output2) {
const auto node2_output_names = output2.get_names();
output2.set_names(output1.get_names());
@ -179,12 +182,15 @@ namespace sink_forward {
* Input nodes can have different shapes. That shapes can have smaller or larger ranks. To manage it we need
* to find max input shape rank and broadcast all input shapes to it.
*/
void UpdateInputGather(NodePtr main_node, const GatherInputsInfo& gather_input_info) {
void UpdateInputGather(NodePtr main_node, const GatherInputsInfo& gather_input_info, const int64_t* a_gather_negative_axis) {
if (gather_input_info.isEmpty() || HasDynamicRankInput(main_node))
return;
const int64_t gather_negative_axis =
GetNormalizedNegativeGatherAxis(gather_input_info.axis_const,
int64_t gather_negative_axis = {};
if (a_gather_negative_axis)
gather_negative_axis = *a_gather_negative_axis;
else
gather_negative_axis = GetNormalizedNegativeGatherAxis(gather_input_info.axis_const,
gather_input_info.gather->get_input_partial_shape(0).rank().get_length());
const std::vector<int64_t> gather_indices = GetNormalizedGatherIndices(gather_input_info.indices_const);

View File

@ -61,12 +61,21 @@ void SwapFriendlyNames(std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>);
*/
void SwapNames(std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>);
/*
Reverts gather indices in a such way that reverted and initial gather will do nothing if
stays after another.
Works only with positive form (no negative indices).
*/
std::vector<int64_t> ReverseGatherIndexes(const std::vector<int64_t>& indexes);
int64_t GetNormalizedNegativeGatherAxis(const std::shared_ptr<ov::opset9::Constant>& axis, ov::Rank::value_type gather_input_rank);
namespace sink_forward {
/**
* @brief Inserts reversed Gather on @args main_node inputs. Removes input Gather specified in @arg
* transpose_input_info
*/
void UpdateInputGather(std::shared_ptr<ov::Node> main_node, const GatherInputsInfo&);
void UpdateInputGather(std::shared_ptr<ov::Node> main_node, const GatherInputsInfo&, const int64_t* a_gather_negative_axis = nullptr);
/**
* @brief Removes @arg input node