ConvertInterpolate1ToInterpolate4 fixes (#6019)
* half_pixel -> asymmetric and round_prefer_floor -> simple in ConvertInterpolate1ToInterpolate4 * test fix
This commit is contained in:
parent
4409a74dcf
commit
98f45ffbdd
@ -68,14 +68,20 @@ ngraph::pass::ConvertInterpolate1ToInterpolate4::ConvertInterpolate1ToInterpolat
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
attrsV4.shape_calculation_mode = ngraph::opset4::Interpolate::ShapeCalcMode::sizes;
|
attrsV4.shape_calculation_mode = ngraph::opset4::Interpolate::ShapeCalcMode::sizes;
|
||||||
attrsV4.nearest_mode = ngraph::opset4::Interpolate::NearestMode::round_prefer_floor;
|
attrsV4.nearest_mode = ngraph::opset4::Interpolate::NearestMode::simple;
|
||||||
attrsV4.pads_begin = attrsV0.pads_begin;
|
attrsV4.pads_begin = attrsV0.pads_begin;
|
||||||
attrsV4.pads_end = attrsV0.pads_end;
|
attrsV4.pads_end = attrsV0.pads_end;
|
||||||
attrsV4.antialias = attrsV0.antialias;
|
attrsV4.antialias = attrsV0.antialias;
|
||||||
attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::half_pixel;
|
attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::asymmetric;
|
||||||
attrsV4.cube_coeff = -0.75f;
|
attrsV4.cube_coeff = -0.75f;
|
||||||
if (attrsV0.align_corners) {
|
if (attrsV0.align_corners) {
|
||||||
attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::align_corners;
|
attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::align_corners;
|
||||||
|
} else if ((attrsV4.mode == ngraph::op::v4::Interpolate::InterpolateMode::linear_onnx ||
|
||||||
|
attrsV4.mode == ngraph::op::v4::Interpolate::InterpolateMode::linear) &&
|
||||||
|
std::all_of(attrsV4.pads_begin.begin(), attrsV4.pads_begin.end(), [](size_t i){return i == 0;}) &&
|
||||||
|
std::all_of(attrsV4.pads_end.begin(), attrsV4.pads_end.end(), [](size_t i){return i == 0;}) &&
|
||||||
|
!(input_shape_rank - 2 == 2 && attrsV0.axes == AxisSet{2, 3})) {
|
||||||
|
attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::half_pixel;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto interpolateV4 = std::make_shared<ngraph::opset4::Interpolate>(interpolationV0->input_value(0), interpolationV0->input_value(1),
|
auto interpolateV4 = std::make_shared<ngraph::opset4::Interpolate>(interpolationV0->input_value(0), interpolationV0->input_value(1),
|
||||||
|
@ -54,7 +54,7 @@ TEST(TransformationTests, ConvertInterpolate1ToInterpolate4) {
|
|||||||
|
|
||||||
auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::nearest,
|
auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::nearest,
|
||||||
opset4::Interpolate::ShapeCalcMode::sizes, std::vector<size_t>{0, 0, 0, 0}, std::vector<size_t>{0, 0, 0, 0},
|
opset4::Interpolate::ShapeCalcMode::sizes, std::vector<size_t>{0, 0, 0, 0}, std::vector<size_t>{0, 0, 0, 0},
|
||||||
opset4::Interpolate::CoordinateTransformMode::asymmetric, opset4::Interpolate::NearestMode::floor,
|
opset4::Interpolate::CoordinateTransformMode::asymmetric, opset4::Interpolate::NearestMode::simple,
|
||||||
false, -0.75);
|
false, -0.75);
|
||||||
|
|
||||||
auto interpolate4 = std::make_shared<opset4::Interpolate>(data_node, out_shape_node, default_scales_node, axes_node, interpolate4_attr);
|
auto interpolate4 = std::make_shared<opset4::Interpolate>(data_node, out_shape_node, default_scales_node, axes_node, interpolate4_attr);
|
||||||
@ -62,7 +62,7 @@ TEST(TransformationTests, ConvertInterpolate1ToInterpolate4) {
|
|||||||
f_ref = std::make_shared<Function>(NodeVector{interpolate4}, ParameterVector{data_node});
|
f_ref = std::make_shared<Function>(NodeVector{interpolate4}, ParameterVector{data_node});
|
||||||
}
|
}
|
||||||
|
|
||||||
auto res = compare_functions(f, f_ref);
|
auto res = compare_functions(f, f_ref, true, false, false, true, true);
|
||||||
ASSERT_TRUE(res.first) << res.second;
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -97,16 +97,16 @@ TEST(TransformationTests, ConvertInterpolate1ToInterpolate4_1) {
|
|||||||
auto default_scales_node = opset1::Constant::create(ngraph::element::f32, Shape{2}, {4.0f / 3.0f, 4.0f / 3.0f});
|
auto default_scales_node = opset1::Constant::create(ngraph::element::f32, Shape{2}, {4.0f / 3.0f, 4.0f / 3.0f});
|
||||||
auto axes_node = opset1::Constant::create(ngraph::element::i64, Shape{2}, {2, 3});
|
auto axes_node = opset1::Constant::create(ngraph::element::i64, Shape{2}, {2, 3});
|
||||||
|
|
||||||
auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::linear,
|
auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::linear_onnx,
|
||||||
opset4::Interpolate::ShapeCalcMode::sizes, std::vector<size_t>{0, 0, 0, 0}, std::vector<size_t>{0, 0, 0, 0},
|
opset4::Interpolate::ShapeCalcMode::sizes, std::vector<size_t>{0, 0, 0, 0}, std::vector<size_t>{0, 0, 0, 0},
|
||||||
opset4::Interpolate::CoordinateTransformMode::align_corners, opset4::Interpolate::NearestMode::floor,
|
opset4::Interpolate::CoordinateTransformMode::asymmetric, opset4::Interpolate::NearestMode::simple,
|
||||||
false, -0.75);
|
true, -0.75);
|
||||||
|
|
||||||
auto interpolate4 = std::make_shared<opset4::Interpolate>(data_node, out_shape_node, default_scales_node, axes_node, interpolate4_attr);
|
auto interpolate4 = std::make_shared<opset4::Interpolate>(data_node, out_shape_node, default_scales_node, axes_node, interpolate4_attr);
|
||||||
|
|
||||||
f_ref = std::make_shared<Function>(NodeVector{interpolate4}, ParameterVector{data_node});
|
f_ref = std::make_shared<Function>(NodeVector{interpolate4}, ParameterVector{data_node});
|
||||||
}
|
}
|
||||||
|
|
||||||
auto res = compare_functions(f, f_ref);
|
auto res = compare_functions(f, f_ref, true, false, false, true, true);
|
||||||
ASSERT_TRUE(res.first) << res.second;
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user