From fe6ae4fbdd1a51540adb6fc17c3be4fa4b2c268f Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Fri, 24 Nov 2023 12:45:42 +0100 Subject: [PATCH] [PT FE] Get input dtype for aten::sum from graph (#21262) --- src/frontends/pytorch/src/op/sum.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/frontends/pytorch/src/op/sum.cpp b/src/frontends/pytorch/src/op/sum.cpp index 5004804074b..41a699d924b 100644 --- a/src/frontends/pytorch/src/op/sum.cpp +++ b/src/frontends/pytorch/src/op/sum.cpp @@ -17,7 +17,9 @@ OutputVector translate_sum(const NodeContext& context) { bool keep_dims = false; ov::Output axes; auto data = context.get_input(0); - if (data.get_element_type() == element::boolean) { + auto data_dtype = simplified_type_interpret(context.get_input_type(0)); + if (data.get_element_type() == element::boolean || + (data_dtype.is() && data_dtype.as() == element::boolean)) { data = context.mark_node(std::make_shared(data, element::i64)); } if (context.input_is_none(1)) {