* 51145 - PriorBox dynamic shape propagation

* code style
This commit is contained in:
Evgenya Stepyreva 2022-01-21 00:32:04 +03:00 committed by GitHub
parent e39614954b
commit 5aa43d560a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 20 deletions

View File

@ -51,18 +51,16 @@ void op::v0::PriorBox::validate_and_infer_types() {
set_input_is_relevant_to_shape(0); set_input_is_relevant_to_shape(0);
if (auto const_shape = get_constant_from_source(input_value(0))) { PartialShape spatials;
if (evaluate_as_partial_shape(input_value(0), spatials)) {
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
shape_size(const_shape->get_shape()) == 2, spatials.rank().is_static() && spatials.size() == 2,
"Layer shape must have rank 2", "Layer shape must have rank 2",
const_shape->get_shape()); spatials);
auto layer_shape = const_shape->get_shape_val(); set_output_type(0,
element::f32,
set_output_type( ov::PartialShape{2, spatials[0] * spatials[1] * Dimension(4 * number_of_priors(m_attrs))});
0,
element::f32,
ov::Shape{2, 4 * layer_shape[0] * layer_shape[1] * static_cast<size_t>(number_of_priors(m_attrs))});
} else { } else {
set_output_type(0, element::f32, ov::PartialShape{2, Dimension::dynamic()}); set_output_type(0, element::f32, ov::PartialShape{2, Dimension::dynamic()});
} }
@ -243,18 +241,16 @@ void op::v8::PriorBox::validate_and_infer_types() {
set_input_is_relevant_to_shape(0); set_input_is_relevant_to_shape(0);
if (auto const_shape = get_constant_from_source(input_value(0))) { PartialShape spatials;
if (evaluate_as_partial_shape(input_value(0), spatials)) {
NODE_VALIDATION_CHECK(this, NODE_VALIDATION_CHECK(this,
shape_size(const_shape->get_shape()) == 2, spatials.rank().is_static() && spatials.size() == 2,
"Layer shape must have rank 2", "Layer shape must have rank 2",
const_shape->get_shape()); spatials);
auto layer_shape = const_shape->get_shape_val(); set_output_type(0,
element::f32,
set_output_type( ov::PartialShape{2, spatials[0] * spatials[1] * Dimension(4 * number_of_priors(m_attrs))});
0,
element::f32,
ov::Shape{2, 4 * layer_shape[0] * layer_shape[1] * static_cast<size_t>(number_of_priors(m_attrs))});
} else { } else {
set_output_type(0, element::f32, ov::PartialShape{2, Dimension::dynamic()}); set_output_type(0, element::f32, ov::PartialShape{2, Dimension::dynamic()});
} }

View File

@ -92,11 +92,11 @@ void op::ROIPooling::validate_and_infer_types() {
Dimension{static_cast<int64_t>(m_output_size[0])}, Dimension{static_cast<int64_t>(m_output_size[0])},
Dimension{static_cast<int64_t>(m_output_size[1])}}}; Dimension{static_cast<int64_t>(m_output_size[1])}}};
if (coords_ps.rank().is_static() && coords_ps[0].is_static()) { if (coords_ps.rank().is_static()) {
output_shape[0] = coords_ps[0]; output_shape[0] = coords_ps[0];
} }
if (feat_maps_ps.rank().is_static() && feat_maps_ps[1].is_static()) { if (feat_maps_ps.rank().is_static()) {
output_shape[1] = feat_maps_ps[1]; output_shape[1] = feat_maps_ps[1];
} }