Fix get_constant_from_source (#10350)

This commit is contained in:
Gleb Kazantaev 2022-02-14 16:03:12 +03:00 committed by GitHub
parent d1477b8569
commit a3d5b6501d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 1 deletions

View File

@ -1271,7 +1271,7 @@ HostTensorPtr evaluate_bound(const Output<Node>& output, bool is_upper, bool inv
should_invalidate |= true;
if (tensor.get_upper_value() && shape_size(tensor.get_upper_value()->get_shape()) > 10)
should_invalidate |= true;
if (should_invalidate)
if (should_invalidate && input.get_target_inputs().size() == 1)
tensor.invalidate_values();
}
propagate_rt_info(node, output);

View File

@ -410,6 +410,7 @@ set(SRC
visitors/op/unsqueeze.cpp
visitors/op/variadic_split.cpp
uint4.cpp
validation_utils.cpp
)
# For type relaxed types

View File

@ -0,0 +1,36 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <openvino/core/type.hpp>
#include <openvino/core/validation_util.hpp>
#include <openvino/opsets/opset8.hpp>
#include <openvino/util/common_util.hpp>
TEST(get_constant_from_source, invalidation_check) {
auto a = ov::opset8::Constant::create(ov::element::i64, {100}, {123});
auto b = ov::opset8::Constant::create(ov::element::i64, {1}, {123});
auto div = std::make_shared<ov::opset8::Divide>(a, b);
auto s = std::make_shared<ov::opset8::ShapeOf>(a);
auto r = std::make_shared<ov::opset8::Reshape>(div, s, true);
auto tmp_consumer = std::make_shared<ov::opset8::ShapeOf>(s);
ASSERT_TRUE(ov::get_constant_from_source(r));
ASSERT_TRUE(r->get_output_tensor(0).get_lower_value());
ASSERT_TRUE(r->get_output_tensor(0).get_upper_value());
ASSERT_TRUE(s->get_output_tensor(0).get_lower_value());
ASSERT_TRUE(s->get_output_tensor(0).get_upper_value());
ASSERT_TRUE(b->get_output_tensor(0).get_lower_value());
ASSERT_TRUE(b->get_output_tensor(0).get_upper_value());
ASSERT_TRUE(a->get_output_tensor(0).get_lower_value());
ASSERT_TRUE(a->get_output_tensor(0).get_upper_value());
ASSERT_FALSE(div->get_output_tensor(0).get_lower_value());
ASSERT_FALSE(div->get_output_tensor(0).get_upper_value());
}