[ShapeInference] EmbeddingBag-Offsets/Packed-Sum shape infer improvements (#16072)

* shape_infer

* Register EmbeddingBagPackedSum shape_nfer for CPU

* Tests

* Merge shapes to preserve 2rd input info

* More label tests

* Add emb_table size check

* rename shape infer file

* Add more tests

* Update constexpr

* Use OV_EXPECT_THROW

* Style

* Reuse emb_table for dynamic rank

* Add common util to calculate emb output shape

* Update embd shape infer to use common util

* Update embedding shape infer util
This commit is contained in:
Katarzyna Mitrus 2023-03-15 11:12:57 +01:00 committed by GitHub
parent 69ba802e03
commit f0c153858b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 389 additions and 87 deletions

View File

@ -0,0 +1,39 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "openvino/core/node.hpp"
namespace ov {
namespace op {
namespace util {
namespace embedding {
/**
* \brief Return a copy of the `emb_table_shape` with the first dimension replaced
* by the first dimension from the `dim_shape_src`
*
* \tparam TShape Shape type
*
* \param op Pointer to operator.
* \param emb_table_shape The shape to be copied
* \param dim_shape_src The shape to copy the first dimension from, with dynamic or static rank > 1
*
* \return The copy of the `emb_table_shape` with the first dimsnsion overwritten by `dim_shape_src[0]` if the rank is
* static, otherwise fully dynamic shape with dynamic rank.
*/
template <class TShape>
TShape out_shape_infer(const ov::Node* op, const TShape& emb_table_shape, const TShape& dim_shape_src) {
if (emb_table_shape.rank().is_static()) {
NODE_VALIDATION_CHECK(op, emb_table_shape.size() > 0, "EMB_TABLE can't be a scalar.");
auto out_shape = emb_table_shape;
out_shape[0] = dim_shape_src.rank().is_static() ? dim_shape_src[0] : Dimension::dynamic();
return out_shape;
}
return PartialShape::dynamic();
}
} // namespace embedding
} // namespace util
} // namespace op
} // namespace ov

View File

@ -4,21 +4,21 @@
#pragma once
#include <openvino/core/validation_util.hpp>
#include <openvino/op/embeddingbag_offsets_sum.hpp>
#include "embedding_shape_infer_utils.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/embeddingbag_offsets_sum.hpp"
#include "utils.hpp"
namespace ov {
namespace op {
namespace util {
template <class T>
void shape_infer(const ov::op::util::EmbeddingBagOffsetsBase* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes) {
template <class TShape>
std::vector<TShape> shape_infer(const ov::op::util::EmbeddingBagOffsetsBase* op,
const std::vector<TShape>& input_shapes) {
const auto input_size = input_shapes.size();
NODE_VALIDATION_CHECK(op, (input_size >= 3 && input_size <= 5) && output_shapes.size() == 1);
NODE_VALIDATION_CHECK(op, (input_size >= 3 && input_size <= 5));
static constexpr int EMB_TABLE = 0;
static constexpr int INDICES = 1;
@ -26,33 +26,33 @@ void shape_infer(const ov::op::util::EmbeddingBagOffsetsBase* op,
static constexpr int DEFAULT_INDEX = 3;
static constexpr int PER_SAMPLE_WEIGHTS = 4;
NODE_VALIDATION_CHECK(op, input_shapes[INDICES].rank().compatible(1), "INDICES must be 1D");
NODE_VALIDATION_CHECK(op, input_shapes[OFFSETS].rank().compatible(1), "OFFSETS must be 1D");
NODE_VALIDATION_CHECK(op, input_shapes[INDICES].rank().compatible(1), "INDICES must be 1D.");
NODE_VALIDATION_CHECK(op, input_shapes[OFFSETS].rank().compatible(1), "OFFSETS must be 1D.");
if (input_size >= 4) {
NODE_VALIDATION_CHECK(op, input_shapes[DEFAULT_INDEX].rank().compatible(0), "DEFAULT_INDEX must be a scalar");
NODE_VALIDATION_CHECK(op, input_shapes[DEFAULT_INDEX].rank().compatible(0), "DEFAULT_INDEX must be a scalar.");
}
if (input_size == 5) {
NODE_VALIDATION_CHECK(op,
input_shapes[PER_SAMPLE_WEIGHTS].rank().compatible(1),
"PER_SAMPLE_WEIGHTS must be 1D");
"PER_SAMPLE_WEIGHTS must be 1D.");
NODE_VALIDATION_CHECK(op,
input_shapes[INDICES].compatible(input_shapes[PER_SAMPLE_WEIGHTS]),
"INDICES and PER_SAMPLE_WEIGHTS shape must be same");
"INDICES and PER_SAMPLE_WEIGHTS shape must be same.");
}
const auto& emb_table_shape = input_shapes[EMB_TABLE];
const auto& offsets_shape = input_shapes[OFFSETS];
if (emb_table_shape.rank().is_static()) {
output_shapes[0] = emb_table_shape;
output_shapes[0][0] = offsets_shape.rank().is_static() ? offsets_shape[0] : Dimension::dynamic();
} else {
output_shapes[0] = PartialShape::dynamic();
}
return {embedding::out_shape_infer(op, input_shapes[EMB_TABLE], input_shapes[OFFSETS])};
}
template <class TShape>
void shape_infer(const ov::op::util::EmbeddingBagOffsetsBase* op,
const std::vector<TShape>& input_shapes,
std::vector<TShape>& output_shapes) {
output_shapes = shape_infer(op, input_shapes);
}
} // namespace util
} // namespace op
} // namespace ov

View File

@ -0,0 +1,50 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "embedding_shape_infer_utils.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/util/embeddingbag_packed_base.hpp"
#include "utils.hpp"
namespace ov {
namespace op {
namespace util {
template <class TShape>
std::vector<TShape> shape_infer(const ov::op::util::EmbeddingBagPackedBase* op,
const std::vector<TShape>& input_shapes) {
const auto input_size = input_shapes.size();
NODE_VALIDATION_CHECK(op, input_size == 2 || input_size == 3);
constexpr size_t EMB_TABLE = 0;
constexpr size_t INDICES = 1;
constexpr size_t PER_SAMPLE_WEIGHTS = 2;
auto indices_shape = input_shapes[INDICES];
NODE_VALIDATION_CHECK(op, indices_shape.rank().compatible(2), "INDICES must be 2D.");
if (input_size == 3) {
NODE_VALIDATION_CHECK(op,
input_shapes[PER_SAMPLE_WEIGHTS].rank().compatible(2),
"PER_SAMPLE_WEIGHTS must be 2D.");
NODE_VALIDATION_CHECK(op,
TShape::merge_into(indices_shape, input_shapes[PER_SAMPLE_WEIGHTS]),
"INDICES and PER_SAMPLE_WEIGHTS shape must be same.");
}
return {embedding::out_shape_infer(op, input_shapes[EMB_TABLE], indices_shape)};
}
template <class TShape>
void shape_infer(const ov::op::util::EmbeddingBagPackedBase* op,
const std::vector<TShape>& input_shapes,
std::vector<TShape>& output_shapes) {
output_shapes = shape_infer(op, input_shapes);
}
} // namespace util
} // namespace op
} // namespace ov

View File

@ -79,16 +79,9 @@ void ov::op::util::EmbeddingBagOffsetsBase::validate_and_infer_types() {
")");
}
element::Type result_et = get_input_element_type(EMB_TABLE);
std::vector<PartialShape> result_shapes = {PartialShape::dynamic()};
std::vector<PartialShape> input_shapes;
for (size_t i = 0; i < get_input_size(); i++)
input_shapes.push_back(get_input_partial_shape(i));
shape_infer(this, input_shapes, result_shapes);
set_output_type(0, result_et, result_shapes[0]);
const auto& result_et = get_input_element_type(EMB_TABLE);
const auto input_shapes = get_node_input_partial_shapes(*this);
set_output_type(0, result_et, shape_infer(this, input_shapes)[0]);
}
bool ov::op::util::EmbeddingBagOffsetsBase::visit_attributes(AttributeVisitor& visitor) {

View File

@ -4,8 +4,10 @@
#include "ngraph/op/util/embeddingbag_packed_base.hpp"
#include "embeddingbag_packed_shape_inference.hpp"
#include "itt.hpp"
#include "ngraph/op/constant.hpp"
#include "openvino/core/validation_util.hpp"
using namespace std;
@ -28,11 +30,6 @@ void ov::op::util::EmbeddingBagPackedBase::validate_and_infer_types() {
get_input_element_type(INDICES) == element::i64 || get_input_element_type(INDICES) == element::i32,
"INDICES type must be i32 or i64");
NODE_VALIDATION_CHECK(
this,
get_input_partial_shape(INDICES).is_dynamic() || get_input_partial_shape(INDICES).to_shape().size() == 2,
"INDICES must be 2D");
if (get_input_size() == 3) {
NODE_VALIDATION_CHECK(this,
get_input_element_type(EMB_TABLE).compatible(get_input_element_type(PER_SAMPLE_WEIGHTS)),
@ -41,31 +38,11 @@ void ov::op::util::EmbeddingBagPackedBase::validate_and_infer_types() {
") must match embedding table element type (",
get_input_element_type(EMB_TABLE),
")");
NODE_VALIDATION_CHECK(this,
get_input_partial_shape(PER_SAMPLE_WEIGHTS).is_dynamic() ||
get_input_partial_shape(PER_SAMPLE_WEIGHTS).to_shape().size() == 2,
"PER_SAMPLE_WEIGHTS must be 2D");
NODE_VALIDATION_CHECK(this,
get_input_partial_shape(INDICES).compatible(get_input_partial_shape(PER_SAMPLE_WEIGHTS)),
"INDICES and PER_SAMPLE_WEIGHTS shape must be same");
}
element::Type result_et = get_input_element_type(EMB_TABLE);
const PartialShape& emb_table_shape = get_input_partial_shape(EMB_TABLE);
const PartialShape& indices_shape = get_input_partial_shape(INDICES);
PartialShape result_shape;
if (emb_table_shape.rank().is_static()) {
result_shape = emb_table_shape;
result_shape[0] = indices_shape.rank().is_static() ? indices_shape[0] : Dimension::dynamic();
} else {
result_shape = PartialShape::dynamic();
}
set_output_type(0, result_et, result_shape);
const auto& emb_et = get_input_element_type(EMB_TABLE);
const auto input_shapes = get_node_input_partial_shapes(*this);
set_output_type(0, emb_et, shape_infer(this, input_shapes)[0]);
}
bool ov::op::util::EmbeddingBagPackedBase::visit_attributes(AttributeVisitor& visitor) {

View File

@ -2,12 +2,48 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "common_test_utils/test_assertions.hpp"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
using namespace testing;
TEST(type_prop, ebos_default_ctor) {
auto emb_table = make_shared<op::Parameter>(element::f32, Shape{5, 2, 6});
auto indices = make_shared<op::Parameter>(element::i64, Shape{4});
auto offsets = make_shared<op::Parameter>(element::i64, Shape{3});
auto per_sample_weights = make_shared<op::Parameter>(element::f32, Shape{4});
auto default_index = make_shared<op::Parameter>(element::i64, Shape{});
auto op = make_shared<op::v3::EmbeddingBagOffsetsSum>();
op->set_arguments(OutputVector{emb_table, indices, offsets, default_index, per_sample_weights});
op->validate_and_infer_types();
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{3, 2, 6}));
EXPECT_EQ(op->get_output_element_type(0), element::f32);
}
TEST(type_prop, ebos_labeled_interval_dims) {
auto emb_shape = PartialShape{{5, 10}, {2, 4}, {1, 3}};
set_shape_labels(emb_shape, 10);
auto off_shape = PartialShape{{6, 8}};
set_shape_labels(off_shape, 20);
auto emb_table = make_shared<op::Parameter>(element::f32, emb_shape);
auto indices = make_shared<op::Parameter>(element::i64, PartialShape{{3, 4}});
auto offsets = make_shared<op::Parameter>(element::i64, off_shape);
auto per_sample_weights = make_shared<op::Parameter>(element::f32, PartialShape{{3, 4}});
auto default_index = make_shared<op::Parameter>(element::i64, Shape{});
auto op =
make_shared<op::v3::EmbeddingBagOffsetsSum>(emb_table, indices, offsets, default_index, per_sample_weights);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{{6, 8}, {2, 4}, {1, 3}}));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(20, 11, 12));
}
TEST(type_prop, ebos) {
auto emb_table = make_shared<op::Parameter>(element::f32, Shape{5, 2});
@ -231,6 +267,20 @@ TEST(type_prop, ebos_fail_indices_1d) {
}
}
TEST(type_prop, ebos_fail_emb_table_0d) {
auto emb_table = make_shared<op::Parameter>(element::f32, Shape{});
auto indices = make_shared<op::Parameter>(element::i64, Shape{4});
auto offsets = make_shared<op::Parameter>(element::i64, Shape{3});
auto per_sample_weights = make_shared<op::Parameter>(element::f32, Shape{4});
auto default_index = make_shared<op::Parameter>(element::i64, Shape{});
OV_EXPECT_THROW(
auto op =
make_shared<op::v3::EmbeddingBagOffsetsSum>(emb_table, indices, offsets, default_index, per_sample_weights),
NodeValidationFailure,
HasSubstr("EMB_TABLE can't be a scalar"));
}
TEST(type_prop, ebos_fail_offsets_1d) {
auto emb_table = make_shared<op::Parameter>(element::f32, Shape{5, 2});
auto indices = make_shared<op::Parameter>(element::i64, Shape{4});

View File

@ -2,12 +2,60 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "common_test_utils/test_assertions.hpp"
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
using namespace testing;
TEST(type_prop, ebps_default_ctor) {
auto emb_table = make_shared<op::Parameter>(element::f32, Shape{5, 2, 6});
auto indices = make_shared<op::Parameter>(element::i64, Shape{3, 4});
auto per_sample_weights = make_shared<op::Parameter>(element::f32, Shape{3, 4});
auto op = make_shared<op::v3::EmbeddingBagPackedSum>();
op->set_arguments(OutputVector{emb_table, indices, per_sample_weights});
op->validate_and_infer_types();
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{3, 2, 6}));
EXPECT_EQ(op->get_output_element_type(0), element::f32);
}
TEST(type_prop, ebps_labeled_interval_dims_2in) {
auto emb_shape = PartialShape{{5, 10}, {2, 4}, {1, 3}};
set_shape_labels(emb_shape, 10);
auto ind_shape = PartialShape{{6, 8}, 4};
set_shape_labels(ind_shape, 20);
auto emb_table = make_shared<op::Parameter>(element::f32, emb_shape);
auto indices = make_shared<op::Parameter>(element::i64, ind_shape);
auto op = make_shared<op::v3::EmbeddingBagPackedSum>(emb_table, indices);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{{6, 8}, {2, 4}, {1, 3}}));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(20, 11, 12));
}
TEST(type_prop, ebps_labeled_interval_dims_3in) {
auto emb_shape = PartialShape{{5, 10}, {2, 4}, {1, 3}};
set_shape_labels(emb_shape, 10);
auto ind_shape = PartialShape{{2, 6}, 4};
set_shape_labels(ind_shape, 20);
auto sample_shape = PartialShape{{4, 8}, 4};
set_shape_labels(sample_shape, 30);
auto emb_table = make_shared<op::Parameter>(element::f32, emb_shape);
auto indices = make_shared<op::Parameter>(element::i64, ind_shape);
auto per_sample_weights = make_shared<op::Parameter>(element::f32, sample_shape);
auto op = make_shared<op::v3::EmbeddingBagPackedSum>(emb_table, indices, per_sample_weights);
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{{4, 6}, {2, 4}, {1, 3}}));
EXPECT_THAT(get_shape_labels(op->get_output_partial_shape(0)), ElementsAre(30, 11, 12));
}
TEST(type_prop, ebps) {
auto emb_table = make_shared<op::Parameter>(element::f32, Shape{5, 2});
@ -115,6 +163,16 @@ TEST(type_prop, ebps_fail_indices_1d) {
}
}
TEST(type_prop, ebps_fail_emb_table_0d) {
auto emb_table = make_shared<op::Parameter>(element::f32, Shape{});
auto indices = make_shared<op::Parameter>(element::i64, Shape{3, 4});
auto per_sample_weights = make_shared<op::Parameter>(element::f32, Shape{3, 4});
OV_EXPECT_THROW(auto op = make_shared<op::v3::EmbeddingBagPackedSum>(emb_table, indices, per_sample_weights),
NodeValidationFailure,
HasSubstr("EMB_TABLE can't be a scalar"));
}
TEST(type_prop, ebps_fail_per_sample_weights_1d) {
auto emb_table = make_shared<op::Parameter>(element::f32, Shape{5, 2});
auto indices = make_shared<op::Parameter>(element::i64, Shape{3, 4});

View File

@ -28,6 +28,7 @@
#include "einsum_shape_inference.hpp"
#include "embedding_segments_sum_shape_inference.hpp"
#include "embeddingbag_offsets_shape_inference.hpp"
#include "embeddingbag_packed_shape_inference.hpp"
#include "experimental_detectron_detection_output_shape_inference.hpp"
#include "experimental_detectron_generate_proposals_shape_inference.hpp"
#include "experimental_detectron_prior_grid_generator_shape_inference.hpp"
@ -561,6 +562,7 @@ const IShapeInferCommonFactory::TRegistry IShapeInferCommonFactory::registry{
_OV_OP_SHAPE_INFER_REG(DFT, entryIOC),
_OV_OP_SHAPE_INFER_REG(Einsum, entryIO),
_OV_OP_SHAPE_INFER_REG(EmbeddingBagOffsetsSum, entryIO),
_OV_OP_SHAPE_INFER_REG(EmbeddingBagPackedSum, entryIO),
_OV_OP_SHAPE_INFER_REG(EmbeddingSegmentsSum, entryIOC),
_OV_OP_SHAPE_INFER_REG(ExperimentalDetectronDetectionOutput, entryIO),
_OV_OP_SHAPE_INFER_REG(ExperimentalDetectronGenerateProposalsSingleImage, entryIO),

View File

@ -0,0 +1,94 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <array>
#include "common_test_utils/test_assertions.hpp"
#include "embeddingbag_offsets_shape_inference.hpp"
#include "gmock/gmock.h"
#include "openvino/opsets/opset10.hpp"
#include "utils.hpp"
using namespace ov;
using namespace ov::intel_cpu;
using namespace ov::opset10;
using namespace testing;
class EmbeddingBagOffsetsSumV3StaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v3::EmbeddingBagOffsetsSum> {
protected:
void SetUp() override {
output_shapes.resize(1);
}
};
TEST_F(EmbeddingBagOffsetsSumV3StaticShapeInferenceTest, default_ctor) {
const auto op = make_op();
const auto batch = 8;
auto expected_output = StaticShape{batch, 4, 5, 6};
// 3 inputs
{
input_shapes = {StaticShape{3, 4, 5, 6}, StaticShape{2}, StaticShape{batch}};
shape_infer(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], expected_output);
}
// 4 inputs
{
input_shapes = {StaticShape{3, 4, 5, 6}, StaticShape{2}, StaticShape{batch}, StaticShape{}};
shape_infer(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], expected_output);
}
// 5 inputs
{
input_shapes = {StaticShape{3, 4, 5, 6}, StaticShape{2}, StaticShape{batch}, StaticShape{}, StaticShape{2}};
shape_infer(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], expected_output);
}
}
TEST_F(EmbeddingBagOffsetsSumV3StaticShapeInferenceTest, basic_3in) {
auto emb_table = std::make_shared<Parameter>(element::f32, ov::PartialShape::dynamic());
auto indices = std::make_shared<Parameter>(element::i64, ov::PartialShape::dynamic());
auto offsets = std::make_shared<Parameter>(element::i64, ov::PartialShape::dynamic());
auto op = make_op(emb_table, indices, offsets);
auto expected_output = StaticShape{3, 2, 6};
input_shapes = {StaticShape{5, 2, 6}, StaticShape{4}, StaticShape{3}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], expected_output);
}
TEST_F(EmbeddingBagOffsetsSumV3StaticShapeInferenceTest, basic_4in) {
auto emb_table = std::make_shared<Parameter>(element::f32, ov::PartialShape::dynamic());
auto indices = std::make_shared<Parameter>(element::i64, ov::PartialShape::dynamic());
auto offsets = std::make_shared<Parameter>(element::i64, ov::PartialShape::dynamic());
auto default_index = std::make_shared<Parameter>(element::i64, ov::PartialShape::dynamic());
auto op = make_op(emb_table, indices, offsets, default_index);
auto expected_output = StaticShape{3, 2, 6};
input_shapes = {StaticShape{5, 2, 6}, StaticShape{4}, StaticShape{3}, StaticShape{}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], expected_output);
}
TEST_F(EmbeddingBagOffsetsSumV3StaticShapeInferenceTest, basic_5in) {
auto emb_table = std::make_shared<Parameter>(element::f32, ov::PartialShape::dynamic());
auto indices = std::make_shared<Parameter>(element::i64, ov::PartialShape::dynamic());
auto offsets = std::make_shared<Parameter>(element::i64, ov::PartialShape::dynamic());
auto default_index = std::make_shared<Parameter>(element::i64, ov::PartialShape::dynamic());
auto per_sample_weights = std::make_shared<Parameter>(element::f32, ov::PartialShape::dynamic());
auto op = make_op(emb_table, indices, offsets, default_index, per_sample_weights);
auto expected_output = StaticShape{3, 2, 6};
input_shapes = {StaticShape{5, 2, 6}, StaticShape{4}, StaticShape{3}, StaticShape{}, StaticShape{4}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], expected_output);
}

