[PT FE]: support boolean data type in sum operation (#17715)
This commit is contained in:
parent
ed8333a94c
commit
307b666d99
@ -3,6 +3,7 @@
|
|||||||
//
|
//
|
||||||
|
|
||||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||||
|
#include "openvino/op/convert.hpp"
|
||||||
#include "openvino/op/reduce_sum.hpp"
|
#include "openvino/op/reduce_sum.hpp"
|
||||||
#include "utils.hpp"
|
#include "utils.hpp"
|
||||||
|
|
||||||
@ -16,6 +17,9 @@ OutputVector translate_sum(const NodeContext& context) {
|
|||||||
bool keep_dims = false;
|
bool keep_dims = false;
|
||||||
ov::Output<ov::Node> axes;
|
ov::Output<ov::Node> axes;
|
||||||
auto data = context.get_input(0);
|
auto data = context.get_input(0);
|
||||||
|
if (data.get_element_type() == element::boolean) {
|
||||||
|
data = context.mark_node(std::make_shared<ov::op::v0::Convert>(data, element::i64));
|
||||||
|
}
|
||||||
if (context.input_is_none(1)) {
|
if (context.input_is_none(1)) {
|
||||||
axes = get_axes_range(context, 0);
|
axes = get_axes_range(context, 0);
|
||||||
} else {
|
} else {
|
||||||
|
@ -40,3 +40,39 @@ class TestSum(PytorchLayerTest):
|
|||||||
def test_sum(self, axes, keep_dim, ie_device, precision, ir_version):
|
def test_sum(self, axes, keep_dim, ie_device, precision, ir_version):
|
||||||
self._test(*self.create_model(axes, keep_dim),
|
self._test(*self.create_model(axes, keep_dim),
|
||||||
ie_device, precision, ir_version)
|
ie_device, precision, ir_version)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSumBool(PytorchLayerTest):
|
||||||
|
def _prepare_input(self):
|
||||||
|
import numpy as np
|
||||||
|
return (np.random.randint(0, 2, (1, 3, 20, 24)).astype(bool),)
|
||||||
|
|
||||||
|
def create_model(self, axes, keep_dims):
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class aten_sum(torch.nn.Module):
|
||||||
|
def __init__(self, axes=None, keep_dims=None):
|
||||||
|
super(aten_sum, self).__init__()
|
||||||
|
self.axes = axes
|
||||||
|
self.keep_dims = keep_dims
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.to(torch.bool)
|
||||||
|
if self.axes is None and self.keep_dims is None:
|
||||||
|
return torch.sum(x)
|
||||||
|
if self.axes is not None and self.keep_dims is None:
|
||||||
|
return torch.sum(x, self.axes)
|
||||||
|
return torch.sum(x, self.axes, self.keep_dims)
|
||||||
|
|
||||||
|
ref_net = None
|
||||||
|
|
||||||
|
return aten_sum(axes, keep_dims), ref_net, "aten::sum"
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("axes,keep_dim",
|
||||||
|
[(None, None), (None, False), (-1, None), (1, None), ((2, 3), False), ((3, 2), True)])
|
||||||
|
@pytest.mark.nightly
|
||||||
|
@pytest.mark.precommit
|
||||||
|
def test_sum(self, axes, keep_dim, ie_device, precision, ir_version):
|
||||||
|
self._test(*self.create_model(axes, keep_dim),
|
||||||
|
ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user