* 51145 - PriorBox dynamic shape propagation

* code style
This commit is contained in:
Evgenya Stepyreva 2022-01-21 00:32:04 +03:00 committed by GitHub
parent e39614954b
commit 5aa43d560a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 20 deletions

View File

@ -51,18 +51,16 @@ void op::v0::PriorBox::validate_and_infer_types() {
set_input_is_relevant_to_shape(0);
if (auto const_shape = get_constant_from_source(input_value(0))) {
PartialShape spatials;
if (evaluate_as_partial_shape(input_value(0), spatials)) {
NODE_VALIDATION_CHECK(this,
shape_size(const_shape->get_shape()) == 2,
spatials.rank().is_static() && spatials.size() == 2,
"Layer shape must have rank 2",
const_shape->get_shape());
spatials);
auto layer_shape = const_shape->get_shape_val();
set_output_type(
0,
set_output_type(0,
element::f32,
ov::Shape{2, 4 * layer_shape[0] * layer_shape[1] * static_cast<size_t>(number_of_priors(m_attrs))});
ov::PartialShape{2, spatials[0] * spatials[1] * Dimension(4 * number_of_priors(m_attrs))});
} else {
set_output_type(0, element::f32, ov::PartialShape{2, Dimension::dynamic()});
}
@ -243,18 +241,16 @@ void op::v8::PriorBox::validate_and_infer_types() {
set_input_is_relevant_to_shape(0);
if (auto const_shape = get_constant_from_source(input_value(0))) {
PartialShape spatials;
if (evaluate_as_partial_shape(input_value(0), spatials)) {
NODE_VALIDATION_CHECK(this,
shape_size(const_shape->get_shape()) == 2,
spatials.rank().is_static() && spatials.size() == 2,
"Layer shape must have rank 2",
const_shape->get_shape());
spatials);
auto layer_shape = const_shape->get_shape_val();
set_output_type(
0,
set_output_type(0,
element::f32,
ov::Shape{2, 4 * layer_shape[0] * layer_shape[1] * static_cast<size_t>(number_of_priors(m_attrs))});
ov::PartialShape{2, spatials[0] * spatials[1] * Dimension(4 * number_of_priors(m_attrs))});
} else {
set_output_type(0, element::f32, ov::PartialShape{2, Dimension::dynamic()});
}

View File

@ -92,11 +92,11 @@ void op::ROIPooling::validate_and_infer_types() {
Dimension{static_cast<int64_t>(m_output_size[0])},
Dimension{static_cast<int64_t>(m_output_size[1])}}};
if (coords_ps.rank().is_static() && coords_ps[0].is_static()) {
if (coords_ps.rank().is_static()) {
output_shape[0] = coords_ps[0];
}
if (feat_maps_ps.rank().is_static() && feat_maps_ps[1].is_static()) {
if (feat_maps_ps.rank().is_static()) {
output_shape[1] = feat_maps_ps[1];
}