Completely transition of MO to Interpolate-4 (#970)

* Refactored infer function and function supported_attrs for the layer Interpolate.

* Small change.

* Deleted unneeded checks in transformations ResizeToInterpolate2D and ResizeToInterpolate3D.

* Small fix in the extractor of ONNX Resize.

* Now the extractor of TF ResizeBilinear generates Interpolate-1 again, because 'axis' in final version of Interpolate-4 specification is an input but is not attribute.

* Now the extractor of TF ResizeNearest generates Interpolate-1 again, because 'axis' in final version of Interpolate-4 specification is an input but is not attribute.

* Added static method get_axis into class Interpolate.

* Refactored class CanBeFused in the transformation InterpolateSequenceToInterpolate.

* Fixed transformation InterpolateSequenceToInterpolate according to the last version of the specification of Interpolate-4.

* Started to write support of Interpolate-4 in the transformation InterpolateWithConcat.

* Added support for Interpolate-4 into the transformation InterpolateWithConcat.

* Added support for Interpolate-4 into the transformation InterpolateConcat.

* Added support for Interpolate-4 into the transformation InterpolateReshapeWA.

* Added support for Interpolate-4 into the transformation InterpolateTranspose.

* Started to add test for opset4 case of the transformation InterpolateSequenceToInterpolate.

* Added test for InterpolateSequenceToInterpolate (test_2d_interpolate_sequence_1_opset4_case).

* Added test for InterpolateSequenceToInterpolate (test_2d_interpolate_sequence_4_opset4_case).

* Added another test for InterpolateSequenceToInterpolate (test_2d_interpolate_sequence_5_opset4_case).

* Added another test for InterpolateSequenceToInterpolate (test_3d_interpolate_sequence_1_opset4_case).

* Finished addition of tests for opset4 case of InterpolateSequenceToInterpolate.

* Small change.

* Now opset is only opset1 or opset4 in the transformation InterpolateTranspose.

* Small fixes in transformations ResizeToInterpolate2D and ResizeToInterpolate3D.

* Deleted reading of unused ONNX attributes.

* Fixed docstring of the transformation InterpolateV1ToInterpolateV4.

* Added node name in assert about axes input.

* Fixes in the definition of the operation ONNXResize11.

* Now Interpolate-4 cannot have 'extension' as opset.

* Now the transformation InterpolateV1ToInterpolateV4 uses find_and_replace_pattern but not replace_sub_graph.

* Fixed tests for transformations InterpolateReshapeWA and InterpolateConcat.

* Fixed some tests.

* Rewritten operation Interpolate-4 class according to new variant of documentation.

* Some fixes in ONNXResize11 operation class.

* Now the transformation ONNXResize11ToInterpolate generates Interpolate-4 with 4 inputs.

* Now the transformation UpsampleToResample generates Interpolate-4 with 4 inputs.

* Now the transformation NearestNeighborUpsampling generates Interpolate-4 with 4 inputs.

* Now transformations ResizeToInterpolate2D and ResizeToInterpolate3D generate Interpolate-4 with 4 inputs.

* Now the transformation SplitConcatPairToInterpolate generates Interpolate-4 with 4 inputs.

* Now the transformation UnsqueezeTileReshapeBlockToInterpolate generates Interpolate-4 with 4 inputs.

* Now the transformation InterpolateV1ToInterpolateV4 generates Interpolate-4 with 4 inputs.

* Some fixes.

* Fixed the transformation InterpolateSequenceToInterpolate according to new variant of Interpolate-4 specification.

* Fixed typos.

* Added shape_calculation_mode to supported_attrs.

* Small fixes.

* Added operation ONNXResize10 and the transformation ONNXResize10ToInterpolate4.

* Fixed function correct_scales_using_dst_shape.

* Some fixes in InterpolateSequenceToInterpolate.

* Fixed bug in the method __call__ of the class CanBeFused: now self.accumulated_axes is correctly cleared in all cases.

* Small change.

* Fixed tests for the transformation SplitConcatPairToInterpolate.

* Now transformations InterpolateWithConcat, InterpolateReshapeWA, InterpolateConcat support Interpolate-4.

* Fixed the transformation InterpolateTranspose for the case of Interpolate-4.

* Written the back transformation InterpolateV4AxesCorrection to convert 'axes' input of Interpolate-4 from NHWC to NCHW layout.

* Added PermuteInput in Interpolate-4 infer.

* Fixed typos.

* Deleted the transformation InterpolateAxesCorrection.

* Now Interpolate-4 permutes axis, not shape in input port 3.

* Small fix.

* Some fix.

* Fixed bug in the transformation UpsampleToResample.

* Added some debug prints.

* Added more debug prints.

* Now ONNX Upsample-9 operation is read as ONNXResize10.

* Small fix.

* Small fixes.

* Fixed tests for the transformation SplitConcatPairToInterpolate.

* Deleted debug prints.

* Deleted some debug prints.

* Fixes in the transformation UnsqueezeTileReshapeBlockToInterpolate and its tests.

* Small fix in the transformation InterpolateSequenceToInterpolate.

* Started to write nGraph transformation to convert Interpolate-1 to Interpolate-4.

* Deleted redundant files.

* Small fixes.

* Small fix.

* Written draft of the transformation Interpolate-1 -> Interpolate-4.

* Small fix.

* Now ONNX Importer reads Resize-10 as Interpolate-4.

* Fixes in the test onnx_model_resize10_import_only.

* Small fix in the test for the conversion Interpolate-1 -> Interpolate-4.

* Small fixes.

* Fixed NGraphReaderTests for Interpolate.

* Some fixes.

* Deleted class for Resample operation.

* Fix in the transformation NearestNeighborUpsampling: fixed precision of the input 'scales' of generated Interpolate-4.

* Fixed typo.

* Now the TF operations ResizeBilinear is readed as internal MO operation TFResizeBilinear. This internal operation is converted into Interpolate-4.

* Small fix in BOM-file.

* Added checks of existence of attributes of TF  ResizeBilinear operation.

* Small fixes in the conversion of the internal MO operation  TFResizeBilinear to Interpolate-4.

* Small fixes.

* Small fixes.

* Now the transformation ONNXResize10ToInterpolateV4 calculates sizes input as input_shape * (scales + epsilon).

* Added the internal MO operation TFResizeNearestNeighbor.

* Fixes in the transformation SplitConcatPairToInterpolate and its tests.

* Fixes in the transformation UnsqueezeTileReshapeBlockToInterpolate and its tests.

* Written the transformation that converts the internal operation TFResizeNearestNeighbor into Interpolate-4.

* Now MO reads the TF operation ResizeNearestNeighbor as the internal MO operation TFResizeNearestNeighbor.

* Small fix.

* Now the specification of Interpolate-4 clarifies that the mode linear_onnx supports only 2D or 4D input tensors.

* Small fix.

* Some fixes.

* Moved the transformation ONNXResize10ToInterpolateV4 to the front stage.

* Deleted infer function and function supported_attrs for ONNXResize10 operation.

* Deleted supported_attrs() for TFResizeBilinear and TFResizeNearestNeighbor.

* Some fixes.

* Fixes in the shape infer function of the nGraph operation Interpolate-4. Now 'axes' input can be non-constant.  In the such case, all elements of the output shape are Dimension::dynamic().

* Deleted corner cases processing in transformations TFResizeBilinearToInterpolateV4 and TFResizeNearestNeighborToInterpolateV4.

* Rewritten the function replace_resize_bilinear.

* Written inner MO operation TFResize that covers TF operations ResizeBilinear and ResizeNearestNeighbor.

* Now TF operations ResizeBilinear and ResizeNearestNeighbor are read as an internal operation TFResize in MO. Transformations TFResizeNearestNeighborToInterpolateV4 and TFResizeBilinearToInterpolateV4 are fused into one transformation TFResizeToInterpolateV4.

* Some changes in the shape infer function of nGraph op Interpolate-4.

* Small fix.

* Some changes.

* The transformation TFResizeToInterpolateV4 is moved to the front stage.

* Deleted redundant assert.

* Deleted transformations ResizeToInterpolate2D and ResizeToInterpolate3D.

* Some renaming.

* Small change.

* Deleted .copy() in the shape infer function of the internal operation TFResize.

* Small fix.

* Small fixes.

* Added comment about the case when the input 'axes' of Interpolate-4 is non-constant.

* Written test for Interpolate-4 shape infer, for the case when the input 'axes' is non-constant and shape_calculation_mode = scales.

* Some fixes.

* Small fixes.

* Small fix.

* Added yet another test for the case of non-constant 'axes' input of Interpolate-4 (when shape_calculation_mode = sizes).

* Added some comment.

* Small fix.

* Reverted changes for InterpolateWithConcat.

* Added type checks for all inputs of nGraph operation Interpolate-4.

* Added u32 and u64 to supported element types of sizes and axes inputs of nGraph operation Interpolate-4.

* Fixed some functional tests.

* Some changes.

* Added helper function float32_array.

* Now the MO transformation InterpolateV1ToInterpolate preserves names of layers.

* Small fix.

* Small fix.

* Reverted some change.

* Small fixes.

* Small fix.

* Small fix.

* Small fix.

* Small fix.

* Reverted changes in the nGraph reader tests for Interpolate-1.

* Some revert.

* Fixed some copyright year.
This commit is contained in:
Vladimir Gavrilov
2021-02-05 19:20:26 +03:00
committed by GitHub
parent bd3884b602
commit c1136cd7b0
34 changed files with 1156 additions and 564 deletions

View File

@@ -15,6 +15,7 @@
* **Type**: string
* **Default value**: none
* **Required**: *yes*
* **Note**: Only 2D and 4D tensors with `axes = {0, 1}` and `axes = {2, 3}` respectively are supported for `"mode" == "linear_onnx"`.
* *shape_calculation_mode*

View File

@@ -17,52 +17,59 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertInterpolate1ToInterpolate4, "Convert
ngraph::pass::ConvertInterpolate1ToInterpolate4::ConvertInterpolate1ToInterpolate4() {
MATCHER_SCOPE(ConvertInterpolate1ToInterpolate4);
auto interpolate1 = ngraph::pattern::wrap_type<ngraph::opset1::Interpolate>({pattern::any_input(pattern::has_static_rank()), pattern::any_input()});
auto interpolate1 = ngraph::pattern::wrap_type<ngraph::opset1::Interpolate>({pattern::any_input(pattern::has_static_shape()), pattern::any_input()});
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
auto interpolate1 = std::dynamic_pointer_cast<ngraph::opset1::Interpolate>(m.get_match_root());
if (!interpolate1)
auto interpolationV0 = std::dynamic_pointer_cast<ngraph::opset1::Interpolate>(m.get_match_root());
if (!interpolationV0) {
return false;
}
auto interpolate_attrs = interpolate1->get_attrs();
auto input_shape_rank = interpolate1->input(0).get_partial_shape().rank().get_length();
auto& inp_partial_shape = interpolationV0->get_input_partial_shape(0);
auto& out_shape = interpolationV0->get_output_shape(0);
auto attrsV0 = interpolationV0->get_attrs();
// attrs
auto mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode();
if (interpolate_attrs.mode == "nearest") {
mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode::nearest;
} else if (interpolate_attrs.mode == "cubic") {
mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode::cubic;
} else if (interpolate_attrs.mode == "linear") {
if (input_shape_rank < 5) {
mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode::linear_onnx;
} else if (input_shape_rank == 5) {
mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode::linear;
} else {
return false;
}
std::vector<float> scales(attrsV0.axes.size(), 1.0f);
auto inp_shape = inp_partial_shape.to_shape();
size_t i = 0;
for (std::size_t axis : attrsV0.axes) {
scales[i] = static_cast<float>(out_shape.at(axis))/inp_shape.at(axis);
i++;
}
auto scalesConstant = ngraph::op::Constant::create(ngraph::element::f32, {scales.size()}, scales);
auto axisConstant = ngraph::op::Constant::create(ngraph::element::i64, {attrsV0.axes.size()},
std::vector<std::size_t>{attrsV0.axes.begin(), attrsV0.axes.end()});
ngraph::opset4::Interpolate::InterpolateAttrs attrsV4;
if (attrsV0.mode == "nearest") {
attrsV4.mode = ngraph::opset4::Interpolate::InterpolateMode::nearest;
} else if (attrsV0.mode == "linear") {
attrsV4.mode = ngraph::opset4::Interpolate::InterpolateMode::linear;
} else if (attrsV0.mode == "cubic") {
attrsV4.mode = ngraph::opset4::Interpolate::InterpolateMode::cubic;
} else if (attrsV0.mode == "linear_onnx") {
attrsV4.mode = ngraph::opset4::Interpolate::InterpolateMode::linear_onnx;
} else {
return false;
}
auto nearest_mode_v4 = ngraph::op::v4::Interpolate::NearestMode::floor;
auto shape_calculation_mode_v4 = ngraph::op::v4::Interpolate::ShapeCalcMode::sizes;
auto coordinate_transformation_mode_v4 = interpolate_attrs.align_corners ? ngraph::op::v4::Interpolate::CoordinateTransformMode::align_corners :
ngraph::op::v4::Interpolate::CoordinateTransformMode::asymmetric;
auto interpolate4_attr = ngraph::op::v4::Interpolate::InterpolateAttrs(mode_v4, shape_calculation_mode_v4,
interpolate_attrs.pads_begin, interpolate_attrs.pads_end,
coordinate_transformation_mode_v4, nearest_mode_v4, interpolate_attrs.antialias, -0.75);
attrsV4.shape_calculation_mode = ngraph::opset4::Interpolate::ShapeCalcMode::sizes;
attrsV4.nearest_mode = ngraph::opset4::Interpolate::NearestMode::round_prefer_floor;
attrsV4.pads_begin = attrsV0.pads_begin;
attrsV4.pads_end = attrsV0.pads_end;
attrsV4.antialias = attrsV0.antialias;
attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::half_pixel;
attrsV4.cube_coeff = -0.75f;
if (attrsV0.align_corners) {
attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::align_corners;
}
// input
auto axes = interpolate_attrs.axes.to_vector();
auto axes_node = ngraph::opset4::Constant::create(element::i64, {axes.size()}, axes);
auto default_scales = std::vector<float>(axes.size(), 1.f);
auto scales_node = ngraph::opset4::Constant::create(element::f32, {axes.size()}, default_scales);
auto interpolateV4 = std::make_shared<ngraph::opset4::Interpolate>(interpolationV0->input_value(0), interpolationV0->input_value(1),
scalesConstant, axisConstant, attrsV4);
auto interpolate4 = std::make_shared<ngraph::opset4::Interpolate>(interpolate1->input_value(0), interpolate1->input_value(1),
scales_node, axes_node, interpolate4_attr);
interpolate4->set_friendly_name(interpolate1->get_friendly_name());
ngraph::copy_runtime_info(interpolate1, interpolate4);
ngraph::replace_node(interpolate1, interpolate4);
interpolateV4->set_friendly_name(interpolationV0->get_friendly_name());
ngraph::copy_runtime_info(interpolationV0, interpolateV4);
ngraph::replace_node(interpolationV0, interpolateV4);
return true;
};

View File

@@ -399,4 +399,4 @@ TEST_F(NGraphReaderTests, ReadInterpolate4Network) {
fdata[2] = 2.0;
fdata[3] = 2.0;
});
}
}

