[PT FE] Add aten::_shape_as_tensor (#17804)
* [PT FE] Add aten::_shape_as_tensor impl * Update shape_as_tensor.cpp * [PT FE] Fix headers, add explicit type, comment out shape detection * [PT FE] Reverse example comments
This commit is contained in:
parent
0944295d61
commit
0d9109acf3
26
src/frontends/pytorch/src/op/shape_as_tensor.cpp
Normal file
26
src/frontends/pytorch/src/op/shape_as_tensor.cpp
Normal file
@ -0,0 +1,26 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_shape_as_tensor(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 1);
|
||||
auto input = context.get_input(0);
|
||||
auto shape = context.mark_node(std::make_shared<v3::ShapeOf>(input, element::i64));
|
||||
return {shape};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -120,6 +120,7 @@ OP_CONVERTER(translate_scaled_dot_product_attention);
|
||||
OP_CONVERTER(translate_select);
|
||||
OP_CONVERTER(translate_set_item);
|
||||
OP_CONVERTER(translate_selu);
|
||||
OP_CONVERTER(translate_shape_as_tensor);
|
||||
OP_CONVERTER(translate_sign);
|
||||
OP_CONVERTER(translate_size);
|
||||
OP_CONVERTER(translate_slice);
|
||||
@ -163,6 +164,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::_convolution_mode", op::translate_convolution_mode},
|
||||
{"aten::_native_multi_head_attention", op::translate_native_multi_head_attention},
|
||||
{"aten::_set_item", op::translate_set_item},
|
||||
{"aten::_shape_as_tensor", op::translate_shape_as_tensor},
|
||||
{"aten::abs", op::translate_1to1_match_1_inputs<opset10::Abs>},
|
||||
{"aten::acos", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acos>},
|
||||
{"aten::acos_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Acos>>},
|
||||
|
32
tests/layer_tests/pytorch_tests/test_shape_as_tensor.py
Normal file
32
tests/layer_tests/pytorch_tests/test_shape_as_tensor.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
class aten_shape_as_tensor(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
def forward(self, input_tensor):
|
||||
return torch.ops.aten._shape_as_tensor(input_tensor)
|
||||
|
||||
class TestShapeAsTensor(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (self.input_tensor,)
|
||||
|
||||
@pytest.mark.parametrize("shape", [
|
||||
# (),
|
||||
(2,),
|
||||
(1,2,3,4),
|
||||
(5,4,2,7)
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_all_noparams(self, shape, ie_device, precision, ir_version):
|
||||
self.input_tensor = np.zeros(shape)
|
||||
self._test(aten_shape_as_tensor(), None, "aten::_shape_as_tensor",
|
||||
ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user