Shape inference adoption for dimension tracking (#10016)

* Shape inference adoption for dimension tracking

* Style

* test adj

* tests fixed
This commit is contained in:
Evgenya Stepyreva 2022-02-10 15:30:18 +03:00 committed by GitHub
parent d5c837cc1b
commit 9ad09f2120
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 252 additions and 66 deletions

View File

@ -87,7 +87,11 @@ template <typename T>
void infer_using_scales(T& output_shape, const std::vector<int64_t>& axes, const std::vector<float>& scales) {
size_t i = 0;
static constexpr float epsilon = 1.0e-6f;
for (auto axis : axes) {
for (const auto& axis : axes) {
if (scales[i] == 1.) {
++i;
continue;
}
const auto& current_dim = output_shape[axis];
float multiplier = scales[i] + epsilon;
if (current_dim.is_static()) {
@ -109,6 +113,7 @@ void shape_infer(const Interpolate* op,
std::vector<T>& output_shapes,
const std::map<size_t, std::shared_ptr<ngraph::runtime::HostTensor>>& constant_data) {
NODE_VALIDATION_CHECK(op, (input_shapes.size() == 3 || input_shapes.size() == 4) && output_shapes.size() == 1);
using DimType = typename std::iterator_traits<typename T::iterator>::value_type;
const auto& input_shape = input_shapes[0];
auto& output_shape = output_shapes[0];
@ -137,9 +142,7 @@ void shape_infer(const Interpolate* op,
// Get padded input shape
for (int64_t i = 0; i < input_rank; ++i) {
if (input_shape[i].is_static()) {
output_shape[i] = pads_begin[i] + pads_end[i] + input_shape[i].get_length();
}
output_shape[i] = DimType(pads_begin[i]) + DimType(pads_end[i]) + input_shape[i];
}
if (op->m_attrs.shape_calculation_mode == Interpolate::ShapeCalcMode::SCALES) {
@ -147,7 +150,7 @@ void shape_infer(const Interpolate* op,
if (get_data_as_float<T>(2, op, scales, constant_data)) {
infer_using_scales(output_shape, axes, scales);
} else {
for (auto axis : axes) {
for (const auto& axis : axes) {
output_shape[axis] = ov::Dimension::dynamic();
}
}
@ -155,11 +158,11 @@ void shape_infer(const Interpolate* op,
T target_spatial_shape;
if (get_data_as_shape<T>(1, op, target_spatial_shape, constant_data)) {
size_t i = 0;
for (auto axis : axes) {
for (const auto& axis : axes) {
output_shape[axis] = target_spatial_shape[i++];
}
} else {
for (auto axis : axes) {
for (const auto& axis : axes) {
output_shape[axis] = ov::Dimension::dynamic();
}
}

View File

@ -4,6 +4,7 @@
#include "ngraph/op/matmul.hpp"
#include <dimension_tracker.hpp>
#include <memory>
#include <numeric>
@ -151,6 +152,17 @@ ov::PartialShape validate_matmul_output_shape(const ov::PartialShape& arg0_shape
: arg0_shape_tmp[i].get_max_length();
}
output_shape[i] = Dimension(lower_bound, upper_bound);
// label setting
size_t out_label = 0;
size_t label_0 = ov::DimensionTracker::get_label(arg0_shape_tmp[i]);
size_t label_1 = ov::DimensionTracker::get_label(arg1_shape_tmp[i]);
if (label_0 == label_1 || label_1 == 0)
out_label = label_0;
else if (label_0 == 0)
out_label = label_1;
output_shape[i] = Dimension(lower_bound, upper_bound);
if (out_label)
ov::DimensionTracker::set_label(output_shape[i], out_label);
}
}

View File

@ -5,6 +5,7 @@
#include "ngraph/op/reshape.hpp"
#include <algorithm>
#include <dimension_tracker.hpp>
#include <ngraph/validation_util.hpp>
#include "itt.hpp"
@ -88,6 +89,9 @@ void op::v1::Reshape::validate_and_infer_types() {
auto upper_bound = std::make_shared<op::v0::Constant>(ub)->cast_vector<int64_t>();
shape_can_be_calculated = true;
NGRAPH_CHECK(lower_bound.size() == upper_bound.size());
const TensorLabel& labels = get_input_source_output(1).get_tensor().get_value_label();
NGRAPH_CHECK(labels.empty() || lower_bound.size() == labels.size());
for (size_t i = 0; i < lower_bound.size(); ++i) {
NODE_VALIDATION_CHECK(this,
lower_bound[i] >= -1 && upper_bound[i] >= -1,
@ -104,8 +108,10 @@ void op::v1::Reshape::validate_and_infer_types() {
upper_bound[i] == std::numeric_limits<std::int32_t>::max()) {
upper_bound[i] = std::numeric_limits<std::int64_t>::max();
}
reshape_pattern.emplace_back(lower_bound[i], upper_bound[i]);
auto d = Dimension(lower_bound[i], upper_bound[i]);
if (!labels.empty() && labels[i])
ov::DimensionTracker::set_label(d, labels[i]);
reshape_pattern.emplace_back(d);
}
// For scalar case reshape_patter should be empty but scalar reshape pattern should be empty
// or equal to 1
@ -232,20 +238,155 @@ bool op::v1::Reshape::constant_fold(OutputVector& output_values, const OutputVec
return false;
}
namespace {
bool fully_eq(const Dimension& rhs, const Dimension& lhs) {
return rhs == lhs && ov::DimensionTracker::get_label(rhs) == ov::DimensionTracker::get_label(lhs) &&
(ov::DimensionTracker::get_label(rhs) || rhs.is_static());
}
Dimension resolve_minus_one(const Node* reshape_node,
vector<Dimension>& input_product,
vector<Dimension>& output_product) {
std::vector<Dimension> to_delete_from_output, to_delete_from_input;
Dimension input_const_part(1), output_const_part(1);
for (const auto& dim : output_product)
if (!ov::DimensionTracker::get_label(dim) && dim.is_static()) {
output_const_part *= dim;
to_delete_from_output.push_back(dim);
}
for (const auto& dim : input_product)
if (!ov::DimensionTracker::get_label(dim) && dim.is_static()) {
input_const_part *= dim;
to_delete_from_input.push_back(dim);
}
for (const auto& dim : to_delete_from_input) {
input_product.erase(std::remove_if(input_product.begin(),
input_product.end(),
[=](const Dimension& d) {
return fully_eq(dim, d);
}),
input_product.end());
}
for (const auto& dim : to_delete_from_output) {
output_product.erase(std::remove_if(output_product.begin(),
output_product.end(),
[=](const Dimension& d) {
return fully_eq(dim, d);
}),
output_product.end());
}
to_delete_from_input.clear();
to_delete_from_output.clear();
if (input_const_part != output_const_part) {
input_product.push_back(input_const_part);
output_product.push_back(output_const_part);
}
for (const auto& out_dim : output_product) {
const auto& it = std::find_if(input_product.begin(), input_product.end(), [out_dim](const Dimension& in_dim) {
return fully_eq(out_dim, in_dim);
});
if (it != input_product.end()) {
to_delete_from_output.push_back(out_dim);
to_delete_from_input.push_back(out_dim);
}
}
for (const auto& dim : to_delete_from_input) {
input_product.erase(std::remove_if(input_product.begin(),
input_product.end(),
[=](const Dimension& d) {
return fully_eq(dim, d);
}),
input_product.end());
}
for (const auto& dim : to_delete_from_output) {
output_product.erase(std::remove_if(output_product.begin(),
output_product.end(),
[=](const Dimension& d) {
return fully_eq(dim, d);
}),
output_product.end());
}
if (output_product.empty() && input_product.size() == 1)
return input_product[0];
Dimension input_dim(1), output_dim(1);
for (const auto& i : input_product) {
input_dim *= i;
}
for (const auto& i : output_product) {
output_dim *= i;
}
if (output_dim == 0) {
NODE_VALIDATION_CHECK(reshape_node,
input_dim == 0,
"Cannot infer '-1' dimension with zero-size output "
"dimension unless at least one input dimension is "
"also zero-size");
return Dimension(0);
} else {
if (input_dim.is_static() && output_dim.is_static()) {
NODE_VALIDATION_CHECK(reshape_node,
input_dim.get_length() % output_dim.get_length() == 0,
"Non-'-1' output dimensions do not evenly divide the input dimensions");
}
if (output_dim.get_min_length() == 0 || output_dim == Dimension() || input_dim == Dimension()) {
return Dimension::dynamic();
} else {
Dimension::value_type lower;
if (input_dim.get_min_length() == 0)
lower = 0;
else if (input_dim.get_min_length() == -1 || output_dim.get_max_length() == 0 ||
output_dim.get_max_length() == -1)
lower = -1; // dynamic
else
lower = static_cast<Dimension::value_type>(
ceil(static_cast<double>(input_dim.get_min_length()) / output_dim.get_max_length()));
Dimension::value_type upper;
if (input_dim.get_max_length() == 0)
upper = 0;
else if (input_dim.get_max_length() == -1 || output_dim.get_min_length() == 0 ||
output_dim.get_min_length() == -1)
upper = -1; // dynamic
else
upper = static_cast<Dimension::value_type>(
floor(static_cast<double>(input_dim.get_max_length()) / output_dim.get_min_length()));
if (lower == -1)
return Dimension::dynamic();
else if (upper == -1)
return Dimension(lower, upper);
else if (lower > upper) // empty intersection
return Dimension::dynamic();
else
return Dimension(lower, upper);
}
}
}
} // namespace
void op::v1::Reshape::calculate_output_shape(vector<Dimension>& reshape_pattern,
const int64_t& minus_one_idx,
const ov::PartialShape& input_pshape,
vector<Dimension>& output_shape) const {
Dimension output_product(1);
std::vector<Dimension> output_product;
for (int64_t i = 0; i < static_cast<int64_t>(reshape_pattern.size()); ++i) {
if (i == minus_one_idx) // resolving everything except -1
continue;
auto pattern_dim = reshape_pattern[i];
if (pattern_dim.get_min_length() == 0 && pattern_dim.get_max_length() == 0 && get_special_zero()) {
if (pattern_dim == 0 && get_special_zero()) {
if (input_pshape.rank().is_dynamic()) {
output_shape[i] = Dimension::dynamic();
output_product *= Dimension::dynamic();
output_product.push_back(Dimension::dynamic());
} else {
NODE_VALIDATION_CHECK(this, i < input_pshape.rank().get_length(), "'0' dimension is out of range");
output_shape[i] = input_pshape[i];
@ -257,71 +398,23 @@ void op::v1::Reshape::calculate_output_shape(vector<Dimension>& reshape_pattern,
}
} else {
output_shape[i] = pattern_dim;
output_product *= pattern_dim;
output_product.push_back(pattern_dim);
}
}
Dimension input_product(1);
std::vector<Dimension> input_product;
if (input_pshape.rank().is_static())
for (int64_t i = 0; i < input_pshape.rank().get_length(); ++i) {
if (i < static_cast<int64_t>(reshape_pattern.size()) && reshape_pattern[i].get_min_length() == 0 &&
reshape_pattern[i].get_max_length() == 0)
continue;
input_product *= input_pshape[i];
input_product.push_back(input_pshape[i]);
}
else
input_product = Dimension::dynamic();
input_product.push_back(Dimension::dynamic());
if (minus_one_idx != -1) // resolving -1 masked dimension
{
if (output_product.get_min_length() == 0 && output_product.get_max_length() == 0) {
// TODO: Decide if this is desired behavior here. (NumPy seems
// to fail.)
NODE_VALIDATION_CHECK(this,
input_product.get_min_length() == 0 && input_product.get_max_length() == 0,
"Cannot infer '-1' dimension with zero-size output "
"dimension unless at least one input dimension is "
"also zero-size");
output_shape[minus_one_idx] = Dimension(0);
} else {
if (input_product.is_static() && output_product.is_static()) {
NODE_VALIDATION_CHECK(this,
input_product.get_length() % output_product.get_length() == 0,
"Non-'-1' output dimensions do not evenly divide the input dimensions");
}
if (output_product.get_min_length() == 0 || output_product == Dimension() || input_product == Dimension()) {
output_shape[minus_one_idx] = Dimension::dynamic();
} else {
Dimension::value_type lower;
if (input_product.get_min_length() == 0)
lower = 0;
else if (input_product.get_min_length() == -1 || output_product.get_max_length() == 0 ||
output_product.get_max_length() == -1)
lower = -1; // dynamic
else
lower = static_cast<Dimension::value_type>(
ceil(static_cast<double>(input_product.get_min_length()) / output_product.get_max_length()));
output_shape[minus_one_idx] = resolve_minus_one(this, input_product, output_product);
Dimension::value_type upper;
if (input_product.get_max_length() == 0)
upper = 0;
else if (input_product.get_max_length() == -1 || output_product.get_min_length() == 0 ||
output_product.get_min_length() == -1)
upper = -1; // dynamic
else
upper = static_cast<Dimension::value_type>(
floor(static_cast<double>(input_product.get_max_length()) / output_product.get_min_length()));
if (lower == -1)
output_shape[minus_one_idx] = Dimension::dynamic();
else if (upper == -1)
output_shape[minus_one_idx] = Dimension(lower, upper);
else if (lower > upper) // empty intersection
output_shape[minus_one_idx] = Dimension::dynamic();
else
output_shape[minus_one_idx] = Dimension(lower, upper);
}
}
}
ov::PartialShape output_pshape(output_shape);
if (input_pshape.is_static() && output_pshape.is_static()) {
size_t zero_dims = std::count_if(reshape_pattern.begin(), reshape_pattern.end(), [](Dimension dim) {

View File

@ -14,6 +14,7 @@
//*****************************************************************************
#include <vector>
#include <dimension_tracker.hpp>
#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
@ -202,6 +203,63 @@ TYPED_TEST_P(ArithmeticOperator, full_dynamic_shape)
ASSERT_TRUE(op->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}
TYPED_TEST_P(ArithmeticOperator, dynamic_shape_static_rank_with_labels_a)
{
Dimension b = -1;
ov::DimensionTracker::set_label(b, 10);
PartialShape A = {b, 3, 224, 224}, B = {1, 3, 1, 1};
auto paramA = std::make_shared<op::Parameter>(element::f64, A);
auto paramB = std::make_shared<op::Parameter>(element::f64, B);
const auto op = std::make_shared<TypeParam>(paramA, paramB);
const auto shape = op->get_output_partial_shape(0);
ASSERT_EQ(shape, A);
ASSERT_EQ(ov::DimensionTracker::get_label(shape[0]), 10);
ASSERT_EQ(ov::DimensionTracker::get_label(shape[1]), 0);
ASSERT_EQ(ov::DimensionTracker::get_label(shape[2]), 0);
ASSERT_EQ(ov::DimensionTracker::get_label(shape[3]), 0);
}
TYPED_TEST_P(ArithmeticOperator, dynamic_shape_static_rank_with_labels_b)
{
Dimension b = -1;
ov::DimensionTracker::set_label(b, 10);
PartialShape A = {b, 3, 224, 224}, B = {1, 3, 1, 1};
auto paramA = std::make_shared<op::Parameter>(element::f64, A);
auto paramB = std::make_shared<op::Parameter>(element::f64, B);
const auto op = std::make_shared<TypeParam>(paramB, paramA);
const auto shape = op->get_output_partial_shape(0);
ASSERT_EQ(shape, A);
ASSERT_EQ(ov::DimensionTracker::get_label(shape[0]), 10);
ASSERT_EQ(ov::DimensionTracker::get_label(shape[1]), 0);
ASSERT_EQ(ov::DimensionTracker::get_label(shape[2]), 0);
ASSERT_EQ(ov::DimensionTracker::get_label(shape[3]), 0);
}
TYPED_TEST_P(ArithmeticOperator, dynamic_shape_static_rank_with_labels_different_rank)
{
Dimension b = -1;
ov::DimensionTracker::set_label(b, 10);
PartialShape A = {b, -1, -1, -1}, B = {3, 1, 1};
auto paramA = std::make_shared<op::Parameter>(element::f64, A);
auto paramB = std::make_shared<op::Parameter>(element::f64, B);
const auto op = std::make_shared<TypeParam>(paramA, paramB);
const auto shape = op->get_output_partial_shape(0);
ASSERT_EQ(shape, ov::PartialShape({-1, 3, -1, -1}));
ASSERT_EQ(ov::DimensionTracker::get_label(shape[0]), 10);
ASSERT_EQ(ov::DimensionTracker::get_label(shape[1]), 0);
ASSERT_EQ(ov::DimensionTracker::get_label(shape[2]), 0);
ASSERT_EQ(ov::DimensionTracker::get_label(shape[3]), 0);
}
REGISTER_TYPED_TEST_SUITE_P(ArithmeticOperator,
shape_inference_2D,
shape_inference_4D,
@ -219,4 +277,7 @@ REGISTER_TYPED_TEST_SUITE_P(ArithmeticOperator,
shape_inference_5D_x_5D_incompatible,
dynamic_shape_3D,
dynamic_shape_5D,
full_dynamic_shape);
full_dynamic_shape,
dynamic_shape_static_rank_with_labels_a,
dynamic_shape_static_rank_with_labels_b,
dynamic_shape_static_rank_with_labels_different_rank);

View File

@ -633,3 +633,20 @@ TEST(type_prop, reshape_dynamic_value_and_label_propagation) {
const auto& output_shape = bc->get_output_partial_shape(0);
ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[0]), 10);
}
TEST(type_prop, reshape_label_shape_propagation_minus_one) {
Dimension marked_0 = Dimension(-1);
ov::DimensionTracker::set_label(marked_0, 10);
PartialShape initial_shape = PartialShape{marked_0, 4, 3, 1};
auto input = std::make_shared<op::Parameter>(element::f32, initial_shape);
auto output_pattern = std::make_shared<op::Constant>(element::i64, Shape{2}, std::vector<int64_t>{-1, 12});
const auto reshape = std::make_shared<op::v1::Reshape>(input, output_pattern, false);
auto output_shape = reshape->get_output_partial_shape(0);
ASSERT_EQ(output_shape, PartialShape({-1, 12}));
ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[0]), 10);
ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[1]), 0);
}