Partial Value propagation from partial to static value (#10162)

* Partial Value propagation from partial to static value

* Style

* Tests ajustment
This commit is contained in:
Evgenya Stepyreva
2022-02-08 21:55:17 +03:00
committed by GitHub
parent 0dfdadb531
commit a18069926e
3 changed files with 62 additions and 21 deletions

View File

@@ -1210,7 +1210,27 @@ void propagate_rt_info(Node* node, const Output<Node>& final_port) {
}
}
HostTensorPtr evaluate_bound(const Output<Node>& output, bool is_upper) {
bool are_equal(const HostTensorPtr& lhs, const HostTensorPtr& rhs, size_t max_elements_limit = 10) {
if (!lhs || !rhs)
return false;
const auto& lhs_shape = lhs->get_shape();
const auto& rhs_shape = rhs->get_shape();
OPENVINO_ASSERT(lhs_shape == rhs_shape);
const auto& lhs_et = lhs->get_element_type();
const auto& rhs_et = rhs->get_element_type();
OPENVINO_ASSERT(lhs_et == rhs_et);
if (shape_size(lhs_shape) > max_elements_limit)
return false;
auto mask = std::make_shared<HostTensor>(element::boolean, lhs_shape);
const auto& param = std::make_shared<op::Parameter>(lhs_et, lhs_shape);
op::v1::Equal(param, param, ngraph::op::AutoBroadcastType::NUMPY).evaluate({mask}, {lhs, rhs});
auto equal = op::Constant(mask).cast_vector<bool>();
return std::all_of(equal.begin(), equal.end(), [](bool i) {
return i;
});
}
HostTensorPtr evaluate_bound(const Output<Node>& output, bool is_upper, bool invalidate_all_unused_values = true) {
// bound is already set in the tensor
if (is_upper && output.get_tensor().get_upper_value() != nullptr)
return output.get_tensor().get_upper_value();
@@ -1229,22 +1249,31 @@ HostTensorPtr evaluate_bound(const Output<Node>& output, bool is_upper) {
TensorLabelVector output_labels(outputs.size());
bool same_inputs = std::all_of(input_values.begin(), input_values.end(), [](const Output<Node>& input) {
return input.get_tensor().has_and_set_bound();
auto& tensor = input.get_tensor();
return tensor.has_and_set_bound() || are_equal(tensor.get_lower_value(), tensor.get_upper_value());
});
for (size_t i = 0; i < outputs.size(); ++i) {
// TODO: should we skip setting value for tensors that have only one consumer?
if ((same_inputs || is_upper) && node->get_output_tensor(i).get_upper_value() == nullptr)
node->get_output_tensor(i).set_upper_value(outputs[i]);
if ((same_inputs || !is_upper) && node->get_output_tensor(i).get_lower_value() == nullptr)
node->get_output_tensor(i).set_lower_value(outputs[i]);
if (are_equal(node->get_output_tensor(i).get_lower_value(),
node->get_output_tensor(i).get_upper_value()))
node->get_output_tensor(i).set_lower_value(node->get_output_tensor(i).get_upper_value());
}
if (node->evaluate_label(output_labels))
for (size_t i = 0; i < outputs.size(); ++i)
node->get_output_tensor(i).set_value_label(output_labels[i]);
for (const auto& input : input_values)
if (input.get_target_inputs().size() == 1)
input.get_tensor().invalidate_values();
for (const auto& input : input_values) {
auto& tensor = input.get_tensor();
bool should_invalidate = invalidate_all_unused_values;
if (tensor.get_lower_value() && shape_size(tensor.get_lower_value()->get_shape()) > 10)
should_invalidate |= true;
if (tensor.get_upper_value() && shape_size(tensor.get_upper_value()->get_shape()) > 10)
should_invalidate |= true;
if (should_invalidate)
tensor.invalidate_values();
}
propagate_rt_info(node, output);
} else {
break;
@@ -1268,7 +1297,7 @@ HostTensorPtr ngraph::evaluate_upper_bound(const Output<Node>& output) {
}
pair<HostTensorPtr, HostTensorPtr> ngraph::evaluate_both_bounds(const Output<Node>& output) {
return {evaluate_lower_bound(output), evaluate_upper_bound(output)};
return {evaluate_bound(output, false, false), evaluate_upper_bound(output)};
}
bool ov::evaluate_as_partial_shape(const Output<Node>& output, PartialShape& pshape) {

View File

@@ -134,6 +134,27 @@ TEST(type_prop, gather_1_dynamic_value_and_label_propagation) {
ASSERT_EQ(ov::DimensionTracker::get_label(output_shape[0]), 10);
}
TEST(type_prop, dynamic_value_propagation) {
auto param = make_shared<op::Parameter>(element::f32, PartialShape{-1, 3, -1, -1});
auto shape_of = std::make_shared<op::v3::ShapeOf>(param, element::i32);
auto indices = op::Constant::create(element::i32, {}, {1});
auto axis = op::Constant::create(element::i32, {}, {0});
auto gather = std::make_shared<op::v1::Gather>(shape_of, indices, axis);
auto add = std::make_shared<op::v1::Add>(gather, op::Constant::create(element::i32, {}, {0}));
auto range = std::make_shared<op::v4::Range>(op::Constant::create(element::i32, {}, {0}),
add,
op::Constant::create(element::i32, {}, {1}),
element::i64);
auto RIC = std::make_shared<op::v1::Gather>(param, range, op::Constant::create(element::i32, {}, {1}));
ASSERT_EQ(RIC->get_element_type(), element::f32);
ASSERT_EQ(RIC->get_output_partial_shape(0), (PartialShape{-1, 3, -1, -1}));
}
// ------------------------------ V7 ------------------------------
TEST(type_prop, gather_7_axis_0) {

View File

@@ -134,15 +134,10 @@ TEST_F(TransformationTestsF, GatherNegativeIndicesNormalize_static_axis_dim) {
auto indices_type = ngraph::element::i32;
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::PartialShape{DYN, 15, DYN});
auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1});
auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {2});
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {1});
auto shape_of = std::make_shared<ngraph::opset7::ShapeOf>(data, indices_type);
auto input_gather = std::make_shared<ngraph::opset7::Gather>(shape_of,
ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0}));
auto add = std::make_shared<ngraph::opset7::Add>(input_gather, indices);
auto gather = std::make_shared<ngraph::opset7::Gather>(data, add, axis);
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
}
}
@@ -164,14 +159,10 @@ TEST_F(TransformationTestsF, GatherNegativeIndicesNormalize_static_axis_dim_neg_
auto indices_type = ngraph::element::i32;
auto data = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::f32, ngraph::PartialShape{DYN, 15, DYN});
auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {-1});
auto indices = ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {2});
auto axis = ngraph::opset7::Constant::create(ngraph::element::i32, ngraph::Shape{}, {-2});
auto shape_of = std::make_shared<ngraph::opset7::ShapeOf>(data, indices_type);
auto input_gather = std::make_shared<ngraph::opset7::Gather>(shape_of,
ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {1}), ngraph::opset7::Constant::create(indices_type, ngraph::Shape{}, {0}));
auto add = std::make_shared<ngraph::opset7::Add>(input_gather, indices);
auto gather = std::make_shared<ngraph::opset7::Gather>(data, add, axis);
auto gather = std::make_shared<ngraph::opset7::Gather>(data, indices, axis);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{gather}, ngraph::ParameterVector{data});
}