Random Uniform: precise shape inference (#17740)
This commit is contained in:
parent
1dad2c003b
commit
b1b5d65951
@ -41,10 +41,8 @@ void op::v8::RandomUniform::validate_and_infer_types() {
|
||||
"The rank of the tensor defining output shape must be equal to 1.");
|
||||
|
||||
OPENVINO_SUPPRESS_DEPRECATED_START
|
||||
if (const auto& const_shape = get_constant_from_source(input_value(0))) {
|
||||
if (!evaluate_as_partial_shape(input_value(0), output_shape)) {
|
||||
OPENVINO_SUPPRESS_DEPRECATED_END
|
||||
output_shape = ov::PartialShape(const_shape->cast_vector<int64_t>());
|
||||
} else {
|
||||
output_shape = ov::PartialShape::dynamic(input_shape[0]);
|
||||
}
|
||||
}
|
||||
|
@ -43,6 +43,19 @@ TEST(type_prop, random_uniform_dynamic_shape) {
|
||||
EXPECT_TRUE(r->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
|
||||
}
|
||||
|
||||
TEST(type_prop, random_uniform_dynamic_shape_1) {
|
||||
auto shape = make_shared<opset8::Parameter>(element::i32, PartialShape{{0, 10}, 4, {3, 7}, -1});
|
||||
auto out_shape = make_shared<opset8::ShapeOf>(shape);
|
||||
|
||||
auto min_val = make_shared<opset8::Constant>(element::i64, Shape{}, 5);
|
||||
auto max_val = make_shared<opset8::Constant>(element::i64, Shape{}, 10);
|
||||
|
||||
auto r = make_shared<opset8::RandomUniform>(out_shape, min_val, max_val, element::i64, 100, 200);
|
||||
|
||||
EXPECT_EQ(r->get_output_element_type(0), element::i64);
|
||||
EXPECT_EQ(r->get_output_partial_shape(0), PartialShape({{0, 10}, 4, {3, 7}, -1}));
|
||||
}
|
||||
|
||||
TEST(type_prop, random_uniform_dynamic_rank) {
|
||||
auto out_shape = make_shared<opset8::Parameter>(element::i32, PartialShape::dynamic());
|
||||
auto min_val = make_shared<opset8::Constant>(element::f64, Shape{}, 5);
|
||||
|
Loading…
Reference in New Issue
Block a user