Handle corner case 1x1 with tests (#12867)

This commit is contained in:
Tomasz Jankowski 2022-09-05 16:18:12 +02:00 committed by GitHub
parent 194e3e766f
commit 9eac5d8281
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 54 additions and 3 deletions

View File

@ -493,6 +493,56 @@ std::vector<GridSampleParams> generateBicubicBatchesParams() {
return params;
}
std::vector<GridSampleParams> generateCornerCaseData1x1Params() {
std::vector<GridSampleParams> params;
const reference_tests::Tensor data{{1, 1, 1, 1}, element::f32, std::vector<float>{7}};
const reference_tests::Tensor grid{{1, 1, 5, 2},
element::f32,
std::vector<float>{1, -1, 0, 0, -1, 0, 0.5, 0.5, 2, -4}};
const reference_tests::Tensor sevens{{1, 1, 1, 5}, element::f32, std::vector<float>{7, 7, 7, 7, 7}};
params.emplace_back(
data,
grid,
op::v9::GridSample::Attributes{false, GS_BILINEAR, GS_ZEROS},
reference_tests::Tensor{{1, 1, 1, 5}, element::f32, std::vector<float>{1.75, 7, 3.5, 3.9375, 0}},
"bilinear_zeros_no_align_data1x1");
params.emplace_back(data,
grid,
op::v9::GridSample::Attributes{false, GS_NEAREST, GS_ZEROS},
reference_tests::Tensor{{1, 1, 1, 5}, element::f32, std::vector<float>{7, 7, 7, 7, 0}},
"nearest_zeros_no_align_data1x1");
params.emplace_back(
data,
grid,
op::v9::GridSample::Attributes{false, GS_BICUBIC, GS_ZEROS},
reference_tests::Tensor{{1, 1, 1, 5}, element::f32, std::vector<float>{2.4677734, 7, 4.15625, 5.4073334, 0}},
"bicubic_zeros_no_align_data1x1");
params.emplace_back(data,
grid,
op::v9::GridSample::Attributes{true, GS_BICUBIC, GS_ZEROS},
sevens,
"bicubic_zeros_align_data1x1");
params.emplace_back(data,
grid,
op::v9::GridSample::Attributes{false, GS_BILINEAR, GS_REFLECTION},
sevens,
"bilinear_reflection_noalign_data1x1");
params.emplace_back(data,
grid,
op::v9::GridSample::Attributes{true, GS_NEAREST, GS_BORDER},
sevens,
"nearest_border_align_data1x1");
return params;
}
std::vector<GridSampleParams> generateGridSampleParams() {
std::vector<std::vector<GridSampleParams>> combo_params{generateNearestParamsOddDimensionsInnerGrids(),
generateNearestParamsOddDimensionsOuterGrids(),
@ -501,7 +551,8 @@ std::vector<GridSampleParams> generateGridSampleParams() {
generateBilinearParamsOddDimensionsOuterGrids(),
generateBilinearParamsEvenDimensions(),
generateBicubicParams(),
generateBicubicBatchesParams()};
generateBicubicBatchesParams(),
generateCornerCaseData1x1Params()};
std::vector<GridSampleParams> test_params;
for (auto& params : combo_params)
std::move(params.begin(), params.end(), std::back_inserter(test_params));

View File

@ -109,8 +109,8 @@ DATA_ET reflection_data_with_align(const DATA_ET* data,
long x_d) {
const auto H = static_cast<long>(data_shape[2]);
const auto W = static_cast<long>(data_shape[3]);
const auto H_2_2 = 2 * (H - 1);
const auto W_2_2 = 2 * (W - 1);
const auto H_2_2 = H == 1 ? 1 : 2 * (H - 1);
const auto W_2_2 = W == 1 ? 1 : 2 * (W - 1);
y_d = std::abs(y_d) % H_2_2;
x_d = std::abs(x_d) % W_2_2;
const auto y = static_cast<size_t>(y_d >= H ? H_2_2 - y_d : y_d);