[PT FE] Support aten::one_hot (#19779)

* [PT FE] Support aten::one_hot

* Apply code style
This commit is contained in:
Maxim Vafin 2023-09-13 20:37:47 +02:00 committed by GitHub
parent f744869551
commit 4f92676c85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 84 additions and 0 deletions

View File

@ -0,0 +1,49 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/op/one_hot.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/greater.hpp"
#include "openvino/op/reduce_max.hpp"
#include "openvino/op/select.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_one_hot(const NodeContext& context) {
num_inputs_check(context, 1, 2);
auto x = context.get_input(0);
// aten::one_hot works on LongTensor which means we need to convert all inputs to i64
x = context.mark_node(std::make_shared<v0::Convert>(x, element::i64));
auto on_value = context.mark_node(v0::Constant::create(element::i64, Shape{}, {1}));
auto zero_value = context.mark_node(v0::Constant::create(element::i64, Shape{}, {0}));
Output<Node> num_classes;
if (context.input_is_none(1)) {
num_classes = context.mark_node(v0::Constant::create(element::i64, Shape{}, {-1}));
} else {
num_classes = context.get_input(1);
num_classes = context.mark_node(std::make_shared<v0::Convert>(num_classes, element::i64));
}
auto one = context.mark_node(v0::Constant::create(element::i64, Shape{}, {1}));
auto greater = context.mark_node(std::make_shared<v1::Greater>(num_classes, zero_value));
auto axes = get_axes_range(context, 0);
auto max_class = context.mark_node(std::make_shared<v1::ReduceMax>(x, axes));
max_class = context.mark_node(std::make_shared<v1::Add>(max_class, one));
num_classes = context.mark_node(std::make_shared<v1::Select>(greater, num_classes, max_class));
return {context.mark_node(std::make_shared<v1::OneHot>(x, num_classes, on_value, zero_value, -1))};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -110,6 +110,7 @@ OP_CONVERTER(translate_nms);
OP_CONVERTER(translate_nonzero);
OP_CONVERTER(translate_norm);
OP_CONVERTER(translate_numel);
OP_CONVERTER(translate_one_hot);
OP_CONVERTER(translate_ones);
OP_CONVERTER(translate_ones_like);
OP_CONVERTER(translate_or);
@ -371,6 +372,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::nonzero", op::translate_nonzero},
{"aten::norm", op::translate_norm},
{"aten::numel", op::translate_numel},
{"aten::one_hot", op::translate_one_hot},
{"aten::ones", op::translate_ones},
{"aten::ones_like", op::translate_ones_like},
{"aten::outer", op::translate_outer},

View File

@ -0,0 +1,33 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestOneHot(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randint(0, 100, (1,1000)).astype(np.int32),)
def create_model(self, num_classes):
import torch
import torch.nn.functional as F
class aten_one_hot(torch.nn.Module):
def __init__(self, num_classes):
super(aten_one_hot, self).__init__()
self.num_classes = num_classes
def forward(self, x):
return F.one_hot(torch.arange(0, x.numel()) % 3, self.num_classes)
return aten_one_hot(num_classes), None, "aten::one_hot"
@pytest.mark.parametrize(("num_classes"), [-1, 3, 1000,])
@pytest.mark.nightly
@pytest.mark.precommit
def test_one_hot(self, num_classes, ie_device, precision, ir_version):
self._test(*self.create_model(num_classes),
ie_device, precision, ir_version)