[PT FE]: support aten::glu and aten::sigmoid_ (#15185)
* [PT FE]: support aten::glu and aten::sigmoid_ * upd headers * Update src/frontends/pytorch/src/op/glu.cpp Co-authored-by: Maxim Vafin <maxim.vafin@intel.com> * return back opset * Update op_table.cpp Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
parent
0d201376df
commit
18bfa727bd
30
src/frontends/pytorch/src/op/glu.cpp
Normal file
30
src/frontends/pytorch/src/op/glu.cpp
Normal file
@ -0,0 +1,30 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/sigmoid.hpp"
|
||||
#include "openvino/op/split.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_glu(NodeContext& context) {
|
||||
auto x = context.get_input(0);
|
||||
auto dim = context.input_is_none(1) ? context.mark_node(ov::op::v0::Constant::create(element::i64, Shape{}, {-1}))
|
||||
: context.get_input(1);
|
||||
auto split = context.mark_node(std::make_shared<ov::op::v1::Split>(x, dim, 2));
|
||||
auto first = split->output(0);
|
||||
auto second = split->output(1);
|
||||
auto sigmoid = context.mark_node(std::make_shared<ov::op::v0::Sigmoid>(second));
|
||||
return {context.mark_node(std::make_shared<ov::op::v1::Multiply>(first, sigmoid))};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -41,6 +41,7 @@ OP_CONVERTER(translate_full);
|
||||
OP_CONVERTER(translate_full_like);
|
||||
OP_CONVERTER(translate_gelu);
|
||||
OP_CONVERTER(translate_get_attr);
|
||||
OP_CONVERTER(translate_glu);
|
||||
OP_CONVERTER(translate_group_norm);
|
||||
OP_CONVERTER(translate_hardtanh);
|
||||
OP_CONVERTER(translate_if);
|
||||
@ -168,6 +169,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::full", op::translate_full},
|
||||
{"aten::full_like", op::translate_full_like},
|
||||
{"aten::gelu", op::translate_gelu},
|
||||
{"aten::glu", op::translate_glu},
|
||||
{"aten::group_norm", op::translate_group_norm},
|
||||
{"aten::ge", op::translate_1to1_match_2_inputs<opset10::GreaterEqual>},
|
||||
{"aten::gt", op::translate_1to1_match_2_inputs<opset10::Greater>},
|
||||
@ -231,6 +233,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::selu", op::translate_selu},
|
||||
{"aten::selu_", op::inplace_op<op::translate_selu>},
|
||||
{"aten::sigmoid", op::translate_1to1_match_1_inputs<opset10::Sigmoid>},
|
||||
{"aten::sigmoid_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Sigmoid>>},
|
||||
{"aten::silu", op::translate_1to1_match_1_inputs<opset10::Swish>},
|
||||
{"aten::silu_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Swish>>},
|
||||
{"aten::sin", op::translate_1to1_match_1_inputs<opset10::Sin>},
|
||||
|
34
tests/layer_tests/pytorch_tests/test_glu.py
Normal file
34
tests/layer_tests/pytorch_tests/test_glu.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestGlu(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return (np.random.randn(2, 4, 224, 224).astype(np.float32),)
|
||||
|
||||
def create_model(self, dim):
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class aten_glu(torch.nn.Module):
|
||||
def __init__(self, dim):
|
||||
super(aten_glu, self).__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
return F.glu(x, self.dim)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_glu(dim), ref_net, "aten::glu"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("dim", [0, 1, 2, 3, -1, -2])
|
||||
def test_glu(self, dim, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(dim), ie_device, precision, ir_version)
|
34
tests/layer_tests/pytorch_tests/test_sigmoid.py
Normal file
34
tests/layer_tests/pytorch_tests/test_sigmoid.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestSigmoid(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
import numpy as np
|
||||
return (np.random.randn(1, 3, 224, 224).astype(np.float32),)
|
||||
|
||||
def create_model(self, inplace=False):
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
class aten_sigmoid(torch.nn.Module):
|
||||
def __init__(self, inplace):
|
||||
super(aten_sigmoid, self).__init__()
|
||||
self.op = torch.sigmoid if not inplace else torch.sigmoid_
|
||||
|
||||
def forward(self, x):
|
||||
return x, self.op(x)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_sigmoid(inplace), ref_net, "aten::sigmoid" if not inplace else "aten::sigmoid_"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("inplace", [True, False])
|
||||
def test_sigmoid(self, inplace, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(inplace), ie_device, precision, ir_version)
|
Loading…
Reference in New Issue
Block a user