View File

@@ -49,7 +49,7 @@ TEST(TransformationTests, ConvertInterpolate1ToInterpolate4) {
{
auto data_node = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 4, 30, 30});
auto out_shape_node = opset1::Constant::create(element::i32, Shape{4}, {2, 4, 40, 40});
auto default_scales_node = opset1::Constant::create(ngraph::element::f32, Shape{4}, {1.f, 1.f, 1.f, 1.f});
auto default_scales_node = opset1::Constant::create(ngraph::element::f32, Shape{4}, {1.f, 1.f, 4.0f / 3.0f, 4.0f / 3.0f});
auto axes_node = opset1::Constant::create(ngraph::element::i64, Shape{4}, {0, 1, 2, 3});
auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::nearest,
@@ -94,10 +94,10 @@ TEST(TransformationTests, ConvertInterpolate1ToInterpolate4_1) {
{
auto data_node = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 4, 30, 30});
auto out_shape_node = opset1::Constant::create(element::i32, Shape{2}, {40, 40});
auto default_scales_node = opset1::Constant::create(ngraph::element::f32, Shape{2}, {1.f, 1.f});
auto default_scales_node = opset1::Constant::create(ngraph::element::f32, Shape{2}, {4.0f / 3.0f, 4.0f / 3.0f});
auto axes_node = opset1::Constant::create(ngraph::element::i64, Shape{2}, {2, 3});
auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::linear_onnx,
auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::linear,
opset4::Interpolate::ShapeCalcMode::sizes, std::vector<size_t>{0, 0, 0, 0}, std::vector<size_t>{0, 0, 0, 0},
opset4::Interpolate::CoordinateTransformMode::align_corners, opset4::Interpolate::NearestMode::floor,
false, -0.75);

View File

@@ -85,7 +85,7 @@ void InterpolateTransformation::validate() {
if (attributes.mode == "nearest") {
ASSERT_EQ("ScaleShiftIE", typeName);
} else {
ASSERT_EQ("Interp", typeName);
ASSERT_TRUE("Interp" == typeName || "Interpolate" == typeName);
}
}

View File

@@ -144,6 +144,7 @@ extensions/front/input_cut.py
extensions/front/instance_normalization.py
extensions/front/interpolate_reshape.py
extensions/front/InterpolateNormalizer.py
extensions/front/InterpolateV1ToInterpolate.py
extensions/front/kaldi/__init__.py
extensions/front/kaldi/add_permute_after_convolution.py
extensions/front/kaldi/add_reshape_around_convolution.py
@@ -298,6 +299,7 @@ extensions/front/onnx/normalize_ext.py
extensions/front/onnx/normalize_l2_normalize.py
extensions/front/onnx/one_hot_ext.py
extensions/front/onnx/one_hot_normalize.py
extensions/front/onnx/ONNXResize10ToInterpolate.py
extensions/front/onnx/pad_converter.py
extensions/front/onnx/pad_ext.py
extensions/front/onnx/parameter_ext.py
@@ -317,7 +319,6 @@ extensions/front/onnx/reduce_ext.py
extensions/front/onnx/remove_filtering_boxes_by_size.py
extensions/front/onnx/reshape_ext.py
extensions/front/onnx/resize_ext.py
extensions/front/onnx/resize_to_interpolate.py
extensions/front/onnx/reverse_sequence_ext.py
extensions/front/onnx/rnn_ext.py
extensions/front/onnx/roialign_ext.py
@@ -483,6 +484,7 @@ extensions/front/tf/SwitchMergeOptimization.py
extensions/front/tf/TensorArrayExtractors.py
extensions/front/tf/TensorArrayGatherV3.py
extensions/front/tf/tensorflow_custom_operations_config_update.py
extensions/front/tf/TFResizeToInterpolate.py
extensions/front/tf/TFSliceToSlice.py
extensions/front/tf/tile_ext.py
extensions/front/tf/topk_ext.py
@@ -571,7 +573,7 @@ extensions/middle/MulFakeQuantizeFuse.py
extensions/middle/MXNetRNNSequenceNormalize.py
extensions/middle/MXNetSplitMultiLayers.py
extensions/middle/MXTileReplacer.py
extensions/middle/ONNXResize11ToInterpolateV4.py
extensions/middle/ONNXResize11ToInterpolate.py
extensions/middle/ONNXRNNSequenceNormalize.py
extensions/middle/PartialInfer.py
extensions/middle/pass_separator.py
@@ -677,6 +679,7 @@ extensions/ops/non_zero.py
extensions/ops/normalize.py
extensions/ops/normalize_l2.py
extensions/ops/one_hot.py
extensions/ops/ONNXResize10.py
extensions/ops/ONNXResize11.py
extensions/ops/pack.py
extensions/ops/parameter.py
@@ -697,7 +700,6 @@ extensions/ops/rank.py
extensions/ops/ReduceOps.py
extensions/ops/regionyolo.py
extensions/ops/reorgyolo.py
extensions/ops/resample.py
extensions/ops/resize.py
extensions/ops/resize_factor_utils.py
extensions/ops/Reverse.py
@@ -733,6 +735,7 @@ extensions/ops/TensorArrayScatter.py
extensions/ops/TensorArraySize.py
extensions/ops/TensorArrayWrite.py
extensions/ops/TensorIterator_ops.py
extensions/ops/TFResize.py
extensions/ops/topk.py
extensions/ops/topkrois_onnx.py
extensions/ops/transpose.py

View File

@@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from extensions.ops.elementwise import Mul
@@ -54,7 +55,6 @@ class InterpolateConcat(BackReplacementPattern):
\ /
Concat(axis=1)
shape=[1, 7, 60, 160]
"""
enabled = True
graph_condition = [lambda graph: not graph.graph['cmd_params'].static_shape]
@@ -89,7 +89,7 @@ class InterpolateConcat(BackReplacementPattern):
interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
def find_and_replace_pattern(self, graph: Graph):
for interpolate in graph.get_op_nodes(type='Interpolate', version='opset1'):
for interpolate in graph.get_op_nodes(type='Interpolate'):
if interpolate.in_port(1).get_source().node.soft_get('type') != 'Const':
continue
dsts = interpolate.out_port(0).get_destinations()
@@ -101,14 +101,12 @@ class InterpolateReshapeWA(BackReplacementPattern):
"""
Replaces hard-coded 1-port input of Interpolate with reshape-able sub-graph.
WARNING: Could cause troubles if model has hard-coded Interpolate intentionally -- rare situation
BEFORE:
input Const
shape=[1, 3, 30, 40] value=[60, 160]
\ /
Interpolate(axes=(2, 3))
shape=[1, 3, 60, 160]
AFTER:
input
shape=[1, 3, 30, 40]
@@ -151,6 +149,6 @@ class InterpolateReshapeWA(BackReplacementPattern):
interpolate.in_port(1).get_connection().set_source(mul.out_port(0))
def find_and_replace_pattern(self, graph: Graph):
for interpolate in graph.get_op_nodes(type='Interpolate', version='opset1'):
for interpolate in graph.get_op_nodes(type='Interpolate'):
if interpolate.in_port(1).get_source().node.soft_get('type') == 'Const':
self.make_interpolate_reshapeable(interpolate)
self.make_interpolate_reshapeable(interpolate)

View File

@@ -0,0 +1,69 @@
"""
Copyright (C) 2018-2021 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np
from extensions.ops.interpolate import Interpolate
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementPattern
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, rename_nodes
def correct_pad(pad):
return int64_array([pad] if not isinstance(pad, list) else pad)
class InterpolateV1ToInterpolate(FrontReplacementPattern):
"""
This transformation replaces the operation Interpolate-1 with the operation Interpolate-4.
"""
enabled = True
def run_after(self):
from extensions.front.InterpolateNormalizer import InterpolateNormalizer
return [InterpolateNormalizer]
def find_and_replace_pattern(self, graph: Graph):
for node in graph.get_op_nodes(op='Interpolate', version='opset1'):
transformation_mode = 'align_corners' if int(node.soft_get('align_corners', 0)) else 'half_pixel'
interpolate1_name = node.soft_get('name', node.id)
interpolate4 = create_op_with_const_inputs(graph, Interpolate,
{
2: np.array([1.0, 1.0]),
3: int64_array(node.axes)
},
{
'mode': node.mode,
'antialias': node.antialias,
'coordinate_transformation_mode': transformation_mode,
'pads_begin': correct_pad(node.soft_get('pads_begin', 0)),
'pads_end': correct_pad(node.soft_get('pads_end', 0)),
'nearest_mode': 'round_prefer_floor',
'cube_coeff': -0.75,
'shape_calculation_mode': 'sizes',
'version': 'opset4',
'in_ports_count': 4,
})
interpolate1_input_connection = node.in_port(0).get_connection()
interpolate1_input_connection.set_destination(interpolate4.in_port(0))
sizes_connection = node.in_port(1).get_connection()
sizes_connection.set_destination(interpolate4.in_port(1))
node.out_port(0).get_connection().set_source(interpolate4.out_port(0))
rename_nodes([(node, interpolate1_name + '/delete'), (interpolate4, interpolate1_name)])

View File

@@ -0,0 +1,135 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging as log
import numpy as np
from extensions.ops.activation_ops import Floor
from extensions.ops.Cast import Cast
from extensions.ops.elementwise import Add, Mul
from extensions.ops.interpolate import Interpolate
from extensions.ops.range import Range
from extensions.ops.rank import Rank
from mo.front.common.partial_infer.utils import int64_array, float_array
from mo.front.common.replacement import FrontReplacementOp
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, Node, rename_nodes
from mo.middle.passes.convert_data_type import data_type_str_to_np
from mo.ops.shape import Shape
from mo.ops.strided_slice import StridedSlice
def replace_resize(graph: Graph, resize: Node):
log.debug("Converting of ONNX Resize-10 to Interpolate-4 "
"is triggered for node {}.".format(resize.soft_get('name', resize.id)))
resize_name = resize.soft_get('name', resize.id)
rank_node = Rank(graph, {'name': resize_name + '/max_axes'}).create_node()
range_node = create_op_with_const_inputs(graph, Range, {0: int64_array(2), 2: int64_array(1)},
{'name': resize_name + '/axes'})
sizes_ss = create_op_with_const_inputs(graph, StridedSlice,
{1: int64_array([2]),
2: int64_array([0]),
3: int64_array([1])},
{'name': resize_name + '/sizes_ss',
'begin_mask': int64_array([1]),
'end_mask': int64_array([0]),
'new_axis_mask': int64_array([0]),
'shrink_axis_mask': int64_array([0]),
'ellipsis_mask': int64_array([0])})
scales_ss = create_op_with_const_inputs(graph, StridedSlice,
{1: int64_array([2]),
2: int64_array([0]),
3: int64_array([1])},
{'name': resize_name + '/scales_ss',
'begin_mask': int64_array([1]),
'end_mask': int64_array([0]),
'new_axis_mask': int64_array([0]),
'shrink_axis_mask': int64_array([0]),
'ellipsis_mask': int64_array([0])})
rank_node.out_port(0).connect(range_node.in_port(1))
interpolate_node = Interpolate(graph, {'version': 'opset4',
'mode': 'linear_onnx' if resize.mode == 'linear' else 'nearest',
'coordinate_transformation_mode': 'asymmetric',
'cube_coeff': -0.75,
'nearest_mode': 'simple',
'pads_begin': int64_array([0]),
'pads_end': int64_array([0]),
'antialias': 0,
'shape_calculation_mode': 'scales',
'in_ports_count': 4}).create_node()
range_node.out_port(0).connect(interpolate_node.in_port(3))
shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()
# When we calculate 'sizes' input as floor(input_shape * scales), we can get incorrect 'sizes' if, e.g.,
# scales = [1.0, 1.0, 1.33333, 2.0], input_shape = [1, 3, 30, 200], because
# input_shape * scales = [1, 3, 39.9999, 400], and floor(input_shape * scales)[2] == 39, not 40.
# Maybe we need to calculate 'sizes' input as floor(input_shape * scales + eps), where eps is some small
# floating point number, e.g. 1.0e-5. But, in this case, if scales = [1.0, 1.0, 1.333333, 2.0],
# input_shape = [1, 3, 30, 200], floor(input_shape * scales + eps) = 39, not 40, because
# input_shape[2] * scales[2] + 1.0e-5 = 39.99991.
# Hence, we need to calculate 'sizes' as floor(input_shape * (scales + eps)).
add_node = create_op_with_const_inputs(graph, Add,
{1: float_array([1.0e-5])},
{'name': resize_name + '/Add'})
input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
cast_shape_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node([cast_shape_to_float, add_node])
floor_node = Floor(graph, {'name': resize_name + '/Floor'}).create_node([mul_node])
cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node([floor_node])
cast_mul_result_to_int.out_port(0).connect(sizes_ss.in_port(0))
sizes_ss.out_port(0).connect(interpolate_node.in_port(1))
scales_ss.out_port(0).connect(interpolate_node.in_port(2))
connection_of_resize_input = resize.in_port(0).get_connection()
connection_of_resize_input.set_destination(interpolate_node.in_port(0))
connection_of_scales = resize.in_port(1).get_connection()
connection_of_scales.set_destination(scales_ss.in_port(0))
connection_of_resize_input.get_source().connect(shape_of.in_port(0))
connection_of_resize_input.get_source().connect(rank_node.in_port(0))
connection_of_scales.get_source().connect(add_node.in_port(0))
rename_nodes([(resize, resize_name + '/delete'), (interpolate_node, resize_name)])
resize.out_port(0).get_connection().set_source(interpolate_node.out_port(0))
class ONNXResize10ToInterpolate(FrontReplacementOp):
"""
The transformation replaces ONNX Resize 10 with Interpolate-4.
"""
op = 'ONNXResize10'
enabled = True
def run_after(self):
from extensions.front.InterpolateNormalizer import InterpolateNormalizer
return [InterpolateNormalizer]
def replace_sub_graph(self, graph: Graph, match: dict):
resize = match['op']
replace_resize(graph, resize)

View File

@@ -14,7 +14,7 @@
limitations under the License.
"""
from extensions.ops.upsample import UpsampleOp
from extensions.ops.ONNXResize10 import ONNXResize10
from extensions.ops.ONNXResize11 import ONNXResize11Op
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import onnx_attr, get_onnx_opset_version
@@ -43,5 +43,5 @@ class ResizeExtractor(FrontExtractorOp):
ONNXResize11Op.update_node_stat(node, attrs)
else:
mode = onnx_attr(node, 'mode', 's', default=b'nearest').decode()
UpsampleOp.update_node_stat(node, {'mode': mode})
ONNXResize10.update_node_stat(node, {'mode': mode})
return cls.enabled

