eliminate broadcast node in masked_fill (#19595)
This commit is contained in:
parent
4eadef9e61
commit
1b5f428752
@ -23,11 +23,9 @@ OutputVector translate_masked_fill(const NodeContext& context) {
|
||||
auto data = context.get_input(0);
|
||||
auto mask = context.get_input(1);
|
||||
auto value = context.get_input(2);
|
||||
auto data_shape = context.mark_node(std::make_shared<v3::ShapeOf>(data, element::i32));
|
||||
value = context.mark_node(std::make_shared<v1::ConvertLike>(value, data));
|
||||
auto broadcasted_value = context.mark_node(std::make_shared<v3::Broadcast>(value, data_shape));
|
||||
auto bool_mask = context.mark_node(std::make_shared<v0::Convert>(mask, element::boolean));
|
||||
return {context.mark_node(std::make_shared<v1::Select>(bool_mask, broadcasted_value, data))};
|
||||
return {context.mark_node(std::make_shared<v1::Select>(bool_mask, value, data))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
Loading…
Reference in New Issue
Block a user