fix gather_sinking_matmul
This commit is contained in:
parent
06d6fbf0e8
commit
7b95c90df8
@ -22,76 +22,60 @@ using namespace ov::intel_gna::pass;
|
||||
using namespace ov::intel_gna::rt_info;
|
||||
|
||||
namespace {
|
||||
/*
|
||||
Reverts gather indices in a such way that reverted and initial gather will do nothing if
|
||||
stays after another.
|
||||
Works only with positive form (no negative indices).
|
||||
*/
|
||||
std::vector<int64_t> ReverseGatherIndexes(const std::vector<int64_t>& indexes) {
|
||||
std::vector<int64_t> out(indexes.size());
|
||||
for (size_t i = 0; i < indexes.size(); i++) {
|
||||
out.at(indexes[i]) = i;
|
||||
}
|
||||
return out;
|
||||
bool Has2dInputs(const Output<Node>& output) {
|
||||
auto node = output.get_node_shared_ptr();
|
||||
auto input_left_rank = node->get_input_partial_shape(0).rank();
|
||||
auto input_right_rank = node->get_input_partial_shape(0).rank();
|
||||
return (input_left_rank.is_static() && input_right_rank.is_static() &&
|
||||
input_left_rank.get_length() == 2 && input_right_rank.get_length() == 2);
|
||||
}
|
||||
|
||||
bool HasGatherInputs(const Output<Node>& output) {
|
||||
return !GetFirstGatherInput(output.get_node_shared_ptr()).isEmpty();
|
||||
}
|
||||
|
||||
bool IsSinked(const Output<Node>& output) {
|
||||
return Has2dInputs(output) && HasGatherInputs(output);
|
||||
}
|
||||
|
||||
int64_t Swap2DNegativeAxis(int64_t axis) {
|
||||
if (axis == -1)
|
||||
return -2;
|
||||
return -1;
|
||||
}
|
||||
|
||||
size_t GetAnotherMatMulIndex(size_t input_idx) {
|
||||
if (!input_idx)
|
||||
return 1;
|
||||
return 0;
|
||||
}
|
||||
|
||||
bool IsMatMulInputTransposed(const std::shared_ptr<MatMul>& matmul, size_t input_idx) {
|
||||
if (!input_idx)
|
||||
return matmul->get_transpose_a();
|
||||
return matmul->get_transpose_b();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
GatherSinkingMatmulForward::GatherSinkingMatmulForward() {
|
||||
MATCHER_SCOPE(GatherSinkingMatmulForward);
|
||||
auto gather_indices_label = wrap_type<Constant>();
|
||||
auto gather_axis_label = wrap_type<Constant>();
|
||||
auto gather_label = wrap_type<Gather>({any_input(), gather_indices_label, gather_axis_label});
|
||||
auto matmul_label = wrap_type<MatMul>({gather_label, any_input()});
|
||||
|
||||
auto matmul_label = wrap_type<MatMul>({any_input(), any_input()}, IsSinked);
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto gather_indices = as_type_ptr<Constant>(pattern_to_output.at(gather_indices_label).get_node_shared_ptr());
|
||||
auto gather_axis = as_type_ptr<Constant>(pattern_to_output.at(gather_axis_label).get_node_shared_ptr());
|
||||
auto gather = as_type_ptr<Gather>(pattern_to_output.at(gather_label).get_node_shared_ptr());
|
||||
auto matmul = as_type_ptr<MatMul>(pattern_to_output.at(matmul_label).get_node_shared_ptr());
|
||||
GatherInputsInfo gather_input_info = GetFirstGatherInput(matmul);
|
||||
|
||||
std::cout << "[EMUTEX DEBUG] GatherSinkingMatmulForward gather " << gather->get_friendly_name() << " matmul " << matmul->get_friendly_name() << std::endl;
|
||||
std::cout << "[EMUTEX DEBUG] GatherSinkingMatmulForward gather axis " << gather_axis->cast_vector<int64_t>()[0] << std::endl;
|
||||
|
||||
auto gather_parent = matmul->input_value(0 /* TODO */).get_node()->input_value(0);;
|
||||
// insert input gather
|
||||
#if 0
|
||||
size_t gather_axis_value_current = ConvertAxisToPositive(gather_axis->cast_vector<int64_t>()[0],
|
||||
gather->get_input_shape(0).size());
|
||||
#endif
|
||||
const size_t gather_axis_value_new = 0; // TODO
|
||||
|
||||
auto gather_axis_new1 = std::make_shared<Constant>(element::i64, Shape{}, gather_axis_value_new);
|
||||
auto gather_indices_values = ReverseGatherIndexes(gather_indices->cast_vector<int64_t>());
|
||||
auto gather_indices_new1 = std::make_shared<Constant>(element::i64, Shape{gather_indices_values.size()}, gather_indices_values);
|
||||
auto gather_new1 = std::make_shared<Gather>(matmul->input_value(1) /* TODO */, gather_indices_new1, gather_axis_new1);
|
||||
|
||||
matmul->input(1 /* TODO */).replace_source_output(gather_new1->output(0));
|
||||
|
||||
// remove input gather
|
||||
matmul->input(0 /* TODO */).replace_source_output(gather_parent);
|
||||
|
||||
// insert output gather
|
||||
auto matmul_consumers = matmul->output(0).get_target_inputs();
|
||||
|
||||
auto gather_axis_new2 = gather_axis->clone_with_new_inputs({});
|
||||
auto gather_indices_new2 = gather_indices->clone_with_new_inputs({});
|
||||
auto gather_new2 = std::make_shared<Gather>(matmul->output(0), gather_indices_new2, gather_axis_new2);
|
||||
|
||||
for (auto& consumer : matmul_consumers) {
|
||||
consumer.replace_source_output(gather_new2);
|
||||
}
|
||||
|
||||
SwapFriendlyNames(gather_new2, matmul);
|
||||
|
||||
copy_runtime_info(gather, {gather_new1, gather_indices_new1, gather_axis_new1, gather_new2, gather_indices_new2, gather_axis_new2});
|
||||
|
||||
register_new_node(gather_new1);
|
||||
register_new_node(gather_new2);
|
||||
|
||||
gather_sinking::UpdateForwardGatherSinkingAbility(gather_new2);
|
||||
int64_t gather_negative_axis =
|
||||
GetNormalizedNegativeGatherAxis(gather_input_info.axis_const,
|
||||
gather_input_info.gather->get_input_partial_shape(0).rank().get_length());
|
||||
gather_negative_axis = Swap2DNegativeAxis(gather_negative_axis);
|
||||
if (IsMatMulInputTransposed(matmul, GetAnotherMatMulIndex(gather_input_info.input_idx)))
|
||||
gather_negative_axis = Swap2DNegativeAxis(gather_negative_axis);
|
||||
|
||||
sink_forward::UpdateInputGather(matmul, gather_input_info, &gather_negative_axis);
|
||||
return true;
|
||||
};
|
||||
|
||||
|
@ -103,12 +103,7 @@ int64_t NormalizeNegativeGatherAxis(int64_t axis, ov::Rank::value_type gather_in
|
||||
return axis - gather_input_rank;
|
||||
}
|
||||
|
||||
/*
|
||||
Gets gather axis in negative form
|
||||
*/
|
||||
int64_t GetNormalizedNegativeGatherAxis(const std::shared_ptr<Constant>& axis, ov::Rank::value_type gather_input_rank) {
|
||||
return NormalizeNegativeGatherAxis(axis->cast_vector<int64_t>()[0], gather_input_rank);
|
||||
}
|
||||
|
||||
|
||||
int64_t ConvertAxisToPositive(int64_t axis, ov::Rank::value_type rank) {
|
||||
if (axis >= 0)
|
||||
@ -116,18 +111,6 @@ int64_t ConvertAxisToPositive(int64_t axis, ov::Rank::value_type rank) {
|
||||
return axis + rank;
|
||||
}
|
||||
|
||||
/*
|
||||
Reverts gather indices in a such way that reverted and initial gather will do nothing if
|
||||
stays after another.
|
||||
Works only with positive form (no negative indices).
|
||||
*/
|
||||
std::vector<int64_t> ReverseGatherIndexes(const std::vector<int64_t>& indexes) {
|
||||
std::vector<int64_t> out(indexes.size());
|
||||
for (size_t i = 0; i < indexes.size(); i++) {
|
||||
out.at(indexes[i]) = i;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
size_t GetDimByAxis(const Shape& shape, int64_t axis) {
|
||||
if (axis < 0)
|
||||
@ -152,6 +135,26 @@ Shape Broadcast(const Shape& shape, ov::Rank::value_type rank) {
|
||||
|
||||
} // namespace
|
||||
|
||||
/*
|
||||
Gets gather axis in negative form
|
||||
*/
|
||||
int64_t GetNormalizedNegativeGatherAxis(const std::shared_ptr<Constant>& axis, ov::Rank::value_type gather_input_rank) {
|
||||
return NormalizeNegativeGatherAxis(axis->cast_vector<int64_t>()[0], gather_input_rank);
|
||||
}
|
||||
|
||||
/*
|
||||
Reverts gather indices in a such way that reverted and initial gather will do nothing if
|
||||
stays after another.
|
||||
Works only with positive form (no negative indices).
|
||||
*/
|
||||
std::vector<int64_t> ReverseGatherIndexes(const std::vector<int64_t>& indexes) {
|
||||
std::vector<int64_t> out(indexes.size());
|
||||
for (size_t i = 0; i < indexes.size(); i++) {
|
||||
out.at(indexes[i]) = i;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
void SwapOutputNames(Output<Node> output1, Output<Node> output2) {
|
||||
const auto node2_output_names = output2.get_names();
|
||||
output2.set_names(output1.get_names());
|
||||
@ -179,12 +182,15 @@ namespace sink_forward {
|
||||
* Input nodes can have different shapes. That shapes can have smaller or larger ranks. To manage it we need
|
||||
* to find max input shape rank and broadcast all input shapes to it.
|
||||
*/
|
||||
void UpdateInputGather(NodePtr main_node, const GatherInputsInfo& gather_input_info) {
|
||||
void UpdateInputGather(NodePtr main_node, const GatherInputsInfo& gather_input_info, const int64_t* a_gather_negative_axis) {
|
||||
if (gather_input_info.isEmpty() || HasDynamicRankInput(main_node))
|
||||
return;
|
||||
|
||||
const int64_t gather_negative_axis =
|
||||
GetNormalizedNegativeGatherAxis(gather_input_info.axis_const,
|
||||
int64_t gather_negative_axis = {};
|
||||
if (a_gather_negative_axis)
|
||||
gather_negative_axis = *a_gather_negative_axis;
|
||||
else
|
||||
gather_negative_axis = GetNormalizedNegativeGatherAxis(gather_input_info.axis_const,
|
||||
gather_input_info.gather->get_input_partial_shape(0).rank().get_length());
|
||||
|
||||
const std::vector<int64_t> gather_indices = GetNormalizedGatherIndices(gather_input_info.indices_const);
|
||||
|
@ -61,12 +61,21 @@ void SwapFriendlyNames(std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>);
|
||||
*/
|
||||
void SwapNames(std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>);
|
||||
|
||||
/*
|
||||
Reverts gather indices in a such way that reverted and initial gather will do nothing if
|
||||
stays after another.
|
||||
Works only with positive form (no negative indices).
|
||||
*/
|
||||
std::vector<int64_t> ReverseGatherIndexes(const std::vector<int64_t>& indexes);
|
||||
|
||||
int64_t GetNormalizedNegativeGatherAxis(const std::shared_ptr<ov::opset9::Constant>& axis, ov::Rank::value_type gather_input_rank);
|
||||
|
||||
namespace sink_forward {
|
||||
/**
|
||||
* @brief Inserts reversed Gather on @args main_node inputs. Removes input Gather specified in @arg
|
||||
* transpose_input_info
|
||||
*/
|
||||
void UpdateInputGather(std::shared_ptr<ov::Node> main_node, const GatherInputsInfo&);
|
||||
void UpdateInputGather(std::shared_ptr<ov::Node> main_node, const GatherInputsInfo&, const int64_t* a_gather_negative_axis = nullptr);
|
||||
|
||||
/**
|
||||
* @brief Removes @arg input node
|
||||
|
Loading…
Reference in New Issue
Block a user