View File

@@ -1,192 +0,0 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging as log
import numpy as np
from extensions.ops.elementwise import Mul
from extensions.ops.interpolate import Interpolate
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementSubgraph
from mo.graph.graph import Graph
from mo.ops.const import Const
class ResizeToInterpolate2D(FrontReplacementSubgraph):
enabled = True
def pattern(self):
return dict(
nodes=[
('input', dict()),
('shape_1', dict(op='ShapeOf')),
('shape_2', dict(op='ShapeOf')),
('shape_3', dict(op='ShapeOf')),
('gather_1', dict(type='Gather')),
('gather_2', dict(type='Gather')),
('mul_1', dict(op='Mul')),
('mul_2', dict(op='Mul')),
('unsqueeze_1', dict(op='ExpandDims')),
('unsqueeze_2', dict(op='ExpandDims')),
('slice', dict(op='Slice')),
('slice_start', dict(op='Const', value=lambda x: x is not None and np.array_equal(x, int64_array([2])))),
('slice_end', dict(op='Const', value=lambda x: x is not None and np.array_equal(x, int64_array([4])))),
('concat_1', dict(op='Concat')),
('cast_1', dict(op='Cast')),
('cast_2', dict(op='Cast')),
('div', dict(op='Div')),
('concat_2', dict(op='Concat')),
('resize', dict(op='Upsample')),
],
edges=[
('input', 'resize', {'in': 0}),
('input', 'shape_1', {'in': 0}),
('input', 'shape_2', {'in': 0}),
('input', 'shape_3', {'in': 0}),
('shape_1', 'gather_1', {'in': 0}),
('shape_2', 'gather_2', {'in': 0}),
('shape_3', 'slice', {'in': 0}),
('slice_start', 'slice', {'in': 1}),
('slice_end', 'slice', {'in': 2}),
('gather_1', 'mul_1', {'in': 0}),
('gather_2', 'mul_2', {'in': 0}),
('mul_1', 'unsqueeze_1', {'in': 0}),
('mul_2', 'unsqueeze_2', {'in': 0}),
('unsqueeze_1', 'concat_1', {'in': 0}),
('unsqueeze_2', 'concat_1', {'in': 1}),
('concat_1', 'cast_1', {'in': 0}),
('slice', 'cast_2', {'in': 0}),
('cast_1', 'div', {'in': 0}),
('cast_2', 'div', {'in': 1}),
('div', 'concat_2', {'in': 1}),
('concat_2', 'resize', {'in': 1}),
])
def replace_sub_graph(self, graph: Graph, match: dict):
resize_node = match['resize']
if match['mul_1'].in_node(1).value != match['mul_2'].in_node(1).value:
log.info('Pattern matched around resize op {} has different scale values.'.format(resize_node.name))
return
interpolate_node = Interpolate(graph, {'name': resize_node.name + '/Interpolate',
'mode': resize_node.mode, 'axes': int64_array([2, 3])}).create_node()
scale = match['mul_1'].in_node(1).value
scale_value = int64_array([scale, scale])
scale_const = Const(graph, {'value': scale_value, 'name': resize_node.name + '/Scale'}).create_node()
interpolated_shape = Mul(graph, {'name': resize_node.name + '/OutputShape'}).create_node()
match['slice'].out_port(0).connect(interpolated_shape.in_port(0))
scale_const.out_port(0).connect(interpolated_shape.in_port(1))
resize_node.in_port(0).get_connection().set_destination(interpolate_node.in_port(0))
interpolated_shape.out_port(0).connect(interpolate_node.in_port(1))
resize_node.out_port(0).get_connection().set_source(interpolate_node.out_port(0))
class ResizeToInterpolate3D(FrontReplacementSubgraph):
enabled = True
def pattern(self):
return dict(
nodes=[
('input', dict()),
('shape_1', dict(op='ShapeOf')),
('shape_2', dict(op='ShapeOf')),
('shape_3', dict(op='ShapeOf')),
('shape_4', dict(op='ShapeOf')),
('gather_1', dict(type='Gather')),
('gather_2', dict(type='Gather')),
('gather_3', dict(type='Gather')),
('mul_1', dict(op='Mul')),
('mul_2', dict(op='Mul')),
('mul_3', dict(op='Mul')),
('cast_1', dict(op='Cast')),
('cast_2', dict(op='Cast')),
('cast_3', dict(op='Cast')),
('unsqueeze_1', dict(op='ExpandDims')),
('unsqueeze_2', dict(op='ExpandDims')),
('unsqueeze_3', dict(op='ExpandDims')),
('floor_1', dict(op='Floor')),
('floor_2', dict(op='Floor')),
('floor_3', dict(op='Floor')),
('slice', dict(op='Slice')),
('slice_start', dict(op='Const', value=lambda x: x is not None and np.array_equal(x, int64_array([2])))),
('slice_end', dict(op='Const', value=lambda x: x is not None and np.array_equal(x, int64_array([5])))),
('concat_1', dict(op='Concat')),
('cast_4', dict(op='Cast')),
('cast_5', dict(op='Cast')),
('div', dict(op='Div')),
('concat_2', dict(op='Concat')),
('resize', dict(op='Upsample')),
],
edges=[
('input', 'resize', {'in': 0}),
('input', 'shape_1', {'in': 0}),
('input', 'shape_2', {'in': 0}),
('input', 'shape_3', {'in': 0}),
('input', 'shape_4', {'in': 0}),
('shape_1', 'gather_1', {'in': 0}),
('shape_2', 'gather_2', {'in': 0}),
('shape_3', 'gather_3', {'in': 0}),
('shape_4', 'slice', {'in': 0}),
('slice_start', 'slice', {'in': 1}),
('slice_end', 'slice', {'in': 2}),
('gather_1', 'mul_1', {'in': 0}),
('gather_2', 'mul_2', {'in': 0}),
('gather_3', 'mul_3', {'in': 0}),
('mul_1', 'cast_1', {'in': 0}),
('mul_2', 'cast_2', {'in': 0}),
('mul_3', 'cast_3', {'in': 0}),
('cast_1', 'floor_1', {'in': 0}),
('cast_2', 'floor_2', {'in': 0}),
('cast_3', 'floor_3', {'in': 0}),
('floor_1', 'unsqueeze_1', {'in': 0}),
('floor_2', 'unsqueeze_2', {'in': 0}),
('floor_3', 'unsqueeze_3', {'in': 0}),
('unsqueeze_1', 'concat_1', {'in': 0}),
('unsqueeze_2', 'concat_1', {'in': 1}),
('unsqueeze_3', 'concat_1', {'in': 2}),
('concat_1', 'cast_4', {'in': 0}),
('slice', 'cast_5', {'in': 0}),
('cast_4', 'div', {'in': 0}),
('cast_5', 'div', {'in': 1}),
('div', 'concat_2', {'in': 1}),
('concat_2', 'resize', {'in': 1}),
])
def replace_sub_graph(self, graph: Graph, match: dict):
resize_node = match['resize']
if match['mul_1'].in_node(1).value != match['mul_2'].in_node(1).value or \
match['mul_1'].in_node(1).value != match['mul_3'].in_node(1).value:
log.info('Pattern matched around resize op {} has different scale values.'.format(resize_node.name))
return
interpolate_node = Interpolate(graph, {'name': resize_node.name + '/Interpolate',
'mode': resize_node.mode, 'axes': int64_array([2, 3, 4])}).create_node()
scale = match['mul_1'].in_node(1).value
scale_value = int64_array([scale, scale, scale])
scale_const = Const(graph, {'value': scale_value, 'name': resize_node.name + '/Scale'}).create_node()
interpolated_shape = Mul(graph, {'name': resize_node.name + '/OutputShape'}).create_node()
match['slice'].out_port(0).connect(interpolated_shape.in_port(0))
scale_const.out_port(0).connect(interpolated_shape.in_port(1))
resize_node.in_port(0).get_connection().set_destination(interpolate_node.in_port(0))
interpolated_shape.out_port(0).connect(interpolate_node.in_port(1))
resize_node.out_port(0).get_connection().set_source(interpolate_node.out_port(0))

View File

@@ -18,9 +18,10 @@ import math
import numpy as np
from extensions.ops.ONNXResize10 import ONNXResize10
from extensions.ops.upsample import UpsampleOp
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import onnx_attr
from mo.front.onnx.extractors.utils import onnx_attr, get_onnx_opset_version
from mo.utils.error import Error
@@ -30,42 +31,47 @@ class UpsampleFrontExtractor(FrontExtractorOp):
@classmethod
def extract(cls, node):
mode = onnx_attr(node, 'mode', 's', default='nearest', dst_type=lambda x: x.decode())
scales = onnx_attr(node, 'scales', 'floats', dst_type=lambda x: np.array(x, dtype=np.float32))
width_scale = onnx_attr(node, 'width_scale', 'f')
height_scale = onnx_attr(node, 'height_scale', 'f')
onnx_opset_version = get_onnx_opset_version(node)
if onnx_opset_version is not None and onnx_opset_version >= 9:
mode = onnx_attr(node, 'mode', 's', default='nearest', dst_type=lambda x: x.decode())
ONNXResize10.update_node_stat(node, {'mode': mode})
else:
mode = onnx_attr(node, 'mode', 's', default='nearest', dst_type=lambda x: x.decode())
scales = onnx_attr(node, 'scales', 'floats', dst_type=lambda x: np.array(x, dtype=np.float32))
width_scale = onnx_attr(node, 'width_scale', 'f')
height_scale = onnx_attr(node, 'height_scale', 'f')
supported_modes = ['nearest', 'linear']
if mode not in supported_modes:
raise Error(
'Error decoding Upsample node {}, mode = {} is not in the list of supported modes {}.',
node.name,
mode,
supported_modes
)
if scales is not None:
if scales.shape != (4,):
supported_modes = ['nearest', 'linear']
if mode not in supported_modes:
raise Error(
'Upsample scales attribute is wrong for node {}. Only 4D scales are supported.',
'Error decoding Upsample node {}, mode = {} is not in the list of supported modes {}.',
node.name,
mode,
supported_modes
)
if scales is not None:
if scales.shape != (4,):
raise Error(
'Upsample scales attribute is wrong for node {}. Only 4D scales are supported.',
node.name
)
if math.fabs(scales[0] - 1) > 1e-5 or math.fabs(scales[1] - 1) > 1e-5:
raise Error(
'Upsampling of batch and feature dimensions is not supported for node {}.',
node.name
)
height_scale = scales[2]
width_scale = scales[3]
if (width_scale is None or height_scale is None) and len(node.in_nodes()) != 2:
raise Error(
'One/both of widths_scale = {} and height_scale = {} is not defined for Upsample node {}.',
width_scale,
height_scale,
node.name
)
if math.fabs(scales[0] - 1) > 1e-5 or math.fabs(scales[1] - 1) > 1e-5:
raise Error(
'Upsampling of batch and feature dimensions is not supported for node {}.',
node.name
)
height_scale = scales[2]
width_scale = scales[3]
if (width_scale is None or height_scale is None) and len(node.in_nodes()) != 2:
raise Error(
'One/both of widths_scale = {} and height_scale = {} is not defined for Upsample node {}.',
width_scale,
height_scale,
node.name
)
UpsampleOp.update_node_stat(node, {'mode': mode, 'height_scale': height_scale,
'width_scale': width_scale})
UpsampleOp.update_node_stat(node, {'mode': mode, 'height_scale': height_scale,
'width_scale': width_scale})
return cls.enabled

View File

