RandomUniform-8: shape inference fix (#11047)

* Shape inference fix

* Update src/core/src/op/random_uniform.cpp

Co-authored-by: Evgenya Stepyreva <evgenya.stepyreva@intel.com>

* Update src/core/tests/type_prop/random_uniform.cpp

Co-authored-by: Evgenya Stepyreva <evgenya.stepyreva@intel.com>

Co-authored-by: Evgenya Stepyreva <evgenya.stepyreva@intel.com>
This commit is contained in:
Tomasz Dołbniak 2022-03-21 10:04:51 +01:00 committed by GitHub
parent a887b41db6
commit b480a49d66
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 0 deletions

View File

@ -41,8 +41,11 @@ void op::v8::RandomUniform::validate_and_infer_types() {
NODE_VALIDATION_CHECK(this,
input_shape.rank() == 1,
"The rank of the tensor defining output shape must be equal to 1.");
if (const auto& const_shape = get_constant_from_source(input_value(0))) {
output_shape = ov::PartialShape(const_shape->cast_vector<int64_t>());
} else {
output_shape = ov::PartialShape::dynamic(input_shape[0]);
}
}

View File

@ -21,6 +21,17 @@ TEST(type_prop, random_uniform_type_shape) {
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape{2, 3, 4, 5}));
}
TEST(type_prop, random_uniform_param_input) {
auto out_shape = make_shared<opset8::Parameter>(element::i32, PartialShape{3});
auto min_val = make_shared<opset8::Constant>(element::i64, Shape{}, 5);
auto max_val = make_shared<opset8::Constant>(element::i64, Shape{}, 10);
auto r = make_shared<opset8::RandomUniform>(out_shape, min_val, max_val, element::i64, 100, 200);
EXPECT_EQ(r->get_output_element_type(0), element::i64);
EXPECT_EQ(r->get_output_partial_shape(0), PartialShape::dynamic(3));
}
TEST(type_prop, random_uniform_dynamic_shape) {
auto out_shape = make_shared<opset8::Parameter>(element::i32, PartialShape{Dimension::dynamic()});
auto min_val = make_shared<opset8::Constant>(element::i64, Shape{}, 5);