[nGraph] get_constant_from_source fix (#10787)
This commit is contained in:
parent
660c6a3e84
commit
060a149322
@ -1298,7 +1298,9 @@ HostTensorPtr ngraph::evaluate_upper_bound(const Output<Node>& output) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pair<HostTensorPtr, HostTensorPtr> ngraph::evaluate_both_bounds(const Output<Node>& output) {
|
pair<HostTensorPtr, HostTensorPtr> ngraph::evaluate_both_bounds(const Output<Node>& output) {
|
||||||
return {evaluate_bound(output, false, false), evaluate_upper_bound(output)};
|
evaluate_bound(output, false, false);
|
||||||
|
evaluate_upper_bound(output);
|
||||||
|
return {output.get_tensor_ptr()->get_lower_value(), output.get_tensor_ptr()->get_upper_value()};
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ov::evaluate_as_partial_shape(const Output<Node>& output, PartialShape& pshape) {
|
bool ov::evaluate_as_partial_shape(const Output<Node>& output, PartialShape& pshape) {
|
||||||
|
@ -34,3 +34,16 @@ TEST(get_constant_from_source, invalidation_check) {
|
|||||||
ASSERT_FALSE(div->get_output_tensor(0).get_lower_value());
|
ASSERT_FALSE(div->get_output_tensor(0).get_lower_value());
|
||||||
ASSERT_FALSE(div->get_output_tensor(0).get_upper_value());
|
ASSERT_FALSE(div->get_output_tensor(0).get_upper_value());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(get_constant_from_source, extract_static_dim_from_dynamic_shape_check) {
|
||||||
|
auto data = std::make_shared<ov::opset8::Parameter>(ngraph::element::f32, ov::PartialShape{-1, 1, 128});
|
||||||
|
auto shape = std::make_shared<ov::opset8::ShapeOf>(data);
|
||||||
|
auto one = ov::opset8::Constant::create(ov::element::i64, {1}, {1});
|
||||||
|
auto zero = ov::opset8::Constant::create(ov::element::i64, {1}, {0});
|
||||||
|
const auto extract_static_dimension = std::make_shared<ov::opset8::Gather>(shape, one, zero);
|
||||||
|
|
||||||
|
ASSERT_TRUE(ov::get_constant_from_source(extract_static_dimension));
|
||||||
|
|
||||||
|
ASSERT_TRUE(extract_static_dimension->get_output_tensor(0).get_lower_value());
|
||||||
|
ASSERT_TRUE(extract_static_dimension->get_output_tensor(0).get_upper_value());
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user