diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index 79dbdb156c6..386ac393d35 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -428,34 +428,46 @@ void align_eltwise_input_types(const NodeContext& context, Output& lhs, Ou // if div we need to also align float types to highest bitness regardless of scalar if (!align_scalars) rhs_dst_type = element::f32; - } else if (is_lhs_scalar) { + } else if (is_lhs_scalar && rhs_type != element::boolean) { lhs = context.mark_node(std::make_shared(lhs, rhs)); return; - } else if (is_rhs_scalar) { + } else if (is_rhs_scalar && lhs_type != element::boolean) { rhs = context.mark_node(std::make_shared(rhs, lhs)); return; } - if (lhs_dst_type == element::boolean || rhs_dst_type == element::boolean) { - // Do nothing with bool - return; - } - if (!lhs_dst_type.is_real() && rhs_dst_type.is_real()) { lhs_dst_type = element::f32; } else if (lhs_dst_type.is_real() && !rhs_dst_type.is_real()) { rhs_dst_type = element::f32; } - // Align bitness to higher - if (lhs_dst_type.bitwidth() != rhs_dst_type.bitwidth()) { - const auto dst_bitness = std::max(lhs_dst_type.bitwidth(), rhs_dst_type.bitwidth()); - element::Type* type_to_align = &lhs_dst_type; - if (rhs_dst_type.bitwidth() < dst_bitness) - type_to_align = &rhs_dst_type; - if (type_to_align->is_real()) { - *type_to_align = bit_to_float.at(dst_bitness); - } else { - *type_to_align = bit_to_int.at(dst_bitness); + // Align bool to other type + if (lhs_dst_type == element::boolean) { + lhs_dst_type = rhs_dst_type; + } else if (rhs_dst_type == element::boolean) { + rhs_dst_type = lhs_dst_type; + } + // At this point we either have both floating point type or both integer type. Align bitness to higher + if (lhs_dst_type != rhs_dst_type) { + auto dst_bitness = std::max(lhs_dst_type.bitwidth(), rhs_dst_type.bitwidth()); + // If integer type are mixed signed+unsigned align to next bitness + if (dst_bitness < 64 && lhs_dst_type.is_integral() && lhs_dst_type.is_integral() && + lhs_dst_type.bitwidth() == rhs_dst_type.bitwidth() && lhs_dst_type != rhs_dst_type) { + dst_bitness *= 2; + } + if (lhs_dst_type.bitwidth() != dst_bitness) { + if (lhs_dst_type.is_real()) { + lhs_dst_type = bit_to_float.at(dst_bitness); + } else { + lhs_dst_type = bit_to_int.at(dst_bitness); + } + } + if (rhs_dst_type.bitwidth() != dst_bitness) { + if (rhs_dst_type.is_real()) { + rhs_dst_type = bit_to_float.at(dst_bitness); + } else { + rhs_dst_type = bit_to_int.at(dst_bitness); + } } } diff --git a/tests/layer_tests/pytorch_tests/test_add.py b/tests/layer_tests/pytorch_tests/test_add.py index 0f72ef08d69..c13cfbcd363 100644 --- a/tests/layer_tests/pytorch_tests/test_add.py +++ b/tests/layer_tests/pytorch_tests/test_add.py @@ -50,11 +50,11 @@ class TestAddTypes(PytorchLayerTest): def _prepare_input(self): if len(self.lhs_shape) == 0: - return (torch.randn(self.rhs_shape).to(self.rhs_type).numpy(),) + return (torch.randint(0, 10, self.rhs_shape).to(self.rhs_type).numpy(),) elif len(self.rhs_shape) == 0: - return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),) - return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(), - torch.randn(self.rhs_shape).to(self.rhs_type).numpy()) + return (torch.randint(0, 10, self.lhs_shape).to(self.lhs_type).numpy(),) + return (torch.randint(0, 10, self.lhs_shape).to(self.lhs_type).numpy(), + torch.randint(0, 10, self.rhs_shape).to(self.rhs_type).numpy()) def create_model(self, lhs_type, lhs_shape, rhs_type, rhs_shape): @@ -71,10 +71,10 @@ class TestAddTypes(PytorchLayerTest): self.forward = self.forward3 def forward1(self, rhs): - return torch.add(torch.tensor(3).to(self.lhs_type), rhs.to(self.rhs_type), alpha=2) + return torch.add(torch.tensor(1).to(self.lhs_type), rhs.to(self.rhs_type), alpha=2) def forward2(self, lhs): - return torch.add(lhs.to(self.lhs_type), torch.tensor(3).to(self.rhs_type), alpha=2) + return torch.add(lhs.to(self.lhs_type), torch.tensor(1).to(self.rhs_type), alpha=2) def forward3(self, lhs, rhs): return torch.add(lhs.to(self.lhs_type), rhs.to(self.rhs_type), alpha=2) @@ -84,8 +84,11 @@ class TestAddTypes(PytorchLayerTest): return aten_add(lhs_type, lhs_shape, rhs_type, rhs_shape), ref_net, "aten::add" @pytest.mark.parametrize(("lhs_type", "rhs_type"), - [[torch.int32, torch.int64], - [torch.int32, torch.float32], + [[torch.bool, torch.uint8], + [torch.bool, torch.int8], + [torch.int8, torch.uint8], + [torch.uint8, torch.int8], + [torch.int32, torch.int64], [torch.int32, torch.float64], [torch.int64, torch.int32], [torch.int64, torch.float32],