Cache const-folded value for the same input connected to MultiSubGraphOp during PushConstantToSubgraph (#18067)
This commit is contained in:
parent
8a2af2418f
commit
92fbf96300
@ -14,16 +14,16 @@ using MultiSubGraphOp = ov::op::util::MultiSubGraphOp;
|
||||
static std::shared_ptr<ov::op::v0::Constant> try_constantfold_input(
|
||||
const std::shared_ptr<MultiSubGraphOp>& op,
|
||||
const MultiSubGraphOp::InputDescription::Ptr& input_desc,
|
||||
std::unordered_map<size_t, std::shared_ptr<ov::op::v0::Constant>>& cache) {
|
||||
std::map<ov::Output<ov::Node>, std::shared_ptr<ov::op::v0::Constant>>& cache) {
|
||||
if (!std::dynamic_pointer_cast<MultiSubGraphOp::InvariantInputDescription>(input_desc)) {
|
||||
return nullptr;
|
||||
}
|
||||
const auto input_index = input_desc->m_input_index;
|
||||
auto it = cache.find(input_index);
|
||||
const auto outer_input = op->input_value(input_desc->m_input_index);
|
||||
auto it = cache.find(outer_input);
|
||||
if (it == cache.end()) {
|
||||
auto constant = ov::util::constantfold_subgraph(op->input_value(input_index));
|
||||
auto constant = ov::util::constantfold_subgraph(outer_input);
|
||||
if (constant) {
|
||||
cache.insert({input_index, constant});
|
||||
cache.insert({outer_input, constant});
|
||||
}
|
||||
return constant;
|
||||
}
|
||||
@ -81,7 +81,7 @@ bool ov::pass::PushConstantToSubgraph::run_on_model(const std::shared_ptr<Model>
|
||||
}
|
||||
|
||||
// cache for already constant folded inputs
|
||||
std::unordered_map<size_t, std::shared_ptr<op::v0::Constant>> cache;
|
||||
std::map<ov::Output<ov::Node>, std::shared_ptr<op::v0::Constant>> cache;
|
||||
// bitmask describing which MultiSubGraphOp's input to remove
|
||||
std::vector<bool> remove_inputs_mask(multi_sub_graph_op->get_input_size(), false);
|
||||
int num_subgraphs = static_cast<int>(multi_sub_graph_op->get_internal_subgraphs_size());
|
||||
|
Loading…
Reference in New Issue
Block a user