Completely transition of MO to Interpolate-4 (#970)
* Refactored infer function and function supported_attrs for the layer Interpolate. * Small change. * Deleted unneeded checks in transformations ResizeToInterpolate2D and ResizeToInterpolate3D. * Small fix in the extractor of ONNX Resize. * Now the extractor of TF ResizeBilinear generates Interpolate-1 again, because 'axis' in final version of Interpolate-4 specification is an input but is not attribute. * Now the extractor of TF ResizeNearest generates Interpolate-1 again, because 'axis' in final version of Interpolate-4 specification is an input but is not attribute. * Added static method get_axis into class Interpolate. * Refactored class CanBeFused in the transformation InterpolateSequenceToInterpolate. * Fixed transformation InterpolateSequenceToInterpolate according to the last version of the specification of Interpolate-4. * Started to write support of Interpolate-4 in the transformation InterpolateWithConcat. * Added support for Interpolate-4 into the transformation InterpolateWithConcat. * Added support for Interpolate-4 into the transformation InterpolateConcat. * Added support for Interpolate-4 into the transformation InterpolateReshapeWA. * Added support for Interpolate-4 into the transformation InterpolateTranspose. * Started to add test for opset4 case of the transformation InterpolateSequenceToInterpolate. * Added test for InterpolateSequenceToInterpolate (test_2d_interpolate_sequence_1_opset4_case). * Added test for InterpolateSequenceToInterpolate (test_2d_interpolate_sequence_4_opset4_case). * Added another test for InterpolateSequenceToInterpolate (test_2d_interpolate_sequence_5_opset4_case). * Added another test for InterpolateSequenceToInterpolate (test_3d_interpolate_sequence_1_opset4_case). * Finished addition of tests for opset4 case of InterpolateSequenceToInterpolate. * Small change. * Now opset is only opset1 or opset4 in the transformation InterpolateTranspose. * Small fixes in transformations ResizeToInterpolate2D and ResizeToInterpolate3D. * Deleted reading of unused ONNX attributes. * Fixed docstring of the transformation InterpolateV1ToInterpolateV4. * Added node name in assert about axes input. * Fixes in the definition of the operation ONNXResize11. * Now Interpolate-4 cannot have 'extension' as opset. * Now the transformation InterpolateV1ToInterpolateV4 uses find_and_replace_pattern but not replace_sub_graph. * Fixed tests for transformations InterpolateReshapeWA and InterpolateConcat. * Fixed some tests. * Rewritten operation Interpolate-4 class according to new variant of documentation. * Some fixes in ONNXResize11 operation class. * Now the transformation ONNXResize11ToInterpolate generates Interpolate-4 with 4 inputs. * Now the transformation UpsampleToResample generates Interpolate-4 with 4 inputs. * Now the transformation NearestNeighborUpsampling generates Interpolate-4 with 4 inputs. * Now transformations ResizeToInterpolate2D and ResizeToInterpolate3D generate Interpolate-4 with 4 inputs. * Now the transformation SplitConcatPairToInterpolate generates Interpolate-4 with 4 inputs. * Now the transformation UnsqueezeTileReshapeBlockToInterpolate generates Interpolate-4 with 4 inputs. * Now the transformation InterpolateV1ToInterpolateV4 generates Interpolate-4 with 4 inputs. * Some fixes. * Fixed the transformation InterpolateSequenceToInterpolate according to new variant of Interpolate-4 specification. * Fixed typos. * Added shape_calculation_mode to supported_attrs. * Small fixes. * Added operation ONNXResize10 and the transformation ONNXResize10ToInterpolate4. * Fixed function correct_scales_using_dst_shape. * Some fixes in InterpolateSequenceToInterpolate. * Fixed bug in the method __call__ of the class CanBeFused: now self.accumulated_axes is correctly cleared in all cases. * Small change. * Fixed tests for the transformation SplitConcatPairToInterpolate. * Now transformations InterpolateWithConcat, InterpolateReshapeWA, InterpolateConcat support Interpolate-4. * Fixed the transformation InterpolateTranspose for the case of Interpolate-4. * Written the back transformation InterpolateV4AxesCorrection to convert 'axes' input of Interpolate-4 from NHWC to NCHW layout. * Added PermuteInput in Interpolate-4 infer. * Fixed typos. * Deleted the transformation InterpolateAxesCorrection. * Now Interpolate-4 permutes axis, not shape in input port 3. * Small fix. * Some fix. * Fixed bug in the transformation UpsampleToResample. * Added some debug prints. * Added more debug prints. * Now ONNX Upsample-9 operation is read as ONNXResize10. * Small fix. * Small fixes. * Fixed tests for the transformation SplitConcatPairToInterpolate. * Deleted debug prints. * Deleted some debug prints. * Fixes in the transformation UnsqueezeTileReshapeBlockToInterpolate and its tests. * Small fix in the transformation InterpolateSequenceToInterpolate. * Started to write nGraph transformation to convert Interpolate-1 to Interpolate-4. * Deleted redundant files. * Small fixes. * Small fix. * Written draft of the transformation Interpolate-1 -> Interpolate-4. * Small fix. * Now ONNX Importer reads Resize-10 as Interpolate-4. * Fixes in the test onnx_model_resize10_import_only. * Small fix in the test for the conversion Interpolate-1 -> Interpolate-4. * Small fixes. * Fixed NGraphReaderTests for Interpolate. * Some fixes. * Deleted class for Resample operation. * Fix in the transformation NearestNeighborUpsampling: fixed precision of the input 'scales' of generated Interpolate-4. * Fixed typo. * Now the TF operations ResizeBilinear is readed as internal MO operation TFResizeBilinear. This internal operation is converted into Interpolate-4. * Small fix in BOM-file. * Added checks of existence of attributes of TF ResizeBilinear operation. * Small fixes in the conversion of the internal MO operation TFResizeBilinear to Interpolate-4. * Small fixes. * Small fixes. * Now the transformation ONNXResize10ToInterpolateV4 calculates sizes input as input_shape * (scales + epsilon). * Added the internal MO operation TFResizeNearestNeighbor. * Fixes in the transformation SplitConcatPairToInterpolate and its tests. * Fixes in the transformation UnsqueezeTileReshapeBlockToInterpolate and its tests. * Written the transformation that converts the internal operation TFResizeNearestNeighbor into Interpolate-4. * Now MO reads the TF operation ResizeNearestNeighbor as the internal MO operation TFResizeNearestNeighbor. * Small fix. * Now the specification of Interpolate-4 clarifies that the mode linear_onnx supports only 2D or 4D input tensors. * Small fix. * Some fixes. * Moved the transformation ONNXResize10ToInterpolateV4 to the front stage. * Deleted infer function and function supported_attrs for ONNXResize10 operation. * Deleted supported_attrs() for TFResizeBilinear and TFResizeNearestNeighbor. * Some fixes. * Fixes in the shape infer function of the nGraph operation Interpolate-4. Now 'axes' input can be non-constant. In the such case, all elements of the output shape are Dimension::dynamic(). * Deleted corner cases processing in transformations TFResizeBilinearToInterpolateV4 and TFResizeNearestNeighborToInterpolateV4. * Rewritten the function replace_resize_bilinear. * Written inner MO operation TFResize that covers TF operations ResizeBilinear and ResizeNearestNeighbor. * Now TF operations ResizeBilinear and ResizeNearestNeighbor are read as an internal operation TFResize in MO. Transformations TFResizeNearestNeighborToInterpolateV4 and TFResizeBilinearToInterpolateV4 are fused into one transformation TFResizeToInterpolateV4. * Some changes in the shape infer function of nGraph op Interpolate-4. * Small fix. * Some changes. * The transformation TFResizeToInterpolateV4 is moved to the front stage. * Deleted redundant assert. * Deleted transformations ResizeToInterpolate2D and ResizeToInterpolate3D. * Some renaming. * Small change. * Deleted .copy() in the shape infer function of the internal operation TFResize. * Small fix. * Small fixes. * Added comment about the case when the input 'axes' of Interpolate-4 is non-constant. * Written test for Interpolate-4 shape infer, for the case when the input 'axes' is non-constant and shape_calculation_mode = scales. * Some fixes. * Small fixes. * Small fix. * Added yet another test for the case of non-constant 'axes' input of Interpolate-4 (when shape_calculation_mode = sizes). * Added some comment. * Small fix. * Reverted changes for InterpolateWithConcat. * Added type checks for all inputs of nGraph operation Interpolate-4. * Added u32 and u64 to supported element types of sizes and axes inputs of nGraph operation Interpolate-4. * Fixed some functional tests. * Some changes. * Added helper function float32_array. * Now the MO transformation InterpolateV1ToInterpolate preserves names of layers. * Small fix. * Small fix. * Reverted some change. * Small fixes. * Small fix. * Small fix. * Small fix. * Small fix. * Reverted changes in the nGraph reader tests for Interpolate-1. * Some revert. * Fixed some copyright year.
This commit is contained in:
committed by
GitHub
parent
bd3884b602
commit
c1136cd7b0
@@ -15,6 +15,7 @@
|
||||
* **Type**: string
|
||||
* **Default value**: none
|
||||
* **Required**: *yes*
|
||||
* **Note**: Only 2D and 4D tensors with `axes = {0, 1}` and `axes = {2, 3}` respectively are supported for `"mode" == "linear_onnx"`.
|
||||
|
||||
* *shape_calculation_mode*
|
||||
|
||||
|
||||
@@ -17,52 +17,59 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::ConvertInterpolate1ToInterpolate4, "Convert
|
||||
|
||||
ngraph::pass::ConvertInterpolate1ToInterpolate4::ConvertInterpolate1ToInterpolate4() {
|
||||
MATCHER_SCOPE(ConvertInterpolate1ToInterpolate4);
|
||||
auto interpolate1 = ngraph::pattern::wrap_type<ngraph::opset1::Interpolate>({pattern::any_input(pattern::has_static_rank()), pattern::any_input()});
|
||||
auto interpolate1 = ngraph::pattern::wrap_type<ngraph::opset1::Interpolate>({pattern::any_input(pattern::has_static_shape()), pattern::any_input()});
|
||||
ngraph::matcher_pass_callback callback = [this](pattern::Matcher& m) {
|
||||
auto interpolate1 = std::dynamic_pointer_cast<ngraph::opset1::Interpolate>(m.get_match_root());
|
||||
if (!interpolate1)
|
||||
auto interpolationV0 = std::dynamic_pointer_cast<ngraph::opset1::Interpolate>(m.get_match_root());
|
||||
if (!interpolationV0) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto interpolate_attrs = interpolate1->get_attrs();
|
||||
auto input_shape_rank = interpolate1->input(0).get_partial_shape().rank().get_length();
|
||||
auto& inp_partial_shape = interpolationV0->get_input_partial_shape(0);
|
||||
auto& out_shape = interpolationV0->get_output_shape(0);
|
||||
auto attrsV0 = interpolationV0->get_attrs();
|
||||
|
||||
// attrs
|
||||
auto mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode();
|
||||
if (interpolate_attrs.mode == "nearest") {
|
||||
mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode::nearest;
|
||||
} else if (interpolate_attrs.mode == "cubic") {
|
||||
mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode::cubic;
|
||||
} else if (interpolate_attrs.mode == "linear") {
|
||||
if (input_shape_rank < 5) {
|
||||
mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode::linear_onnx;
|
||||
} else if (input_shape_rank == 5) {
|
||||
mode_v4 = ngraph::op::v4::Interpolate::InterpolateMode::linear;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
std::vector<float> scales(attrsV0.axes.size(), 1.0f);
|
||||
auto inp_shape = inp_partial_shape.to_shape();
|
||||
size_t i = 0;
|
||||
for (std::size_t axis : attrsV0.axes) {
|
||||
scales[i] = static_cast<float>(out_shape.at(axis))/inp_shape.at(axis);
|
||||
i++;
|
||||
}
|
||||
|
||||
auto scalesConstant = ngraph::op::Constant::create(ngraph::element::f32, {scales.size()}, scales);
|
||||
auto axisConstant = ngraph::op::Constant::create(ngraph::element::i64, {attrsV0.axes.size()},
|
||||
std::vector<std::size_t>{attrsV0.axes.begin(), attrsV0.axes.end()});
|
||||
|
||||
ngraph::opset4::Interpolate::InterpolateAttrs attrsV4;
|
||||
|
||||
if (attrsV0.mode == "nearest") {
|
||||
attrsV4.mode = ngraph::opset4::Interpolate::InterpolateMode::nearest;
|
||||
} else if (attrsV0.mode == "linear") {
|
||||
attrsV4.mode = ngraph::opset4::Interpolate::InterpolateMode::linear;
|
||||
} else if (attrsV0.mode == "cubic") {
|
||||
attrsV4.mode = ngraph::opset4::Interpolate::InterpolateMode::cubic;
|
||||
} else if (attrsV0.mode == "linear_onnx") {
|
||||
attrsV4.mode = ngraph::opset4::Interpolate::InterpolateMode::linear_onnx;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
auto nearest_mode_v4 = ngraph::op::v4::Interpolate::NearestMode::floor;
|
||||
auto shape_calculation_mode_v4 = ngraph::op::v4::Interpolate::ShapeCalcMode::sizes;
|
||||
auto coordinate_transformation_mode_v4 = interpolate_attrs.align_corners ? ngraph::op::v4::Interpolate::CoordinateTransformMode::align_corners :
|
||||
ngraph::op::v4::Interpolate::CoordinateTransformMode::asymmetric;
|
||||
auto interpolate4_attr = ngraph::op::v4::Interpolate::InterpolateAttrs(mode_v4, shape_calculation_mode_v4,
|
||||
interpolate_attrs.pads_begin, interpolate_attrs.pads_end,
|
||||
coordinate_transformation_mode_v4, nearest_mode_v4, interpolate_attrs.antialias, -0.75);
|
||||
attrsV4.shape_calculation_mode = ngraph::opset4::Interpolate::ShapeCalcMode::sizes;
|
||||
attrsV4.nearest_mode = ngraph::opset4::Interpolate::NearestMode::round_prefer_floor;
|
||||
attrsV4.pads_begin = attrsV0.pads_begin;
|
||||
attrsV4.pads_end = attrsV0.pads_end;
|
||||
attrsV4.antialias = attrsV0.antialias;
|
||||
attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::half_pixel;
|
||||
attrsV4.cube_coeff = -0.75f;
|
||||
if (attrsV0.align_corners) {
|
||||
attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::align_corners;
|
||||
}
|
||||
|
||||
// input
|
||||
auto axes = interpolate_attrs.axes.to_vector();
|
||||
auto axes_node = ngraph::opset4::Constant::create(element::i64, {axes.size()}, axes);
|
||||
auto default_scales = std::vector<float>(axes.size(), 1.f);
|
||||
auto scales_node = ngraph::opset4::Constant::create(element::f32, {axes.size()}, default_scales);
|
||||
auto interpolateV4 = std::make_shared<ngraph::opset4::Interpolate>(interpolationV0->input_value(0), interpolationV0->input_value(1),
|
||||
scalesConstant, axisConstant, attrsV4);
|
||||
|
||||
auto interpolate4 = std::make_shared<ngraph::opset4::Interpolate>(interpolate1->input_value(0), interpolate1->input_value(1),
|
||||
scales_node, axes_node, interpolate4_attr);
|
||||
|
||||
interpolate4->set_friendly_name(interpolate1->get_friendly_name());
|
||||
ngraph::copy_runtime_info(interpolate1, interpolate4);
|
||||
ngraph::replace_node(interpolate1, interpolate4);
|
||||
interpolateV4->set_friendly_name(interpolationV0->get_friendly_name());
|
||||
ngraph::copy_runtime_info(interpolationV0, interpolateV4);
|
||||
ngraph::replace_node(interpolationV0, interpolateV4);
|
||||
return true;
|
||||
};
|
||||
|
||||
|
||||
@@ -399,4 +399,4 @@ TEST_F(NGraphReaderTests, ReadInterpolate4Network) {
|
||||
fdata[2] = 2.0;
|
||||
fdata[3] = 2.0;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ TEST(TransformationTests, ConvertInterpolate1ToInterpolate4) {
|
||||
{
|
||||
auto data_node = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 4, 30, 30});
|
||||
auto out_shape_node = opset1::Constant::create(element::i32, Shape{4}, {2, 4, 40, 40});
|
||||
auto default_scales_node = opset1::Constant::create(ngraph::element::f32, Shape{4}, {1.f, 1.f, 1.f, 1.f});
|
||||
auto default_scales_node = opset1::Constant::create(ngraph::element::f32, Shape{4}, {1.f, 1.f, 4.0f / 3.0f, 4.0f / 3.0f});
|
||||
auto axes_node = opset1::Constant::create(ngraph::element::i64, Shape{4}, {0, 1, 2, 3});
|
||||
|
||||
auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::nearest,
|
||||
@@ -94,10 +94,10 @@ TEST(TransformationTests, ConvertInterpolate1ToInterpolate4_1) {
|
||||
{
|
||||
auto data_node = std::make_shared<opset1::Parameter>(element::f32, Shape{2, 4, 30, 30});
|
||||
auto out_shape_node = opset1::Constant::create(element::i32, Shape{2}, {40, 40});
|
||||
auto default_scales_node = opset1::Constant::create(ngraph::element::f32, Shape{2}, {1.f, 1.f});
|
||||
auto default_scales_node = opset1::Constant::create(ngraph::element::f32, Shape{2}, {4.0f / 3.0f, 4.0f / 3.0f});
|
||||
auto axes_node = opset1::Constant::create(ngraph::element::i64, Shape{2}, {2, 3});
|
||||
|
||||
auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::linear_onnx,
|
||||
auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::linear,
|
||||
opset4::Interpolate::ShapeCalcMode::sizes, std::vector<size_t>{0, 0, 0, 0}, std::vector<size_t>{0, 0, 0, 0},
|
||||
opset4::Interpolate::CoordinateTransformMode::align_corners, opset4::Interpolate::NearestMode::floor,
|
||||
false, -0.75);
|
||||
|
||||
@@ -85,7 +85,7 @@ void InterpolateTransformation::validate() {
|
||||
if (attributes.mode == "nearest") {
|
||||
ASSERT_EQ("ScaleShiftIE", typeName);
|
||||
} else {
|
||||
ASSERT_EQ("Interp", typeName);
|
||||
ASSERT_TRUE("Interp" == typeName || "Interpolate" == typeName);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -144,6 +144,7 @@ extensions/front/input_cut.py
|
||||
extensions/front/instance_normalization.py
|
||||
extensions/front/interpolate_reshape.py
|
||||
extensions/front/InterpolateNormalizer.py
|
||||
extensions/front/InterpolateV1ToInterpolate.py
|
||||
extensions/front/kaldi/__init__.py
|
||||
extensions/front/kaldi/add_permute_after_convolution.py
|
||||
extensions/front/kaldi/add_reshape_around_convolution.py
|
||||
@@ -298,6 +299,7 @@ extensions/front/onnx/normalize_ext.py
|
||||
extensions/front/onnx/normalize_l2_normalize.py
|
||||
extensions/front/onnx/one_hot_ext.py
|
||||
extensions/front/onnx/one_hot_normalize.py
|
||||
extensions/front/onnx/ONNXResize10ToInterpolate.py
|
||||
extensions/front/onnx/pad_converter.py
|
||||
extensions/front/onnx/pad_ext.py
|
||||
extensions/front/onnx/parameter_ext.py
|
||||
@@ -317,7 +319,6 @@ extensions/front/onnx/reduce_ext.py
|
||||
extensions/front/onnx/remove_filtering_boxes_by_size.py
|
||||
extensions/front/onnx/reshape_ext.py
|
||||
extensions/front/onnx/resize_ext.py
|
||||
extensions/front/onnx/resize_to_interpolate.py
|
||||
extensions/front/onnx/reverse_sequence_ext.py
|
||||
extensions/front/onnx/rnn_ext.py
|
||||
extensions/front/onnx/roialign_ext.py
|
||||
@@ -483,6 +484,7 @@ extensions/front/tf/SwitchMergeOptimization.py
|
||||
extensions/front/tf/TensorArrayExtractors.py
|
||||
extensions/front/tf/TensorArrayGatherV3.py
|
||||
extensions/front/tf/tensorflow_custom_operations_config_update.py
|
||||
extensions/front/tf/TFResizeToInterpolate.py
|
||||
extensions/front/tf/TFSliceToSlice.py
|
||||
extensions/front/tf/tile_ext.py
|
||||
extensions/front/tf/topk_ext.py
|
||||
@@ -571,7 +573,7 @@ extensions/middle/MulFakeQuantizeFuse.py
|
||||
extensions/middle/MXNetRNNSequenceNormalize.py
|
||||
extensions/middle/MXNetSplitMultiLayers.py
|
||||
extensions/middle/MXTileReplacer.py
|
||||
extensions/middle/ONNXResize11ToInterpolateV4.py
|
||||
extensions/middle/ONNXResize11ToInterpolate.py
|
||||
extensions/middle/ONNXRNNSequenceNormalize.py
|
||||
extensions/middle/PartialInfer.py
|
||||
extensions/middle/pass_separator.py
|
||||
@@ -677,6 +679,7 @@ extensions/ops/non_zero.py
|
||||
extensions/ops/normalize.py
|
||||
extensions/ops/normalize_l2.py
|
||||
extensions/ops/one_hot.py
|
||||
extensions/ops/ONNXResize10.py
|
||||
extensions/ops/ONNXResize11.py
|
||||
extensions/ops/pack.py
|
||||
extensions/ops/parameter.py
|
||||
@@ -697,7 +700,6 @@ extensions/ops/rank.py
|
||||
extensions/ops/ReduceOps.py
|
||||
extensions/ops/regionyolo.py
|
||||
extensions/ops/reorgyolo.py
|
||||
extensions/ops/resample.py
|
||||
extensions/ops/resize.py
|
||||
extensions/ops/resize_factor_utils.py
|
||||
extensions/ops/Reverse.py
|
||||
@@ -733,6 +735,7 @@ extensions/ops/TensorArrayScatter.py
|
||||
extensions/ops/TensorArraySize.py
|
||||
extensions/ops/TensorArrayWrite.py
|
||||
extensions/ops/TensorIterator_ops.py
|
||||
extensions/ops/TFResize.py
|
||||
extensions/ops/topk.py
|
||||
extensions/ops/topkrois_onnx.py
|
||||
extensions/ops/transpose.py
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.elementwise import Mul
|
||||
@@ -54,7 +55,6 @@ class InterpolateConcat(BackReplacementPattern):
|
||||
\ /
|
||||
Concat(axis=1)
|
||||
shape=[1, 7, 60, 160]
|
||||
|
||||
"""
|
||||
enabled = True
|
||||
graph_condition = [lambda graph: not graph.graph['cmd_params'].static_shape]
|
||||
@@ -89,7 +89,7 @@ class InterpolateConcat(BackReplacementPattern):
|
||||
interpolate.in_port(1).get_connection().set_source(gather.out_port(0))
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for interpolate in graph.get_op_nodes(type='Interpolate', version='opset1'):
|
||||
for interpolate in graph.get_op_nodes(type='Interpolate'):
|
||||
if interpolate.in_port(1).get_source().node.soft_get('type') != 'Const':
|
||||
continue
|
||||
dsts = interpolate.out_port(0).get_destinations()
|
||||
@@ -101,14 +101,12 @@ class InterpolateReshapeWA(BackReplacementPattern):
|
||||
"""
|
||||
Replaces hard-coded 1-port input of Interpolate with reshape-able sub-graph.
|
||||
WARNING: Could cause troubles if model has hard-coded Interpolate intentionally -- rare situation
|
||||
|
||||
BEFORE:
|
||||
input Const
|
||||
shape=[1, 3, 30, 40] value=[60, 160]
|
||||
\ /
|
||||
Interpolate(axes=(2, 3))
|
||||
shape=[1, 3, 60, 160]
|
||||
|
||||
AFTER:
|
||||
input
|
||||
shape=[1, 3, 30, 40]
|
||||
@@ -151,6 +149,6 @@ class InterpolateReshapeWA(BackReplacementPattern):
|
||||
interpolate.in_port(1).get_connection().set_source(mul.out_port(0))
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for interpolate in graph.get_op_nodes(type='Interpolate', version='opset1'):
|
||||
for interpolate in graph.get_op_nodes(type='Interpolate'):
|
||||
if interpolate.in_port(1).get_source().node.soft_get('type') == 'Const':
|
||||
self.make_interpolate_reshapeable(interpolate)
|
||||
self.make_interpolate_reshapeable(interpolate)
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Copyright (C) 2018-2021 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.interpolate import Interpolate
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementPattern
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph, rename_nodes
|
||||
|
||||
|
||||
def correct_pad(pad):
|
||||
return int64_array([pad] if not isinstance(pad, list) else pad)
|
||||
|
||||
|
||||
class InterpolateV1ToInterpolate(FrontReplacementPattern):
|
||||
"""
|
||||
This transformation replaces the operation Interpolate-1 with the operation Interpolate-4.
|
||||
"""
|
||||
enabled = True
|
||||
|
||||
def run_after(self):
|
||||
from extensions.front.InterpolateNormalizer import InterpolateNormalizer
|
||||
return [InterpolateNormalizer]
|
||||
|
||||
def find_and_replace_pattern(self, graph: Graph):
|
||||
for node in graph.get_op_nodes(op='Interpolate', version='opset1'):
|
||||
transformation_mode = 'align_corners' if int(node.soft_get('align_corners', 0)) else 'half_pixel'
|
||||
interpolate1_name = node.soft_get('name', node.id)
|
||||
interpolate4 = create_op_with_const_inputs(graph, Interpolate,
|
||||
{
|
||||
2: np.array([1.0, 1.0]),
|
||||
3: int64_array(node.axes)
|
||||
},
|
||||
{
|
||||
'mode': node.mode,
|
||||
'antialias': node.antialias,
|
||||
'coordinate_transformation_mode': transformation_mode,
|
||||
'pads_begin': correct_pad(node.soft_get('pads_begin', 0)),
|
||||
'pads_end': correct_pad(node.soft_get('pads_end', 0)),
|
||||
'nearest_mode': 'round_prefer_floor',
|
||||
'cube_coeff': -0.75,
|
||||
'shape_calculation_mode': 'sizes',
|
||||
'version': 'opset4',
|
||||
'in_ports_count': 4,
|
||||
})
|
||||
|
||||
interpolate1_input_connection = node.in_port(0).get_connection()
|
||||
interpolate1_input_connection.set_destination(interpolate4.in_port(0))
|
||||
|
||||
sizes_connection = node.in_port(1).get_connection()
|
||||
sizes_connection.set_destination(interpolate4.in_port(1))
|
||||
|
||||
node.out_port(0).get_connection().set_source(interpolate4.out_port(0))
|
||||
rename_nodes([(node, interpolate1_name + '/delete'), (interpolate4, interpolate1_name)])
|
||||
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging as log
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.activation_ops import Floor
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.elementwise import Add, Mul
|
||||
from extensions.ops.interpolate import Interpolate
|
||||
from extensions.ops.range import Range
|
||||
from extensions.ops.rank import Rank
|
||||
from mo.front.common.partial_infer.utils import int64_array, float_array
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph, Node, rename_nodes
|
||||
from mo.middle.passes.convert_data_type import data_type_str_to_np
|
||||
from mo.ops.shape import Shape
|
||||
from mo.ops.strided_slice import StridedSlice
|
||||
|
||||
|
||||
def replace_resize(graph: Graph, resize: Node):
|
||||
log.debug("Converting of ONNX Resize-10 to Interpolate-4 "
|
||||
"is triggered for node {}.".format(resize.soft_get('name', resize.id)))
|
||||
|
||||
resize_name = resize.soft_get('name', resize.id)
|
||||
|
||||
rank_node = Rank(graph, {'name': resize_name + '/max_axes'}).create_node()
|
||||
range_node = create_op_with_const_inputs(graph, Range, {0: int64_array(2), 2: int64_array(1)},
|
||||
{'name': resize_name + '/axes'})
|
||||
|
||||
sizes_ss = create_op_with_const_inputs(graph, StridedSlice,
|
||||
{1: int64_array([2]),
|
||||
2: int64_array([0]),
|
||||
3: int64_array([1])},
|
||||
{'name': resize_name + '/sizes_ss',
|
||||
'begin_mask': int64_array([1]),
|
||||
'end_mask': int64_array([0]),
|
||||
'new_axis_mask': int64_array([0]),
|
||||
'shrink_axis_mask': int64_array([0]),
|
||||
'ellipsis_mask': int64_array([0])})
|
||||
scales_ss = create_op_with_const_inputs(graph, StridedSlice,
|
||||
{1: int64_array([2]),
|
||||
2: int64_array([0]),
|
||||
3: int64_array([1])},
|
||||
{'name': resize_name + '/scales_ss',
|
||||
'begin_mask': int64_array([1]),
|
||||
'end_mask': int64_array([0]),
|
||||
'new_axis_mask': int64_array([0]),
|
||||
'shrink_axis_mask': int64_array([0]),
|
||||
'ellipsis_mask': int64_array([0])})
|
||||
|
||||
rank_node.out_port(0).connect(range_node.in_port(1))
|
||||
|
||||
interpolate_node = Interpolate(graph, {'version': 'opset4',
|
||||
'mode': 'linear_onnx' if resize.mode == 'linear' else 'nearest',
|
||||
'coordinate_transformation_mode': 'asymmetric',
|
||||
'cube_coeff': -0.75,
|
||||
'nearest_mode': 'simple',
|
||||
'pads_begin': int64_array([0]),
|
||||
'pads_end': int64_array([0]),
|
||||
'antialias': 0,
|
||||
'shape_calculation_mode': 'scales',
|
||||
'in_ports_count': 4}).create_node()
|
||||
|
||||
range_node.out_port(0).connect(interpolate_node.in_port(3))
|
||||
shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()
|
||||
|
||||
# When we calculate 'sizes' input as floor(input_shape * scales), we can get incorrect 'sizes' if, e.g.,
|
||||
# scales = [1.0, 1.0, 1.33333, 2.0], input_shape = [1, 3, 30, 200], because
|
||||
# input_shape * scales = [1, 3, 39.9999, 400], and floor(input_shape * scales)[2] == 39, not 40.
|
||||
# Maybe we need to calculate 'sizes' input as floor(input_shape * scales + eps), where eps is some small
|
||||
# floating point number, e.g. 1.0e-5. But, in this case, if scales = [1.0, 1.0, 1.333333, 2.0],
|
||||
# input_shape = [1, 3, 30, 200], floor(input_shape * scales + eps) = 39, not 40, because
|
||||
# input_shape[2] * scales[2] + 1.0e-5 = 39.99991.
|
||||
# Hence, we need to calculate 'sizes' as floor(input_shape * (scales + eps)).
|
||||
add_node = create_op_with_const_inputs(graph, Add,
|
||||
{1: float_array([1.0e-5])},
|
||||
{'name': resize_name + '/Add'})
|
||||
|
||||
input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
|
||||
|
||||
cast_shape_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
|
||||
|
||||
shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
|
||||
mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node([cast_shape_to_float, add_node])
|
||||
floor_node = Floor(graph, {'name': resize_name + '/Floor'}).create_node([mul_node])
|
||||
cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node([floor_node])
|
||||
cast_mul_result_to_int.out_port(0).connect(sizes_ss.in_port(0))
|
||||
sizes_ss.out_port(0).connect(interpolate_node.in_port(1))
|
||||
|
||||
scales_ss.out_port(0).connect(interpolate_node.in_port(2))
|
||||
|
||||
connection_of_resize_input = resize.in_port(0).get_connection()
|
||||
connection_of_resize_input.set_destination(interpolate_node.in_port(0))
|
||||
|
||||
connection_of_scales = resize.in_port(1).get_connection()
|
||||
connection_of_scales.set_destination(scales_ss.in_port(0))
|
||||
|
||||
connection_of_resize_input.get_source().connect(shape_of.in_port(0))
|
||||
connection_of_resize_input.get_source().connect(rank_node.in_port(0))
|
||||
connection_of_scales.get_source().connect(add_node.in_port(0))
|
||||
|
||||
rename_nodes([(resize, resize_name + '/delete'), (interpolate_node, resize_name)])
|
||||
resize.out_port(0).get_connection().set_source(interpolate_node.out_port(0))
|
||||
|
||||
|
||||
class ONNXResize10ToInterpolate(FrontReplacementOp):
|
||||
"""
|
||||
The transformation replaces ONNX Resize 10 with Interpolate-4.
|
||||
"""
|
||||
op = 'ONNXResize10'
|
||||
enabled = True
|
||||
|
||||
def run_after(self):
|
||||
from extensions.front.InterpolateNormalizer import InterpolateNormalizer
|
||||
return [InterpolateNormalizer]
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, match: dict):
|
||||
resize = match['op']
|
||||
replace_resize(graph, resize)
|
||||
@@ -14,7 +14,7 @@
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from extensions.ops.upsample import UpsampleOp
|
||||
from extensions.ops.ONNXResize10 import ONNXResize10
|
||||
from extensions.ops.ONNXResize11 import ONNXResize11Op
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.onnx.extractors.utils import onnx_attr, get_onnx_opset_version
|
||||
@@ -43,5 +43,5 @@ class ResizeExtractor(FrontExtractorOp):
|
||||
ONNXResize11Op.update_node_stat(node, attrs)
|
||||
else:
|
||||
mode = onnx_attr(node, 'mode', 's', default=b'nearest').decode()
|
||||
UpsampleOp.update_node_stat(node, {'mode': mode})
|
||||
ONNXResize10.update_node_stat(node, {'mode': mode})
|
||||
return cls.enabled
|
||||
|
||||
@@ -1,192 +0,0 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
import logging as log
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.elementwise import Mul
|
||||
from extensions.ops.interpolate import Interpolate
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementSubgraph
|
||||
from mo.graph.graph import Graph
|
||||
from mo.ops.const import Const
|
||||
|
||||
|
||||
class ResizeToInterpolate2D(FrontReplacementSubgraph):
|
||||
enabled = True
|
||||
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
('input', dict()),
|
||||
('shape_1', dict(op='ShapeOf')),
|
||||
('shape_2', dict(op='ShapeOf')),
|
||||
('shape_3', dict(op='ShapeOf')),
|
||||
('gather_1', dict(type='Gather')),
|
||||
('gather_2', dict(type='Gather')),
|
||||
('mul_1', dict(op='Mul')),
|
||||
('mul_2', dict(op='Mul')),
|
||||
('unsqueeze_1', dict(op='ExpandDims')),
|
||||
('unsqueeze_2', dict(op='ExpandDims')),
|
||||
('slice', dict(op='Slice')),
|
||||
('slice_start', dict(op='Const', value=lambda x: x is not None and np.array_equal(x, int64_array([2])))),
|
||||
('slice_end', dict(op='Const', value=lambda x: x is not None and np.array_equal(x, int64_array([4])))),
|
||||
('concat_1', dict(op='Concat')),
|
||||
('cast_1', dict(op='Cast')),
|
||||
('cast_2', dict(op='Cast')),
|
||||
('div', dict(op='Div')),
|
||||
('concat_2', dict(op='Concat')),
|
||||
('resize', dict(op='Upsample')),
|
||||
],
|
||||
edges=[
|
||||
('input', 'resize', {'in': 0}),
|
||||
('input', 'shape_1', {'in': 0}),
|
||||
('input', 'shape_2', {'in': 0}),
|
||||
('input', 'shape_3', {'in': 0}),
|
||||
('shape_1', 'gather_1', {'in': 0}),
|
||||
('shape_2', 'gather_2', {'in': 0}),
|
||||
('shape_3', 'slice', {'in': 0}),
|
||||
('slice_start', 'slice', {'in': 1}),
|
||||
('slice_end', 'slice', {'in': 2}),
|
||||
('gather_1', 'mul_1', {'in': 0}),
|
||||
('gather_2', 'mul_2', {'in': 0}),
|
||||
('mul_1', 'unsqueeze_1', {'in': 0}),
|
||||
('mul_2', 'unsqueeze_2', {'in': 0}),
|
||||
('unsqueeze_1', 'concat_1', {'in': 0}),
|
||||
('unsqueeze_2', 'concat_1', {'in': 1}),
|
||||
('concat_1', 'cast_1', {'in': 0}),
|
||||
('slice', 'cast_2', {'in': 0}),
|
||||
('cast_1', 'div', {'in': 0}),
|
||||
('cast_2', 'div', {'in': 1}),
|
||||
('div', 'concat_2', {'in': 1}),
|
||||
('concat_2', 'resize', {'in': 1}),
|
||||
])
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, match: dict):
|
||||
resize_node = match['resize']
|
||||
|
||||
if match['mul_1'].in_node(1).value != match['mul_2'].in_node(1).value:
|
||||
log.info('Pattern matched around resize op {} has different scale values.'.format(resize_node.name))
|
||||
return
|
||||
|
||||
interpolate_node = Interpolate(graph, {'name': resize_node.name + '/Interpolate',
|
||||
'mode': resize_node.mode, 'axes': int64_array([2, 3])}).create_node()
|
||||
|
||||
scale = match['mul_1'].in_node(1).value
|
||||
scale_value = int64_array([scale, scale])
|
||||
scale_const = Const(graph, {'value': scale_value, 'name': resize_node.name + '/Scale'}).create_node()
|
||||
|
||||
interpolated_shape = Mul(graph, {'name': resize_node.name + '/OutputShape'}).create_node()
|
||||
match['slice'].out_port(0).connect(interpolated_shape.in_port(0))
|
||||
scale_const.out_port(0).connect(interpolated_shape.in_port(1))
|
||||
|
||||
resize_node.in_port(0).get_connection().set_destination(interpolate_node.in_port(0))
|
||||
interpolated_shape.out_port(0).connect(interpolate_node.in_port(1))
|
||||
resize_node.out_port(0).get_connection().set_source(interpolate_node.out_port(0))
|
||||
|
||||
|
||||
class ResizeToInterpolate3D(FrontReplacementSubgraph):
|
||||
enabled = True
|
||||
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
('input', dict()),
|
||||
('shape_1', dict(op='ShapeOf')),
|
||||
('shape_2', dict(op='ShapeOf')),
|
||||
('shape_3', dict(op='ShapeOf')),
|
||||
('shape_4', dict(op='ShapeOf')),
|
||||
('gather_1', dict(type='Gather')),
|
||||
('gather_2', dict(type='Gather')),
|
||||
('gather_3', dict(type='Gather')),
|
||||
('mul_1', dict(op='Mul')),
|
||||
('mul_2', dict(op='Mul')),
|
||||
('mul_3', dict(op='Mul')),
|
||||
('cast_1', dict(op='Cast')),
|
||||
('cast_2', dict(op='Cast')),
|
||||
('cast_3', dict(op='Cast')),
|
||||
('unsqueeze_1', dict(op='ExpandDims')),
|
||||
('unsqueeze_2', dict(op='ExpandDims')),
|
||||
('unsqueeze_3', dict(op='ExpandDims')),
|
||||
('floor_1', dict(op='Floor')),
|
||||
('floor_2', dict(op='Floor')),
|
||||
('floor_3', dict(op='Floor')),
|
||||
('slice', dict(op='Slice')),
|
||||
('slice_start', dict(op='Const', value=lambda x: x is not None and np.array_equal(x, int64_array([2])))),
|
||||
('slice_end', dict(op='Const', value=lambda x: x is not None and np.array_equal(x, int64_array([5])))),
|
||||
('concat_1', dict(op='Concat')),
|
||||
('cast_4', dict(op='Cast')),
|
||||
('cast_5', dict(op='Cast')),
|
||||
('div', dict(op='Div')),
|
||||
('concat_2', dict(op='Concat')),
|
||||
('resize', dict(op='Upsample')),
|
||||
],
|
||||
edges=[
|
||||
('input', 'resize', {'in': 0}),
|
||||
('input', 'shape_1', {'in': 0}),
|
||||
('input', 'shape_2', {'in': 0}),
|
||||
('input', 'shape_3', {'in': 0}),
|
||||
('input', 'shape_4', {'in': 0}),
|
||||
('shape_1', 'gather_1', {'in': 0}),
|
||||
('shape_2', 'gather_2', {'in': 0}),
|
||||
('shape_3', 'gather_3', {'in': 0}),
|
||||
('shape_4', 'slice', {'in': 0}),
|
||||
('slice_start', 'slice', {'in': 1}),
|
||||
('slice_end', 'slice', {'in': 2}),
|
||||
('gather_1', 'mul_1', {'in': 0}),
|
||||
('gather_2', 'mul_2', {'in': 0}),
|
||||
('gather_3', 'mul_3', {'in': 0}),
|
||||
('mul_1', 'cast_1', {'in': 0}),
|
||||
('mul_2', 'cast_2', {'in': 0}),
|
||||
('mul_3', 'cast_3', {'in': 0}),
|
||||
('cast_1', 'floor_1', {'in': 0}),
|
||||
('cast_2', 'floor_2', {'in': 0}),
|
||||
('cast_3', 'floor_3', {'in': 0}),
|
||||
('floor_1', 'unsqueeze_1', {'in': 0}),
|
||||
('floor_2', 'unsqueeze_2', {'in': 0}),
|
||||
('floor_3', 'unsqueeze_3', {'in': 0}),
|
||||
('unsqueeze_1', 'concat_1', {'in': 0}),
|
||||
('unsqueeze_2', 'concat_1', {'in': 1}),
|
||||
('unsqueeze_3', 'concat_1', {'in': 2}),
|
||||
('concat_1', 'cast_4', {'in': 0}),
|
||||
('slice', 'cast_5', {'in': 0}),
|
||||
('cast_4', 'div', {'in': 0}),
|
||||
('cast_5', 'div', {'in': 1}),
|
||||
('div', 'concat_2', {'in': 1}),
|
||||
('concat_2', 'resize', {'in': 1}),
|
||||
])
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, match: dict):
|
||||
resize_node = match['resize']
|
||||
if match['mul_1'].in_node(1).value != match['mul_2'].in_node(1).value or \
|
||||
match['mul_1'].in_node(1).value != match['mul_3'].in_node(1).value:
|
||||
log.info('Pattern matched around resize op {} has different scale values.'.format(resize_node.name))
|
||||
return
|
||||
|
||||
interpolate_node = Interpolate(graph, {'name': resize_node.name + '/Interpolate',
|
||||
'mode': resize_node.mode, 'axes': int64_array([2, 3, 4])}).create_node()
|
||||
|
||||
scale = match['mul_1'].in_node(1).value
|
||||
scale_value = int64_array([scale, scale, scale])
|
||||
scale_const = Const(graph, {'value': scale_value, 'name': resize_node.name + '/Scale'}).create_node()
|
||||
|
||||
interpolated_shape = Mul(graph, {'name': resize_node.name + '/OutputShape'}).create_node()
|
||||
match['slice'].out_port(0).connect(interpolated_shape.in_port(0))
|
||||
scale_const.out_port(0).connect(interpolated_shape.in_port(1))
|
||||
|
||||
resize_node.in_port(0).get_connection().set_destination(interpolate_node.in_port(0))
|
||||
interpolated_shape.out_port(0).connect(interpolate_node.in_port(1))
|
||||
resize_node.out_port(0).get_connection().set_source(interpolate_node.out_port(0))
|
||||
@@ -18,9 +18,10 @@ import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.ONNXResize10 import ONNXResize10
|
||||
from extensions.ops.upsample import UpsampleOp
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.onnx.extractors.utils import onnx_attr
|
||||
from mo.front.onnx.extractors.utils import onnx_attr, get_onnx_opset_version
|
||||
from mo.utils.error import Error
|
||||
|
||||
|
||||
@@ -30,42 +31,47 @@ class UpsampleFrontExtractor(FrontExtractorOp):
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
mode = onnx_attr(node, 'mode', 's', default='nearest', dst_type=lambda x: x.decode())
|
||||
scales = onnx_attr(node, 'scales', 'floats', dst_type=lambda x: np.array(x, dtype=np.float32))
|
||||
width_scale = onnx_attr(node, 'width_scale', 'f')
|
||||
height_scale = onnx_attr(node, 'height_scale', 'f')
|
||||
onnx_opset_version = get_onnx_opset_version(node)
|
||||
if onnx_opset_version is not None and onnx_opset_version >= 9:
|
||||
mode = onnx_attr(node, 'mode', 's', default='nearest', dst_type=lambda x: x.decode())
|
||||
ONNXResize10.update_node_stat(node, {'mode': mode})
|
||||
else:
|
||||
mode = onnx_attr(node, 'mode', 's', default='nearest', dst_type=lambda x: x.decode())
|
||||
scales = onnx_attr(node, 'scales', 'floats', dst_type=lambda x: np.array(x, dtype=np.float32))
|
||||
width_scale = onnx_attr(node, 'width_scale', 'f')
|
||||
height_scale = onnx_attr(node, 'height_scale', 'f')
|
||||
|
||||
supported_modes = ['nearest', 'linear']
|
||||
if mode not in supported_modes:
|
||||
raise Error(
|
||||
'Error decoding Upsample node {}, mode = {} is not in the list of supported modes {}.',
|
||||
node.name,
|
||||
mode,
|
||||
supported_modes
|
||||
)
|
||||
|
||||
if scales is not None:
|
||||
if scales.shape != (4,):
|
||||
supported_modes = ['nearest', 'linear']
|
||||
if mode not in supported_modes:
|
||||
raise Error(
|
||||
'Upsample scales attribute is wrong for node {}. Only 4D scales are supported.',
|
||||
'Error decoding Upsample node {}, mode = {} is not in the list of supported modes {}.',
|
||||
node.name,
|
||||
mode,
|
||||
supported_modes
|
||||
)
|
||||
|
||||
if scales is not None:
|
||||
if scales.shape != (4,):
|
||||
raise Error(
|
||||
'Upsample scales attribute is wrong for node {}. Only 4D scales are supported.',
|
||||
node.name
|
||||
)
|
||||
if math.fabs(scales[0] - 1) > 1e-5 or math.fabs(scales[1] - 1) > 1e-5:
|
||||
raise Error(
|
||||
'Upsampling of batch and feature dimensions is not supported for node {}.',
|
||||
node.name
|
||||
)
|
||||
height_scale = scales[2]
|
||||
width_scale = scales[3]
|
||||
|
||||
if (width_scale is None or height_scale is None) and len(node.in_nodes()) != 2:
|
||||
raise Error(
|
||||
'One/both of widths_scale = {} and height_scale = {} is not defined for Upsample node {}.',
|
||||
width_scale,
|
||||
height_scale,
|
||||
node.name
|
||||
)
|
||||
if math.fabs(scales[0] - 1) > 1e-5 or math.fabs(scales[1] - 1) > 1e-5:
|
||||
raise Error(
|
||||
'Upsampling of batch and feature dimensions is not supported for node {}.',
|
||||
node.name
|
||||
)
|
||||
height_scale = scales[2]
|
||||
width_scale = scales[3]
|
||||
|
||||
if (width_scale is None or height_scale is None) and len(node.in_nodes()) != 2:
|
||||
raise Error(
|
||||
'One/both of widths_scale = {} and height_scale = {} is not defined for Upsample node {}.',
|
||||
width_scale,
|
||||
height_scale,
|
||||
node.name
|
||||
)
|
||||
|
||||
UpsampleOp.update_node_stat(node, {'mode': mode, 'height_scale': height_scale,
|
||||
'width_scale': width_scale})
|
||||
UpsampleOp.update_node_stat(node, {'mode': mode, 'height_scale': height_scale,
|
||||
'width_scale': width_scale})
|
||||
return cls.enabled
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.interpolate import Interpolate
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.replacement import FrontReplacementSubgraph
|
||||
from mo.graph.graph import Graph
|
||||
@@ -35,8 +36,7 @@ class InterpolateTranspose(FrontReplacementSubgraph):
|
||||
('interpolate',
|
||||
{
|
||||
'kind': 'op',
|
||||
'op': 'Interpolate',
|
||||
'axes': lambda axes: axes is not None and np.array_equal(axes, int64_array([1, 2]))
|
||||
'op': 'Interpolate'
|
||||
}),
|
||||
('transpose_1', {'kind': 'op', 'op': 'Transpose'}),
|
||||
('transpose_1_order',
|
||||
@@ -70,8 +70,18 @@ class InterpolateTranspose(FrontReplacementSubgraph):
|
||||
transpose_1 = match['transpose_1']
|
||||
transpose_2 = match['transpose_2']
|
||||
|
||||
axes = Interpolate.get_axes(interpolate)
|
||||
if axes is None or not np.array_equal(axes, int64_array([1, 2])):
|
||||
return
|
||||
|
||||
# because we remove Transpose layers the ResizeNearestNeighbor should be updated for NCHW layout
|
||||
interpolate.axes = int64_array([2, 3])
|
||||
opset = interpolate.get_opset()
|
||||
assert opset in ['opset1', 'opset4'], \
|
||||
'Interpolate node with name {} has unsupported opset'.format(interpolate.soft_get('name', interpolate.id))
|
||||
if opset == 'opset1':
|
||||
interpolate.axes = int64_array([2, 3])
|
||||
else:
|
||||
interpolate.in_port(3).data.set_value(int64_array([2, 3]))
|
||||
|
||||
transpose_1.in_port(0).get_connection().set_destination(interpolate.in_port(0))
|
||||
transpose_2.out_port(0).get_connection().set_source(interpolate.out_port(0))
|
||||
|
||||
133
model-optimizer/extensions/front/tf/TFResizeToInterpolate.py
Normal file
133
model-optimizer/extensions/front/tf/TFResizeToInterpolate.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging as log
|
||||
|
||||
import numpy as np
|
||||
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.elementwise import Div
|
||||
from extensions.ops.interpolate import Interpolate
|
||||
from mo.front.common.layout import get_height_dim, get_width_dim
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.common.replacement import FrontReplacementOp
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs
|
||||
from mo.graph.graph import Graph, Node, rename_nodes
|
||||
from mo.ops.shape import Shape
|
||||
from mo.ops.strided_slice import StridedSlice
|
||||
|
||||
|
||||
def replace_tf_resize(graph: Graph, resize: Node, interpolation_mode: str):
|
||||
resize_name = resize.soft_get('name', resize.id)
|
||||
log.debug("Converting of {} to Interpolate-4 is triggered for node {}.".format(resize.op, resize_name))
|
||||
|
||||
num_of_inputs = len([port for port in resize.in_ports().values() if not port.disconnected()])
|
||||
assert num_of_inputs == 2, \
|
||||
"Number of inputs of {} (with name {}) should be equal to 2".format(resize.op, resize_name)
|
||||
|
||||
attrs_msg = "If half_pixel_centers attribute of the node {} with op {} is True, " \
|
||||
"the attribute align_corners must be False"
|
||||
assert not resize.half_pixel_centers or (resize.half_pixel_centers and not resize.align_corners), \
|
||||
attrs_msg.format(resize_name, resize.op)
|
||||
|
||||
shape = Shape(graph, {'name': resize_name + '/shapeof'}).create_node()
|
||||
|
||||
layout = graph.graph['layout']
|
||||
height_dim = get_height_dim(layout, 4)
|
||||
width_dim = get_width_dim(layout, 4)
|
||||
|
||||
ss = create_op_with_const_inputs(graph, StridedSlice,
|
||||
{1: int64_array([height_dim]),
|
||||
2: int64_array([width_dim + 1]),
|
||||
3: int64_array([1])
|
||||
},
|
||||
{'name': resize_name + '/StridedSlice',
|
||||
'begin_mask': int64_array([1]),
|
||||
'end_mask': int64_array([1]),
|
||||
'new_axis_mask': int64_array([0]),
|
||||
'shrink_axis_mask': int64_array([0]),
|
||||
'ellipsis_mask': int64_array([0])
|
||||
})
|
||||
|
||||
div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()
|
||||
|
||||
shape_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()
|
||||
size_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()
|
||||
|
||||
size_to_float.out_port(0).connect(div_node.in_port(0))
|
||||
shape_to_float.out_port(0).connect(div_node.in_port(1))
|
||||
ss.out_port(0).connect(shape_to_float.in_port(0))
|
||||
shape.out_port(0).connect(ss.in_port(0))
|
||||
|
||||
align_corners = resize.align_corners
|
||||
half_pixel_centers = resize.half_pixel_centers
|
||||
|
||||
nearest_mode = 'floor' if interpolation_mode == 'nearest' else 'round_prefer_floor'
|
||||
if align_corners:
|
||||
coordinate_transformation_mode = 'align_corners'
|
||||
if interpolation_mode == 'nearest':
|
||||
nearest_mode = 'round_prefer_ceil'
|
||||
elif half_pixel_centers:
|
||||
coordinate_transformation_mode = 'tf_half_pixel_for_nn' if interpolation_mode == 'nearest' else 'half_pixel'
|
||||
else:
|
||||
coordinate_transformation_mode = 'asymmetric'
|
||||
|
||||
interpolate4 = create_op_with_const_inputs(graph, Interpolate,
|
||||
{
|
||||
3: int64_array([height_dim, width_dim])
|
||||
},
|
||||
{
|
||||
'name': resize_name + '/interpolate_4',
|
||||
'mode': interpolation_mode,
|
||||
'antialias': False,
|
||||
'coordinate_transformation_mode': coordinate_transformation_mode,
|
||||
'pads_begin': int64_array([0]),
|
||||
'pads_end': int64_array([0]),
|
||||
'nearest_mode': nearest_mode,
|
||||
'cube_coeff': -0.75,
|
||||
'shape_calculation_mode': 'sizes',
|
||||
'version': 'opset4',
|
||||
'in_ports_count': 4,
|
||||
})
|
||||
|
||||
resize_input_connection = resize.in_port(0).get_connection()
|
||||
resize_input_connection.set_destination(interpolate4.in_port(0))
|
||||
resize_input_connection.get_source().connect(shape.in_port(0))
|
||||
|
||||
div_node.out_port(0).connect(interpolate4.in_port(2))
|
||||
|
||||
sizes_connection = resize.in_port(1).get_connection()
|
||||
sizes_connection.set_destination(interpolate4.in_port(1))
|
||||
sizes_connection.get_source().connect(size_to_float.in_port(0))
|
||||
|
||||
resize.out_port(0).get_connection().set_source(interpolate4.out_port(0))
|
||||
rename_nodes([(resize, resize_name + '/delete'), (interpolate4, resize_name)])
|
||||
|
||||
|
||||
class TFResizeToInterpolate(FrontReplacementOp):
|
||||
"""
|
||||
The transformation replaces TFResize with Interpolate-4.
|
||||
"""
|
||||
op = 'TFResize'
|
||||
enabled = True
|
||||
|
||||
def run_after(self):
|
||||
from extensions.front.InterpolateNormalizer import InterpolateNormalizer
|
||||
return [InterpolateNormalizer]
|
||||
|
||||
def replace_sub_graph(self, graph: Graph, match: dict):
|
||||
resize = match['op']
|
||||
replace_tf_resize(graph, resize, resize.mode)
|
||||
@@ -70,17 +70,30 @@ class NearestNeighborUpsampling(FrontReplacementSubgraph):
|
||||
log.warning('Failed to determine scaling parameters from the topology. Do not apply pattern.')
|
||||
return
|
||||
|
||||
axes = int64_array([2, 3]) if graph.graph['layout'] == 'NCHW' else int64_array([1, 2])
|
||||
|
||||
resample_op = Interpolate(graph, {'name': 'Resample_', 'antialias': 0, 'mode': 'nearest', 'axes': axes})
|
||||
reshape2_name = match['reshape_2'].name
|
||||
resample_op = Interpolate(graph,
|
||||
{'mode': 'nearest', 'antialias': 0, 'pads_begin': int64_array([0]),
|
||||
'pads_end': int64_array([0]), 'coordinate_transformation_mode': 'half_pixel',
|
||||
'nearest_mode': 'round_prefer_floor', 'cube_coeff': -0.75, 'version': 'opset4',
|
||||
'name': reshape2_name + '/Resample', 'shape_calculation_mode': 'scales',
|
||||
'in_ports_count': 4})
|
||||
resample_node = resample_op.create_node([match['op']])
|
||||
const = Const(graph, {'value': np.array([input_height * height_scale, input_width * width_scale]),
|
||||
'name': resample_node.name + '/target_shape'}).create_node()
|
||||
axes_node = Const(graph,
|
||||
{
|
||||
'name': resample_node.name + '/axes',
|
||||
'value': int64_array([2, 3]) if graph.graph['layout'] == 'NCHW' else int64_array([1, 2])
|
||||
}).create_node()
|
||||
sizes_node = Const(graph, {'value': np.array([input_height * height_scale, input_width * width_scale]),
|
||||
'name': resample_node.name + '/target_shape'}).create_node()
|
||||
scales_node = Const(graph, {'value': np.array([height_scale, width_scale], dtype=np.float32),
|
||||
'name': resample_node.name + '/scales'}).create_node()
|
||||
|
||||
match['reshape_2'].replace_node(resample_node)
|
||||
|
||||
resample_node.add_input_port(1, skip_if_exist=True)
|
||||
assert resample_node.in_port(1).disconnected()
|
||||
const.out_port(0).connect(resample_node.in_port(1))
|
||||
sizes_node.out_port(0).connect(resample_node.in_port(1))
|
||||
scales_node.out_port(0).connect(resample_node.in_port(2))
|
||||
axes_node.out_port(0).connect(resample_node.in_port(3))
|
||||
|
||||
graph.remove_nodes_from([node.id for node in match.values() if node.id != match['op'].id])
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from extensions.ops.interpolate import Interpolate
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
|
||||
from extensions.ops.TFResize import TFResize
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
|
||||
|
||||
@@ -24,10 +24,18 @@ class ResizeBilinearFrontExtractor(FrontExtractorOp):
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
mapping_rule = {
|
||||
'align_corners': int(node.pb.attr['align_corners'].b),
|
||||
'mode': 'linear',
|
||||
'axes': int64_array([1, 2]),
|
||||
align_corners = False
|
||||
if 'align_corners' in node.pb.attr:
|
||||
align_corners = node.pb.attr['align_corners'].b
|
||||
|
||||
half_pixel_centers = False
|
||||
if 'half_pixel_centers' in node.pb.attr:
|
||||
half_pixel_centers = node.pb.attr['half_pixel_centers'].b
|
||||
|
||||
attrs = {
|
||||
'align_corners': align_corners,
|
||||
'half_pixel_centers': half_pixel_centers,
|
||||
'mode': 'linear'
|
||||
}
|
||||
Interpolate.update_node_stat(node, mapping_rule)
|
||||
TFResize.update_node_stat(node, attrs)
|
||||
return cls.enabled
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from extensions.ops.interpolate import Interpolate
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
|
||||
from extensions.ops.TFResize import TFResize
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
|
||||
|
||||
@@ -24,10 +24,18 @@ class ResizeNearestNeighborFrontExtractor(FrontExtractorOp):
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
mapping_rule = {
|
||||
'mode': 'nearest',
|
||||
'antialias': 0,
|
||||
'axes': int64_array([1, 2]),
|
||||
align_corners = False
|
||||
if 'align_corners' in node.pb.attr:
|
||||
align_corners = node.pb.attr['align_corners'].b
|
||||
|
||||
half_pixel_centers = False
|
||||
if 'half_pixel_centers' in node.pb.attr:
|
||||
half_pixel_centers = node.pb.attr['half_pixel_centers'].b
|
||||
|
||||
attrs = {
|
||||
'align_corners': align_corners,
|
||||
'half_pixel_centers': half_pixel_centers,
|
||||
'mode': 'nearest'
|
||||
}
|
||||
Interpolate.update_node_stat(node, mapping_rule)
|
||||
TFResize.update_node_stat(node, attrs)
|
||||
return cls.enabled
|
||||
|
||||
@@ -40,12 +40,12 @@ def is_next(first: Node, second: Node) -> bool:
|
||||
:param second: another Interpolate layer
|
||||
:return: True, if 'first' is an predecessor of 'second', and False otherwise.
|
||||
"""
|
||||
if not node_has_one_consumer(first):
|
||||
return False
|
||||
dests = first.out_port(0).get_destinations()
|
||||
if len(dests) != 1:
|
||||
return False
|
||||
return second.id == dests[0].node.id
|
||||
if node_has_one_consumer(first):
|
||||
return second.id == dests[0].node.id
|
||||
elif first.soft_get('maybe_part_of_sequence', False):
|
||||
return len(dests) == 2 and second.id in [d.node.id for d in dests]
|
||||
return False
|
||||
|
||||
|
||||
class CanBeFused:
|
||||
@@ -131,7 +131,8 @@ class CanBeFused:
|
||||
:param second: the second of fused nodes
|
||||
:return: True, if nodes can be fused, and False otherwise
|
||||
"""
|
||||
if not self._compare_attributes(first, second):
|
||||
if not (is_next(first, second) and self._compare_attributes(first, second)):
|
||||
self.accumulated_axes = set()
|
||||
return False
|
||||
|
||||
fst_axes = set([a for a in Interpolate.get_axes(first)])
|
||||
@@ -250,7 +251,7 @@ def replace_sequence(seq: List[Node], graph: Graph):
|
||||
|
||||
last_interp_node.out_port(0).get_connection().set_source(interp_node.out_port(0))
|
||||
|
||||
rename_nodes([(last_interp_node, last_interp_node_name + '/delete_'), (interp_node, last_interp_node_name)])
|
||||
rename_nodes([(last_interp_node, last_interp_node_name + '/delete'), (interp_node, last_interp_node_name)])
|
||||
|
||||
|
||||
class InterpolateSequenceToInterpolate(MiddleReplacementPattern):
|
||||
@@ -267,6 +268,6 @@ class InterpolateSequenceToInterpolate(MiddleReplacementPattern):
|
||||
log.debug('Enabled replacement of a sequence of Interpolate layers with one Interpolate layer.')
|
||||
interps = [n for n in graph.pseudo_topological_sort() if n.kind == 'op' and n.op == 'Interpolate']
|
||||
fuser = CanBeFused()
|
||||
sequences = group_by_with_binary_predicate(interps, lambda prev, x: is_next(prev, x) and fuser(prev, x))
|
||||
sequences = group_by_with_binary_predicate(interps, fuser)
|
||||
for seq in sequences:
|
||||
replace_sequence(seq, graph)
|
||||
|
||||
@@ -84,7 +84,7 @@ def replace_resize(graph: Graph, resize: Node):
|
||||
'shrink_axis_mask': int64_array([0]),
|
||||
'ellipsis_mask': int64_array([0])})
|
||||
axes_node = Const(graph,
|
||||
{'name': resize_name + '/axis_',
|
||||
{'name': resize_name + '/axis',
|
||||
'value': int64_array(np.arange(begin_dim, end_dim))}).create_node()
|
||||
|
||||
shape_calculation_mode = 'scales' if num_of_inputs == 3 else 'sizes'
|
||||
@@ -101,21 +101,21 @@ def replace_resize(graph: Graph, resize: Node):
|
||||
'in_ports_count': 4}).create_node()
|
||||
|
||||
axes_node.out_port(0).connect(interpolate_node.in_port(3))
|
||||
shape_of = Shape(graph, {'name': resize_name + '/ShapeOf_'}).create_node()
|
||||
shape_of = Shape(graph, {'name': resize_name + '/ShapeOf'}).create_node()
|
||||
|
||||
add_node = create_op_with_const_inputs(graph, Add,
|
||||
{1: float_array([1.0e-5])},
|
||||
{'name': resize_name + '/Add_'})
|
||||
{'name': resize_name + '/Add'})
|
||||
|
||||
input_data_type = data_type_str_to_np(graph.graph['cmd_params'].data_type)
|
||||
|
||||
if num_of_inputs == 3:
|
||||
cast_shape_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
|
||||
mul_node = Mul(graph, {'name': resize_name + '/Mul_'}).create_node()
|
||||
mul_node = Mul(graph, {'name': resize_name + '/Mul'}).create_node()
|
||||
shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
|
||||
cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))
|
||||
cast_add_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node()
|
||||
floor_node = Floor(graph, {'name': resize_name + '/Floor_'}).create_node()
|
||||
floor_node = Floor(graph, {'name': resize_name + '/Floor'}).create_node()
|
||||
mul_node.out_port(0).connect(add_node.in_port(0))
|
||||
add_node.out_port(0).connect(floor_node.in_port(0))
|
||||
floor_node.out_port(0).connect(cast_add_result_to_int.in_port(0))
|
||||
@@ -134,7 +134,7 @@ def replace_resize(graph: Graph, resize: Node):
|
||||
else:
|
||||
cast_shape_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
|
||||
cast_sizes_to_float = Cast(graph, {'dst_type': input_data_type}).create_node()
|
||||
div_node = Div(graph, {'name': resize_name + '/Div_'}).create_node()
|
||||
div_node = Div(graph, {'name': resize_name + '/Div'}).create_node()
|
||||
cast_sizes_to_float.out_port(0).connect(div_node.in_port(0))
|
||||
cast_shape_to_float.out_port(0).connect(div_node.in_port(1))
|
||||
shape_of.out_port(0).connect(cast_shape_to_float.in_port(0))
|
||||
@@ -156,7 +156,7 @@ def replace_resize(graph: Graph, resize: Node):
|
||||
resize.out_port(0).get_connection().set_source(interpolate_node.out_port(0))
|
||||
|
||||
|
||||
class ONNXResize11ToInterpolate4(MiddleReplacementPattern):
|
||||
class ONNXResize11ToInterpolate(MiddleReplacementPattern):
|
||||
"""
|
||||
The transformation replaces ONNX Resize 11 with Interpolate-4.
|
||||
"""
|
||||
@@ -15,8 +15,11 @@
|
||||
"""
|
||||
|
||||
import logging as log
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ops.activation_ops import Floor
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.elementwise import Mul
|
||||
from extensions.ops.interpolate import Interpolate
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
@@ -74,20 +77,21 @@ def get_split_scale(split: Node) -> int:
|
||||
|
||||
def replace_interpolate_pattern(graph: Graph, match: dict):
|
||||
split = match['split']
|
||||
scale = int64_array([get_split_scale(split)])
|
||||
scale = np.array([get_split_scale(split)], dtype=np.float32)
|
||||
axis = int(split.in_port(1).get_connection().get_source().node.value)
|
||||
split_node_name = split.name
|
||||
axis_node = Const(graph, {'name': split_node_name + '/axis', 'value': int64_array([axis])}).create_node()
|
||||
|
||||
shape_node = Shape(graph, dict(name=split_node_name + '/Shape_')).create_node()
|
||||
scales_node = Const(graph, dict(name=split_node_name + '/scales_', value=scale)).create_node()
|
||||
mul_node = Mul(graph, dict(name=split_node_name + '/Mul_')).create_node()
|
||||
shape_node = Shape(graph, dict(name=split_node_name + '/Shape')).create_node()
|
||||
scales_node = Const(graph, dict(name=split_node_name + '/scales', value=scale)).create_node()
|
||||
mul_node = Mul(graph, dict(name=split_node_name + '/Mul')).create_node()
|
||||
scales_node.out_port(0).connect(mul_node.in_port(1))
|
||||
|
||||
strided_slice_node = create_op_with_const_inputs(graph,
|
||||
StridedSlice,
|
||||
{1: int64_array([axis]), 2: int64_array([axis + 1])},
|
||||
{
|
||||
'name': split_node_name + '/StridedSlice_',
|
||||
'name': split_node_name + '/StridedSlice',
|
||||
'begin_mask': int64_array([1]),
|
||||
'end_mask': int64_array([1]),
|
||||
'new_axis_mask': int64_array([0]),
|
||||
@@ -96,12 +100,28 @@ def replace_interpolate_pattern(graph: Graph, match: dict):
|
||||
})
|
||||
shape_node.out_port(0).connect(strided_slice_node.in_port(0))
|
||||
|
||||
strided_slice_node.out_port(0).connect(mul_node.in_port(0))
|
||||
cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node()
|
||||
|
||||
interp_node = Interpolate(graph, dict(name=split_node_name + '/Interpolate_',
|
||||
axes=int64_array([axis]),
|
||||
mode='nearest')).create_node()
|
||||
mul_node.out_port(0).connect(interp_node.in_port(1))
|
||||
strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0))
|
||||
cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))
|
||||
|
||||
interp_node = Interpolate(graph,
|
||||
dict(name=split_node_name + '/Interpolate',
|
||||
mode='nearest',
|
||||
antialias=0, pads_begin=int64_array([0]), pads_end=int64_array([0]),
|
||||
coordinate_transformation_mode='half_pixel', nearest_mode='round_prefer_floor',
|
||||
cube_coeff=-0.75, version='opset4', shape_calculation_mode='scales',
|
||||
in_ports_count=4, maybe_part_of_sequence=True)).create_node()
|
||||
|
||||
floor_node = Floor(graph, {'name': split_node_name + '/Floor'}).create_node()
|
||||
cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node()
|
||||
|
||||
mul_node.out_port(0).connect(floor_node.in_port(0))
|
||||
floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0))
|
||||
|
||||
cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1))
|
||||
scales_node.out_port(0).connect(interp_node.in_port(2))
|
||||
axis_node.out_port(0).connect(interp_node.in_port(3))
|
||||
|
||||
match['concat'].out_port(0).get_connection().set_source(interp_node.out_port(0))
|
||||
|
||||
|
||||
@@ -39,7 +39,11 @@ graph_node_attrs_for_2d_spatial_case = {
|
||||
'op': 'Const',
|
||||
'type': 'Const'
|
||||
},
|
||||
'split_axis_const_data': {'value': None, 'shape': np.array(3, dtype=np.int64).shape, 'kind': 'data'},
|
||||
'split_axis_const_data': {
|
||||
'value': np.array(3, dtype=np.int64),
|
||||
'shape': np.array(3, dtype=np.int64).shape,
|
||||
'kind': 'data'
|
||||
},
|
||||
'concat': {'type': 'Concat', 'kind': 'op', 'axis': 3},
|
||||
'split_data_0': {'value': None, 'shape': int64_array([1, 100, 120, 50]), 'kind': 'data'},
|
||||
'split_data_1': {'value': None, 'shape': int64_array([1, 100, 120, 50]), 'kind': 'data'},
|
||||
@@ -66,7 +70,11 @@ graph_node_attrs_for_3d_spatial_case = {
|
||||
'op': 'Const',
|
||||
'type': 'Const'
|
||||
},
|
||||
'split_axis_const_data': {'value': None, 'shape': np.array(4, dtype=np.int64).shape, 'kind': 'data'},
|
||||
'split_axis_const_data': {
|
||||
'value': np.array(4, dtype=np.int64),
|
||||
'shape': np.array(4, dtype=np.int64).shape,
|
||||
'kind': 'data'
|
||||
},
|
||||
'concat': {'type': 'Concat', 'kind': 'op', 'axis': 4},
|
||||
'split_data_0': {'value': None, 'shape': int64_array([1, 3, 100, 120, 50]), 'kind': 'data'},
|
||||
'split_data_1': {'value': None, 'shape': int64_array([1, 3, 100, 120, 50]), 'kind': 'data'},
|
||||
@@ -99,7 +107,7 @@ graph_edges = [
|
||||
]
|
||||
|
||||
|
||||
ref_graph_edges = [
|
||||
ref_graph_edges_opset4 = [
|
||||
('placeholder', 'placeholder_data'),
|
||||
('placeholder_data', 'interpolate', {'in': 0}),
|
||||
('placeholder_data', 'shape'),
|
||||
@@ -110,17 +118,100 @@ ref_graph_edges = [
|
||||
('slice_end', 'slice_end_data'),
|
||||
('slice_end_data', 'sslice', {'in': 2}),
|
||||
('sslice', 'sslice_data'),
|
||||
('sslice_data', 'cast_shape_to_float'),
|
||||
('cast_shape_to_float', 'cast_shape_to_float_data'),
|
||||
('scales', 'scales_data'),
|
||||
('sslice_data', 'mul', {'in': 0}),
|
||||
('scales_data', 'mul', {'in': 1}),
|
||||
('axes', 'axes_data'),
|
||||
('cast_shape_to_float_data', 'mul', {'in': 0}),
|
||||
('scales_data', 'mul', {'in': 1, 'out': 0}),
|
||||
('mul', 'mul_data'),
|
||||
('mul_data', 'interpolate', {'in': 1}),
|
||||
('mul_data', 'floor'),
|
||||
('floor', 'floor_data'),
|
||||
('floor_data', 'cast_mul_to_float'),
|
||||
('cast_mul_to_float', 'cast_mul_to_float_data'),
|
||||
('cast_mul_to_float_data', 'interpolate', {'in': 1}),
|
||||
('scales_data', 'interpolate', {'in': 2, 'out': 0}),
|
||||
('axes_data', 'interpolate', {'in': 3}),
|
||||
('interpolate', 'interpolate_data'),
|
||||
('interpolate_data', 'abs'),
|
||||
('abs', 'abs_data'),
|
||||
('abs_data', 'output'),
|
||||
]
|
||||
|
||||
ref_graph_node_attrs_for_2d_spatial_case_1_opset4 = {
|
||||
'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_data': {
|
||||
'value': None,
|
||||
'shape': int64_array([1, 100, 120, 150]),
|
||||
'kind': 'data',
|
||||
'data_type': None
|
||||
},
|
||||
'interpolate': {
|
||||
'type': 'Interpolate',
|
||||
'kind': 'op',
|
||||
'op': 'Interpolate',
|
||||
'mode': 'nearest',
|
||||
'antialias': 0,
|
||||
'pads_begin': int64_array([0]),
|
||||
'pads_end': int64_array([0]),
|
||||
'coordinate_transformation_mode': 'half_pixel',
|
||||
'nearest_mode': 'round_prefer_floor',
|
||||
'cube_coeff': -0.75,
|
||||
'version': 'opset4',
|
||||
'shape_calculation_mode': 'scales'
|
||||
},
|
||||
'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data', 'shape': None, 'value': None},
|
||||
'slice_begin': {
|
||||
'type': 'Const',
|
||||
'op': 'Const',
|
||||
'kind': 'op',
|
||||
'value': int64_array([3]),
|
||||
'shape': int64_array([1])
|
||||
},
|
||||
'slice_begin_data': {'kind': 'data', 'shape': int64_array([1]), 'value': int64_array([3])},
|
||||
'slice_end': {'type': 'Const', 'op': 'Const', 'kind': 'op', 'value': int64_array([4]), 'shape': int64_array([1])},
|
||||
'slice_end_data': {'kind': 'data', 'value': int64_array([4]), 'shape': int64_array([1])},
|
||||
'sslice': {
|
||||
'kind': 'op',
|
||||
'type': 'StridedSlice',
|
||||
'op': 'StridedSlice',
|
||||
'begin_mask': int64_array([1]),
|
||||
'end_mask': int64_array([1]),
|
||||
'new_axis_mask': int64_array([0]),
|
||||
'shrink_axis_mask': int64_array([0]),
|
||||
'ellipsis_mask': int64_array([0]),
|
||||
},
|
||||
'sslice_data': {'kind': 'data', 'shape': None},
|
||||
'scales': {
|
||||
'type': 'Const',
|
||||
'op': 'Const',
|
||||
'kind': 'op',
|
||||
'value': np.array([2], dtype=np.float32),
|
||||
'shape': int64_array([1])
|
||||
},
|
||||
'scales_data': {'kind': 'data', 'shape': None},
|
||||
'cast_shape_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.float32},
|
||||
'cast_shape_to_float_data': {'kind': 'data', 'shape': None},
|
||||
'axes': {
|
||||
'type': 'Const',
|
||||
'op': 'Const',
|
||||
'kind': 'op',
|
||||
'value': int64_array([3]),
|
||||
'shape': int64_array([1])
|
||||
},
|
||||
'axes_data': {'kind': 'data', 'shape': None},
|
||||
'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
|
||||
'mul_data': {'kind': 'data', 'shape': None},
|
||||
'floor': {'kind': 'op', 'op': 'Floor', 'type': 'Floor'},
|
||||
'floor_data': {'kind': 'data', 'shape': None},
|
||||
'cast_mul_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64},
|
||||
'cast_mul_to_float_data': {'kind': 'data', 'shape': None},
|
||||
'interpolate_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
|
||||
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
|
||||
'abs_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
|
||||
'output': {'kind': 'op', 'op': 'Result'},
|
||||
}
|
||||
|
||||
ref_graph_node_attrs_for_2d_spatial_case_1 = {
|
||||
'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
@@ -168,6 +259,14 @@ ref_graph_node_attrs_for_2d_spatial_case_1 = {
|
||||
'shape': int64_array([1])
|
||||
},
|
||||
'scales_data': {'kind': 'data', 'shape': None},
|
||||
'axes': {
|
||||
'type': 'Const',
|
||||
'op': 'Const',
|
||||
'kind': 'op',
|
||||
'value': int64_array([3]),
|
||||
'shape': int64_array([1])
|
||||
},
|
||||
'axes_data': {'kind': 'data', 'shape': None},
|
||||
'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
|
||||
'mul_data': {'kind': 'data', 'shape': None},
|
||||
'interpolate_data': {'value': None, 'shape': int64_array([1, 100, 120, 300]), 'kind': 'data'},
|
||||
@@ -188,8 +287,15 @@ ref_graph_node_attrs_for_2d_spatial_case_2 = {
|
||||
'type': 'Interpolate',
|
||||
'kind': 'op',
|
||||
'op': 'Interpolate',
|
||||
'axes': int64_array([2]),
|
||||
'mode': 'nearest'
|
||||
'mode': 'nearest',
|
||||
'antialias': 0,
|
||||
'pads_begin': int64_array([0]),
|
||||
'pads_end': int64_array([0]),
|
||||
'coordinate_transformation_mode': 'half_pixel',
|
||||
'nearest_mode': 'round_prefer_floor',
|
||||
'cube_coeff': -0.75,
|
||||
'version': 'opset4',
|
||||
'shape_calculation_mode': 'scales'
|
||||
},
|
||||
'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data', 'shape': None, 'value': None},
|
||||
@@ -218,12 +324,26 @@ ref_graph_node_attrs_for_2d_spatial_case_2 = {
|
||||
'type': 'Const',
|
||||
'op': 'Const',
|
||||
'kind': 'op',
|
||||
'value': int64_array([2]),
|
||||
'value': np.array([2], dtype=np.float32),
|
||||
'shape': int64_array([1])
|
||||
},
|
||||
'scales_data': {'kind': 'data', 'shape': None},
|
||||
'cast_shape_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.float32},
|
||||
'cast_shape_to_float_data': {'kind': 'data', 'shape': None},
|
||||
'axes': {
|
||||
'type': 'Const',
|
||||
'op': 'Const',
|
||||
'kind': 'op',
|
||||
'value': int64_array([3]),
|
||||
'shape': int64_array([1])
|
||||
},
|
||||
'axes_data': {'kind': 'data', 'shape': None},
|
||||
'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
|
||||
'mul_data': {'kind': 'data', 'shape': None},
|
||||
'floor': {'kind': 'op', 'op': 'Floor', 'type': 'Floor'},
|
||||
'floor_data': {'kind': 'data', 'shape': None},
|
||||
'cast_mul_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64},
|
||||
'cast_mul_to_float_data': {'kind': 'data', 'shape': None},
|
||||
'interpolate_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
|
||||
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
|
||||
'abs_data': {'value': None, 'shape': int64_array([1, 100, 240, 150]), 'kind': 'data'},
|
||||
@@ -243,8 +363,15 @@ ref_graph_node_attrs_for_3d_spatial_case_1 = {
|
||||
'type': 'Interpolate',
|
||||
'kind': 'op',
|
||||
'op': 'Interpolate',
|
||||
'axes': int64_array([4]),
|
||||
'mode': 'nearest'
|
||||
'mode': 'nearest',
|
||||
'antialias': 0,
|
||||
'pads_begin': int64_array([0]),
|
||||
'pads_end': int64_array([0]),
|
||||
'coordinate_transformation_mode': 'half_pixel',
|
||||
'nearest_mode': 'round_prefer_floor',
|
||||
'cube_coeff': -0.75,
|
||||
'version': 'opset4',
|
||||
'shape_calculation_mode': 'scales'
|
||||
},
|
||||
'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data', 'shape': None, 'value': None},
|
||||
@@ -273,12 +400,26 @@ ref_graph_node_attrs_for_3d_spatial_case_1 = {
|
||||
'type': 'Const',
|
||||
'op': 'Const',
|
||||
'kind': 'op',
|
||||
'value': int64_array([2]),
|
||||
'value': np.array([2], dtype=np.float32),
|
||||
'shape': int64_array([1])
|
||||
},
|
||||
'scales_data': {'kind': 'data', 'shape': None},
|
||||
'cast_shape_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.float32},
|
||||
'cast_shape_to_float_data': {'kind': 'data', 'shape': None},
|
||||
'axes': {
|
||||
'type': 'Const',
|
||||
'op': 'Const',
|
||||
'kind': 'op',
|
||||
'value': int64_array([3]),
|
||||
'shape': int64_array([1])
|
||||
},
|
||||
'axes_data': {'kind': 'data', 'shape': None},
|
||||
'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
|
||||
'mul_data': {'kind': 'data', 'shape': None},
|
||||
'floor': {'kind': 'op', 'op': 'Floor', 'type': 'Floor'},
|
||||
'floor_data': {'kind': 'data', 'shape': None},
|
||||
'cast_mul_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64},
|
||||
'cast_mul_to_float_data': {'kind': 'data', 'shape': None},
|
||||
'interpolate_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
|
||||
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
|
||||
'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 120, 300]), 'kind': 'data'},
|
||||
@@ -298,8 +439,15 @@ ref_graph_node_attrs_for_3d_spatial_case_2 = {
|
||||
'type': 'Interpolate',
|
||||
'kind': 'op',
|
||||
'op': 'Interpolate',
|
||||
'axes': int64_array([3]),
|
||||
'mode': 'nearest'
|
||||
'mode': 'nearest',
|
||||
'antialias': 0,
|
||||
'pads_begin': int64_array([0]),
|
||||
'pads_end': int64_array([0]),
|
||||
'coordinate_transformation_mode': 'half_pixel',
|
||||
'nearest_mode': 'round_prefer_floor',
|
||||
'cube_coeff': -0.75,
|
||||
'version': 'opset4',
|
||||
'shape_calculation_mode': 'scales'
|
||||
},
|
||||
'shape': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shape_data': {'kind': 'data', 'shape': None, 'value': None},
|
||||
@@ -328,12 +476,26 @@ ref_graph_node_attrs_for_3d_spatial_case_2 = {
|
||||
'type': 'Const',
|
||||
'op': 'Const',
|
||||
'kind': 'op',
|
||||
'value': int64_array([2]),
|
||||
'value': np.array([2], dtype=np.float32),
|
||||
'shape': int64_array([1])
|
||||
},
|
||||
'scales_data': {'kind': 'data', 'shape': None},
|
||||
'cast_shape_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.float32},
|
||||
'cast_shape_to_float_data': {'kind': 'data', 'shape': None},
|
||||
'axes': {
|
||||
'type': 'Const',
|
||||
'op': 'Const',
|
||||
'kind': 'op',
|
||||
'value': int64_array([3]),
|
||||
'shape': int64_array([1])
|
||||
},
|
||||
'axes_data': {'kind': 'data', 'shape': None},
|
||||
'mul': {'kind': 'op', 'op': 'Mul', 'type': 'Multiply'},
|
||||
'mul_data': {'kind': 'data', 'shape': None},
|
||||
'floor': {'kind': 'op', 'op': 'Floor', 'type': 'Floor'},
|
||||
'floor_data': {'kind': 'data', 'shape': None},
|
||||
'cast_mul_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64},
|
||||
'cast_mul_to_float_data': {'kind': 'data', 'shape': None},
|
||||
'interpolate_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
|
||||
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
|
||||
'abs_data': {'value': None, 'shape': int64_array([1, 3, 100, 240, 150]), 'kind': 'data'},
|
||||
@@ -348,8 +510,8 @@ class SplitConcatPairToInterpolateTest(unittest.TestCase):
|
||||
edges=graph_edges
|
||||
)
|
||||
ref_graph = build_graph(
|
||||
nodes_attrs=ref_graph_node_attrs_for_2d_spatial_case_1,
|
||||
edges=ref_graph_edges
|
||||
nodes_attrs=ref_graph_node_attrs_for_2d_spatial_case_1_opset4,
|
||||
edges=ref_graph_edges_opset4
|
||||
)
|
||||
SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'output')
|
||||
@@ -367,7 +529,11 @@ class SplitConcatPairToInterpolateTest(unittest.TestCase):
|
||||
'op': 'Const',
|
||||
'type': 'Const'
|
||||
},
|
||||
'split_axis_const_data': {'value': None, 'shape': np.array(2, dtype=np.int64).shape, 'kind': 'data'},
|
||||
'split_axis_const_data': {
|
||||
'value': np.array(2, dtype=np.int64),
|
||||
'shape': np.array(2, dtype=np.int64).shape,
|
||||
'kind': 'data'
|
||||
},
|
||||
'concat': {'type': 'Concat', 'kind': 'op', 'axis': 2},
|
||||
'split_data_0': {'value': None, 'shape': int64_array([1, 100, 40, 150]), 'kind': 'data'},
|
||||
'split_data_1': {'value': None, 'shape': int64_array([1, 100, 40, 150]), 'kind': 'data'},
|
||||
@@ -378,7 +544,10 @@ class SplitConcatPairToInterpolateTest(unittest.TestCase):
|
||||
)
|
||||
ref_graph = build_graph(
|
||||
nodes_attrs=ref_graph_node_attrs_for_2d_spatial_case_2,
|
||||
edges=ref_graph_edges
|
||||
edges=ref_graph_edges_opset4,
|
||||
update_attributes={
|
||||
'axes': {'shape': int64_array([1]), 'value': int64_array([2])}
|
||||
}
|
||||
)
|
||||
SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'output')
|
||||
@@ -391,7 +560,10 @@ class SplitConcatPairToInterpolateTest(unittest.TestCase):
|
||||
)
|
||||
ref_graph = build_graph(
|
||||
nodes_attrs=ref_graph_node_attrs_for_3d_spatial_case_1,
|
||||
edges=ref_graph_edges
|
||||
edges=ref_graph_edges_opset4,
|
||||
update_attributes={
|
||||
'axes': {'shape': int64_array([1]), 'value': int64_array([4])}
|
||||
}
|
||||
)
|
||||
SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'output')
|
||||
@@ -409,7 +581,11 @@ class SplitConcatPairToInterpolateTest(unittest.TestCase):
|
||||
'op': 'Const',
|
||||
'type': 'Const'
|
||||
},
|
||||
'split_axis_const_data': {'value': None, 'shape': np.array(3, dtype=np.int64).shape, 'kind': 'data'},
|
||||
'split_axis_const_data': {
|
||||
'value': np.array(3, dtype=np.int64),
|
||||
'shape': np.array(3, dtype=np.int64).shape,
|
||||
'kind': 'data'
|
||||
},
|
||||
'concat': {'type': 'Concat', 'kind': 'op', 'axis': 3},
|
||||
'split_data_0': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
|
||||
'split_data_1': {'value': None, 'shape': int64_array([1, 3, 100, 40, 150]), 'kind': 'data'},
|
||||
@@ -420,7 +596,7 @@ class SplitConcatPairToInterpolateTest(unittest.TestCase):
|
||||
)
|
||||
ref_graph = build_graph(
|
||||
nodes_attrs=ref_graph_node_attrs_for_3d_spatial_case_2,
|
||||
edges=ref_graph_edges
|
||||
edges=ref_graph_edges_opset4
|
||||
)
|
||||
SplitConcatPairToInterpolate().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'output')
|
||||
|
||||
@@ -15,12 +15,14 @@
|
||||
"""
|
||||
|
||||
import logging as log
|
||||
import numpy as np
|
||||
|
||||
from extensions.middle.InterpolateSequenceToInterpolate import InterpolateSequenceToInterpolate
|
||||
from extensions.ops.activation_ops import Floor
|
||||
from extensions.ops.Cast import Cast
|
||||
from extensions.ops.elementwise import Mul
|
||||
from extensions.ops.interpolate import Interpolate
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Graph
|
||||
from mo.front.common.partial_infer.utils import int64_array, float32_array
|
||||
from mo.graph.graph import Graph, rename_nodes
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.shape import Shape
|
||||
@@ -53,6 +55,7 @@ class UnsqueezeTileReshapeBlockToInterpolate(MiddleReplacementPattern):
|
||||
force_shape_inference = True
|
||||
|
||||
def run_before(self):
|
||||
from extensions.middle.InterpolateSequenceToInterpolate import InterpolateSequenceToInterpolate
|
||||
return [InterpolateSequenceToInterpolate]
|
||||
|
||||
def pattern(self):
|
||||
@@ -91,19 +94,20 @@ class UnsqueezeTileReshapeBlockToInterpolate(MiddleReplacementPattern):
|
||||
if len(input_shape_of_unsqueeze) not in {4, 5}:
|
||||
return
|
||||
|
||||
scale = int64_array([second_input_of_tile.value[d_idx]])
|
||||
scale = float32_array([second_input_of_tile.value[d_idx]])
|
||||
axis = d_idx - 1
|
||||
axis_node = Const(graph, {'name': unsqueeze_name + '/axis', 'value': int64_array([axis])}).create_node()
|
||||
|
||||
shape_node = Shape(graph, dict(name=unsqueeze_name + '/Shape_')).create_node()
|
||||
scales_node = Const(graph, dict(name=unsqueeze_name + '/scales_', value=scale)).create_node()
|
||||
mul_node = Mul(graph, dict(name=unsqueeze_name + '/Mul_')).create_node()
|
||||
shape_node = Shape(graph, dict(name=unsqueeze_name + '/Shape')).create_node()
|
||||
scales_node = Const(graph, dict(name=unsqueeze_name + '/scales', value=scale)).create_node()
|
||||
mul_node = Mul(graph, dict(name=unsqueeze_name + '/Mul')).create_node()
|
||||
scales_node.out_port(0).connect(mul_node.in_port(1))
|
||||
|
||||
slice_begin = Const(graph, dict(name=unsqueeze_name + '/slice_begin_', value=int64_array([axis]))).create_node()
|
||||
slice_end = Const(graph, dict(name=unsqueeze_name + '/slice_end_', value=int64_array([axis + 1]))).create_node()
|
||||
slice_begin = Const(graph, dict(name=unsqueeze_name + '/slice_begin', value=int64_array([axis]))).create_node()
|
||||
slice_end = Const(graph, dict(name=unsqueeze_name + '/slice_end', value=int64_array([axis + 1]))).create_node()
|
||||
|
||||
strided_slice_node = StridedSlice(graph,
|
||||
{'name': unsqueeze_name + '/StridedSlice_',
|
||||
{'name': unsqueeze_name + '/StridedSlice',
|
||||
'begin_mask': int64_array([1]),
|
||||
'end_mask': int64_array([1]),
|
||||
'new_axis_mask': int64_array([0]),
|
||||
@@ -113,14 +117,36 @@ class UnsqueezeTileReshapeBlockToInterpolate(MiddleReplacementPattern):
|
||||
shape_node.out_port(0).connect(strided_slice_node.in_port(0))
|
||||
slice_begin.out_port(0).connect(strided_slice_node.in_port(1))
|
||||
slice_end.out_port(0).connect(strided_slice_node.in_port(2))
|
||||
strided_slice_node.out_port(0).connect(mul_node.in_port(0))
|
||||
|
||||
interp_node = Interpolate(graph, dict(name=unsqueeze_name + '/Interpolate_',
|
||||
axes=int64_array([axis]),
|
||||
mode='nearest')).create_node()
|
||||
mul_node.out_port(0).connect(interp_node.in_port(1))
|
||||
cast_shape_to_float = Cast(graph, {'dst_type': np.float32}).create_node()
|
||||
|
||||
match['reshape'].out_port(0).get_connection().set_source(interp_node.out_port(0))
|
||||
strided_slice_node.out_port(0).connect(cast_shape_to_float.in_port(0))
|
||||
cast_shape_to_float.out_port(0).connect(mul_node.in_port(0))
|
||||
|
||||
interp_node = Interpolate(graph,
|
||||
dict(mode='nearest',
|
||||
antialias=0, pads_begin=int64_array([0]),
|
||||
pads_end=int64_array([0]), coordinate_transformation_mode='half_pixel',
|
||||
nearest_mode='round_prefer_floor', cube_coeff=-0.75,
|
||||
version='opset4', shape_calculation_mode='scales',
|
||||
in_ports_count=4,
|
||||
maybe_part_of_sequence=True)).create_node()
|
||||
|
||||
floor_node = Floor(graph, {'name': unsqueeze_name + '/Floor'}).create_node()
|
||||
cast_mul_result_to_int = Cast(graph, {'dst_type': np.int64}).create_node()
|
||||
|
||||
mul_node.out_port(0).connect(floor_node.in_port(0))
|
||||
floor_node.out_port(0).connect(cast_mul_result_to_int.in_port(0))
|
||||
|
||||
cast_mul_result_to_int.out_port(0).connect(interp_node.in_port(1))
|
||||
scales_node.out_port(0).connect(interp_node.in_port(2))
|
||||
axis_node.out_port(0).connect(interp_node.in_port(3))
|
||||
|
||||
reshape_node = match['reshape']
|
||||
|
||||
reshape_node.out_port(0).get_connection().set_source(interp_node.out_port(0))
|
||||
reshape_name = reshape_node.soft_get('name', reshape_node.id)
|
||||
rename_nodes([(reshape_node, reshape_name + '/delete'), (interp_node, reshape_name)])
|
||||
|
||||
unsqueeze_connection = match['unsqueeze'].in_port(0).get_connection()
|
||||
before_unsqueeze = unsqueeze_connection.get_source().node
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
from extensions.middle.UnsqueezeTileReshapeBlockToInterpolate import UnsqueezeTileReshapeBlockToInterpolate
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
@@ -112,7 +113,7 @@ graph_edges = [
|
||||
]
|
||||
|
||||
|
||||
ref_graph_node_attrs = {
|
||||
ref_graph_node_attrs_with_4_inputs_interpolate = {
|
||||
'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_data': {
|
||||
'value': None,
|
||||
@@ -169,7 +170,7 @@ ref_graph_node_attrs = {
|
||||
'kind': 'op',
|
||||
'op': 'Const',
|
||||
'type': 'Const',
|
||||
'value': int64_array([2]),
|
||||
'value': np.array([2], dtype=np.float32),
|
||||
'shape': int64_array([1]),
|
||||
},
|
||||
'scales_data': {
|
||||
@@ -177,18 +178,36 @@ ref_graph_node_attrs = {
|
||||
'value': None,
|
||||
'shape': None,
|
||||
},
|
||||
'cast_shape_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.float32},
|
||||
'cast_shape_to_float_data': {'kind': 'data', 'shape': None},
|
||||
'mul': {'type': 'Mul', 'kind': 'op', 'op': 'Mul'},
|
||||
'mul_data': {
|
||||
'kind': 'data',
|
||||
'value': None,
|
||||
'shape': None,
|
||||
},
|
||||
'floor': {'kind': 'op', 'op': 'Floor', 'type': 'Floor'},
|
||||
'floor_data': {'kind': 'data', 'shape': None},
|
||||
'cast_mul_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int64},
|
||||
'cast_mul_to_float_data': {'kind': 'data', 'shape': None},
|
||||
'interpolate': {'type': 'Interpolate', 'kind': 'op', 'op': 'Interpolate'},
|
||||
'interpolate_data': {
|
||||
'kind': 'data',
|
||||
'value': None,
|
||||
'shape': int64_array([1, 16, 32, 32, 64]),
|
||||
},
|
||||
'axes': {
|
||||
'kind': 'op',
|
||||
'op': 'Const',
|
||||
'type': 'Const',
|
||||
'value': int64_array([1]),
|
||||
'shape': int64_array([1]),
|
||||
},
|
||||
'axes_data': {
|
||||
'kind': 'data',
|
||||
'value': None,
|
||||
'shape': None,
|
||||
},
|
||||
'abs': {'type': 'Abs', 'kind': 'op', 'op': 'Abs'},
|
||||
'abs_data': {
|
||||
'kind': 'data',
|
||||
@@ -198,7 +217,8 @@ ref_graph_node_attrs = {
|
||||
'output': {'kind': 'op', 'op': 'Result', 'type': 'Result'},
|
||||
}
|
||||
|
||||
ref_graph_edges = [
|
||||
|
||||
ref_graph_edges_attrs_with_4_inputs_interpolate = [
|
||||
('placeholder', 'placeholder_data'),
|
||||
('placeholder_data', 'shapeof'),
|
||||
('shapeof', 'shapeof_data'),
|
||||
@@ -209,11 +229,20 @@ ref_graph_edges = [
|
||||
('end', 'end_data'),
|
||||
('end_data', 'strided_slice', {'in': 2}),
|
||||
('scales', 'scales_data'),
|
||||
('strided_slice_data', 'mul', {'in': 0}),
|
||||
('scales_data', 'mul', {'in': 1}),
|
||||
('strided_slice_data', 'cast_shape_to_float'),
|
||||
('cast_shape_to_float', 'cast_shape_to_float_data'),
|
||||
('cast_shape_to_float_data', 'mul', {'in': 0}),
|
||||
('scales_data', 'mul', {'out': 0, 'in': 1}),
|
||||
('scales_data', 'interpolate', {'out': 0, 'in': 2}),
|
||||
('mul', 'mul_data'),
|
||||
('mul_data', 'interpolate', {'in': 1}),
|
||||
('mul_data', 'floor'),
|
||||
('floor', 'floor_data'),
|
||||
('floor_data', 'cast_mul_to_float'),
|
||||
('cast_mul_to_float', 'cast_mul_to_float_data'),
|
||||
('cast_mul_to_float_data', 'interpolate', {'in': 1}),
|
||||
('placeholder_data', 'interpolate', {'in': 0}),
|
||||
('axes', 'axes_data'),
|
||||
('axes_data', 'interpolate', {'in': 3}),
|
||||
('interpolate', 'interpolate_data'),
|
||||
('interpolate_data', 'abs'),
|
||||
('abs', 'abs_data'),
|
||||
@@ -298,7 +327,8 @@ graph_edges_when_transformation_is_not_applicable = graph_edges
|
||||
class UnsqueezeTileReshapeBlockToInterpolateTest(unittest.TestCase):
|
||||
def test_5d(self):
|
||||
graph = build_graph(nodes_attrs=graph_node_attrs, edges=graph_edges)
|
||||
ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs, edges=ref_graph_edges)
|
||||
ref_graph = build_graph(nodes_attrs=ref_graph_node_attrs_with_4_inputs_interpolate,
|
||||
edges=ref_graph_edges_attrs_with_4_inputs_interpolate)
|
||||
UnsqueezeTileReshapeBlockToInterpolate().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'output')
|
||||
self.assertTrue(flag, resp)
|
||||
@@ -320,12 +350,13 @@ class UnsqueezeTileReshapeBlockToInterpolateTest(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
ref_graph = build_graph(
|
||||
nodes_attrs=ref_graph_node_attrs,
|
||||
edges=ref_graph_edges,
|
||||
nodes_attrs=ref_graph_node_attrs_with_4_inputs_interpolate,
|
||||
edges=ref_graph_edges_attrs_with_4_inputs_interpolate,
|
||||
update_attributes={
|
||||
'placeholder_data': {'shape': int64_array([1, 8, 32, 32])},
|
||||
'interpolate_data': {'shape': int64_array([1, 16, 32, 32])},
|
||||
'abs_data': {'shape': int64_array([1, 16, 32, 32])},
|
||||
'axes': {'shape': int64_array([1]), 'value': int64_array([1])},
|
||||
}
|
||||
)
|
||||
UnsqueezeTileReshapeBlockToInterpolate().find_and_replace_pattern(graph)
|
||||
|
||||
@@ -26,8 +26,9 @@ from extensions.ops.interpolate import Interpolate
|
||||
from mo.front.common.layout import get_height_dim, get_width_dim, get_depth_dim
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.tf.graph_utils import create_op_with_const_inputs, create_op_node_with_second_input
|
||||
from mo.graph.graph import Graph, Node
|
||||
from mo.graph.graph import Graph, Node, rename_nodes
|
||||
from mo.middle.replacement import MiddleReplacementPattern
|
||||
from mo.ops.const import Const
|
||||
from mo.ops.shape import Shape
|
||||
from mo.ops.strided_slice import StridedSlice
|
||||
|
||||
@@ -63,6 +64,8 @@ class UpsampleToResample(MiddleReplacementPattern):
|
||||
return
|
||||
|
||||
depth_scale = None
|
||||
layout = graph.graph['layout']
|
||||
|
||||
if len(upsample.in_nodes()) == 2:
|
||||
if upsample.in_node(1).value is None:
|
||||
return
|
||||
@@ -71,10 +74,10 @@ class UpsampleToResample(MiddleReplacementPattern):
|
||||
len(scales), upsample_name)
|
||||
if not (math.isclose(scales[0], 1, rel_tol=1e-5) and math.isclose(scales[1], 1, rel_tol=1e-5)):
|
||||
return
|
||||
height_scale = scales[2]
|
||||
width_scale = scales[3]
|
||||
height_scale = scales[get_height_dim(layout, input_shape_rank)]
|
||||
width_scale = scales[get_width_dim(layout, input_shape_rank)]
|
||||
if len(scales) == 5:
|
||||
depth_scale = scales[4]
|
||||
depth_scale = scales[get_depth_dim(layout, input_shape_rank)]
|
||||
else:
|
||||
height_scale = upsample['height_scale']
|
||||
width_scale = upsample['width_scale']
|
||||
@@ -82,6 +85,7 @@ class UpsampleToResample(MiddleReplacementPattern):
|
||||
if 1 in upsample.in_ports() and not upsample.in_port(1).disconnected():
|
||||
upsample.in_port(1).disconnect()
|
||||
|
||||
upsample_name = upsample.soft_get('name', upsample.id)
|
||||
shape = Shape(graph, {'name': upsample_name + '/0_port'}).create_node()
|
||||
|
||||
layout = graph.graph['layout']
|
||||
@@ -104,10 +108,9 @@ class UpsampleToResample(MiddleReplacementPattern):
|
||||
'new_axis_mask': int64_array([0]),
|
||||
'shrink_axis_mask': int64_array([0]),
|
||||
'ellipsis_mask': int64_array([0])
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
mul = create_op_node_with_second_input(graph, Mul, factor_value, {'name': upsample_name + '/factor_mul_'})
|
||||
mul = create_op_node_with_second_input(graph, Mul, factor_value, {'name': upsample_name + '/factor_mul'})
|
||||
|
||||
source = upsample.in_port(0).get_connection().get_source()
|
||||
source.connect(shape.in_port(0))
|
||||
@@ -124,16 +127,27 @@ class UpsampleToResample(MiddleReplacementPattern):
|
||||
get_height_dim(layout, input_shape_rank),
|
||||
get_width_dim(layout, input_shape_rank)])
|
||||
|
||||
resample_op = Interpolate(graph, dict(name=upsample_name + '/Interpolate',
|
||||
axes=axes, mode=upsample.attrs()['mode'],
|
||||
antialias=0, convert_to_resample=True)).create_node()
|
||||
axes_node = Const(graph, {'name': upsample_name + '/axis', 'value': axes}).create_node()
|
||||
|
||||
interpolate = Interpolate(graph, {'mode': upsample.attrs()['mode'], 'antialias': 0,
|
||||
'pads_begin': int64_array([0]), 'pads_end': int64_array([0]),
|
||||
'coordinate_transformation_mode': 'half_pixel',
|
||||
'nearest_mode': 'round_prefer_floor', 'cube_coeff': -0.75,
|
||||
'shape_calculation_mode': 'scales',
|
||||
'version': 'opset4', 'in_ports_count': 4}).create_node()
|
||||
|
||||
upsample.add_input_port(1, skip_if_exist=True)
|
||||
assert upsample.in_port(1).disconnected()
|
||||
mul.out_port(0).connect(resample_op.in_port(1))
|
||||
mul.out_port(0).connect(interpolate.in_port(1))
|
||||
axes_node.out_port(0).connect(interpolate.in_port(3))
|
||||
|
||||
upsample.in_port(0).get_connection().set_destination(resample_op.in_port(0))
|
||||
upsample.out_port(0).get_connection().set_source(resample_op.out_port(0))
|
||||
scales_node = Const(graph, {'name': upsample_name + '/scales', 'value': factor_value}).create_node()
|
||||
scales_node.out_port(0).connect(interpolate.in_port(2))
|
||||
|
||||
upsample.in_port(0).get_connection().set_destination(interpolate.in_port(0))
|
||||
upsample.out_port(0).get_connection().set_source(interpolate.out_port(0))
|
||||
|
||||
rename_nodes([(upsample, upsample_name + '/delete'), (interpolate, upsample_name)])
|
||||
|
||||
convert_to_float = Cast(graph, dict(dst_type=np.float32)).create_node()
|
||||
convert_to_int = Cast(graph, dict(dst_type=np.int64)).create_node()
|
||||
|
||||
@@ -43,6 +43,66 @@ graph_edges = [
|
||||
('upsample_data', 'output'),
|
||||
]
|
||||
|
||||
new_ref_graph_node_attr = {
|
||||
'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': np.float32},
|
||||
'ss_begin': {'kind': 'op', 'op': 'Const', 'type': 'Const', 'value': int64_array([2]), 'shape': int64_array([1])},
|
||||
'ss_begin_data': {'kind': 'data', 'value': int64_array([2]), 'shape': int64_array([1])},
|
||||
'ss_end': {'kind': 'op', 'op': 'Const', 'type': 'Const', 'value': int64_array([4]), 'shape': int64_array([1])},
|
||||
'ss_end_data': {'kind': 'data', 'value': int64_array([4]), 'shape': int64_array([1])},
|
||||
'ss_stride': {'kind': 'op', 'op': 'Const', 'type': 'Const', 'value': int64_array([1]), 'shape': int64_array([1])},
|
||||
'ss_stride_data': {'kind': 'data', 'value': int64_array([1]), 'shape': int64_array([1])},
|
||||
'strided_slice': {'type': 'StridedSlice', 'kind': 'op', 'op': 'StridedSlice'},
|
||||
'strided_slice_data': {'kind': 'data', 'shape': None, 'value': None},
|
||||
'cast_to_float': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.float},
|
||||
'cast_to_float_d': {'kind': 'data', 'value': None, 'shape': None},
|
||||
'factor': {'kind': 'op', 'op': 'Const', 'type': 'Const', 'value': int64_array([5, 5]), 'shape': int64_array([2])},
|
||||
'factor_data': {'kind': 'data', 'value': int64_array([5, 5]), 'shape': int64_array([2])},
|
||||
'shapeof': {'type': 'ShapeOf', 'kind': 'op', 'op': 'ShapeOf'},
|
||||
'shapeof_data': {'kind': 'data', 'shape': None, 'value': None},
|
||||
'mul': {'type': 'Multiply', 'kind': 'op', 'op': 'Multiply'},
|
||||
'mul_data': {'kind': 'data', 'shape': None, 'value': None},
|
||||
'cast_to_int': {'kind': 'op', 'op': 'Cast', 'type': 'Convert', 'dst_type': np.int32},
|
||||
'cast_to_int_d': {'kind': 'data', 'shape': None, 'value': None},
|
||||
'axes_const': {'kind': 'op', 'op': 'Const', 'type': 'Const', 'value': None, 'shape': None},
|
||||
'axes_const_data': {'kind': 'data', 'value': None, 'shape': None},
|
||||
'scales': {'kind': 'op', 'op': 'Const', 'type': 'Const', 'value': int64_array([5, 5]), 'shape': int64_array([2])},
|
||||
'scales_data': {'kind': 'data', 'value': None, 'shape': None},
|
||||
'interpolate': {'type': 'Interpolate', 'kind': 'op', 'op': 'Interpolate', 'axes': None},
|
||||
'interpolate_data': {'kind': 'data', 'shape': None, 'value': None},
|
||||
'output': {'kind': 'op', 'op': 'Result', 'type': 'Result'},
|
||||
}
|
||||
|
||||
new_ref_graph_edges = [
|
||||
('placeholder', 'placeholder_data'),
|
||||
('placeholder_data', 'shapeof', {'in': 0, 'out': 0}),
|
||||
('placeholder_data', 'interpolate', {'in': 0, 'out': 0}),
|
||||
('ss_begin', 'ss_begin_data'),
|
||||
('ss_begin_data', 'strided_slice', {'in': 1, 'out': 0}),
|
||||
('ss_end', 'ss_end_data'),
|
||||
('ss_end_data', 'strided_slice', {'in': 2, 'out': 0}),
|
||||
('ss_stride', 'ss_stride_data'),
|
||||
('ss_stride_data', 'strided_slice', {'in': 3, 'out': 0}),
|
||||
('strided_slice', 'strided_slice_data'),
|
||||
('strided_slice_data', 'cast_to_float'),
|
||||
('cast_to_float', 'cast_to_float_d'),
|
||||
('shapeof', 'shapeof_data'),
|
||||
('shapeof_data', 'strided_slice', {'in': 0, 'out': 0}),
|
||||
('factor', 'factor_data'),
|
||||
('cast_to_float_d', 'mul', {'in': 0, 'out': 0}),
|
||||
('factor_data', 'mul', {'in': 1, 'out': 0}),
|
||||
('mul', 'mul_data'),
|
||||
('mul_data', 'cast_to_int'),
|
||||
('cast_to_int', 'cast_to_int_d'),
|
||||
('cast_to_int_d', 'interpolate', {'in': 1, 'out': 0}),
|
||||
('axes_const', 'axes_const_data'),
|
||||
('axes_const_data', 'interpolate', {'in': 3, 'out': 0}),
|
||||
('scales', 'scales_data'),
|
||||
('scales_data', 'interpolate', {'in': 2, 'out': 0}),
|
||||
('interpolate', 'interpolate_data'),
|
||||
('interpolate_data', 'output')
|
||||
]
|
||||
|
||||
ref_graph_node_attrs = {
|
||||
'placeholder': {'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
|
||||
'placeholder_data': {'value': None, 'shape': None, 'kind': 'data', 'data_type': np.float32},
|
||||
@@ -98,30 +158,40 @@ ref_graph_edges = [
|
||||
|
||||
@generator
|
||||
class UpsampleToResampleTest(unittest.TestCase):
|
||||
@generate(*[([2, 10, 20, 30], [1, 1, 5, 5],),
|
||||
([2, 20, 30, 40], [1, 1, 3, 3],),
|
||||
([2, 10, 20, 30], [1, 1, 6, 5],),
|
||||
([2, 20, 30, 40], [1, 1, 3, 4],),
|
||||
([2, 3, 20, 30, 40], [1, 1, 3, 3, 3],),
|
||||
([2, 3, 20, 30, 40], [1, 1, 3, 4, 3],),
|
||||
([2, 3, 20, 30, 40], [1, 1, 4, 3, 3],),
|
||||
([2, 3, 20, 30, 40], [1, 1, 3, 3, 4],),
|
||||
@generate(*[([2, 10, 20, 30], [1, 1, 5, 5], [2, 3]),
|
||||
([2, 20, 30, 40], [1, 1, 3, 3], [2, 3]),
|
||||
([2, 10, 20, 30], [1, 1, 6, 5], [2, 3]),
|
||||
([2, 20, 30, 40], [1, 1, 3, 4], [2, 3]),
|
||||
([2, 3, 20, 30, 40], [1, 1, 3, 3, 3], [2, 3, 4]),
|
||||
([2, 3, 20, 30, 40], [1, 1, 3, 4, 3], [2, 3, 4]),
|
||||
([2, 3, 20, 30, 40], [1, 1, 4, 3, 3], [2, 3, 4]),
|
||||
([2, 3, 20, 30, 40], [1, 1, 3, 3, 4], [2, 3, 4]),
|
||||
])
|
||||
def test_conversion(self, input_shape, scales):
|
||||
graph = build_graph(graph_node_attrs, graph_edges,
|
||||
{'placeholder_data': {'shape': int64_array(input_shape)},
|
||||
'scales': {'value': int64_array(scales), 'shape': int64_array(scales).shape},
|
||||
'scales_data': {'value': int64_array(scales), 'shape': int64_array(scales).shape},
|
||||
'upsample_data': {'shape': int64_array(input_shape) * int64_array(scales)}})
|
||||
def test_conversion(self, input_shape, scales, axes):
|
||||
graph = build_graph(graph_node_attrs,
|
||||
graph_edges,
|
||||
{
|
||||
'placeholder_data': {'shape': int64_array(input_shape)},
|
||||
'scales': {'value': int64_array(scales), 'shape': int64_array(scales).shape},
|
||||
'scales_data': {'value': int64_array(scales), 'shape': int64_array(scales).shape},
|
||||
'upsample_data': {'shape': int64_array(input_shape) * int64_array(scales)}
|
||||
})
|
||||
graph.graph['layout'] = 'NCHW'
|
||||
|
||||
ref_graph = build_graph(ref_graph_node_attrs, ref_graph_edges,
|
||||
{'placeholder_data': {'shape': int64_array(input_shape)},
|
||||
'factor': {'value': int64_array(scales)[2:], 'shape': int64_array(scales[2:]).shape},
|
||||
'interpolate_data': {'shape': int64_array(input_shape) * int64_array(scales)},
|
||||
'interpolate': {'axes': list(range(2, len(input_shape)))}}
|
||||
)
|
||||
|
||||
ref_graph = build_graph(new_ref_graph_node_attr,
|
||||
new_ref_graph_edges,
|
||||
{
|
||||
'placeholder_data': {'shape': int64_array(input_shape)},
|
||||
'ss_begin': {'value': int64_array([axes[0]])},
|
||||
'ss_end': {'value': int64_array([axes[-1] + 1])},
|
||||
'ss_begin_data': {'value': int64_array([axes[0]])},
|
||||
'ss_end_data': {'value': int64_array([axes[-1] + 1])},
|
||||
'factor': {'value': int64_array(scales)[2:],
|
||||
'shape': int64_array(scales[2:]).shape},
|
||||
'factor_data': {'value': int64_array(scales)[2:],
|
||||
'shape': int64_array(scales[2:]).shape},
|
||||
'axes_const': {'value': int64_array(axes), 'shape': int64_array(axes).shape},
|
||||
'interpolate_data': {'shape': int64_array(input_shape) * int64_array(scales)},
|
||||
})
|
||||
UpsampleToResample().find_and_replace_pattern(graph)
|
||||
(flag, resp) = compare_graphs(graph, ref_graph, 'output')
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
31
model-optimizer/extensions/ops/ONNXResize10.py
Normal file
31
model-optimizer/extensions/ops/ONNXResize10.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.ops.op import Op
|
||||
|
||||
|
||||
class ONNXResize10(Op):
|
||||
op = 'ONNXResize10'
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
mandatory_props = {
|
||||
'op': self.op,
|
||||
'in_ports_count': 2,
|
||||
'out_ports_count': 1,
|
||||
}
|
||||
super().__init__(graph, mandatory_props, attrs)
|
||||
67
model-optimizer/extensions/ops/TFResize.py
Normal file
67
model-optimizer/extensions/ops/TFResize.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
from mo.front.common.layout import get_height_dim, get_width_dim
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.ops.op import Op
|
||||
|
||||
|
||||
class TFResize(Op):
|
||||
op = 'TFResize'
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
mandatory_props = {
|
||||
'op': self.op,
|
||||
'out_ports_count': 1,
|
||||
'in_ports_count': 2,
|
||||
'infer': TFResize.tf_resize_infer
|
||||
}
|
||||
super().__init__(graph, mandatory_props, attrs)
|
||||
|
||||
@staticmethod
|
||||
def tf_resize_infer(node: Node):
|
||||
input_shape = node.in_port(0).data.get_shape()
|
||||
if input_shape is None:
|
||||
return
|
||||
|
||||
attrs_msg = "If half_pixel_centers attribute of the node {} with op {} is True, " \
|
||||
"the attribute align_corners must be False"
|
||||
node_name = node.soft_get('name', node.id)
|
||||
assert not node.half_pixel_centers or (node.half_pixel_centers and not node.align_corners), \
|
||||
attrs_msg.format(node_name, node.op)
|
||||
|
||||
connected_in_ports = [port for port in node.in_ports().values() if not port.disconnected()]
|
||||
assert len(connected_in_ports) == 2, \
|
||||
"Node {} with op {} number of inputs must be equal to 2.".format(node_name, node.op)
|
||||
|
||||
new_sizes_value = node.in_port(1).data.get_value()
|
||||
assert new_sizes_value is not None, "Node {} with op {} has no value in input port 1".format(node_name, node.op)
|
||||
|
||||
input_rank = len(input_shape)
|
||||
assert input_rank == 4, \
|
||||
"Resized input data of the node {} with op {} must be 4D tensor".format(node_name, node.op)
|
||||
|
||||
len_msg = "Op {} with name {} supports only resize with respect to height and width dimension simultaneously"
|
||||
assert len(new_sizes_value) == 2, len_msg.format(node_name, node.op)
|
||||
|
||||
output_shape = int64_array(input_shape.copy())
|
||||
|
||||
layout = node.graph.graph['layout']
|
||||
output_shape[get_height_dim(layout, input_rank)] = new_sizes_value[0]
|
||||
output_shape[get_width_dim(layout, input_rank)] = new_sizes_value[1]
|
||||
|
||||
node.out_port(0).data.set_shape(output_shape)
|
||||
@@ -22,6 +22,7 @@ import numpy as np
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.extractor import bool_to_str
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.graph.perm_inputs import PermuteInputs
|
||||
from mo.ops.op import Op, PermuteAttrs
|
||||
|
||||
|
||||
@@ -61,6 +62,9 @@ def infer_for_opset4(node: Node):
|
||||
for i, axis in enumerate(axes):
|
||||
output_shape[axis] = math.floor(scales[i] * output_shape[axis] + 1.0e-5)
|
||||
|
||||
if node.is_in_port_connected(3):
|
||||
PermuteInputs().set_input_permutation(node.in_node(3), node, 'input:0', 'axis')
|
||||
|
||||
node.out_port(0).data.set_shape(output_shape)
|
||||
|
||||
|
||||
@@ -103,7 +107,8 @@ def correct_scales_using_dst_shape(node, dst_shape, src_shape, axes):
|
||||
if scales_value is None or len(scales_value) != len(dst_shape):
|
||||
corrected_scales = np.zeros(len(dst_shape))
|
||||
for i, axis in enumerate(list(axes)):
|
||||
corrected_scales[i] = math.floor((dst_shape[i] / src_shape[axis]) + 1.0e-5)
|
||||
corrected_scales[i] = dst_shape[i] / src_shape[axis]
|
||||
node.in_port(2).data.set_value(corrected_scales)
|
||||
|
||||
|
||||
class Interpolate(Op):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Copyright (C) 2020 Intel Corporation
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
"""
|
||||
Copyright (C) 2018-2020 Intel Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import logging as log
|
||||
|
||||
from extensions.ops.resize_factor_utils import factor_update
|
||||
from mo.front.common.layout import get_batch_dim, get_features_dim, get_height_dim, get_width_dim, shape_for_layout
|
||||
from mo.graph.graph import Node, Graph
|
||||
from mo.ops.op import Op
|
||||
|
||||
|
||||
class ResampleOp(Op):
|
||||
enabled = False
|
||||
op = 'Resample'
|
||||
|
||||
def __init__(self, graph: Graph, attrs: dict):
|
||||
mandatory_props = {
|
||||
'type': __class__.op,
|
||||
'op': __class__.op,
|
||||
'factor': None,
|
||||
'in_ports_count': 2,
|
||||
'out_ports_count': 1,
|
||||
'infer': None
|
||||
}
|
||||
super().__init__(graph, mandatory_props, attrs)
|
||||
|
||||
def supported_attrs(self):
|
||||
return [
|
||||
'antialias',
|
||||
'height',
|
||||
'width',
|
||||
'resample_type',
|
||||
'factor',
|
||||
]
|
||||
|
||||
def backend_attrs(self):
|
||||
return [
|
||||
'antialias',
|
||||
'height',
|
||||
'width',
|
||||
('type', 'resample_type'),
|
||||
'factor'
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def resample_infer(node: Node):
|
||||
layout = node.graph.graph['layout']
|
||||
assert len(layout) == 4
|
||||
|
||||
input_shape = node.in_node(0).shape
|
||||
if input_shape is None:
|
||||
return
|
||||
in_height = input_shape[get_height_dim(layout, 4)]
|
||||
in_width = input_shape[get_width_dim(layout, 4)]
|
||||
|
||||
if node.has('fw') and node.fw == 'tf':
|
||||
dst_shape = node.in_node(1).value
|
||||
if dst_shape is None or len(input_shape) != 4 or len(dst_shape) != 2:
|
||||
log.error(
|
||||
'Node {} with op {} cannot be converted to Resample layer because there is no enough info about '
|
||||
'src/dst shapes: src_shape = {}, dst_shape = {}'.format(node.name, node.op, input_shape, dst_shape))
|
||||
node.type = None # prevent translation to a valid IE layer
|
||||
return
|
||||
out_height = dst_shape[0]
|
||||
out_width = dst_shape[1]
|
||||
else:
|
||||
if len(node.in_nodes()) == 1:
|
||||
if node.has('width') and node.has('height'):
|
||||
out_height = node.height
|
||||
out_width = node.width
|
||||
else:
|
||||
out_height = node.factor * in_height
|
||||
out_width = node.factor * in_width
|
||||
else:
|
||||
out_height = node.in_node(1).shape[get_height_dim(layout, 4)]
|
||||
out_width = node.in_node(1).shape[get_width_dim(layout, 4)]
|
||||
|
||||
node.factor = factor_update(
|
||||
node.factor,
|
||||
[float(out_height) / in_height, float(out_width) / in_width],
|
||||
[in_height, in_width],
|
||||
[out_height, out_width],
|
||||
node.soft_get('name'))
|
||||
|
||||
node.out_node().shape = shape_for_layout(layout,
|
||||
batch=input_shape[get_batch_dim(layout, 4)],
|
||||
features=input_shape[get_features_dim(layout, 4)],
|
||||
height=out_height,
|
||||
width=out_width)
|
||||
@@ -28,6 +28,10 @@ def float_array(l: list):
|
||||
return np.array(l, dtype=np.float64)
|
||||
|
||||
|
||||
def float32_array(l: list):
|
||||
return np.array(l, dtype=np.float32)
|
||||
|
||||
|
||||
def mark_input_bins(node, names=('weights', 'biases'), start_port: int = 1):
|
||||
"""
|
||||
Preparing necessary attributes for edges at input ports starting from start_port.
|
||||
|
||||
@@ -232,6 +232,27 @@ void op::v4::Interpolate::validate_and_infer_types()
|
||||
input_et == element::i8 || input_et == element::bf16,
|
||||
"Input element type must be f32, f16, bf16 or i8");
|
||||
|
||||
element::Type sizes_et = get_input_element_type(1);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
sizes_et == element::i32 || sizes_et == element::i64 ||
|
||||
sizes_et == element::u32 || sizes_et == element::u64,
|
||||
"Sizes element type must be i32, i64, u32 or u64");
|
||||
|
||||
element::Type scales_et = get_input_element_type(2);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
scales_et == element::f32 || scales_et == element::f16 ||
|
||||
scales_et == element::bf16,
|
||||
"Scales element type must be f32, f16 or bf16");
|
||||
|
||||
if (input_values().size() == 4)
|
||||
{
|
||||
element::Type axes_et = get_input_element_type(3);
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
axes_et == element::i64 || axes_et == element::i32 ||
|
||||
sizes_et == element::u32 || sizes_et == element::u64,
|
||||
"Axes element type must be i32, i64, u32 or u64");
|
||||
}
|
||||
|
||||
PartialShape input_shape = PartialShape(get_input_partial_shape(0));
|
||||
|
||||
if (!input_shape.rank().is_static())
|
||||
@@ -240,11 +261,20 @@ void op::v4::Interpolate::validate_and_infer_types()
|
||||
return;
|
||||
}
|
||||
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
|
||||
// If the input 'axes' is given and this input is not Constant, we cannot infer any elements
|
||||
// of the output shape. Hence, all components of the output shape should be dynamic.
|
||||
if (input_values().size() == 4 && !is_type<op::Constant>(input_value(3).get_node()))
|
||||
{
|
||||
PartialShape output_shape = std::vector<Dimension>(input_rank, Dimension::dynamic());
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
return;
|
||||
}
|
||||
|
||||
auto axes = get_axes();
|
||||
correct_pads();
|
||||
|
||||
const auto input_rank = input_shape.rank().get_length();
|
||||
|
||||
PartialShape padded_input_shape = get_padded_input_shape(input_shape);
|
||||
PartialShape output_shape = padded_input_shape;
|
||||
|
||||
@@ -257,7 +287,6 @@ void op::v4::Interpolate::validate_and_infer_types()
|
||||
}
|
||||
}
|
||||
|
||||
set_output_type(0, get_input_element_type(0), output_shape);
|
||||
if (m_attrs.shape_calculation_mode == ShapeCalcMode::scales)
|
||||
{
|
||||
if (const auto& const_scales = get_constant_from_source(input_value(2)))
|
||||
|
||||
@@ -175,44 +175,6 @@ namespace ngraph
|
||||
|
||||
return scales;
|
||||
}
|
||||
|
||||
OutputVector build_resize(const Node& node,
|
||||
const std::shared_ptr<ngraph::Node>& output_shape,
|
||||
const AxisSet& axes)
|
||||
{
|
||||
const auto mode = node.get_attribute_value<std::string>("mode", "nearest");
|
||||
|
||||
std::unordered_set<std::string> supported_modes = {"nearest", "linear"};
|
||||
bool is_mode_supported =
|
||||
(std::find(supported_modes.begin(), supported_modes.end(), mode) !=
|
||||
supported_modes.end());
|
||||
|
||||
if (!is_mode_supported)
|
||||
{
|
||||
std::string supported_modes_str = "";
|
||||
for (const auto& mode_name : supported_modes)
|
||||
{
|
||||
supported_modes_str += (mode_name + ", ");
|
||||
}
|
||||
CHECK_VALID_NODE(node,
|
||||
is_mode_supported,
|
||||
mode,
|
||||
" - this type of interpolation mode is not supported."
|
||||
" Choose one of the following modes: ",
|
||||
supported_modes_str);
|
||||
}
|
||||
|
||||
auto attrs = ngraph::op::v0::InterpolateAttrs();
|
||||
attrs.axes = axes;
|
||||
attrs.mode = mode;
|
||||
attrs.align_corners = false;
|
||||
|
||||
const auto inputs = node.get_ng_inputs();
|
||||
const auto& data = inputs.at(0);
|
||||
|
||||
return {
|
||||
std::make_shared<ngraph::op::v0::Interpolate>(data, output_shape, attrs)};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace set_11
|
||||
@@ -273,17 +235,20 @@ namespace ngraph
|
||||
const auto& data_shape = data.get_partial_shape();
|
||||
const auto& scales_shape = scales.get_partial_shape();
|
||||
|
||||
auto attrs = get_resize_attrs(node);
|
||||
if (attrs.mode == InterpolateMode::linear_onnx)
|
||||
{
|
||||
attrs.coordinate_transformation_mode = Transform_mode::asymmetric;
|
||||
}
|
||||
|
||||
CHECK_VALID_NODE(
|
||||
node,
|
||||
(scales_shape.is_static() || data_shape.rank().is_static()),
|
||||
" Data rank or shape of scales input is required to be static.");
|
||||
|
||||
size_t axes_size = scales_shape.is_static() ? scales_shape[0].get_length()
|
||||
: data_shape.rank().get_length();
|
||||
|
||||
const auto output_shape = calculate_output_shape_based_on_scales(data, scales);
|
||||
return build_resize(
|
||||
node, output_shape, AxisSet(common::get_monotonic_range(axes_size)));
|
||||
return {std::make_shared<default_opset::Interpolate>(
|
||||
data, output_shape, scales, attrs)};
|
||||
}
|
||||
|
||||
} // namespace set_1
|
||||
|
||||
@@ -29,7 +29,7 @@ using ShapeCalcMode = op::v4::Interpolate::ShapeCalcMode;
|
||||
TEST(type_prop, interpolate_v4)
|
||||
{
|
||||
auto image = std::make_shared<op::Parameter>(element::f32, Shape{2, 2, 30, 60});
|
||||
auto target_shape = std::make_shared<op::Parameter>(element::f32, Shape{2, 2, 15, 30});
|
||||
auto target_shape = std::make_shared<op::Parameter>(element::i32, Shape{15, 30});
|
||||
auto scales = op::Constant::create<float>(element::f32, Shape{2}, {0.5f, 0.5f});
|
||||
auto axes = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 3});
|
||||
|
||||
@@ -48,12 +48,68 @@ TEST(type_prop, interpolate_v4)
|
||||
EXPECT_EQ(interp->get_shape(), (Shape{2, 2, 15, 30}));
|
||||
}
|
||||
|
||||
TEST(type_prop, interpolate_v4_non_constant_axes_scales)
|
||||
{
|
||||
auto image = std::make_shared<op::Parameter>(element::f32, Shape{2, 2, 30, 60});
|
||||
auto target_shape = std::make_shared<op::Parameter>(element::i64, Shape{15, 30});
|
||||
auto scales = op::Constant::create<float>(element::f32, Shape{2}, {0.5f, 0.5f});
|
||||
|
||||
auto start = std::make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{2});
|
||||
auto stop = std::make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{4});
|
||||
auto step = std::make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{1});
|
||||
auto axes = std::make_shared<op::v4::Range>(start, stop, step, element::i32);
|
||||
|
||||
InterpolateAttrs attrs;
|
||||
attrs.mode = InterpolateMode::nearest;
|
||||
attrs.shape_calculation_mode = ShapeCalcMode::scales;
|
||||
attrs.coordinate_transformation_mode = CoordinateTransformMode::half_pixel;
|
||||
attrs.nearest_mode = Nearest_mode::round_prefer_floor;
|
||||
attrs.antialias = false;
|
||||
attrs.pads_begin = {0, 0, 0, 0};
|
||||
attrs.pads_end = {0, 0, 0, 0};
|
||||
attrs.cube_coeff = -0.75;
|
||||
auto interp = std::make_shared<op::v4::Interpolate>(image, target_shape, scales, axes, attrs);
|
||||
|
||||
EXPECT_EQ(interp->get_element_type(), element::f32);
|
||||
auto dyn_dim = Dimension::dynamic();
|
||||
auto expected_shape = PartialShape{dyn_dim, dyn_dim, dyn_dim, dyn_dim};
|
||||
ASSERT_TRUE(interp->get_output_partial_shape(0).same_scheme(expected_shape));
|
||||
}
|
||||
|
||||
TEST(type_prop, interpolate_v4_non_constant_axes_sizes)
|
||||
{
|
||||
auto image = std::make_shared<op::Parameter>(element::f32, Shape{2, 2, 30, 60});
|
||||
auto target_shape = std::make_shared<op::Parameter>(element::i64, Shape{15, 30});
|
||||
auto scales = op::Constant::create<float>(element::f32, Shape{2}, {0.5f, 0.5f});
|
||||
|
||||
auto start = std::make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{2});
|
||||
auto stop = std::make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{4});
|
||||
auto step = std::make_shared<op::Constant>(element::i32, Shape{}, std::vector<int32_t>{1});
|
||||
auto axes = std::make_shared<op::v4::Range>(start, stop, step, element::i32);
|
||||
|
||||
InterpolateAttrs attrs;
|
||||
attrs.mode = InterpolateMode::nearest;
|
||||
attrs.shape_calculation_mode = ShapeCalcMode::sizes;
|
||||
attrs.coordinate_transformation_mode = CoordinateTransformMode::half_pixel;
|
||||
attrs.nearest_mode = Nearest_mode::round_prefer_floor;
|
||||
attrs.antialias = false;
|
||||
attrs.pads_begin = {0, 0, 0, 0};
|
||||
attrs.pads_end = {0, 0, 0, 0};
|
||||
attrs.cube_coeff = -0.75;
|
||||
auto interp = std::make_shared<op::v4::Interpolate>(image, target_shape, scales, axes, attrs);
|
||||
|
||||
EXPECT_EQ(interp->get_element_type(), element::f32);
|
||||
auto dyn_dim = Dimension::dynamic();
|
||||
auto expected_shape = PartialShape{dyn_dim, dyn_dim, dyn_dim, dyn_dim};
|
||||
ASSERT_TRUE(interp->get_output_partial_shape(0).same_scheme(expected_shape));
|
||||
}
|
||||
|
||||
TEST(type_prop, interpolate_v4_partial)
|
||||
{
|
||||
auto partial_shape = PartialShape{2, 2, Dimension::dynamic(), Dimension::dynamic()};
|
||||
|
||||
auto image = std::make_shared<op::Parameter>(element::f32, partial_shape);
|
||||
auto target_shape = std::make_shared<op::Parameter>(element::f32, Shape{2, 2, 15, 30});
|
||||
auto target_shape = std::make_shared<op::Parameter>(element::i32, Shape{15, 30});
|
||||
auto scales = op::Constant::create<float>(element::f32, Shape{2}, {0.5f, 0.5f});
|
||||
auto axes = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 3});
|
||||
|
||||
@@ -83,7 +139,7 @@ TEST(type_prop, interpolate_v4_partial_static_rank)
|
||||
auto partial_shape = PartialShape{2, 2, Dimension::dynamic(), Dimension::dynamic()};
|
||||
|
||||
auto image = std::make_shared<op::Parameter>(element::f32, partial_shape);
|
||||
auto target_shape = std::make_shared<op::Parameter>(element::f32, Shape{2, 2, 15, 30});
|
||||
auto target_shape = std::make_shared<op::Parameter>(element::i32, Shape{15, 30});
|
||||
auto scales = op::Constant::create<float>(element::f32, Shape{2}, {0.5f, 0.5f});
|
||||
auto axes = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 3});
|
||||
|
||||
@@ -109,7 +165,7 @@ TEST(type_prop, interpolate_v4_partial_static_rank2)
|
||||
auto out_shape = PartialShape{Dimension::dynamic(), Dimension::dynamic(), 5, 10};
|
||||
|
||||
auto image = std::make_shared<op::Parameter>(element::f32, partial_shape);
|
||||
auto target_shape = std::make_shared<op::Parameter>(element::f32, Shape{2, 2, 15, 30});
|
||||
auto target_shape = std::make_shared<op::Parameter>(element::i32, Shape{15, 30});
|
||||
auto scales = op::Constant::create<float>(element::f32, Shape{2}, {0.5f, 0.5f});
|
||||
auto axes = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 3});
|
||||
|
||||
@@ -135,7 +191,7 @@ TEST(type_prop, interpolate_v4_partial_static_rank3)
|
||||
auto out_shape = PartialShape{Dimension::dynamic(), Dimension::dynamic(), 1, 1};
|
||||
|
||||
auto image = std::make_shared<op::Parameter>(element::f32, partial_shape);
|
||||
auto target_shape = std::make_shared<op::Parameter>(element::f32, Shape{2, 2, 1, 1});
|
||||
auto target_shape = std::make_shared<op::Parameter>(element::i32, Shape{1, 1});
|
||||
auto scales = op::Constant::create<float>(element::f32, Shape{2}, {1.0f / 3.0f, 1.0f / 3.0f});
|
||||
auto axes = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 3});
|
||||
|
||||
|
||||
Reference in New Issue
Block a user