[PT FE] Support aten::scatter_add (#21633)
This commit is contained in:
@@ -131,6 +131,23 @@ OutputVector translate_scatter_reduce(const NodeContext& context) {
|
||||
return {scatter_result};
|
||||
};
|
||||
|
||||
OutputVector translate_scatter_add(const NodeContext& context) {
|
||||
// aten::scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor
|
||||
num_inputs_check(context, 4, 4);
|
||||
auto input = context.get_input(0);
|
||||
auto dim = context.get_input(1);
|
||||
auto index = context.mark_node(std::make_shared<v0::Convert>(context.get_input(2), element::i32));
|
||||
auto src = context.get_input(3);
|
||||
auto src_input_dtype = prepare_source(context, src, index, input);
|
||||
auto scatter_result =
|
||||
context.mark_node(std::make_shared<v12::ScatterElementsUpdate>(input,
|
||||
index,
|
||||
src_input_dtype,
|
||||
dim,
|
||||
v12::ScatterElementsUpdate::Reduction::SUM));
|
||||
return {scatter_result};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
|
||||
@@ -178,6 +178,7 @@ OP_CONVERTER(translate_rsqrt);
|
||||
OP_CONVERTER(translate_rsub);
|
||||
OP_CONVERTER(translate_scaled_dot_product_attention);
|
||||
OP_CONVERTER(translate_scatter);
|
||||
OP_CONVERTER(translate_scatter_add);
|
||||
OP_CONVERTER(translate_scatter_reduce);
|
||||
OP_CONVERTER(translate_select);
|
||||
OP_CONVERTER(translate_set_item);
|
||||
@@ -502,6 +503,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::scaled_dot_product_attention", op::translate_scaled_dot_product_attention},
|
||||
{"aten::scatter", op::translate_scatter},
|
||||
{"aten::scatter_", op::inplace_op<op::translate_scatter>},
|
||||
{"aten::scatter_add", op::translate_scatter_add},
|
||||
{"aten::scatter_add_", op::inplace_op<op::translate_scatter_add>},
|
||||
{"aten::scatter_reduce", op::translate_scatter_reduce},
|
||||
{"aten::scatter_reduce_", op::inplace_op<op::translate_scatter_reduce>},
|
||||
{"aten::select", op::quantizable_op<op::translate_select>},
|
||||
|
||||
@@ -219,3 +219,65 @@ class TestScatterReduce(PytorchLayerTest):
|
||||
kwargs_to_prepare_input={"dtype": dtype, "out": has_out},
|
||||
freeze_model=freeze
|
||||
)
|
||||
|
||||
class TestScatterAdd(PytorchLayerTest):
|
||||
def _prepare_input(self, dtype):
|
||||
return (np.random.randn(6, 6).astype(dtype),)
|
||||
|
||||
def create_model(self, dim, index, src, inplace):
|
||||
class aten_scatter_reduce(torch.nn.Module):
|
||||
def __init__(self, dim, index, src, inplace):
|
||||
super(aten_scatter_reduce, self).__init__()
|
||||
self.dim = dim
|
||||
self.use_empty_index = False
|
||||
if index is None:
|
||||
self.use_empty_index = True
|
||||
# Placeholder
|
||||
self.index = torch.empty([1])
|
||||
else:
|
||||
self.index = index
|
||||
self.src = src
|
||||
self.inplace = inplace
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.use_empty_index:
|
||||
index = torch.empty([0, 0])
|
||||
else:
|
||||
index = self.index
|
||||
if self.inplace:
|
||||
return x.scatter_add_(self.dim, index, self.src)
|
||||
else:
|
||||
return x.scatter_add(self.dim, index, self.src)
|
||||
|
||||
op_name = "aten::scatter_add_" if inplace else "aten::scatter_add"
|
||||
|
||||
return aten_scatter_reduce(dim, index, src, inplace), None, op_name
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("dim", [1, -1, 0])
|
||||
@pytest.mark.parametrize(
|
||||
"index",
|
||||
[
|
||||
None, # Empty tensor scenario.
|
||||
torch.tensor([[0, 1, 2, 3]]),
|
||||
torch.tensor([[0, 5], [4, 1], [2, 3]]),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("src", [torch.arange(1, 26).reshape(5, 5)])
|
||||
@pytest.mark.parametrize("dtype", ["int32", "int64", "float32", "float64"])
|
||||
@pytest.mark.parametrize("inplace", [True, False])
|
||||
def test_scatter_reduce(self, dim, index, src, dtype, inplace, ie_device, precision, ir_version):
|
||||
if isinstance(src, torch.Tensor):
|
||||
src = src.to(getattr(torch, dtype))
|
||||
if index is None:
|
||||
pytest.skip(
|
||||
"Cannot test reduce parameters with empty indexes due to issues with empty constant tensor or issues with prim::GetAttr str inputs."
|
||||
)
|
||||
self._test(
|
||||
*self.create_model(dim, index, src, inplace),
|
||||
ie_device,
|
||||
precision,
|
||||
ir_version,
|
||||
kwargs_to_prepare_input={"dtype": dtype},
|
||||
)
|
||||
|
||||
@@ -5,9 +5,9 @@ LearningToPaint,None
|
||||
Super_SloMo,None,xfail,Unsupported ops aten::l1_loss aten::mse_loss
|
||||
#alexnet,None - Already tested by torchvision tests
|
||||
basic_gnn_edgecnn,None,xfail,Accuracy validation failed
|
||||
basic_gnn_gcn,None,xfail,Unsupported ops aten::pow_ aten::scatter_add_
|
||||
basic_gnn_gin,None,xfail,Unsupported op aten::scatter_add_
|
||||
basic_gnn_sage,None,xfail,Unsupported op aten::scatter_add_
|
||||
basic_gnn_gcn,None,xfail,Unsupported ops aten::pow_
|
||||
basic_gnn_gin,None
|
||||
basic_gnn_sage,None
|
||||
#cm3leon_generate,None,skip,No install.py is found
|
||||
dcgan,None
|
||||
demucs,None
|
||||
|
||||
Reference in New Issue
Block a user