[ShapeInference] GatherTree shape infer improvements (#15399)
* Shape infer improvments * Add type_prop label and interval dims tests * Update shape_infer tests * Use new shape_infer * Revert headers changes * Rename test file
This commit is contained in:
parent
de74d3c837
commit
407590cfc2
@ -7,22 +7,22 @@
|
||||
namespace ov {
|
||||
namespace op {
|
||||
namespace v1 {
|
||||
template <class T>
|
||||
void shape_infer(const GatherTree* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
|
||||
NODE_VALIDATION_CHECK(op, input_shapes.size() == 4 && output_shapes.size() == 1);
|
||||
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
|
||||
const auto& step_ids_pshape = input_shapes[0];
|
||||
const auto& parent_idx_pshape = input_shapes[1];
|
||||
|
||||
template <class TShape>
|
||||
std::vector<TShape> shape_infer(const GatherTree* op, const std::vector<TShape>& input_shapes) {
|
||||
NODE_VALIDATION_CHECK(op, input_shapes.size() == 4);
|
||||
using DimType = typename std::iterator_traits<typename TShape::iterator>::value_type;
|
||||
const auto& step_ids_shape = input_shapes[0];
|
||||
const auto& parent_idx_shape = input_shapes[1];
|
||||
const auto& max_seq_len_pshape = input_shapes[2];
|
||||
const auto& end_token_pshape = input_shapes[3];
|
||||
auto& result_pshape = output_shapes[0];
|
||||
result_pshape = step_ids_pshape;
|
||||
auto result_shape = step_ids_shape;
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
T::merge_into(result_pshape, parent_idx_pshape) && result_pshape.rank().compatible(3),
|
||||
TShape::merge_into(result_shape, parent_idx_shape) && result_shape.rank().compatible(3),
|
||||
"step_ids and parent_idx inputs must have the same shape with rank 3. Got: ",
|
||||
step_ids_pshape,
|
||||
step_ids_shape,
|
||||
" and ",
|
||||
parent_idx_pshape,
|
||||
parent_idx_shape,
|
||||
", respectively");
|
||||
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
@ -30,12 +30,12 @@ void shape_infer(const GatherTree* op, const std::vector<T>& input_shapes, std::
|
||||
"max_seq_len input must have rank 1. Got: ",
|
||||
max_seq_len_pshape);
|
||||
|
||||
if (result_pshape.rank().is_static() && max_seq_len_pshape.rank().is_static()) {
|
||||
if (result_shape.rank().is_static() && max_seq_len_pshape.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(op,
|
||||
DimType::merge(result_pshape[1], result_pshape[1], max_seq_len_pshape[0]),
|
||||
DimType::merge(result_shape[1], result_shape[1], max_seq_len_pshape[0]),
|
||||
"Number of elements of max_seq_len input must match BATCH_SIZE dimension of "
|
||||
"step_ids/parent_idx inputs. Got: ",
|
||||
result_pshape[1],
|
||||
result_shape[1],
|
||||
" and ",
|
||||
max_seq_len_pshape[0],
|
||||
", respectively");
|
||||
@ -45,6 +45,12 @@ void shape_infer(const GatherTree* op, const std::vector<T>& input_shapes, std::
|
||||
end_token_pshape.rank().compatible(0),
|
||||
"end_token input must be scalar. Got: ",
|
||||
end_token_pshape);
|
||||
return {result_shape};
|
||||
}
|
||||
|
||||
template <class TShape>
|
||||
void shape_infer(const GatherTree* op, const std::vector<TShape>& input_shapes, std::vector<TShape>& output_shapes) {
|
||||
output_shapes = shape_infer(op, input_shapes);
|
||||
}
|
||||
} // namespace v1
|
||||
} // namespace op
|
||||
|
@ -4,10 +4,10 @@
|
||||
|
||||
#include "ngraph/op/gather_tree.hpp"
|
||||
|
||||
#include <gather_tree_shape_inference.hpp>
|
||||
|
||||
#include "gather_tree_shape_inference.hpp"
|
||||
#include "itt.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
@ -59,13 +59,6 @@ void op::v1::GatherTree::validate_and_infer_types() {
|
||||
"Element type of inputs must be numeric. Got: ",
|
||||
result_et);
|
||||
|
||||
const auto& step_ids_pshape = get_input_partial_shape(0);
|
||||
const auto& parent_idx_pshape = get_input_partial_shape(1);
|
||||
const auto& max_seq_len_pshape = get_input_partial_shape(2);
|
||||
const auto& end_token_pshape = get_input_partial_shape(3);
|
||||
std::vector<PartialShape> input_shapes = {step_ids_pshape, parent_idx_pshape, max_seq_len_pshape, end_token_pshape},
|
||||
output_shapes = {PartialShape{}};
|
||||
shape_infer(this, input_shapes, output_shapes);
|
||||
|
||||
set_output_type(0, result_et, output_shapes[0]);
|
||||
const auto output_shape = shape_infer(this, ov::get_node_input_partial_shapes(*this)).front();
|
||||
set_output_type(0, result_et, output_shape);
|
||||
}
|
||||
|
@ -5,12 +5,15 @@
|
||||
#include <array>
|
||||
#include <utility>
|
||||
|
||||
#include "dimension_tracker.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "openvino/op/ops.hpp"
|
||||
#include "util/type_prop.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
using namespace testing;
|
||||
|
||||
namespace {
|
||||
constexpr size_t step_ids_input_idx = 0;
|
||||
@ -39,6 +42,25 @@ std::shared_ptr<Node> makeGatherTreeOp(const GatherTreeInputParams& p) {
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(type_prop, gather_tree_default_constructor) {
|
||||
auto op = std::make_shared<op::v1::GatherTree>();
|
||||
|
||||
auto step_ids = std::make_shared<op::Parameter>(element::i32, PartialShape{2, 4, 3});
|
||||
auto parent_idx = std::make_shared<op::Parameter>(element::i32, PartialShape{2, 4, 3});
|
||||
auto max_seq_len = std::make_shared<op::Parameter>(element::i32, PartialShape{4});
|
||||
auto end_token = std::make_shared<op::Parameter>(element::i32, PartialShape{});
|
||||
|
||||
op->set_argument(0, step_ids);
|
||||
op->set_argument(1, parent_idx);
|
||||
op->set_argument(2, max_seq_len);
|
||||
op->set_argument(3, end_token);
|
||||
|
||||
op->validate_and_infer_types();
|
||||
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::i32);
|
||||
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 4, 3}));
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_tree_invalid_input_element_type) {
|
||||
Shape scalar_shape{};
|
||||
Shape vector_shape{2};
|
||||
@ -230,12 +252,12 @@ TEST(type_prop, gather_tree_output_shape) {
|
||||
std::vector<std::pair<PartialShape, PartialShape>> input_shapes = {
|
||||
{PartialShape{1, 2, 3}, PartialShape{2}},
|
||||
{PartialShape{1, 2, 3}, PartialShape::dynamic(1)},
|
||||
{PartialShape{Dimension(), 2, Dimension()}, PartialShape{2}},
|
||||
{PartialShape{-1, 2, -1}, PartialShape{2}},
|
||||
{
|
||||
PartialShape::dynamic(3),
|
||||
PartialShape{4},
|
||||
},
|
||||
{PartialShape{Dimension(), Dimension(3, 5), Dimension()}, PartialShape{Dimension(1, 3)}},
|
||||
{PartialShape{-1, {3, 5}, -1}, PartialShape{{1, 3}}},
|
||||
{PartialShape::dynamic(), PartialShape::dynamic()}};
|
||||
std::vector<GatherTreeInputParams> test_cases;
|
||||
std::for_each(std::begin(input_shapes), std::end(input_shapes), [&](std::pair<PartialShape, PartialShape> shapes) {
|
||||
@ -254,8 +276,8 @@ TEST(type_prop, gather_tree_output_shape) {
|
||||
if (result_shape.rank().is_static() && max_seq_len_shape.rank().is_static()) {
|
||||
result_shape[1] = result_shape[1] & max_seq_len_shape[0];
|
||||
}
|
||||
ASSERT_EQ(gather_tree->get_output_partial_shape(0), result_shape);
|
||||
ASSERT_EQ(gather_tree->get_output_element_type(0), et);
|
||||
EXPECT_EQ(gather_tree->get_output_partial_shape(0), result_shape);
|
||||
EXPECT_EQ(gather_tree->get_output_element_type(0), et);
|
||||
} catch (...) {
|
||||
FAIL() << "Output shape check failed for unexpected reason";
|
||||
}
|
||||
@ -289,3 +311,85 @@ TEST(type_prop, gather_tree_invalid_end_token_rank) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_tree_interval_labeled_dims_all) {
|
||||
auto step_ids_shape = PartialShape{{2, 5}, {4, 8}, {3, 6}};
|
||||
set_shape_labels(step_ids_shape, 10);
|
||||
|
||||
auto parent_ids_shape = PartialShape{{3, 7}, {3, 7}, {1, 4}};
|
||||
set_shape_labels(parent_ids_shape, 20);
|
||||
|
||||
auto max_seq_len_shape = PartialShape{{2, 6}};
|
||||
set_shape_labels(max_seq_len_shape, 30);
|
||||
|
||||
auto step_ids = std::make_shared<op::v0::Parameter>(element::i32, step_ids_shape);
|
||||
auto parent_ids = std::make_shared<op::v0::Parameter>(element::i32, parent_ids_shape);
|
||||
auto max_seq_len = std::make_shared<op::v0::Parameter>(element::i32, max_seq_len_shape);
|
||||
auto end_token = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
|
||||
|
||||
auto op = std::make_shared<op::v1::GatherTree>(step_ids, parent_ids, max_seq_len, end_token);
|
||||
|
||||
const auto& out_shape = op->get_output_partial_shape(0);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::i32);
|
||||
EXPECT_EQ(out_shape, (PartialShape{{3, 5}, {4, 6}, {3, 4}}));
|
||||
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, 30, 22));
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_tree_interval_labeled_dims_step_ids) {
|
||||
auto step_ids_shape = PartialShape{{2, 5}, {4, 8}, {3, 6}};
|
||||
set_shape_labels(step_ids_shape, 10);
|
||||
|
||||
auto parent_ids_shape = PartialShape{{3, 7}, {3, 7}, {1, 4}};
|
||||
auto max_seq_len_shape = PartialShape{{2, 6}};
|
||||
|
||||
auto step_ids = std::make_shared<op::v0::Parameter>(element::i32, step_ids_shape);
|
||||
auto parent_ids = std::make_shared<op::v0::Parameter>(element::i32, parent_ids_shape);
|
||||
auto max_seq_len = std::make_shared<op::v0::Parameter>(element::i32, max_seq_len_shape);
|
||||
auto end_token = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
|
||||
|
||||
auto op = std::make_shared<op::v1::GatherTree>(step_ids, parent_ids, max_seq_len, end_token);
|
||||
|
||||
const auto& out_shape = op->get_output_partial_shape(0);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::i32);
|
||||
EXPECT_EQ(out_shape, (PartialShape{{3, 5}, {4, 6}, {3, 4}}));
|
||||
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(10, 11, 12));
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_tree_interval_labeled_dims_parent_ids) {
|
||||
auto step_ids_shape = PartialShape{{2, 5}, {4, 8}, {3, 6}};
|
||||
auto parent_ids_shape = PartialShape{{3, 7}, {3, 7}, {1, 4}};
|
||||
set_shape_labels(parent_ids_shape, 20);
|
||||
|
||||
auto max_seq_len_shape = PartialShape{{2, 6}};
|
||||
|
||||
auto step_ids = std::make_shared<op::v0::Parameter>(element::i32, step_ids_shape);
|
||||
auto parent_ids = std::make_shared<op::v0::Parameter>(element::i32, parent_ids_shape);
|
||||
auto max_seq_len = std::make_shared<op::v0::Parameter>(element::i32, max_seq_len_shape);
|
||||
auto end_token = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
|
||||
|
||||
auto op = std::make_shared<op::v1::GatherTree>(step_ids, parent_ids, max_seq_len, end_token);
|
||||
|
||||
const auto& out_shape = op->get_output_partial_shape(0);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::i32);
|
||||
EXPECT_EQ(out_shape, (PartialShape{{3, 5}, {4, 6}, {3, 4}}));
|
||||
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(20, 21, 22));
|
||||
}
|
||||
|
||||
TEST(type_prop, gather_tree_interval_labeled_dims_seq_len) {
|
||||
auto step_ids_shape = PartialShape{{2, 5}, {4, 8}, {3, 6}};
|
||||
auto parent_ids_shape = PartialShape{{3, 7}, {3, 7}, {1, 4}};
|
||||
auto max_seq_len_shape = PartialShape{{2, 6}};
|
||||
set_shape_labels(max_seq_len_shape, 30);
|
||||
|
||||
auto step_ids = std::make_shared<op::v0::Parameter>(element::i32, step_ids_shape);
|
||||
auto parent_ids = std::make_shared<op::v0::Parameter>(element::i32, parent_ids_shape);
|
||||
auto max_seq_len = std::make_shared<op::v0::Parameter>(element::i32, max_seq_len_shape);
|
||||
auto end_token = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
|
||||
|
||||
auto op = std::make_shared<op::v1::GatherTree>(step_ids, parent_ids, max_seq_len, end_token);
|
||||
|
||||
const auto& out_shape = op->get_output_partial_shape(0);
|
||||
EXPECT_EQ(op->get_output_element_type(0), element::i32);
|
||||
EXPECT_EQ(out_shape, (PartialShape{{3, 5}, {4, 6}, {3, 4}}));
|
||||
EXPECT_THAT(get_shape_labels(out_shape), ElementsAre(ov::no_label, 30, ov::no_label));
|
||||
}
|
||||
|
@ -1,31 +0,0 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <gather_tree_shape_inference.hpp>
|
||||
#include <openvino/op/gather_tree.hpp>
|
||||
#include <openvino/op/ops.hpp>
|
||||
#include <openvino/op/parameter.hpp>
|
||||
#include <utils/shape_inference/shape_inference.hpp>
|
||||
#include <utils/shape_inference/static_shape.hpp>
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::intel_cpu;
|
||||
|
||||
TEST(StaticShapeInferenceTest, GatherTreeTest) {
|
||||
auto step_ids = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
auto parent_idx = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
auto max_seq_len = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1});
|
||||
auto end_token = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{Shape{}});
|
||||
auto gather_tree = std::make_shared<op::v1::GatherTree>(step_ids, parent_idx, max_seq_len, end_token);
|
||||
// Test StaticShape
|
||||
std::vector<StaticShape> static_input_shapes = {StaticShape{1, 2, 3},
|
||||
StaticShape{1, 2, 3},
|
||||
StaticShape{2},
|
||||
StaticShape{}},
|
||||
static_output_shapes = {StaticShape{}};
|
||||
shape_inference(gather_tree.get(), static_input_shapes, static_output_shapes);
|
||||
ASSERT_EQ(static_output_shapes[0], (StaticShape{1, 2, 3}));
|
||||
}
|
@ -0,0 +1,36 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include "gather_tree_shape_inference.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "openvino/op/ops.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::intel_cpu;
|
||||
using namespace testing;
|
||||
|
||||
class GatherTreeStaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v1::GatherTree> {};
|
||||
|
||||
TEST_F(GatherTreeStaticShapeInferenceTest, gather_tree) {
|
||||
auto step_ids = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
auto parent_idx = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1, -1, -1});
|
||||
auto max_seq_len = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1});
|
||||
auto end_token = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{Shape{}});
|
||||
op = make_op(step_ids, parent_idx, max_seq_len, end_token);
|
||||
|
||||
input_shapes = {StaticShape{1, 2, 3}, StaticShape{1, 2, 3}, StaticShape{2}, StaticShape{}};
|
||||
output_shapes = {StaticShape{}};
|
||||
shape_inference(op.get(), input_shapes, output_shapes);
|
||||
EXPECT_EQ(output_shapes[0], (StaticShape{1, 2, 3}));
|
||||
}
|
||||
|
||||
TEST_F(GatherTreeStaticShapeInferenceTest, gather_tree_default_ctor) {
|
||||
op = make_op();
|
||||
input_shapes = {StaticShape{2, 4, 3}, StaticShape{2, 4, 3}, StaticShape{4}, StaticShape{}};
|
||||
output_shapes = {StaticShape{}};
|
||||
shape_infer(op.get(), input_shapes, output_shapes);
|
||||
EXPECT_EQ(output_shapes[0], (StaticShape{2, 4, 3}));
|
||||
}
|
Loading…
Reference in New Issue
Block a user