Review bucketize shape inference (#16136)
* Review bucketize shape inference: - check interval dimension and label propagation - check template shape_infer implementation - minor refactoring and add tests * Add missing using of namespaces
This commit is contained in:
parent
ab684036f4
commit
bdf1923972
@ -4,17 +4,15 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <openvino/core/validation_util.hpp>
|
|
||||||
#include <openvino/op/bucketize.hpp>
|
#include <openvino/op/bucketize.hpp>
|
||||||
|
|
||||||
#include "utils.hpp"
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
namespace op {
|
namespace op {
|
||||||
namespace v3 {
|
namespace v3 {
|
||||||
|
|
||||||
template <class T>
|
template <class TShape>
|
||||||
void shape_infer(const Bucketize* op, const std::vector<T>& input_shapes, std::vector<T>& output_shapes) {
|
std::vector<TShape> shape_infer(const Bucketize* op, const std::vector<TShape>& input_shapes) {
|
||||||
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 2) && output_shapes.size() == 1);
|
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 2));
|
||||||
|
|
||||||
const auto& data_shape = input_shapes[0];
|
const auto& data_shape = input_shapes[0];
|
||||||
const auto& buckets_shape = input_shapes[1];
|
const auto& buckets_shape = input_shapes[1];
|
||||||
@ -23,7 +21,12 @@ void shape_infer(const Bucketize* op, const std::vector<T>& input_shapes, std::v
|
|||||||
buckets_shape.rank().compatible(1),
|
buckets_shape.rank().compatible(1),
|
||||||
"Buckets input must be a 1D tensor. Got: ",
|
"Buckets input must be a 1D tensor. Got: ",
|
||||||
buckets_shape);
|
buckets_shape);
|
||||||
output_shapes[0] = data_shape;
|
return {data_shape};
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class TShape>
|
||||||
|
void shape_infer(const Bucketize* op, const std::vector<TShape>& input_shapes, std::vector<TShape>& output_shapes) {
|
||||||
|
output_shapes = shape_infer(op, input_shapes);
|
||||||
}
|
}
|
||||||
} // namespace v3
|
} // namespace v3
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
@ -1,15 +1,17 @@
|
|||||||
// Copyright (C) 2018-2023 Intel Corporation
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
#include "openvino/op/bucketize.hpp"
|
||||||
|
|
||||||
#include "ngraph/op/bucketize.hpp"
|
#include <array>
|
||||||
|
|
||||||
#include "bucketize_shape_inference.hpp"
|
#include "bucketize_shape_inference.hpp"
|
||||||
#include "itt.hpp"
|
#include "itt.hpp"
|
||||||
|
#include "openvino/core/validation_util.hpp"
|
||||||
|
|
||||||
using namespace ngraph;
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
op::v3::Bucketize::Bucketize(const Output<Node>& data,
|
op::v3::Bucketize::Bucketize(const Output<Node>& data,
|
||||||
const Output<Node>& buckets,
|
const Output<Node>& buckets,
|
||||||
const element::Type output_type,
|
const element::Type output_type,
|
||||||
@ -29,36 +31,28 @@ bool op::v3::Bucketize::visit_attributes(AttributeVisitor& visitor) {
|
|||||||
|
|
||||||
void op::v3::Bucketize::validate_and_infer_types() {
|
void op::v3::Bucketize::validate_and_infer_types() {
|
||||||
OV_OP_SCOPE(v3_Bucketize_validate_and_infer_types);
|
OV_OP_SCOPE(v3_Bucketize_validate_and_infer_types);
|
||||||
const ov::PartialShape& data_pshape = get_input_partial_shape(0);
|
static constexpr std::array<const char*, 2> input_names{"Data", "Buckets"};
|
||||||
const ov::PartialShape& buckets_pshape = get_input_partial_shape(1);
|
|
||||||
|
|
||||||
const auto data_et = get_input_element_type(0);
|
for (size_t i = 0; i < input_names.size(); ++i) {
|
||||||
const auto buckets_et = get_input_element_type(1);
|
const auto& in_et = get_input_element_type(i);
|
||||||
|
NODE_VALIDATION_CHECK(this,
|
||||||
NODE_VALIDATION_CHECK(this,
|
in_et.is_real() || in_et.is_integral_number(),
|
||||||
data_et.is_real() || data_et.is_integral_number(),
|
input_names[i],
|
||||||
"Data input type must be numeric. Got: ",
|
" input type must be numeric. Got: ",
|
||||||
data_et);
|
in_et);
|
||||||
|
}
|
||||||
NODE_VALIDATION_CHECK(this,
|
|
||||||
buckets_et.is_real() || buckets_et.is_integral_number(),
|
|
||||||
"Buckets input type must be numeric. Got: ",
|
|
||||||
buckets_et);
|
|
||||||
|
|
||||||
NODE_VALIDATION_CHECK(this,
|
NODE_VALIDATION_CHECK(this,
|
||||||
m_output_type == element::i64 || m_output_type == element::i32,
|
m_output_type == element::i64 || m_output_type == element::i32,
|
||||||
"Output type must be i32 or i64. Got: ",
|
"Output type must be i32 or i64. Got: ",
|
||||||
m_output_type);
|
m_output_type);
|
||||||
|
|
||||||
std::vector<ov::PartialShape> input_shapes = {data_pshape, buckets_pshape};
|
const auto input_shapes = get_node_input_partial_shapes(*this);
|
||||||
std::vector<ov::PartialShape> output_shapes = {ov::PartialShape::dynamic()};
|
const auto output_shapes = shape_infer(this, input_shapes);
|
||||||
shape_infer(this, input_shapes, output_shapes);
|
|
||||||
|
|
||||||
if (data_pshape.is_dynamic()) {
|
if (get_input_partial_shape(0).is_dynamic()) {
|
||||||
set_input_is_relevant_to_shape(0);
|
set_input_is_relevant_to_shape(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
set_output_size(1);
|
|
||||||
set_output_type(0, m_output_type, output_shapes[0]);
|
set_output_type(0, m_output_type, output_shapes[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -68,3 +62,4 @@ shared_ptr<Node> op::v3::Bucketize::clone_with_new_inputs(const OutputVector& in
|
|||||||
|
|
||||||
return make_shared<v3::Bucketize>(inputs.at(0), inputs.at(1), m_output_type, m_with_right_bound);
|
return make_shared<v3::Bucketize>(inputs.at(0), inputs.at(1), m_output_type, m_with_right_bound);
|
||||||
}
|
}
|
||||||
|
} // namespace ov
|
||||||
|
@ -2,123 +2,138 @@
|
|||||||
// SPDX-License-Identifier: Apache-2.0
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
//
|
//
|
||||||
|
|
||||||
|
#include "common_test_utils/test_assertions.hpp"
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "ngraph/ngraph.hpp"
|
#include "ngraph/ngraph.hpp"
|
||||||
|
#include "openvino/opsets/opset11.hpp"
|
||||||
#include "util/type_prop.hpp"
|
#include "util/type_prop.hpp"
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace ngraph;
|
using namespace ov;
|
||||||
|
using namespace ov::opset11;
|
||||||
|
using namespace testing;
|
||||||
|
|
||||||
|
class TypePropBucketizeV3Test : public TypePropOpTest<op::v3::Bucketize> {};
|
||||||
|
|
||||||
|
TEST_F(TypePropBucketizeV3Test, default_ctor) {
|
||||||
|
auto data = make_shared<Parameter>(element::f32, Shape{2, 3, 2});
|
||||||
|
auto buckets = make_shared<Parameter>(element::f32, Shape{4});
|
||||||
|
|
||||||
|
auto bucketize = make_op();
|
||||||
|
bucketize->set_arguments(OutputVector{data, buckets});
|
||||||
|
bucketize->set_output_type(element::i64);
|
||||||
|
bucketize->set_with_right_bound(true);
|
||||||
|
bucketize->validate_and_infer_types();
|
||||||
|
|
||||||
TEST(type_prop, bucketize) {
|
|
||||||
auto data = make_shared<op::Parameter>(element::f32, Shape{2, 3, 2});
|
|
||||||
auto buckets = make_shared<op::Parameter>(element::f32, Shape{4});
|
|
||||||
auto bucketize = make_shared<op::v3::Bucketize>(data, buckets);
|
|
||||||
EXPECT_EQ(bucketize->get_element_type(), element::i64);
|
|
||||||
EXPECT_TRUE(bucketize->get_with_right_bound());
|
EXPECT_TRUE(bucketize->get_with_right_bound());
|
||||||
EXPECT_TRUE(bucketize->get_output_partial_shape(0).same_scheme(PartialShape{2, 3, 2}));
|
EXPECT_EQ(bucketize->get_output_type(), element::i64);
|
||||||
|
EXPECT_EQ(bucketize->get_input_size(), 2);
|
||||||
|
EXPECT_EQ(bucketize->get_output_size(), 1);
|
||||||
|
EXPECT_EQ(bucketize->get_element_type(), element::i64);
|
||||||
|
EXPECT_EQ(bucketize->get_output_partial_shape(0), (PartialShape{2, 3, 2}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, bucketize_output_type) {
|
TEST_F(TypePropBucketizeV3Test, simple_shape) {
|
||||||
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
|
auto data = make_shared<Parameter>(element::f32, Shape{2, 3, 2});
|
||||||
auto buckets = make_shared<op::Parameter>(element::f32, Shape{5});
|
auto buckets = make_shared<Parameter>(element::f32, Shape{4});
|
||||||
auto bucketize = make_shared<op::v3::Bucketize>(data, buckets, element::i32);
|
auto bucketize = make_op(data, buckets);
|
||||||
|
|
||||||
|
EXPECT_TRUE(bucketize->get_with_right_bound());
|
||||||
|
EXPECT_EQ(bucketize->get_element_type(), element::i64);
|
||||||
|
EXPECT_EQ(bucketize->get_output_partial_shape(0), (PartialShape{2, 3, 2}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TypePropBucketizeV3Test, output_type_i32) {
|
||||||
|
auto data = make_shared<Parameter>(element::f32, Shape{1, 2, 3, 4});
|
||||||
|
auto buckets = make_shared<Parameter>(element::f32, Shape{5});
|
||||||
|
auto bucketize = make_op(data, buckets, element::i32);
|
||||||
|
|
||||||
ASSERT_EQ(bucketize->get_output_element_type(0), element::i32);
|
ASSERT_EQ(bucketize->get_output_element_type(0), element::i32);
|
||||||
EXPECT_TRUE(bucketize->get_output_partial_shape(0).same_scheme(PartialShape{1, 2, 3, 4}));
|
EXPECT_EQ(bucketize->get_output_partial_shape(0), (PartialShape{1, 2, 3, 4}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, bucketize_output_type_right_bound) {
|
TEST_F(TypePropBucketizeV3Test, output_type_right_bound) {
|
||||||
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
|
auto data = make_shared<Parameter>(element::f32, Shape{1, 2, 3, 4});
|
||||||
auto buckets = make_shared<op::Parameter>(element::f32, Shape{5});
|
auto buckets = make_shared<Parameter>(element::f32, Shape{5});
|
||||||
auto bucketize = make_shared<op::v3::Bucketize>(data, buckets, element::i32, false);
|
auto bucketize = make_op(data, buckets, element::i32, false);
|
||||||
|
|
||||||
ASSERT_EQ(bucketize->get_output_element_type(0), element::i32);
|
ASSERT_EQ(bucketize->get_output_element_type(0), element::i32);
|
||||||
EXPECT_TRUE(bucketize->get_output_partial_shape(0).same_scheme(PartialShape{1, 2, 3, 4}));
|
EXPECT_EQ(bucketize->get_output_partial_shape(0), (PartialShape{1, 2, 3, 4}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, bucketize_dynamic_input) {
|
TEST_F(TypePropBucketizeV3Test, dynamic_input) {
|
||||||
auto data = make_shared<op::Parameter>(element::f16, PartialShape{4, Dimension::dynamic()});
|
auto data_shape = PartialShape::dynamic();
|
||||||
auto buckets = make_shared<op::Parameter>(element::f32, Shape{5});
|
auto data = make_shared<Parameter>(element::f16, data_shape);
|
||||||
auto bucketize = make_shared<op::v3::Bucketize>(data, buckets);
|
auto buckets = make_shared<Parameter>(element::f32, Shape{5});
|
||||||
|
auto bucketize = make_op(data, buckets);
|
||||||
|
|
||||||
EXPECT_EQ(bucketize->get_element_type(), element::i64);
|
EXPECT_EQ(bucketize->get_element_type(), element::i64);
|
||||||
EXPECT_TRUE(bucketize->get_output_partial_shape(0).same_scheme(PartialShape{4, Dimension::dynamic()}));
|
EXPECT_EQ(bucketize->get_output_partial_shape(0), PartialShape::dynamic());
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, bucketize_dynamic_buckets) {
|
TEST_F(TypePropBucketizeV3Test, dynamic_buckets) {
|
||||||
auto data = make_shared<op::Parameter>(element::f16, PartialShape{4, Dimension::dynamic()});
|
auto data = make_shared<Parameter>(element::f16, PartialShape{4, Dimension::dynamic()});
|
||||||
auto buckets = make_shared<op::Parameter>(element::f32, PartialShape{Dimension::dynamic()});
|
auto buckets = make_shared<Parameter>(element::f32, PartialShape{Dimension::dynamic()});
|
||||||
auto bucketize = make_shared<op::v3::Bucketize>(data, buckets);
|
auto bucketize = make_op(data, buckets);
|
||||||
|
|
||||||
EXPECT_EQ(bucketize->get_element_type(), element::i64);
|
EXPECT_EQ(bucketize->get_element_type(), element::i64);
|
||||||
EXPECT_TRUE(bucketize->get_output_partial_shape(0).same_scheme(PartialShape{4, Dimension::dynamic()}));
|
EXPECT_EQ(bucketize->get_output_partial_shape(0), (PartialShape{4, Dimension::dynamic()}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, bucketize_invalid_input_types) {
|
TEST_F(TypePropBucketizeV3Test, interval_dimensions) {
|
||||||
// Invalid data input element type
|
auto data_shape = PartialShape{{10, 30}, {12, -1}, -1, {0, 30}};
|
||||||
try {
|
set_shape_labels(data_shape, 10);
|
||||||
auto data = make_shared<op::Parameter>(element::boolean, Shape{1, 2, 3, 4});
|
auto data = make_shared<Parameter>(element::f16, data_shape);
|
||||||
auto buckets = make_shared<op::Parameter>(element::f32, Shape{5});
|
auto buckets = make_shared<Parameter>(element::f32, PartialShape{{2, 4}});
|
||||||
auto bucketize = make_shared<op::v3::Bucketize>(data, buckets, element::i32);
|
auto bucketize = make_op(data, buckets);
|
||||||
// Data input expected to be of numeric type
|
|
||||||
FAIL() << "Invalid input type not detected";
|
|
||||||
} catch (const NodeValidationFailure& error) {
|
|
||||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Data input type must be numeric"));
|
|
||||||
} catch (...) {
|
|
||||||
FAIL() << "Input type check failed for unexpected reason";
|
|
||||||
}
|
|
||||||
|
|
||||||
// Invalid buckets input element type
|
EXPECT_EQ(bucketize->get_element_type(), element::i64);
|
||||||
try {
|
EXPECT_EQ(bucketize->get_output_partial_shape(0), data_shape);
|
||||||
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 2, 3, 4});
|
EXPECT_THAT(get_shape_labels(bucketize->get_output_partial_shape(0)), ElementsAre(10, 11, 12, 13));
|
||||||
auto buckets = make_shared<op::Parameter>(element::boolean, Shape{5});
|
}
|
||||||
auto bucketize = make_shared<op::v3::Bucketize>(data, buckets, element::i32);
|
|
||||||
// Buckets input expected to be of numeric type
|
TEST_F(TypePropBucketizeV3Test, invalid_data_element_type) {
|
||||||
FAIL() << "Invalid input type not detected";
|
auto data = make_shared<Parameter>(element::boolean, Shape{1, 2, 3, 4});
|
||||||
} catch (const NodeValidationFailure& error) {
|
auto buckets = make_shared<Parameter>(element::f32, Shape{5});
|
||||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Buckets input type must be numeric"));
|
OV_EXPECT_THROW(auto bucketize = make_op(data, buckets, element::i32),
|
||||||
} catch (...) {
|
NodeValidationFailure,
|
||||||
FAIL() << "Input type check failed for unexpected reason";
|
HasSubstr("Data input type must be numeric"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TypePropBucketizeV3Test, invalid_bucket_element_types) {
|
||||||
|
auto data = make_shared<Parameter>(element::f32, Shape{1, 2, 3, 4});
|
||||||
|
auto buckets = make_shared<Parameter>(element::boolean, Shape{5});
|
||||||
|
|
||||||
|
OV_EXPECT_THROW(auto bucketize = make_op(data, buckets, element::i32),
|
||||||
|
NodeValidationFailure,
|
||||||
|
HasSubstr("Buckets input type must be numeric"));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TypePropBucketizeV3Test, invalid_output_types) {
|
||||||
|
vector<element::Type_t> output_types = {element::f64,
|
||||||
|
element::f32,
|
||||||
|
element::f16,
|
||||||
|
element::bf16,
|
||||||
|
element::i16,
|
||||||
|
element::i8,
|
||||||
|
element::u64,
|
||||||
|
element::u32,
|
||||||
|
element::u16,
|
||||||
|
element::u8,
|
||||||
|
element::boolean};
|
||||||
|
auto data = make_shared<Parameter>(element::f32, PartialShape{4, Dimension::dynamic()});
|
||||||
|
auto buckets = make_shared<Parameter>(element::f32, Shape{5});
|
||||||
|
for (const auto& output_type : output_types) {
|
||||||
|
OV_EXPECT_THROW(auto bucketize = make_op(data, buckets, output_type),
|
||||||
|
NodeValidationFailure,
|
||||||
|
HasSubstr("Output type must be i32 or i64"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(type_prop, bucketize_invalid_output_types) {
|
TEST_F(TypePropBucketizeV3Test, invalid_buckets_dim) {
|
||||||
vector<ngraph::element::Type_t> output_types = {ngraph::element::f64,
|
auto data = make_shared<Parameter>(element::f32, PartialShape{4, Dimension::dynamic()});
|
||||||
ngraph::element::f32,
|
auto buckets = make_shared<Parameter>(element::f16, Shape{5, 5});
|
||||||
ngraph::element::f16,
|
OV_EXPECT_THROW(auto bucketize = make_op(data, buckets),
|
||||||
ngraph::element::bf16,
|
NodeValidationFailure,
|
||||||
ngraph::element::i16,
|
HasSubstr("Buckets input must be a 1D tensor"));
|
||||||
ngraph::element::i8,
|
|
||||||
ngraph::element::u64,
|
|
||||||
ngraph::element::u32,
|
|
||||||
ngraph::element::u16,
|
|
||||||
ngraph::element::u8,
|
|
||||||
ngraph::element::boolean};
|
|
||||||
auto data = make_shared<op::Parameter>(element::f32, PartialShape{4, Dimension::dynamic()});
|
|
||||||
auto buckets = make_shared<op::Parameter>(element::f32, Shape{5});
|
|
||||||
for (auto output_type : output_types) {
|
|
||||||
try {
|
|
||||||
auto bucketize = make_shared<op::v3::Bucketize>(data, buckets, output_type);
|
|
||||||
// Should have thrown, so fail if it didn't
|
|
||||||
FAIL() << "Invalid output type not detected";
|
|
||||||
} catch (const NodeValidationFailure& error) {
|
|
||||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Output type must be i32 or i64"));
|
|
||||||
} catch (...) {
|
|
||||||
FAIL() << "Deduced type check failed for unexpected reason";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(type_prop, bucketize_invalid_buckets_dim) {
|
|
||||||
auto data = make_shared<op::Parameter>(element::f32, PartialShape{4, Dimension::dynamic()});
|
|
||||||
auto buckets = make_shared<op::Parameter>(element::f16, Shape{5, 5});
|
|
||||||
try {
|
|
||||||
auto bucketize = make_shared<op::v3::Bucketize>(data, buckets);
|
|
||||||
// Should have thrown, so fail if it didn't
|
|
||||||
FAIL() << "Invalid output type not detected";
|
|
||||||
} catch (const NodeValidationFailure& error) {
|
|
||||||
EXPECT_HAS_SUBSTRING(error.what(), std::string("Buckets input must be a 1D tensor"));
|
|
||||||
} catch (...) {
|
|
||||||
FAIL() << "Buckets dimension check failed for unexpected reason";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -0,0 +1,64 @@
|
|||||||
|
// Copyright (C) 2018-2023 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "bucketize_shape_inference.hpp"
|
||||||
|
#include "common_test_utils/test_assertions.hpp"
|
||||||
|
#include "utils.hpp"
|
||||||
|
|
||||||
|
using namespace ov;
|
||||||
|
using namespace ov::intel_cpu;
|
||||||
|
using namespace testing;
|
||||||
|
|
||||||
|
class BucketizeV3StaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v3::Bucketize> {
|
||||||
|
void SetUp() override {
|
||||||
|
output_shapes.resize(1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(BucketizeV3StaticShapeInferenceTest, default_ctor) {
|
||||||
|
op = make_op();
|
||||||
|
op->set_output_type(element::i32);
|
||||||
|
op->set_with_right_bound(false);
|
||||||
|
|
||||||
|
input_shapes = ShapeVector{{3, 2, 7, 89}, {3}};
|
||||||
|
shape_inference(op.get(), input_shapes, output_shapes);
|
||||||
|
|
||||||
|
EXPECT_EQ(output_shapes.size(), 1);
|
||||||
|
EXPECT_EQ(output_shapes.front(), StaticShape({3, 2, 7, 89}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BucketizeV3StaticShapeInferenceTest, dynamic_rank_inputs) {
|
||||||
|
const auto data = std::make_shared<op::v0::Parameter>(element::f16, PartialShape::dynamic());
|
||||||
|
const auto buckets = std::make_shared<op::v0::Parameter>(element::f32, PartialShape::dynamic());
|
||||||
|
op = make_op(data, buckets, element::i32);
|
||||||
|
|
||||||
|
input_shapes = ShapeVector{{10, 12, 1}, {5}};
|
||||||
|
shape_inference(op.get(), input_shapes, output_shapes);
|
||||||
|
|
||||||
|
EXPECT_EQ(output_shapes.size(), 1);
|
||||||
|
EXPECT_EQ(output_shapes.front(), StaticShape({10, 12, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BucketizeV3StaticShapeInferenceTest, static_rank_inputs) {
|
||||||
|
const auto data = std::make_shared<op::v0::Parameter>(element::f16, PartialShape{-1, -1});
|
||||||
|
const auto buckets = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1});
|
||||||
|
op = make_op(data, buckets);
|
||||||
|
|
||||||
|
input_shapes = ShapeVector{{100, 11}, {1}};
|
||||||
|
shape_inference(op.get(), input_shapes, output_shapes);
|
||||||
|
|
||||||
|
EXPECT_EQ(output_shapes.size(), 1);
|
||||||
|
EXPECT_EQ(output_shapes.front(), StaticShape({100, 11}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(BucketizeV3StaticShapeInferenceTest, bucket_incorrect_rank) {
|
||||||
|
const auto data = std::make_shared<op::v0::Parameter>(element::f16, PartialShape{-1, -1});
|
||||||
|
const auto buckets = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{-1});
|
||||||
|
op = make_op(data, buckets, element::i32);
|
||||||
|
|
||||||
|
input_shapes = ShapeVector{{100, 11}, {2, 1}};
|
||||||
|
OV_EXPECT_THROW(shape_inference(op.get(), input_shapes, output_shapes),
|
||||||
|
NodeValidationFailure,
|
||||||
|
HasSubstr("Buckets input must be a 1D tensor"));
|
||||||
|
}
|
@ -1,19 +0,0 @@
|
|||||||
// Copyright (C) 2018-2023 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <bucketize_shape_inference.hpp>
|
|
||||||
|
|
||||||
#include "utils.hpp"
|
|
||||||
|
|
||||||
using namespace ov;
|
|
||||||
using namespace ov::intel_cpu;
|
|
||||||
using namespace std;
|
|
||||||
|
|
||||||
TEST(StaticShapeInferenceTest, BucketizeV3) {
|
|
||||||
auto data = make_shared<op::v0::Parameter>(element::f32, ov::PartialShape{-1, -1, -1});
|
|
||||||
auto buckets = make_shared<op::v0::Parameter>(element::f32, ov::PartialShape{-1});
|
|
||||||
auto bucketize = make_shared<op::v3::Bucketize>(data, buckets);
|
|
||||||
|
|
||||||
check_static_shape(bucketize.get(), {StaticShape{2, 3, 2}, StaticShape{4}}, {StaticShape{2, 3, 2}});
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user