[ShapeInference] GridSample shape infer review (#15102)
* Add more type_prop tests for interval dims and labels * Add setter for grid sample attributes * Merge grid sample batch dims * Add StaticShapeInferenceTest for GridSample * Fix label test * Use OpStaticShapeInferenceTest fixture in test
This commit is contained in:
parent
32fce5cb40
commit
fd6640b6eb
@ -55,6 +55,10 @@ public:
|
||||
return m_attributes;
|
||||
}
|
||||
|
||||
void set_attributes(const Attributes& attributes) {
|
||||
m_attributes = attributes;
|
||||
}
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
|
@ -41,15 +41,9 @@ void shape_infer(const GridSample* op, const std::vector<shape_t>& input_shapes,
|
||||
if (data_shape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(
|
||||
op,
|
||||
data_shape[0].compatible(grid_shape[0]),
|
||||
shape_t::value_type::merge(batch_dim, grid_shape[0], data_shape[0]),
|
||||
"The batch dimension in the input data tensor's shape doesn't match the batch dimension in "
|
||||
"the grid tensor's shape.");
|
||||
|
||||
// both dimensions should match but use the one which is (possibly) static for the output shape
|
||||
if (data_shape[0].is_static()) {
|
||||
batch_dim = data_shape[0];
|
||||
}
|
||||
|
||||
channel_dim = data_shape[1];
|
||||
}
|
||||
} else if (data_shape.rank().is_static()) {
|
||||
|
@ -2,14 +2,44 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "dimension_tracker.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "openvino/op/util/attr_types.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov;
|
||||
using namespace opset9;
|
||||
using namespace testing;
|
||||
|
||||
TEST(type_prop, grid_sample_default) {
|
||||
TEST(type_prop, grid_sample_default_constructor) {
|
||||
const auto data = make_shared<Parameter>(element::i32, PartialShape{1, 3, 4, 6});
|
||||
const auto grid = make_shared<Parameter>(element::f32, PartialShape{1, 7, 8, 2});
|
||||
auto op = make_shared<GridSample>();
|
||||
|
||||
const auto& default_attrs = op->get_attributes();
|
||||
EXPECT_EQ(default_attrs.align_corners, false);
|
||||
EXPECT_EQ(default_attrs.mode, GridSample::InterpolationMode::BILINEAR);
|
||||
EXPECT_EQ(default_attrs.padding_mode, GridSample::PaddingMode::ZEROS);
|
||||
|
||||
op->set_argument(0, data);
|
||||
op->set_argument(1, grid);
|
||||
|
||||
op->set_attributes(
|
||||
GridSample::Attributes(true, GridSample::InterpolationMode::BICUBIC, GridSample::PaddingMode::BORDER));
|
||||
const auto& new_attrs = op->get_attributes();
|
||||
EXPECT_EQ(new_attrs.align_corners, true);
|
||||
EXPECT_EQ(new_attrs.mode, GridSample::InterpolationMode::BICUBIC);
|
||||
EXPECT_EQ(new_attrs.padding_mode, GridSample::PaddingMode::BORDER);
|
||||
|
||||
op->validate_and_infer_types();
|
||||
|
||||
EXPECT_EQ(op->get_element_type(), element::i32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{1, 3, 7, 8}));
|
||||
}
|
||||
|
||||
TEST(type_prop, grid_sample_default_attributes) {
|
||||
const auto data = make_shared<opset9::Parameter>(element::i32, PartialShape{1, 3, 224, 224});
|
||||
const auto grid = make_shared<opset9::Parameter>(element::f32, PartialShape{1, 10, 10, 2});
|
||||
const auto grid_sample = make_shared<opset9::GridSample>(data, grid, opset9::GridSample::Attributes{});
|
||||
@ -29,6 +59,82 @@ TEST(type_prop, grid_sample_dynamic_batch) {
|
||||
<< "The output shape of GridSample is incorrect";
|
||||
}
|
||||
|
||||
TEST(type_prop, grid_sample_interval_dims_and_labels) {
|
||||
auto data_pshape = PartialShape{{2, 4}, {1, 3}, 128, 256};
|
||||
set_shape_labels(data_pshape, 10);
|
||||
const auto data = make_shared<opset9::Parameter>(element::i32, data_pshape);
|
||||
|
||||
auto grid_pshape = PartialShape{{3, 8}, {4, 6}, {5, 7}, 2};
|
||||
set_shape_labels(grid_pshape, 20);
|
||||
const auto grid = make_shared<opset9::Parameter>(element::f32, grid_pshape);
|
||||
|
||||
const auto grid_sample = make_shared<opset9::GridSample>(data, grid, opset9::GridSample::Attributes{});
|
||||
|
||||
const auto& out_shape = grid_sample->get_output_partial_shape(0);
|
||||
EXPECT_EQ(out_shape, (PartialShape{{3, 4}, {1, 3}, {4, 6}, {5, 7}}));
|
||||
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(10, 11, 21, 22));
|
||||
}
|
||||
|
||||
TEST(type_prop, grid_sample_static_batch_data_labeled_dynamic_grid_batch) {
|
||||
auto data_pshape = PartialShape{2, {1, 3}, 224, 224};
|
||||
const auto data = make_shared<opset9::Parameter>(element::i32, data_pshape);
|
||||
|
||||
auto grid_pshape = PartialShape{-1, {4, 6}, {5, 7}, 2};
|
||||
set_shape_labels(grid_pshape, 20);
|
||||
const auto grid = make_shared<opset9::Parameter>(element::f32, grid_pshape);
|
||||
|
||||
const auto grid_sample = make_shared<opset9::GridSample>(data, grid, opset9::GridSample::Attributes{});
|
||||
|
||||
const auto& out_shape = grid_sample->get_output_partial_shape(0);
|
||||
EXPECT_EQ(out_shape, (PartialShape{2, {1, 3}, {4, 6}, {5, 7}}));
|
||||
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, ov::no_label, 21, 22));
|
||||
}
|
||||
|
||||
TEST(type_prop, grid_sample_labeled_dynamic_batch_data_labeled_static_grid_batch) {
|
||||
auto data_pshape = PartialShape{-1, {1, 3}, 224, 224};
|
||||
set_shape_labels(data_pshape, 10);
|
||||
const auto data = make_shared<opset9::Parameter>(element::i32, data_pshape);
|
||||
|
||||
auto grid_pshape = PartialShape{2, Dimension(4, 6), Dimension(5, 7), 2};
|
||||
const auto grid = make_shared<opset9::Parameter>(element::f32, grid_pshape);
|
||||
|
||||
const auto grid_sample = make_shared<opset9::GridSample>(data, grid, opset9::GridSample::Attributes{});
|
||||
|
||||
const auto& out_shape = grid_sample->get_output_partial_shape(0);
|
||||
EXPECT_EQ(out_shape, (PartialShape{2, {1, 3}, {4, 6}, {5, 7}}));
|
||||
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(10, 11, ov::no_label, ov::no_label));
|
||||
}
|
||||
|
||||
TEST(type_prop, grid_sample_labeled_interval_batch_data_dynamic_grid_batch) {
|
||||
auto data_pshape = PartialShape{{2, 4}, 3, 224, 224};
|
||||
set_shape_labels(data_pshape, 10);
|
||||
const auto data = make_shared<opset9::Parameter>(element::i32, data_pshape);
|
||||
|
||||
auto grid_pshape = PartialShape{-1, 6, 7, 2};
|
||||
const auto grid = make_shared<opset9::Parameter>(element::f32, grid_pshape);
|
||||
|
||||
const auto grid_sample = make_shared<opset9::GridSample>(data, grid, opset9::GridSample::Attributes{});
|
||||
|
||||
const auto& out_shape = grid_sample->get_output_partial_shape(0);
|
||||
EXPECT_EQ(out_shape, (PartialShape{{2, 4}, 3, 6, 7}));
|
||||
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(10, 11, ov::no_label, ov::no_label));
|
||||
}
|
||||
|
||||
TEST(type_prop, grid_sample_dynamic_batch_data_labeled_interval_grid_batch) {
|
||||
auto data_pshape = PartialShape{-1, 3, 224, 224};
|
||||
const auto data = make_shared<opset9::Parameter>(element::i32, data_pshape);
|
||||
|
||||
auto grid_pshape = PartialShape{{2, 4}, 6, 7, 2};
|
||||
set_shape_labels(grid_pshape, 20);
|
||||
const auto grid = make_shared<opset9::Parameter>(element::f32, grid_pshape);
|
||||
|
||||
const auto grid_sample = make_shared<opset9::GridSample>(data, grid, opset9::GridSample::Attributes{});
|
||||
|
||||
const auto& out_shape = grid_sample->get_output_partial_shape(0);
|
||||
EXPECT_EQ(out_shape, (PartialShape{{2, 4}, 3, 6, 7}));
|
||||
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, ov::no_label, 21, 22));
|
||||
}
|
||||
|
||||
TEST(type_prop, grid_sample_dynamic_output_spatials) {
|
||||
const auto data = make_shared<opset9::Parameter>(element::i32, PartialShape{2, 3, 224, 224});
|
||||
const auto grid =
|
||||
|
@ -0,0 +1,39 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "utils.hpp"
|
||||
#include "openvino/opsets/opset9.hpp"
|
||||
#include "grid_sample_shape_inference.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::intel_cpu;
|
||||
|
||||
class GridSampleStaticShapeInferenceTest : public OpStaticShapeInferenceTest<opset9::GridSample> {};
|
||||
|
||||
TEST_F(GridSampleStaticShapeInferenceTest, GridSample) {
|
||||
const auto data = std::make_shared<opset9::Parameter>(element::i32, PartialShape{-1, -1, -1, -1});
|
||||
const auto grid = std::make_shared<opset9::Parameter>(element::f32, PartialShape{-1, -1, -1, -1});
|
||||
|
||||
op = make_op(data, grid, opset9::GridSample::Attributes{});
|
||||
|
||||
input_shapes = {StaticShape{2, 3, 4, 8}, StaticShape{2, 6, 7, 2}};
|
||||
output_shapes = {StaticShape{}};
|
||||
exp_shape = StaticShape{2, 3, 6, 7};
|
||||
|
||||
shape_inference(op.get(), input_shapes, output_shapes);
|
||||
EXPECT_EQ(output_shapes[0], exp_shape);
|
||||
}
|
||||
|
||||
TEST_F(GridSampleStaticShapeInferenceTest, GridSample_default_constructor) {
|
||||
op = make_op();
|
||||
|
||||
input_shapes = {StaticShape{2, 3, 4, 8}, StaticShape{2, 6, 7, 2}};
|
||||
output_shapes = {StaticShape{}};
|
||||
exp_shape = StaticShape{2, 3, 6, 7};
|
||||
|
||||
shape_infer(op.get(), input_shapes, output_shapes);
|
||||
EXPECT_EQ(output_shapes[0], exp_shape);
|
||||
}
|
Loading…
Reference in New Issue
Block a user