Files
openvino/model-optimizer/extensions/ops/interpolate.py
Vladimir Gavrilov dca30b4522 Extend MO for support of Interpolate-4 (#2026)
* Commit.

* Added opset4 version in the class Interpolate.

* Added class ONNXResize11Op to read ONNX Resize with opset version >= 11.

* Added support for Interpolate-4 into transformations TestInterpolateReshapeWA and InterpolateConcat.

* Added support for Interpolate-4 into transformation InterpolateWithConcat.

* Deleted redundant checks from the transformation UpsampleToResample.

* Reverted last changes.

* Changed ONNX Resize extractor to support for Interpolate-4.

* Added conversion of ONNXResize11Op into Interpolate-4.

* Added support for Interpolate-4 into the transformation InterpolateSequenceToInterpolate.

* Small fix for formatting.

* Written tests for MO version of Interpolate-4 with shape_calculation_mode = sizes.

* Written tests for infer function of Interpolate-4.

* Now transformations InterpolateWithConcat, InterpolateConcat, InterpolateReshapeWA skip Interpolate-4.

* Used create_op_with_const_inputs in the transformation InterpolateSequenceToInterpolate.

* The transformation ONNXResize11ToInterpolate4 was rewritten using find_and_replace_pattern.

* Now the dictionary infers (dictionary of infer functions of Interpolate) is a class static attribute.

* Deleted unused variable.

* Restored original logic of find_and_replace_pattern method of the class InterpolateReshapeWA.

* Used create_op_with_const_inputs() in the transformation InterpolateSequenceToInterpolate for opset1 case.

* Replaced resize_name by resize.soft_get('name', resize.id).

* Small fixes.

* Added two tests for Interpolate-4 infer function.

* Fixed the transformation ONNXResize11ToInterpolateV4 for the case when ONNXResize11 operation has 3 inputs.

* Added conversion of ONNXResize11 with tf_crop_and_resize_mode to ROIPooling + ONNXResize11.

* Fixed bugs in the transformation ONNXResize11ToInterpolateV4 and in the infer function of the operation ONNXResize11.

* Small changes.

* Renamed transformation that converts ONNXResize11 into ROIPooling + ONNXResize11 and fixed BOM-file.

* Fixed tests for the transformation InterpolateSequenceToInterpolate.

* Small change.

* Now the transformation InterpolateSequenceToInterpolate preserves output layer name.

* Deleted the transformation ONNXResize11ToTFCropAndResize.
2020-09-09 16:28:52 +03:00

175 lines
5.8 KiB
Python

"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import math
import numpy as np
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node, Graph
from mo.ops.op import Op, PermuteAttrs
def infer_for_opset4(node: Node):
assert len([p for p in node.in_ports().values() if not p.disconnected()]) in [3, 4], \
"Interpolate-4 node {} must have 3 or 4 inputs".format(node.soft_get(node.name, node.id))
assert node.has_valid('mode')
assert node.has_valid('shape_calculation_mode')
src_shape = node.in_port(0).data.get_shape()
assert src_shape is not None
input_rank = len(src_shape)
pads_begin = correct_pad(node.soft_get('pads_begin', [0]), input_rank)
pads_end = correct_pad(node.soft_get('pads_end', [0]), input_rank)
node['pads_begin'] = pads_begin
node['pads_end'] = pads_end
if len(node.in_ports()) == 3:
axes = list(range(0, input_rank))
else:
axes = node.in_port(3).get_source().data.get_value()
assert axes is not None, \
"Interpolate-4 node with name {} has None as 'axes' input".format(node.soft_get('name', node.id))
axes = int64_array(axes)
output_shape = src_shape + pads_begin + pads_end
if node.shape_calculation_mode == 'sizes':
dst_shape = node.in_port(1).data.get_value()
assert dst_shape is not None
correct_scales_using_dst_shape(node, dst_shape, src_shape, axes)
for i, axis in enumerate(axes):
output_shape[axis] = dst_shape[i]
else:
scales = node.in_port(2).data.get_value()
assert scales is not None
for i, axis in enumerate(axes):
output_shape[axis] = math.floor(scales[i] * output_shape[axis] + 1.0e-5)
node.out_port(0).data.set_shape(output_shape)
def infer_for_opset1(node: Node):
assert len([p for p in node.in_ports().values() if not p.disconnected()]) == 2
assert node.has_valid('mode')
assert node.has_valid('axes')
src_shape = node.in_port(0).data.get_shape()
assert src_shape is not None
dst_shape = node.in_port(1).data.get_value()
assert dst_shape is not None
output_shape = src_shape.copy()
for ind, axis in enumerate(node.axes):
output_shape[axis] = dst_shape[ind]
node.out_port(0).data.set_shape(output_shape)
PermuteAttrs.create_permute_attrs(node, attrs=[('axes', 'input:0')])
def pad_attribute_to_str(node: Node, attr: str):
return ','.join(map(str, node[attr])) if node.has_valid(attr) else None
def correct_pad(pad, rank):
pad_len = len(pad)
if pad_len < rank:
return np.pad(pad, (0, rank - pad_len), 'constant').astype(np.int64)
elif pad_len > rank:
return np.array(pad[: rank]).astype(np.int64)
else:
return np.array(pad, dtype=np.int64)
def correct_scales_using_dst_shape(node, dst_shape, src_shape, axes):
scales_value = node.in_port(2).data.get_value()
if scales_value is None or len(scales_value) != len(dst_shape):
corrected_scales = np.zeros(len(dst_shape))
for i, axis in enumerate(list(axes)):
corrected_scales[i] = math.floor((dst_shape[i] / src_shape[axis]) + 1.0e-5)
class Interpolate(Op):
op = 'Interpolate'
enabled = False
infers = {
'opset1': infer_for_opset1,
'opset4': infer_for_opset4
}
def __init__(self, graph: Graph, attrs: dict):
self.attributes_for_opsets = {
'opset1': [
('axes', lambda node: ','.join(map(str, node.axes))),
'mode', 'align_corners', 'antialias', 'pads_begin', 'pads_end',
],
'opset4': [
'mode', 'antialias', 'nearest_mode', 'cube_coeff', 'coordinate_transformation_mode',
'shape_calculation_mode',
('pads_begin', lambda node: pad_attribute_to_str(node, 'pads_begin')),
('pads_end', lambda node: pad_attribute_to_str(node, 'pads_end')),
]
}
mandatory_props = {
'op': self.op,
'type': self.op,
'version': 'opset1',
'axes': None,
'mode': None,
'align_corners': 0,
'antialias': 0,
'pads_begin': 0,
'pads_end': 0,
'infer': self.infer,
'force_precision_in_ports': {1: 'int64'},
'in_ports_count': 2,
'out_ports_count': 1,
}
super().__init__(graph, mandatory_props, attrs)
def supported_attrs(self):
opset = self.get_opset()
key = opset if opset in self.attributes_for_opsets else 'opset1'
return self.attributes_for_opsets[key]
def infer(self, node: Node):
opset = self.get_opset()
key = opset if opset in self.infers else 'opset1'
self.infers[key](node)
@staticmethod
def get_axes(node: Node) -> np.ndarray:
opset = node.get_opset()
if opset == 'opset1':
interp_axes = node.soft_get('axes', None)
return interp_axes if interp_axes is None else int64_array(interp_axes)
src_shape = node.in_port(0).data.get_shape()
assert src_shape is not None
input_rank = len(src_shape)
if len(node.in_ports()) == 3:
axes = list(range(0, input_rank))
else:
axes = node.in_port(3).get_source().data.get_value()
return int64_array(axes)