Dimension tracking: Always output batch the same as it was before (#9953)

This commit is contained in:
Evgenya Stepyreva 2022-01-27 16:29:16 +03:00 committed by GitHub
parent bd7a5db029
commit 1d5352dae3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 12 additions and 10 deletions

View File

@ -35,7 +35,8 @@ namespace batch_util {
void mark_no_batch(const std::shared_ptr<ov::opset1::Parameter> &parameter, P2Btype &map);
void mark_layout_independent_batch(const std::shared_ptr<ov::opset1::Parameter> &parameter, const std::shared_ptr<ov::Node> & result, P2Btype &map);
void mark_with_unique_dimension_labels(const std::shared_ptr<Model> &m, const ov::DimensionTracker &dt);
void restore_original_dimensions_except_batch(const std::map<std::shared_ptr<ov::opset1::Parameter>, ov::PartialShape>& parameter_to_shape);
void restore_original_dimensions(
const std::map<std::shared_ptr<ov::opset1::Parameter>, ov::PartialShape>& parameter_to_shape, bool leave_batch_dynamic = true);
bool check_batch_tracks_through_all_the_nodes(const std::shared_ptr<ov::Model>& m);
P2Btype find_batch(const std::shared_ptr<ov::Model> &m);
} // namespace batch_util

View File

@ -158,8 +158,8 @@ P2Btype ov::batch_util::find_batch(const std::shared_ptr<ov::Model>& f) {
return parameter_to_batch_labels;
}
void ov::batch_util::restore_original_dimensions_except_batch(
const std::map<std::shared_ptr<ov::opset1::Parameter>, ov::PartialShape>& parameter_to_shape) {
void ov::batch_util::restore_original_dimensions(
const std::map<std::shared_ptr<ov::opset1::Parameter>, ov::PartialShape>& parameter_to_shape, bool leave_batch_dynamic) {
for (const auto& item : parameter_to_shape) {
const auto& batch_marked_shape = item.first->get_partial_shape();
auto original_shape = item.second;
@ -168,7 +168,8 @@ void ov::batch_util::restore_original_dimensions_except_batch(
for (size_t n = 0; n < batch_marked_shape.size(); ++n) {
if (const auto& label = ov::DimensionTracker::get_label(batch_marked_shape[n])) {
original_shape[n] = Dimension::dynamic();
if (leave_batch_dynamic)
original_shape[n] = Dimension::dynamic();
ov::DimensionTracker::set_label(original_shape[n], label);
}
}
@ -238,18 +239,18 @@ bool ov::pass::FindBatch::run_on_model(const std::shared_ptr<ov::Model>& m) {
ov::batch_util::find_batch(m);
ov::batch_util::restore_original_dimensions_except_batch(parameter_to_shape);
ov::batch_util::restore_original_dimensions(parameter_to_shape);
m->validate_nodes_and_infer_types();
bool failed_to_propagate_batch = ov::batch_util::check_batch_tracks_through_all_the_nodes(m);
if (failed_to_propagate_batch) {
// return function to the initial state
for (const auto& item : parameter_to_shape) {
if (failed_to_propagate_batch) { // restore original input shape with labels
for (const auto& item : parameter_to_shape)
item.first->set_partial_shape(item.second);
}
m->validate_nodes_and_infer_types();
} else { // restore original input shape with batch labels
ov::batch_util::restore_original_dimensions(parameter_to_shape, false);
}
m->validate_nodes_and_infer_types();
return true;
}