PushConstantToSubgraph - use vector<bool> for remove_inputs_mask (#15827)

When MultiSubGraphOp has more than 32 inputs, int is not enough
to track inputs to be removed.

Ticket: 103248
This commit is contained in:
Mateusz Tabaka
2023-02-22 15:31:34 +01:00
committed by GitHub
parent 5251133202
commit eaf368a5f5
2 changed files with 81 additions and 6 deletions

View File

@@ -46,12 +46,12 @@ static void replace_body_parameter(const std::shared_ptr<ov::Model>& body,
}
static void update_multi_sub_graph_op_inputs(const std::shared_ptr<MultiSubGraphOp>& multi_sub_graph_op,
int remove_inputs_mask) {
const std::vector<bool>& remove_inputs_mask) {
int num_subgraphs = static_cast<int>(multi_sub_graph_op->get_internal_subgraphs_size());
auto inputs = multi_sub_graph_op->input_values();
for (size_t i = multi_sub_graph_op->get_input_size(); i > 0; i--) {
const auto input_index = i - 1;
if ((remove_inputs_mask & (1 << input_index)) != 0) {
if (remove_inputs_mask[input_index]) {
// remove MultiSubGraphOp's input if it was marked to be removed
// (meaning it was constfolded and pushed to inner subgraph)
inputs.erase(inputs.begin() + input_index);
@@ -83,7 +83,7 @@ bool ov::pass::PushConstantToSubgraph::run_on_model(const std::shared_ptr<Model>
// cache for already constant folded inputs
std::unordered_map<size_t, std::shared_ptr<op::v0::Constant>> cache;
// bitmask describing which MultiSubGraphOp's input to remove
int remove_inputs_mask = 0;
std::vector<bool> remove_inputs_mask(multi_sub_graph_op->get_input_size(), false);
int num_subgraphs = static_cast<int>(multi_sub_graph_op->get_internal_subgraphs_size());
for (int body_idx = 0; body_idx < num_subgraphs; body_idx++) {
@@ -95,7 +95,7 @@ bool ov::pass::PushConstantToSubgraph::run_on_model(const std::shared_ptr<Model>
const auto input_index = desc->m_input_index;
const auto constant = try_constantfold_input(multi_sub_graph_op, desc, cache);
if (!constant) {
remove_inputs_mask &= ~(1 << input_index);
remove_inputs_mask[input_index] = false;
desc_it++;
continue;
}
@@ -103,12 +103,12 @@ bool ov::pass::PushConstantToSubgraph::run_on_model(const std::shared_ptr<Model>
desc_it = descriptions.erase(desc_it);
auto& body_param = body_params[body_parameter_index];
replace_body_parameter(body, body_param, body_parameter_index, constant, descriptions);
remove_inputs_mask |= 1 << input_index;
remove_inputs_mask[input_index] = true;
result = true;
}
}
if (remove_inputs_mask > 0) {
if (result) {
update_multi_sub_graph_op_inputs(multi_sub_graph_op, remove_inputs_mask);
}

View File

@@ -179,3 +179,78 @@ TEST_F(TransformationTestsF, PushConstantToSubgraphIf) {
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}
TEST_F(TransformationTestsF, PushConstantToSubgraphLoopMoreThan32Inputs) {
int num_const_inputs = 33;
{
auto trip_count = opset10::Constant::create(element::i32, Shape{}, {2});
auto term_cond = opset10::Constant::create(element::boolean, Shape{}, {true});
std::shared_ptr<Model> loop_body;
{
auto X = std::make_shared<opset10::Parameter>(element::f32, Shape{1, 2});
ParameterVector params;
params.reserve(num_const_inputs + 1);
params.push_back(X);
NodeVector concat_inputs;
concat_inputs.reserve(num_const_inputs + 1);
concat_inputs.push_back(X);
for (int i = 0; i < num_const_inputs; i++) {
params.push_back(std::make_shared<opset10::Parameter>(element::f32, Shape{1, 2}));
concat_inputs.push_back(params.back());
}
auto concat = std::make_shared<opset10::Concat>(concat_inputs, 1);
auto cond = opset10::Constant::create(element::boolean, Shape{}, {true});
loop_body = std::make_shared<Model>(NodeVector{concat, cond}, params);
}
auto loop = std::make_shared<opset10::Loop>(trip_count, term_cond);
loop->set_function(loop_body);
auto X = std::make_shared<opset10::Parameter>(element::f32, Shape{2, 2});
NodeVector constants;
constants.reserve(num_const_inputs);
for (int i = 0; i < num_const_inputs; i++) {
constants.push_back(opset10::Constant::create(element::f32, Shape{1, 2}, {-2}));
}
const auto& loop_params = loop_body->get_parameters();
loop->set_special_body_ports({-1, 1});
loop->set_sliced_input(loop_params[0], X, 0, 1, 1, -1, 0);
for (int i = 0; i < num_const_inputs; i++) {
loop->set_invariant_input(loop_params[i + 1], constants[i]);
}
auto out = loop->get_concatenated_slices(loop_body->get_results()[0], 0, 1, 1, -1, 0);
function = std::make_shared<Model>(OutputVector{out}, ParameterVector{X});
manager.register_pass<pass::PushConstantToSubgraph>();
}
{
auto trip_count = opset10::Constant::create(element::i32, Shape{}, {2});
auto term_cond = opset10::Constant::create(element::boolean, Shape{}, {true});
std::shared_ptr<Model> loop_body;
{
auto constant = opset10::Constant::create(element::f32, Shape{1, 2}, {-2});
auto X = std::make_shared<opset10::Parameter>(element::f32, Shape{1, 2});
NodeVector concat_inputs;
concat_inputs.reserve(num_const_inputs + 1);
concat_inputs.push_back(X);
for (int i = 0; i < num_const_inputs; i++) {
concat_inputs.push_back(opset10::Constant::create(element::f32, Shape{1, 2}, {-2}));
}
auto concat = std::make_shared<opset10::Concat>(concat_inputs, 1);
auto cond = opset10::Constant::create(element::boolean, Shape{}, {true});
loop_body = std::make_shared<Model>(NodeVector{concat, cond}, ParameterVector{X});
}
auto loop = std::make_shared<opset10::Loop>(trip_count, term_cond);
loop->set_function(loop_body);
auto X = std::make_shared<opset10::Parameter>(element::f32, Shape{2, 2});
const auto& loop_params = loop_body->get_parameters();
loop->set_special_body_ports({-1, 1});
loop->set_sliced_input(loop_params[0], X, 0, 1, 1, -1, 0);
auto out = loop->get_concatenated_slices(loop_body->get_results()[0], 0, 1, 1, -1, 0);
function_ref = std::make_shared<Model>(OutputVector{out}, ParameterVector{X});
}
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
}