support pytorch efrc op (#21567)

* support pytorch efrc op

* allow dynamic case & fix code format

* fix code style, use space 4
This commit is contained in:
Ahmad Chalhoub 2023-12-12 15:54:29 -05:00 committed by GitHub
parent 9ab0a6d2f1
commit 942bc8b1ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 105 additions and 0 deletions

View File

@ -0,0 +1,45 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/erf.hpp"
#include "openvino/op/subtract.hpp"
#include "utils.hpp"
using namespace std;
using namespace ov::op;
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
OutputVector translate_erfc(const NodeContext& context) {
// aten::erf(Tensor self) -> Tensor
// aten::erf.out(Tensor self, Tensor(!a) out) -> Tensor(!a)
num_inputs_check(context, 1, 2);
auto x = context.get_input(0);
// create 'ones' to use to calculate complementary of Erf output
auto ones = context.mark_node(make_shared<v0::Constant>(element::f32, Shape{}, 1.0f))->output(0);
// align data types of input 'x' and ones
align_eltwise_input_types(context, x, ones);
// apply Erf to the input tensor 'x'
auto y = context.mark_node(make_shared<v0::Erf>(x));
y = context.mark_node(make_shared<v1::Subtract>(ones, y));
if (!context.input_is_none(1)) {
context.mutate_input(1, y);
}
return {y};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -68,6 +68,7 @@ OP_CONVERTER(translate_embedding_bag);
OP_CONVERTER(translate_empty);
OP_CONVERTER(translate_empty_like);
OP_CONVERTER(translate_erf);
OP_CONVERTER(translate_erfc);
OP_CONVERTER(translate_expand);
OP_CONVERTER(translate_expand_as);
OP_CONVERTER(translate_eye);
@ -348,6 +349,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::eq", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
{"aten::erf", op::translate_erf},
{"aten::erf_", op::inplace_op<op::translate_erf>},
{"aten::erfc", op::translate_erfc},
{"aten::erfc_", op::inplace_op<op::translate_erfc>},
{"aten::exp", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Exp>},
{"aten::exp_", op::inplace_op<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Exp>>},
{"aten::expand", op::translate_expand},

View File

@ -0,0 +1,57 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestErfc(PytorchLayerTest):
def _prepare_input(self, input_dtype, out=False):
import numpy as np
x = np.linspace(-3, 3).astype(input_dtype)
if not out:
return (x, )
return (x, np.zeros_like(x).astype(input_dtype))
def create_model(self, mode="", input_dtype="float32"):
import torch
dtypes = {
"float32": torch.float32,
"float64": torch.float64,
"int32": torch.int32
}
dtype = dtypes[input_dtype]
class aten_erfc(torch.nn.Module):
def __init__(self, mode, dtype):
super(aten_erfc, self).__init__()
self.dtype = dtype
if mode == "out":
self.forward = self.forward_out
elif mode == "inplace":
self.forward = self.forward_inplace
def forward(self, x):
return torch.special.erfc(x.to(self.dtype))
def forward_out(self, x, y):
return torch.special.erfc(x.to(self.dtype), out=y), y
def forward_inplace(self, x):
x = x.to(self.dtype)
return x.erfc_(), x
ref_net = None
return aten_erfc(mode, dtype), ref_net, "aten::erfc" if mode != "inplace" else "aten::erfc_"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("mode,input_dtype", [
("", "float32"), ("", "float64"), ("", "int32"),
("out", "float32"), ("out", "float64"),
("inplace", "float32"), ("inplace", "float64")])
def test_erfc(self, mode, input_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(mode, input_dtype), ie_device, precision, ir_version,
kwargs_to_prepare_input={"input_dtype": input_dtype, "out": mode == "out"} )