[ShapeInference] GatherTree shape infer improvements (#15399)

* Shape infer improvments

* Add type_prop label and interval dims tests

* Update shape_infer tests

* Use new shape_infer

* Revert headers changes

* Rename test file
This commit is contained in:
Katarzyna Mitrus 2023-01-31 14:04:19 +01:00 committed by GitHub
parent de74d3c837
commit 407590cfc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 168 additions and 60 deletions

View File

@ -7,22 +7,22 @@
namespace ov {
namespace op {
namespace v1 {
template <class T>
void shape_infer(const GatherTree* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
NODE_VALIDATION_CHECK(op, input_shapes.size() == 4 && output_shapes.size() == 1);
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
const auto& step_ids_pshape = input_shapes[0];
const auto& parent_idx_pshape = input_shapes[1];
template <class TShape>
std::vector<TShape> shape_infer(const GatherTree* op, const std::vector<TShape>& input_shapes) {
NODE_VALIDATION_CHECK(op, input_shapes.size() == 4);
using DimType = typename std::iterator_traits<typename TShape::iterator>::value_type;
const auto& step_ids_shape = input_shapes[0];
const auto& parent_idx_shape = input_shapes[1];
const auto& max_seq_len_pshape = input_shapes[2];
const auto& end_token_pshape = input_shapes[3];
auto& result_pshape = output_shapes[0];
result_pshape = step_ids_pshape;
auto result_shape = step_ids_shape;
NODE_VALIDATION_CHECK(op,
T::merge_into(result_pshape, parent_idx_pshape) && result_pshape.rank().compatible(3),
TShape::merge_into(result_shape, parent_idx_shape) && result_shape.rank().compatible(3),
"step_ids and parent_idx inputs must have the same shape with rank 3. Got: ",
step_ids_pshape,
step_ids_shape,
" and ",
parent_idx_pshape,
parent_idx_shape,
", respectively");
NODE_VALIDATION_CHECK(op,
@ -30,12 +30,12 @@ void shape_infer(const GatherTree* op, const std::vector<T>& input_shapes, std::
"max_seq_len input must have rank 1. Got: ",
max_seq_len_pshape);
if (result_pshape.rank().is_static() && max_seq_len_pshape.rank().is_static()) {
if (result_shape.rank().is_static() && max_seq_len_pshape.rank().is_static()) {
NODE_VALIDATION_CHECK(op,
DimType::merge(result_pshape[1], result_pshape[1], max_seq_len_pshape[0]),
DimType::merge(result_shape[1], result_shape[1], max_seq_len_pshape[0]),
"Number of elements of max_seq_len input must match BATCH_SIZE dimension of "
"step_ids/parent_idx inputs. Got: ",
result_pshape[1],
result_shape[1],
" and ",
max_seq_len_pshape[0],
", respectively");
@ -45,6 +45,12 @@ void shape_infer(const GatherTree* op, const std::vector<T>& input_shapes, std::
end_token_pshape.rank().compatible(0),
"end_token input must be scalar. Got: ",
end_token_pshape);
return {result_shape};
}
template <class TShape>
void shape_infer(const GatherTree* op, const std::vector<TShape>& input_shapes, std::vector<TShape>& output_shapes) {
output_shapes = shape_infer(op, input_shapes);
}
} // namespace v1
} // namespace op

View File

@ -4,10 +4,10 @@
#include "ngraph/op/gather_tree.hpp"
#include <gather_tree_shape_inference.hpp>
#include "gather_tree_shape_inference.hpp"
#include "itt.hpp"
#include "ngraph/shape.hpp"
#include "openvino/core/validation_util.hpp"
using namespace std;
using namespace ngraph;
@ -59,13 +59,6 @@ void op::v1::GatherTree::validate_and_infer_types() {
"Element type of inputs must be numeric. Got: ",
result_et);
const auto& step_ids_pshape = get_input_partial_shape(0);
const auto& parent_idx_pshape = get_input_partial_shape(1);
const auto& max_seq_len_pshape = get_input_partial_shape(2);
const auto& end_token_pshape = get_input_partial_shape(3);
std::vector<PartialShape> input_shapes = {step_ids_pshape, parent_idx_pshape, max_seq_len_pshape, end_token_pshape},
output_shapes = {PartialShape{}};
shape_infer(this, input_shapes, output_shapes);
set_output_type(0, result_et, output_shapes[0]);
const auto output_shape = shape_infer(this, ov::get_node_input_partial_shapes(*this)).front();
set_output_type(0, result_et, output_shape);
}

View File

@ -5,12 +5,15 @@
#include <array>
#include <utility>
#include "dimension_tracker.hpp"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "openvino/op/ops.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
using namespace testing;
namespace {
constexpr size_t step_ids_input_idx = 0;
@ -39,6 +42,25 @@ std::shared_ptr<Node> makeGatherTreeOp(const GatherTreeInputParams& p) {
}
} // namespace
TEST(type_prop, gather_tree_default_constructor) {
auto op = std::make_shared<op::v1::GatherTree>();
auto step_ids = std::make_shared<op::Parameter>(element::i32, PartialShape{2, 4, 3});
auto parent_idx = std::make_shared<op::Parameter>(element::i32, PartialShape{2, 4, 3});
auto max_seq_len = std::make_shared<op::Parameter>(element::i32, PartialShape{4});
auto end_token = std::make_shared<op::Parameter>(element::i32, PartialShape{});
op->set_argument(0, step_ids);
op->set_argument(1, parent_idx);
op->set_argument(2, max_seq_len);
op->set_argument(3, end_token);
op->validate_and_infer_types();
EXPECT_EQ(op->get_output_element_type(0), element::i32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 4, 3}));
}
TEST(type_prop, gather_tree_invalid_input_element_type) {
Shape scalar_shape{};
Shape vector_shape{2};
@ -230,12 +252,12 @@ TEST(type_prop, gather_tree_output_shape) {
std::vector<std::pair<PartialShape, PartialShape>> input_shapes = {
{PartialShape{1, 2, 3}, PartialShape{2}},
{PartialShape{1, 2, 3}, PartialShape::dynamic(1)},
{PartialShape{Dimension(), 2, Dimension()}, PartialShape{2}},
{PartialShape{-1, 2, -1}, PartialShape{2}},
{
PartialShape::dynamic(3),
PartialShape{4},
},
{PartialShape{Dimension(), Dimension(3, 5), Dimension()}, PartialShape{Dimension(1, 3)}},
{PartialShape{-1, {3, 5}, -1}, PartialShape{{1, 3}}},
{PartialShape::dynamic(), PartialShape::dynamic()}};
std::vector<GatherTreeInputParams> test_cases;
std::for_each(std::begin(input_shapes), std::end(input_shapes), [&](std::pair<PartialShape, PartialShape> shapes) {
@ -254,8 +276,8 @@ TEST(type_prop, gather_tree_output_shape) {
if (result_shape.rank().is_static() && max_seq_len_shape.rank().is_static()) {
result_shape[1] = result_shape[1] & max_seq_len_shape[0];
}
ASSERT_EQ(gather_tree->get_output_partial_shape(0), result_shape);
ASSERT_EQ(gather_tree->get_output_element_type(0), et);
EXPECT_EQ(gather_tree->get_output_partial_shape(0), result_shape);
EXPECT_EQ(gather_tree->get_output_element_type(0), et);
} catch (...) {
FAIL() << "Output shape check failed for unexpected reason";
}
@ -289,3 +311,85 @@ TEST(type_prop, gather_tree_invalid_end_token_rank) {
}
}
}
TEST(type_prop, gather_tree_interval_labeled_dims_all) {
auto step_ids_shape = PartialShape{{2, 5}, {4, 8}, {3, 6}};
set_shape_labels(step_ids_shape, 10);
auto parent_ids_shape = PartialShape{{3, 7}, {3, 7}, {1, 4}};
set_shape_labels(parent_ids_shape, 20);
auto max_seq_len_shape = PartialShape{{2, 6}};
set_shape_labels(max_seq_len_shape, 30);
auto step_ids = std::make_shared<op::v0::Parameter>(element::i32, step_ids_shape);
auto parent_ids = std::make_shared<op::v0::Parameter>(element::i32, parent_ids_shape);
auto max_seq_len = std::make_shared<op::v0::Parameter>(element::i32, max_seq_len_shape);
auto end_token = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
auto op = std::make_shared<op::v1::GatherTree>(step_ids, parent_ids, max_seq_len, end_token);
const auto& out_shape = op->get_output_partial_shape(0);
EXPECT_EQ(op->get_output_element_type(0), element::i32);
EXPECT_EQ(out_shape, (PartialShape{{3, 5}, {4, 6}, {3, 4}}));
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, 30, 22));
}
TEST(type_prop, gather_tree_interval_labeled_dims_step_ids) {
auto step_ids_shape = PartialShape{{2, 5}, {4, 8}, {3, 6}};
set_shape_labels(step_ids_shape, 10);
auto parent_ids_shape = PartialShape{{3, 7}, {3, 7}, {1, 4}};
auto max_seq_len_shape = PartialShape{{2, 6}};
auto step_ids = std::make_shared<op::v0::Parameter>(element::i32, step_ids_shape);
auto parent_ids = std::make_shared<op::v0::Parameter>(element::i32, parent_ids_shape);
auto max_seq_len = std::make_shared<op::v0::Parameter>(element::i32, max_seq_len_shape);
auto end_token = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
auto op = std::make_shared<op::v1::GatherTree>(step_ids, parent_ids, max_seq_len, end_token);
const auto& out_shape = op->get_output_partial_shape(0);
EXPECT_EQ(op->get_output_element_type(0), element::i32);
EXPECT_EQ(out_shape, (PartialShape{{3, 5}, {4, 6}, {3, 4}}));
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(10, 11, 12));
}
TEST(type_prop, gather_tree_interval_labeled_dims_parent_ids) {
auto step_ids_shape = PartialShape{{2, 5}, {4, 8}, {3, 6}};
auto parent_ids_shape = PartialShape{{3, 7}, {3, 7}, {1, 4}};
set_shape_labels(parent_ids_shape, 20);
auto max_seq_len_shape = PartialShape{{2, 6}};
auto step_ids = std::make_shared<op::v0::Parameter>(element::i32, step_ids_shape);
auto parent_ids = std::make_shared<op::v0::Parameter>(element::i32, parent_ids_shape);
auto max_seq_len = std::make_shared<op::v0::Parameter>(element::i32, max_seq_len_shape);
auto end_token = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
auto op = std::make_shared<op::v1::GatherTree>(step_ids, parent_ids, max_seq_len, end_token);
const auto& out_shape = op->get_output_partial_shape(0);
EXPECT_EQ(op->get_output_element_type(0), element::i32);
EXPECT_EQ(out_shape, (PartialShape{{3, 5}, {4, 6}, {3, 4}}));
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, 21, 22));
}
TEST(type_prop, gather_tree_interval_labeled_dims_seq_len) {
auto step_ids_shape = PartialShape{{2, 5}, {4, 8}, {3, 6}};
auto parent_ids_shape = PartialShape{{3, 7}, {3, 7}, {1, 4}};
auto max_seq_len_shape = PartialShape{{2, 6}};
set_shape_labels(max_seq_len_shape, 30);
auto step_ids = std::make_shared<op::v0::Parameter>(element::i32, step_ids_shape);
auto parent_ids = std::make_shared<op::v0::Parameter>(element::i32, parent_ids_shape);
auto max_seq_len = std::make_shared<op::v0::Parameter>(element::i32, max_seq_len_shape);
auto end_token = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
auto op = std::make_shared<op::v1::GatherTree>(step_ids, parent_ids, max_seq_len, end_token);
const auto& out_shape = op->get_output_partial_shape(0);
EXPECT_EQ(op->get_output_element_type(0), element::i32);
EXPECT_EQ(out_shape, (PartialShape{{3, 5}, {4, 6}, {3, 4}}));
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(ov::no_label, 30, ov::no_label));
}

