[PT FE] Get input dtype for aten::sum from graph (#21262)

This commit is contained in:
Maxim Vafin 2023-11-24 12:45:42 +01:00 committed by GitHub
parent 71836a959d
commit fe6ae4fbdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,7 +17,9 @@ OutputVector translate_sum(const NodeContext& context) {
bool keep_dims = false; bool keep_dims = false;
ov::Output<ov::Node> axes; ov::Output<ov::Node> axes;
auto data = context.get_input(0); auto data = context.get_input(0);
if (data.get_element_type() == element::boolean) { auto data_dtype = simplified_type_interpret(context.get_input_type(0));
if (data.get_element_type() == element::boolean ||
(data_dtype.is<element::Type>() && data_dtype.as<element::Type>() == element::boolean)) {
data = context.mark_node(std::make_shared<ov::op::v0::Convert>(data, element::i64)); data = context.mark_node(std::make_shared<ov::op::v0::Convert>(data, element::i64));
} }
if (context.input_is_none(1)) { if (context.input_is_none(1)) {