Shape inference adoption for dimension tracking (#10016)
* Shape inference adoption for dimension tracking * Style * test adj * tests fixed
This commit is contained in:
parent
d5c837cc1b
commit
9ad09f2120
@ -87,7 +87,11 @@ template <typename T>
|
||||
void infer_using_scales(T& output_shape, const std::vector<int64_t>& axes, const std::vector<float>& scales) {
|
||||
size_t i = 0;
|
||||
static constexpr float epsilon = 1.0e-6f;
|
||||
for (auto axis : axes) {
|
||||
for (const auto& axis : axes) {
|
||||
if (scales[i] == 1.) {
|
||||
++i;
|
||||
continue;
|
||||
}
|
||||
const auto& current_dim = output_shape[axis];
|
||||
float multiplier = scales[i] + epsilon;
|
||||
if (current_dim.is_static()) {
|
||||
@ -109,6 +113,7 @@ void shape_infer(const Interpolate* op,
|
||||
std::vector<T>& output_shapes,
|
||||
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data) {
|
||||
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 3 || input_shapes.size() == 4) && output_shapes.size() == 1);
|
||||
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
|
||||
|
||||
const auto& input_shape = input_shapes[0];
|
||||
auto& output_shape = output_shapes[0];
|
||||
@ -137,9 +142,7 @@ void shape_infer(const Interpolate* op,
|
||||
|
||||
// Get padded input shape
|
||||
for (int64_t i = 0; i < input_rank; ++i) {
|
||||
if (input_shape[i].is_static()) {
|
||||
output_shape[i] = pads_begin[i] + pads_end[i] + input_shape[i].get_length();
|
||||
}
|
||||
output_shape[i] = DimType(pads_begin[i]) + DimType(pads_end[i]) + input_shape[i];
|
||||
}
|
||||
|
||||
if (op->m_attrs.shape_calculation_mode == Interpolate::ShapeCalcMode::SCALES) {
|
||||
@ -147,7 +150,7 @@ void shape_infer(const Interpolate* op,
|
||||
if (get_data_as_float<T>(2, op, scales, constant_data)) {
|
||||
infer_using_scales(output_shape, axes, scales);
|
||||
} else {
|
||||
for (auto axis : axes) {
|
||||
for (const auto& axis : axes) {
|
||||
output_shape[axis] = ov::Dimension::dynamic();
|
||||
}
|
||||
}
|
||||
@ -155,11 +158,11 @@ void shape_infer(const Interpolate* op,
|
||||
T target_spatial_shape;
|
||||
if (get_data_as_shape<T>(1, op, target_spatial_shape, constant_data)) {
|
||||
size_t i = 0;
|
||||
for (auto axis : axes) {
|
||||
for (const auto& axis : axes) {
|
||||
output_shape[axis] = target_spatial_shape[i++];
|
||||
}
|
||||
} else {
|
||||
for (auto axis : axes) {
|
||||
for (const auto& axis : axes) {
|
||||
output_shape[axis] = ov::Dimension::dynamic();
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,7 @@
|
||||
|
||||
#include "ngraph/op/matmul.hpp"
|
||||
|
||||
#include <dimension_tracker.hpp>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
|
||||
@ -151,6 +152,17 @@ ov::PartialShape validate_matmul_output_shape(const ov::PartialShape& arg0_shape
|
||||
: arg0_shape_tmp[i].get_max_length();
|
||||
}
|
||||
output_shape[i] = Dimension(lower_bound, upper_bound);
|
||||
// label setting
|
||||
size_t out_label = 0;
|
||||
size_t label_0 = ov::DimensionTracker::get_label(arg0_shape_tmp[i]);
|
||||
size_t label_1 = ov::DimensionTracker::get_label(arg1_shape_tmp[i]);
|
||||
if (label_0 == label_1 || label_1 == 0)
|
||||
out_label = label_0;
|
||||
else if (label_0 == 0)
|
||||
out_label = label_1;
|
||||
output_shape[i] = Dimension(lower_bound, upper_bound);
|
||||
if (out_label)
|
||||
ov::DimensionTracker::set_label(output_shape[i], out_label);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include "ngraph/op/reshape.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
#include <dimension_tracker.hpp>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
@ -88,6 +89,9 @@ void op::v1::Reshape::validate_and_infer_types() {
|
||||
auto upper_bound = std::make_shared<op::v0::Constant>(ub)->cast_vector<int64_t>();
|
||||
shape_can_be_calculated = true;
|
||||
NGRAPH_CHECK(lower_bound.size() == upper_bound.size());
|
||||
const TensorLabel& labels = get_input_source_output(1).get_tensor().get_value_label();
|
||||
NGRAPH_CHECK(labels.empty() || lower_bound.size() == labels.size());
|
||||
|
||||
for (size_t i = 0; i < lower_bound.size(); ++i) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
lower_bound[i] >= -1 && upper_bound[i] >= -1,
|
||||
@ -104,8 +108,10 @@ void op::v1::Reshape::validate_and_infer_types() {
|
||||
upper_bound[i] == std::numeric_limits<std::int32_t>::max()) {
|
||||
upper_bound[i] = std::numeric_limits<std::int64_t>::max();
|
||||
}
|
||||
|
||||
reshape_pattern.emplace_back(lower_bound[i], upper_bound[i]);
|
||||
auto d = Dimension(lower_bound[i], upper_bound[i]);
|
||||
if (!labels.empty() && labels[i])
|
||||
ov::DimensionTracker::set_label(d, labels[i]);
|
||||
reshape_pattern.emplace_back(d);
|
||||
}
|
||||
// For scalar case reshape_patter should be empty but scalar reshape pattern should be empty
|
||||
// or equal to 1
|
||||
@ -232,20 +238,155 @@ bool op::v1::Reshape::constant_fold(OutputVector& output_values, const OutputVec
|
||||
return false;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool fully_eq(const Dimension& rhs, const Dimension& lhs) {
|
||||
return rhs == lhs && ov::DimensionTracker::get_label(rhs) == ov::DimensionTracker::get_label(lhs) &&
|
||||
(ov::DimensionTracker::get_label(rhs) || rhs.is_static());
|
||||
}
|
||||
|
||||
Dimension resolve_minus_one(const Node* reshape_node,
|
||||
vector<Dimension>& input_product,
|
||||
vector<Dimension>& output_product) {
|
||||
std::vector<Dimension> to_delete_from_output, to_delete_from_input;
|
||||
Dimension input_const_part(1), output_const_part(1);
|
||||
|
||||
for (const auto& dim : output_product)
|
||||
if (!ov::DimensionTracker::get_label(dim) && dim.is_static()) {
|
||||
output_const_part *= dim;
|
||||
to_delete_from_output.push_back(dim);
|
||||
}
|
||||
|
||||
for (const auto& dim : input_product)
|
||||
if (!ov::DimensionTracker::get_label(dim) && dim.is_static()) {
|
||||
input_const_part *= dim;
|
||||
to_delete_from_input.push_back(dim);
|
||||
}
|
||||
|
||||
for (const auto& dim : to_delete_from_input) {
|
||||
input_product.erase(std::remove_if(input_product.begin(),
|
||||
input_product.end(),
|
||||
[=](const Dimension& d) {
|
||||
return fully_eq(dim, d);
|
||||
}),
|
||||
input_product.end());
|
||||
}
|
||||
for (const auto& dim : to_delete_from_output) {
|
||||
output_product.erase(std::remove_if(output_product.begin(),
|
||||
output_product.end(),
|
||||
[=](const Dimension& d) {
|
||||
return fully_eq(dim, d);
|
||||
}),
|
||||
output_product.end());
|
||||
}
|
||||
|
||||
to_delete_from_input.clear();
|
||||
to_delete_from_output.clear();
|
||||
|
||||
if (input_const_part != output_const_part) {
|
||||
input_product.push_back(input_const_part);
|
||||
output_product.push_back(output_const_part);
|
||||
}
|
||||
|
||||
for (const auto& out_dim : output_product) {
|
||||
const auto& it = std::find_if(input_product.begin(), input_product.end(), [out_dim](const Dimension& in_dim) {
|
||||
return fully_eq(out_dim, in_dim);
|
||||
});
|
||||
if (it != input_product.end()) {
|
||||
to_delete_from_output.push_back(out_dim);
|
||||
to_delete_from_input.push_back(out_dim);
|
||||
}
|
||||
}
|
||||
for (const auto& dim : to_delete_from_input) {
|
||||
input_product.erase(std::remove_if(input_product.begin(),
|
||||
input_product.end(),
|
||||
[=](const Dimension& d) {
|
||||
return fully_eq(dim, d);
|
||||
}),
|
||||
input_product.end());
|
||||
}
|
||||
for (const auto& dim : to_delete_from_output) {
|
||||
output_product.erase(std::remove_if(output_product.begin(),
|
||||
output_product.end(),
|
||||
[=](const Dimension& d) {
|
||||
return fully_eq(dim, d);
|
||||
}),
|
||||
output_product.end());
|
||||
}
|
||||
|
||||
if (output_product.empty() && input_product.size() == 1)
|
||||
return input_product[0];
|
||||
|
||||
Dimension input_dim(1), output_dim(1);
|
||||
for (const auto& i : input_product) {
|
||||
input_dim *= i;
|
||||
}
|
||||
for (const auto& i : output_product) {
|
||||
output_dim *= i;
|
||||
}
|
||||
|
||||
if (output_dim == 0) {
|
||||
NODE_VALIDATION_CHECK(reshape_node,
|
||||
input_dim == 0,
|
||||
"Cannot infer '-1' dimension with zero-size output "
|
||||
"dimension unless at least one input dimension is "
|
||||
"also zero-size");
|
||||
return Dimension(0);
|
||||
} else {
|
||||
if (input_dim.is_static() && output_dim.is_static()) {
|
||||
NODE_VALIDATION_CHECK(reshape_node,
|
||||
input_dim.get_length() % output_dim.get_length() == 0,
|
||||
"Non-'-1' output dimensions do not evenly divide the input dimensions");
|
||||
}
|
||||
if (output_dim.get_min_length() == 0 || output_dim == Dimension() || input_dim == Dimension()) {
|
||||
return Dimension::dynamic();
|
||||
} else {
|
||||
Dimension::value_type lower;
|
||||
if (input_dim.get_min_length() == 0)
|
||||
lower = 0;
|
||||
else if (input_dim.get_min_length() == -1 || output_dim.get_max_length() == 0 ||
|
||||
output_dim.get_max_length() == -1)
|
||||
lower = -1; // dynamic
|
||||
else
|
||||
lower = static_cast<Dimension::value_type>(
|
||||
ceil(static_cast<double>(input_dim.get_min_length()) / output_dim.get_max_length()));
|
||||
|
||||
Dimension::value_type upper;
|
||||
if (input_dim.get_max_length() == 0)
|
||||
upper = 0;
|
||||
else if (input_dim.get_max_length() == -1 || output_dim.get_min_length() == 0 ||
|
||||
output_dim.get_min_length() == -1)
|
||||
upper = -1; // dynamic
|
||||
else
|
||||
upper = static_cast<Dimension::value_type>(
|
||||
floor(static_cast<double>(input_dim.get_max_length()) / output_dim.get_min_length()));
|
||||
|
||||
if (lower == -1)
|
||||
return Dimension::dynamic();
|
||||
else if (upper == -1)
|
||||
return Dimension(lower, upper);
|
||||
else if (lower > upper) // empty intersection
|
||||
return Dimension::dynamic();
|
||||
else
|
||||
return Dimension(lower, upper);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void op::v1::Reshape::calculate_output_shape(vector<Dimension>& reshape_pattern,
|
||||
const int64_t& minus_one_idx,
|
||||
const ov::PartialShape& input_pshape,
|
||||
vector<Dimension>& output_shape) const {
|
||||
Dimension output_product(1);
|
||||
std::vector<Dimension> output_product;
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(reshape_pattern.size()); ++i) {
|
||||
if (i == minus_one_idx) // resolving everything except -1
|
||||
continue;
|
||||
|
||||
auto pattern_dim = reshape_pattern[i];
|
||||
if (pattern_dim.get_min_length() == 0 && pattern_dim.get_max_length() == 0 && get_special_zero()) {
|
||||
if (pattern_dim == 0 && get_special_zero()) {
|
||||
if (input_pshape.rank().is_dynamic()) {
|
||||
output_shape[i] = Dimension::dynamic();
|
||||
output_product *= Dimension::dynamic();
|
||||
output_product.push_back(Dimension::dynamic());
|
||||
} else {
|
||||
NODE_VALIDATION_CHECK(this, i < input_pshape.rank().get_length(), "'0' dimension is out of range");
|
||||
output_shape[i] = input_pshape[i];
|
||||
@ -257,71 +398,23 @@ void op::v1::Reshape::calculate_output_shape(vector<Dimension>& reshape_pattern,
|
||||
}
|
||||
} else {
|
||||
output_shape[i] = pattern_dim;
|
||||
output_product *= pattern_dim;
|
||||
output_product.push_back(pattern_dim);
|
||||
}
|
||||
}
|
||||
Dimension input_product(1);
|
||||
std::vector<Dimension> input_product;
|
||||
if (input_pshape.rank().is_static())
|
||||
for (int64_t i = 0; i < input_pshape.rank().get_length(); ++i) {
|
||||
if (i < static_cast<int64_t>(reshape_pattern.size()) && reshape_pattern[i].get_min_length() == 0 &&
|
||||
reshape_pattern[i].get_max_length() == 0)
|
||||
continue;
|
||||
input_product *= input_pshape[i];
|
||||
input_product.push_back(input_pshape[i]);
|
||||
}
|
||||
else
|
||||
input_product = Dimension::dynamic();
|
||||
input_product.push_back(Dimension::dynamic());
|
||||
|
||||
if (minus_one_idx != -1) // resolving -1 masked dimension
|
||||
{
|
||||
if (output_product.get_min_length() == 0 && output_product.get_max_length() == 0) {
|
||||
// TODO: Decide if this is desired behavior here. (NumPy seems
|
||||
// to fail.)
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_product.get_min_length() == 0 && input_product.get_max_length() == 0,
|
||||
"Cannot infer '-1' dimension with zero-size output "
|
||||
"dimension unless at least one input dimension is "
|
||||
"also zero-size");
|
||||
output_shape[minus_one_idx] = Dimension(0);
|
||||
} else {
|
||||
if (input_product.is_static() && output_product.is_static()) {
|
||||
NODE_VALIDATION_CHECK(this,
|
||||
input_product.get_length() % output_product.get_length() == 0,
|
||||
"Non-'-1' output dimensions do not evenly divide the input dimensions");
|
||||
}
|
||||
if (output_product.get_min_length() == 0 || output_product == Dimension() || input_product == Dimension()) {
|
||||
output_shape[minus_one_idx] = Dimension::dynamic();
|
||||
} else {
|
||||
Dimension::value_type lower;
|
||||
if (input_product.get_min_length() == 0)
|
||||
lower = 0;
|
||||
else if (input_product.get_min_length() == -1 || output_product.get_max_length() == 0 ||
|
||||
output_product.get_max_length() == -1)
|
||||
lower = -1; // dynamic
|
||||
else
|
||||
lower = static_cast<Dimension::value_type>(
|
||||
ceil(static_cast<double>(input_product.get_min_length()) / output_product.get_max_length()));
|
||||
output_shape[minus_one_idx] = resolve_minus_one(this, input_product, output_product);
|
||||
|
||||
Dimension::value_type upper;
|
||||
if (input_product.get_max_length() == 0)
|
||||
upper = 0;
|
||||
else if (input_product.get_max_length() == -1 || output_product.get_min_length() == 0 ||
|
||||
output_product.get_min_length() == -1)
|
||||
upper = -1; // dynamic
|
||||
else
|
||||
upper = static_cast<Dimension::value_type>(
|
||||
floor(static_cast<double>(input_product.get_max_length()) / output_product.get_min_length()));
|
||||
|
||||
if (lower == -1)
|
||||
output_shape[minus_one_idx] = Dimension::dynamic();
|
||||
else if (upper == -1)
|
||||
output_shape[minus_one_idx] = Dimension(lower, upper);
|
||||
else if (lower > upper) // empty intersection
|
||||
output_shape[minus_one_idx] = Dimension::dynamic();
|
||||
else
|
||||
output_shape[minus_one_idx] = Dimension(lower, upper);
|
||||
}
|
||||
}
|
||||
}
|
||||
ov::PartialShape output_pshape(output_shape);
|
||||
if (input_pshape.is_static() && output_pshape.is_static()) {
|
||||
size_t zero_dims = std::count_if(reshape_pattern.begin(), reshape_pattern.end(), [](Dimension dim) {
|
||||
|
@ -14,6 +14,7 @@
|
||||
//*****************************************************************************
|
||||
|
||||
#include <vector>
|
||||
#include <dimension_tracker.hpp>
|
||||
#include "gtest/gtest.h"
|
||||
#include "ngraph/ngraph.hpp"
|
||||
|
||||
@ -202,6 +203,63 @@ TYPED_TEST_P(ArithmeticOperator, full_dynamic_shape)
|
||||
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
|
||||
}
|
||||
|
||||
TYPED_TEST_P(ArithmeticOperator, dynamic_shape_static_rank_with_labels_a)
|
||||
{
|
||||
Dimension b = -1;
|
||||
ov::DimensionTracker::set_label(b, 10);
|
||||
PartialShape A = {b, 3, 224, 224}, B = {1, 3, 1, 1};
|
||||
|
||||
auto paramA = std::make_shared<op::Parameter>(element::f64, A);
|
||||
auto paramB = std::make_shared<op::Parameter>(element::f64, B);
|
||||
const auto op = std::make_shared<TypeParam>(paramA, paramB);
|
||||
|
||||
const auto shape = op->get_output_partial_shape(0);
|
||||
|
||||
ASSERT_EQ(shape, A);
|
||||
ASSERT_EQ(ov::DimensionTracker::get_label(shape[0]), 10);
|
||||
ASSERT_EQ(ov::DimensionTracker::get_label(shape[1]), 0);
|
||||
ASSERT_EQ(ov::DimensionTracker::get_label(shape[2]), 0);
|
||||
ASSERT_EQ(ov::DimensionTracker::get_label(shape[3]), 0);
|
||||
}
|
||||
|
||||
TYPED_TEST_P(ArithmeticOperator, dynamic_shape_static_rank_with_labels_b)
|
||||
{
|
||||
Dimension b = -1;
|
||||
ov::DimensionTracker::set_label(b, 10);
|
||||
PartialShape A = {b, 3, 224, 224}, B = {1, 3, 1, 1};
|
||||
|
||||
auto paramA = std::make_shared<op::Parameter>(element::f64, A);
|
||||
auto paramB = std::make_shared<op::Parameter>(element::f64, B);
|
||||
const auto op = std::make_shared<TypeParam>(paramB, paramA);
|
||||
|
||||
const auto shape = op->get_output_partial_shape(0);
|
||||
|
||||
ASSERT_EQ(shape, A);
|
||||
ASSERT_EQ(ov::DimensionTracker::get_label(shape[0]), 10);
|
||||
ASSERT_EQ(ov::DimensionTracker::get_label(shape[1]), 0);
|
||||
ASSERT_EQ(ov::DimensionTracker::get_label(shape[2]), 0);
|
||||
ASSERT_EQ(ov::DimensionTracker::get_label(shape[3]), 0);
|
||||
}
|
||||
|
||||
TYPED_TEST_P(ArithmeticOperator, dynamic_shape_static_rank_with_labels_different_rank)
|
||||
{
|
||||
Dimension b = -1;
|
||||
ov::DimensionTracker::set_label(b, 10);
|
||||
PartialShape A = {b, -1, -1, -1}, B = {3, 1, 1};
|
||||
|
||||
auto paramA = std::make_shared<op::Parameter>(element::f64, A);
|
||||
auto paramB = std::make_shared<op::Parameter>(element::f64, B);
|
||||
const auto op = std::make_shared<TypeParam>(paramA, paramB);
|
||||
|
||||
const auto shape = op->get_output_partial_shape(0);
|
||||
|
||||
ASSERT_EQ(shape, ov::PartialShape({-1, 3, -1, -1}));
|
||||
ASSERT_EQ(ov::DimensionTracker::get_label(shape[0]), 10);
|
||||
ASSERT_EQ(ov::DimensionTracker::get_label(shape[1]), 0);
|
||||
ASSERT_EQ(ov::DimensionTracker::get_label(shape[2]), 0);
|
||||
ASSERT_EQ(ov::DimensionTracker::get_label(shape[3]), 0);
|
||||
}
|
||||
|
||||
REGISTER_TYPED_TEST_SUITE_P(ArithmeticOperator,
|
||||
shape_inference_2D,
|
||||
shape_inference_4D,
|
||||
@ -219,4 +277,7 @@ REGISTER_TYPED_TEST_SUITE_P(ArithmeticOperator,
|
||||
shape_inference_5D_x_5D_incompatible,
|
||||
dynamic_shape_3D,
|
||||
dynamic_shape_5D,
|
||||
full_dynamic_shape);
|
||||
full_dynamic_shape,
|
||||
dynamic_shape_static_rank_with_labels_a,
|
||||
dynamic_shape_static_rank_with_labels_b,
|
||||
dynamic_shape_static_rank_with_labels_different_rank);
|
||||
|
@ -633,3 +633,20 @@ TEST(type_prop, reshape_dynamic_value_and_label_propagation) {
|
||||
const auto& output_shape = bc->get_output_partial_shape(0);
|
||||
ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[0]), 10);
|
||||
}
|
||||
|
||||
TEST(type_prop, reshape_label_shape_propagation_minus_one) {
|
||||
Dimension marked_0 = Dimension(-1);
|
||||
ov::DimensionTracker::set_label(marked_0, 10);
|
||||
|
||||
PartialShape initial_shape = PartialShape{marked_0, 4, 3, 1};
|
||||
|
||||
auto input = std::make_shared<op::Parameter>(element::f32, initial_shape);
|
||||
auto output_pattern = std::make_shared<op::Constant>(element::i64, Shape{2}, std::vector<int64_t>{-1, 12});
|
||||
|
||||
const auto reshape = std::make_shared<op::v1::Reshape>(input, output_pattern, false);
|
||||
|
||||
auto output_shape = reshape->get_output_partial_shape(0);
|
||||
ASSERT_EQ(output_shape, PartialShape({-1, 12}));
|
||||
ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[0]), 10);
|
||||
ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[1]), 0);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user