@@ -15,6 +15,7 @@
"""
import numpy as np
from extensions.ops.interpolate import Interpolate
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.replacement import FrontReplacementSubgraph
from mo.graph.graph import Graph
@@ -35,8 +36,7 @@ class InterpolateTranspose(FrontReplacementSubgraph):
('interpolate',
{
'kind': 'op',
'op': 'Interpolate',
'axes': lambda axes: axes is not None and np.array_equal(axes, int64_array([1, 2]))
'op': 'Interpolate'
}),
('transpose_1', {'kind': 'op', 'op': 'Transpose'}),
('transpose_1_order',
@@ -70,8 +70,18 @@ class InterpolateTranspose(FrontReplacementSubgraph):
transpose_1 = match['transpose_1']
transpose_2 = match['transpose_2']
axes = Interpolate.get_axes(interpolate)
if axes is None or not np.array_equal(axes, int64_array([1, 2])):
return
# because we remove Transpose layers the ResizeNearestNeighbor should be updated for NCHW layout
interpolate.axes = int64_array([2, 3])
opset = interpolate.get_opset()
assert opset in ['opset1', 'opset4'], \
'Interpolate node with name {} has unsupported opset'.format(interpolate.soft_get('name', interpolate.id))
if opset == 'opset1':
interpolate.axes = int64_array([2, 3])
else:
interpolate.in_port(3).data.set_value(int64_array([2, 3]))
transpose_1.in_port(0).get_connection().set_destination(interpolate.in_port(0))
transpose_2.out_port(0).get_connection().set_source(interpolate.out_port(0))

View File

@@ -0,0 +1,133 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging as log
import numpy as np
from extensions.ops.Cast import Cast
from extensions.ops.elementwise import Div
from extensions.ops.interpolate import Interpolate
from mo.front.common.layout import get_height_dim, get_width_dim
from mo.front.common.partial_infer.utils import int64_array
from mo.front.common.replacement import FrontReplacementOp
from mo.front.tf.graph_utils import create_op_with_const_inputs
from mo.graph.graph import Graph, Node, rename_nodes
from mo.ops.shape import Shape
from mo.ops.strided_slice import StridedSlice
def replace_tf_resize(graph: Graph, resize: Node, interpolation_mode: str):
resize_name = resize.soft_get('name', resize.id)
log.debug("Converting of {} to Interpolate-4 is triggered for node {}.".format(resize.op, resize_name))
num_of_inputs = len([port for port in resize.in_ports().values() if not port.disconnected()])
assert num_of_inputs == 2, \
"Number of inputs of {} (with name {}) should be equal to 2".format(resize.op, resize_name)
attrs_msg = "If half_pixel_centers attribute of the node {} with op {} is True, " \
"the attribute align_corners must be False"
assert not resize.half_pixel_centers or (resize.half_pixel_centers and not resize.align_corners), \
attrs_msg.format(resize_name, resize.op)
shape = Shape(graph, {'name': resize_name + '/shapeof'}).create_node()
layout = graph.graph['layout']
height_dim = get_height_dim(layout, 4)
width_dim = get_width_dim(layout, 4)
ss = create_op_with_const_inputs(graph, StridedSlice,
{1: int64_array([height_dim]),
2: int64_array([width_dim + 1]),
3: int64_array([1])
},
{'name': resize_name + '/StridedSlice',
'begin_mask': int64_array([1]),
'end_mask': int64_array([1]),
'new_axis_mask': int64_array([0]),
'shrink_axis_mask': int64_array([0]),
'ellipsis_mask': int64_array([0])
})
div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()
shape_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()
size_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()
size_to_float.out_port(0).connect(div_node.in_port(0))
shape_to_float.out_port(0).connect(div_node.in_port(1))
ss.out_port(0).connect(shape_to_float.in_port(0))
shape.out_port(0).connect(ss.in_port(0))
align_corners = resize.align_corners
half_pixel_centers = resize.half_pixel_centers
nearest_mode = 'floor' if interpolation_mode == 'nearest' else 'round_prefer_floor'
if align_corners:
coordinate_transformation_mode = 'align_corners'
if interpolation_mode == 'nearest':
nearest_mode = 'round_prefer_ceil'
elif half_pixel_centers:
coordinate_transformation_mode = 'tf_half_pixel_for_nn' if interpolation_mode == 'nearest' else 'half_pixel'
else:
coordinate_transformation_mode = 'asymmetric'
interpolate4 = create_op_with_const_inputs(graph, Interpolate,
{
3: int64_array([height_dim, width_dim])
},
{
'name': resize_name + '/interpolate_4',
'mode': interpolation_mode,
'antialias': False,
'coordinate_transformation_mode': coordinate_transformation_mode,
'pads_begin': int64_array([0]),
'pads_end': int64_array([0]),
'nearest_mode': nearest_mode,
'cube_coeff': -0.75,
'shape_calculation_mode': 'sizes',
'version': 'opset4',
'in_ports_count': 4,
})
resize_input_connection = resize.in_port(0).get_connection()
resize_input_connection.set_destination(interpolate4.in_port(0))
resize_input_connection.get_source().connect(shape.in_port(0))
div_node.out_port(0).connect(interpolate4.in_port(2))
sizes_connection = resize.in_port(1).get_connection()
sizes_connection.set_destination(interpolate4.in_port(1))
sizes_connection.get_source().connect(size_to_float.in_port(0))
resize.out_port(0).get_connection().set_source(interpolate4.out_port(0))
rename_nodes([(resize, resize_name + '/delete'), (interpolate4, resize_name)])
class TFResizeToInterpolate(FrontReplacementOp):
"""
The transformation replaces TFResize with Interpolate-4.
"""
op = 'TFResize'
enabled = True
def run_after(self):
from extensions.front.InterpolateNormalizer import InterpolateNormalizer
return [InterpolateNormalizer]
def replace_sub_graph(self, graph: Graph, match: dict):
resize = match['op']
replace_tf_resize(graph, resize, resize.mode)

View File

@@ -70,17 +70,30 @@ class NearestNeighborUpsampling(FrontReplacementSubgraph):
log.warning('Failed to determine scaling parameters from the topology. Do not apply pattern.')
return
axes = int64_array([2, 3]) if graph.graph['layout'] == 'NCHW' else int64_array([1, 2])
resample_op = Interpolate(graph, {'name': 'Resample_', 'antialias': 0, 'mode': 'nearest', 'axes': axes})
reshape2_name = match['reshape_2'].name
resample_op = Interpolate(graph,
{'mode': 'nearest', 'antialias': 0, 'pads_begin': int64_array([0]),
'pads_end': int64_array([0]), 'coordinate_transformation_mode': 'half_pixel',
'nearest_mode': 'round_prefer_floor', 'cube_coeff': -0.75, 'version': 'opset4',
'name': reshape2_name + '/Resample', 'shape_calculation_mode': 'scales',
'in_ports_count': 4})
resample_node = resample_op.create_node([match['op']])
const = Const(graph, {'value': np.array([input_height * height_scale, input_width * width_scale]),
'name': resample_node.name + '/target_shape'}).create_node()
axes_node = Const(graph,
{
'name': resample_node.name + '/axes',
'value': int64_array([2, 3]) if graph.graph['layout'] == 'NCHW' else int64_array([1, 2])
}).create_node()
sizes_node = Const(graph, {'value': np.array([input_height * height_scale, input_width * width_scale]),
'name': resample_node.name + '/target_shape'}).create_node()
scales_node = Const(graph, {'value': np.array([height_scale, width_scale], dtype=np.float32),
'name': resample_node.name + '/scales'}).create_node()
match['reshape_2'].replace_node(resample_node)
resample_node.add_input_port(1, skip_if_exist=True)
assert resample_node.in_port(1).disconnected()
const.out_port(0).connect(resample_node.in_port(1))
sizes_node.out_port(0).connect(resample_node.in_port(1))
scales_node.out_port(0).connect(resample_node.in_port(2))
axes_node.out_port(0).connect(resample_node.in_port(3))
graph.remove_nodes_from([node.id for node in match.values() if node.id != match['op'].id])

View File

@@ -13,8 +13,8 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.interpolate import Interpolate
from mo.front.common.partial_infer.utils import int64_array
from extensions.ops.TFResize import TFResize
from mo.front.extractor import FrontExtractorOp
@@ -24,10 +24,18 @@ class ResizeBilinearFrontExtractor(FrontExtractorOp):
@classmethod
def extract(cls, node):
mapping_rule = {
'align_corners': int(node.pb.attr['align_corners'].b),
'mode': 'linear',
'axes': int64_array([1, 2]),
align_corners = False
if 'align_corners' in node.pb.attr:
align_corners = node.pb.attr['align_corners'].b
half_pixel_centers = False
if 'half_pixel_centers' in node.pb.attr:
half_pixel_centers = node.pb.attr['half_pixel_centers'].b
attrs = {
'align_corners': align_corners,
'half_pixel_centers': half_pixel_centers,
'mode': 'linear'
}
Interpolate.update_node_stat(node, mapping_rule)
TFResize.update_node_stat(node, attrs)
return cls.enabled

View File

@@ -13,8 +13,8 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from extensions.ops.interpolate import Interpolate
from mo.front.common.partial_infer.utils import int64_array
from extensions.ops.TFResize import TFResize
from mo.front.extractor import FrontExtractorOp
@@ -24,10 +24,18 @@ class ResizeNearestNeighborFrontExtractor(FrontExtractorOp):
@classmethod
def extract(cls, node):
mapping_rule = {
'mode': 'nearest',
'antialias': 0,
'axes': int64_array([1, 2]),
align_corners = False
if 'align_corners' in node.pb.attr:
align_corners = node.pb.attr['align_corners'].b
half_pixel_centers = False
if 'half_pixel_centers' in node.pb.attr:
half_pixel_centers = node.pb.attr['half_pixel_centers'].b
attrs = {
'align_corners': align_corners,
'half_pixel_centers': half_pixel_centers,
'mode': 'nearest'
}
Interpolate.update_node_stat(node, mapping_rule)
TFResize.update_node_stat(node, attrs)
return cls.enabled

View File

@@ -40,12 +40,12 @@ def is_next(first: Node, second: Node) -> bool:
:param second: another Interpolate layer
:return: True, if 'first' is an predecessor of 'second', and False otherwise.
"""
if not node_has_one_consumer(first):
return False
dests = first.out_port(0).get_destinations()
if len(dests) != 1:
return False
return second.id == dests[0].node.id
if node_has_one_consumer(first):
return second.id == dests[0].node.id
elif first.soft_get('maybe_part_of_sequence', False):
return len(dests) == 2 and second.id in [d.node.id for d in dests]
return False
class CanBeFused:
@@ -131,7 +131,8 @@ class CanBeFused:
:param second: the second of fused nodes
:return: True, if nodes can be fused, and False otherwise
"""
if not self._compare_attributes(first, second):
if not (is_next(first, second) and self._compare_attributes(first, second)):
self.accumulated_axes = set()
return False
fst_axes = set([a for a in Interpolate.get_axes(first)])
@@ -250,7 +251,7 @@ def replace_sequence(seq: List[Node], graph: Graph):
last_interp_node.out_port(0).get_connection().set_source(interp_node.out_port(0))
rename_nodes([(last_interp_node, last_interp_node_name + '/delete_'), (interp_node, last_interp_node_name)])
rename_nodes([(last_interp_node, last_interp_node_name + '/delete'), (interp_node, last_interp_node_name)])
class InterpolateSequenceToInterpolate(MiddleReplacementPattern):
@@ -267,6 +268,6 @@ class InterpolateSequenceToInterpolate(MiddleReplacementPattern):
log.debug('Enabled replacement of a sequence of Interpolate layers with one Interpolate layer.')
interps = [n for n in graph.pseudo_topological_sort() if n.kind == 'op' and n.op == 'Interpolate']
fuser = CanBeFused()
sequences = group_by_with_binary_predicate(interps, lambda prev, x: is_next(prev, x) and fuser(prev, x))
sequences = group_by_with_binary_predicate(interps, fuser)
for seq in sequences:
replace_sequence(seq, graph)

View File

@@ -84,7 +84,7 @@ def replace_resize(graph: Graph, resize: Node):
'shrink_axis_mask': int64_array([0]),
'ellipsis_mask': int64_array([0])})
axes_node = Const(graph,
{'name': resize_name + '/axis_',
{'name': resize_name + '/axis',
'value': int64_array(np.arange(begin_dim, end_dim))}).create_node()
shape_calculation_mode = 'scales' if num_of_inputs == 3 else 'sizes'
@@ -101,21 +101,21 @@ def replace_resize(graph: Graph, resize: Node):
'in_ports_count': 4}).create_node()
axes_node.out_port(0).connect(interpolate_node.in_port(3))
shape_of = Shape(graph, {'name': resize_name + '/ShapeOf_'}).create_node()
shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()
add_node = create_op_with_const_inputs(graph, Add,
{1: float_array([1.0e-5])},
{'name': resize_name + '/Add_'})
{'name': resize_name + '/Add'})
input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
if num_of_inputs == 3:
cast_shape_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
mul_node = Mul(graph, {'name': resize_name + '/Mul_'}).create_node()
mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node()
shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))
cast_add_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node()
floor_node = Floor(graph, {'name': resize_name + '/Floor_'}).create_node()
floor_node = Floor(graph, {'name': resize_name + '/Floor'}).create_node()
mul_node.out_port(0).connect(add_node.in_port(0))
add_node.out_port(0).connect(floor_node.in_port(0))
floor_node.out_port(0).connect(cast_add_result_to_int.in_port(0))
@@ -134,7 +134,7 @@ def replace_resize(graph: Graph, resize: Node):
else:
cast_shape_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
cast_sizes_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
div_node = Div(graph, {'name': resize_name + '/Div_'}).create_node()
div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()
cast_sizes_to_float.out_port(0).connect(div_node.in_port(0))
cast_shape_to_float.out_port(0).connect(div_node.in_port(1))
shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
@@ -156,7 +156,7 @@ def replace_resize(graph: Graph, resize: Node):
resize.out_port(0).get_connection().set_source(interpolate_node.out_port(0))
class ONNXResize11ToInterpolate4(MiddleReplacementPattern):
class ONNXResize11ToInterpolate(MiddleReplacementPattern):
"""
The transformation replaces ONNX Resize 11 with Interpolate-4.
"""

View File

@@ -15,8 +15,11 @@
"""
import logging as log
import numpy as np
from typing import Optional
from extensions.ops.activation_ops import Floor
from extensions.ops.Cast import Cast
from extensions.ops.elementwise import Mul
from extensions.ops.interpolate import Interpolate
from mo.front.common.partial_infer.utils import int64_array
@@ -74,20 +77,21 @@ def get_split_scale(split: Node) -> int:
def replace_interpolate_pattern(graph: Graph, match: dict):
split = match['split']
scale = int64_array([get_split_scale(split)])
scale = np.array([get_split_scale(split)], dtype=np.float32)
axis = int(split.in_port(1).get_connection().get_source().node.value)
split_node_name = split.name
axis_node = Const(graph, {'name': split_node_name + '/axis', 'value': int64_array([axis])}).create_node()
shape_node = Shape(graph, dict(name=split_node_name + '/Shape_')).create_node()
scales_node = Const(graph, dict(name=split_node_name + '/scales_', value=scale)).create_node()
mul_node = Mul(graph, dict(name=split_node_name + '/Mul_')).create_node()
shape_node = Shape(graph, dict(name=split_node_name + '/Shape')).create_node()
scales_node = Const(graph, dict(name=split_node_name + '/scales', value=scale)).create_node()
mul_node = Mul(graph, dict(name=split_node_name + '/Mul')).create_node()
scales_node.out_port(0).connect(mul_node.in_port(1))
strided_slice_node = create_op_with_const_inputs(graph,
StridedSlice,
{1: int64_array([axis]), 2: int64_array([axis + 1])},
{
'name': split_node_name + '/StridedSlice_',
'name': split_node_name + '/StridedSlice',
'begin_mask': int64_array([1]),
'end_mask': int64_array([1]),
'new_axis_mask': int64_array([0]),
@@ -96,12 +100,28 @@ def replace_interpolate_pattern(graph: Graph, match: dict):
})
shape_node.out_port(0).connect(strided_slice_node.in_port(0))
strided_slice_node.out_port(0).connect(mul_node.in_port(0))
cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node()
interp_node = Interpolate(graph, dict(name=split_node_name + '/Interpolate_',
axes=int64_array([axis]),
mode='nearest')).create_node()
mul_node.out_port(0).connect(interp_node.in_port(1))
strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0))
cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))
interp_node = Interpolate(graph,
dict(name=split_node_name + '/Interpolate',
mode='nearest',
antialias=0, pads_begin=int64_array([0]), pads_end=int64_array([0]),
coordinate_transformation_mode='half_pixel', nearest_mode='round_prefer_floor',
cube_coeff=-0.75, version='opset4', shape_calculation_mode='scales',
in_ports_count=4, maybe_part_of_sequence=True)).create_node()
floor_node = Floor(graph, {'name': split_node_name + '/Floor'}).create_node()
cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node()
mul_node.out_port(0).connect(floor_node.in_port(0))
floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0))
cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1))
scales_node.out_port(0).connect(interp_node.in_port(2))
axis_node.out_port(0).connect(interp_node.in_port(3))
match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0))

