[PT FE]: aten::gather (#16784)

* [PT FE]: aten::gather

* add detach and sign
This commit is contained in:
Ekaterina Aidova
2023-04-11 14:28:05 +04:00
committed by GitHub
parent d407bc1b3b
commit d41663694c
5 changed files with 171 additions and 0 deletions

View File

@@ -0,0 +1,34 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/gather_elements.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
OutputVector translate_gather(const NodeContext& context) {
// aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor
// aten::gather.out(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!)
num_inputs_check(context, 4, 5);
auto x = context.get_input(0);
auto axis = context.const_input<int64_t>(1);
auto index = context.get_input(2);
index = context.mark_node(std::make_shared<ov::op::v0::Convert>(index, element::i32));
// input 3 sparse_grad if True, gradient w.r.t. input will be a sparse tensor, used only for training, skip
auto gather_elements = context.mark_node(std::make_shared<ov::op::v6::GatherElements>(x, index, axis));
if (!context.input_is_none(4)) {
context.mutate_input(4, gather_elements);
}
return {gather_elements};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@@ -0,0 +1,31 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/op/sign.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_sign(const NodeContext& context) {
// aten::sign(input, *, out=None)
num_inputs_check(context, 1, 2);
auto input = context.get_input(0);
auto sign = context.mark_node(std::make_shared<ov::op::v0::Sign>(input));
if (!context.input_is_none(1)) {
context.mutate_input(1, sign);
}
return {sign};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@@ -49,6 +49,7 @@ OP_CONVERTER(translate_floor_divide);
OP_CONVERTER(translate_floordiv);
OP_CONVERTER(translate_full);
OP_CONVERTER(translate_full_like);
OP_CONVERTER(translate_gather);
OP_CONVERTER(translate_gelu);
OP_CONVERTER(translate_get_attr);
OP_CONVERTER(translate_getitem);
@@ -106,6 +107,7 @@ OP_CONVERTER(translate_rsub);
OP_CONVERTER(translate_select);
OP_CONVERTER(translate_set_item);
OP_CONVERTER(translate_selu);
OP_CONVERTER(translate_sign);
OP_CONVERTER(translate_size);
OP_CONVERTER(translate_slice);
OP_CONVERTER(translate_softmax);
@@ -197,6 +199,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::cosh", op::translate_1to1_match_1_inputs<opset10::Cosh>},
{"aten::cosh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Cosh>>},
{"aten::cumsum", op::translate_cumsum},
{"aten::detach", op::skip_node},
{"aten::dim", op::translate_dim},
{"aten::div", op::translate_div},
{"aten::div_", op::inplace_op<op::translate_div>},
@@ -218,6 +221,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::floordiv", op::translate_floordiv},
{"aten::full", op::translate_full},
{"aten::full_like", op::translate_full_like},
{"aten::gather", op::translate_gather},
{"aten::ge", op::translate_1to1_match_2_inputs_align_types<opset10::GreaterEqual>},
{"aten::gelu", op::translate_gelu},
{"aten::glu", op::translate_glu},
@@ -297,6 +301,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::selu_", op::inplace_op<op::translate_selu>},
{"aten::sigmoid", op::translate_1to1_match_1_inputs<opset10::Sigmoid>},
{"aten::sigmoid_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Sigmoid>>},
{"aten::sign", op::translate_sign},
{"aten::silu", op::translate_1to1_match_1_inputs<opset10::Swish>},
{"aten::silu_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Swish>>},
{"aten::sin", op::translate_1to1_match_1_inputs<opset10::Sin>},

View File

@@ -0,0 +1,49 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestGelu(PytorchLayerTest):
def _prepare_input(self, m, n, max_val, out=False):
import numpy as np
index = np.random.randint(0, max_val, (m, n))
inp = np.random.randn(m, n).astype(np.float32)
if out:
axis = int(max_val == n)
out = np.zeros_like(np.take(inp, index, axis))
return (inp, index, out)
return (inp, index)
def create_model(self, axis, out):
import torch
class aten_gather(torch.nn.Module):
def __init__(self, axis, out=False):
super(aten_gather, self).__init__()
self.axis = axis
if out:
self.forward = self.forward_out
def forward(self, x, index):
return torch.gather(x, self.axis, index)
def forward_out(self, x, index, out):
return torch.gather(x, self.axis, index, out=out)
ref_net = None
return aten_gather(axis, out), ref_net, "aten::gather"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("m", [2, 10, 100])
@pytest.mark.parametrize("n", [2, 10, 100])
@pytest.mark.parametrize("axis", [0, 1])
@pytest.mark.parametrize("out", [True, False])
def test_gather(self, m, n, axis, out, ie_device, precision, ir_version):
self._test(*self.create_model(axis, out), ie_device, precision, ir_version, kwargs_to_prepare_input={
"m": m, "n": n, "max_val": m if axis == 0 else n, "out": out
})

View File

@@ -0,0 +1,52 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestSilu(PytorchLayerTest):
def _prepare_input(self, inp_type="mixed", out=False):
import numpy as np
inp = np.arange(0, 10).astype(np.float32)
if inp_type == "negative":
inp[0] = 1
inp = -1 * inp
elif inp_type == "positive":
inp[0] = 11
elif inp_type == "zeros":
inp *= 0
else:
idx = np.random.choice(inp, 3)
inp[idx.astype(int)] *= -1
if out:
return (inp, np.zeros_like(inp))
return (inp, )
def create_model(self, out):
import torch
class aten_sign(torch.nn.Module):
def __init__(self, out):
super(aten_sign, self).__init__()
if out:
self.forward = self.forward_out
def forward(self, x):
return torch.sign(x)
def forward_out(self, x, out):
return torch.sign(x), out
ref_net = None
return aten_sign(out), ref_net, "aten::sign"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("input_type", ["zeros", "positive", "negative", "mixed"])
@pytest.mark.parametrize("out", [True, False])
def test_sign(self, input_type, out, ie_device, precision, ir_version):
self._test(*self.create_model(out), ie_device, precision, ir_version,
kwargs_to_prepare_input={"inp_type": input_type, "out": out})