* Refactored infer function and function supported_attrs for the layer Interpolate. * Small change. * Deleted unneeded checks in transformations ResizeToInterpolate2D and ResizeToInterpolate3D. * Small fix in the extractor of ONNX Resize. * Now the extractor of TF ResizeBilinear generates Interpolate-1 again, because 'axis' in final version of Interpolate-4 specification is an input but is not attribute. * Now the extractor of TF ResizeNearest generates Interpolate-1 again, because 'axis' in final version of Interpolate-4 specification is an input but is not attribute. * Added static method get_axis into class Interpolate. * Refactored class CanBeFused in the transformation InterpolateSequenceToInterpolate. * Fixed transformation InterpolateSequenceToInterpolate according to the last version of the specification of Interpolate-4. * Started to write support of Interpolate-4 in the transformation InterpolateWithConcat. * Added support for Interpolate-4 into the transformation InterpolateWithConcat. * Added support for Interpolate-4 into the transformation InterpolateConcat. * Added support for Interpolate-4 into the transformation InterpolateReshapeWA. * Added support for Interpolate-4 into the transformation InterpolateTranspose. * Started to add test for opset4 case of the transformation InterpolateSequenceToInterpolate. * Added test for InterpolateSequenceToInterpolate (test_2d_interpolate_sequence_1_opset4_case). * Added test for InterpolateSequenceToInterpolate (test_2d_interpolate_sequence_4_opset4_case). * Added another test for InterpolateSequenceToInterpolate (test_2d_interpolate_sequence_5_opset4_case). * Added another test for InterpolateSequenceToInterpolate (test_3d_interpolate_sequence_1_opset4_case). * Finished addition of tests for opset4 case of InterpolateSequenceToInterpolate. * Small change. * Now opset is only opset1 or opset4 in the transformation InterpolateTranspose. * Small fixes in transformations ResizeToInterpolate2D and ResizeToInterpolate3D. * Deleted reading of unused ONNX attributes. * Fixed docstring of the transformation InterpolateV1ToInterpolateV4. * Added node name in assert about axes input. * Fixes in the definition of the operation ONNXResize11. * Now Interpolate-4 cannot have 'extension' as opset. * Now the transformation InterpolateV1ToInterpolateV4 uses find_and_replace_pattern but not replace_sub_graph. * Fixed tests for transformations InterpolateReshapeWA and InterpolateConcat. * Fixed some tests. * Rewritten operation Interpolate-4 class according to new variant of documentation. * Some fixes in ONNXResize11 operation class. * Now the transformation ONNXResize11ToInterpolate generates Interpolate-4 with 4 inputs. * Now the transformation UpsampleToResample generates Interpolate-4 with 4 inputs. * Now the transformation NearestNeighborUpsampling generates Interpolate-4 with 4 inputs. * Now transformations ResizeToInterpolate2D and ResizeToInterpolate3D generate Interpolate-4 with 4 inputs. * Now the transformation SplitConcatPairToInterpolate generates Interpolate-4 with 4 inputs. * Now the transformation UnsqueezeTileReshapeBlockToInterpolate generates Interpolate-4 with 4 inputs. * Now the transformation InterpolateV1ToInterpolateV4 generates Interpolate-4 with 4 inputs. * Some fixes. * Fixed the transformation InterpolateSequenceToInterpolate according to new variant of Interpolate-4 specification. * Fixed typos. * Added shape_calculation_mode to supported_attrs. * Small fixes. * Added operation ONNXResize10 and the transformation ONNXResize10ToInterpolate4. * Fixed function correct_scales_using_dst_shape. * Some fixes in InterpolateSequenceToInterpolate. * Fixed bug in the method __call__ of the class CanBeFused: now self.accumulated_axes is correctly cleared in all cases. * Small change. * Fixed tests for the transformation SplitConcatPairToInterpolate. * Now transformations InterpolateWithConcat, InterpolateReshapeWA, InterpolateConcat support Interpolate-4. * Fixed the transformation InterpolateTranspose for the case of Interpolate-4. * Written the back transformation InterpolateV4AxesCorrection to convert 'axes' input of Interpolate-4 from NHWC to NCHW layout. * Added PermuteInput in Interpolate-4 infer. * Fixed typos. * Deleted the transformation InterpolateAxesCorrection. * Now Interpolate-4 permutes axis, not shape in input port 3. * Small fix. * Some fix. * Fixed bug in the transformation UpsampleToResample. * Added some debug prints. * Added more debug prints. * Now ONNX Upsample-9 operation is read as ONNXResize10. * Small fix. * Small fixes. * Fixed tests for the transformation SplitConcatPairToInterpolate. * Deleted debug prints. * Deleted some debug prints. * Fixes in the transformation UnsqueezeTileReshapeBlockToInterpolate and its tests. * Small fix in the transformation InterpolateSequenceToInterpolate. * Started to write nGraph transformation to convert Interpolate-1 to Interpolate-4. * Deleted redundant files. * Small fixes. * Small fix. * Written draft of the transformation Interpolate-1 -> Interpolate-4. * Small fix. * Now ONNX Importer reads Resize-10 as Interpolate-4. * Fixes in the test onnx_model_resize10_import_only. * Small fix in the test for the conversion Interpolate-1 -> Interpolate-4. * Small fixes. * Fixed NGraphReaderTests for Interpolate. * Some fixes. * Deleted class for Resample operation. * Fix in the transformation NearestNeighborUpsampling: fixed precision of the input 'scales' of generated Interpolate-4. * Fixed typo. * Now the TF operations ResizeBilinear is readed as internal MO operation TFResizeBilinear. This internal operation is converted into Interpolate-4. * Small fix in BOM-file. * Added checks of existence of attributes of TF ResizeBilinear operation. * Small fixes in the conversion of the internal MO operation TFResizeBilinear to Interpolate-4. * Small fixes. * Small fixes. * Now the transformation ONNXResize10ToInterpolateV4 calculates sizes input as input_shape * (scales + epsilon). * Added the internal MO operation TFResizeNearestNeighbor. * Fixes in the transformation SplitConcatPairToInterpolate and its tests. * Fixes in the transformation UnsqueezeTileReshapeBlockToInterpolate and its tests. * Written the transformation that converts the internal operation TFResizeNearestNeighbor into Interpolate-4. * Now MO reads the TF operation ResizeNearestNeighbor as the internal MO operation TFResizeNearestNeighbor. * Small fix. * Now the specification of Interpolate-4 clarifies that the mode linear_onnx supports only 2D or 4D input tensors. * Small fix. * Some fixes. * Moved the transformation ONNXResize10ToInterpolateV4 to the front stage. * Deleted infer function and function supported_attrs for ONNXResize10 operation. * Deleted supported_attrs() for TFResizeBilinear and TFResizeNearestNeighbor. * Some fixes. * Fixes in the shape infer function of the nGraph operation Interpolate-4. Now 'axes' input can be non-constant. In the such case, all elements of the output shape are Dimension::dynamic(). * Deleted corner cases processing in transformations TFResizeBilinearToInterpolateV4 and TFResizeNearestNeighborToInterpolateV4. * Rewritten the function replace_resize_bilinear. * Written inner MO operation TFResize that covers TF operations ResizeBilinear and ResizeNearestNeighbor. * Now TF operations ResizeBilinear and ResizeNearestNeighbor are read as an internal operation TFResize in MO. Transformations TFResizeNearestNeighborToInterpolateV4 and TFResizeBilinearToInterpolateV4 are fused into one transformation TFResizeToInterpolateV4. * Some changes in the shape infer function of nGraph op Interpolate-4. * Small fix. * Some changes. * The transformation TFResizeToInterpolateV4 is moved to the front stage. * Deleted redundant assert. * Deleted transformations ResizeToInterpolate2D and ResizeToInterpolate3D. * Some renaming. * Small change. * Deleted .copy() in the shape infer function of the internal operation TFResize. * Small fix. * Small fixes. * Added comment about the case when the input 'axes' of Interpolate-4 is non-constant. * Written test for Interpolate-4 shape infer, for the case when the input 'axes' is non-constant and shape_calculation_mode = scales. * Some fixes. * Small fixes. * Small fix. * Added yet another test for the case of non-constant 'axes' input of Interpolate-4 (when shape_calculation_mode = sizes). * Added some comment. * Small fix. * Reverted changes for InterpolateWithConcat. * Added type checks for all inputs of nGraph operation Interpolate-4. * Added u32 and u64 to supported element types of sizes and axes inputs of nGraph operation Interpolate-4. * Fixed some functional tests. * Some changes. * Added helper function float32_array. * Now the MO transformation InterpolateV1ToInterpolate preserves names of layers. * Small fix. * Small fix. * Reverted some change. * Small fixes. * Small fix. * Small fix. * Small fix. * Small fix. * Reverted changes in the nGraph reader tests for Interpolate-1. * Some revert. * Fixed some copyright year.
185 lines
6.3 KiB
Python
185 lines
6.3 KiB
Python
"""
|
|
Copyright (C) 2018-2021 Intel Corporation
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
|
|
import math
|
|
|
|
import numpy as np
|
|
|
|
from mo.front.common.partial_infer.utils import int64_array
|
|
from mo.front.extractor import bool_to_str
|
|
from mo.graph.graph import Node, Graph
|
|
from mo.graph.perm_inputs import PermuteInputs
|
|
from mo.ops.op import Op, PermuteAttrs
|
|
|
|
|
|
def infer_for_opset4(node: Node):
|
|
assert len([p for p in node.in_ports().values() if not p.disconnected()]) in [3, 4], \
|
|
"Interpolate-4 node {} must have 3 or 4 inputs".format(node.soft_get(node.name, node.id))
|
|
assert node.has_valid('mode')
|
|
assert node.has_valid('shape_calculation_mode')
|
|
src_shape = node.in_port(0).data.get_shape()
|
|
assert src_shape is not None
|
|
|
|
input_rank = len(src_shape)
|
|
|
|
pads_begin = correct_pad(node.soft_get('pads_begin', [0]), input_rank)
|
|
pads_end = correct_pad(node.soft_get('pads_end', [0]), input_rank)
|
|
node['pads_begin'] = pads_begin
|
|
node['pads_end'] = pads_end
|
|
|
|
if len(node.in_ports()) == 3:
|
|
axes = list(range(0, input_rank))
|
|
else:
|
|
axes = node.in_port(3).get_source().data.get_value()
|
|
assert axes is not None, \
|
|
"Interpolate-4 node with name {} has None as 'axes' input".format(node.soft_get('name', node.id))
|
|
|
|
axes = int64_array(axes)
|
|
output_shape = src_shape + pads_begin + pads_end
|
|
if node.shape_calculation_mode == 'sizes':
|
|
dst_shape = node.in_port(1).data.get_value()
|
|
assert dst_shape is not None
|
|
correct_scales_using_dst_shape(node, dst_shape, src_shape, axes)
|
|
for i, axis in enumerate(axes):
|
|
output_shape[axis] = dst_shape[i]
|
|
else:
|
|
scales = node.in_port(2).data.get_value()
|
|
assert scales is not None
|
|
for i, axis in enumerate(axes):
|
|
output_shape[axis] = math.floor(scales[i] * output_shape[axis] + 1.0e-5)
|
|
|
|
if node.is_in_port_connected(3):
|
|
PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:0', 'axis')
|
|
|
|
node.out_port(0).data.set_shape(output_shape)
|
|
|
|
|
|
def infer_for_opset1(node: Node):
|
|
assert len([p for p in node.in_ports().values() if not p.disconnected()]) == 2
|
|
assert node.has_valid('mode')
|
|
assert node.has_valid('axes')
|
|
|
|
src_shape = node.in_port(0).data.get_shape()
|
|
|
|
assert src_shape is not None
|
|
dst_shape = node.in_port(1).data.get_value()
|
|
assert dst_shape is not None
|
|
|
|
output_shape = src_shape.copy()
|
|
for ind, axis in enumerate(node.axes):
|
|
output_shape[axis] = dst_shape[ind]
|
|
|
|
node.out_port(0).data.set_shape(output_shape)
|
|
|
|
PermuteAttrs.create_permute_attrs(node, attrs=[('axes', 'input:0')])
|
|
|
|
|
|
def pad_attribute_to_str(node: Node, attr: str):
|
|
return ','.join(map(str, node[attr])) if node.has_valid(attr) else None
|
|
|
|
|
|
def correct_pad(pad, rank):
|
|
pad_len = len(pad)
|
|
if pad_len < rank:
|
|
return np.pad(pad, (0, rank - pad_len), 'constant').astype(np.int64)
|
|
elif pad_len > rank:
|
|
return np.array(pad[: rank]).astype(np.int64)
|
|
else:
|
|
return np.array(pad, dtype=np.int64)
|
|
|
|
|
|
def correct_scales_using_dst_shape(node, dst_shape, src_shape, axes):
|
|
scales_value = node.in_port(2).data.get_value()
|
|
if scales_value is None or len(scales_value) != len(dst_shape):
|
|
corrected_scales = np.zeros(len(dst_shape))
|
|
for i, axis in enumerate(list(axes)):
|
|
corrected_scales[i] = dst_shape[i] / src_shape[axis]
|
|
node.in_port(2).data.set_value(corrected_scales)
|
|
|
|
|
|
class Interpolate(Op):
|
|
op = 'Interpolate'
|
|
enabled = False
|
|
infers = {
|
|
'opset1': infer_for_opset1,
|
|
'opset4': infer_for_opset4
|
|
}
|
|
|
|
def __init__(self, graph: Graph, attrs: dict):
|
|
self.attributes_for_opsets = {
|
|
'opset1': [
|
|
('axes', lambda node: ','.join(map(str, node.axes))),
|
|
('antialias', lambda node: bool_to_str(node, 'antialias')),
|
|
('align_corners', lambda node: bool_to_str(node, 'align_corners')),
|
|
'mode', 'pads_begin', 'pads_end',
|
|
],
|
|
'opset4': [
|
|
'mode', 'nearest_mode', 'cube_coeff', 'coordinate_transformation_mode',
|
|
'shape_calculation_mode',
|
|
('antialias', lambda node: bool_to_str(node, 'antialias')),
|
|
('pads_begin', lambda node: pad_attribute_to_str(node, 'pads_begin')),
|
|
('pads_end', lambda node: pad_attribute_to_str(node, 'pads_end')),
|
|
]
|
|
}
|
|
|
|
mandatory_props = {
|
|
'op': self.op,
|
|
'type': self.op,
|
|
'version': 'opset1',
|
|
|
|
'axes': None,
|
|
'mode': None,
|
|
'align_corners': 0,
|
|
'antialias': 0,
|
|
'pads_begin': 0,
|
|
'pads_end': 0,
|
|
|
|
'infer': self.infer,
|
|
|
|
'force_precision_in_ports': {1: 'int64'},
|
|
'in_ports_count': 2,
|
|
'out_ports_count': 1,
|
|
}
|
|
super().__init__(graph, mandatory_props, attrs)
|
|
|
|
def supported_attrs(self):
|
|
opset = self.get_opset()
|
|
key = opset if opset in self.attributes_for_opsets else 'opset1'
|
|
return self.attributes_for_opsets[key]
|
|
|
|
def infer(self, node: Node):
|
|
opset = self.get_opset()
|
|
key = opset if opset in self.infers else 'opset1'
|
|
self.infers[key](node)
|
|
|
|
@staticmethod
|
|
def get_axes(node: Node) -> np.ndarray:
|
|
opset = node.get_opset()
|
|
if opset == 'opset1':
|
|
interp_axes = node.soft_get('axes', None)
|
|
return interp_axes if interp_axes is None else int64_array(interp_axes)
|
|
|
|
src_shape = node.in_port(0).data.get_shape()
|
|
assert src_shape is not None
|
|
input_rank = len(src_shape)
|
|
|
|
if len(node.in_ports()) == 3:
|
|
axes = list(range(0, input_rank))
|
|
else:
|
|
axes = node.in_port(3).get_source().data.get_value()
|
|
return int64_array(axes)
|