Dimension tracking: Always output batch the same as it was before (#9953)
This commit is contained in:
parent
bd7a5db029
commit
1d5352dae3
@ -35,7 +35,8 @@ namespace batch_util {
|
||||
void mark_no_batch(const std::shared_ptr<ov::opset1::Parameter> ¶meter, P2Btype &map);
|
||||
void mark_layout_independent_batch(const std::shared_ptr<ov::opset1::Parameter> ¶meter, const std::shared_ptr<ov::Node> & result, P2Btype &map);
|
||||
void mark_with_unique_dimension_labels(const std::shared_ptr<Model> &m, const ov::DimensionTracker &dt);
|
||||
void restore_original_dimensions_except_batch(const std::map<std::shared_ptr<ov::opset1::Parameter>, ov::PartialShape>& parameter_to_shape);
|
||||
void restore_original_dimensions(
|
||||
const std::map<std::shared_ptr<ov::opset1::Parameter>, ov::PartialShape>& parameter_to_shape, bool leave_batch_dynamic = true);
|
||||
bool check_batch_tracks_through_all_the_nodes(const std::shared_ptr<ov::Model>& m);
|
||||
P2Btype find_batch(const std::shared_ptr<ov::Model> &m);
|
||||
} // namespace batch_util
|
||||
|
@ -158,8 +158,8 @@ P2Btype ov::batch_util::find_batch(const std::shared_ptr<ov::Model>& f) {
|
||||
return parameter_to_batch_labels;
|
||||
}
|
||||
|
||||
void ov::batch_util::restore_original_dimensions_except_batch(
|
||||
const std::map<std::shared_ptr<ov::opset1::Parameter>, ov::PartialShape>& parameter_to_shape) {
|
||||
void ov::batch_util::restore_original_dimensions(
|
||||
const std::map<std::shared_ptr<ov::opset1::Parameter>, ov::PartialShape>& parameter_to_shape, bool leave_batch_dynamic) {
|
||||
for (const auto& item : parameter_to_shape) {
|
||||
const auto& batch_marked_shape = item.first->get_partial_shape();
|
||||
auto original_shape = item.second;
|
||||
@ -168,7 +168,8 @@ void ov::batch_util::restore_original_dimensions_except_batch(
|
||||
|
||||
for (size_t n = 0; n < batch_marked_shape.size(); ++n) {
|
||||
if (const auto& label = ov::DimensionTracker::get_label(batch_marked_shape[n])) {
|
||||
original_shape[n] = Dimension::dynamic();
|
||||
if (leave_batch_dynamic)
|
||||
original_shape[n] = Dimension::dynamic();
|
||||
ov::DimensionTracker::set_label(original_shape[n], label);
|
||||
}
|
||||
}
|
||||
@ -238,18 +239,18 @@ bool ov::pass::FindBatch::run_on_model(const std::shared_ptr<ov::Model>& m) {
|
||||
|
||||
ov::batch_util::find_batch(m);
|
||||
|
||||
ov::batch_util::restore_original_dimensions_except_batch(parameter_to_shape);
|
||||
ov::batch_util::restore_original_dimensions(parameter_to_shape);
|
||||
|
||||
m->validate_nodes_and_infer_types();
|
||||
|
||||
bool failed_to_propagate_batch = ov::batch_util::check_batch_tracks_through_all_the_nodes(m);
|
||||
|
||||
if (failed_to_propagate_batch) {
|
||||
// return function to the initial state
|
||||
for (const auto& item : parameter_to_shape) {
|
||||
if (failed_to_propagate_batch) { // restore original input shape with labels
|
||||
for (const auto& item : parameter_to_shape)
|
||||
item.first->set_partial_shape(item.second);
|
||||
}
|
||||
m->validate_nodes_and_infer_types();
|
||||
} else { // restore original input shape with batch labels
|
||||
ov::batch_util::restore_original_dimensions(parameter_to_shape, false);
|
||||
}
|
||||
m->validate_nodes_and_infer_types();
|
||||
return true;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user