[PT FE] Support aten::scatter_add (#21633)

This commit is contained in:
Maxim Vafin
2023-12-13 17:54:33 +01:00
committed by GitHub
parent 65439eda5d
commit be7c70c8c4
4 changed files with 85 additions and 3 deletions

View File

@@ -131,6 +131,23 @@ OutputVector translate_scatter_reduce(const NodeContext& context) {
return {scatter_result};
};
OutputVector translate_scatter_add(const NodeContext& context) {
// aten::scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
num_inputs_check(context, 4, 4);
auto input = context.get_input(0);
auto dim = context.get_input(1);
auto index = context.mark_node(std::make_shared<v0::Convert>(context.get_input(2), element::i32));
auto src = context.get_input(3);
auto src_input_dtype = prepare_source(context, src, index, input);
auto scatter_result =
context.mark_node(std::make_shared<v12::ScatterElementsUpdate>(input,
index,
src_input_dtype,
dim,
v12::ScatterElementsUpdate::Reduction::SUM));
return {scatter_result};
};
} // namespace op
} // namespace pytorch
} // namespace frontend

View File

@@ -178,6 +178,7 @@ OP_CONVERTER(translate_rsqrt);
OP_CONVERTER(translate_rsub);
OP_CONVERTER(translate_scaled_dot_product_attention);
OP_CONVERTER(translate_scatter);
OP_CONVERTER(translate_scatter_add);
OP_CONVERTER(translate_scatter_reduce);
OP_CONVERTER(translate_select);
OP_CONVERTER(translate_set_item);
@@ -502,6 +503,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::scaled_dot_product_attention", op::translate_scaled_dot_product_attention},
{"aten::scatter", op::translate_scatter},
{"aten::scatter_", op::inplace_op<op::translate_scatter>},
{"aten::scatter_add", op::translate_scatter_add},
{"aten::scatter_add_", op::inplace_op<op::translate_scatter_add>},
{"aten::scatter_reduce", op::translate_scatter_reduce},
{"aten::scatter_reduce_", op::inplace_op<op::translate_scatter_reduce>},
{"aten::select", op::quantizable_op<op::translate_select>},

View File

@@ -219,3 +219,65 @@ class TestScatterReduce(PytorchLayerTest):
kwargs_to_prepare_input={"dtype": dtype, "out": has_out},
freeze_model=freeze
)
class TestScatterAdd(PytorchLayerTest):
def _prepare_input(self, dtype):
return (np.random.randn(6, 6).astype(dtype),)
def create_model(self, dim, index, src, inplace):
class aten_scatter_reduce(torch.nn.Module):
def __init__(self, dim, index, src, inplace):
super(aten_scatter_reduce, self).__init__()
self.dim = dim
self.use_empty_index = False
if index is None:
self.use_empty_index = True
# Placeholder
self.index = torch.empty([1])
else:
self.index = index
self.src = src
self.inplace = inplace
def forward(self, x: torch.Tensor):
if self.use_empty_index:
index = torch.empty([0, 0])
else:
index = self.index
if self.inplace:
return x.scatter_add_(self.dim, index, self.src)
else:
return x.scatter_add(self.dim, index, self.src)
op_name = "aten::scatter_add_" if inplace else "aten::scatter_add"
return aten_scatter_reduce(dim, index, src, inplace), None, op_name
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("dim", [1, -1, 0])
@pytest.mark.parametrize(
"index",
[
None, # Empty tensor scenario.
torch.tensor([[0, 1, 2, 3]]),
torch.tensor([[0, 5], [4, 1], [2, 3]]),
],
)
@pytest.mark.parametrize("src", [torch.arange(1, 26).reshape(5, 5)])
@pytest.mark.parametrize("dtype", ["int32", "int64", "float32", "float64"])
@pytest.mark.parametrize("inplace", [True, False])
def test_scatter_reduce(self, dim, index, src, dtype, inplace, ie_device, precision, ir_version):
if isinstance(src, torch.Tensor):
src = src.to(getattr(torch, dtype))
if index is None:
pytest.skip(
"Cannot test reduce parameters with empty indexes due to issues with empty constant tensor or issues with prim::GetAttr str inputs."
)
self._test(
*self.create_model(dim, index, src, inplace),
ie_device,
precision,
ir_version,
kwargs_to_prepare_input={"dtype": dtype},
)

View File

@@ -5,9 +5,9 @@ LearningToPaint,None
Super_SloMo,None,xfail,Unsupported ops aten::l1_loss aten::mse_loss
#alexnet,None - Already tested by torchvision tests
basic_gnn_edgecnn,None,xfail,Accuracy validation failed
basic_gnn_gcn,None,xfail,Unsupported ops aten::pow_ aten::scatter_add_
basic_gnn_gin,None,xfail,Unsupported op aten::scatter_add_
basic_gnn_sage,None,xfail,Unsupported op aten::scatter_add_
basic_gnn_gcn,None,xfail,Unsupported ops aten::pow_
basic_gnn_gin,None
basic_gnn_sage,None
#cm3leon_generate,None,skip,No install.py is found
dcgan,None
demucs,None