[PT FE] Get input dtype for aten::sum from graph (#21262)
This commit is contained in:
parent
71836a959d
commit
fe6ae4fbdd
@ -17,7 +17,9 @@ OutputVector translate_sum(const NodeContext& context) {
|
||||
bool keep_dims = false;
|
||||
ov::Output<ov::Node> axes;
|
||||
auto data = context.get_input(0);
|
||||
if (data.get_element_type() == element::boolean) {
|
||||
auto data_dtype = simplified_type_interpret(context.get_input_type(0));
|
||||
if (data.get_element_type() == element::boolean ||
|
||||
(data_dtype.is<element::Type>() && data_dtype.as<element::Type>() == element::boolean)) {
|
||||
data = context.mark_node(std::make_shared<ov::op::v0::Convert>(data, element::i64));
|
||||
}
|
||||
if (context.input_is_none(1)) {
|
||||
|
Loading…
Reference in New Issue
Block a user