Files
openvino/model-optimizer/extensions/ops/roll.py
Vladimir Gavrilov 0c288d506c MO support for operations DFT and IDFT (#5197)
* Written MO classes for DFT and IDFT operations.

* Added class to read TF (I)FFT operations.

* Written extractors for TF operations FFT, FFT2D, FFT3D, IFFT, IFFT2D, IFFT3D.

* Written MO Roll operation and TF Roll operation extractor.

* Started to write needed transformations.

* Written transformation StridedSlices + Complex + Roll + (i)FFTxD + Roll + (Imag, Real) + Pack -> Roll + (I)DFT + Roll.

* Written transformation for Complex + ComplexAbs.

* Written correction of axes of Roll.

* Small fix.

* Small fix.

* Some fixes.

* Some changes.

* Now TF Roll is read as TFRoll. Written inserting Transposes before and after (I)DFT.

* Small fix.

* Written tests for the transformation TFRollToRoll.

* Added comments to some transformations.

* Deleted redundant import.

* Written tests for the transformation TransposeDFT.

* Fixes in MO IR Reader to read/write (I)DFT.

* Fixes in the list of supported TF layers.

* Started to write tests for SSliceComplexRolledFFTPackBlockReplacement transformation.

* Written tests for the MO transformation SSliceComplexRolledFFTPackBlockReplacement.

* Written tests for the MO transformation ComplexAbs.

* Tests for transformations were moved into unit_tests directory.

* All extractors for (I)FFTxD are in one file now.

* Deleted redundant transformations.

* Fixed extractor for TF Roll: now this operation is read as MO Roll.

* Added comments to TFFFT operation.

* The method insert_transpose of classes TransposeDFT and LayoutChangeForGatherND was moved into the separate function in the file model-optimizer/extensions/middle/InsertLayoutPropagationTransposes.py.

* Fixed comment for the transformation TransposeDFT.

* Small fix.

* Some fixes.

* Deleted shape infer function for the operation TFFFT. Sorted imports in complex_abs.py.

* Small fixes.

* Deleted redundant import.

* Fixes in some asserts.

* Small fix.

* Added names for created nodes in the transformation ComplexAbs.

* Added comments to the method canonicalize_axes.

* The transformation SSliceComplexRolledFFTPackBlockReplacement was split into the sequence of transformations SSliceComplexRollReplacement -> RollRealImagPackReplacement -> TFFFTToDFT.

* Written tests for the transformation SSliceComplexRollReplacement.

* Written tests for the transformation RollRealImagPackReplacement.

* Written tests for the transformation TFFFTToDFT.

* Deleted commented code.

* Fixed types of constants in the transformation ComplexAbs.

* Written tests for canonicalization of signal_size value.

* Deleted 'Replacement' from names of files and classes.

* Used comarison of ids, not names.

* replace_sub_graph was replaced with find_and_replace_pattern.

* Now the transformation RollRealImagPack is executed before running transformation model-optimizer/extensions/front/Pack.py.

* The body of the function create_dft_from_tffft is a part of the transformation TFFFTToDFT body now.

* Now method correct_roll_axes of classes RollRealImagPack and SSliceComplexRoll is moved to the function in mo/front/tf/graph_utils.py.

* Small changes.

* Added comment before mark_input_as_in_correct_layout(roll, 2).

* Now the functions correct_roll_axes generates sub-graph in the input port 2 of Roll.

* Corrected tests for the transformation SSliceComplexRoll.

* Corrected tests for the transformation RollRealImagPack.

* Deleted commented code.

* Some renaming.

* Added decomposition of the separate operation ComplexAbs (without Complex before it).

* Added comment to the transformation ComplexAbsAfterComplex.

* Optimized imports for the transformation TFFFTToDFT.

* The transformation SSliceComplexRoll was split into the sequence SSliceComplex -> CorrectRollAxes and disabled.

* Written tests for the transformation ComplexAbs.

* Written tests for the transformation SSliceComplex.

* Written tests for the transformation CorrectRollAxes.

* Deleted the transformation SSliceComplexRoll.

* Deleted renaming nodes.

* Fixed comment.

* Small fixes.

* Small fix.

* The attribute need_correction was renamed as input_rank_changed.

* Small fixes.

* Deleted commented code.

* Now we iterate over all complex_node.out_port(0).get_connection().get_destinations() input ports and mark the corresponding nodes with the marker attribute.

* Added the attribute 'in_ports_count' into the class FFTBase.

* Tests for the transformation TransposeDFT were rewritten using helper functions.

* Now the transformation RollRealImagPack uses existing Roll node instead of creating new one.

* Small fixes.

* Fix in the documentation.

* Written class to read MxNet (I)FFT operations. Written corresponding extractors.

* Corrected shape infer function for MXFFT operation. Written transformation to convert MXFFT to (I)DFT.

* Fixed shape infer function.

* Fixed the conversion MXFFT to (I)DFT.

* Written tests for the transformation MXFFTToDFT.

* The function correct_roll_axes was replaced with more generic function add_constant_to_negative_values.

* Fixes in classes TFFFT, FFTBase, DFT, IDFT, MXFFT.

* Added asserts in constructors of operations TFFFT and MXFFT.

* Refactored transformation MXFFTToDFT: conversion of DFT and IDFT were moved into separated functions.

* Moved some commented code.

* Fixed BOM file.

* Written function convert_ifft_to_dft.

* Started to rewrite tests for MXFFTToDFT transformations, in the case is_inverse=False.

* Small fixes.

* Fixes in the transformation RollRealImagPack.

* Renaming tests class for the transformation SSliceComplex.

* Fixes in the function compare_graphs. Now we get all output nodes of op node, and these output nodes are sorted by names.

* Fixed tests for the transformation MXFFTToDFT.

* Fix in the transformation ThresholdedReluDecomposition: added disconnect for trelu input port.

* Fixes in test for the transformation TFSliceToSlice.

* Small fix in the transformation ObjectDetectionAPIPreprocessor2Replacement.

* Small fix in comment.

* Optimized imports.

* Used remove_node in the transformation ThresholdedReluDecomposition and remove_nodes_from in the transformation RollRealImagPack, instead of ports disconnection.

* Deleted commented code.

* Deleted test case test_slice_replacer_begin_with_2_inputs.
2021-05-07 09:44:24 +03:00

52 lines
1.4 KiB
Python

# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from mo.front.common.partial_infer.elemental import copy_shape_infer
from mo.graph.graph import Graph, Node
from mo.graph.perm_inputs import PermuteInputs
from mo.ops.op import Op
class Roll(Op):
"""
Roll operation that shifts elements of a tensor along specified axes.
"""
op = 'Roll'
enabled = False
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'type': self.op,
'op': self.op,
'version': 'opset7',
'infer': roll_infer,
'in_ports_count': 3,
'out_ports_count': 1
}, attrs)
class AttributedRoll(Op):
""" Roll operation that shifts elements of a tensor along specified axes.
This operation uses the same semantics as Roll but with shift and axes specified as attributes.
Shift and axes are specified as attributes in MxNet.
"""
op = 'AttributedRoll'
enabled = False
def __init__(self, graph: Graph, attrs: dict):
super().__init__(graph, {
'type': None,
'op': self.op,
'infer': None,
'in_ports_count': 3,
'out_ports_count': 1,
'shift': None,
'axes': None
}, attrs)
def roll_infer(node: Node):
PermuteInputs().set_input_permutation(node.in_node(2), node, 'input:0', 'axis')
copy_shape_infer(node)