[PT FE] Get input dtype for aten::sum from graph (#21262)
This commit is contained in:
parent
71836a959d
commit
fe6ae4fbdd
@ -17,7 +17,9 @@ OutputVector translate_sum(const NodeContext& context) {
|
|||||||
bool keep_dims = false;
|
bool keep_dims = false;
|
||||||
ov::Output<ov::Node> axes;
|
ov::Output<ov::Node> axes;
|
||||||
auto data = context.get_input(0);
|
auto data = context.get_input(0);
|
||||||
if (data.get_element_type() == element::boolean) {
|
auto data_dtype = simplified_type_interpret(context.get_input_type(0));
|
||||||
|
if (data.get_element_type() == element::boolean ||
|
||||||
|
(data_dtype.is<element::Type>() && data_dtype.as<element::Type>() == element::boolean)) {
|
||||||
data = context.mark_node(std::make_shared<ov::op::v0::Convert>(data, element::i64));
|
data = context.mark_node(std::make_shared<ov::op::v0::Convert>(data, element::i64));
|
||||||
}
|
}
|
||||||
if (context.input_is_none(1)) {
|
if (context.input_is_none(1)) {
|
||||||
|
Loading…
Reference in New Issue
Block a user