[PT FE] Align bool types and same bit int types (#19399)
* [PT FE] Align bool types and same bit int types * Fix max value
This commit is contained in:
parent
eef6b35bef
commit
c5b64e458b
@ -428,34 +428,46 @@ void align_eltwise_input_types(const NodeContext& context, Output<Node>& lhs, Ou
|
||||
// if div we need to also align float types to highest bitness regardless of scalar
|
||||
if (!align_scalars)
|
||||
rhs_dst_type = element::f32;
|
||||
} else if (is_lhs_scalar) {
|
||||
} else if (is_lhs_scalar && rhs_type != element::boolean) {
|
||||
lhs = context.mark_node(std::make_shared<opset10::ConvertLike>(lhs, rhs));
|
||||
return;
|
||||
} else if (is_rhs_scalar) {
|
||||
} else if (is_rhs_scalar && lhs_type != element::boolean) {
|
||||
rhs = context.mark_node(std::make_shared<opset10::ConvertLike>(rhs, lhs));
|
||||
return;
|
||||
}
|
||||
|
||||
if (lhs_dst_type == element::boolean || rhs_dst_type == element::boolean) {
|
||||
// Do nothing with bool
|
||||
return;
|
||||
}
|
||||
|
||||
if (!lhs_dst_type.is_real() && rhs_dst_type.is_real()) {
|
||||
lhs_dst_type = element::f32;
|
||||
} else if (lhs_dst_type.is_real() && !rhs_dst_type.is_real()) {
|
||||
rhs_dst_type = element::f32;
|
||||
}
|
||||
// Align bitness to higher
|
||||
if (lhs_dst_type.bitwidth() != rhs_dst_type.bitwidth()) {
|
||||
const auto dst_bitness = std::max(lhs_dst_type.bitwidth(), rhs_dst_type.bitwidth());
|
||||
element::Type* type_to_align = &lhs_dst_type;
|
||||
if (rhs_dst_type.bitwidth() < dst_bitness)
|
||||
type_to_align = &rhs_dst_type;
|
||||
if (type_to_align->is_real()) {
|
||||
*type_to_align = bit_to_float.at(dst_bitness);
|
||||
} else {
|
||||
*type_to_align = bit_to_int.at(dst_bitness);
|
||||
// Align bool to other type
|
||||
if (lhs_dst_type == element::boolean) {
|
||||
lhs_dst_type = rhs_dst_type;
|
||||
} else if (rhs_dst_type == element::boolean) {
|
||||
rhs_dst_type = lhs_dst_type;
|
||||
}
|
||||
// At this point we either have both floating point type or both integer type. Align bitness to higher
|
||||
if (lhs_dst_type != rhs_dst_type) {
|
||||
auto dst_bitness = std::max(lhs_dst_type.bitwidth(), rhs_dst_type.bitwidth());
|
||||
// If integer type are mixed signed+unsigned align to next bitness
|
||||
if (dst_bitness < 64 && lhs_dst_type.is_integral() && lhs_dst_type.is_integral() &&
|
||||
lhs_dst_type.bitwidth() == rhs_dst_type.bitwidth() && lhs_dst_type != rhs_dst_type) {
|
||||
dst_bitness *= 2;
|
||||
}
|
||||
if (lhs_dst_type.bitwidth() != dst_bitness) {
|
||||
if (lhs_dst_type.is_real()) {
|
||||
lhs_dst_type = bit_to_float.at(dst_bitness);
|
||||
} else {
|
||||
lhs_dst_type = bit_to_int.at(dst_bitness);
|
||||
}
|
||||
}
|
||||
if (rhs_dst_type.bitwidth() != dst_bitness) {
|
||||
if (rhs_dst_type.is_real()) {
|
||||
rhs_dst_type = bit_to_float.at(dst_bitness);
|
||||
} else {
|
||||
rhs_dst_type = bit_to_int.at(dst_bitness);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -50,11 +50,11 @@ class TestAddTypes(PytorchLayerTest):
|
||||
|
||||
def _prepare_input(self):
|
||||
if len(self.lhs_shape) == 0:
|
||||
return (torch.randn(self.rhs_shape).to(self.rhs_type).numpy(),)
|
||||
return (torch.randint(0, 10, self.rhs_shape).to(self.rhs_type).numpy(),)
|
||||
elif len(self.rhs_shape) == 0:
|
||||
return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),)
|
||||
return (torch.randn(self.lhs_shape).to(self.lhs_type).numpy(),
|
||||
torch.randn(self.rhs_shape).to(self.rhs_type).numpy())
|
||||
return (torch.randint(0, 10, self.lhs_shape).to(self.lhs_type).numpy(),)
|
||||
return (torch.randint(0, 10, self.lhs_shape).to(self.lhs_type).numpy(),
|
||||
torch.randint(0, 10, self.rhs_shape).to(self.rhs_type).numpy())
|
||||
|
||||
def create_model(self, lhs_type, lhs_shape, rhs_type, rhs_shape):
|
||||
|
||||
@ -71,10 +71,10 @@ class TestAddTypes(PytorchLayerTest):
|
||||
self.forward = self.forward3
|
||||
|
||||
def forward1(self, rhs):
|
||||
return torch.add(torch.tensor(3).to(self.lhs_type), rhs.to(self.rhs_type), alpha=2)
|
||||
return torch.add(torch.tensor(1).to(self.lhs_type), rhs.to(self.rhs_type), alpha=2)
|
||||
|
||||
def forward2(self, lhs):
|
||||
return torch.add(lhs.to(self.lhs_type), torch.tensor(3).to(self.rhs_type), alpha=2)
|
||||
return torch.add(lhs.to(self.lhs_type), torch.tensor(1).to(self.rhs_type), alpha=2)
|
||||
|
||||
def forward3(self, lhs, rhs):
|
||||
return torch.add(lhs.to(self.lhs_type), rhs.to(self.rhs_type), alpha=2)
|
||||
@ -84,8 +84,11 @@ class TestAddTypes(PytorchLayerTest):
|
||||
return aten_add(lhs_type, lhs_shape, rhs_type, rhs_shape), ref_net, "aten::add"
|
||||
|
||||
@pytest.mark.parametrize(("lhs_type", "rhs_type"),
|
||||
[[torch.int32, torch.int64],
|
||||
[torch.int32, torch.float32],
|
||||
[[torch.bool, torch.uint8],
|
||||
[torch.bool, torch.int8],
|
||||
[torch.int8, torch.uint8],
|
||||
[torch.uint8, torch.int8],
|
||||
[torch.int32, torch.int64],
|
||||
[torch.int32, torch.float64],
|
||||
[torch.int64, torch.int32],
|
||||
[torch.int64, torch.float32],
|
||||
|
Loading…
Reference in New Issue
Block a user