View File

@@ -39,7 +39,11 @@ graph_node_attrs_for_2d_spatial_case = {
'op': 'Const',
'type': 'Const'
},
'split_axis_const_data': {'value': None, 'shape': np.array(3, dtype=np.int64).shape, 'kind': 'data'},
'split_axis_const_data': {
'value': np.array(3, dtype=np.int64),
'shape': np.array(3, dtype=np.int64).shape,
'kind': 'data'
},
'concat': {'type': 'Concat', 'kind': 'op', 'axis': 3},
'split_data_0': {'value': None, 'shape': int64_array([1, 100, 120, 50]), 'kind': 'data'},
'split_data_1': {'value': None, 'shape': int64_array([1, 100, 120, 50]), 'kind': 'data'},
@@ -66,7 +70,11 @@ graph_node_attrs_for_3d_spatial_case = {
'op': 'Const',
'type': 'Const'
},
'split_axis_const_data': {'value': None, 'shape': np.array(4, dtype=np.int64).shape, 'kind': 'data'},
'split_axis_const_data': {
'value': np.array(4, dtype=np.int64),
'shape': np.array(4, dtype=np.int64).shape,
'kind': 'data'
},
'concat': {'type': 'Concat', 'kind': 'op', 'axis': 4},
'split_data_0': {'value': None, 'shape': int64_array([1, 3, 100, 120, 50]), 'kind': 'data'},
'split_data_1': {'value': None, 'shape': int64_array([1, 3, 100, 120, 50]), 'kind': 'data'},
@@ -99,7 +107,7 @@ graph_edges = [
]
ref_graph_edges = [
ref_graph_edges_opset4 = [
('placeholder', 'placeholder_data'),
('placeholder_data', 'interpolate', {'in': 0}),
('placeholder_data', 'shape'),
@@ -110,17 +118,100 @@ ref_graph_edges = [
('slice_end', 'slice_end_data'),
('slice_end_data', 'sslice', {'in': 2}),
('sslice', 'sslice_data'),
('sslice_data', 'cast_shape_to_float'),
('cast_shape_to_float', 'cast_shape_to_float_data'),
('scales', 'scales_data'),
('sslice_data', 'mul', {'in': 0}),
('scales_data', 'mul', {'in': 1}),
('axes', 'axes_data'),
('cast_shape_to_float_data', 'mul', {'in': 0}),
('scales_data', 'mul', {'in': 1, 'out': 0}),
('mul', 'mul_data'),
('mul_data', 'interpolate', {'in': 1}),
('mul_data', 'floor'),
('floor', 'floor_data'),
('floor_data', 'cast_mul_to_float'),
('cast_mul_to_float', 'cast_mul_to_float_data'),
('cast_mul_to_float_data', 'interpolate', {'in': 1}),
('scales_data', 'interpolate', {'in': 2, 'out': 0}),
('axes_data', 'interpolate', {'in': 3}),
('interpolate', 'interpolate_data'),
('interpolate_data', 'abs'),
('abs', 'abs_data'),
('abs_data', 'output'),
]
ref_graph_node_attrs_for_2d_spatial_case_1_opset4 = {
'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'placeholder_data': {
'value': None,
'shape': int64_array([1, 100, 120, 150]),
'kind': 'data',
'data_type': None
},
'interpolate': {
'type': 'Interpolate',
'kind': 'op',
'op': 'Interpolate',
'mode': 'nearest',
'antialias': 0,
'pads_begin': int64_array([0]),
'pads_end': int64_array([0]),
'coordinate_transformation_mode': 'half_pixel',
'nearest_mode': 'round_prefer_floor',
'cube_coeff': -0.75,
'version': 'opset4',
'shape_calculation_mode': 'scales'
},
'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data', 'shape': None, 'value': None},
'slice_begin': {
'type': 'Const',
'op': 'Const',
'kind': 'op',
'value': int64_array([3]),
'shape': int64_array([1])
},
'slice_begin_data': {'kind': 'data', 'shape': int64_array([1]), 'value': int64_array([3])},
'slice_end': {'type': 'Const', 'op': 'Const', 'kind': 'op', 'value': int64_array([4]), 'shape': int64_array([1])},
'slice_end_data': {'kind': 'data', 'value': int64_array([4]), 'shape': int64_array([1])},
'sslice': {
'kind': 'op',
'type': 'StridedSlice',
'op': 'StridedSlice',
'begin_mask': int64_array([1]),
'end_mask': int64_array([1]),
'new_axis_mask': int64_array([0]),
'shrink_axis_mask': int64_array([0]),
'ellipsis_mask': int64_array([0]),
},
'sslice_data': {'kind': 'data', 'shape': None},
'scales': {
'type': 'Const',
'op': 'Const',
'kind': 'op',
'value': np.array([2], dtype=np.float32),
'shape': int64_array([1])
},
'scales_data': {'kind': 'data', 'shape': None},
'cast_shape_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.float32},
'cast_shape_to_float_data': {'kind': 'data', 'shape': None},
'axes': {
'type': 'Const',
'op': 'Const',
'kind': 'op',
'value': int64_array([3]),
'shape': int64_array([1])
},
'axes_data': {'kind': 'data', 'shape': None},
'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
'mul_data': {'kind': 'data', 'shape': None},
'floor': {'kind': 'op', 'op': 'Floor', 'type': 'Floor'},
'floor_data': {'kind': 'data', 'shape': None},
'cast_mul_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64},
'cast_mul_to_float_data': {'kind': 'data', 'shape': None},
'interpolate_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
'abs_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
'output': {'kind': 'op', 'op': 'Result'},
}
ref_graph_node_attrs_for_2d_spatial_case_1 = {
'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
@@ -168,6 +259,14 @@ ref_graph_node_attrs_for_2d_spatial_case_1 = {
'shape': int64_array([1])
},
'scales_data': {'kind': 'data', 'shape': None},
'axes': {
'type': 'Const',
'op': 'Const',
'kind': 'op',
'value': int64_array([3]),
'shape': int64_array([1])
},
'axes_data': {'kind': 'data', 'shape': None},
'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
'mul_data': {'kind': 'data', 'shape': None},
'interpolate_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
@@ -188,8 +287,15 @@ ref_graph_node_attrs_for_2d_spatial_case_2 = {
'type': 'Interpolate',
'kind': 'op',
'op': 'Interpolate',
'axes': int64_array([2]),
'mode': 'nearest'
'mode': 'nearest',
'antialias': 0,
'pads_begin': int64_array([0]),
'pads_end': int64_array([0]),
'coordinate_transformation_mode': 'half_pixel',
'nearest_mode': 'round_prefer_floor',
'cube_coeff': -0.75,
'version': 'opset4',
'shape_calculation_mode': 'scales'
},
'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data', 'shape': None, 'value': None},
@@ -218,12 +324,26 @@ ref_graph_node_attrs_for_2d_spatial_case_2 = {
'type': 'Const',
'op': 'Const',
'kind': 'op',
'value': int64_array([2]),
'value': np.array([2], dtype=np.float32),
'shape': int64_array([1])
},
'scales_data': {'kind': 'data', 'shape': None},
'cast_shape_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.float32},
'cast_shape_to_float_data': {'kind': 'data', 'shape': None},
'axes': {
'type': 'Const',
'op': 'Const',
'kind': 'op',
'value': int64_array([3]),
'shape': int64_array([1])
},
'axes_data': {'kind': 'data', 'shape': None},
'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
'mul_data': {'kind': 'data', 'shape': None},
'floor': {'kind': 'op', 'op': 'Floor', 'type': 'Floor'},
'floor_data': {'kind': 'data', 'shape': None},
'cast_mul_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64},
'cast_mul_to_float_data': {'kind': 'data', 'shape': None},
'interpolate_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
'abs_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
@@ -243,8 +363,15 @@ ref_graph_node_attrs_for_3d_spatial_case_1 = {
'type': 'Interpolate',
'kind': 'op',
'op': 'Interpolate',
'axes': int64_array([4]),
'mode': 'nearest'
'mode': 'nearest',
'antialias': 0,
'pads_begin': int64_array([0]),
'pads_end': int64_array([0]),
'coordinate_transformation_mode': 'half_pixel',
'nearest_mode': 'round_prefer_floor',
'cube_coeff': -0.75,
'version': 'opset4',
'shape_calculation_mode': 'scales'
},
'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data', 'shape': None, 'value': None},
@@ -273,12 +400,26 @@ ref_graph_node_attrs_for_3d_spatial_case_1 = {
'type': 'Const',
'op': 'Const',
'kind': 'op',
'value': int64_array([2]),
'value': np.array([2], dtype=np.float32),
'shape': int64_array([1])
},
'scales_data': {'kind': 'data', 'shape': None},
'cast_shape_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.float32},
'cast_shape_to_float_data': {'kind': 'data', 'shape': None},
'axes': {
'type': 'Const',
'op': 'Const',
'kind': 'op',
'value': int64_array([3]),
'shape': int64_array([1])
},
'axes_data': {'kind': 'data', 'shape': None},
'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
'mul_data': {'kind': 'data', 'shape': None},
'floor': {'kind': 'op', 'op': 'Floor', 'type': 'Floor'},
'floor_data': {'kind': 'data', 'shape': None},
'cast_mul_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64},
'cast_mul_to_float_data': {'kind': 'data', 'shape': None},
'interpolate_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
@@ -298,8 +439,15 @@ ref_graph_node_attrs_for_3d_spatial_case_2 = {
'type': 'Interpolate',
'kind': 'op',
'op': 'Interpolate',
'axes': int64_array([3]),
'mode': 'nearest'
'mode': 'nearest',
'antialias': 0,
'pads_begin': int64_array([0]),
'pads_end': int64_array([0]),
'coordinate_transformation_mode': 'half_pixel',
'nearest_mode': 'round_prefer_floor',
'cube_coeff': -0.75,
'version': 'opset4',
'shape_calculation_mode': 'scales'
},
'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
'shape_data': {'kind': 'data', 'shape': None, 'value': None},
@@ -328,12 +476,26 @@ ref_graph_node_attrs_for_3d_spatial_case_2 = {
'type': 'Const',
'op': 'Const',
'kind': 'op',
'value': int64_array([2]),
'value': np.array([2], dtype=np.float32),
'shape': int64_array([1])
},
'scales_data': {'kind': 'data', 'shape': None},
'cast_shape_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.float32},
'cast_shape_to_float_data': {'kind': 'data', 'shape': None},
'axes': {
'type': 'Const',
'op': 'Const',
'kind': 'op',
'value': int64_array([3]),
'shape': int64_array([1])
},
'axes_data': {'kind': 'data', 'shape': None},
'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
'mul_data': {'kind': 'data', 'shape': None},
'floor': {'kind': 'op', 'op': 'Floor', 'type': 'Floor'},
'floor_data': {'kind': 'data', 'shape': None},
'cast_mul_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64},
'cast_mul_to_float_data': {'kind': 'data', 'shape': None},
'interpolate_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
@@ -348,8 +510,8 @@ class SplitConcatPairToInterpolateTest(unittest.TestCase):
edges=graph_edges
)
ref_graph = build_graph(
nodes_attrs=ref_graph_node_attrs_for_2d_spatial_case_1,
edges=ref_graph_edges
nodes_attrs=ref_graph_node_attrs_for_2d_spatial_case_1_opset4,
edges=ref_graph_edges_opset4
)
SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, ref_graph, 'output')
@@ -367,7 +529,11 @@ class SplitConcatPairToInterpolateTest(unittest.TestCase):
'op': 'Const',
'type': 'Const'
},
'split_axis_const_data': {'value': None, 'shape': np.array(2, dtype=np.int64).shape, 'kind': 'data'},
'split_axis_const_data': {
'value': np.array(2, dtype=np.int64),
'shape': np.array(2, dtype=np.int64).shape,
'kind': 'data'
},
'concat': {'type': 'Concat', 'kind': 'op', 'axis': 2},
'split_data_0': {'value': None, 'shape': int64_array([1, 100, 40, 150]), 'kind': 'data'},
'split_data_1': {'value': None, 'shape': int64_array([1, 100, 40, 150]), 'kind': 'data'},
@@ -378,7 +544,10 @@ class SplitConcatPairToInterpolateTest(unittest.TestCase):
)
ref_graph = build_graph(
nodes_attrs=ref_graph_node_attrs_for_2d_spatial_case_2,
edges=ref_graph_edges
edges=ref_graph_edges_opset4,
update_attributes={
'axes': {'shape': int64_array([1]), 'value': int64_array([2])}
}
)
SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, ref_graph, 'output')
@@ -391,7 +560,10 @@ class SplitConcatPairToInterpolateTest(unittest.TestCase):
)
ref_graph = build_graph(
nodes_attrs=ref_graph_node_attrs_for_3d_spatial_case_1,
edges=ref_graph_edges
edges=ref_graph_edges_opset4,
update_attributes={
'axes': {'shape': int64_array([1]), 'value': int64_array([4])}
}
)
SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, ref_graph, 'output')
@@ -409,7 +581,11 @@ class SplitConcatPairToInterpolateTest(unittest.TestCase):
'op': 'Const',
'type': 'Const'
},
'split_axis_const_data': {'value': None, 'shape': np.array(3, dtype=np.int64).shape, 'kind': 'data'},
'split_axis_const_data': {
'value': np.array(3, dtype=np.int64),
'shape': np.array(3, dtype=np.int64).shape,
'kind': 'data'
},
'concat': {'type': 'Concat', 'kind': 'op', 'axis': 3},
'split_data_0': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
'split_data_1': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
@@ -420,7 +596,7 @@ class SplitConcatPairToInterpolateTest(unittest.TestCase):
)
ref_graph = build_graph(
nodes_attrs=ref_graph_node_attrs_for_3d_spatial_case_2,
edges=ref_graph_edges
edges=ref_graph_edges_opset4
)
SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, ref_graph, 'output')

