[ONNX FE] ONNX Resize (1-10) import use scale input and fix (#16789)

This commit is contained in:
Katarzyna Mitrus
2023-05-09 13:05:29 +02:00
committed by GitHub
parent eb178a753b
commit 43936bd18a
2 changed files with 21 additions and 10 deletions

View File

@@ -139,21 +139,15 @@ OutputVector resize(const onnx_import::Node& node) {
const auto& scales = inputs.at(1);
auto attrs = get_resize_attrs(node);
attrs.shape_calculation_mode = default_opset::Interpolate::ShapeCalcMode::SIZES;
attrs.shape_calculation_mode = default_opset::Interpolate::ShapeCalcMode::SCALES;
if (attrs.mode == InterpolateMode::NEAREST) {
attrs.nearest_mode = Nearest_mode::FLOOR;
attrs.nearest_mode = Nearest_mode::SIMPLE;
attrs.coordinate_transformation_mode = Transform_mode::ASYMMETRIC;
} else if (attrs.mode == InterpolateMode::LINEAR_ONNX) {
attrs.coordinate_transformation_mode = Transform_mode::ASYMMETRIC;
}
const auto shape_of_data = std::make_shared<default_opset::Convert>(std::make_shared<default_opset::ShapeOf>(data),
scales.get_element_type());
const auto multiply = std::make_shared<default_opset::Multiply>(shape_of_data, scales);
const auto output_shape = std::make_shared<default_opset::Convert>(multiply, ngraph::element::i64);
return {std::make_shared<default_opset::Interpolate>(data, output_shape, attrs)};
return {std::make_shared<default_opset::Interpolate>(data, scales, attrs)};
}
} // namespace set_1

View File

@@ -1501,6 +1501,23 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_resize11_empty_constant_as_input) {
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_resize10_down_scales_const_linear) {
const auto function =
onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
SERIALIZED_ZOO,
"onnx/resize10_down_scales_const_linear.onnx"));
// Input data shape (1, 1, 2, 4)
// Input const scales values {1.0, 1.0, 0.6, 0.6}
// mode: linear
Shape expected_output_shape{1, 1, 1, 2};
auto test_case = test::TestCase(function, s_device);
test_case.add_input<float>({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0});
test_case.add_expected_output<float>(expected_output_shape, {1.0f, 2.6666665f});
test_case.run();
}
NGRAPH_TEST(${BACKEND_NAME}, onnx_resize10_down_scales_const_nearest) {
const auto function =
onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
@@ -1509,7 +1526,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_resize10_down_scales_const_nearest) {
// Input data shape (1, 1, 2, 4)
// Input const scales values {1.0, 1.0, 0.6, 0.6}
// mode: linear
// mode: nearest
Shape expected_output_shape{1, 1, 1, 2};
auto test_case = test::TestCase(function, s_device);