diff --git a/docs/template_plugin/tests/functional/op_reference/grid_sample.cpp b/docs/template_plugin/tests/functional/op_reference/grid_sample.cpp index 6093be86859..1923fe011a9 100644 --- a/docs/template_plugin/tests/functional/op_reference/grid_sample.cpp +++ b/docs/template_plugin/tests/functional/op_reference/grid_sample.cpp @@ -493,6 +493,56 @@ std::vector generateBicubicBatchesParams() { return params; } +std::vector generateCornerCaseData1x1Params() { + std::vector params; + + const reference_tests::Tensor data{{1, 1, 1, 1}, element::f32, std::vector{7}}; + const reference_tests::Tensor grid{{1, 1, 5, 2}, + element::f32, + std::vector{1, -1, 0, 0, -1, 0, 0.5, 0.5, 2, -4}}; + const reference_tests::Tensor sevens{{1, 1, 1, 5}, element::f32, std::vector{7, 7, 7, 7, 7}}; + + params.emplace_back( + data, + grid, + op::v9::GridSample::Attributes{false, GS_BILINEAR, GS_ZEROS}, + reference_tests::Tensor{{1, 1, 1, 5}, element::f32, std::vector{1.75, 7, 3.5, 3.9375, 0}}, + "bilinear_zeros_no_align_data1x1"); + + params.emplace_back(data, + grid, + op::v9::GridSample::Attributes{false, GS_NEAREST, GS_ZEROS}, + reference_tests::Tensor{{1, 1, 1, 5}, element::f32, std::vector{7, 7, 7, 7, 0}}, + "nearest_zeros_no_align_data1x1"); + + params.emplace_back( + data, + grid, + op::v9::GridSample::Attributes{false, GS_BICUBIC, GS_ZEROS}, + reference_tests::Tensor{{1, 1, 1, 5}, element::f32, std::vector{2.4677734, 7, 4.15625, 5.4073334, 0}}, + "bicubic_zeros_no_align_data1x1"); + + params.emplace_back(data, + grid, + op::v9::GridSample::Attributes{true, GS_BICUBIC, GS_ZEROS}, + sevens, + "bicubic_zeros_align_data1x1"); + + params.emplace_back(data, + grid, + op::v9::GridSample::Attributes{false, GS_BILINEAR, GS_REFLECTION}, + sevens, + "bilinear_reflection_noalign_data1x1"); + + params.emplace_back(data, + grid, + op::v9::GridSample::Attributes{true, GS_NEAREST, GS_BORDER}, + sevens, + "nearest_border_align_data1x1"); + + return params; +} + std::vector generateGridSampleParams() { std::vector> combo_params{generateNearestParamsOddDimensionsInnerGrids(), generateNearestParamsOddDimensionsOuterGrids(), @@ -501,7 +551,8 @@ std::vector generateGridSampleParams() { generateBilinearParamsOddDimensionsOuterGrids(), generateBilinearParamsEvenDimensions(), generateBicubicParams(), - generateBicubicBatchesParams()}; + generateBicubicBatchesParams(), + generateCornerCaseData1x1Params()}; std::vector test_params; for (auto& params : combo_params) std::move(params.begin(), params.end(), std::back_inserter(test_params)); diff --git a/src/core/reference/include/ngraph/runtime/reference/grid_sample.hpp b/src/core/reference/include/ngraph/runtime/reference/grid_sample.hpp index 7ffe07bce40..577b9be65fc 100644 --- a/src/core/reference/include/ngraph/runtime/reference/grid_sample.hpp +++ b/src/core/reference/include/ngraph/runtime/reference/grid_sample.hpp @@ -109,8 +109,8 @@ DATA_ET reflection_data_with_align(const DATA_ET* data, long x_d) { const auto H = static_cast(data_shape[2]); const auto W = static_cast(data_shape[3]); - const auto H_2_2 = 2 * (H - 1); - const auto W_2_2 = 2 * (W - 1); + const auto H_2_2 = H == 1 ? 1 : 2 * (H - 1); + const auto W_2_2 = W == 1 ? 1 : 2 * (W - 1); y_d = std::abs(y_d) % H_2_2; x_d = std::abs(x_d) % W_2_2; const auto y = static_cast(y_d >= H ? H_2_2 - y_d : y_d);