Modify the condition making batch interpretation true/false (#18283)

* Modify the condition making batch interpretation true/false

- When the user is Convert for Constant node, and tensor is 1d,
- Set needBatchInterpretation to true

* Narrow down the range of the condition

* Merge the condition

* Add additional condition not to check self node

* Fix incomplete condition

* Check if all inputs to binary eltwise is 1d

* Change code style
This commit is contained in:
David Nam
2023-07-28 14:42:07 +09:00
committed by GitHub
parent c9001980ef
commit 1fcdc90989

View File

@@ -96,6 +96,40 @@ static void CreateConstantOp(Program& p, const std::shared_ptr<ngraph::op::v0::C
}
};
auto is_binary_eltwise = [&] (ov::Node* op) -> bool {
if (ngraph::op::is_binary_elementwise_arithmetic(op) ||
ngraph::op::is_binary_elementwise_logical(op) ||
ngraph::op::is_binary_elementwise_comparison(op)) {
return true;
} else {
return false;
}
};
auto is_all_inputs_1d = [&] (ov::Node* op) -> bool {
for (size_t i = 0; i < op->get_input_size(); i++) {
auto& in_shape = op->get_input_partial_shape(i);
if (in_shape.size() > 1)
return false;
}
return true;
};
auto is_convert_into_binary_eltwise = [&] (ov::Node* op) -> bool {
if (ngraph::is_type<ngraph::op::v0::Convert>(op)) {
for (size_t i = 0; i < op->get_output_size(); ++i) {
auto convertUsers = op->get_output_target_inputs(i);
for (auto user : convertUsers) {
if (is_binary_eltwise(user.get_node()) &&
is_all_inputs_1d(user.get_node())) {
return true;
}
}
}
}
return false;
};
// WA to inconsistency between input and const 1d tensors
// For Concat along batch we go with batch interpretation
// For Gather input we go with batch interpretation
@@ -106,17 +140,9 @@ static void CreateConstantOp(Program& p, const std::shared_ptr<ngraph::op::v0::C
if (castedOp->get_axis() == 0) {
consts[op].needsBatchInterpretation = constDims.size() == 1;
}
} else if (ngraph::op::is_binary_elementwise_arithmetic(outOp) ||
ngraph::op::is_binary_elementwise_logical(outOp) ||
ngraph::op::is_binary_elementwise_comparison(outOp) ||
ngraph::is_type<ngraph::op::v0::SquaredDifference>(outOp)) {
bool all_inputs_1d = true;
for (size_t j = 0; j < outOp->get_input_size(); j++) {
auto& in_shape = outOp->get_input_partial_shape(j);
if (in_shape.size() > 1)
all_inputs_1d = false;
}
consts[op].needsBatchInterpretation = all_inputs_1d && constDims.size() == 1;
} else if (((is_binary_eltwise(outOp) || ngraph::is_type<ngraph::op::v0::SquaredDifference>(outOp)) && is_all_inputs_1d(outOp)) ||
is_convert_into_binary_eltwise(outOp)) {
consts[op].needsBatchInterpretation = constDims.size() == 1;
} else if (ngraph::is_type<ngraph::op::v1::Gather>(outOp) ||
ngraph::is_type<ngraph::op::v7::Gather>(outOp) ||
ngraph::is_type<ngraph::op::v8::Gather>(outOp) ||