Fix issue with Pow when both inputs are scalars (#17305)

* Fix issue with Pow when both inputs are scalars

* Fix code style
This commit is contained in:
Maxim Vafin 2023-05-02 16:49:42 +02:00 committed by GitHub
parent 35cae6251c
commit 89d3eaa67f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 33 deletions

View File

@ -347,33 +347,24 @@ void align_eltwise_input_types(const NodeContext& context, Output<Node>& lhs, Ou
// consider dynamic rank as non scalar // consider dynamic rank as non scalar
const auto is_lhs_scalar = lhs_rank.is_static() && lhs_rank.get_length() == 0; const auto is_lhs_scalar = lhs_rank.is_static() && lhs_rank.get_length() == 0;
const auto is_rhs_scalar = rhs_rank.is_static() && rhs_rank.get_length() == 0; const auto is_rhs_scalar = rhs_rank.is_static() && rhs_rank.get_length() == 0;
if (is_lhs_scalar && is_rhs_scalar) {
// if both scalar, align to lhs
rhs = context.mark_node(std::make_shared<opset10::ConvertLike>(rhs, lhs));
return;
}
auto lhs_dst_type = lhs_type; auto lhs_dst_type = lhs_type;
auto rhs_dst_type = rhs_type; auto rhs_dst_type = rhs_type;
if (is_lhs_scalar) { if (is_lhs_scalar && lhs_type.is_real() && !rhs_type.is_real()) {
if (lhs_type.is_real() && !rhs_type.is_real()) { // if div we need to also align float types to highest bitness regardless of scalar
// if div we need to also align float types to highest bitness regardless of scalar if (!align_scalars)
if (!align_scalars)
lhs_dst_type = element::f32;
rhs_dst_type = element::f32;
} else {
lhs = context.mark_node(std::make_shared<opset10::ConvertLike>(lhs, rhs));
return;
}
} else if (is_rhs_scalar) {
if (!lhs_type.is_real() && rhs_type.is_real()) {
lhs_dst_type = element::f32; lhs_dst_type = element::f32;
// if div we need to also align float types to highest bitness regardless of scalar rhs_dst_type = element::f32;
if (!align_scalars) } else if (is_rhs_scalar && !lhs_type.is_real() && rhs_type.is_real()) {
rhs_dst_type = element::f32; lhs_dst_type = element::f32;
} else { // if div we need to also align float types to highest bitness regardless of scalar
rhs = context.mark_node(std::make_shared<opset10::ConvertLike>(rhs, lhs)); if (!align_scalars)
return; rhs_dst_type = element::f32;
} } else if (is_lhs_scalar) {
lhs = context.mark_node(std::make_shared<opset10::ConvertLike>(lhs, rhs));
return;
} else if (is_rhs_scalar) {
rhs = context.mark_node(std::make_shared<opset10::ConvertLike>(rhs, lhs));
return;
} }
if (lhs_dst_type == element::boolean || rhs_dst_type == element::boolean) { if (lhs_dst_type == element::boolean || rhs_dst_type == element::boolean) {

View File

@ -110,14 +110,12 @@ class PytorchLayerTest:
# check if results dtypes match # check if results dtypes match
for fw_tensor, ov_tensor in zip(flatten_fw_res, output_list): for fw_tensor, ov_tensor in zip(flatten_fw_res, output_list):
if not isinstance(fw_tensor, torch.Tensor): if not isinstance(fw_tensor, torch.Tensor):
if np.isscalar(fw_tensor): fw_type = torch.tensor(fw_tensor).numpy().dtype
assert fw_tensor == np.array(ov_tensor).item( ov_type = ov_tensor.dtype
), f"{fw_tensor} != {np.array(ov_tensor).item()}" if fw_type in [np.int32, np.int64] and ov_type in [np.int32, np.int64]:
else: # do not differentiate between int32 and int64
if isinstance(fw_tensor, list): continue
ov_tensor = ov_tensor.tolist() assert ov_type == fw_type, f"dtype validation failed: {ov_type} != {fw_type}"
assert ov_tensor == fw_tensor
assert type(fw_tensor) == type(ov_tensor)
continue continue
assert torch.tensor(np.array( assert torch.tensor(np.array(
ov_tensor)).dtype == fw_tensor.dtype, f"dtype validation failed: {torch.tensor(np.array(ov_tensor)).dtype} != {fw_tensor.dtype}" ov_tensor)).dtype == fw_tensor.dtype, f"dtype validation failed: {torch.tensor(np.array(ov_tensor)).dtype} != {fw_tensor.dtype}"
@ -199,7 +197,8 @@ class PytorchLayerTest:
ov_inputs[i] = ov_inputs[i].astype(np.int32) ov_inputs[i] = ov_inputs[i].astype(np.int32)
inp = ov_inputs[i] inp = ov_inputs[i]
assert inp.dtype.name in self._type_map, f"Unknown type {inp.dtype}." assert inp.dtype.name in self._type_map, f"Unknown type {inp.dtype}."
params[i].get_node().set_element_type(self._type_map[inp.dtype.name]) if params[i].get_node().get_element_type().is_dynamic():
params[i].get_node().set_element_type(self._type_map[inp.dtype.name])
shape = [-1] * len(inp.shape) if dynamic_shapes else inp.shape shape = [-1] * len(inp.shape) if dynamic_shapes else inp.shape
params[i].get_node().set_partial_shape(PartialShape(shape)) params[i].get_node().set_partial_shape(PartialShape(shape))
om.validate_nodes_and_infer_types() om.validate_nodes_and_infer_types()

View File

@ -104,3 +104,22 @@ class TestPowMixedTypes(PytorchLayerTest):
self.rhs_shape = rhs_shape self.rhs_shape = rhs_shape
self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape), self._test(*self.create_model(lhs_type, lhs_shape, rhs_type, rhs_shape),
ie_device, precision, ir_version) ie_device, precision, ir_version)
class TestPowMixedTypesScalars(PytorchLayerTest):
def _prepare_input(self):
return (torch.randn([1,2,3,4]).numpy(),)
def create_model(self):
class aten_pow(torch.nn.Module):
def forward(self, x):
return torch.pow(x.size(2), -0.5)
ref_net = None
return aten_pow(), ref_net, "aten::pow"
@pytest.mark.nightly
@pytest.mark.precommit
def test_pow_mixed_types(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version)