Fix issue with Pow when both inputs are scalars (#17305)

* Fix issue with Pow when both inputs are scalars

* Fix code style
This commit is contained in:
Maxim Vafin 2023-05-02 16:49:42 +02:00 committed by GitHub
parent 35cae6251c
commit 89d3eaa67f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 33 deletions

View File

@ -347,34 +347,25 @@ void align_eltwise_input_types(const NodeContext& context, Output<Node>& lhs, Ou
// consider dynamic rank as non scalar
const auto is_lhs_scalar = lhs_rank.is_static() && lhs_rank.get_length() == 0;
const auto is_rhs_scalar = rhs_rank.is_static() && rhs_rank.get_length() == 0;
if (is_lhs_scalar && is_rhs_scalar) {
// if both scalar, align to lhs
rhs = context.mark_node(std::make_shared<opset10::ConvertLike>(rhs, lhs));
return;
}
auto lhs_dst_type = lhs_type;
auto rhs_dst_type = rhs_type;
if (is_lhs_scalar) {
if (lhs_type.is_real() && !rhs_type.is_real()) {
if (is_lhs_scalar && lhs_type.is_real() && !rhs_type.is_real()) {
// if div we need to also align float types to highest bitness regardless of scalar
if (!align_scalars)
lhs_dst_type = element::f32;
rhs_dst_type = element::f32;
} else {
} else if (is_rhs_scalar && !lhs_type.is_real() && rhs_type.is_real()) {
lhs_dst_type = element::f32;
// if div we need to also align float types to highest bitness regardless of scalar
if (!align_scalars)
rhs_dst_type = element::f32;
} else if (is_lhs_scalar) {
lhs = context.mark_node(std::make_shared<opset10::ConvertLike>(lhs, rhs));
return;
}
} else if (is_rhs_scalar) {
if (!lhs_type.is_real() && rhs_type.is_real()) {
lhs_dst_type = element::f32;
// if div we need to also align float types to highest bitness regardless of scalar
if (!align_scalars)
rhs_dst_type = element::f32;
} else {
rhs = context.mark_node(std::make_shared<opset10::ConvertLike>(rhs, lhs));
return;
}
}
if (lhs_dst_type == element::boolean || rhs_dst_type == element::boolean) {
// Do nothing with bool

View File

@ -110,14 +110,12 @@ class PytorchLayerTest:
# check if results dtypes match
for fw_tensor, ov_tensor in zip(flatten_fw_res, output_list):
if not isinstance(fw_tensor, torch.Tensor):
if np.isscalar(fw_tensor):
assert fw_tensor == np.array(ov_tensor).item(
), f"{fw_tensor} != {np.array(ov_tensor).item()}"
else:
if isinstance(fw_tensor, list):
ov_tensor = ov_tensor.tolist()
assert ov_tensor == fw_tensor
assert type(fw_tensor) == type(ov_tensor)
fw_type = torch.tensor(fw_tensor).numpy().dtype
ov_type = ov_tensor.dtype
if fw_type in [np.int32, np.int64] and ov_type in [np.int32, np.int64]:
# do not differentiate between int32 and int64
continue
assert ov_type == fw_type, f"dtype validation failed: {ov_type} != {fw_type}"
continue
assert torch.tensor(np.array(
ov_tensor)).dtype == fw_tensor.dtype, f"dtype validation failed: {torch.tensor(np.array(ov_tensor)).dtype} != {fw_tensor.dtype}"
@ -199,6 +197,7 @@ class PytorchLayerTest:
ov_inputs[i] = ov_inputs[i].astype(np.int32)
inp = ov_inputs[i]
assert inp.dtype.name in self._type_map, f"Unknown type {inp.dtype}."
if params[i].get_node().get_element_type().is_dynamic():
params[i].get_node().set_element_type(self._type_map[inp.dtype.name])
shape = [-1] * len(inp.shape) if dynamic_shapes else inp.shape
params[i].get_node().set_partial_shape(PartialShape(shape))

View File

@ -104,3 +104,22 @@ class TestPowMixedTypes(PytorchLayerTest):
self.rhs_shape = rhs_shape
self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape),
ie_device, precision, ir_version)
class TestPowMixedTypesScalars(PytorchLayerTest):
def _prepare_input(self):
return (torch.randn([1,2,3,4]).numpy(),)
def create_model(self):
class aten_pow(torch.nn.Module):
def forward(self, x):
return torch.pow(x.size(2), -0.5)
ref_net = None
return aten_pow(), ref_net, "aten::pow"
@pytest.mark.nightly
@pytest.mark.precommit
def test_pow_mixed_types(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version)