GroupedGatherElimination short circuit (#12380)
* Disable GroupedGatherElimination in case of scalar inputs containing indices * clang format
This commit is contained in:
parent
9d5e799c62
commit
62f79c3222
@ -75,7 +75,16 @@ ngraph::pass::GroupedGatherElimination::GroupedGatherElimination() {
|
|||||||
(curr->input_value(0) != next->input_value(0))) {
|
(curr->input_value(0) != next->input_value(0))) {
|
||||||
++i;
|
++i;
|
||||||
continue;
|
continue;
|
||||||
} // curr and next are the same type of gather which takes data from the same source
|
}
|
||||||
|
|
||||||
|
// Scalar inputs are not supported by Concat and we don't want to throw an exception here.
|
||||||
|
// The transformation should not be applied instead.
|
||||||
|
if (curr->input_value(1).get_partial_shape().same_scheme(Shape{}) ||
|
||||||
|
next->input_value(1).get_partial_shape().same_scheme(Shape{})) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// curr and next are the same type of gather which takes data from the same source
|
||||||
auto joint_indices = ngraph::op::util::make_try_fold<opset1::Concat>(
|
auto joint_indices = ngraph::op::util::make_try_fold<opset1::Concat>(
|
||||||
OutputVector{curr->input_value(1), next->input_value(1)},
|
OutputVector{curr->input_value(1), next->input_value(1)},
|
||||||
0);
|
0);
|
||||||
|
Loading…
Reference in New Issue
Block a user