eliminate broadcast node in masked_fill (#19595)
This commit is contained in:
parent
4eadef9e61
commit
1b5f428752
@ -23,11 +23,9 @@ OutputVector translate_masked_fill(const NodeContext& context) {
|
|||||||
auto data = context.get_input(0);
|
auto data = context.get_input(0);
|
||||||
auto mask = context.get_input(1);
|
auto mask = context.get_input(1);
|
||||||
auto value = context.get_input(2);
|
auto value = context.get_input(2);
|
||||||
auto data_shape = context.mark_node(std::make_shared<v3::ShapeOf>(data, element::i32));
|
|
||||||
value = context.mark_node(std::make_shared<v1::ConvertLike>(value, data));
|
value = context.mark_node(std::make_shared<v1::ConvertLike>(value, data));
|
||||||
auto broadcasted_value = context.mark_node(std::make_shared<v3::Broadcast>(value, data_shape));
|
|
||||||
auto bool_mask = context.mark_node(std::make_shared<v0::Convert>(mask, element::boolean));
|
auto bool_mask = context.mark_node(std::make_shared<v0::Convert>(mask, element::boolean));
|
||||||
return {context.mark_node(std::make_shared<v1::Select>(bool_mask, broadcasted_value, data))};
|
return {context.mark_node(std::make_shared<v1::Select>(bool_mask, value, data))};
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace op
|
} // namespace op
|
||||||
|
Loading…
Reference in New Issue
Block a user