View File

@@ -15,12 +15,14 @@
"""
import logging as log
import numpy as np
from extensions.middle.InterpolateSequenceToInterpolate import InterpolateSequenceToInterpolate
from extensions.ops.activation_ops import Floor
from extensions.ops.Cast import Cast
from extensions.ops.elementwise import Mul
from extensions.ops.interpolate import Interpolate
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Graph
from mo.front.common.partial_infer.utils import int64_array, float32_array
from mo.graph.graph import Graph, rename_nodes
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
from mo.ops.shape import Shape
@@ -53,6 +55,7 @@ class UnsqueezeTileReshapeBlockToInterpolate(MiddleReplacementPattern):
force_shape_inference = True
def run_before(self):
from extensions.middle.InterpolateSequenceToInterpolate import InterpolateSequenceToInterpolate
return [InterpolateSequenceToInterpolate]
def pattern(self):
@@ -91,19 +94,20 @@ class UnsqueezeTileReshapeBlockToInterpolate(MiddleReplacementPattern):
if len(input_shape_of_unsqueeze) not in {4, 5}:
return
scale = int64_array([second_input_of_tile.value[d_idx]])
scale = float32_array([second_input_of_tile.value[d_idx]])
axis = d_idx - 1
axis_node = Const(graph, {'name': unsqueeze_name + '/axis', 'value': int64_array([axis])}).create_node()
shape_node = Shape(graph, dict(name=unsqueeze_name + '/Shape_')).create_node()
scales_node = Const(graph, dict(name=unsqueeze_name + '/scales_', value=scale)).create_node()
mul_node = Mul(graph, dict(name=unsqueeze_name + '/Mul_')).create_node()
shape_node = Shape(graph, dict(name=unsqueeze_name + '/Shape')).create_node()
scales_node = Const(graph, dict(name=unsqueeze_name + '/scales', value=scale)).create_node()
mul_node = Mul(graph, dict(name=unsqueeze_name + '/Mul')).create_node()
scales_node.out_port(0).connect(mul_node.in_port(1))
slice_begin = Const(graph, dict(name=unsqueeze_name + '/slice_begin_', value=int64_array([axis]))).create_node()
slice_end = Const(graph, dict(name=unsqueeze_name + '/slice_end_', value=int64_array([axis + 1]))).create_node()
slice_begin = Const(graph, dict(name=unsqueeze_name + '/slice_begin', value=int64_array([axis]))).create_node()
slice_end = Const(graph, dict(name=unsqueeze_name + '/slice_end', value=int64_array([axis + 1]))).create_node()
strided_slice_node = StridedSlice(graph,
{'name': unsqueeze_name + '/StridedSlice_',
{'name': unsqueeze_name + '/StridedSlice',
'begin_mask': int64_array([1]),
'end_mask': int64_array([1]),
'new_axis_mask': int64_array([0]),
@@ -113,14 +117,36 @@ class UnsqueezeTileReshapeBlockToInterpolate(MiddleReplacementPattern):
shape_node.out_port(0).connect(strided_slice_node.in_port(0))
slice_begin.out_port(0).connect(strided_slice_node.in_port(1))
slice_end.out_port(0).connect(strided_slice_node.in_port(2))
strided_slice_node.out_port(0).connect(mul_node.in_port(0))
interp_node = Interpolate(graph, dict(name=unsqueeze_name + '/Interpolate_',
axes=int64_array([axis]),
mode='nearest')).create_node()
mul_node.out_port(0).connect(interp_node.in_port(1))
cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node()
match['reshape'].out_port(0).get_connection().set_source(interp_node.out_port(0))
strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0))
cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))
interp_node = Interpolate(graph,
dict(mode='nearest',
antialias=0, pads_begin=int64_array([0]),
pads_end=int64_array([0]), coordinate_transformation_mode='half_pixel',
nearest_mode='round_prefer_floor', cube_coeff=-0.75,
version='opset4', shape_calculation_mode='scales',
in_ports_count=4,
maybe_part_of_sequence=True)).create_node()
floor_node = Floor(graph, {'name': unsqueeze_name + '/Floor'}).create_node()
cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node()
mul_node.out_port(0).connect(floor_node.in_port(0))
floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0))
cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1))
scales_node.out_port(0).connect(interp_node.in_port(2))
axis_node.out_port(0).connect(interp_node.in_port(3))
reshape_node = match['reshape']
reshape_node.out_port(0).get_connection().set_source(interp_node.out_port(0))
reshape_name = reshape_node.soft_get('name', reshape_node.id)
rename_nodes([(reshape_node, reshape_name + '/delete'), (interp_node, reshape_name)])
unsqueeze_connection = match['unsqueeze'].in_port(0).get_connection()
before_unsqueeze = unsqueeze_connection.get_source().node

View File

@@ -15,6 +15,7 @@
"""
import unittest
import numpy as np
from extensions.middle.UnsqueezeTileReshapeBlockToInterpolate import UnsqueezeTileReshapeBlockToInterpolate
from mo.front.common.partial_infer.utils import int64_array
@@ -112,7 +113,7 @@ graph_edges = [
]
ref_graph_node_attrs = {
ref_graph_node_attrs_with_4_inputs_interpolate = {
'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'placeholder_data': {
'value': None,
@@ -169,7 +170,7 @@ ref_graph_node_attrs = {
'kind': 'op',
'op': 'Const',
'type': 'Const',
'value': int64_array([2]),
'value': np.array([2], dtype=np.float32),
'shape': int64_array([1]),
},
'scales_data': {
@@ -177,18 +178,36 @@ ref_graph_node_attrs = {
'value': None,
'shape': None,
},
'cast_shape_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.float32},
'cast_shape_to_float_data': {'kind': 'data', 'shape': None},
'mul': {'type': 'Mul', 'kind': 'op', 'op': 'Mul'},
'mul_data': {
'kind': 'data',
'value': None,
'shape': None,
},
'floor': {'kind': 'op', 'op': 'Floor', 'type': 'Floor'},
'floor_data': {'kind': 'data', 'shape': None},
'cast_mul_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64},
'cast_mul_to_float_data': {'kind': 'data', 'shape': None},
'interpolate': {'type': 'Interpolate', 'kind': 'op', 'op': 'Interpolate'},
'interpolate_data': {
'kind': 'data',
'value': None,
'shape': int64_array([1, 16, 32, 32, 64]),
},
'axes': {
'kind': 'op',
'op': 'Const',
'type': 'Const',
'value': int64_array([1]),
'shape': int64_array([1]),
},
'axes_data': {
'kind': 'data',
'value': None,
'shape': None,
},
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
'abs_data': {
'kind': 'data',
@@ -198,7 +217,8 @@ ref_graph_node_attrs = {
'output': {'kind': 'op', 'op': 'Result', 'type': 'Result'},
}
ref_graph_edges = [
ref_graph_edges_attrs_with_4_inputs_interpolate = [
('placeholder', 'placeholder_data'),
('placeholder_data', 'shapeof'),
('shapeof', 'shapeof_data'),
@@ -209,11 +229,20 @@ ref_graph_edges = [
('end', 'end_data'),
('end_data', 'strided_slice', {'in': 2}),
('scales', 'scales_data'),
('strided_slice_data', 'mul', {'in': 0}),
('scales_data', 'mul', {'in': 1}),
('strided_slice_data', 'cast_shape_to_float'),
('cast_shape_to_float', 'cast_shape_to_float_data'),
('cast_shape_to_float_data', 'mul', {'in': 0}),
('scales_data', 'mul', {'out': 0, 'in': 1}),
('scales_data', 'interpolate', {'out': 0, 'in': 2}),
('mul', 'mul_data'),
('mul_data', 'interpolate', {'in': 1}),
('mul_data', 'floor'),
('floor', 'floor_data'),
('floor_data', 'cast_mul_to_float'),
('cast_mul_to_float', 'cast_mul_to_float_data'),
('cast_mul_to_float_data', 'interpolate', {'in': 1}),
('placeholder_data', 'interpolate', {'in': 0}),
('axes', 'axes_data'),
('axes_data', 'interpolate', {'in': 3}),
('interpolate', 'interpolate_data'),
('interpolate_data', 'abs'),
('abs', 'abs_data'),
@@ -298,7 +327,8 @@ graph_edges_when_transformation_is_not_applicable = graph_edges
class UnsqueezeTileReshapeBlockToInterpolateTest(unittest.TestCase):
def test_5d(self):
graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges)
ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs, edges=ref_graph_edges)
ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs_with_4_inputs_interpolate,
edges=ref_graph_edges_attrs_with_4_inputs_interpolate)
UnsqueezeTileReshapeBlockToInterpolate().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, ref_graph, 'output')
self.assertTrue(flag, resp)
@@ -320,12 +350,13 @@ class UnsqueezeTileReshapeBlockToInterpolateTest(unittest.TestCase):
}
)
ref_graph = build_graph(
nodes_attrs=ref_graph_node_attrs,
edges=ref_graph_edges,
nodes_attrs=ref_graph_node_attrs_with_4_inputs_interpolate,
edges=ref_graph_edges_attrs_with_4_inputs_interpolate,
update_attributes={
'placeholder_data': {'shape': int64_array([1, 8, 32, 32])},
'interpolate_data': {'shape': int64_array([1, 16, 32, 32])},
'abs_data': {'shape': int64_array([1, 16, 32, 32])},
'axes': {'shape': int64_array([1]), 'value': int64_array([1])},
}
)
UnsqueezeTileReshapeBlockToInterpolate().find_and_replace_pattern(graph)

View File

@@ -26,8 +26,9 @@ from extensions.ops.interpolate import Interpolate
from mo.front.common.layout import get_height_dim, get_width_dim, get_depth_dim
from mo.front.common.partial_infer.utils import int64_array
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
from mo.graph.graph import Graph, Node
from mo.graph.graph import Graph, Node, rename_nodes
from mo.middle.replacement import MiddleReplacementPattern
from mo.ops.const import Const
from mo.ops.shape import Shape
from mo.ops.strided_slice import StridedSlice
@@ -63,6 +64,8 @@ class UpsampleToResample(MiddleReplacementPattern):
return
depth_scale = None
layout = graph.graph['layout']
if len(upsample.in_nodes()) == 2:
if upsample.in_node(1).value is None:
return
@@ -71,10 +74,10 @@ class UpsampleToResample(MiddleReplacementPattern):
len(scales), upsample_name)
if not (math.isclose(scales[0], 1, rel_tol=1e-5) and math.isclose(scales[1], 1, rel_tol=1e-5)):
return
height_scale = scales[2]
width_scale = scales[3]
height_scale = scales[get_height_dim(layout, input_shape_rank)]
width_scale = scales[get_width_dim(layout, input_shape_rank)]
if len(scales) == 5:
depth_scale = scales[4]
depth_scale = scales[get_depth_dim(layout, input_shape_rank)]
else:
height_scale = upsample['height_scale']
width_scale = upsample['width_scale']
@@ -82,6 +85,7 @@ class UpsampleToResample(MiddleReplacementPattern):
if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
upsample.in_port(1).disconnect()
upsample_name = upsample.soft_get('name', upsample.id)
shape = Shape(graph, {'name': upsample_name + '/0_port'}).create_node()
layout = graph.graph['layout']
@@ -104,10 +108,9 @@ class UpsampleToResample(MiddleReplacementPattern):
'new_axis_mask': int64_array([0]),
'shrink_axis_mask': int64_array([0]),
'ellipsis_mask': int64_array([0])
}
)
})
mul = create_op_node_with_second_input(graph, Mul, factor_value, {'name': upsample_name + '/factor_mul_'})
mul = create_op_node_with_second_input(graph, Mul, factor_value, {'name': upsample_name + '/factor_mul'})
source = upsample.in_port(0).get_connection().get_source()
source.connect(shape.in_port(0))
@@ -124,16 +127,27 @@ class UpsampleToResample(MiddleReplacementPattern):
get_height_dim(layout, input_shape_rank),
get_width_dim(layout, input_shape_rank)])
resample_op = Interpolate(graph, dict(name=upsample_name + '/Interpolate',
axes=axes, mode=upsample.attrs()['mode'],
antialias=0, convert_to_resample=True)).create_node()
axes_node = Const(graph, {'name': upsample_name + '/axis', 'value': axes}).create_node()
interpolate = Interpolate(graph, {'mode': upsample.attrs()['mode'], 'antialias': 0,
'pads_begin': int64_array([0]), 'pads_end': int64_array([0]),
'coordinate_transformation_mode': 'half_pixel',
'nearest_mode': 'round_prefer_floor', 'cube_coeff': -0.75,
'shape_calculation_mode': 'scales',
'version': 'opset4', 'in_ports_count': 4}).create_node()
upsample.add_input_port(1, skip_if_exist=True)
assert upsample.in_port(1).disconnected()
mul.out_port(0).connect(resample_op.in_port(1))
mul.out_port(0).connect(interpolate.in_port(1))
axes_node.out_port(0).connect(interpolate.in_port(3))
upsample.in_port(0).get_connection().set_destination(resample_op.in_port(0))
upsample.out_port(0).get_connection().set_source(resample_op.out_port(0))
scales_node = Const(graph, {'name': upsample_name + '/scales', 'value': factor_value}).create_node()
scales_node.out_port(0).connect(interpolate.in_port(2))
upsample.in_port(0).get_connection().set_destination(interpolate.in_port(0))
upsample.out_port(0).get_connection().set_source(interpolate.out_port(0))
rename_nodes([(upsample, upsample_name + '/delete'), (interpolate, upsample_name)])
convert_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()
convert_to_int = Cast(graph, dict(dst_type=np.int64)).create_node()

