This commit is contained in:
Evgeny Kotov 2023-02-21 13:29:05 +01:00
parent bd35e57c60
commit 860ae33f38
2 changed files with 70 additions and 3 deletions

View File

@ -12,12 +12,79 @@ namespace ov {
namespace intel_gna {
namespace pass {
/**
* @brief Moves Gather layer forward from the start to the end of the graph
* through the BinaryElementwiseArithmetic operations. Gather layer is moved
* from the Binary input to the Binary output. Reversed Gather layer is moved
* to another Binary input.
* Reversed Gather is the layer that together one after another with the original Gather
* layer gives a subgraph that does nothing.
* Reversed Gather layer is expected to be pushed by backward GatherSinking to another
* Model input or a constant. After all sinking operations we hope to find all Gather
* layers on Model inputs and execute them on CPU or before constants and fold them.
* Transformation restrictions:
* - Gather has only 1D indices
* - all nodes have static ranks
*
* Any1 Any1 Any2
* | | |
* Gather Any2 | Reversed-Gather
* | | => | |
* Binary Binary
* | |
* Any3 Gather
* |
* Any3
*
* Any1 Any2 Any1
* | | |
* Any2 Gather Reversed-Gather |
* | | => | |
* Binary Binary
* | |
* Any3 Gather
* |
* Any3
*
* All GatherSinking tranformations are designed to work in 2 steps:
* - forward push
* - backward push
* Add flag into Gather layer rt_info that prevents backward sinking if the next layer
* after Gather does not support by GatherSinking transformations. That is done to
* prevent backward pushing the layer that already pushed forward through the graph.
*/
class GatherSinkingBinaryForward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("GatherSinkingBinaryForward", "0");
GatherSinkingBinaryForward();
};
/**
* @brief Moves Gather layer backward from the end to the start of the graph
*
* Any1 Any2 Any1 Any2
* | | | |
* Binary => Gather Gather
* | | |
* Gather Binary
* | |
* Any3 Any3
*
* Any1 Any2 Any1 Any2
* | | | |
* Binary => Gather Gather
* | | | |
* Gather Gather Binary
* | | | |
* Any3 Any4 Any3 Any4
*
* Moves Gather layer backward only if:
* - Gather is not marked as non-sinkable
* - all Binary consumers are Gather layers
* - all that Gather layers equal each other
* - Gather has only 1D indices
* - all nodes have static ranks
*/
class GatherSinkingBinaryBackward : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("GatherSinkingBinaryBackward", "0");

View File

@ -195,8 +195,8 @@ void UpdateInputGather(NodePtr main_node, const GatherInputsInfo& gather_input_i
const auto max_input_rank = GetMaxInputRank(main_node);
if (max_input_rank < 0)
return;
return;
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
auto input_node = main_node->input_value(i);
if (i == gather_input_info.input_idx) {
@ -389,7 +389,7 @@ GatherInfo GetGatherInfo(Node* node) {
}
Node* FindFirstConsumer(NodePtr node) {
for (auto& output : node->outputs()) {
for (auto output : node->outputs()) {
auto inputs = output.get_target_inputs();
if (inputs.empty())
continue;