Fix issue with Pow when both inputs are scalars (#17305)
* Fix issue with Pow when both inputs are scalars * Fix code style
This commit is contained in:
parent
35cae6251c
commit
89d3eaa67f
@ -347,34 +347,25 @@ void align_eltwise_input_types(const NodeContext& context, Output<Node>& lhs, Ou
|
||||
// consider dynamic rank as non scalar
|
||||
const auto is_lhs_scalar = lhs_rank.is_static() && lhs_rank.get_length() == 0;
|
||||
const auto is_rhs_scalar = rhs_rank.is_static() && rhs_rank.get_length() == 0;
|
||||
if (is_lhs_scalar && is_rhs_scalar) {
|
||||
// if both scalar, align to lhs
|
||||
rhs = context.mark_node(std::make_shared<opset10::ConvertLike>(rhs, lhs));
|
||||
return;
|
||||
}
|
||||
auto lhs_dst_type = lhs_type;
|
||||
auto rhs_dst_type = rhs_type;
|
||||
if (is_lhs_scalar) {
|
||||
if (lhs_type.is_real() && !rhs_type.is_real()) {
|
||||
if (is_lhs_scalar && lhs_type.is_real() && !rhs_type.is_real()) {
|
||||
// if div we need to also align float types to highest bitness regardless of scalar
|
||||
if (!align_scalars)
|
||||
lhs_dst_type = element::f32;
|
||||
rhs_dst_type = element::f32;
|
||||
} else {
|
||||
} else if (is_rhs_scalar && !lhs_type.is_real() && rhs_type.is_real()) {
|
||||
lhs_dst_type = element::f32;
|
||||
// if div we need to also align float types to highest bitness regardless of scalar
|
||||
if (!align_scalars)
|
||||
rhs_dst_type = element::f32;
|
||||
} else if (is_lhs_scalar) {
|
||||
lhs = context.mark_node(std::make_shared<opset10::ConvertLike>(lhs, rhs));
|
||||
return;
|
||||
}
|
||||
} else if (is_rhs_scalar) {
|
||||
if (!lhs_type.is_real() && rhs_type.is_real()) {
|
||||
lhs_dst_type = element::f32;
|
||||
// if div we need to also align float types to highest bitness regardless of scalar
|
||||
if (!align_scalars)
|
||||
rhs_dst_type = element::f32;
|
||||
} else {
|
||||
rhs = context.mark_node(std::make_shared<opset10::ConvertLike>(rhs, lhs));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if (lhs_dst_type == element::boolean || rhs_dst_type == element::boolean) {
|
||||
// Do nothing with bool
|
||||
|
@ -110,14 +110,12 @@ class PytorchLayerTest:
|
||||
# check if results dtypes match
|
||||
for fw_tensor, ov_tensor in zip(flatten_fw_res, output_list):
|
||||
if not isinstance(fw_tensor, torch.Tensor):
|
||||
if np.isscalar(fw_tensor):
|
||||
assert fw_tensor == np.array(ov_tensor).item(
|
||||
), f"{fw_tensor} != {np.array(ov_tensor).item()}"
|
||||
else:
|
||||
if isinstance(fw_tensor, list):
|
||||
ov_tensor = ov_tensor.tolist()
|
||||
assert ov_tensor == fw_tensor
|
||||
assert type(fw_tensor) == type(ov_tensor)
|
||||
fw_type = torch.tensor(fw_tensor).numpy().dtype
|
||||
ov_type = ov_tensor.dtype
|
||||
if fw_type in [np.int32, np.int64] and ov_type in [np.int32, np.int64]:
|
||||
# do not differentiate between int32 and int64
|
||||
continue
|
||||
assert ov_type == fw_type, f"dtype validation failed: {ov_type} != {fw_type}"
|
||||
continue
|
||||
assert torch.tensor(np.array(
|
||||
ov_tensor)).dtype == fw_tensor.dtype, f"dtype validation failed: {torch.tensor(np.array(ov_tensor)).dtype} != {fw_tensor.dtype}"
|
||||
@ -199,6 +197,7 @@ class PytorchLayerTest:
|
||||
ov_inputs[i] = ov_inputs[i].astype(np.int32)
|
||||
inp = ov_inputs[i]
|
||||
assert inp.dtype.name in self._type_map, f"Unknown type {inp.dtype}."
|
||||
if params[i].get_node().get_element_type().is_dynamic():
|
||||
params[i].get_node().set_element_type(self._type_map[inp.dtype.name])
|
||||
shape = [-1] * len(inp.shape) if dynamic_shapes else inp.shape
|
||||
params[i].get_node().set_partial_shape(PartialShape(shape))
|
||||
|
@ -104,3 +104,22 @@ class TestPowMixedTypes(PytorchLayerTest):
|
||||
self.rhs_shape = rhs_shape
|
||||
self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape),
|
||||
ie_device, precision, ir_version)
|
||||
|
||||
class TestPowMixedTypesScalars(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (torch.randn([1,2,3,4]).numpy(),)
|
||||
|
||||
def create_model(self):
|
||||
|
||||
class aten_pow(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.pow(x.size(2), -0.5)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_pow(), ref_net, "aten::pow"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_pow_mixed_types(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
||||
|
Loading…
Reference in New Issue
Block a user