PushConstantToSubgraph - use vector<bool> for remove_inputs_mask (#15827)
When MultiSubGraphOp has more than 32 inputs, int is not enough to track inputs to be removed. Ticket: 103248
This commit is contained in:
@@ -46,12 +46,12 @@ static void replace_body_parameter(const std::shared_ptr<ov::Model>& body,
|
||||
}
|
||||
|
||||
static void update_multi_sub_graph_op_inputs(const std::shared_ptr<MultiSubGraphOp>& multi_sub_graph_op,
|
||||
int remove_inputs_mask) {
|
||||
const std::vector<bool>& remove_inputs_mask) {
|
||||
int num_subgraphs = static_cast<int>(multi_sub_graph_op->get_internal_subgraphs_size());
|
||||
auto inputs = multi_sub_graph_op->input_values();
|
||||
for (size_t i = multi_sub_graph_op->get_input_size(); i > 0; i--) {
|
||||
const auto input_index = i - 1;
|
||||
if ((remove_inputs_mask & (1 << input_index)) != 0) {
|
||||
if (remove_inputs_mask[input_index]) {
|
||||
// remove MultiSubGraphOp's input if it was marked to be removed
|
||||
// (meaning it was constfolded and pushed to inner subgraph)
|
||||
inputs.erase(inputs.begin() + input_index);
|
||||
@@ -83,7 +83,7 @@ bool ov::pass::PushConstantToSubgraph::run_on_model(const std::shared_ptr<Model>
|
||||
// cache for already constant folded inputs
|
||||
std::unordered_map<size_t, std::shared_ptr<op::v0::Constant>> cache;
|
||||
// bitmask describing which MultiSubGraphOp's input to remove
|
||||
int remove_inputs_mask = 0;
|
||||
std::vector<bool> remove_inputs_mask(multi_sub_graph_op->get_input_size(), false);
|
||||
int num_subgraphs = static_cast<int>(multi_sub_graph_op->get_internal_subgraphs_size());
|
||||
|
||||
for (int body_idx = 0; body_idx < num_subgraphs; body_idx++) {
|
||||
@@ -95,7 +95,7 @@ bool ov::pass::PushConstantToSubgraph::run_on_model(const std::shared_ptr<Model>
|
||||
const auto input_index = desc->m_input_index;
|
||||
const auto constant = try_constantfold_input(multi_sub_graph_op, desc, cache);
|
||||
if (!constant) {
|
||||
remove_inputs_mask &= ~(1 << input_index);
|
||||
remove_inputs_mask[input_index] = false;
|
||||
desc_it++;
|
||||
continue;
|
||||
}
|
||||
@@ -103,12 +103,12 @@ bool ov::pass::PushConstantToSubgraph::run_on_model(const std::shared_ptr<Model>
|
||||
desc_it = descriptions.erase(desc_it);
|
||||
auto& body_param = body_params[body_parameter_index];
|
||||
replace_body_parameter(body, body_param, body_parameter_index, constant, descriptions);
|
||||
remove_inputs_mask |= 1 << input_index;
|
||||
remove_inputs_mask[input_index] = true;
|
||||
result = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (remove_inputs_mask > 0) {
|
||||
if (result) {
|
||||
update_multi_sub_graph_op_inputs(multi_sub_graph_op, remove_inputs_mask);
|
||||
}
|
||||
|
||||
|
||||
@@ -179,3 +179,78 @@ TEST_F(TransformationTestsF, PushConstantToSubgraphIf) {
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, PushConstantToSubgraphLoopMoreThan32Inputs) {
|
||||
int num_const_inputs = 33;
|
||||
{
|
||||
auto trip_count = opset10::Constant::create(element::i32, Shape{}, {2});
|
||||
auto term_cond = opset10::Constant::create(element::boolean, Shape{}, {true});
|
||||
std::shared_ptr<Model> loop_body;
|
||||
{
|
||||
auto X = std::make_shared<opset10::Parameter>(element::f32, Shape{1, 2});
|
||||
ParameterVector params;
|
||||
params.reserve(num_const_inputs + 1);
|
||||
params.push_back(X);
|
||||
NodeVector concat_inputs;
|
||||
concat_inputs.reserve(num_const_inputs + 1);
|
||||
concat_inputs.push_back(X);
|
||||
for (int i = 0; i < num_const_inputs; i++) {
|
||||
params.push_back(std::make_shared<opset10::Parameter>(element::f32, Shape{1, 2}));
|
||||
concat_inputs.push_back(params.back());
|
||||
}
|
||||
auto concat = std::make_shared<opset10::Concat>(concat_inputs, 1);
|
||||
auto cond = opset10::Constant::create(element::boolean, Shape{}, {true});
|
||||
loop_body = std::make_shared<Model>(NodeVector{concat, cond}, params);
|
||||
}
|
||||
auto loop = std::make_shared<opset10::Loop>(trip_count, term_cond);
|
||||
loop->set_function(loop_body);
|
||||
|
||||
auto X = std::make_shared<opset10::Parameter>(element::f32, Shape{2, 2});
|
||||
NodeVector constants;
|
||||
constants.reserve(num_const_inputs);
|
||||
for (int i = 0; i < num_const_inputs; i++) {
|
||||
constants.push_back(opset10::Constant::create(element::f32, Shape{1, 2}, {-2}));
|
||||
}
|
||||
const auto& loop_params = loop_body->get_parameters();
|
||||
loop->set_special_body_ports({-1, 1});
|
||||
loop->set_sliced_input(loop_params[0], X, 0, 1, 1, -1, 0);
|
||||
for (int i = 0; i < num_const_inputs; i++) {
|
||||
loop->set_invariant_input(loop_params[i + 1], constants[i]);
|
||||
}
|
||||
auto out = loop->get_concatenated_slices(loop_body->get_results()[0], 0, 1, 1, -1, 0);
|
||||
function = std::make_shared<Model>(OutputVector{out}, ParameterVector{X});
|
||||
|
||||
manager.register_pass<pass::PushConstantToSubgraph>();
|
||||
}
|
||||
|
||||
{
|
||||
auto trip_count = opset10::Constant::create(element::i32, Shape{}, {2});
|
||||
auto term_cond = opset10::Constant::create(element::boolean, Shape{}, {true});
|
||||
std::shared_ptr<Model> loop_body;
|
||||
{
|
||||
auto constant = opset10::Constant::create(element::f32, Shape{1, 2}, {-2});
|
||||
auto X = std::make_shared<opset10::Parameter>(element::f32, Shape{1, 2});
|
||||
NodeVector concat_inputs;
|
||||
concat_inputs.reserve(num_const_inputs + 1);
|
||||
concat_inputs.push_back(X);
|
||||
for (int i = 0; i < num_const_inputs; i++) {
|
||||
concat_inputs.push_back(opset10::Constant::create(element::f32, Shape{1, 2}, {-2}));
|
||||
}
|
||||
auto concat = std::make_shared<opset10::Concat>(concat_inputs, 1);
|
||||
auto cond = opset10::Constant::create(element::boolean, Shape{}, {true});
|
||||
loop_body = std::make_shared<Model>(NodeVector{concat, cond}, ParameterVector{X});
|
||||
}
|
||||
auto loop = std::make_shared<opset10::Loop>(trip_count, term_cond);
|
||||
loop->set_function(loop_body);
|
||||
|
||||
auto X = std::make_shared<opset10::Parameter>(element::f32, Shape{2, 2});
|
||||
const auto& loop_params = loop_body->get_parameters();
|
||||
loop->set_special_body_ports({-1, 1});
|
||||
loop->set_sliced_input(loop_params[0], X, 0, 1, 1, -1, 0);
|
||||
auto out = loop->get_concatenated_slices(loop_body->get_results()[0], 0, 1, 1, -1, 0);
|
||||
function_ref = std::make_shared<Model>(OutputVector{out}, ParameterVector{X});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::ACCURACY);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user