[PT FE]: support aten::take_along_dim (#21625)
This commit is contained in:
parent
9f6c3e997f
commit
27bf494355
56
src/frontends/pytorch/src/op/take_along_dim.cpp
Normal file
56
src/frontends/pytorch/src/op/take_along_dim.cpp
Normal file
@ -0,0 +1,56 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/gather_elements.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/scatter_update.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_take_along_dim(const NodeContext& context) {
|
||||
// aten::take_along_dim(Tensor self, Tensor indices, int? dim=None) -> Tensor
|
||||
// aten::take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)
|
||||
num_inputs_check(context, 3, 4);
|
||||
auto x = context.get_input(0);
|
||||
auto index = context.get_input(1);
|
||||
index = context.mark_node(std::make_shared<ov::op::v0::Convert>(index, element::i32));
|
||||
int64_t axis = 0;
|
||||
|
||||
if (context.input_is_none(2)) {
|
||||
// if dimension is not provided, flattenize input first
|
||||
auto minus_1 = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
x = context.mark_node(std::make_shared<ov::op::v1::Reshape>(x, minus_1, false));
|
||||
} else {
|
||||
axis = context.const_input<int64_t>(2);
|
||||
// OpenVINO GatherElements requires to have equal dims between index and input except specified axis
|
||||
// while PyTorch requires to have them broadcastable
|
||||
auto axis_node = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {axis}));
|
||||
auto const_1 = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {1}));
|
||||
auto const_0 = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{1}, {0}));
|
||||
auto x_shape = context.mark_node(std::make_shared<ov::op::v3::ShapeOf>(x, element::i32));
|
||||
auto broadcast_shape =
|
||||
context.mark_node(std::make_shared<ov::op::v3::ScatterUpdate>(x_shape, axis_node, const_1, const_0));
|
||||
index = context.mark_node(
|
||||
std::make_shared<ov::op::v3::Broadcast>(index, broadcast_shape, ov::op::BroadcastType::BIDIRECTIONAL));
|
||||
}
|
||||
auto gather_elements = context.mark_node(std::make_shared<ov::op::v6::GatherElements>(x, index, axis));
|
||||
if (!context.input_is_none(3)) {
|
||||
context.mutate_input(3, gather_elements);
|
||||
}
|
||||
return {gather_elements};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -197,6 +197,7 @@ OP_CONVERTER(translate_sub);
|
||||
OP_CONVERTER(translate_sub_);
|
||||
OP_CONVERTER(translate_sum);
|
||||
OP_CONVERTER(translate_t);
|
||||
OP_CONVERTER(translate_take_along_dim);
|
||||
OP_CONVERTER(translate_to);
|
||||
OP_CONVERTER(translate_topk);
|
||||
OP_CONVERTER(translate_transpose);
|
||||
@ -536,6 +537,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::swapaxes", op::quantizable_op<op::translate_transpose>},
|
||||
{"aten::t", op::translate_t},
|
||||
{"aten::t_", op::inplace_op<op::translate_t>},
|
||||
{"aten::take_along_dim", op::translate_take_along_dim},
|
||||
{"aten::tan", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Tan>},
|
||||
{"aten::tan_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Tan>>},
|
||||
{"aten::tanh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Tanh>},
|
||||
|
56
tests/layer_tests/pytorch_tests/test_take_along_dim.py
Normal file
56
tests/layer_tests/pytorch_tests/test_take_along_dim.py
Normal file
@ -0,0 +1,56 @@
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
class TestTakeAlongDim(PytorchLayerTest):
|
||||
def _prepare_input(self, m, n, max_val, out=False, flattenize=False):
|
||||
import numpy as np
|
||||
index = np.random.randint(0, max_val, (m, n) if not flattenize else (m*n, ))
|
||||
inp = np.random.randn(m, n).astype(np.float32)
|
||||
if out:
|
||||
axis = int(max_val == n)
|
||||
if flattenize:
|
||||
out = np.zeros_like(np.take(inp, index))
|
||||
else:
|
||||
out = np.zeros_like(np.take(inp, index, axis))
|
||||
return (inp, index, out)
|
||||
return (inp, index)
|
||||
|
||||
def create_model(self, axis, out):
|
||||
import torch
|
||||
|
||||
class aten_take_along_dim(torch.nn.Module):
|
||||
def __init__(self, axis, out=False):
|
||||
super(aten_take_along_dim, self).__init__()
|
||||
self.axis = axis
|
||||
if self.axis is None:
|
||||
self.forward = self.forward_no_dim
|
||||
if out:
|
||||
self.forward = self.forward_out if self.axis is not None else self.forward_no_dim_out
|
||||
|
||||
def forward(self, x, index):
|
||||
return torch.take_along_dim(x, index, dim=self.axis)
|
||||
|
||||
def forward_out(self, x, index, out):
|
||||
return torch.take_along_dim(x, index, dim=self.axis, out=out), out
|
||||
|
||||
def forward_no_dim(self, x, index):
|
||||
return torch.take_along_dim(x, index)
|
||||
|
||||
def forward_no_dim_out(self, x, index, out):
|
||||
return torch.take_along_dim(x, index, out=out)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_take_along_dim(axis, out), ref_net, "aten::take_along_dim"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("m", [2, 10, 100])
|
||||
@pytest.mark.parametrize("n", [2, 10, 100])
|
||||
@pytest.mark.parametrize("axis", [0, 1, None])
|
||||
@pytest.mark.parametrize("out", [True, False])
|
||||
def test_gather(self, m, n, axis, out, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(axis, out), ie_device, precision, ir_version, kwargs_to_prepare_input={
|
||||
"m": m, "n": n, "max_val": m if axis == 0 else n, "out": out, "flattenize": axis is None
|
||||
})
|
Loading…
Reference in New Issue
Block a user