[PT FE] Add aten::scatter and inplace for aten::sub translation (#18341)
* Add sub inplace * Add scatter implementation * Remove debug var * Add tests for empty index * Add reduce support --------- Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
committed by
GitHub
parent
0148076ed7
commit
82c65c25da
77
src/frontends/pytorch/src/op/scatter.cpp
Normal file
77
src/frontends/pytorch/src/op/scatter.cpp
Normal file
@@ -0,0 +1,77 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/convert_like.hpp"
|
||||
#include "openvino/op/scatter_elements_update.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_scatter(const NodeContext& context) {
|
||||
// Out-of-place schema
|
||||
// aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor:
|
||||
// aten::scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor:
|
||||
// aten::scatter.reduce(Tensor self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor:
|
||||
// aten::scatter.value_reduce(Tensor self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor:
|
||||
|
||||
// Inplace schema
|
||||
// aten::scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!):
|
||||
// aten::scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!):
|
||||
// aten::scatter_.reduce(Tensor(a!) self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor(a!):
|
||||
// aten::scatter_.value_reduce(Tensor(a!) self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor(a!):
|
||||
num_inputs_check(context, 4, 5);
|
||||
auto input = context.get_input(0);
|
||||
auto dim = context.get_input(1);
|
||||
auto index = context.mark_node(std::make_shared<v0::Convert>(context.get_input(2), element::i32));
|
||||
auto src = context.get_input(3);
|
||||
|
||||
auto reduction = v12::ScatterElementsUpdate::Reduction::NONE;
|
||||
auto input_num = context.get_input_size();
|
||||
if (input_num > 4 && !context.input_is_none(input_num - 1)) {
|
||||
auto reduce_mode = context.const_input<std::string>(input_num - 1);
|
||||
if (reduce_mode == "add") {
|
||||
reduction = v12::ScatterElementsUpdate::Reduction::SUM;
|
||||
} else if (reduce_mode == "multiply") {
|
||||
reduction = v12::ScatterElementsUpdate::Reduction::PROD;
|
||||
}
|
||||
}
|
||||
auto src_partial_shape = src.get_partial_shape();
|
||||
auto index_shape_rank = get_shape_rank(context, index);
|
||||
auto index_shape = std::get<0>(index_shape_rank);
|
||||
auto index_rank = std::get<1>(index_shape_rank);
|
||||
|
||||
// Source input can be either Tensor which should be passed in original shape or Scalar that should be broadcasted
|
||||
// into shape of indices.
|
||||
// TODO: Figure out way to dynamically broadcast scalar src only, without affecting Tensor src. Current
|
||||
// implementation will fail if Scalar source would have dynamic rank.
|
||||
if (src_partial_shape.rank().is_static() && src_partial_shape.rank().get_length() == 0) {
|
||||
src = context.mark_node(std::make_shared<v3::Broadcast>(src, index_shape));
|
||||
}
|
||||
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto zeros = context.mark_node(std::make_shared<v3::Broadcast>(const_0, index_rank));
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
auto ones = context.mark_node(std::make_shared<v3::Broadcast>(const_1, index_rank));
|
||||
// In torch indices can be of different shape than source tensor. Create slice to trim source tensor to shape of
|
||||
// indices.
|
||||
auto src_pruned = context.mark_node(std::make_shared<v8::Slice>(src, zeros, index_shape, ones));
|
||||
|
||||
auto src_input_dtype = context.mark_node(std::make_shared<v1::ConvertLike>(src_pruned, input));
|
||||
return {
|
||||
context.mark_node(std::make_shared<v12::ScatterElementsUpdate>(input, index, src_input_dtype, dim, reduction))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
@@ -119,6 +119,7 @@ OP_CONVERTER(translate_roll);
|
||||
OP_CONVERTER(translate_rsqrt);
|
||||
OP_CONVERTER(translate_rsub);
|
||||
OP_CONVERTER(translate_scaled_dot_product_attention);
|
||||
OP_CONVERTER(translate_scatter);
|
||||
OP_CONVERTER(translate_select);
|
||||
OP_CONVERTER(translate_set_item);
|
||||
OP_CONVERTER(translate_selu);
|
||||
@@ -332,6 +333,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::rsub", op::translate_rsub},
|
||||
{"aten::ScalarImplicit", op::skip_node},
|
||||
{"aten::scaled_dot_product_attention", op::translate_scaled_dot_product_attention},
|
||||
{"aten::scatter", op::translate_scatter},
|
||||
{"aten::scatter_", op::inplace_op<op::translate_scatter>},
|
||||
{"aten::select", op::translate_select},
|
||||
{"aten::selu", op::translate_selu},
|
||||
{"aten::selu_", op::inplace_op<op::translate_selu>},
|
||||
@@ -352,6 +355,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::square", op::translate_square},
|
||||
{"aten::squeeze", op::translate_squeeze},
|
||||
{"aten::sub", op::translate_sub},
|
||||
{"aten::sub_", op::inplace_op<op::translate_sub>},
|
||||
{"aten::sum", op::translate_sum},
|
||||
{"aten::t", op::translate_t},
|
||||
{"aten::t_", op::inplace_op<op::translate_t>},
|
||||
|
||||
108
tests/layer_tests/pytorch_tests/test_scatter.py
Normal file
108
tests/layer_tests/pytorch_tests/test_scatter.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestScatter(PytorchLayerTest):
|
||||
def _prepare_input(self, dtype):
|
||||
inp = np.random.randn(6, 6).astype(getattr(np, dtype))
|
||||
return (inp,)
|
||||
|
||||
def create_model(self, dim, index, src, inplace, reduce):
|
||||
class aten_scatter(torch.nn.Module):
|
||||
def __init__(self, dim, index, src, inplace, reduce):
|
||||
super(aten_scatter, self).__init__()
|
||||
self.dim = dim
|
||||
self.use_empty_index = False
|
||||
if index is None:
|
||||
self.use_empty_index = True
|
||||
# Placeholder
|
||||
self.index = torch.empty([1])
|
||||
else:
|
||||
self.index = index
|
||||
self.src = src
|
||||
str_forward = "_forward"
|
||||
if inplace:
|
||||
str_forward += "_inplace"
|
||||
else:
|
||||
str_forward += "_out_of_place"
|
||||
|
||||
if reduce:
|
||||
self.reduce = reduce
|
||||
str_forward += "_reduce"
|
||||
self.forward = getattr(self, str_forward)
|
||||
|
||||
def _forward_out_of_place(self, x: torch.Tensor):
|
||||
if self.use_empty_index:
|
||||
index = torch.empty([0, 0])
|
||||
else:
|
||||
index = self.index
|
||||
return torch.scatter(x, self.dim, index, self.src)
|
||||
|
||||
def _forward_inplace(self, x: torch.Tensor):
|
||||
if self.use_empty_index:
|
||||
index = torch.empty([0, 0])
|
||||
else:
|
||||
index = self.index
|
||||
return x.scatter_(self.dim, index, self.src)
|
||||
|
||||
def _forward_out_of_place_reduce(self, x: torch.Tensor):
|
||||
if self.use_empty_index:
|
||||
index = torch.empty([0, 0])
|
||||
else:
|
||||
index = self.index
|
||||
return torch.scatter(x, self.dim, index, self.src, reduce=self.reduce)
|
||||
|
||||
def _forward_inplace_reduce(self, x: torch.Tensor):
|
||||
if self.use_empty_index:
|
||||
index = torch.empty([0, 0])
|
||||
else:
|
||||
index = self.index
|
||||
return x.scatter_(self.dim, index, self.src, reduce=self.reduce)
|
||||
|
||||
ref_net = None
|
||||
if inplace:
|
||||
op_name = "aten::scatter_"
|
||||
else:
|
||||
op_name = "aten::scatter"
|
||||
|
||||
return aten_scatter(dim, index, src, inplace, reduce), ref_net, op_name
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("dim", [1, -1, 0])
|
||||
@pytest.mark.parametrize(
|
||||
"index",
|
||||
[
|
||||
None, # Empty tensor scenario.
|
||||
torch.tensor([[0, 1, 2, 3]]),
|
||||
torch.tensor([[0, 5], [4, 1], [2, 3]]),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("src", [torch.arange(1, 26).reshape(5, 5), 1])
|
||||
@pytest.mark.parametrize("dtype", ["int32", "int64", "float32", "float64"])
|
||||
@pytest.mark.parametrize("inplace", [True, False])
|
||||
@pytest.mark.parametrize("reduce", [None, "add", "multiply"])
|
||||
def test_scatter(self, dim, index, src, dtype, inplace, reduce, ie_device, precision, ir_version):
|
||||
if isinstance(src, torch.Tensor):
|
||||
src = src.to(getattr(torch, dtype))
|
||||
freeze = True
|
||||
if index is None:
|
||||
# Freeze creates empty constant tensor which isn't supported by OV.
|
||||
freeze = False
|
||||
if (not freeze) and reduce:
|
||||
pytest.skip(
|
||||
"Cannot test reduce parameters with empty indexes due to issues with empty constant tensor or issues with prim::GetAttr str inputs."
|
||||
)
|
||||
self._test(
|
||||
*self.create_model(dim, index, src, inplace, reduce),
|
||||
ie_device,
|
||||
precision,
|
||||
ir_version,
|
||||
kwargs_to_prepare_input={"dtype": dtype},
|
||||
freeze_model=freeze
|
||||
)
|
||||
@@ -12,16 +12,30 @@ class TestSub(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return self.input_data
|
||||
|
||||
def create_model(self):
|
||||
def create_model(self, inplace):
|
||||
|
||||
class aten_sub(torch.nn.Module):
|
||||
def __init__(self, inplace) -> None:
|
||||
super().__init__()
|
||||
if inplace:
|
||||
self.forward = self._forward_inplace
|
||||
else:
|
||||
self.forward = self._forward_out_of_place
|
||||
|
||||
def forward(self, x, y, alpha: float):
|
||||
def _forward_out_of_place(self, x, y, alpha: float):
|
||||
return torch.sub(x, y, alpha=alpha)
|
||||
|
||||
def _forward_inplace(self, x, y, alpha: float):
|
||||
return x.sub_(y, alpha=alpha)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_sub(), ref_net, "aten::sub"
|
||||
if inplace:
|
||||
op_name = "aten::sub_"
|
||||
else:
|
||||
op_name = "aten::sub"
|
||||
|
||||
return aten_sub(inplace), ref_net, op_name
|
||||
|
||||
@pytest.mark.parametrize('input_data', [(np.random.randn(2, 3, 4).astype(np.float32),
|
||||
np.random.randn(
|
||||
@@ -31,11 +45,12 @@ class TestSub(PytorchLayerTest):
|
||||
np.random.randn(
|
||||
1, 2, 3).astype(np.float32),
|
||||
np.random.randn(1)), ])
|
||||
@pytest.mark.parametrize("inplace", [True, False])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_sub(self, ie_device, precision, ir_version, input_data):
|
||||
def test_sub(self, ie_device, precision, ir_version, input_data, inplace):
|
||||
self.input_data = input_data
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
||||
self._test(*self.create_model(inplace), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestSubTypes(PytorchLayerTest):
|
||||
|
||||
Reference in New Issue
Block a user