* Written MO classes for DFT and IDFT operations. * Added class to read TF (I)FFT operations. * Written extractors for TF operations FFT, FFT2D, FFT3D, IFFT, IFFT2D, IFFT3D. * Written MO Roll operation and TF Roll operation extractor. * Started to write needed transformations. * Written transformation StridedSlices + Complex + Roll + (i)FFTxD + Roll + (Imag, Real) + Pack -> Roll + (I)DFT + Roll. * Written transformation for Complex + ComplexAbs. * Written correction of axes of Roll. * Small fix. * Small fix. * Some fixes. * Some changes. * Now TF Roll is read as TFRoll. Written inserting Transposes before and after (I)DFT. * Small fix. * Written tests for the transformation TFRollToRoll. * Added comments to some transformations. * Deleted redundant import. * Written tests for the transformation TransposeDFT. * Fixes in MO IR Reader to read/write (I)DFT. * Fixes in the list of supported TF layers. * Started to write tests for SSliceComplexRolledFFTPackBlockReplacement transformation. * Written tests for the MO transformation SSliceComplexRolledFFTPackBlockReplacement. * Written tests for the MO transformation ComplexAbs. * Tests for transformations were moved into unit_tests directory. * All extractors for (I)FFTxD are in one file now. * Deleted redundant transformations. * Fixed extractor for TF Roll: now this operation is read as MO Roll. * Added comments to TFFFT operation. * The method insert_transpose of classes TransposeDFT and LayoutChangeForGatherND was moved into the separate function in the file model-optimizer/extensions/middle/InsertLayoutPropagationTransposes.py. * Fixed comment for the transformation TransposeDFT. * Small fix. * Some fixes. * Deleted shape infer function for the operation TFFFT. Sorted imports in complex_abs.py. * Small fixes. * Deleted redundant import. * Fixes in some asserts. * Small fix. * Added names for created nodes in the transformation ComplexAbs. * Added comments to the method canonicalize_axes. * The transformation SSliceComplexRolledFFTPackBlockReplacement was split into the sequence of transformations SSliceComplexRollReplacement -> RollRealImagPackReplacement -> TFFFTToDFT. * Written tests for the transformation SSliceComplexRollReplacement. * Written tests for the transformation RollRealImagPackReplacement. * Written tests for the transformation TFFFTToDFT. * Deleted commented code. * Fixed types of constants in the transformation ComplexAbs. * Written tests for canonicalization of signal_size value. * Deleted 'Replacement' from names of files and classes. * Used comarison of ids, not names. * replace_sub_graph was replaced with find_and_replace_pattern. * Now the transformation RollRealImagPack is executed before running transformation model-optimizer/extensions/front/Pack.py. * The body of the function create_dft_from_tffft is a part of the transformation TFFFTToDFT body now. * Now method correct_roll_axes of classes RollRealImagPack and SSliceComplexRoll is moved to the function in mo/front/tf/graph_utils.py. * Small changes. * Added comment before mark_input_as_in_correct_layout(roll, 2). * Now the functions correct_roll_axes generates sub-graph in the input port 2 of Roll. * Corrected tests for the transformation SSliceComplexRoll. * Corrected tests for the transformation RollRealImagPack. * Deleted commented code. * Some renaming. * Added decomposition of the separate operation ComplexAbs (without Complex before it). * Added comment to the transformation ComplexAbsAfterComplex. * Optimized imports for the transformation TFFFTToDFT. * The transformation SSliceComplexRoll was split into the sequence SSliceComplex -> CorrectRollAxes and disabled. * Written tests for the transformation ComplexAbs. * Written tests for the transformation SSliceComplex. * Written tests for the transformation CorrectRollAxes. * Deleted the transformation SSliceComplexRoll. * Deleted renaming nodes. * Fixed comment. * Small fixes. * Small fix. * The attribute need_correction was renamed as input_rank_changed. * Small fixes. * Deleted commented code. * Now we iterate over all complex_node.out_port(0).get_connection().get_destinations() input ports and mark the corresponding nodes with the marker attribute. * Added the attribute 'in_ports_count' into the class FFTBase. * Tests for the transformation TransposeDFT were rewritten using helper functions. * Now the transformation RollRealImagPack uses existing Roll node instead of creating new one. * Small fixes. * Fix in the documentation. * Written class to read MxNet (I)FFT operations. Written corresponding extractors. * Corrected shape infer function for MXFFT operation. Written transformation to convert MXFFT to (I)DFT. * Fixed shape infer function. * Fixed the conversion MXFFT to (I)DFT. * Written tests for the transformation MXFFTToDFT. * The function correct_roll_axes was replaced with more generic function add_constant_to_negative_values. * Fixes in classes TFFFT, FFTBase, DFT, IDFT, MXFFT. * Added asserts in constructors of operations TFFFT and MXFFT. * Refactored transformation MXFFTToDFT: conversion of DFT and IDFT were moved into separated functions. * Moved some commented code. * Fixed BOM file. * Written function convert_ifft_to_dft. * Started to rewrite tests for MXFFTToDFT transformations, in the case is_inverse=False. * Small fixes. * Fixes in the transformation RollRealImagPack. * Renaming tests class for the transformation SSliceComplex. * Fixes in the function compare_graphs. Now we get all output nodes of op node, and these output nodes are sorted by names. * Fixed tests for the transformation MXFFTToDFT. * Fix in the transformation ThresholdedReluDecomposition: added disconnect for trelu input port. * Fixes in test for the transformation TFSliceToSlice. * Small fix in the transformation ObjectDetectionAPIPreprocessor2Replacement. * Small fix in comment. * Optimized imports. * Used remove_node in the transformation ThresholdedReluDecomposition and remove_nodes_from in the transformation RollRealImagPack, instead of ports disconnection. * Deleted commented code. * Deleted test case test_slice_replacer_begin_with_2_inputs.
74 lines
2.3 KiB
Python
74 lines
2.3 KiB
Python
# Copyright (C) 2018-2021 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
|
from mo.front.common.partial_infer.utils import int64_array
|
|
from mo.front.common.replacement import FrontReplacementSubgraph
|
|
from mo.front.subgraph_matcher import SubgraphMatch
|
|
from mo.front.tf.graph_utils import add_constant_to_negative_values
|
|
from mo.graph.graph import Graph
|
|
|
|
|
|
class RollRealImagPack(FrontReplacementSubgraph):
|
|
"""
|
|
Some TF models contain Roll for complex data, as a part of the sub-graph
|
|
|
|
input shift axes
|
|
| | |
|
|
-------------------
|
|
Roll
|
|
|
|
|
-------------------
|
|
| |
|
|
Real Imag
|
|
| |
|
|
------- -------
|
|
| |
|
|
Pack
|
|
|
|
|
SomeOp
|
|
|
|
This sub-graph can be replaced with the sub-graph
|
|
|
|
input shift axes
|
|
| | |
|
|
-------------------
|
|
Roll
|
|
|
|
|
SomeOp
|
|
|
|
But after such replacement, we should correct axes of Roll, because input data are real now. Namely, if
|
|
there are negative axes for Roll, we need subtract 1 from such axes indices.
|
|
"""
|
|
enabled = True
|
|
|
|
def run_after(self):
|
|
from extensions.front.tf.SSliceComplex import SSliceComplex
|
|
return [SSliceComplex]
|
|
|
|
def run_before(self):
|
|
from extensions.front.Pack import Pack
|
|
return [Pack]
|
|
|
|
def pattern(self):
|
|
return dict(
|
|
nodes=[
|
|
('unroll', dict(op='Roll')),
|
|
('real', dict(op='Real')),
|
|
('imag', dict(op='Imag')),
|
|
('pack', dict(op='Pack')),
|
|
],
|
|
edges=[
|
|
('unroll', 'real', {'in': 0}),
|
|
('unroll', 'imag', {'in': 0}),
|
|
('real', 'pack', {'in': 0}),
|
|
('imag', 'pack', {'in': 1}),
|
|
])
|
|
|
|
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]):
|
|
unroll = match['unroll']
|
|
add_constant_to_negative_values(unroll, 2, int64_array(-1))
|
|
pack = match['pack']
|
|
pack.out_port(0).get_connection().set_source(unroll.out_port(0))
|
|
graph.remove_nodes_from([match['real'].id, match['imag'].id])
|