View File

@ -1,27 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <embeddingbag_offsets_shape_inference.hpp>
#include "utils.hpp"
using namespace ov;
using namespace ov::intel_cpu;
using namespace std;
TEST(StaticShapeInferenceTest, EmbeddingBagOffsetsSumV3) {
auto emb_table = make_shared<op::v0::Parameter>(element::f32, ov::PartialShape::dynamic());
auto indices = make_shared<op::v0::Parameter>(element::i64, ov::PartialShape::dynamic());
auto offsets = make_shared<op::v0::Parameter>(element::i64, ov::PartialShape::dynamic());
auto default_index = make_shared<op::v0::Parameter>(element::i64, ov::PartialShape::dynamic());
auto per_sample_weights = make_shared<op::v0::Parameter>(element::f32, ov::PartialShape::dynamic());
auto ebos =
make_shared<op::v3::EmbeddingBagOffsetsSum>(emb_table, indices, offsets, default_index, per_sample_weights);
check_static_shape(
ebos.get(),
{StaticShape{5, 2}, StaticShape{4}, StaticShape{3}, StaticShape{}, StaticShape{4}},
{StaticShape{3, 2}});
}

View File

@ -0,0 +1,66 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <array>
#include "common_test_utils/test_assertions.hpp"
#include "embeddingbag_packed_shape_inference.hpp"
#include "gmock/gmock.h"
#include "openvino/opsets/opset10.hpp"
#include "utils.hpp"
using namespace ov;
using namespace ov::intel_cpu;
using namespace ov::opset10;
using namespace testing;
class EmbeddingBagPackedSumV3StaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v3::EmbeddingBagPackedSum> {
protected:
void SetUp() override {
output_shapes.resize(1);
}
};
TEST_F(EmbeddingBagPackedSumV3StaticShapeInferenceTest, default_ctor) {
const auto op = make_op();
const auto batch = 8;
auto expected_output = StaticShape{batch, 4, 5, 6};
// 2 inputs
{
input_shapes = {StaticShape{3, 4, 5, 6}, StaticShape{batch, 2}};
shape_infer(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], expected_output);
}
// 3 inputs
{
input_shapes = {StaticShape{3, 4, 5, 6}, StaticShape{batch, 2}, StaticShape{batch, 2}};
shape_infer(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], expected_output);
}
}
TEST_F(EmbeddingBagPackedSumV3StaticShapeInferenceTest, basic_2in) {
auto emb_table = std::make_shared<Parameter>(element::f32, ov::PartialShape::dynamic());
auto indices = std::make_shared<Parameter>(element::i64, ov::PartialShape::dynamic());
auto op = make_op(emb_table, indices);
input_shapes = {StaticShape{5, 2, 6}, StaticShape{3, 4}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], (StaticShape{3, 2, 6}));
}
TEST_F(EmbeddingBagPackedSumV3StaticShapeInferenceTest, basic_3in) {
auto emb_table = std::make_shared<Parameter>(element::f32, ov::PartialShape::dynamic());
auto indices = std::make_shared<Parameter>(element::i64, ov::PartialShape::dynamic());
auto per_sample_weights = std::make_shared<Parameter>(element::f32, ov::PartialShape::dynamic());
auto op = make_op(emb_table, indices, per_sample_weights);
input_shapes = {StaticShape{5, 2, 6}, StaticShape{3, 4}, StaticShape{3, 4}};
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes[0], (StaticShape{3, 2, 6}));
}