View File

@ -1,31 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <gather_tree_shape_inference.hpp>
#include <openvino/op/gather_tree.hpp>
#include <openvino/op/ops.hpp>
#include <openvino/op/parameter.hpp>
#include <utils/shape_inference/shape_inference.hpp>
#include <utils/shape_inference/static_shape.hpp>
using namespace ov;
using namespace ov::intel_cpu;
TEST(StaticShapeInferenceTest, GatherTreeTest) {
auto step_ids = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto parent_idx = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto max_seq_len = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1});
auto end_token = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{Shape{}});
auto gather_tree = std::make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
// Test StaticShape
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 2, 3},
StaticShape{1, 2, 3},
StaticShape{2},
StaticShape{}},
static_output_shapes = {StaticShape{}};
shape_inference(gather_tree.get(), static_input_shapes, static_output_shapes);
ASSERT_EQ(static_output_shapes[0], (StaticShape{1, 2, 3}));
}

View File

@ -0,0 +1,36 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gather_tree_shape_inference.hpp"
#include <gtest/gtest.h>
#include "openvino/op/ops.hpp"
#include "utils.hpp"
using namespace ov;
using namespace ov::intel_cpu;
using namespace testing;
class GatherTreeStaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v1::GatherTree> {};
TEST_F(GatherTreeStaticShapeInferenceTest, gather_tree) {
auto step_ids = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto parent_idx = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
auto max_seq_len = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1});
auto end_token = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{Shape{}});
op = make_op(step_ids, parent_idx, max_seq_len, end_token);
input_shapes = {StaticShape{1, 2, 3}, StaticShape{1, 2, 3}, StaticShape{2}, StaticShape{}};
output_shapes = {StaticShape{}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], (StaticShape{1, 2, 3}));
}
TEST_F(GatherTreeStaticShapeInferenceTest, gather_tree_default_ctor) {
op = make_op();
input_shapes = {StaticShape{2, 4, 3}, StaticShape{2, 4, 3}, StaticShape{4}, StaticShape{}};
output_shapes = {StaticShape{}};
shape_infer(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], (StaticShape{2, 4, 3}));
}