Handle corner case 1x1 with tests (#12867)
This commit is contained in:
parent
194e3e766f
commit
9eac5d8281
@ -493,6 +493,56 @@ std::vector<GridSampleParams> generateBicubicBatchesParams() {
|
||||
return params;
|
||||
}
|
||||
|
||||
std::vector<GridSampleParams> generateCornerCaseData1x1Params() {
|
||||
std::vector<GridSampleParams> params;
|
||||
|
||||
const reference_tests::Tensor data{{1, 1, 1, 1}, element::f32, std::vector<float>{7}};
|
||||
const reference_tests::Tensor grid{{1, 1, 5, 2},
|
||||
element::f32,
|
||||
std::vector<float>{1, -1, 0, 0, -1, 0, 0.5, 0.5, 2, -4}};
|
||||
const reference_tests::Tensor sevens{{1, 1, 1, 5}, element::f32, std::vector<float>{7, 7, 7, 7, 7}};
|
||||
|
||||
params.emplace_back(
|
||||
data,
|
||||
grid,
|
||||
op::v9::GridSample::Attributes{false, GS_BILINEAR, GS_ZEROS},
|
||||
reference_tests::Tensor{{1, 1, 1, 5}, element::f32, std::vector<float>{1.75, 7, 3.5, 3.9375, 0}},
|
||||
"bilinear_zeros_no_align_data1x1");
|
||||
|
||||
params.emplace_back(data,
|
||||
grid,
|
||||
op::v9::GridSample::Attributes{false, GS_NEAREST, GS_ZEROS},
|
||||
reference_tests::Tensor{{1, 1, 1, 5}, element::f32, std::vector<float>{7, 7, 7, 7, 0}},
|
||||
"nearest_zeros_no_align_data1x1");
|
||||
|
||||
params.emplace_back(
|
||||
data,
|
||||
grid,
|
||||
op::v9::GridSample::Attributes{false, GS_BICUBIC, GS_ZEROS},
|
||||
reference_tests::Tensor{{1, 1, 1, 5}, element::f32, std::vector<float>{2.4677734, 7, 4.15625, 5.4073334, 0}},
|
||||
"bicubic_zeros_no_align_data1x1");
|
||||
|
||||
params.emplace_back(data,
|
||||
grid,
|
||||
op::v9::GridSample::Attributes{true, GS_BICUBIC, GS_ZEROS},
|
||||
sevens,
|
||||
"bicubic_zeros_align_data1x1");
|
||||
|
||||
params.emplace_back(data,
|
||||
grid,
|
||||
op::v9::GridSample::Attributes{false, GS_BILINEAR, GS_REFLECTION},
|
||||
sevens,
|
||||
"bilinear_reflection_noalign_data1x1");
|
||||
|
||||
params.emplace_back(data,
|
||||
grid,
|
||||
op::v9::GridSample::Attributes{true, GS_NEAREST, GS_BORDER},
|
||||
sevens,
|
||||
"nearest_border_align_data1x1");
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
std::vector<GridSampleParams> generateGridSampleParams() {
|
||||
std::vector<std::vector<GridSampleParams>> combo_params{generateNearestParamsOddDimensionsInnerGrids(),
|
||||
generateNearestParamsOddDimensionsOuterGrids(),
|
||||
@ -501,7 +551,8 @@ std::vector<GridSampleParams> generateGridSampleParams() {
|
||||
generateBilinearParamsOddDimensionsOuterGrids(),
|
||||
generateBilinearParamsEvenDimensions(),
|
||||
generateBicubicParams(),
|
||||
generateBicubicBatchesParams()};
|
||||
generateBicubicBatchesParams(),
|
||||
generateCornerCaseData1x1Params()};
|
||||
std::vector<GridSampleParams> test_params;
|
||||
for (auto& params : combo_params)
|
||||
std::move(params.begin(), params.end(), std::back_inserter(test_params));
|
||||
|
@ -109,8 +109,8 @@ DATA_ET reflection_data_with_align(const DATA_ET* data,
|
||||
long x_d) {
|
||||
const auto H = static_cast<long>(data_shape[2]);
|
||||
const auto W = static_cast<long>(data_shape[3]);
|
||||
const auto H_2_2 = 2 * (H - 1);
|
||||
const auto W_2_2 = 2 * (W - 1);
|
||||
const auto H_2_2 = H == 1 ? 1 : 2 * (H - 1);
|
||||
const auto W_2_2 = W == 1 ? 1 : 2 * (W - 1);
|
||||
y_d = std::abs(y_d) % H_2_2;
|
||||
x_d = std::abs(x_d) % W_2_2;
|
||||
const auto y = static_cast<size_t>(y_d >= H ? H_2_2 - y_d : y_d);
|
||||
|
Loading…
Reference in New Issue
Block a user