[PT FE] Add aten::scatter and inplace for aten::sub translation (#18341)

* Add sub inplace

* Add scatter implementation

* Remove debug var

* Add tests for empty index

* Add reduce support

---------

Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
Mateusz Mikolajczyk
2023-07-11 11:00:50 +02:00
committed by GitHub
parent 0148076ed7
commit 82c65c25da
4 changed files with 209 additions and 5 deletions

View File

@@ -0,0 +1,77 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/convert_like.hpp"
#include "openvino/op/scatter_elements_update.hpp"
#include "openvino/op/slice.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_scatter(const NodeContext& context) {
// Out-of-place schema
// aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor:
// aten::scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor:
// aten::scatter.reduce(Tensor self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor:
// aten::scatter.value_reduce(Tensor self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor:
// Inplace schema
// aten::scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!):
// aten::scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!):
// aten::scatter_.reduce(Tensor(a!) self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor(a!):
// aten::scatter_.value_reduce(Tensor(a!) self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor(a!):
num_inputs_check(context, 4, 5);
auto input = context.get_input(0);
auto dim = context.get_input(1);
auto index = context.mark_node(std::make_shared<v0::Convert>(context.get_input(2), element::i32));
auto src = context.get_input(3);
auto reduction = v12::ScatterElementsUpdate::Reduction::NONE;
auto input_num = context.get_input_size();
if (input_num > 4 && !context.input_is_none(input_num - 1)) {
auto reduce_mode = context.const_input<std::string>(input_num - 1);
if (reduce_mode == "add") {
reduction = v12::ScatterElementsUpdate::Reduction::SUM;
} else if (reduce_mode == "multiply") {
reduction = v12::ScatterElementsUpdate::Reduction::PROD;
}
}
auto src_partial_shape = src.get_partial_shape();
auto index_shape_rank = get_shape_rank(context, index);
auto index_shape = std::get<0>(index_shape_rank);
auto index_rank = std::get<1>(index_shape_rank);
// Source input can be either Tensor which should be passed in original shape or Scalar that should be broadcasted
// into shape of indices.
// TODO: Figure out way to dynamically broadcast scalar src only, without affecting Tensor src. Current
// implementation will fail if Scalar source would have dynamic rank.
if (src_partial_shape.rank().is_static() && src_partial_shape.rank().get_length() == 0) {
src = context.mark_node(std::make_shared<v3::Broadcast>(src, index_shape));
}
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto zeros = context.mark_node(std::make_shared<v3::Broadcast>(const_0, index_rank));
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto ones = context.mark_node(std::make_shared<v3::Broadcast>(const_1, index_rank));
// In torch indices can be of different shape than source tensor. Create slice to trim source tensor to shape of
// indices.
auto src_pruned = context.mark_node(std::make_shared<v8::Slice>(src, zeros, index_shape, ones));
auto src_input_dtype = context.mark_node(std::make_shared<v1::ConvertLike>(src_pruned, input));
return {
context.mark_node(std::make_shared<v12::ScatterElementsUpdate>(input, index, src_input_dtype, dim, reduction))};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@@ -119,6 +119,7 @@ OP_CONVERTER(translate_roll);
OP_CONVERTER(translate_rsqrt);
OP_CONVERTER(translate_rsub);
OP_CONVERTER(translate_scaled_dot_product_attention);
OP_CONVERTER(translate_scatter);
OP_CONVERTER(translate_select);
OP_CONVERTER(translate_set_item);
OP_CONVERTER(translate_selu);
@@ -332,6 +333,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::rsub", op::translate_rsub},
{"aten::ScalarImplicit", op::skip_node},
{"aten::scaled_dot_product_attention", op::translate_scaled_dot_product_attention},
{"aten::scatter", op::translate_scatter},
{"aten::scatter_", op::inplace_op<op::translate_scatter>},
{"aten::select", op::translate_select},
{"aten::selu", op::translate_selu},
{"aten::selu_", op::inplace_op<op::translate_selu>},
@@ -352,6 +355,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::square", op::translate_square},
{"aten::squeeze", op::translate_squeeze},
{"aten::sub", op::translate_sub},
{"aten::sub_", op::inplace_op<op::translate_sub>},
{"aten::sum", op::translate_sum},
{"aten::t", op::translate_t},
{"aten::t_", op::inplace_op<op::translate_t>},

View File

@@ -0,0 +1,108 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
class TestScatter(PytorchLayerTest):
def _prepare_input(self, dtype):
inp = np.random.randn(6, 6).astype(getattr(np, dtype))
return (inp,)
def create_model(self, dim, index, src, inplace, reduce):
class aten_scatter(torch.nn.Module):
def __init__(self, dim, index, src, inplace, reduce):
super(aten_scatter, self).__init__()
self.dim = dim
self.use_empty_index = False
if index is None:
self.use_empty_index = True
# Placeholder
self.index = torch.empty([1])
else:
self.index = index
self.src = src
str_forward = "_forward"
if inplace:
str_forward += "_inplace"
else:
str_forward += "_out_of_place"
if reduce:
self.reduce = reduce
str_forward += "_reduce"
self.forward = getattr(self, str_forward)
def _forward_out_of_place(self, x: torch.Tensor):
if self.use_empty_index:
index = torch.empty([0, 0])
else:
index = self.index
return torch.scatter(x, self.dim, index, self.src)
def _forward_inplace(self, x: torch.Tensor):
if self.use_empty_index:
index = torch.empty([0, 0])
else:
index = self.index
return x.scatter_(self.dim, index, self.src)
def _forward_out_of_place_reduce(self, x: torch.Tensor):
if self.use_empty_index:
index = torch.empty([0, 0])
else:
index = self.index
return torch.scatter(x, self.dim, index, self.src, reduce=self.reduce)
def _forward_inplace_reduce(self, x: torch.Tensor):
if self.use_empty_index:
index = torch.empty([0, 0])
else:
index = self.index
return x.scatter_(self.dim, index, self.src, reduce=self.reduce)
ref_net = None
if inplace:
op_name = "aten::scatter_"
else:
op_name = "aten::scatter"
return aten_scatter(dim, index, src, inplace, reduce), ref_net, op_name
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("dim", [1, -1, 0])
@pytest.mark.parametrize(
"index",
[
None, # Empty tensor scenario.
torch.tensor([[0, 1, 2, 3]]),
torch.tensor([[0, 5], [4, 1], [2, 3]]),
],
)
@pytest.mark.parametrize("src", [torch.arange(1, 26).reshape(5, 5), 1])
@pytest.mark.parametrize("dtype", ["int32", "int64", "float32", "float64"])
@pytest.mark.parametrize("inplace", [True, False])
@pytest.mark.parametrize("reduce", [None, "add", "multiply"])
def test_scatter(self, dim, index, src, dtype, inplace, reduce, ie_device, precision, ir_version):
if isinstance(src, torch.Tensor):
src = src.to(getattr(torch, dtype))
freeze = True
if index is None:
# Freeze creates empty constant tensor which isn't supported by OV.
freeze = False
if (not freeze) and reduce:
pytest.skip(
"Cannot test reduce parameters with empty indexes due to issues with empty constant tensor or issues with prim::GetAttr str inputs."
)
self._test(
*self.create_model(dim, index, src, inplace, reduce),
ie_device,
precision,
ir_version,
kwargs_to_prepare_input={"dtype": dtype},
freeze_model=freeze
)

View File

@@ -12,16 +12,30 @@ class TestSub(PytorchLayerTest):
def _prepare_input(self):
return self.input_data
def create_model(self):
def create_model(self, inplace):
class aten_sub(torch.nn.Module):
def __init__(self, inplace) -> None:
super().__init__()
if inplace:
self.forward = self._forward_inplace
else:
self.forward = self._forward_out_of_place
def forward(self, x, y, alpha: float):
def _forward_out_of_place(self, x, y, alpha: float):
return torch.sub(x, y, alpha=alpha)
def _forward_inplace(self, x, y, alpha: float):
return x.sub_(y, alpha=alpha)
ref_net = None
return aten_sub(), ref_net, "aten::sub"
if inplace:
op_name = "aten::sub_"
else:
op_name = "aten::sub"
return aten_sub(inplace), ref_net, op_name
@pytest.mark.parametrize('input_data', [(np.random.randn(2, 3, 4).astype(np.float32),
np.random.randn(
@@ -31,11 +45,12 @@ class TestSub(PytorchLayerTest):
np.random.randn(
1, 2, 3).astype(np.float32),
np.random.randn(1)), ])
@pytest.mark.parametrize("inplace", [True, False])
@pytest.mark.nightly
@pytest.mark.precommit
def test_sub(self, ie_device, precision, ir_version, input_data):
def test_sub(self, ie_device, precision, ir_version, input_data, inplace):
self.input_data = input_data
self._test(*self.create_model(), ie_device, precision, ir_version)
self._test(*self.create_model(inplace), ie_device, precision, ir_version)
class TestSubTypes(PytorchLayerTest):