[PT FE]: support aten::glu and aten::sigmoid_ (#15185)

* [PT FE]: support aten::glu and aten::sigmoid_

* upd headers

* Update src/frontends/pytorch/src/op/glu.cpp

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

* return back opset

* Update op_table.cpp

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
Ekaterina Aidova 2023-01-21 22:05:20 +04:00 committed by GitHub
parent 0d201376df
commit 18bfa727bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 101 additions and 0 deletions

View File

@ -0,0 +1,30 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/sigmoid.hpp"
#include "openvino/op/split.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
OutputVector translate_glu(NodeContext& context) {
auto x = context.get_input(0);
auto dim = context.input_is_none(1) ? context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {-1}))
: context.get_input(1);
auto split = context.mark_node(std::make_shared<ov::op::v1::Split>(x, dim, 2));
auto first = split->output(0);
auto second = split->output(1);
auto sigmoid = context.mark_node(std::make_shared<ov::op::v0::Sigmoid>(second));
return {context.mark_node(std::make_shared<ov::op::v1::Multiply>(first, sigmoid))};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -41,6 +41,7 @@ OP_CONVERTER(translate_full);
OP_CONVERTER(translate_full_like);
OP_CONVERTER(translate_gelu);
OP_CONVERTER(translate_get_attr);
OP_CONVERTER(translate_glu);
OP_CONVERTER(translate_group_norm);
OP_CONVERTER(translate_hardtanh);
OP_CONVERTER(translate_if);
@ -168,6 +169,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::full", op::translate_full},
{"aten::full_like", op::translate_full_like},
{"aten::gelu", op::translate_gelu},
{"aten::glu", op::translate_glu},
{"aten::group_norm", op::translate_group_norm},
{"aten::ge", op::translate_1to1_match_2_inputs<opset10::GreaterEqual>},
{"aten::gt", op::translate_1to1_match_2_inputs<opset10::Greater>},
@ -231,6 +233,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::selu", op::translate_selu},
{"aten::selu_", op::inplace_op<op::translate_selu>},
{"aten::sigmoid", op::translate_1to1_match_1_inputs<opset10::Sigmoid>},
{"aten::sigmoid_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Sigmoid>>},
{"aten::silu", op::translate_1to1_match_1_inputs<opset10::Swish>},
{"aten::silu_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Swish>>},
{"aten::sin", op::translate_1to1_match_1_inputs<opset10::Sin>},

View File

@ -0,0 +1,34 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestGlu(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(2, 4, 224, 224).astype(np.float32),)
def create_model(self, dim):
import torch
import torch.nn.functional as F
class aten_glu(torch.nn.Module):
def __init__(self, dim):
super(aten_glu, self).__init__()
self.dim = dim
def forward(self, x):
return F.glu(x, self.dim)
ref_net = None
return aten_glu(dim), ref_net, "aten::glu"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("dim", [0, 1, 2, 3, -1, -2])
def test_glu(self, dim, ie_device, precision, ir_version):
self._test(*self.create_model(dim), ie_device, precision, ir_version)

View File

@ -0,0 +1,34 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestSigmoid(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(1, 3, 224, 224).astype(np.float32),)
def create_model(self, inplace=False):
import torch
import torch.nn.functional as F
class aten_sigmoid(torch.nn.Module):
def __init__(self, inplace):
super(aten_sigmoid, self).__init__()
self.op = torch.sigmoid if not inplace else torch.sigmoid_
def forward(self, x):
return x, self.op(x)
ref_net = None
return aten_sigmoid(inplace), ref_net, "aten::sigmoid" if not inplace else "aten::sigmoid_"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("inplace", [True, False])
def test_sigmoid(self, inplace, ie_device, precision, ir_version):
self._test(*self.create_model(inplace), ie_device, precision, ir_version)