[ShapeInference] GridSample shape infer review (#15102)

* Add more type_prop tests for interval dims and labels

* Add setter for grid sample attributes

* Merge grid sample batch dims

* Add StaticShapeInferenceTest for GridSample

* Fix label test

* Use OpStaticShapeInferenceTest fixture in test
This commit is contained in:
Katarzyna Mitrus 2023-01-18 13:39:05 +01:00 committed by GitHub
parent 32fce5cb40
commit fd6640b6eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 151 additions and 8 deletions

View File

@ -55,6 +55,10 @@ public:
return m_attributes;
}
void set_attributes(const Attributes& attributes) {
m_attributes = attributes;
}
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END

View File

@ -41,15 +41,9 @@ void shape_infer(const GridSample* op, const std::vector<shape_t>& input_shapes,
if (data_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(
op,
data_shape[0].compatible(grid_shape[0]),
shape_t::value_type::merge(batch_dim, grid_shape[0], data_shape[0]),
"The batch dimension in the input data tensor's shape doesn't match the batch dimension in "
"the grid tensor's shape.");
// both dimensions should match but use the one which is (possibly) static for the output shape
if (data_shape[0].is_static()) {
batch_dim = data_shape[0];
}
channel_dim = data_shape[1];
}
} else if (data_shape.rank().is_static()) {

View File

@ -2,14 +2,44 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "dimension_tracker.hpp"
#include "gtest/gtest.h"
#include "openvino/op/util/attr_types.hpp"
#include "openvino/opsets/opset9.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ov;
using namespace opset9;
using namespace testing;
TEST(type_prop, grid_sample_default) {
TEST(type_prop, grid_sample_default_constructor) {
const auto data = make_shared<Parameter>(element::i32, PartialShape{1, 3, 4, 6});
const auto grid = make_shared<Parameter>(element::f32, PartialShape{1, 7, 8, 2});
auto op = make_shared<GridSample>();
const auto& default_attrs = op->get_attributes();
EXPECT_EQ(default_attrs.align_corners, false);
EXPECT_EQ(default_attrs.mode, GridSample::InterpolationMode::BILINEAR);
EXPECT_EQ(default_attrs.padding_mode, GridSample::PaddingMode::ZEROS);
op->set_argument(0, data);
op->set_argument(1, grid);
op->set_attributes(
GridSample::Attributes(true, GridSample::InterpolationMode::BICUBIC, GridSample::PaddingMode::BORDER));
const auto& new_attrs = op->get_attributes();
EXPECT_EQ(new_attrs.align_corners, true);
EXPECT_EQ(new_attrs.mode, GridSample::InterpolationMode::BICUBIC);
EXPECT_EQ(new_attrs.padding_mode, GridSample::PaddingMode::BORDER);
op->validate_and_infer_types();
EXPECT_EQ(op->get_element_type(), element::i32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{1, 3, 7, 8}));
}
TEST(type_prop, grid_sample_default_attributes) {
const auto data = make_shared<opset9::Parameter>(element::i32, PartialShape{1, 3, 224, 224});
const auto grid = make_shared<opset9::Parameter>(element::f32, PartialShape{1, 10, 10, 2});
const auto grid_sample = make_shared<opset9::GridSample>(data, grid, opset9::GridSample::Attributes{});
@ -29,6 +59,82 @@ TEST(type_prop, grid_sample_dynamic_batch) {
<< "The output shape of GridSample is incorrect";
}
TEST(type_prop, grid_sample_interval_dims_and_labels) {
auto data_pshape = PartialShape{{2, 4}, {1, 3}, 128, 256};
set_shape_labels(data_pshape, 10);
const auto data = make_shared<opset9::Parameter>(element::i32, data_pshape);
auto grid_pshape = PartialShape{{3, 8}, {4, 6}, {5, 7}, 2};
set_shape_labels(grid_pshape, 20);
const auto grid = make_shared<opset9::Parameter>(element::f32, grid_pshape);
const auto grid_sample = make_shared<opset9::GridSample>(data, grid, opset9::GridSample::Attributes{});
const auto& out_shape = grid_sample->get_output_partial_shape(0);
EXPECT_EQ(out_shape, (PartialShape{{3, 4}, {1, 3}, {4, 6}, {5, 7}}));
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(10, 11, 21, 22));
}
TEST(type_prop, grid_sample_static_batch_data_labeled_dynamic_grid_batch) {
auto data_pshape = PartialShape{2, {1, 3}, 224, 224};
const auto data = make_shared<opset9::Parameter>(element::i32, data_pshape);
auto grid_pshape = PartialShape{-1, {4, 6}, {5, 7}, 2};
set_shape_labels(grid_pshape, 20);
const auto grid = make_shared<opset9::Parameter>(element::f32, grid_pshape);
const auto grid_sample = make_shared<opset9::GridSample>(data, grid, opset9::GridSample::Attributes{});
const auto& out_shape = grid_sample->get_output_partial_shape(0);
EXPECT_EQ(out_shape, (PartialShape{2, {1, 3}, {4, 6}, {5, 7}}));
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, ov::no_label, 21, 22));
}
TEST(type_prop, grid_sample_labeled_dynamic_batch_data_labeled_static_grid_batch) {
auto data_pshape = PartialShape{-1, {1, 3}, 224, 224};
set_shape_labels(data_pshape, 10);
const auto data = make_shared<opset9::Parameter>(element::i32, data_pshape);
auto grid_pshape = PartialShape{2, Dimension(4, 6), Dimension(5, 7), 2};
const auto grid = make_shared<opset9::Parameter>(element::f32, grid_pshape);
const auto grid_sample = make_shared<opset9::GridSample>(data, grid, opset9::GridSample::Attributes{});
const auto& out_shape = grid_sample->get_output_partial_shape(0);
EXPECT_EQ(out_shape, (PartialShape{2, {1, 3}, {4, 6}, {5, 7}}));
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(10, 11, ov::no_label, ov::no_label));
}
TEST(type_prop, grid_sample_labeled_interval_batch_data_dynamic_grid_batch) {
auto data_pshape = PartialShape{{2, 4}, 3, 224, 224};
set_shape_labels(data_pshape, 10);
const auto data = make_shared<opset9::Parameter>(element::i32, data_pshape);
auto grid_pshape = PartialShape{-1, 6, 7, 2};
const auto grid = make_shared<opset9::Parameter>(element::f32, grid_pshape);
const auto grid_sample = make_shared<opset9::GridSample>(data, grid, opset9::GridSample::Attributes{});
const auto& out_shape = grid_sample->get_output_partial_shape(0);
EXPECT_EQ(out_shape, (PartialShape{{2, 4}, 3, 6, 7}));
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(10, 11, ov::no_label, ov::no_label));
}
TEST(type_prop, grid_sample_dynamic_batch_data_labeled_interval_grid_batch) {
auto data_pshape = PartialShape{-1, 3, 224, 224};
const auto data = make_shared<opset9::Parameter>(element::i32, data_pshape);
auto grid_pshape = PartialShape{{2, 4}, 6, 7, 2};
set_shape_labels(grid_pshape, 20);
const auto grid = make_shared<opset9::Parameter>(element::f32, grid_pshape);
const auto grid_sample = make_shared<opset9::GridSample>(data, grid, opset9::GridSample::Attributes{});
const auto& out_shape = grid_sample->get_output_partial_shape(0);
EXPECT_EQ(out_shape, (PartialShape{{2, 4}, 3, 6, 7}));
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, ov::no_label, 21, 22));
}
TEST(type_prop, grid_sample_dynamic_output_spatials) {
const auto data = make_shared<opset9::Parameter>(element::i32, PartialShape{2, 3, 224, 224});
const auto grid =

View File

@ -0,0 +1,39 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "utils.hpp"
#include "openvino/opsets/opset9.hpp"
#include "grid_sample_shape_inference.hpp"
using namespace ov;
using namespace ov::intel_cpu;
class GridSampleStaticShapeInferenceTest : public OpStaticShapeInferenceTest<opset9::GridSample> {};
TEST_F(GridSampleStaticShapeInferenceTest, GridSample) {
const auto data = std::make_shared<opset9::Parameter>(element::i32, PartialShape{-1, -1, -1, -1});
const auto grid = std::make_shared<opset9::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
op = make_op(data, grid, opset9::GridSample::Attributes{});
input_shapes = {StaticShape{2, 3, 4, 8}, StaticShape{2, 6, 7, 2}};
output_shapes = {StaticShape{}};
exp_shape = StaticShape{2, 3, 6, 7};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], exp_shape);
}
TEST_F(GridSampleStaticShapeInferenceTest, GridSample_default_constructor) {
op = make_op();
input_shapes = {StaticShape{2, 3, 4, 8}, StaticShape{2, 6, 7, 2}};
output_shapes = {StaticShape{}};
exp_shape = StaticShape{2, 3, 6, 7};
shape_infer(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], exp_shape);
}