Random Uniform: precise shape inference (#17740)

This commit is contained in:
Evgenya Stepyreva 2023-05-26 09:05:42 +04:00 committed by GitHub
parent 1dad2c003b
commit b1b5d65951
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 3 deletions

View File

@ -41,10 +41,8 @@ void op::v8::RandomUniform::validate_and_infer_types() {
"The rank of the tensor defining output shape must be equal to 1."); "The rank of the tensor defining output shape must be equal to 1.");
OPENVINO_SUPPRESS_DEPRECATED_START OPENVINO_SUPPRESS_DEPRECATED_START
if (const auto& const_shape = get_constant_from_source(input_value(0))) { if (!evaluate_as_partial_shape(input_value(0), output_shape)) {
OPENVINO_SUPPRESS_DEPRECATED_END OPENVINO_SUPPRESS_DEPRECATED_END
output_shape = ov::PartialShape(const_shape->cast_vector<int64_t>());
} else {
output_shape = ov::PartialShape::dynamic(input_shape[0]); output_shape = ov::PartialShape::dynamic(input_shape[0]);
} }
} }

View File

@ -43,6 +43,19 @@ TEST(type_prop, random_uniform_dynamic_shape) {
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic())); EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
} }
TEST(type_prop, random_uniform_dynamic_shape_1) {
auto shape = make_shared<opset8::Parameter>(element::i32, PartialShape{{0, 10}, 4, {3, 7}, -1});
auto out_shape = make_shared<opset8::ShapeOf>(shape);
auto min_val = make_shared<opset8::Constant>(element::i64, Shape{}, 5);
auto max_val = make_shared<opset8::Constant>(element::i64, Shape{}, 10);
auto r = make_shared<opset8::RandomUniform>(out_shape, min_val, max_val, element::i64, 100, 200);
EXPECT_EQ(r->get_output_element_type(0), element::i64);
EXPECT_EQ(r->get_output_partial_shape(0), PartialShape({{0, 10}, 4, {3, 7}, -1}));
}
TEST(type_prop, random_uniform_dynamic_rank) { TEST(type_prop, random_uniform_dynamic_rank) {
auto out_shape = make_shared<opset8::Parameter>(element::i32, PartialShape::dynamic()); auto out_shape = make_shared<opset8::Parameter>(element::i32, PartialShape::dynamic());
auto min_val = make_shared<opset8::Constant>(element::f64, Shape{}, 5); auto min_val = make_shared<opset8::Constant>(element::f64, Shape{}, 5);