Files
openvino/model-optimizer/extensions/back/StridedSliceMasksNormalizer.py
Pavel Esir 22169a05b9 [MO] StridedSlice improvements (#4139)
* fix ss

* successfully converted

* successfully run moved infer and normalizer unit-tests

* successfully rewritten StridedSlice infer unittests

* int64 array

* Successfully converter crash-when-loading, xj_feauture and toy nets (cherry-picked maxpoolV4 and tf_broadcast_ext)

* successfully moved PermuteAttrs to general mechanism

* successfully converted xj_feauture and crash when loading with the new rewritten SS infer

* fixed get_shape_from_slice and moved to common utils

* fixed extending masks and some other

* some refactoring

* fixed extending masks in extractor, fixed licence year and some other code clearing

* corrected a couple of unittests

* fox permute for 5 rank slice and 4 rank inputs/

* WIP

* Added comments

* fixed StridedSlice in ProposalMutation.py

* rechecked shape_infer unittests added some new cases

* added shape_infer unit-tests after StridedSliceNormalizer pass and Permute unit-tests

* corrected unittests

* Applied review comments

* general permutations for inputs implemented, corrected ellipsis unrolling when shrink_axis is at the beginning, some other corrections

* removed code duplication in infer and normalizer, moved 'slices' attr normalizing to StridedSliceNormalizer.py

* removed some code duplication and other minor improvements

* Added tests

* minor corrections

* wider range of unittests added (froze the number)

* review comments applied

* enabled skipped unit-test

* comment corrections

* applied review comments: changed op -> type, added some asserts, corrected comments and other minor corrections

* sorted inputs, updated Supported_Frameworks_Layers.md, some minor
2021-02-16 11:48:49 +03:00

37 lines
1.4 KiB
Python

"""
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.back.replacement import BackReplacementPattern
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph
class StridedSliceMasksNormalizer(BackReplacementPattern):
enabled = True
force_clean_up = True
def run_after(self):
from extensions.back.ConvolutionNormalizer import DeconvolutionNormalizer
from extensions.back.CropToStridedSlice import CropToStridedSlice
return [CropToStridedSlice, DeconvolutionNormalizer]
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(type='StridedSlice'):
assert node.has_valid('begin_mask')
assert node.has_valid('end_mask')
node.begin_mask = int64_array([1 - i for i in node.begin_mask])
node.end_mask = int64_array([1 - i for i in node.end_mask])