GroupedGatherElimination short circuit (#12380)

* Disable GroupedGatherElimination in case of scalar inputs containing indices

* clang format
This commit is contained in:
Tomasz Dołbniak 2022-08-03 11:47:22 +02:00 committed by GitHub
parent 9d5e799c62
commit 62f79c3222
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -75,7 +75,16 @@ ngraph::pass::GroupedGatherElimination::GroupedGatherElimination() {
(curr->input_value(0) != next->input_value(0))) {
++i;
continue;
} // curr and next are the same type of gather which takes data from the same source
}
// Scalar inputs are not supported by Concat and we don't want to throw an exception here.
// The transformation should not be applied instead.
if (curr->input_value(1).get_partial_shape().same_scheme(Shape{}) ||
next->input_value(1).get_partial_shape().same_scheme(Shape{})) {
return false;
}
// curr and next are the same type of gather which takes data from the same source
auto joint_indices = ngraph::op::util::make_try_fold<opset1::Concat>(
OutputVector{curr->input_value(1), next->input_value(1)},
0);