GroupedGatherElimination short circuit (#12380)
* Disable GroupedGatherElimination in case of scalar inputs containing indices * clang format
This commit is contained in:
parent
9d5e799c62
commit
62f79c3222
@ -75,7 +75,16 @@ ngraph::pass::GroupedGatherElimination::GroupedGatherElimination() {
|
||||
(curr->input_value(0) != next->input_value(0))) {
|
||||
++i;
|
||||
continue;
|
||||
} // curr and next are the same type of gather which takes data from the same source
|
||||
}
|
||||
|
||||
// Scalar inputs are not supported by Concat and we don't want to throw an exception here.
|
||||
// The transformation should not be applied instead.
|
||||
if (curr->input_value(1).get_partial_shape().same_scheme(Shape{}) ||
|
||||
next->input_value(1).get_partial_shape().same_scheme(Shape{})) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// curr and next are the same type of gather which takes data from the same source
|
||||
auto joint_indices = ngraph::op::util::make_try_fold<opset1::Concat>(
|
||||
OutputVector{curr->input_value(1), next->input_value(1)},
|
||||
0);
|
||||
|
Loading…
Reference in New Issue
Block a user