[PT FE] Add aten::LogSoftmax (#17629)
* [PT FE] Add aten::LogSoftmax implementation & tests * Update log_softmax.cpp * Update src/frontends/pytorch/src/op/log_softmax.cpp Co-authored-by: Maxim Vafin <maxim.vafin@intel.com> * [PT FE] Add recommended comment, replace get_input_tensor with new implementation * [PT FE] Align to f32 if no dtype provided * [PT FE] Revert type align --------- Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
44
src/frontends/pytorch/src/op/log_softmax.cpp
Normal file
44
src/frontends/pytorch/src/op/log_softmax.cpp
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_log_softmax(const NodeContext& context) {
|
||||
/*
|
||||
aten::log_softmax(
|
||||
Tensor input,
|
||||
int64 dim,
|
||||
dtype dtype = None
|
||||
)
|
||||
*/
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto input = context.get_input(0);
|
||||
auto const dim = context.const_input<int64_t>(1);
|
||||
|
||||
if (!context.input_is_none(2)) {
|
||||
const auto elem_type = input.get_element_type();
|
||||
const auto target_dtype_i64 = context.const_input<int64_t>(2);
|
||||
const auto target_dtype = convert_dtype(target_dtype_i64);
|
||||
if (elem_type != target_dtype) {
|
||||
input = context.mark_node(std::make_shared<opset10::Convert>(input, target_dtype));
|
||||
}
|
||||
}
|
||||
|
||||
const auto log_softmax = context.mark_node(std::make_shared<opset10::LogSoftmax>(input, dim));
|
||||
return {log_softmax};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
@@ -79,6 +79,7 @@ OP_CONVERTER(translate_linalg_vector_norm);
|
||||
OP_CONVERTER(translate_linear);
|
||||
OP_CONVERTER(translate_list_construct);
|
||||
OP_CONVERTER(translate_log);
|
||||
OP_CONVERTER(translate_log_softmax);
|
||||
OP_CONVERTER(translate_log2);
|
||||
OP_CONVERTER(translate_loop);
|
||||
OP_CONVERTER(translate_masked_fill);
|
||||
@@ -273,6 +274,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::linear", op::translate_linear},
|
||||
{"aten::log", op::translate_log},
|
||||
{"aten::log_", op::inplace_op<op::translate_log>},
|
||||
{"aten::log_softmax", op::translate_log_softmax},
|
||||
{"aten::log2", op::translate_log2},
|
||||
{"aten::log2_", op::inplace_op<op::translate_log2>},
|
||||
{"aten::lt", op::translate_1to1_match_2_inputs_align_types<opset10::Less>},
|
||||
|
||||
45
tests/layer_tests/pytorch_tests/test_log_softmax.py
Normal file
45
tests/layer_tests/pytorch_tests/test_log_softmax.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
class aten_log_softmax(torch.nn.Module):
|
||||
def __init__(self, dim, dtype) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, input_tensor):
|
||||
return F.log_softmax(input_tensor, dim = self.dim, dtype = self.dtype)
|
||||
|
||||
class TestLogSoftmax(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
if self.input_dtype == torch.float:
|
||||
self.input_tensor = np.random.randn(5, 9, 7)
|
||||
else:
|
||||
self.input_tensor = np.random.randint(-100, 100, (5, 9, 7))
|
||||
return (self.input_tensor,)
|
||||
|
||||
@pytest.mark.parametrize(["input_dtype", "convert_dtype"], [
|
||||
# convert_dtype cannot be of type int from pytorch limitations
|
||||
[torch.int, torch.float32],
|
||||
[torch.int, torch.float64],
|
||||
[torch.float, None],
|
||||
[torch.float, torch.float64]
|
||||
])
|
||||
@pytest.mark.parametrize("dim", [
|
||||
0,
|
||||
1,
|
||||
-1
|
||||
])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_log_softmax(self, input_dtype, convert_dtype, dim, ie_device, precision, ir_version):
|
||||
self.input_dtype = input_dtype
|
||||
self._test(aten_log_softmax(dim, convert_dtype), None, "aten::log_softmax",
|
||||
ie_device, precision, ir_version)
|
||||
Reference in New Issue
Block a user