Cache const-folded value for the same input connected to MultiSubGraphOp during PushConstantToSubgraph (#18067)

This commit is contained in:
Mateusz Bencer 2023-06-23 17:45:55 +02:00 committed by GitHub
parent 8a2af2418f
commit 92fbf96300
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -14,16 +14,16 @@ using MultiSubGraphOp = ov::op::util::MultiSubGraphOp;
static std::shared_ptr<ov::op::v0::Constant> try_constantfold_input(
const std::shared_ptr<MultiSubGraphOp>& op,
const MultiSubGraphOp::InputDescription::Ptr& input_desc,
std::unordered_map<size_t, std::shared_ptr<ov::op::v0::Constant>>& cache) {
std::map<ov::Output<ov::Node>, std::shared_ptr<ov::op::v0::Constant>>& cache) {
if (!std::dynamic_pointer_cast<MultiSubGraphOp::InvariantInputDescription>(input_desc)) {
return nullptr;
}
const auto input_index = input_desc->m_input_index;
auto it = cache.find(input_index);
const auto outer_input = op->input_value(input_desc->m_input_index);
auto it = cache.find(outer_input);
if (it == cache.end()) {
auto constant = ov::util::constantfold_subgraph(op->input_value(input_index));
auto constant = ov::util::constantfold_subgraph(outer_input);
if (constant) {
cache.insert({input_index, constant});
cache.insert({outer_input, constant});
}
return constant;
}
@ -81,7 +81,7 @@ bool ov::pass::PushConstantToSubgraph::run_on_model(const std::shared_ptr<Model>
}
// cache for already constant folded inputs
std::unordered_map<size_t, std::shared_ptr<op::v0::Constant>> cache;
std::map<ov::Output<ov::Node>, std::shared_ptr<op::v0::Constant>> cache;
// bitmask describing which MultiSubGraphOp's input to remove
std::vector<bool> remove_inputs_mask(multi_sub_graph_op->get_input_size(), false);
int num_subgraphs = static_cast<int>(multi_sub_graph_op->get_internal_subgraphs_size());