ConvertInterpolate1ToInterpolate4 fixes (#6019)

* half_pixel -> asymmetric and round_prefer_floor -> simple in ConvertInterpolate1ToInterpolate4

* test fix
This commit is contained in:
Maxim Andronov 2021-06-08 10:19:25 +03:00 committed by GitHub
parent 4409a74dcf
commit 98f45ffbdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 8 deletions

View File

@ -68,14 +68,20 @@ ngraph::pass::ConvertInterpolate1ToInterpolate4::ConvertInterpolate1ToInterpolat
return false; return false;
} }
attrsV4.shape_calculation_mode = ngraph::opset4::Interpolate::ShapeCalcMode::sizes; attrsV4.shape_calculation_mode = ngraph::opset4::Interpolate::ShapeCalcMode::sizes;
attrsV4.nearest_mode = ngraph::opset4::Interpolate::NearestMode::round_prefer_floor; attrsV4.nearest_mode = ngraph::opset4::Interpolate::NearestMode::simple;
attrsV4.pads_begin = attrsV0.pads_begin; attrsV4.pads_begin = attrsV0.pads_begin;
attrsV4.pads_end = attrsV0.pads_end; attrsV4.pads_end = attrsV0.pads_end;
attrsV4.antialias = attrsV0.antialias; attrsV4.antialias = attrsV0.antialias;
attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::half_pixel; attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::asymmetric;
attrsV4.cube_coeff = -0.75f; attrsV4.cube_coeff = -0.75f;
if (attrsV0.align_corners) { if (attrsV0.align_corners) {
attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::align_corners; attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::align_corners;
} else if ((attrsV4.mode == ngraph::op::v4::Interpolate::InterpolateMode::linear_onnx ||
attrsV4.mode == ngraph::op::v4::Interpolate::InterpolateMode::linear) &&
std::all_of(attrsV4.pads_begin.begin(), attrsV4.pads_begin.end(), [](size_t i){return i == 0;}) &&
std::all_of(attrsV4.pads_end.begin(), attrsV4.pads_end.end(), [](size_t i){return i == 0;}) &&
!(input_shape_rank - 2 == 2 && attrsV0.axes == AxisSet{2, 3})) {
attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::half_pixel;
} }
auto interpolateV4 = std::make_shared<ngraph::opset4::Interpolate>(interpolationV0->input_value(0), interpolationV0->input_value(1), auto interpolateV4 = std::make_shared<ngraph::opset4::Interpolate>(interpolationV0->input_value(0), interpolationV0->input_value(1),

View File

@ -54,7 +54,7 @@ TEST(TransformationTests, ConvertInterpolate1ToInterpolate4) {
auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::nearest, auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::nearest,
opset4::Interpolate::ShapeCalcMode::sizes, std::vector<size_t>{0, 0, 0, 0}, std::vector<size_t>{0, 0, 0, 0}, opset4::Interpolate::ShapeCalcMode::sizes, std::vector<size_t>{0, 0, 0, 0}, std::vector<size_t>{0, 0, 0, 0},
opset4::Interpolate::CoordinateTransformMode::asymmetric, opset4::Interpolate::NearestMode::floor, opset4::Interpolate::CoordinateTransformMode::asymmetric, opset4::Interpolate::NearestMode::simple,
false, -0.75); false, -0.75);
auto interpolate4 = std::make_shared<opset4::Interpolate>(data_node, out_shape_node, default_scales_node, axes_node, interpolate4_attr); auto interpolate4 = std::make_shared<opset4::Interpolate>(data_node, out_shape_node, default_scales_node, axes_node, interpolate4_attr);
@ -62,7 +62,7 @@ TEST(TransformationTests, ConvertInterpolate1ToInterpolate4) {
f_ref = std::make_shared<Function>(NodeVector{interpolate4}, ParameterVector{data_node}); f_ref = std::make_shared<Function>(NodeVector{interpolate4}, ParameterVector{data_node});
} }
auto res = compare_functions(f, f_ref); auto res = compare_functions(f, f_ref, true, false, false, true, true);
ASSERT_TRUE(res.first) << res.second; ASSERT_TRUE(res.first) << res.second;
} }
@ -97,16 +97,16 @@ TEST(TransformationTests, ConvertInterpolate1ToInterpolate4_1) {
auto default_scales_node = opset1::Constant::create(ngraph::element::f32, Shape{2}, {4.0f / 3.0f, 4.0f / 3.0f}); auto default_scales_node = opset1::Constant::create(ngraph::element::f32, Shape{2}, {4.0f / 3.0f, 4.0f / 3.0f});
auto axes_node = opset1::Constant::create(ngraph::element::i64, Shape{2}, {2, 3}); auto axes_node = opset1::Constant::create(ngraph::element::i64, Shape{2}, {2, 3});
auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::linear, auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::linear_onnx,
opset4::Interpolate::ShapeCalcMode::sizes, std::vector<size_t>{0, 0, 0, 0}, std::vector<size_t>{0, 0, 0, 0}, opset4::Interpolate::ShapeCalcMode::sizes, std::vector<size_t>{0, 0, 0, 0}, std::vector<size_t>{0, 0, 0, 0},
opset4::Interpolate::CoordinateTransformMode::align_corners, opset4::Interpolate::NearestMode::floor, opset4::Interpolate::CoordinateTransformMode::asymmetric, opset4::Interpolate::NearestMode::simple,
false, -0.75); true, -0.75);
auto interpolate4 = std::make_shared<opset4::Interpolate>(data_node, out_shape_node, default_scales_node, axes_node, interpolate4_attr); auto interpolate4 = std::make_shared<opset4::Interpolate>(data_node, out_shape_node, default_scales_node, axes_node, interpolate4_attr);
f_ref = std::make_shared<Function>(NodeVector{interpolate4}, ParameterVector{data_node}); f_ref = std::make_shared<Function>(NodeVector{interpolate4}, ParameterVector{data_node});
} }
auto res = compare_functions(f, f_ref); auto res = compare_functions(f, f_ref, true, false, false, true, true);
ASSERT_TRUE(res.first) << res.second; ASSERT_TRUE(res.first) << res.second;
} }