[PT FE]: support boolean data type in sum operation (#17715)
This commit is contained in:
parent
ed8333a94c
commit
307b666d99
@ -3,6 +3,7 @@
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/reduce_sum.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
@ -16,6 +17,9 @@ OutputVector translate_sum(const NodeContext& context) {
|
||||
bool keep_dims = false;
|
||||
ov::Output<ov::Node> axes;
|
||||
auto data = context.get_input(0);
|
||||
if (data.get_element_type() == element::boolean) {
|
||||
data = context.mark_node(std::make_shared<ov::op::v0::Convert>(data, element::i64));
|
||||
}
|
||||
if (context.input_is_none(1)) {
|
||||
axes = get_axes_range(context, 0);
|
||||
} else {
|
||||
|
@ -40,3 +40,39 @@ class TestSum(PytorchLayerTest):
|
||||
def test_sum(self, axes, keep_dim, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(axes, keep_dim),
|
||||
ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestSumBool(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return (np.random.randint(0, 2, (1, 3, 20, 24)).astype(bool),)
|
||||
|
||||
def create_model(self, axes, keep_dims):
|
||||
|
||||
import torch
|
||||
|
||||
class aten_sum(torch.nn.Module):
|
||||
def __init__(self, axes=None, keep_dims=None):
|
||||
super(aten_sum, self).__init__()
|
||||
self.axes = axes
|
||||
self.keep_dims = keep_dims
|
||||
|
||||
def forward(self, x):
|
||||
x = x.to(torch.bool)
|
||||
if self.axes is None and self.keep_dims is None:
|
||||
return torch.sum(x)
|
||||
if self.axes is not None and self.keep_dims is None:
|
||||
return torch.sum(x, self.axes)
|
||||
return torch.sum(x, self.axes, self.keep_dims)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_sum(axes, keep_dims), ref_net, "aten::sum"
|
||||
|
||||
@pytest.mark.parametrize("axes,keep_dim",
|
||||
[(None, None), (None, False), (-1, None), (1, None), ((2, 3), False), ((3, 2), True)])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_sum(self, axes, keep_dim, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(axes, keep_dim),
|
||||
ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user