diff --git a/inference-engine/src/transformations/src/transformations/op_conversions/convert_interpolate1_to_interpolate4.cpp b/inference-engine/src/transformations/src/transformations/op_conversions/convert_interpolate1_to_interpolate4.cpp index 94173079c62..36a58551a68 100644 --- a/inference-engine/src/transformations/src/transformations/op_conversions/convert_interpolate1_to_interpolate4.cpp +++ b/inference-engine/src/transformations/src/transformations/op_conversions/convert_interpolate1_to_interpolate4.cpp @@ -68,14 +68,20 @@ ngraph::pass::ConvertInterpolate1ToInterpolate4::ConvertInterpolate1ToInterpolat return false; } attrsV4.shape_calculation_mode = ngraph::opset4::Interpolate::ShapeCalcMode::sizes; - attrsV4.nearest_mode = ngraph::opset4::Interpolate::NearestMode::round_prefer_floor; + attrsV4.nearest_mode = ngraph::opset4::Interpolate::NearestMode::simple; attrsV4.pads_begin = attrsV0.pads_begin; attrsV4.pads_end = attrsV0.pads_end; attrsV4.antialias = attrsV0.antialias; - attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::half_pixel; + attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::asymmetric; attrsV4.cube_coeff = -0.75f; if (attrsV0.align_corners) { attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::align_corners; + } else if ((attrsV4.mode == ngraph::op::v4::Interpolate::InterpolateMode::linear_onnx || + attrsV4.mode == ngraph::op::v4::Interpolate::InterpolateMode::linear) && + std::all_of(attrsV4.pads_begin.begin(), attrsV4.pads_begin.end(), [](size_t i){return i == 0;}) && + std::all_of(attrsV4.pads_end.begin(), attrsV4.pads_end.end(), [](size_t i){return i == 0;}) && + !(input_shape_rank - 2 == 2 && attrsV0.axes == AxisSet{2, 3})) { + attrsV4.coordinate_transformation_mode = ngraph::opset4::Interpolate::CoordinateTransformMode::half_pixel; } auto interpolateV4 = std::make_shared(interpolationV0->input_value(0), interpolationV0->input_value(1), diff --git a/inference-engine/tests/functional/inference_engine/transformations/convert_interpolate1_to_interpolate4_test.cpp b/inference-engine/tests/functional/inference_engine/transformations/convert_interpolate1_to_interpolate4_test.cpp index 9468db9287d..12177f78cbc 100644 --- a/inference-engine/tests/functional/inference_engine/transformations/convert_interpolate1_to_interpolate4_test.cpp +++ b/inference-engine/tests/functional/inference_engine/transformations/convert_interpolate1_to_interpolate4_test.cpp @@ -54,7 +54,7 @@ TEST(TransformationTests, ConvertInterpolate1ToInterpolate4) { auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::nearest, opset4::Interpolate::ShapeCalcMode::sizes, std::vector{0, 0, 0, 0}, std::vector{0, 0, 0, 0}, - opset4::Interpolate::CoordinateTransformMode::asymmetric, opset4::Interpolate::NearestMode::floor, + opset4::Interpolate::CoordinateTransformMode::asymmetric, opset4::Interpolate::NearestMode::simple, false, -0.75); auto interpolate4 = std::make_shared(data_node, out_shape_node, default_scales_node, axes_node, interpolate4_attr); @@ -62,7 +62,7 @@ TEST(TransformationTests, ConvertInterpolate1ToInterpolate4) { f_ref = std::make_shared(NodeVector{interpolate4}, ParameterVector{data_node}); } - auto res = compare_functions(f, f_ref); + auto res = compare_functions(f, f_ref, true, false, false, true, true); ASSERT_TRUE(res.first) << res.second; } @@ -97,16 +97,16 @@ TEST(TransformationTests, ConvertInterpolate1ToInterpolate4_1) { auto default_scales_node = opset1::Constant::create(ngraph::element::f32, Shape{2}, {4.0f / 3.0f, 4.0f / 3.0f}); auto axes_node = opset1::Constant::create(ngraph::element::i64, Shape{2}, {2, 3}); - auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::linear, + auto interpolate4_attr = opset4::Interpolate::InterpolateAttrs(opset4::Interpolate::InterpolateMode::linear_onnx, opset4::Interpolate::ShapeCalcMode::sizes, std::vector{0, 0, 0, 0}, std::vector{0, 0, 0, 0}, - opset4::Interpolate::CoordinateTransformMode::align_corners, opset4::Interpolate::NearestMode::floor, - false, -0.75); + opset4::Interpolate::CoordinateTransformMode::asymmetric, opset4::Interpolate::NearestMode::simple, + true, -0.75); auto interpolate4 = std::make_shared(data_node, out_shape_node, default_scales_node, axes_node, interpolate4_attr); f_ref = std::make_shared(NodeVector{interpolate4}, ParameterVector{data_node}); } - auto res = compare_functions(f, f_ref); + auto res = compare_functions(f, f_ref, true, false, false, true, true); ASSERT_TRUE(res.first) << res.second; }