diff --git a/src/frontends/pytorch/src/op/sum.cpp b/src/frontends/pytorch/src/op/sum.cpp index 5004804074b..41a699d924b 100644 --- a/src/frontends/pytorch/src/op/sum.cpp +++ b/src/frontends/pytorch/src/op/sum.cpp @@ -17,7 +17,9 @@ OutputVector translate_sum(const NodeContext& context) { bool keep_dims = false; ov::Output axes; auto data = context.get_input(0); - if (data.get_element_type() == element::boolean) { + auto data_dtype = simplified_type_interpret(context.get_input_type(0)); + if (data.get_element_type() == element::boolean || + (data_dtype.is() && data_dtype.as() == element::boolean)) { data = context.mark_node(std::make_shared(data, element::i64)); } if (context.input_is_none(1)) {