add docs
This commit is contained in:
parent
bd35e57c60
commit
860ae33f38
@ -12,12 +12,79 @@ namespace ov {
|
||||
namespace intel_gna {
|
||||
namespace pass {
|
||||
|
||||
/**
|
||||
* @brief Moves Gather layer forward from the start to the end of the graph
|
||||
* through the BinaryElementwiseArithmetic operations. Gather layer is moved
|
||||
* from the Binary input to the Binary output. Reversed Gather layer is moved
|
||||
* to another Binary input.
|
||||
* Reversed Gather is the layer that together one after another with the original Gather
|
||||
* layer gives a subgraph that does nothing.
|
||||
* Reversed Gather layer is expected to be pushed by backward GatherSinking to another
|
||||
* Model input or a constant. After all sinking operations we hope to find all Gather
|
||||
* layers on Model inputs and execute them on CPU or before constants and fold them.
|
||||
* Transformation restrictions:
|
||||
* - Gather has only 1D indices
|
||||
* - all nodes have static ranks
|
||||
*
|
||||
* Any1 Any1 Any2
|
||||
* | | |
|
||||
* Gather Any2 | Reversed-Gather
|
||||
* | | => | |
|
||||
* Binary Binary
|
||||
* | |
|
||||
* Any3 Gather
|
||||
* |
|
||||
* Any3
|
||||
*
|
||||
* Any1 Any2 Any1
|
||||
* | | |
|
||||
* Any2 Gather Reversed-Gather |
|
||||
* | | => | |
|
||||
* Binary Binary
|
||||
* | |
|
||||
* Any3 Gather
|
||||
* |
|
||||
* Any3
|
||||
*
|
||||
* All GatherSinking tranformations are designed to work in 2 steps:
|
||||
* - forward push
|
||||
* - backward push
|
||||
* Add flag into Gather layer rt_info that prevents backward sinking if the next layer
|
||||
* after Gather does not support by GatherSinking transformations. That is done to
|
||||
* prevent backward pushing the layer that already pushed forward through the graph.
|
||||
*/
|
||||
class GatherSinkingBinaryForward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("GatherSinkingBinaryForward", "0");
|
||||
GatherSinkingBinaryForward();
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Moves Gather layer backward from the end to the start of the graph
|
||||
*
|
||||
* Any1 Any2 Any1 Any2
|
||||
* | | | |
|
||||
* Binary => Gather Gather
|
||||
* | | |
|
||||
* Gather Binary
|
||||
* | |
|
||||
* Any3 Any3
|
||||
*
|
||||
* Any1 Any2 Any1 Any2
|
||||
* | | | |
|
||||
* Binary => Gather Gather
|
||||
* | | | |
|
||||
* Gather Gather Binary
|
||||
* | | | |
|
||||
* Any3 Any4 Any3 Any4
|
||||
*
|
||||
* Moves Gather layer backward only if:
|
||||
* - Gather is not marked as non-sinkable
|
||||
* - all Binary consumers are Gather layers
|
||||
* - all that Gather layers equal each other
|
||||
* - Gather has only 1D indices
|
||||
* - all nodes have static ranks
|
||||
*/
|
||||
class GatherSinkingBinaryBackward : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("GatherSinkingBinaryBackward", "0");
|
||||
|
@ -195,8 +195,8 @@ void UpdateInputGather(NodePtr main_node, const GatherInputsInfo& gather_input_i
|
||||
|
||||
const auto max_input_rank = GetMaxInputRank(main_node);
|
||||
if (max_input_rank < 0)
|
||||
return;
|
||||
|
||||
return;
|
||||
|
||||
for (size_t i = 0; i < main_node->get_input_size(); ++i) {
|
||||
auto input_node = main_node->input_value(i);
|
||||
if (i == gather_input_info.input_idx) {
|
||||
@ -389,7 +389,7 @@ GatherInfo GetGatherInfo(Node* node) {
|
||||
}
|
||||
|
||||
Node* FindFirstConsumer(NodePtr node) {
|
||||
for (auto& output : node->outputs()) {
|
||||
for (auto output : node->outputs()) {
|
||||
auto inputs = output.get_target_inputs();
|
||||
if (inputs.empty())
|
||||
continue;
|
||||
|
Loading…
Reference in New Issue
Block a user