[ONNX FE] ONNX Resize (1-10) import use scale input and fix (#16789)
This commit is contained in:
@@ -139,21 +139,15 @@ OutputVector resize(const onnx_import::Node& node) {
|
||||
const auto& scales = inputs.at(1);
|
||||
|
||||
auto attrs = get_resize_attrs(node);
|
||||
attrs.shape_calculation_mode = default_opset::Interpolate::ShapeCalcMode::SIZES;
|
||||
attrs.shape_calculation_mode = default_opset::Interpolate::ShapeCalcMode::SCALES;
|
||||
|
||||
if (attrs.mode == InterpolateMode::NEAREST) {
|
||||
attrs.nearest_mode = Nearest_mode::FLOOR;
|
||||
attrs.nearest_mode = Nearest_mode::SIMPLE;
|
||||
attrs.coordinate_transformation_mode = Transform_mode::ASYMMETRIC;
|
||||
} else if (attrs.mode == InterpolateMode::LINEAR_ONNX) {
|
||||
attrs.coordinate_transformation_mode = Transform_mode::ASYMMETRIC;
|
||||
}
|
||||
|
||||
const auto shape_of_data = std::make_shared<default_opset::Convert>(std::make_shared<default_opset::ShapeOf>(data),
|
||||
scales.get_element_type());
|
||||
const auto multiply = std::make_shared<default_opset::Multiply>(shape_of_data, scales);
|
||||
const auto output_shape = std::make_shared<default_opset::Convert>(multiply, ngraph::element::i64);
|
||||
|
||||
return {std::make_shared<default_opset::Interpolate>(data, output_shape, attrs)};
|
||||
return {std::make_shared<default_opset::Interpolate>(data, scales, attrs)};
|
||||
}
|
||||
|
||||
} // namespace set_1
|
||||
|
||||
@@ -1501,6 +1501,23 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_resize11_empty_constant_as_input) {
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_resize10_down_scales_const_linear) {
|
||||
const auto function =
|
||||
onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||
SERIALIZED_ZOO,
|
||||
"onnx/resize10_down_scales_const_linear.onnx"));
|
||||
|
||||
// Input data shape (1, 1, 2, 4)
|
||||
// Input const scales values {1.0, 1.0, 0.6, 0.6}
|
||||
// mode: linear
|
||||
|
||||
Shape expected_output_shape{1, 1, 1, 2};
|
||||
auto test_case = test::TestCase(function, s_device);
|
||||
test_case.add_input<float>({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0});
|
||||
test_case.add_expected_output<float>(expected_output_shape, {1.0f, 2.6666665f});
|
||||
test_case.run();
|
||||
}
|
||||
|
||||
NGRAPH_TEST(${BACKEND_NAME}, onnx_resize10_down_scales_const_nearest) {
|
||||
const auto function =
|
||||
onnx_import::import_onnx_model(file_util::path_join(CommonTestUtils::getExecutableDirectory(),
|
||||
@@ -1509,7 +1526,7 @@ NGRAPH_TEST(${BACKEND_NAME}, onnx_resize10_down_scales_const_nearest) {
|
||||
|
||||
// Input data shape (1, 1, 2, 4)
|
||||
// Input const scales values {1.0, 1.0, 0.6, 0.6}
|
||||
// mode: linear
|
||||
// mode: nearest
|
||||
|
||||
Shape expected_output_shape{1, 1, 1, 2};
|
||||
auto test_case = test::TestCase(function, s_device);
|
||||
|
||||
Reference in New Issue
Block a user