* Written MO classes for DFT and IDFT operations. * Added class to read TF (I)FFT operations. * Written extractors for TF operations FFT, FFT2D, FFT3D, IFFT, IFFT2D, IFFT3D. * Written MO Roll operation and TF Roll operation extractor. * Started to write needed transformations. * Written transformation StridedSlices + Complex + Roll + (i)FFTxD + Roll + (Imag, Real) + Pack -> Roll + (I)DFT + Roll. * Written transformation for Complex + ComplexAbs. * Written correction of axes of Roll. * Small fix. * Small fix. * Some fixes. * Some changes. * Now TF Roll is read as TFRoll. Written inserting Transposes before and after (I)DFT. * Small fix. * Written tests for the transformation TFRollToRoll. * Added comments to some transformations. * Deleted redundant import. * Written tests for the transformation TransposeDFT. * Fixes in MO IR Reader to read/write (I)DFT. * Fixes in the list of supported TF layers. * Started to write tests for SSliceComplexRolledFFTPackBlockReplacement transformation. * Written tests for the MO transformation SSliceComplexRolledFFTPackBlockReplacement. * Written tests for the MO transformation ComplexAbs. * Tests for transformations were moved into unit_tests directory. * All extractors for (I)FFTxD are in one file now. * Deleted redundant transformations. * Fixed extractor for TF Roll: now this operation is read as MO Roll. * Added comments to TFFFT operation. * The method insert_transpose of classes TransposeDFT and LayoutChangeForGatherND was moved into the separate function in the file model-optimizer/extensions/middle/InsertLayoutPropagationTransposes.py. * Fixed comment for the transformation TransposeDFT. * Small fix. * Some fixes. * Deleted shape infer function for the operation TFFFT. Sorted imports in complex_abs.py. * Small fixes. * Deleted redundant import. * Fixes in some asserts. * Small fix. * Added names for created nodes in the transformation ComplexAbs. * Added comments to the method canonicalize_axes. * The transformation SSliceComplexRolledFFTPackBlockReplacement was split into the sequence of transformations SSliceComplexRollReplacement -> RollRealImagPackReplacement -> TFFFTToDFT. * Written tests for the transformation SSliceComplexRollReplacement. * Written tests for the transformation RollRealImagPackReplacement. * Written tests for the transformation TFFFTToDFT. * Deleted commented code. * Fixed types of constants in the transformation ComplexAbs. * Written tests for canonicalization of signal_size value. * Deleted 'Replacement' from names of files and classes. * Used comarison of ids, not names. * replace_sub_graph was replaced with find_and_replace_pattern. * Now the transformation RollRealImagPack is executed before running transformation model-optimizer/extensions/front/Pack.py. * The body of the function create_dft_from_tffft is a part of the transformation TFFFTToDFT body now. * Now method correct_roll_axes of classes RollRealImagPack and SSliceComplexRoll is moved to the function in mo/front/tf/graph_utils.py. * Small changes. * Added comment before mark_input_as_in_correct_layout(roll, 2). * Now the functions correct_roll_axes generates sub-graph in the input port 2 of Roll. * Corrected tests for the transformation SSliceComplexRoll. * Corrected tests for the transformation RollRealImagPack. * Deleted commented code. * Some renaming. * Added decomposition of the separate operation ComplexAbs (without Complex before it). * Added comment to the transformation ComplexAbsAfterComplex. * Optimized imports for the transformation TFFFTToDFT. * The transformation SSliceComplexRoll was split into the sequence SSliceComplex -> CorrectRollAxes and disabled. * Written tests for the transformation ComplexAbs. * Written tests for the transformation SSliceComplex. * Written tests for the transformation CorrectRollAxes. * Deleted the transformation SSliceComplexRoll. * Deleted renaming nodes. * Fixed comment. * Small fixes. * Small fix. * The attribute need_correction was renamed as input_rank_changed. * Small fixes. * Deleted commented code. * Now we iterate over all complex_node.out_port(0).get_connection().get_destinations() input ports and mark the corresponding nodes with the marker attribute. * Added the attribute 'in_ports_count' into the class FFTBase. * Tests for the transformation TransposeDFT were rewritten using helper functions. * Now the transformation RollRealImagPack uses existing Roll node instead of creating new one. * Small fixes. * Fix in the documentation. * Written class to read MxNet (I)FFT operations. Written corresponding extractors. * Corrected shape infer function for MXFFT operation. Written transformation to convert MXFFT to (I)DFT. * Fixed shape infer function. * Fixed the conversion MXFFT to (I)DFT. * Written tests for the transformation MXFFTToDFT. * The function correct_roll_axes was replaced with more generic function add_constant_to_negative_values. * Fixes in classes TFFFT, FFTBase, DFT, IDFT, MXFFT. * Added asserts in constructors of operations TFFFT and MXFFT. * Refactored transformation MXFFTToDFT: conversion of DFT and IDFT were moved into separated functions. * Moved some commented code. * Fixed BOM file. * Written function convert_ifft_to_dft. * Started to rewrite tests for MXFFTToDFT transformations, in the case is_inverse=False. * Small fixes. * Fixes in the transformation RollRealImagPack. * Renaming tests class for the transformation SSliceComplex. * Fixes in the function compare_graphs. Now we get all output nodes of op node, and these output nodes are sorted by names. * Fixed tests for the transformation MXFFTToDFT. * Fix in the transformation ThresholdedReluDecomposition: added disconnect for trelu input port. * Fixes in test for the transformation TFSliceToSlice. * Small fix in the transformation ObjectDetectionAPIPreprocessor2Replacement. * Small fix in comment. * Optimized imports. * Used remove_node in the transformation ThresholdedReluDecomposition and remove_nodes_from in the transformation RollRealImagPack, instead of ports disconnection. * Deleted commented code. * Deleted test case test_slice_replacer_begin_with_2_inputs.
144 lines
6.3 KiB
Python
144 lines
6.3 KiB
Python
# Copyright (C) 2018-2021 Intel Corporation
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
|
import unittest
|
|
|
|
from extensions.front.tf.SSliceComplex import SSliceComplex
|
|
from mo.front.common.partial_infer.utils import int64_array
|
|
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
|
from unit_tests.utils.graph import build_graph
|
|
|
|
|
|
graph_node_attrs = {
|
|
'placeholder': {'shape': int64_array([3, 100, 100, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
|
'strided_slice_real': {
|
|
'type': 'StridedSlice', 'kind': 'op', 'op': 'StridedSlice', 'begin_mask': int64_array([1]),
|
|
'end_mask': int64_array([1]), 'ellipsis_mask': int64_array([1]), 'new_axis_mask': int64_array([0]),
|
|
'shrink_axis_mask': int64_array([0, 1]),
|
|
},
|
|
'strided_slice_imag': {
|
|
'type': 'StridedSlice', 'kind': 'op', 'op': 'StridedSlice', 'begin_mask': int64_array([1]),
|
|
'end_mask': int64_array([1]), 'ellipsis_mask': int64_array([1]), 'new_axis_mask': int64_array([0]),
|
|
'shrink_axis_mask': int64_array([0, 1]),
|
|
},
|
|
'complex': {'kind': 'op', 'op': 'Complex'},
|
|
'real_begin': {
|
|
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([0, 0])
|
|
},
|
|
'imag_begin': {
|
|
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([0, 1])
|
|
},
|
|
'real_end': {
|
|
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([0, 1])
|
|
},
|
|
'imag_end': {
|
|
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([0, 2])
|
|
},
|
|
'real_strides': {
|
|
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([1, 1])
|
|
},
|
|
'imag_strides': {
|
|
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([1, 1])
|
|
},
|
|
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
|
|
'output': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
|
|
}
|
|
|
|
graph_edges = [
|
|
('placeholder', 'strided_slice_real', {'out': 0, 'in': 0}),
|
|
('placeholder', 'strided_slice_imag', {'out': 0, 'in': 0}),
|
|
('strided_slice_real', 'complex', {'in': 0}),
|
|
('strided_slice_imag', 'complex', {'in': 1}),
|
|
('complex', 'abs'),
|
|
('abs', 'output'),
|
|
('real_begin', 'strided_slice_real', {'in': 1}),
|
|
('imag_begin', 'strided_slice_imag', {'in': 1}),
|
|
('real_end', 'strided_slice_real', {'in': 2}),
|
|
('imag_end', 'strided_slice_imag', {'in': 2}),
|
|
('real_strides', 'strided_slice_real', {'in': 3}),
|
|
('imag_strides', 'strided_slice_imag', {'in': 3}),
|
|
]
|
|
|
|
|
|
ref_graph_node_attrs = {
|
|
'placeholder': {'shape': int64_array([3, 100, 100, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
|
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
|
|
'output': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
|
|
}
|
|
|
|
ref_graph_edges = [
|
|
('placeholder', 'abs'),
|
|
('abs', 'output'),
|
|
]
|
|
|
|
|
|
non_transformed_graph_node_attrs = {
|
|
'placeholder_0': {'shape': int64_array([3, 100, 100, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
|
'placeholder_1': {'shape': int64_array([3, 100, 100, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
|
'strided_slice_real': {
|
|
'type': 'StridedSlice', 'kind': 'op', 'op': 'StridedSlice', 'begin_mask': int64_array([1]),
|
|
'end_mask': int64_array([1]), 'ellipsis_mask': int64_array([1]), 'new_axis_mask': int64_array([0]),
|
|
'shrink_axis_mask': int64_array([0, 1]),
|
|
},
|
|
'strided_slice_imag': {
|
|
'type': 'StridedSlice', 'kind': 'op', 'op': 'StridedSlice', 'begin_mask': int64_array([1]),
|
|
'end_mask': int64_array([1]), 'ellipsis_mask': int64_array([1]), 'new_axis_mask': int64_array([0]),
|
|
'shrink_axis_mask': int64_array([0, 1]),
|
|
},
|
|
'complex': {'kind': 'op', 'op': 'Complex'},
|
|
'real_begin': {
|
|
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([0, 0])
|
|
},
|
|
'imag_begin': {
|
|
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([0, 1])
|
|
},
|
|
'real_end': {
|
|
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([0, 1])
|
|
},
|
|
'imag_end': {
|
|
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([0, 2])
|
|
},
|
|
'real_strides': {
|
|
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([1, 1])
|
|
},
|
|
'imag_strides': {
|
|
'type': 'Const', 'kind': 'op', 'op': 'Const', 'shape': int64_array([2]), 'value': int64_array([1, 1])
|
|
},
|
|
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
|
|
'output': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
|
|
}
|
|
|
|
non_transformed_graph_edges = [
|
|
('placeholder_0', 'strided_slice_real', {'out': 0, 'in': 0}),
|
|
('placeholder_1', 'strided_slice_imag', {'out': 0, 'in': 0}),
|
|
('strided_slice_real', 'complex', {'in': 0}),
|
|
('strided_slice_imag', 'complex', {'in': 1}),
|
|
('complex', 'abs'),
|
|
('abs', 'output'),
|
|
('real_begin', 'strided_slice_real', {'in': 1}),
|
|
('imag_begin', 'strided_slice_imag', {'in': 1}),
|
|
('real_end', 'strided_slice_real', {'in': 2}),
|
|
('imag_end', 'strided_slice_imag', {'in': 2}),
|
|
('real_strides', 'strided_slice_real', {'in': 3}),
|
|
('imag_strides', 'strided_slice_imag', {'in': 3}),
|
|
]
|
|
|
|
|
|
class SSliceComplexTest(unittest.TestCase):
|
|
def test_replacement(self):
|
|
graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges)
|
|
graph.stage = 'front'
|
|
SSliceComplex().find_and_replace_pattern(graph)
|
|
ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs, edges=ref_graph_edges)
|
|
(flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
|
|
self.assertTrue(flag, resp)
|
|
|
|
def test_nonreplacement(self):
|
|
graph = build_graph(nodes_attrs=non_transformed_graph_node_attrs, edges=non_transformed_graph_edges)
|
|
ref_graph = build_graph(nodes_attrs=non_transformed_graph_node_attrs, edges=non_transformed_graph_edges)
|
|
graph.stage = 'front'
|
|
SSliceComplex().find_and_replace_pattern(graph)
|
|
(flag, resp) = compare_graphs(graph, ref_graph, 'output', check_op_attrs=True)
|
|
self.assertTrue(flag, resp)
|