From 89d3eaa67fba1fc19f8fb6412145353e8a02ec78 Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Tue, 2 May 2023 16:49:42 +0200 Subject: [PATCH] Fix issue with Pow when both inputs are scalars (#17305) * Fix issue with Pow when both inputs are scalars * Fix code style --- src/frontends/pytorch/src/utils.cpp | 39 +++++++------------ .../pytorch_tests/pytorch_layer_test_class.py | 17 ++++---- tests/layer_tests/pytorch_tests/test_pow.py | 19 +++++++++ 3 files changed, 42 insertions(+), 33 deletions(-) diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index e9c67d73f54..4fe1af10922 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -347,33 +347,24 @@ void align_eltwise_input_types(const NodeContext& context, Output& lhs, Ou // consider dynamic rank as non scalar const auto is_lhs_scalar = lhs_rank.is_static() && lhs_rank.get_length() == 0; const auto is_rhs_scalar = rhs_rank.is_static() && rhs_rank.get_length() == 0; - if (is_lhs_scalar && is_rhs_scalar) { - // if both scalar, align to lhs - rhs = context.mark_node(std::make_shared(rhs, lhs)); - return; - } auto lhs_dst_type = lhs_type; auto rhs_dst_type = rhs_type; - if (is_lhs_scalar) { - if (lhs_type.is_real() && !rhs_type.is_real()) { - // if div we need to also align float types to highest bitness regardless of scalar - if (!align_scalars) - lhs_dst_type = element::f32; - rhs_dst_type = element::f32; - } else { - lhs = context.mark_node(std::make_shared(lhs, rhs)); - return; - } - } else if (is_rhs_scalar) { - if (!lhs_type.is_real() && rhs_type.is_real()) { + if (is_lhs_scalar && lhs_type.is_real() && !rhs_type.is_real()) { + // if div we need to also align float types to highest bitness regardless of scalar + if (!align_scalars) lhs_dst_type = element::f32; - // if div we need to also align float types to highest bitness regardless of scalar - if (!align_scalars) - rhs_dst_type = element::f32; - } else { - rhs = context.mark_node(std::make_shared(rhs, lhs)); - return; - } + rhs_dst_type = element::f32; + } else if (is_rhs_scalar && !lhs_type.is_real() && rhs_type.is_real()) { + lhs_dst_type = element::f32; + // if div we need to also align float types to highest bitness regardless of scalar + if (!align_scalars) + rhs_dst_type = element::f32; + } else if (is_lhs_scalar) { + lhs = context.mark_node(std::make_shared(lhs, rhs)); + return; + } else if (is_rhs_scalar) { + rhs = context.mark_node(std::make_shared(rhs, lhs)); + return; } if (lhs_dst_type == element::boolean || rhs_dst_type == element::boolean) { diff --git a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py index 77d80af650c..456b1856267 100644 --- a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py +++ b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py @@ -110,14 +110,12 @@ class PytorchLayerTest: # check if results dtypes match for fw_tensor, ov_tensor in zip(flatten_fw_res, output_list): if not isinstance(fw_tensor, torch.Tensor): - if np.isscalar(fw_tensor): - assert fw_tensor == np.array(ov_tensor).item( - ), f"{fw_tensor} != {np.array(ov_tensor).item()}" - else: - if isinstance(fw_tensor, list): - ov_tensor = ov_tensor.tolist() - assert ov_tensor == fw_tensor - assert type(fw_tensor) == type(ov_tensor) + fw_type = torch.tensor(fw_tensor).numpy().dtype + ov_type = ov_tensor.dtype + if fw_type in [np.int32, np.int64] and ov_type in [np.int32, np.int64]: + # do not differentiate between int32 and int64 + continue + assert ov_type == fw_type, f"dtype validation failed: {ov_type} != {fw_type}" continue assert torch.tensor(np.array( ov_tensor)).dtype == fw_tensor.dtype, f"dtype validation failed: {torch.tensor(np.array(ov_tensor)).dtype} != {fw_tensor.dtype}" @@ -199,7 +197,8 @@ class PytorchLayerTest: ov_inputs[i] = ov_inputs[i].astype(np.int32) inp = ov_inputs[i] assert inp.dtype.name in self._type_map, f"Unknown type {inp.dtype}." - params[i].get_node().set_element_type(self._type_map[inp.dtype.name]) + if params[i].get_node().get_element_type().is_dynamic(): + params[i].get_node().set_element_type(self._type_map[inp.dtype.name]) shape = [-1] * len(inp.shape) if dynamic_shapes else inp.shape params[i].get_node().set_partial_shape(PartialShape(shape)) om.validate_nodes_and_infer_types() diff --git a/tests/layer_tests/pytorch_tests/test_pow.py b/tests/layer_tests/pytorch_tests/test_pow.py index b973106d0ac..9cf6468404e 100644 --- a/tests/layer_tests/pytorch_tests/test_pow.py +++ b/tests/layer_tests/pytorch_tests/test_pow.py @@ -104,3 +104,22 @@ class TestPowMixedTypes(PytorchLayerTest): self.rhs_shape = rhs_shape self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape), ie_device, precision, ir_version) + +class TestPowMixedTypesScalars(PytorchLayerTest): + def _prepare_input(self): + return (torch.randn([1,2,3,4]).numpy(),) + + def create_model(self): + + class aten_pow(torch.nn.Module): + def forward(self, x): + return torch.pow(x.size(2), -0.5) + + ref_net = None + + return aten_pow(), ref_net, "aten::pow" + + @pytest.mark.nightly + @pytest.mark.precommit + def test_pow_mixed_types(self, ie_device, precision, ir_version): + self._test(*self.create_model(), ie_device, precision, ir_version)