Review opset1 variadic split for shape inference aspects (#13894)

* Review variadic split shape_infer template
- extend static shape tests
- add default ctor tests

* Use OV_EXPECT_THROW in exception tests

* Review evaluate upper, lower and label propagation

* Fix usage broadcast in bound tests

* VariadicSplit bound check in evaluate lower,upper

* Clean-up tests leftovers
This commit is contained in:
Pawel Raasz 2022-11-11 12:29:42 +01:00 committed by GitHub
parent b8ada02cba
commit cb067de597
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 377 additions and 206 deletions

View File

@ -39,11 +39,16 @@ public:
}
OPENVINO_SUPPRESS_DEPRECATED_START
bool evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs) const override;
bool evaluate_lower(const HostTensorVector& outputs) const override;
bool evaluate_upper(const HostTensorVector& outputs) const override;
OPENVINO_SUPPRESS_DEPRECATED_END
bool has_evaluate() const override;
bool evaluate_label(TensorLabelVector& output_labels) const override;
private:
bool evaluate_variadic_split(const HostTensorVector& outputs, const HostTensorVector& inputs) const;
bool have_axis_and_splits_bound_set() const;
};
} // namespace v1
} // namespace op

View File

@ -110,14 +110,12 @@ void shape_infer(const VariadicSplit* op,
auto out_shape = data_shape;
out_shape[axis] = Dimension::dynamic();
for (int64_t output = 0; output < num_outputs; ++output)
output_shapes.push_back(out_shape);
output_shapes.resize(num_outputs, out_shape);
}
} else {
// we only know num_outputs, only predict the rank
auto out_shape = ov::PartialShape::dynamic(data_shape.rank());
for (int64_t output = 0; output < num_outputs; ++output)
output_shapes.push_back(out_shape);
output_shapes.resize(num_outputs, out_shape);
}
} else {
// we don't even known the number of outputs in this case.

View File

@ -6,6 +6,7 @@
#include <numeric>
#include "compare.hpp"
#include "itt.hpp"
#include "ngraph/runtime/reference/slice.hpp"
#include "ngraph/validation_util.hpp"
@ -30,13 +31,11 @@ bool ngraph::op::v1::VariadicSplit::visit_attributes(AttributeVisitor& visitor)
void ngraph::op::v1::VariadicSplit::validate_and_infer_types() {
OV_OP_SCOPE(v1_VariadicSplit_validate_and_infer_types);
set_input_is_relevant_to_value(0);
set_input_is_relevant_to_value(1);
set_input_is_relevant_to_value(2);
for (size_t i = 0; i < get_input_size(); ++i) {
set_input_is_relevant_to_value(i);
}
std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0),
get_input_partial_shape(1),
get_input_partial_shape(2)};
const auto input_shapes = get_node_input_partial_shapes(*this);
std::vector<ov::PartialShape> output_shapes;
shape_infer(this, input_shapes, output_shapes);
@ -59,9 +58,7 @@ inline bool evaluate(const HostTensorPtr& in,
const Coordinate& lower_bounds,
const Coordinate& upper_bounds) {
const auto& output_shape = out->get_shape();
auto has_nonzero_dims = std::all_of(output_shape.begin(), output_shape.end(), [](size_t dim) {
return dim != 0;
});
const auto has_nonzero_dims = std::none_of(output_shape.begin(), output_shape.end(), ov::cmp::Equal<size_t>(0));
if (has_nonzero_dims) {
runtime::reference::slice(in->get_data_ptr<const char>(),
@ -122,3 +119,28 @@ bool op::v1::VariadicSplit::has_evaluate() const {
OV_OP_SCOPE(v1_VariadicSplit_has_evaluate);
return get_input_element_type(1).is_integral_number() && get_input_element_type(2).is_integral_number();
}
bool op::v1::VariadicSplit::have_axis_and_splits_bound_set() const {
for (size_t i = 1; i < get_input_size(); ++i) {
if (!get_input_tensor(i).has_and_set_bound()) {
return false;
}
}
return true;
}
bool op::v1::VariadicSplit::evaluate_lower(const HostTensorVector& output_values) const {
OV_OP_SCOPE(v1_Split_evaluate_lower);
return has_evaluate() && have_axis_and_splits_bound_set() && default_lower_bound_evaluator(this, output_values);
}
bool op::v1::VariadicSplit::evaluate_upper(const HostTensorVector& output_values) const {
OV_OP_SCOPE(v1_Split_evaluate_upper);
return has_evaluate() && have_axis_and_splits_bound_set() && default_upper_bound_evaluator(this, output_values);
}
bool op::v1::VariadicSplit::evaluate_label(TensorLabelVector& output_labels) const {
return have_axis_and_splits_bound_set() && default_label_evaluator(this, output_labels);
}

View File

@ -1349,12 +1349,12 @@ bool ov::default_label_evaluator(const Node* node, TensorLabelVector& output_lab
HostTensorVector input_tensors(input_values.size());
for (size_t i = 0; i < input_values.size(); ++i) {
const auto& input = input_values[i];
if (i != 0)
if (i != 0) {
if (input.get_tensor().has_and_set_bound())
input_tensors[i] = input.get_tensor().get_lower_value();
else
return false;
else {
} else {
const auto& input_labels = input.get_tensor().get_value_label();
if (has_no_labels(input_labels)) {
return false;

View File

@ -353,29 +353,25 @@ INSTANTIATE_TEST_SUITE_P(
PrintToStringParamName());
TEST_P(SplitBoundTest, propagate_label_and_dynamic_value) {
PartialShape labeled_shape = PartialShape{p_shape};
const auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::i32, PartialShape{1});
const auto in_exp_labels = make_in_exp_labels();
set_shape_labels(labeled_shape, in_exp_labels.first);
set_shape_labels(p_shape, in_exp_labels.first);
constexpr auto et = element::i64;
const auto labeled_param = std::make_shared<op::Parameter>(et, labeled_shape);
const auto labeled_param = std::make_shared<op::Parameter>(et, p_shape);
const auto labeled_shape_of = std::make_shared<op::ShapeOf>(labeled_param);
const auto zero = std::vector<int64_t>{0};
const auto axis = std::make_shared<op::v0::Constant>(et, Shape{}, zero);
const auto indices = std::make_shared<op::v0::Constant>(et, Shape{}, zero);
const auto split = std::make_shared<op::v1::Split>(labeled_shape_of, axis, num_of_splits);
for (auto& output : split->outputs()) {
const auto& bc = std::make_shared<op::v3::Broadcast>(param, output);
const auto& bc = std::make_shared<op::v3::Broadcast>(
std::make_shared<ov::op::v0::Parameter>(ov::element::i32, PartialShape{1}),
output);
out_shapes.push_back(bc->get_output_partial_shape(0));
out_labels.push_back(get_shape_labels(bc->get_output_partial_shape(0)));
}
auto exp_labels_it = in_exp_labels.second.begin();
EXPECT_EQ(out_shapes, exp_shapes);
EXPECT_EQ(out_labels, in_exp_labels.second);
}

View File

@ -41,9 +41,12 @@ TEST(type_prop, squeeze_incorrect_negative_axes) {
HasSubstr("Parameter axis -10 out of the tensor rank range"));
}
using SplitTypePropTestParam = std::tuple<PartialShape, std::vector<int64_t>, PartialShape>;
using SqueezeTypePropTestParam = std::tuple<PartialShape, // Input shape
std::vector<int64_t>, // Squeeze axis
PartialShape // Expected shape
>;
class SqueezeTest : public WithParamInterface<SplitTypePropTestParam>, public UnSqueezeFixture {
class SqueezeTest : public WithParamInterface<SqueezeTypePropTestParam>, public UnSqueezeFixture {
protected:
void SetUp() override {
std::tie(p_shape, axes, exp_shape) = GetParam();

View File

@ -79,9 +79,13 @@ TEST(type_prop, unsqueeze_empty_axes) {
FAIL() << "Deduced type check failed for unexpected reason";
}
}
using TypePropTestParam = std::tuple<PartialShape, std::vector<int64_t>, PartialShape>;
class UnsqueezeTest : public WithParamInterface<TypePropTestParam>, public UnSqueezeFixture {
using UnSqueezeTypePropTestParam = std::tuple<PartialShape, // Input shape
std::vector<int64_t>, // Unsqueeze axis
PartialShape // Expected shape
>;
class UnsqueezeTest : public WithParamInterface<UnSqueezeTypePropTestParam>, public UnSqueezeFixture {
protected:
void SetUp() override {
std::tie(p_shape, axes, exp_shape) = GetParam();

View File

@ -2,149 +2,225 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "gtest/gtest.h"
#include "common_test_utils/test_assertions.hpp"
#include "dimension_tracker.hpp"
#include "gmock/gmock.h"
#include "ngraph/ngraph.hpp"
#include "sequnce_generator.hpp"
#include "util/type_prop.hpp"
using namespace std;
using namespace ngraph;
using namespace testing;
TEST(type_prop, variadic_split) {
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
const auto axis = op::Constant::create<int64_t>(element::i64, Shape{}, {1});
const auto splits = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 4});
const auto split = make_shared<op::v1::VariadicSplit>(data, axis, splits);
EXPECT_EQ(split->outputs().size(), 2);
EXPECT_EQ(split->get_output_shape(0), (Shape{2, 2}));
EXPECT_EQ(split->get_output_shape(1), (Shape{2, 4}));
EXPECT_EQ(split->get_output_element_type(0), element::i32);
EXPECT_EQ(split->get_output_element_type(1), element::i32);
using VSplitTypePropTestParam = std::tuple<PartialShape, // Input shape
int64_t, // Split axis
std::vector<int64_t>, // Split lengths
PartialShapes // Expected shapes
>;
EXPECT_EQ(make_shared<op::v1::VariadicSplit>(make_shared<op::Parameter>(element::i32, Shape{12, 6}),
op::Constant::create<int64_t>(element::i64, Shape{}, {-2}),
op::Constant::create<int64_t>(element::i64, Shape{3}, {7, -1, 2}))
->output(1)
.get_shape(),
(Shape{3, 6}));
EXPECT_EQ(make_shared<op::v1::VariadicSplit>(make_shared<op::Parameter>(element::i32, Shape{12, 6}),
op::Constant::create<int64_t>(element::i64, Shape{}, {-2}),
op::Constant::create<int64_t>(element::i64, Shape{3}, {-1, 7, 2}))
->output(0)
.get_shape(),
(Shape{3, 6}));
EXPECT_EQ(make_shared<op::v1::VariadicSplit>(make_shared<op::Parameter>(element::i32, Shape{12, 1, 6}),
op::Constant::create<int64_t>(element::i64, Shape{1}, {2}),
op::Constant::create<int64_t>(element::i64, Shape{3}, {3, 1, 2}))
->output(2)
.get_shape(),
(Shape{12, 1, 2}));
EXPECT_EQ(make_shared<op::v1::VariadicSplit>(make_shared<op::Parameter>(element::i32, Shape{12, 6}),
op::Constant::create<int64_t>(element::i64, Shape{1}, {1}),
op::Constant::create<int64_t>(element::i64, Shape{2}, {6, 0}))
->output(1)
.get_shape(),
(Shape{12, 0}));
}
TEST(type_prop, variadic_split_splits_rank) {
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
try {
const auto axis = op::Constant::create<int64_t>(element::i64, Shape{}, {1});
const auto splits = op::Constant::create<int64_t>(element::i64, Shape{1, 2}, {2, 4});
const auto split = make_shared<op::v1::VariadicSplit>(data, axis, splits);
FAIL() << "Split node was created with incorrect data.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Split lengths should be a 1-D tensor. Got 2 instead."));
class VariadicSplitTest : public TestWithParam<VSplitTypePropTestParam> {
protected:
void SetUp() override {
std::tie(p_shape, axis, split_lengths, exp_shapes) = GetParam();
}
}
TEST(type_prop, variadic_split_incorrect_sum) {
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
try {
const auto axis = op::Constant::create<int64_t>(element::i64, Shape{}, {1});
const auto splits = op::Constant::create<int64_t>(element::i64, Shape{2}, {1, 6});
const auto split = make_shared<op::v1::VariadicSplit>(data, axis, splits);
FAIL() << "Split node was created with incorrect data.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(),
std::string("Total length of splits: 7 must match the length of the chosen axis: 6"));
PartialShapes get_output_partial_shapes(const Node& n) const {
PartialShapes out;
for (size_t i = 0; i < n.get_output_size(); ++i) {
out.push_back(n.get_output_partial_shape(i));
}
}
TEST(type_prop, variadic_split_incorrect_axis) {
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
try {
const auto axis = op::Constant::create<int64_t>(element::i64, Shape{}, {-5});
const auto splits = op::Constant::create<int64_t>(element::i64, Shape{2}, {2, 4});
const auto split = make_shared<op::v1::VariadicSplit>(data, axis, splits);
FAIL() << "Split node was created with incorrect data.";
} catch (const ngraph_error& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Parameter axis -5 out of the tensor rank range [-2, 1]."));
return out;
}
}
TEST(type_prop, variadic_split_splits_invalid_negative) {
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
std::pair<std::vector<size_t>, std::vector<size_t>> make_in_exp_labels() const {
std::vector<size_t> in_labels;
std::generate_n(std::back_inserter(in_labels), p_shape.size(), ov::SeqGen<size_t>(10));
try {
const auto axis = op::Constant::create<int64_t>(element::i64, Shape{}, {1});
const auto splits = op::Constant::create<int64_t>(element::i64, Shape{2}, {-2, 4});
const auto split = make_shared<op::v1::VariadicSplit>(data, axis, splits);
FAIL() << "Split node was created with incorrect data.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Invalid value -2 in split lengths input. Should be >= -1."));
auto exp_labels = in_labels;
const auto n_axis = normalize_axis("", axis, p_shape.rank());
exp_labels[n_axis] = ov::no_label;
return {in_labels, exp_labels};
}
int64_t axis;
std::vector<int64_t> split_lengths;
PartialShape p_shape;
PartialShapes exp_shapes;
};
INSTANTIATE_TEST_SUITE_P(type_prop_static_shape,
VariadicSplitTest,
Values(std::make_tuple(PartialShape{6, 2}, 0, std::vector<int64_t>{6}, PartialShapes{{6, 2}}),
std::make_tuple(PartialShape{6, 2, 10},
-1,
std::vector<int64_t>{6, -1, 3},
PartialShapes{{6, 2, 6}, {6, 2, 1}, {6, 2, 3}}),
std::make_tuple(PartialShape{1, 20, 3},
1,
std::vector<int64_t>{-1, 10, 3, 5},
PartialShapes{{1, 2, 3}, {1, 10, 3}, {1, 3, 3}, {1, 5, 3}})),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(
type_prop_dynamic_shape,
VariadicSplitTest,
Values(
std::make_tuple(PartialShape{{2, 6}, 2}, 0, std::vector<int64_t>{4}, PartialShapes{{4, 2}}),
std::make_tuple(PartialShape{{2, 6}, 2},
0,
std::vector<int64_t>{4, 1, -1},
PartialShapes{{4, 2}, {1, 2}, {-1, 2}}),
std::make_tuple(PartialShape{12, Dimension()},
-2,
std::vector<int64_t>{7, -1, 2},
PartialShapes{{7, -1}, {3, -1}, {2, -1}}),
std::make_tuple(PartialShape{Dimension(), Dimension(), 6},
2,
std::vector<int64_t>{3, 1, 2},
PartialShapes{{-1, -1, 3}, {-1, -1, 1}, {-1, -1, 2}}),
std::make_tuple(PartialShape{Dimension(), 6}, 1, std::vector<int64_t>{6, 0}, PartialShapes{{-1, 6}, {-1, 0}}),
std::make_tuple(PartialShape{{2, 4}, Dimension::dynamic()},
1,
std::vector<int64_t>{4, 1, -1, 3},
PartialShapes{{{2, 4}, 4}, {{2, 4}, 1}, {{2, 4}, -1}, {{2, 4}, 3}})),
PrintToStringParamName());
TEST_P(VariadicSplitTest, dimension_propagation_axis_scalar) {
constexpr auto dtype = element::i32;
const auto data = make_shared<op::Parameter>(dtype, p_shape);
const auto axis_node = make_shared<op::Constant>(element::i16, Shape{}, axis);
const auto lengths_node = std::make_shared<op::Constant>(element::i16, Shape{split_lengths.size()}, split_lengths);
const auto var_split = make_shared<op::v1::VariadicSplit>(data, axis_node, lengths_node);
EXPECT_EQ(var_split->get_output_size(), split_lengths.size());
EXPECT_THAT(var_split->outputs(), Each(Property("Element type", &Output<Node>::get_element_type, dtype)));
EXPECT_THAT(get_output_partial_shapes(*var_split), ElementsAreArray(exp_shapes));
}
TEST(type_prop, variadic_split_splits_multiple_negatives) {
const auto data = make_shared<op::Parameter>(element::i32, Shape{2, 6});
TEST_P(VariadicSplitTest, dimension_propagation_axis_1d) {
constexpr auto dtype = element::u64;
const auto data = make_shared<op::Parameter>(dtype, p_shape);
const auto axis_node = make_shared<op::Constant>(element::i32, Shape{1}, axis);
const auto lengths_node = std::make_shared<op::Constant>(element::i32, Shape{split_lengths.size()}, split_lengths);
try {
const auto axis = op::Constant::create<int64_t>(element::i64, Shape{}, {1});
const auto splits = op::Constant::create<int64_t>(element::i64, Shape{3}, {-1, -1, 3});
const auto split = make_shared<op::v1::VariadicSplit>(data, axis, splits);
FAIL() << "Split node was created with incorrect data.";
} catch (const NodeValidationFailure& error) {
EXPECT_HAS_SUBSTRING(error.what(), std::string("Cannot infer split with multiple -1 values at 0 and 1"));
const auto var_split = make_shared<op::v1::VariadicSplit>(data, axis_node, lengths_node);
EXPECT_EQ(var_split->get_output_size(), split_lengths.size());
EXPECT_THAT(var_split->outputs(), Each(Property("Element type", &Output<Node>::get_element_type, dtype)));
EXPECT_THAT(get_output_partial_shapes(*var_split), ElementsAreArray(exp_shapes));
}
TEST_P(VariadicSplitTest, use_default_ctor) {
constexpr auto dtype = element::f32;
const auto param = make_shared<op::Parameter>(dtype, p_shape);
const auto axis_node = make_shared<op::Constant>(element::i64, Shape{}, axis);
const auto lengths_node = std::make_shared<op::Constant>(element::i64, Shape{split_lengths.size()}, split_lengths);
const auto var_split = make_shared<op::v1::VariadicSplit>();
var_split->set_arguments(NodeVector{param, axis_node, lengths_node});
var_split->validate_and_infer_types();
EXPECT_EQ(var_split->get_output_size(), split_lengths.size());
EXPECT_THAT(var_split->outputs(), Each(Property("Element type", &Output<Node>::get_element_type, dtype)));
EXPECT_THAT(get_output_partial_shapes(*var_split), ElementsAreArray(exp_shapes));
}
TEST_P(VariadicSplitTest, label_propagation) {
std::vector<size_t> in_labels, exp_labels;
std::tie(in_labels, exp_labels) = make_in_exp_labels();
set_shape_labels(p_shape, in_labels);
const auto data = make_shared<op::Parameter>(element::f32, p_shape);
const auto axis_node = make_shared<op::Constant>(element::i64, Shape{}, axis);
const auto lengths_node = std::make_shared<op::Constant>(element::i64, Shape{split_lengths.size()}, split_lengths);
const auto var_split = make_shared<op::v1::VariadicSplit>(data, axis_node, lengths_node);
EXPECT_EQ(var_split->get_output_size(), split_lengths.size());
EXPECT_THAT(
var_split->outputs(),
Each(Property("Partial shape", &Output<Node>::get_partial_shape, ResultOf(get_shape_labels, exp_labels))));
}
class VariadicSplitBoundTest : public VariadicSplitTest {
protected:
std::pair<std::vector<size_t>, std::vector<std::vector<size_t>>> make_in_exp_labels() const {
std::vector<size_t> in_labels;
std::generate_n(std::back_inserter(in_labels), p_shape.size(), ov::SeqGen<size_t>(8));
std::vector<std::vector<size_t>> exp_labels;
auto label_it = in_labels.begin();
for (auto split_length : split_lengths) {
if (split_length == 0) {
exp_labels.emplace_back(std::vector<size_t>(1, ov::no_label));
} else if (split_length == -1) {
split_length = std::accumulate(split_lengths.cbegin(),
split_lengths.cend(),
static_cast<int64_t>(p_shape.size()),
[](const int64_t& a, const int64_t& v) {
return (v != -1) ? a - v : a;
});
exp_labels.emplace_back(label_it, label_it + split_length);
} else {
exp_labels.emplace_back(label_it, label_it + split_length);
}
}
TEST(type_prop, variadic_split_shape_partially_dynamic) {
// Variadic split shape {12,?} into {7,?}, {3,?} and {2,?}
auto var_split1 =
make_shared<op::v1::VariadicSplit>(make_shared<op::Parameter>(element::i32, PartialShape{12, Dimension()}),
op::Constant::create<int64_t>(element::i64, Shape{}, {-2}),
op::Constant::create<int64_t>(element::i64, Shape{3}, {7, -1, 2}));
EXPECT_TRUE(var_split1->get_output_partial_shape(0).same_scheme(PartialShape{7, Dimension::dynamic()}));
EXPECT_TRUE(var_split1->get_output_partial_shape(1).same_scheme(PartialShape{3, Dimension::dynamic()}));
EXPECT_TRUE(var_split1->get_output_partial_shape(2).same_scheme(PartialShape{2, Dimension::dynamic()}));
// Variadic split shape {?,?,6} into {?,?,3}, {?,?,1} and {?,?,2}
auto var_split2 = make_shared<op::v1::VariadicSplit>(
make_shared<op::Parameter>(element::i32, PartialShape{Dimension(), Dimension(), 6}),
op::Constant::create<int64_t>(element::i64, Shape{}, {2}),
op::Constant::create<int64_t>(element::i64, Shape{3}, {3, 1, 2}));
EXPECT_TRUE(var_split2->get_output_partial_shape(0).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), 3}));
EXPECT_TRUE(var_split2->get_output_partial_shape(1).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), 1}));
EXPECT_TRUE(var_split2->get_output_partial_shape(2).same_scheme(
PartialShape{Dimension::dynamic(), Dimension::dynamic(), 2}));
// Variadic split shape {?,6} into {?,6}, and {?,0}
auto var_split3 =
make_shared<op::v1::VariadicSplit>(make_shared<op::Parameter>(element::i32, PartialShape{Dimension(), 6}),
op::Constant::create<int64_t>(element::i64, Shape{}, {1}),
op::Constant::create<int64_t>(element::i64, Shape{2}, {6, 0}));
EXPECT_TRUE(var_split3->get_output_partial_shape(0).same_scheme(PartialShape{Dimension::dynamic(), 6}));
EXPECT_TRUE(var_split3->get_output_partial_shape(1).same_scheme(PartialShape{Dimension::dynamic(), 0}));
label_it += split_length;
}
return {in_labels, exp_labels};
}
std::vector<PartialShape> out_shapes;
std::vector<std::vector<size_t>> out_labels;
};
INSTANTIATE_TEST_SUITE_P(type_prop_bounds_propagate,
VariadicSplitBoundTest,
Values(std::make_tuple(PartialShape{{2, 6}, 2, 3},
0,
std::vector<int64_t>{2, 1, 0},
PartialShapes{{{2, 6}, 2}, {3}, {1}}),
std::make_tuple(PartialShape{{2, 6}, 2, 3, {-1, 6}, 5},
0,
std::vector<int64_t>{1, -1, 0, 2},
PartialShapes{{{2, 6}}, {2, 3}, {1}, {{-1, 6}, 5}}),
std::make_tuple(PartialShape{{2, 6}, 2, 3, 8, 10, {-1, 6}, 5},
0,
std::vector<int64_t>{-1, 3, 0, 2},
PartialShapes{{{2, 6}, 2}, {3, 8, 10}, {1}, {{-1, 6}, 5}}),
std::make_tuple(PartialShape{{2, 6}, 2, 3, 5},
0,
std::vector<int64_t>{2, -1, 0},
PartialShapes{{{2, 6}, 2}, {3, 5}, {1}})),
PrintToStringParamName());
TEST_P(VariadicSplitBoundTest, propagate_label_and_dynamic_value) {
std::vector<size_t> in_labels;
std::vector<std::vector<size_t>> exp_labels;
std::tie(in_labels, exp_labels) = make_in_exp_labels();
set_shape_labels(p_shape, in_labels);
constexpr auto et = element::i64;
const auto labeled_param = std::make_shared<op::Parameter>(et, p_shape);
const auto labeled_shape_of = std::make_shared<op::ShapeOf>(labeled_param);
const auto zero = std::vector<int64_t>{0};
const auto axis_node = std::make_shared<op::v0::Constant>(et, Shape{}, zero);
const auto lengths_node = std::make_shared<op::Constant>(et, Shape{split_lengths.size()}, split_lengths);
const auto var_split = std::make_shared<op::v1::VariadicSplit>(labeled_shape_of, axis_node, lengths_node);
for (auto& output : var_split->outputs()) {
const auto& bc = std::make_shared<op::v3::Broadcast>(
std::make_shared<ov::op::v0::Parameter>(ov::element::i32, PartialShape{1}),
output,
"BIDIRECTIONAL");
out_shapes.push_back(bc->get_output_partial_shape(0));
out_labels.push_back(get_shape_labels(bc->get_output_partial_shape(0)));
}
EXPECT_EQ(out_shapes, exp_shapes);
EXPECT_EQ(out_labels, exp_labels);
}

View File

@ -0,0 +1,117 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "gmock/gmock.h"
#include "openvino/op/constant.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/variadic_split.hpp"
#include "utils.hpp"
#include "variadic_split_shape_inference.hpp"
using namespace ov;
using namespace ov::intel_cpu;
using namespace testing;
using VariadicSplitTestParams = std::tuple<ShapeVector, // Input shapes
int64_t, // Split axis
std::vector<int64_t>, // split lengths
ShapeVector // Expected shapes
>;
class VariadicSplitStaticShapeInferenceTest : public OpStaticShapeInferenceTest<op::v1::VariadicSplit>,
public WithParamInterface<VariadicSplitTestParams> {
protected:
void SetUp() override {
std::tie(input_shapes, axis, split_lengths, exp_shapes) = GetParam();
output_shapes = ShapeVector();
data = std::make_shared<op::v0::Parameter>(element::f32, input_shapes.front().get_shape());
}
int64_t axis;
std::vector<int64_t> split_lengths;
ShapeVector exp_shapes;
std::shared_ptr<op::v0::Parameter> data;
};
INSTANTIATE_TEST_SUITE_P(
1d_shapes,
VariadicSplitStaticShapeInferenceTest,
Values(make_tuple(ShapeVector{{0}, {}, {1}}, 0, std::vector<int64_t>{0}, ShapeVector{{0}}),
make_tuple(ShapeVector{{15}, {}, {3}}, -1, std::vector<int64_t>{2, 3, 10}, ShapeVector{{2}, {3}, {10}}),
make_tuple(ShapeVector{{15}, {}, {3}}, 0, std::vector<int64_t>{5, -1, 2}, ShapeVector{{5}, {8}, {2}})),
PrintToStringParamName());
INSTANTIATE_TEST_SUITE_P(multi_dim_shapes,
VariadicSplitStaticShapeInferenceTest,
Values(make_tuple(ShapeVector{{2, 6, 5}, {}, {3}},
2,
std::vector<int64_t>{2, 1, 2},
ShapeVector{{2, 6, 2}, {2, 6, 1}, {2, 6, 2}}),
make_tuple(ShapeVector{{2, 6, 5}, {}, {2}},
-2,
std::vector<int64_t>{2, 4},
ShapeVector{{2, 2, 5}, {2, 4, 5}}),
make_tuple(ShapeVector{{4, 6, 5}, {}, {3}},
0,
std::vector<int64_t>{-1, 3, 1},
ShapeVector{{0, 6, 5}, {3, 6, 5}, {1, 6, 5}}),
make_tuple(ShapeVector{{4, 6, 5}, {}, {3}},
0,
std::vector<int64_t>{3, -1, 1},
ShapeVector{{3, 6, 5}, {0, 6, 5}, {1, 6, 5}}),
make_tuple(ShapeVector{{4, 6, 5}, {}, {3}},
0,
std::vector<int64_t>{3, 1, -1},
ShapeVector{{3, 6, 5}, {1, 6, 5}, {0, 6, 5}})),
PrintToStringParamName());
TEST_P(VariadicSplitStaticShapeInferenceTest, shape_inference_empty_const_map) {
const auto axis_node = std::make_shared<op::v0::Constant>(element::i64, Shape{}, axis);
const auto split_len_node =
std::make_shared<op::v0::Constant>(element::i64, Shape{split_lengths.size()}, split_lengths);
op = make_op(data, axis_node, split_len_node);
shape_inference(op.get(), input_shapes, output_shapes);
EXPECT_EQ(output_shapes.size(), split_lengths.size());
EXPECT_EQ(output_shapes, exp_shapes);
}
TEST_P(VariadicSplitStaticShapeInferenceTest, shape_inference_axis_in_const_map) {
const auto axis_node = std::make_shared<op::v0::Parameter>(element::i64, ov::PartialShape::dynamic());
const auto split_len_node =
std::make_shared<op::v0::Constant>(element::i64, Shape{split_lengths.size()}, split_lengths);
op = make_op(data, axis_node, split_len_node);
const auto axis_const = std::make_shared<op::v0::Constant>(element::i64, ov::Shape{}, axis);
const auto axis_tensor = std::make_shared<ngraph::runtime::HostTensor>(axis_const);
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {{1, axis_tensor}};
shape_inference(op.get(), input_shapes, output_shapes, constant_data);
EXPECT_EQ(output_shapes.size(), split_lengths.size());
EXPECT_EQ(output_shapes, exp_shapes);
}
TEST_P(VariadicSplitStaticShapeInferenceTest, shape_inference_all_const_in_map) {
const auto axis_node = std::make_shared<op::v0::Parameter>(element::i64, ov::PartialShape::dynamic());
const auto split_len_node = std::make_shared<op::v0::Parameter>(element::i64, ov::PartialShape::dynamic());
op = make_op(data, axis_node, split_len_node);
const auto axis_const = std::make_shared<op::v0::Constant>(element::i64, Shape{}, axis);
const auto axis_tensor = std::make_shared<ngraph::runtime::HostTensor>(axis_const);
const auto split_len_const =
std::make_shared<op::v0::Constant>(element::i64, Shape{split_lengths.size()}, split_lengths);
const auto split_len_tensor = std::make_shared<ngraph::runtime::HostTensor>(split_len_const);
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data = {{2, split_len_tensor},
{1, axis_tensor}};
shape_inference(op.get(), input_shapes, output_shapes, constant_data);
EXPECT_EQ(output_shapes.size(), split_lengths.size());
EXPECT_EQ(output_shapes, exp_shapes);
}

View File

@ -1,51 +0,0 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <variadic_split_shape_inference.hpp>
#include "utils.hpp"
using namespace ov;
using namespace ov::intel_cpu;
static std::shared_ptr<op::v1::VariadicSplit> build_variadic_split(PartialShape data_shape,
std::initializer_list<int64_t> axis_value,
std::initializer_list<int64_t> splits) {
std::shared_ptr<ov::Node> axis;
std::shared_ptr<ov::Node> splits_len;
const auto data = std::make_shared<op::v0::Parameter>(element::i32, data_shape);
if (axis_value.size())
axis = op::v0::Constant::create(element::i64, ov::Shape{}, {*axis_value.begin()});
else
axis = std::make_shared<op::v0::Parameter>(element::i64, ov::PartialShape::dynamic(ov::Rank(0)));
if (splits.size())
splits_len = op::v0::Constant::create(element::i64, ov::Shape{splits.size()}, splits);
else
splits_len = std::make_shared<op::v0::Parameter>(element::i64, ov::PartialShape::dynamic(ov::Rank(1)));
return std::make_shared<op::v1::VariadicSplit>(data, axis, splits_len);
}
TEST(StaticShapeInferenceTest, VariadicSplitV1) {
const auto split = build_variadic_split(ov::PartialShape::dynamic(), {}, {});
check_static_shape(split.get(),
{StaticShape{12, 6}, {-2}, {7, -1, 2}},
{{7, 6}, {3, 6}, {2, 6}});
check_static_shape(split.get(),
{StaticShape{12, 6}, {-2}, {-1, 7, 2}},
{{3, 6}, {7, 6}, {2, 6}});
check_static_shape(split.get(),
{StaticShape{12, 1, 6}, {2}, {3, 1, 2}},
{{12, 1, 3}, {12, 1, 1}, {12, 1, 2}});
check_static_shape(split.get(), {StaticShape{12, 6}, {1}, {6, 0}}, {{12, 6}, {12, 0}});
}
TEST(StaticShapeInferenceTest, VariadicSplitV1_StaticWithConstMap) {
check_static_shape(build_variadic_split(ov::PartialShape{-1, -1}, {}, {}).get(),
{StaticShape{12, 6}, {-2}, {7, -1, 2}},
{{7, 6}, {3, 6}, {2, 6}});
}

View File

@ -34,10 +34,9 @@ protected:
std::shared_ptr<ov::op::v0::Parameter> param;
};
using BoundTestParam = std::tuple<ov::PartialShape, ov::PartialShape>;
/** \brief Test fixture for Unsqueeze/Squeeze type_prop bound tests. */
class UnSqueezeBoundTest : public testing::WithParamInterface<BoundTestParam>, public UnSqueezeFixture {
class UnSqueezeBoundTest : public testing::WithParamInterface<std::tuple<ov::PartialShape, ov::PartialShape>>,
public UnSqueezeFixture {
protected:
void SetUp() override {
std::tie(p_shape, exp_shape) = GetParam();
@ -46,3 +45,5 @@ protected:
std::vector<size_t> in_labels;
};
using PartialShapes = std::vector<ov::PartialShape>;