View File

@@ -43,6 +43,66 @@ graph_edges = [
('upsample_data', 'output'),
]
new_ref_graph_node_attr = {
'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'placeholder_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': np.float32},
'ss_begin': {'kind': 'op', 'op': 'Const', 'type': 'Const', 'value': int64_array([2]), 'shape': int64_array([1])},
'ss_begin_data': {'kind': 'data', 'value': int64_array([2]), 'shape': int64_array([1])},
'ss_end': {'kind': 'op', 'op': 'Const', 'type': 'Const', 'value': int64_array([4]), 'shape': int64_array([1])},
'ss_end_data': {'kind': 'data', 'value': int64_array([4]), 'shape': int64_array([1])},
'ss_stride': {'kind': 'op', 'op': 'Const', 'type': 'Const', 'value': int64_array([1]), 'shape': int64_array([1])},
'ss_stride_data': {'kind': 'data', 'value': int64_array([1]), 'shape': int64_array([1])},
'strided_slice': {'type': 'StridedSlice', 'kind': 'op', 'op': 'StridedSlice'},
'strided_slice_data': {'kind': 'data', 'shape': None, 'value': None},
'cast_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.float},
'cast_to_float_d': {'kind': 'data', 'value': None, 'shape': None},
'factor': {'kind': 'op', 'op': 'Const', 'type': 'Const', 'value': int64_array([5, 5]), 'shape': int64_array([2])},
'factor_data': {'kind': 'data', 'value': int64_array([5, 5]), 'shape': int64_array([2])},
'shapeof': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
'shapeof_data': {'kind': 'data', 'shape': None, 'value': None},
'mul': {'type': 'Multiply', 'kind': 'op', 'op': 'Multiply'},
'mul_data': {'kind': 'data', 'shape': None, 'value': None},
'cast_to_int': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int32},
'cast_to_int_d': {'kind': 'data', 'shape': None, 'value': None},
'axes_const': {'kind': 'op', 'op': 'Const', 'type': 'Const', 'value': None, 'shape': None},
'axes_const_data': {'kind': 'data', 'value': None, 'shape': None},
'scales': {'kind': 'op', 'op': 'Const', 'type': 'Const', 'value': int64_array([5, 5]), 'shape': int64_array([2])},
'scales_data': {'kind': 'data', 'value': None, 'shape': None},
'interpolate': {'type': 'Interpolate', 'kind': 'op', 'op': 'Interpolate', 'axes': None},
'interpolate_data': {'kind': 'data', 'shape': None, 'value': None},
'output': {'kind': 'op', 'op': 'Result', 'type': 'Result'},
}
new_ref_graph_edges = [
('placeholder', 'placeholder_data'),
('placeholder_data', 'shapeof', {'in': 0, 'out': 0}),
('placeholder_data', 'interpolate', {'in': 0, 'out': 0}),
('ss_begin', 'ss_begin_data'),
('ss_begin_data', 'strided_slice', {'in': 1, 'out': 0}),
('ss_end', 'ss_end_data'),
('ss_end_data', 'strided_slice', {'in': 2, 'out': 0}),
('ss_stride', 'ss_stride_data'),
('ss_stride_data', 'strided_slice', {'in': 3, 'out': 0}),
('strided_slice', 'strided_slice_data'),
('strided_slice_data', 'cast_to_float'),
('cast_to_float', 'cast_to_float_d'),
('shapeof', 'shapeof_data'),
('shapeof_data', 'strided_slice', {'in': 0, 'out': 0}),
('factor', 'factor_data'),
('cast_to_float_d', 'mul', {'in': 0, 'out': 0}),
('factor_data', 'mul', {'in': 1, 'out': 0}),
('mul', 'mul_data'),
('mul_data', 'cast_to_int'),
('cast_to_int', 'cast_to_int_d'),
('cast_to_int_d', 'interpolate', {'in': 1, 'out': 0}),
('axes_const', 'axes_const_data'),
('axes_const_data', 'interpolate', {'in': 3, 'out': 0}),
('scales', 'scales_data'),
('scales_data', 'interpolate', {'in': 2, 'out': 0}),
('interpolate', 'interpolate_data'),
('interpolate_data', 'output')
]
ref_graph_node_attrs = {
'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
'placeholder_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': np.float32},
@@ -98,30 +158,40 @@ ref_graph_edges = [
@generator
class UpsampleToResampleTest(unittest.TestCase):
@generate(*[([2, 10, 20, 30], [1, 1, 5, 5],),
([2, 20, 30, 40], [1, 1, 3, 3],),
([2, 10, 20, 30], [1, 1, 6, 5],),
([2, 20, 30, 40], [1, 1, 3, 4],),
([2, 3, 20, 30, 40], [1, 1, 3, 3, 3],),
([2, 3, 20, 30, 40], [1, 1, 3, 4, 3],),
([2, 3, 20, 30, 40], [1, 1, 4, 3, 3],),
([2, 3, 20, 30, 40], [1, 1, 3, 3, 4],),
@generate(*[([2, 10, 20, 30], [1, 1, 5, 5], [2, 3]),
([2, 20, 30, 40], [1, 1, 3, 3], [2, 3]),
([2, 10, 20, 30], [1, 1, 6, 5], [2, 3]),
([2, 20, 30, 40], [1, 1, 3, 4], [2, 3]),
([2, 3, 20, 30, 40], [1, 1, 3, 3, 3], [2, 3, 4]),
([2, 3, 20, 30, 40], [1, 1, 3, 4, 3], [2, 3, 4]),
([2, 3, 20, 30, 40], [1, 1, 4, 3, 3], [2, 3, 4]),
([2, 3, 20, 30, 40], [1, 1, 3, 3, 4], [2, 3, 4]),
])
def test_conversion(self, input_shape, scales):
graph = build_graph(graph_node_attrs, graph_edges,
{'placeholder_data': {'shape': int64_array(input_shape)},
'scales': {'value': int64_array(scales), 'shape': int64_array(scales).shape},
'scales_data': {'value': int64_array(scales), 'shape': int64_array(scales).shape},
'upsample_data': {'shape': int64_array(input_shape) * int64_array(scales)}})
def test_conversion(self, input_shape, scales, axes):
graph = build_graph(graph_node_attrs,
graph_edges,
{
'placeholder_data': {'shape': int64_array(input_shape)},
'scales': {'value': int64_array(scales), 'shape': int64_array(scales).shape},
'scales_data': {'value': int64_array(scales), 'shape': int64_array(scales).shape},
'upsample_data': {'shape': int64_array(input_shape) * int64_array(scales)}
})
graph.graph['layout'] = 'NCHW'
ref_graph = build_graph(ref_graph_node_attrs, ref_graph_edges,
{'placeholder_data': {'shape': int64_array(input_shape)},
'factor': {'value': int64_array(scales)[2:], 'shape': int64_array(scales[2:]).shape},
'interpolate_data': {'shape': int64_array(input_shape) * int64_array(scales)},
'interpolate': {'axes': list(range(2, len(input_shape)))}}
)
ref_graph = build_graph(new_ref_graph_node_attr,
new_ref_graph_edges,
{
'placeholder_data': {'shape': int64_array(input_shape)},
'ss_begin': {'value': int64_array([axes[0]])},
'ss_end': {'value': int64_array([axes[-1] + 1])},
'ss_begin_data': {'value': int64_array([axes[0]])},
'ss_end_data': {'value': int64_array([axes[-1] + 1])},
'factor': {'value': int64_array(scales)[2:],
'shape': int64_array(scales[2:]).shape},
'factor_data': {'value': int64_array(scales)[2:],
'shape': int64_array(scales[2:]).shape},
'axes_const': {'value': int64_array(axes), 'shape': int64_array(axes).shape},
'interpolate_data': {'shape': int64_array(input_shape) * int64_array(scales)},
})
UpsampleToResample().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, ref_graph, 'output')
self.assertTrue(flag, resp)

View File

@@ -0,0 +1,31 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.graph.graph import Node, Graph
from mo.ops.op import Op
class ONNXResize10(Op):
op = 'ONNXResize10'
def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'op': self.op,
'in_ports_count': 2,
'out_ports_count': 1,
}
super().__init__(graph, mandatory_props, attrs)

View File

@@ -0,0 +1,67 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from mo.front.common.layout import get_height_dim, get_width_dim
from mo.front.common.partial_infer.utils import int64_array
from mo.graph.graph import Node, Graph
from mo.ops.op import Op
class TFResize(Op):
op = 'TFResize'
def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'op': self.op,
'out_ports_count': 1,
'in_ports_count': 2,
'infer': TFResize.tf_resize_infer
}
super().__init__(graph, mandatory_props, attrs)
@staticmethod
def tf_resize_infer(node: Node):
input_shape = node.in_port(0).data.get_shape()
if input_shape is None:
return
attrs_msg = "If half_pixel_centers attribute of the node {} with op {} is True, " \
"the attribute align_corners must be False"
node_name = node.soft_get('name', node.id)
assert not node.half_pixel_centers or (node.half_pixel_centers and not node.align_corners), \
attrs_msg.format(node_name, node.op)
connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
assert len(connected_in_ports) == 2, \
"Node {} with op {} number of inputs must be equal to 2.".format(node_name, node.op)
new_sizes_value = node.in_port(1).data.get_value()
assert new_sizes_value is not None, "Node {} with op {} has no value in input port 1".format(node_name, node.op)
input_rank = len(input_shape)
assert input_rank == 4, \
"Resized input data of the node {} with op {} must be 4D tensor".format(node_name, node.op)
len_msg = "Op {} with name {} supports only resize with respect to height and width dimension simultaneously"
assert len(new_sizes_value) == 2, len_msg.format(node_name, node.op)
output_shape = int64_array(input_shape.copy())
layout = node.graph.graph['layout']
output_shape[get_height_dim(layout, input_rank)] = new_sizes_value[0]
output_shape[get_width_dim(layout, input_rank)] = new_sizes_value[1]
node.out_port(0).data.set_shape(output_shape)

View File

@@ -22,6 +22,7 @@ import numpy as np
from mo.front.common.partial_infer.utils import int64_array
from mo.front.extractor import bool_to_str
from mo.graph.graph import Node, Graph
from mo.graph.perm_inputs import PermuteInputs
from mo.ops.op import Op, PermuteAttrs
@@ -61,6 +62,9 @@ def infer_for_opset4(node: Node):
for i, axis in enumerate(axes):
output_shape[axis] = math.floor(scales[i] * output_shape[axis] + 1.0e-5)
if node.is_in_port_connected(3):
PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:0', 'axis')
node.out_port(0).data.set_shape(output_shape)
@@ -103,7 +107,8 @@ def correct_scales_using_dst_shape(node, dst_shape, src_shape, axes):
if scales_value is None or len(scales_value) != len(dst_shape):
corrected_scales = np.zeros(len(dst_shape))
for i, axis in enumerate(list(axes)):
corrected_scales[i] = math.floor((dst_shape[i] / src_shape[axis]) + 1.0e-5)
corrected_scales[i] = dst_shape[i] / src_shape[axis]
node.in_port(2).data.set_value(corrected_scales)
class Interpolate(Op):

