Files
openvino/model-optimizer/extensions/front/tf/RollRealImagPack.py
Vladimir Gavrilov 0c288d506c MO support for operations DFT and IDFT (#5197)
* Written MO classes for DFT and IDFT operations.

* Added class to read TF (I)FFT operations.

* Written extractors for TF operations FFT, FFT2D, FFT3D, IFFT, IFFT2D, IFFT3D.

* Written MO Roll operation and TF Roll operation extractor.

* Started to write needed transformations.

* Written transformation StridedSlices + Complex + Roll + (i)FFTxD + Roll + (Imag, Real) + Pack -> Roll + (I)DFT + Roll.

* Written transformation for Complex + ComplexAbs.

* Written correction of axes of Roll.

* Small fix.

* Small fix.

* Some fixes.

* Some changes.

* Now TF Roll is read as TFRoll. Written inserting Transposes before and after (I)DFT.

* Small fix.

* Written tests for the transformation TFRollToRoll.

* Added comments to some transformations.

* Deleted redundant import.

* Written tests for the transformation TransposeDFT.

* Fixes in MO IR Reader to read/write (I)DFT.

* Fixes in the list of supported TF layers.

* Started to write tests for SSliceComplexRolledFFTPackBlockReplacement transformation.

* Written tests for the MO transformation SSliceComplexRolledFFTPackBlockReplacement.

* Written tests for the MO transformation ComplexAbs.

* Tests for transformations were moved into unit_tests directory.

* All extractors for (I)FFTxD are in one file now.

* Deleted redundant transformations.

* Fixed extractor for TF Roll: now this operation is read as MO Roll.

* Added comments to TFFFT operation.

* The method insert_transpose of classes TransposeDFT and LayoutChangeForGatherND was moved into the separate function in the file model-optimizer/extensions/middle/InsertLayoutPropagationTransposes.py.

* Fixed comment for the transformation TransposeDFT.

* Small fix.

* Some fixes.

* Deleted shape infer function for the operation TFFFT. Sorted imports in complex_abs.py.

* Small fixes.

* Deleted redundant import.

* Fixes in some asserts.

* Small fix.

* Added names for created nodes in the transformation ComplexAbs.

* Added comments to the method canonicalize_axes.

* The transformation SSliceComplexRolledFFTPackBlockReplacement was split into the sequence of transformations SSliceComplexRollReplacement -> RollRealImagPackReplacement -> TFFFTToDFT.

* Written tests for the transformation SSliceComplexRollReplacement.

* Written tests for the transformation RollRealImagPackReplacement.

* Written tests for the transformation TFFFTToDFT.

* Deleted commented code.

* Fixed types of constants in the transformation ComplexAbs.

* Written tests for canonicalization of signal_size value.

* Deleted 'Replacement' from names of files and classes.

* Used comarison of ids, not names.

* replace_sub_graph was replaced with find_and_replace_pattern.

* Now the transformation RollRealImagPack is executed before running transformation model-optimizer/extensions/front/Pack.py.

* The body of the function create_dft_from_tffft is a part of the transformation TFFFTToDFT body now.

* Now method correct_roll_axes of classes RollRealImagPack and SSliceComplexRoll is moved to the function in mo/front/tf/graph_utils.py.

* Small changes.

* Added comment before mark_input_as_in_correct_layout(roll, 2).

* Now the functions correct_roll_axes generates sub-graph in the input port 2 of Roll.

* Corrected tests for the transformation SSliceComplexRoll.

* Corrected tests for the transformation RollRealImagPack.

* Deleted commented code.

* Some renaming.

* Added decomposition of the separate operation ComplexAbs (without Complex before it).

* Added comment to the transformation ComplexAbsAfterComplex.

* Optimized imports for the transformation TFFFTToDFT.

* The transformation SSliceComplexRoll was split into the sequence SSliceComplex -> CorrectRollAxes and disabled.

* Written tests for the transformation ComplexAbs.

* Written tests for the transformation SSliceComplex.

* Written tests for the transformation CorrectRollAxes.

* Deleted the transformation SSliceComplexRoll.

* Deleted renaming nodes.

* Fixed comment.

* Small fixes.

* Small fix.

* The attribute need_correction was renamed as input_rank_changed.

* Small fixes.

* Deleted commented code.

* Now we iterate over all complex_node.out_port(0).get_connection().get_destinations() input ports and mark the corresponding nodes with the marker attribute.

* Added the attribute 'in_ports_count' into the class FFTBase.

* Tests for the transformation TransposeDFT were rewritten using helper functions.

* Now the transformation RollRealImagPack uses existing Roll node instead of creating new one.

* Small fixes.

* Fix in the documentation.

* Written class to read MxNet (I)FFT operations. Written corresponding extractors.

* Corrected shape infer function for MXFFT operation. Written transformation to convert MXFFT to (I)DFT.

* Fixed shape infer function.

* Fixed the conversion MXFFT to (I)DFT.

* Written tests for the transformation MXFFTToDFT.

* The function correct_roll_axes was replaced with more generic function add_constant_to_negative_values.

* Fixes in classes TFFFT, FFTBase, DFT, IDFT, MXFFT.

* Added asserts in constructors of operations TFFFT and MXFFT.

* Refactored transformation MXFFTToDFT: conversion of DFT and IDFT were moved into separated functions.

* Moved some commented code.

* Fixed BOM file.

* Written function convert_ifft_to_dft.

* Started to rewrite tests for MXFFTToDFT transformations, in the case is_inverse=False.

* Small fixes.

* Fixes in the transformation RollRealImagPack.

* Renaming tests class for the transformation SSliceComplex.

* Fixes in the function compare_graphs. Now we get all output nodes of op node, and these output nodes are sorted by names.

* Fixed tests for the transformation MXFFTToDFT.

* Fix in the transformation ThresholdedReluDecomposition: added disconnect for trelu input port.

* Fixes in test for the transformation TFSliceToSlice.

* Small fix in the transformation ObjectDetectionAPIPreprocessor2Replacement.

* Small fix in comment.

* Optimized imports.

* Used remove_node in the transformation ThresholdedReluDecomposition and remove_nodes_from in the transformation RollRealImagPack, instead of ports disconnection.

* Deleted commented code.

* Deleted test case test_slice_replacer_begin_with_2_inputs.
2021-05-07 09:44:24 +03:00

74 lines
2.3 KiB
Python

# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.front.subgraph_matcher import SubgraphMatch
from mo.front.tf.graph_utils import add_constant_to_negative_values
from mo.graph.graph import Graph
class RollRealImagPack(FrontReplacementSubgraph):
"""
Some TF models contain Roll for complex data, as a part of the sub-graph
input shift axes
| | |
-------------------
Roll
|
-------------------
| |
Real Imag
| |
------- -------
| |
Pack
|
SomeOp
This sub-graph can be replaced with the sub-graph
input shift axes
| | |
-------------------
Roll
|
SomeOp
But after such replacement, we should correct axes of Roll, because input data are real now. Namely, if
there are negative axes for Roll, we need subtract 1 from such axes indices.
"""
enabled = True
def run_after(self):
from extensions.front.tf.SSliceComplex import SSliceComplex
return [SSliceComplex]
def run_before(self):
from extensions.front.Pack import Pack
return [Pack]
def pattern(self):
return dict(
nodes=[
('unroll', dict(op='Roll')),
('real', dict(op='Real')),
('imag', dict(op='Imag')),
('pack', dict(op='Pack')),
],
edges=[
('unroll', 'real', {'in': 0}),
('unroll', 'imag', {'in': 0}),
('real', 'pack', {'in': 0}),
('imag', 'pack', {'in': 1}),
])
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
unroll = match['unroll']
add_constant_to_negative_values(unroll, 2, int64_array(-1))
pack = match['pack']
pack.out_port(0).get_connection().set_source(unroll.out_port(0))
graph.remove_nodes_from([match['real'].id, match['imag'].id])