[PT FE] Align bool types and same bit int types (#19399)

* [PT FE] Align bool types and same bit int types

* Fix max value
This commit is contained in:
Maxim Vafin 2023-08-24 22:30:59 +02:00 committed by GitHub
parent eef6b35bef
commit c5b64e458b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 25 deletions

View File

@ -428,34 +428,46 @@ void align_eltwise_input_types(const NodeContext& context, Output<Node>& lhs, Ou
// if div we need to also align float types to highest bitness regardless of scalar
if (!align_scalars)
rhs_dst_type = element::f32;
} else if (is_lhs_scalar) {
} else if (is_lhs_scalar && rhs_type != element::boolean) {
lhs = context.mark_node(std::make_shared<opset10::ConvertLike>(lhs, rhs));
return;
} else if (is_rhs_scalar) {
} else if (is_rhs_scalar && lhs_type != element::boolean) {
rhs = context.mark_node(std::make_shared<opset10::ConvertLike>(rhs, lhs));
return;
}
if (lhs_dst_type == element::boolean || rhs_dst_type == element::boolean) {
// Do nothing with bool
return;
}
if (!lhs_dst_type.is_real() && rhs_dst_type.is_real()) {
lhs_dst_type = element::f32;
} else if (lhs_dst_type.is_real() && !rhs_dst_type.is_real()) {
rhs_dst_type = element::f32;
}
// Align bitness to higher
if (lhs_dst_type.bitwidth() != rhs_dst_type.bitwidth()) {
const auto dst_bitness = std::max(lhs_dst_type.bitwidth(), rhs_dst_type.bitwidth());
element::Type* type_to_align = &lhs_dst_type;
if (rhs_dst_type.bitwidth() < dst_bitness)
type_to_align = &rhs_dst_type;
if (type_to_align->is_real()) {
*type_to_align = bit_to_float.at(dst_bitness);
} else {
*type_to_align = bit_to_int.at(dst_bitness);
// Align bool to other type
if (lhs_dst_type == element::boolean) {
lhs_dst_type = rhs_dst_type;
} else if (rhs_dst_type == element::boolean) {
rhs_dst_type = lhs_dst_type;
}
// At this point we either have both floating point type or both integer type. Align bitness to higher
if (lhs_dst_type != rhs_dst_type) {
auto dst_bitness = std::max(lhs_dst_type.bitwidth(), rhs_dst_type.bitwidth());
// If integer type are mixed signed+unsigned align to next bitness
if (dst_bitness < 64 && lhs_dst_type.is_integral() && lhs_dst_type.is_integral() &&
lhs_dst_type.bitwidth() == rhs_dst_type.bitwidth() && lhs_dst_type != rhs_dst_type) {
dst_bitness *= 2;
}
if (lhs_dst_type.bitwidth() != dst_bitness) {
if (lhs_dst_type.is_real()) {
lhs_dst_type = bit_to_float.at(dst_bitness);
} else {
lhs_dst_type = bit_to_int.at(dst_bitness);
}
}
if (rhs_dst_type.bitwidth() != dst_bitness) {
if (rhs_dst_type.is_real()) {
rhs_dst_type = bit_to_float.at(dst_bitness);
} else {
rhs_dst_type = bit_to_int.at(dst_bitness);
}
}
}

View File

@ -50,11 +50,11 @@ class TestAddTypes(PytorchLayerTest):
def _prepare_input(self):
if len(self.lhs_shape) == 0:
return (torch.randn(self.rhs_shape).to(self.rhs_type).numpy(),)
return (torch.randint(0, 10, self.rhs_shape).to(self.rhs_type).numpy(),)
elif len(self.rhs_shape) == 0:
return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),)
return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),
torch.randn(self.rhs_shape).to(self.rhs_type).numpy())
return (torch.randint(0, 10, self.lhs_shape).to(self.lhs_type).numpy(),)
return (torch.randint(0, 10, self.lhs_shape).to(self.lhs_type).numpy(),
torch.randint(0, 10, self.rhs_shape).to(self.rhs_type).numpy())
def create_model(self, lhs_type, lhs_shape, rhs_type, rhs_shape):
@ -71,10 +71,10 @@ class TestAddTypes(PytorchLayerTest):
self.forward = self.forward3
def forward1(self, rhs):
return torch.add(torch.tensor(3).to(self.lhs_type), rhs.to(self.rhs_type), alpha=2)
return torch.add(torch.tensor(1).to(self.lhs_type), rhs.to(self.rhs_type), alpha=2)
def forward2(self, lhs):
return torch.add(lhs.to(self.lhs_type), torch.tensor(3).to(self.rhs_type), alpha=2)
return torch.add(lhs.to(self.lhs_type), torch.tensor(1).to(self.rhs_type), alpha=2)
def forward3(self, lhs, rhs):
return torch.add(lhs.to(self.lhs_type), rhs.to(self.rhs_type), alpha=2)
@ -84,8 +84,11 @@ class TestAddTypes(PytorchLayerTest):
return aten_add(lhs_type, lhs_shape, rhs_type, rhs_shape), ref_net, "aten::add"
@pytest.mark.parametrize(("lhs_type", "rhs_type"),
[[torch.int32, torch.int64],
[torch.int32, torch.float32],
[[torch.bool, torch.uint8],
[torch.bool, torch.int8],
[torch.int8, torch.uint8],
[torch.uint8, torch.int8],
[torch.int32, torch.int64],
[torch.int32, torch.float64],
[torch.int64, torch.int32],
[torch.int64, torch.float32],