[nGraph] get_constant_from_source fix (#10787)

This commit is contained in:
Vladislav Golubev 2022-03-15 14:50:21 +03:00 committed by GitHub
parent 660c6a3e84
commit 060a149322
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 1 deletions

View File

@ -1298,7 +1298,9 @@ HostTensorPtr ngraph::evaluate_upper_bound(const Output<Node>& output) {
}
pair<HostTensorPtr, HostTensorPtr> ngraph::evaluate_both_bounds(const Output<Node>& output) {
return {evaluate_bound(output, false, false), evaluate_upper_bound(output)};
evaluate_bound(output, false, false);
evaluate_upper_bound(output);
return {output.get_tensor_ptr()->get_lower_value(), output.get_tensor_ptr()->get_upper_value()};
}
bool ov::evaluate_as_partial_shape(const Output<Node>& output, PartialShape& pshape) {

View File

@ -34,3 +34,16 @@ TEST(get_constant_from_source, invalidation_check) {
ASSERT_FALSE(div->get_output_tensor(0).get_lower_value());
ASSERT_FALSE(div->get_output_tensor(0).get_upper_value());
}
TEST(get_constant_from_source, extract_static_dim_from_dynamic_shape_check) {
auto data = std::make_shared<ov::opset8::Parameter>(ngraph::element::f32, ov::PartialShape{-1, 1, 128});
auto shape = std::make_shared<ov::opset8::ShapeOf>(data);
auto one = ov::opset8::Constant::create(ov::element::i64, {1}, {1});
auto zero = ov::opset8::Constant::create(ov::element::i64, {1}, {0});
const auto extract_static_dimension = std::make_shared<ov::opset8::Gather>(shape, one, zero);
ASSERT_TRUE(ov::get_constant_from_source(extract_static_dimension));
ASSERT_TRUE(extract_static_dimension->get_output_tensor(0).get_lower_value());
ASSERT_TRUE(extract_static_dimension->get_output_tensor(0).get_upper_value());
}