[PT FE] Align bool types and same bit int types (#19399)

* [PT FE] Align bool types and same bit int types

* Fix max value
This commit is contained in:
Maxim Vafin 2023-08-24 22:30:59 +02:00 committed by GitHub
parent eef6b35bef
commit c5b64e458b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 25 deletions

View File

@ -428,34 +428,46 @@ void align_eltwise_input_types(const NodeContext& context, Output<Node>& lhs, Ou
// if div we need to also align float types to highest bitness regardless of scalar // if div we need to also align float types to highest bitness regardless of scalar
if (!align_scalars) if (!align_scalars)
rhs_dst_type = element::f32; rhs_dst_type = element::f32;
} else if (is_lhs_scalar) { } else if (is_lhs_scalar && rhs_type != element::boolean) {
lhs = context.mark_node(std::make_shared<opset10::ConvertLike>(lhs, rhs)); lhs = context.mark_node(std::make_shared<opset10::ConvertLike>(lhs, rhs));
return; return;
} else if (is_rhs_scalar) { } else if (is_rhs_scalar && lhs_type != element::boolean) {
rhs = context.mark_node(std::make_shared<opset10::ConvertLike>(rhs, lhs)); rhs = context.mark_node(std::make_shared<opset10::ConvertLike>(rhs, lhs));
return; return;
} }
if (lhs_dst_type == element::boolean || rhs_dst_type == element::boolean) {
// Do nothing with bool
return;
}
if (!lhs_dst_type.is_real() && rhs_dst_type.is_real()) { if (!lhs_dst_type.is_real() && rhs_dst_type.is_real()) {
lhs_dst_type = element::f32; lhs_dst_type = element::f32;
} else if (lhs_dst_type.is_real() && !rhs_dst_type.is_real()) { } else if (lhs_dst_type.is_real() && !rhs_dst_type.is_real()) {
rhs_dst_type = element::f32; rhs_dst_type = element::f32;
} }
// Align bitness to higher // Align bool to other type
if (lhs_dst_type.bitwidth() != rhs_dst_type.bitwidth()) { if (lhs_dst_type == element::boolean) {
const auto dst_bitness = std::max(lhs_dst_type.bitwidth(), rhs_dst_type.bitwidth()); lhs_dst_type = rhs_dst_type;
element::Type* type_to_align = &lhs_dst_type; } else if (rhs_dst_type == element::boolean) {
if (rhs_dst_type.bitwidth() < dst_bitness) rhs_dst_type = lhs_dst_type;
type_to_align = &rhs_dst_type; }
if (type_to_align->is_real()) { // At this point we either have both floating point type or both integer type. Align bitness to higher
*type_to_align = bit_to_float.at(dst_bitness); if (lhs_dst_type != rhs_dst_type) {
} else { auto dst_bitness = std::max(lhs_dst_type.bitwidth(), rhs_dst_type.bitwidth());
*type_to_align = bit_to_int.at(dst_bitness); // If integer type are mixed signed+unsigned align to next bitness
if (dst_bitness < 64 && lhs_dst_type.is_integral() && lhs_dst_type.is_integral() &&
lhs_dst_type.bitwidth() == rhs_dst_type.bitwidth() && lhs_dst_type != rhs_dst_type) {
dst_bitness *= 2;
}
if (lhs_dst_type.bitwidth() != dst_bitness) {
if (lhs_dst_type.is_real()) {
lhs_dst_type = bit_to_float.at(dst_bitness);
} else {
lhs_dst_type = bit_to_int.at(dst_bitness);
}
}
if (rhs_dst_type.bitwidth() != dst_bitness) {
if (rhs_dst_type.is_real()) {
rhs_dst_type = bit_to_float.at(dst_bitness);
} else {
rhs_dst_type = bit_to_int.at(dst_bitness);
}
} }
} }

View File

@ -50,11 +50,11 @@ class TestAddTypes(PytorchLayerTest):
def _prepare_input(self): def _prepare_input(self):
if len(self.lhs_shape) == 0: if len(self.lhs_shape) == 0:
return (torch.randn(self.rhs_shape).to(self.rhs_type).numpy(),) return (torch.randint(0, 10, self.rhs_shape).to(self.rhs_type).numpy(),)
elif len(self.rhs_shape) == 0: elif len(self.rhs_shape) == 0:
return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),) return (torch.randint(0, 10, self.lhs_shape).to(self.lhs_type).numpy(),)
return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(), return (torch.randint(0, 10, self.lhs_shape).to(self.lhs_type).numpy(),
torch.randn(self.rhs_shape).to(self.rhs_type).numpy()) torch.randint(0, 10, self.rhs_shape).to(self.rhs_type).numpy())
def create_model(self, lhs_type, lhs_shape, rhs_type, rhs_shape): def create_model(self, lhs_type, lhs_shape, rhs_type, rhs_shape):
@ -71,10 +71,10 @@ class TestAddTypes(PytorchLayerTest):
self.forward = self.forward3 self.forward = self.forward3
def forward1(self, rhs): def forward1(self, rhs):
return torch.add(torch.tensor(3).to(self.lhs_type), rhs.to(self.rhs_type), alpha=2) return torch.add(torch.tensor(1).to(self.lhs_type), rhs.to(self.rhs_type), alpha=2)
def forward2(self, lhs): def forward2(self, lhs):
return torch.add(lhs.to(self.lhs_type), torch.tensor(3).to(self.rhs_type), alpha=2) return torch.add(lhs.to(self.lhs_type), torch.tensor(1).to(self.rhs_type), alpha=2)
def forward3(self, lhs, rhs): def forward3(self, lhs, rhs):
return torch.add(lhs.to(self.lhs_type), rhs.to(self.rhs_type), alpha=2) return torch.add(lhs.to(self.lhs_type), rhs.to(self.rhs_type), alpha=2)
@ -84,8 +84,11 @@ class TestAddTypes(PytorchLayerTest):
return aten_add(lhs_type, lhs_shape, rhs_type, rhs_shape), ref_net, "aten::add" return aten_add(lhs_type, lhs_shape, rhs_type, rhs_shape), ref_net, "aten::add"
@pytest.mark.parametrize(("lhs_type", "rhs_type"), @pytest.mark.parametrize(("lhs_type", "rhs_type"),
[[torch.int32, torch.int64], [[torch.bool, torch.uint8],
[torch.int32, torch.float32], [torch.bool, torch.int8],
[torch.int8, torch.uint8],
[torch.uint8, torch.int8],
[torch.int32, torch.int64],
[torch.int32, torch.float64], [torch.int32, torch.float64],
[torch.int64, torch.int32], [torch.int64, torch.int32],
[torch.int64, torch.float32], [torch.int64, torch.float32],