View File

@@ -1,5 +1,5 @@
"""
Copyright (C) 2020 Intel Corporation
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.

View File

@@ -1,102 +0,0 @@
"""
Copyright (C) 2018-2020 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging as log
from extensions.ops.resize_factor_utils import factor_update
from mo.front.common.layout import get_batch_dim, get_features_dim, get_height_dim, get_width_dim, shape_for_layout
from mo.graph.graph import Node, Graph
from mo.ops.op import Op
class ResampleOp(Op):
enabled = False
op = 'Resample'
def __init__(self, graph: Graph, attrs: dict):
mandatory_props = {
'type': __class__.op,
'op': __class__.op,
'factor': None,
'in_ports_count': 2,
'out_ports_count': 1,
'infer': None
}
super().__init__(graph, mandatory_props, attrs)
def supported_attrs(self):
return [
'antialias',
'height',
'width',
'resample_type',
'factor',
]
def backend_attrs(self):
return [
'antialias',
'height',
'width',
('type', 'resample_type'),
'factor'
]
@staticmethod
def resample_infer(node: Node):
layout = node.graph.graph['layout']
assert len(layout) == 4
input_shape = node.in_node(0).shape
if input_shape is None:
return
in_height = input_shape[get_height_dim(layout, 4)]
in_width = input_shape[get_width_dim(layout, 4)]
if node.has('fw') and node.fw == 'tf':
dst_shape = node.in_node(1).value
if dst_shape is None or len(input_shape) != 4 or len(dst_shape) != 2:
log.error(
'Node {} with op {} cannot be converted to Resample layer because there is no enough info about '
'src/dst shapes: src_shape = {}, dst_shape = {}'.format(node.name, node.op, input_shape, dst_shape))
node.type = None # prevent translation to a valid IE layer
return
out_height = dst_shape[0]
out_width = dst_shape[1]
else:
if len(node.in_nodes()) == 1:
if node.has('width') and node.has('height'):
out_height = node.height
out_width = node.width
else:
out_height = node.factor * in_height
out_width = node.factor * in_width
else:
out_height = node.in_node(1).shape[get_height_dim(layout, 4)]
out_width = node.in_node(1).shape[get_width_dim(layout, 4)]
node.factor = factor_update(
node.factor,
[float(out_height) / in_height, float(out_width) / in_width],
[in_height, in_width],
[out_height, out_width],
node.soft_get('name'))
node.out_node().shape = shape_for_layout(layout,
batch=input_shape[get_batch_dim(layout, 4)],
features=input_shape[get_features_dim(layout, 4)],
height=out_height,
width=out_width)

View File

@@ -28,6 +28,10 @@ def float_array(l: list):
return np.array(l, dtype=np.float64)
def float32_array(l: list):
return np.array(l, dtype=np.float32)
def mark_input_bins(node, names=('weights', 'biases'), start_port: int = 1):
"""
Preparing necessary attributes for edges at input ports starting from start_port.

View File

@@ -232,6 +232,27 @@ void op::v4::Interpolate::validate_and_infer_types()
input_et == element::i8 || input_et == element::bf16,
"Input element type must be f32, f16, bf16 or i8");
element::Type sizes_et = get_input_element_type(1);
NODE_VALIDATION_CHECK(this,
sizes_et == element::i32 || sizes_et == element::i64 ||
sizes_et == element::u32 || sizes_et == element::u64,
"Sizes element type must be i32, i64, u32 or u64");
element::Type scales_et = get_input_element_type(2);
NODE_VALIDATION_CHECK(this,
scales_et == element::f32 || scales_et == element::f16 ||
scales_et == element::bf16,
"Scales element type must be f32, f16 or bf16");
if (input_values().size() == 4)
{
element::Type axes_et = get_input_element_type(3);
NODE_VALIDATION_CHECK(this,
axes_et == element::i64 || axes_et == element::i32 ||
sizes_et == element::u32 || sizes_et == element::u64,
"Axes element type must be i32, i64, u32 or u64");
}
PartialShape input_shape = PartialShape(get_input_partial_shape(0));
if (!input_shape.rank().is_static())
@@ -240,11 +261,20 @@ void op::v4::Interpolate::validate_and_infer_types()
return;
}
const auto input_rank = input_shape.rank().get_length();
// If the input 'axes' is given and this input is not Constant, we cannot infer any elements
// of the output shape. Hence, all components of the output shape should be dynamic.
if (input_values().size() == 4 && !is_type<op::Constant>(input_value(3).get_node()))
{
PartialShape output_shape = std::vector<Dimension>(input_rank, Dimension::dynamic());
set_output_type(0, get_input_element_type(0), output_shape);
return;
}
auto axes = get_axes();
correct_pads();
const auto input_rank = input_shape.rank().get_length();
PartialShape padded_input_shape = get_padded_input_shape(input_shape);
PartialShape output_shape = padded_input_shape;
@@ -257,7 +287,6 @@ void op::v4::Interpolate::validate_and_infer_types()
}
}
set_output_type(0, get_input_element_type(0), output_shape);
if (m_attrs.shape_calculation_mode == ShapeCalcMode::scales)
{
if (const auto& const_scales = get_constant_from_source(input_value(2)))

View File

@@ -175,44 +175,6 @@ namespace ngraph
return scales;
}
OutputVector build_resize(const Node& node,
const std::shared_ptr<ngraph::Node>& output_shape,
const AxisSet& axes)
{
const auto mode = node.get_attribute_value<std::string>("mode", "nearest");
std::unordered_set<std::string> supported_modes = {"nearest", "linear"};
bool is_mode_supported =
(std::find(supported_modes.begin(), supported_modes.end(), mode) !=
supported_modes.end());
if (!is_mode_supported)
{
std::string supported_modes_str = "";
for (const auto& mode_name : supported_modes)
{
supported_modes_str += (mode_name + ", ");
}
CHECK_VALID_NODE(node,
is_mode_supported,
mode,
" - this type of interpolation mode is not supported."
" Choose one of the following modes: ",
supported_modes_str);
}
auto attrs = ngraph::op::v0::InterpolateAttrs();
attrs.axes = axes;
attrs.mode = mode;
attrs.align_corners = false;
const auto inputs = node.get_ng_inputs();
const auto& data = inputs.at(0);
return {
std::make_shared<ngraph::op::v0::Interpolate>(data, output_shape, attrs)};
}
} // namespace
namespace set_11
@@ -273,17 +235,20 @@ namespace ngraph
const auto& data_shape = data.get_partial_shape();
const auto& scales_shape = scales.get_partial_shape();
auto attrs = get_resize_attrs(node);
if (attrs.mode == InterpolateMode::linear_onnx)
{
attrs.coordinate_transformation_mode = Transform_mode::asymmetric;
}
CHECK_VALID_NODE(
node,
(scales_shape.is_static() || data_shape.rank().is_static()),
" Data rank or shape of scales input is required to be static.");
size_t axes_size = scales_shape.is_static() ? scales_shape[0].get_length()
: data_shape.rank().get_length();
const auto output_shape = calculate_output_shape_based_on_scales(data, scales);
return build_resize(
node, output_shape, AxisSet(common::get_monotonic_range(axes_size)));
return {std::make_shared<default_opset::Interpolate>(
data, output_shape, scales, attrs)};
}
} // namespace set_1

View File

@@ -29,7 +29,7 @@ using ShapeCalcMode = op::v4::Interpolate::ShapeCalcMode;
TEST(type_prop, interpolate_v4)
{
auto image = std::make_shared<op::Parameter>(element::f32, Shape{2, 2, 30, 60});
auto target_shape = std::make_shared<op::Parameter>(element::f32, Shape{2, 2, 15, 30});
auto target_shape = std::make_shared<op::Parameter>(element::i32, Shape{15, 30});
auto scales = op::Constant::create<float>(element::f32, Shape{2}, {0.5f, 0.5f});
auto axes = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 3});
@@ -48,12 +48,68 @@ TEST(type_prop, interpolate_v4)
EXPECT_EQ(interp->get_shape(), (Shape{2, 2, 15, 30}));
}
TEST(type_prop, interpolate_v4_non_constant_axes_scales)
{
auto image = std::make_shared<op::Parameter>(element::f32, Shape{2, 2, 30, 60});
auto target_shape = std::make_shared<op::Parameter>(element::i64, Shape{15, 30});
auto scales = op::Constant::create<float>(element::f32, Shape{2}, {0.5f, 0.5f});
auto start = std::make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{2});
auto stop = std::make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{4});
auto step = std::make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{1});
auto axes = std::make_shared<op::v4::Range>(start, stop, step, element::i32);
InterpolateAttrs attrs;
attrs.mode = InterpolateMode::nearest;
attrs.shape_calculation_mode = ShapeCalcMode::scales;
attrs.coordinate_transformation_mode = CoordinateTransformMode::half_pixel;
attrs.nearest_mode = Nearest_mode::round_prefer_floor;
attrs.antialias = false;
attrs.pads_begin = {0, 0, 0, 0};
attrs.pads_end = {0, 0, 0, 0};
attrs.cube_coeff = -0.75;
auto interp = std::make_shared<op::v4::Interpolate>(image, target_shape, scales, axes, attrs);
EXPECT_EQ(interp->get_element_type(), element::f32);
auto dyn_dim = Dimension::dynamic();
auto expected_shape = PartialShape{dyn_dim, dyn_dim, dyn_dim, dyn_dim};
ASSERT_TRUE(interp->get_output_partial_shape(0).same_scheme(expected_shape));
}
TEST(type_prop, interpolate_v4_non_constant_axes_sizes)
{
auto image = std::make_shared<op::Parameter>(element::f32, Shape{2, 2, 30, 60});
auto target_shape = std::make_shared<op::Parameter>(element::i64, Shape{15, 30});
auto scales = op::Constant::create<float>(element::f32, Shape{2}, {0.5f, 0.5f});
auto start = std::make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{2});
auto stop = std::make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{4});
auto step = std::make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{1});
auto axes = std::make_shared<op::v4::Range>(start, stop, step, element::i32);
InterpolateAttrs attrs;
attrs.mode = InterpolateMode::nearest;
attrs.shape_calculation_mode = ShapeCalcMode::sizes;
attrs.coordinate_transformation_mode = CoordinateTransformMode::half_pixel;
attrs.nearest_mode = Nearest_mode::round_prefer_floor;
attrs.antialias = false;
attrs.pads_begin = {0, 0, 0, 0};
attrs.pads_end = {0, 0, 0, 0};
attrs.cube_coeff = -0.75;
auto interp = std::make_shared<op::v4::Interpolate>(image, target_shape, scales, axes, attrs);
EXPECT_EQ(interp->get_element_type(), element::f32);
auto dyn_dim = Dimension::dynamic();
auto expected_shape = PartialShape{dyn_dim, dyn_dim, dyn_dim, dyn_dim};
ASSERT_TRUE(interp->get_output_partial_shape(0).same_scheme(expected_shape));
}
TEST(type_prop, interpolate_v4_partial)
{
auto partial_shape = PartialShape{2, 2, Dimension::dynamic(), Dimension::dynamic()};
auto image = std::make_shared<op::Parameter>(element::f32, partial_shape);
auto target_shape = std::make_shared<op::Parameter>(element::f32, Shape{2, 2, 15, 30});
auto target_shape = std::make_shared<op::Parameter>(element::i32, Shape{15, 30});
auto scales = op::Constant::create<float>(element::f32, Shape{2}, {0.5f, 0.5f});
auto axes = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 3});
@@ -83,7 +139,7 @@ TEST(type_prop, interpolate_v4_partial_static_rank)
auto partial_shape = PartialShape{2, 2, Dimension::dynamic(), Dimension::dynamic()};
auto image = std::make_shared<op::Parameter>(element::f32, partial_shape);
auto target_shape = std::make_shared<op::Parameter>(element::f32, Shape{2, 2, 15, 30});
auto target_shape = std::make_shared<op::Parameter>(element::i32, Shape{15, 30});
auto scales = op::Constant::create<float>(element::f32, Shape{2}, {0.5f, 0.5f});
auto axes = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 3});
@@ -109,7 +165,7 @@ TEST(type_prop, interpolate_v4_partial_static_rank2)
auto out_shape = PartialShape{Dimension::dynamic(), Dimension::dynamic(), 5, 10};
auto image = std::make_shared<op::Parameter>(element::f32, partial_shape);
auto target_shape = std::make_shared<op::Parameter>(element::f32, Shape{2, 2, 15, 30});
auto target_shape = std::make_shared<op::Parameter>(element::i32, Shape{15, 30});
auto scales = op::Constant::create<float>(element::f32, Shape{2}, {0.5f, 0.5f});
auto axes = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 3});
@@ -135,7 +191,7 @@ TEST(type_prop, interpolate_v4_partial_static_rank3)
auto out_shape = PartialShape{Dimension::dynamic(), Dimension::dynamic(), 1, 1};
auto image = std::make_shared<op::Parameter>(element::f32, partial_shape);
auto target_shape = std::make_shared<op::Parameter>(element::f32, Shape{2, 2, 1, 1});
auto target_shape = std::make_shared<op::Parameter>(element::i32, Shape{1, 1});
auto scales = op::Constant::create<float>(element::f32, Shape{2}, {1.0f / 3.0f, 1.0f / 3.0f});
auto axes = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 3});