[PT FE] Support aten::one_hot (#19779)
* [PT FE] Support aten::one_hot * Apply code style
This commit is contained in:
parent
f744869551
commit
4f92676c85
49
src/frontends/pytorch/src/op/one_hot.cpp
Normal file
49
src/frontends/pytorch/src/op/one_hot.cpp
Normal file
@ -0,0 +1,49 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/op/one_hot.hpp"
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/greater.hpp"
|
||||
#include "openvino/op/reduce_max.hpp"
|
||||
#include "openvino/op/select.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_one_hot(const NodeContext& context) {
|
||||
num_inputs_check(context, 1, 2);
|
||||
auto x = context.get_input(0);
|
||||
// aten::one_hot works on LongTensor which means we need to convert all inputs to i64
|
||||
x = context.mark_node(std::make_shared<v0::Convert>(x, element::i64));
|
||||
auto on_value = context.mark_node(v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto zero_value = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
|
||||
Output<Node> num_classes;
|
||||
if (context.input_is_none(1)) {
|
||||
num_classes = context.mark_node(v0::Constant::create(element::i64, Shape{}, {-1}));
|
||||
} else {
|
||||
num_classes = context.get_input(1);
|
||||
num_classes = context.mark_node(std::make_shared<v0::Convert>(num_classes, element::i64));
|
||||
}
|
||||
auto one = context.mark_node(v0::Constant::create(element::i64, Shape{}, {1}));
|
||||
auto greater = context.mark_node(std::make_shared<v1::Greater>(num_classes, zero_value));
|
||||
auto axes = get_axes_range(context, 0);
|
||||
auto max_class = context.mark_node(std::make_shared<v1::ReduceMax>(x, axes));
|
||||
max_class = context.mark_node(std::make_shared<v1::Add>(max_class, one));
|
||||
num_classes = context.mark_node(std::make_shared<v1::Select>(greater, num_classes, max_class));
|
||||
return {context.mark_node(std::make_shared<v1::OneHot>(x, num_classes, on_value, zero_value, -1))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -110,6 +110,7 @@ OP_CONVERTER(translate_nms);
|
||||
OP_CONVERTER(translate_nonzero);
|
||||
OP_CONVERTER(translate_norm);
|
||||
OP_CONVERTER(translate_numel);
|
||||
OP_CONVERTER(translate_one_hot);
|
||||
OP_CONVERTER(translate_ones);
|
||||
OP_CONVERTER(translate_ones_like);
|
||||
OP_CONVERTER(translate_or);
|
||||
@ -371,6 +372,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::nonzero", op::translate_nonzero},
|
||||
{"aten::norm", op::translate_norm},
|
||||
{"aten::numel", op::translate_numel},
|
||||
{"aten::one_hot", op::translate_one_hot},
|
||||
{"aten::ones", op::translate_ones},
|
||||
{"aten::ones_like", op::translate_ones_like},
|
||||
{"aten::outer", op::translate_outer},
|
||||
|
33
tests/layer_tests/pytorch_tests/test_one_hot.py
Normal file
33
tests/layer_tests/pytorch_tests/test_one_hot.py
Normal file
@ -0,0 +1,33 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestOneHot(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.randint(0, 100, (1,1000)).astype(np.int32),)
|
||||
|
||||
def create_model(self, num_classes):
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class aten_one_hot(torch.nn.Module):
|
||||
def __init__(self, num_classes):
|
||||
super(aten_one_hot, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
|
||||
def forward(self, x):
|
||||
return F.one_hot(torch.arange(0, x.numel()) % 3, self.num_classes)
|
||||
|
||||
return aten_one_hot(num_classes), None, "aten::one_hot"
|
||||
|
||||
@pytest.mark.parametrize(("num_classes"), [-1, 3, 1000,])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_one_hot(self, num_classes, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(num_classes),
|
||||
ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user