Modify the condition making batch interpretation true/false (#18283)
* Modify the condition making batch interpretation true/false - When the user is Convert for Constant node, and tensor is 1d, - Set needBatchInterpretation to true * Narrow down the range of the condition * Merge the condition * Add additional condition not to check self node * Fix incomplete condition * Check if all inputs to binary eltwise is 1d * Change code style
This commit is contained in:
@@ -96,6 +96,40 @@ static void CreateConstantOp(Program& p, const std::shared_ptr<ngraph::op::v0::C
|
||||
}
|
||||
};
|
||||
|
||||
auto is_binary_eltwise = [&] (ov::Node* op) -> bool {
|
||||
if (ngraph::op::is_binary_elementwise_arithmetic(op) ||
|
||||
ngraph::op::is_binary_elementwise_logical(op) ||
|
||||
ngraph::op::is_binary_elementwise_comparison(op)) {
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
auto is_all_inputs_1d = [&] (ov::Node* op) -> bool {
|
||||
for (size_t i = 0; i < op->get_input_size(); i++) {
|
||||
auto& in_shape = op->get_input_partial_shape(i);
|
||||
if (in_shape.size() > 1)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto is_convert_into_binary_eltwise = [&] (ov::Node* op) -> bool {
|
||||
if (ngraph::is_type<ngraph::op::v0::Convert>(op)) {
|
||||
for (size_t i = 0; i < op->get_output_size(); ++i) {
|
||||
auto convertUsers = op->get_output_target_inputs(i);
|
||||
for (auto user : convertUsers) {
|
||||
if (is_binary_eltwise(user.get_node()) &&
|
||||
is_all_inputs_1d(user.get_node())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// WA to inconsistency between input and const 1d tensors
|
||||
// For Concat along batch we go with batch interpretation
|
||||
// For Gather input we go with batch interpretation
|
||||
@@ -106,17 +140,9 @@ static void CreateConstantOp(Program& p, const std::shared_ptr<ngraph::op::v0::C
|
||||
if (castedOp->get_axis() == 0) {
|
||||
consts[op].needsBatchInterpretation = constDims.size() == 1;
|
||||
}
|
||||
} else if (ngraph::op::is_binary_elementwise_arithmetic(outOp) ||
|
||||
ngraph::op::is_binary_elementwise_logical(outOp) ||
|
||||
ngraph::op::is_binary_elementwise_comparison(outOp) ||
|
||||
ngraph::is_type<ngraph::op::v0::SquaredDifference>(outOp)) {
|
||||
bool all_inputs_1d = true;
|
||||
for (size_t j = 0; j < outOp->get_input_size(); j++) {
|
||||
auto& in_shape = outOp->get_input_partial_shape(j);
|
||||
if (in_shape.size() > 1)
|
||||
all_inputs_1d = false;
|
||||
}
|
||||
consts[op].needsBatchInterpretation = all_inputs_1d && constDims.size() == 1;
|
||||
} else if (((is_binary_eltwise(outOp) || ngraph::is_type<ngraph::op::v0::SquaredDifference>(outOp)) && is_all_inputs_1d(outOp)) ||
|
||||
is_convert_into_binary_eltwise(outOp)) {
|
||||
consts[op].needsBatchInterpretation = constDims.size() == 1;
|
||||
} else if (ngraph::is_type<ngraph::op::v1::Gather>(outOp) ||
|
||||
ngraph::is_type<ngraph::op::v7::Gather>(outOp) ||
|
||||
ngraph::is_type<ngraph::op::v8::Gather>(outOp) ||
|
||||
|
||||
Reference in New Issue
Block a user