[PT FE]: support boolean data type in sum operation (#17715)

This commit is contained in:
Ekaterina Aidova 2023-05-26 15:44:50 +04:00 committed by GitHub
parent ed8333a94c
commit 307b666d99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 0 deletions

View File

@ -3,6 +3,7 @@
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/reduce_sum.hpp"
#include "utils.hpp"
@ -16,6 +17,9 @@ OutputVector translate_sum(const NodeContext& context) {
bool keep_dims = false;
ov::Output<ov::Node> axes;
auto data = context.get_input(0);
if (data.get_element_type() == element::boolean) {
data = context.mark_node(std::make_shared<ov::op::v0::Convert>(data, element::i64));
}
if (context.input_is_none(1)) {
axes = get_axes_range(context, 0);
} else {

View File

@ -40,3 +40,39 @@ class TestSum(PytorchLayerTest):
def test_sum(self, axes, keep_dim, ie_device, precision, ir_version):
self._test(*self.create_model(axes, keep_dim),
ie_device, precision, ir_version)
class TestSumBool(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randint(0, 2, (1, 3, 20, 24)).astype(bool),)
def create_model(self, axes, keep_dims):
import torch
class aten_sum(torch.nn.Module):
def __init__(self, axes=None, keep_dims=None):
super(aten_sum, self).__init__()
self.axes = axes
self.keep_dims = keep_dims
def forward(self, x):
x = x.to(torch.bool)
if self.axes is None and self.keep_dims is None:
return torch.sum(x)
if self.axes is not None and self.keep_dims is None:
return torch.sum(x, self.axes)
return torch.sum(x, self.axes, self.keep_dims)
ref_net = None
return aten_sum(axes, keep_dims), ref_net, "aten::sum"
@pytest.mark.parametrize("axes,keep_dim",
[(None, None), (None, False), (-1, None), (1, None), ((2, 3), False), ((3, 2), True)])
@pytest.mark.nightly
@pytest.mark.precommit
def test_sum(self, axes, keep_dim, ie_device, precision, ir_version):
self._test(*self.create_model(axes, keep_dim),
ie_device, precision, ir_version)