support pytorch efrc op (#21567)
* support pytorch efrc op * allow dynamic case & fix code format * fix code style, use space 4
This commit is contained in:
parent
9ab0a6d2f1
commit
942bc8b1ba
45
src/frontends/pytorch/src/op/erfc.cpp
Normal file
45
src/frontends/pytorch/src/op/erfc.cpp
Normal file
@ -0,0 +1,45 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/erf.hpp"
|
||||
#include "openvino/op/subtract.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
using namespace std;
|
||||
using namespace ov::op;
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
OutputVector translate_erfc(const NodeContext& context) {
|
||||
// aten::erf(Tensor self) -> Tensor
|
||||
// aten::erf.out(Tensor self, Tensor(!a) out) -> Tensor(!a)
|
||||
num_inputs_check(context, 1, 2);
|
||||
auto x = context.get_input(0);
|
||||
|
||||
// create 'ones' to use to calculate complementary of Erf output
|
||||
auto ones = context.mark_node(make_shared<v0::Constant>(element::f32, Shape{}, 1.0f))->output(0);
|
||||
|
||||
// align data types of input 'x' and ones
|
||||
align_eltwise_input_types(context, x, ones);
|
||||
|
||||
// apply Erf to the input tensor 'x'
|
||||
auto y = context.mark_node(make_shared<v0::Erf>(x));
|
||||
|
||||
y = context.mark_node(make_shared<v1::Subtract>(ones, y));
|
||||
|
||||
if (!context.input_is_none(1)) {
|
||||
context.mutate_input(1, y);
|
||||
}
|
||||
return {y};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -68,6 +68,7 @@ OP_CONVERTER(translate_embedding_bag);
|
||||
OP_CONVERTER(translate_empty);
|
||||
OP_CONVERTER(translate_empty_like);
|
||||
OP_CONVERTER(translate_erf);
|
||||
OP_CONVERTER(translate_erfc);
|
||||
OP_CONVERTER(translate_expand);
|
||||
OP_CONVERTER(translate_expand_as);
|
||||
OP_CONVERTER(translate_eye);
|
||||
@ -348,6 +349,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::eq", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
|
||||
{"aten::erf", op::translate_erf},
|
||||
{"aten::erf_", op::inplace_op<op::translate_erf>},
|
||||
{"aten::erfc", op::translate_erfc},
|
||||
{"aten::erfc_", op::inplace_op<op::translate_erfc>},
|
||||
{"aten::exp", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Exp>},
|
||||
{"aten::exp_", op::inplace_op<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Exp>>},
|
||||
{"aten::expand", op::translate_expand},
|
||||
|
57
tests/layer_tests/pytorch_tests/test_erfc.py
Normal file
57
tests/layer_tests/pytorch_tests/test_erfc.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestErfc(PytorchLayerTest):
|
||||
def _prepare_input(self, input_dtype, out=False):
|
||||
import numpy as np
|
||||
x = np.linspace(-3, 3).astype(input_dtype)
|
||||
if not out:
|
||||
return (x, )
|
||||
return (x, np.zeros_like(x).astype(input_dtype))
|
||||
|
||||
def create_model(self, mode="", input_dtype="float32"):
|
||||
import torch
|
||||
dtypes = {
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64,
|
||||
"int32": torch.int32
|
||||
}
|
||||
|
||||
dtype = dtypes[input_dtype]
|
||||
class aten_erfc(torch.nn.Module):
|
||||
def __init__(self, mode, dtype):
|
||||
super(aten_erfc, self).__init__()
|
||||
self.dtype = dtype
|
||||
if mode == "out":
|
||||
self.forward = self.forward_out
|
||||
elif mode == "inplace":
|
||||
self.forward = self.forward_inplace
|
||||
|
||||
def forward(self, x):
|
||||
return torch.special.erfc(x.to(self.dtype))
|
||||
|
||||
def forward_out(self, x, y):
|
||||
return torch.special.erfc(x.to(self.dtype), out=y), y
|
||||
|
||||
def forward_inplace(self, x):
|
||||
x = x.to(self.dtype)
|
||||
return x.erfc_(), x
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_erfc(mode, dtype), ref_net, "aten::erfc" if mode != "inplace" else "aten::erfc_"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("mode,input_dtype", [
|
||||
("", "float32"), ("", "float64"), ("", "int32"),
|
||||
("out", "float32"), ("out", "float64"),
|
||||
("inplace", "float32"), ("inplace", "float64")])
|
||||
def test_erfc(self, mode, input_dtype, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(mode, input_dtype), ie_device, precision, ir_version,
|
||||
kwargs_to_prepare_input={"input_dtype": input_dtype, "out": mode == "out"} )
|
Loading…
Reference in New Issue
Block a user