[PT FE]: support aten::erf and aten::adaptive_avg_pool1d (#20350)

* [PT FE]: support aten::erf and aten::adaptive_avg_pool1d

* align adaptive avg pools for different sizes

* refactor adaptive max pool
This commit is contained in:
Ekaterina Aidova 2023-10-11 17:33:32 +04:00 committed by GitHub
parent 3403e6c028
commit 9bedafb560
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 473 additions and 165 deletions

View File

@ -1,47 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/adaptive_avg_pool.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/tile.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_adaptive_avg_pool3d(const NodeContext& context) {
num_inputs_check(context, 2, 2);
auto const_tile_params = context.mark_node(v0::Constant::create(element::i32, Shape{5}, {1, 1, 1, 1, 1}));
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto const_neg_3 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-3}));
auto input_tensor = context.get_input(0);
auto given_shape = context.get_input(1);
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input_tensor, element::i32));
auto shape_begin =
context.mark_node(std::make_shared<v8::Slice>(input_shape, const_0, const_neg_3, const_1, const_0));
auto output_shape = context.mark_node(std::make_shared<v0::Concat>(OutputVector{shape_begin, given_shape}, 0));
auto tile = context.mark_node(std::make_shared<v0::Tile>(input_tensor, const_tile_params));
auto adaptive_avg_pool = context.mark_node(std::make_shared<v8::AdaptiveAvgPool>(tile, given_shape));
auto reshape = context.mark_node(std::make_shared<v1::Reshape>(adaptive_avg_pool, output_shape, false));
return {reshape};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -1,25 +0,0 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/adaptive_max_pool.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
OutputVector translate_adaptive_max_pool2d(const NodeContext& context) {
num_inputs_check(context, 2, 2);
auto x = context.get_input(0);
auto y = context.get_input(1);
auto adaptive_max_pool = context.mark_node(std::make_shared<ov::op::v8::AdaptiveMaxPool>(x, y, ov::element::i32));
return {adaptive_max_pool->output(0), adaptive_max_pool->output(1)};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,123 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/adaptive_avg_pool.hpp"
#include "openvino/op/adaptive_max_pool.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/tile.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
namespace {
std::tuple<Output<Node>, Output<Node>> get_tile_input_and_output_shape(const NodeContext& context,
const Output<Node>& input_tensor,
const Output<Node>& given_shape,
const Output<Node>& tile_shape,
const Output<Node>& slice_end) {
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto input_shape = context.mark_node(std::make_shared<v3::ShapeOf>(input_tensor, element::i32));
auto shape_begin =
context.mark_node(std::make_shared<v8::Slice>(input_shape, const_0, slice_end, const_1, const_0));
Output<Node> output_shape =
context.mark_node(std::make_shared<v0::Concat>(OutputVector{shape_begin, given_shape}, 0));
Output<Node> tile = context.mark_node(std::make_shared<v0::Tile>(input_tensor, tile_shape));
return std::make_tuple(tile, output_shape);
};
OutputVector translate_adaptive_avg_pool_base(const NodeContext& context,
const Output<Node>& tile_shape,
const Output<Node>& slice_end) {
num_inputs_check(context, 2, 2);
auto input_tensor = context.get_input(0);
auto given_shape = context.get_input(1);
Output<Node> tile_input;
Output<Node> output_shape;
std::tie(tile_input, output_shape) =
get_tile_input_and_output_shape(context, input_tensor, given_shape, tile_shape, slice_end);
auto adaptive_avg_pool = context.mark_node(std::make_shared<v8::AdaptiveAvgPool>(tile_input, given_shape));
auto reshape = context.mark_node(std::make_shared<v1::Reshape>(adaptive_avg_pool, output_shape, false));
return {reshape};
};
OutputVector translate_adaptive_max_pool_base(const NodeContext& context,
const Output<Node>& tile_shape,
const Output<Node>& slice_end) {
num_inputs_check(context, 2, 2);
auto input_tensor = context.get_input(0);
auto given_shape = context.get_input(1);
Output<Node> tile_input;
Output<Node> output_shape;
std::tie(tile_input, output_shape) =
get_tile_input_and_output_shape(context, input_tensor, given_shape, tile_shape, slice_end);
auto adaptive_max_pool =
context.mark_node(std::make_shared<v8::AdaptiveMaxPool>(tile_input, given_shape, element::i32));
auto pooled_tensor = adaptive_max_pool->output(0);
auto pooled_indices = adaptive_max_pool->output(1);
// adaptive max pool in torch return indices in i64, indices_element_type i64 is not implented on ov runtime side
pooled_indices = context.mark_node(std::make_shared<v0::Convert>(pooled_indices, element::i64));
pooled_tensor = context.mark_node(std::make_shared<v1::Reshape>(pooled_tensor, output_shape, false));
pooled_indices = context.mark_node(std::make_shared<v1::Reshape>(pooled_indices, output_shape, false));
// aten::adaptive_max_pool{n}d always returns tuple with 2 tensors: pooled tensor and indicies
// output selecting only first or preserve both made outside of operation by return_indices flag
return {pooled_tensor, pooled_indices};
};
} // namespace
OutputVector translate_adaptive_avg_pool3d(const NodeContext& context) {
auto const_tile_params = context.mark_node(v0::Constant::create(element::i32, Shape{5}, {1, 1, 1, 1, 1}));
auto const_neg_3 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-3}));
return translate_adaptive_avg_pool_base(context, const_tile_params, const_neg_3);
};
OutputVector translate_adaptive_avg_pool2d(const NodeContext& context) {
auto const_tile_params = context.mark_node(v0::Constant::create(element::i32, Shape{4}, {1, 1, 1, 1}));
auto const_neg_2 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-2}));
return translate_adaptive_avg_pool_base(context, const_tile_params, const_neg_2);
};
OutputVector translate_adaptive_avg_pool1d(const NodeContext& context) {
auto const_tile_params = context.mark_node(v0::Constant::create(element::i32, Shape{3}, {1, 1, 1}));
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
return translate_adaptive_avg_pool_base(context, const_tile_params, const_neg_1);
};
OutputVector translate_adaptive_max_pool3d(const NodeContext& context) {
auto const_tile_params = context.mark_node(v0::Constant::create(element::i32, Shape{5}, {1, 1, 1, 1, 1}));
auto const_neg_3 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-3}));
return translate_adaptive_max_pool_base(context, const_tile_params, const_neg_3);
};
OutputVector translate_adaptive_max_pool2d(const NodeContext& context) {
auto const_tile_params = context.mark_node(v0::Constant::create(element::i32, Shape{4}, {1, 1, 1, 1}));
auto const_neg_2 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-2}));
return translate_adaptive_max_pool_base(context, const_tile_params, const_neg_2);
};
OutputVector translate_adaptive_max_pool1d(const NodeContext& context) {
auto const_tile_params = context.mark_node(v0::Constant::create(element::i32, Shape{3}, {1, 1, 1}));
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
return translate_adaptive_max_pool_base(context, const_tile_params, const_neg_1);
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -0,0 +1,37 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/op/erf.hpp"
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/convert.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
OutputVector translate_erf(const NodeContext& context) {
// aten::erf(Tensor self) -> Tensor
// aten::erf.out(Tensor self, Tensor(!a) out) -> Tensor(!a)
num_inputs_check(context, 1, 2);
auto x = context.get_input(0);
auto xdtype = x.get_element_type();
// in torch, erf return always float dtype, while ov cast to input dtype
if (xdtype.is_dynamic() || !xdtype.is_real()) {
x = context.mark_node(std::make_shared<ov::op::v0::Convert>(x, element::f32));
}
auto y = context.mark_node(std::make_shared<ov::op::v0::Erf>(x));
if (!context.input_is_none(1)) {
context.mutate_input(1, y);
}
return {y};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -17,7 +17,11 @@ namespace op {
// TorchScript translations
OP_CONVERTER(translate_adaptive_avg_pool3d);
OP_CONVERTER(translate_adaptive_avg_pool2d);
OP_CONVERTER(translate_adaptive_avg_pool1d);
OP_CONVERTER(translate_adaptive_max_pool3d);
OP_CONVERTER(translate_adaptive_max_pool2d);
OP_CONVERTER(translate_adaptive_max_pool1d);
OP_CONVERTER(translate_add);
OP_CONVERTER(translate_addcmul);
OP_CONVERTER(translate_addmm);
@ -56,6 +60,7 @@ OP_CONVERTER(translate_elu);
OP_CONVERTER(translate_embedding);
OP_CONVERTER(translate_embedding_bag);
OP_CONVERTER(translate_empty);
OP_CONVERTER(translate_erf);
OP_CONVERTER(translate_expand);
OP_CONVERTER(translate_expand_as);
OP_CONVERTER(translate_eye);
@ -232,9 +237,12 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::acos_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Acos>>},
{"aten::acosh", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Acosh>},
{"aten::acosh_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Acosh>>},
{"aten::adaptive_avg_pool2d", op::quantizable_op<op::translate_1to1_match_2_inputs<opset10::AdaptiveAvgPool>>},
{"aten::adaptive_avg_pool1d", op::quantizable_op<op::translate_adaptive_avg_pool1d>},
{"aten::adaptive_avg_pool2d", op::quantizable_op<op::translate_adaptive_avg_pool2d>},
{"aten::adaptive_avg_pool3d", op::quantizable_op<op::translate_adaptive_avg_pool3d>},
{"aten::adaptive_max_pool1d", op::quantizable_op<op::translate_adaptive_max_pool1d>},
{"aten::adaptive_max_pool2d", op::quantizable_op<op::translate_adaptive_max_pool2d>},
{"aten::adaptive_max_pool3d", op::quantizable_op<op::translate_adaptive_max_pool3d>},
{"aten::add", op::translate_add},
{"aten::add_", op::inplace_op<op::translate_add>},
{"aten::addcmul", op::translate_addcmul},
@ -305,6 +313,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::embedding_bag", op::translate_embedding_bag},
{"aten::empty", op::translate_empty},
{"aten::eq", op::translate_1to1_match_2_inputs_align_types<opset10::Equal>},
{"aten::erf", op::translate_erf},
{"aten::erf_", op::inplace_op<op::translate_erf>},
{"aten::exp", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Exp>},
{"aten::exp_", op::inplace_op<op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Exp>>},
{"aten::expand", op::translate_expand},

View File

@ -0,0 +1,101 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
@pytest.mark.parametrize('input_tensor', (np.random.randn(1, 2, 8, 9, 10).astype(np.float32),
np.random.randn(2, 8, 9, 10).astype(np.float32)))
@pytest.mark.parametrize('output_size', ([5, 7, 9], 7))
class TestAdaptiveAvgPool3D(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor,)
def create_model(self, output_size):
class aten_adaptive_avg_pool3d(torch.nn.Module):
def __init__(self, output_size) -> None:
super().__init__()
self.output_size = output_size
def forward(self, input_tensor):
return torch.nn.functional.adaptive_avg_pool3d(input_tensor, self.output_size)
ref_net = None
return aten_adaptive_avg_pool3d(output_size), ref_net, "aten::adaptive_avg_pool3d"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_ts_backend
@pytest.mark.precommit_fx_backend
def test_adaptive_avg_pool3d(self, ie_device, precision, ir_version, input_tensor, output_size):
self.input_tensor = input_tensor
self._test(*self.create_model(output_size), ie_device, precision, ir_version)
@pytest.mark.parametrize('input_tensor', [np.random.randn(2, 8, 9, 10).astype(np.float32), np.random.randn(8, 9, 10).astype(np.float32)])
@pytest.mark.parametrize('output_size', ([7, 9], 7))
class TestAdaptiveAvgPool2D(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor,)
def create_model(self, output_size):
class aten_adaptive_avg_pool2d(torch.nn.Module):
def __init__(self, output_size) -> None:
super().__init__()
self.output_size = output_size
def forward(self, input_tensor):
return torch.nn.functional.adaptive_avg_pool2d(input_tensor, self.output_size)
ref_net = None
return aten_adaptive_avg_pool2d(output_size), ref_net, "aten::adaptive_avg_pool2d"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_ts_backend
@pytest.mark.precommit_fx_backend
def test_adaptive_avg_pool2d(self, ie_device, precision, ir_version, input_tensor, output_size):
self.input_tensor = input_tensor
self._test(*self.create_model(output_size), ie_device, precision, ir_version)
@pytest.mark.parametrize('input_tensor', [np.random.randn(8, 9, 10).astype(np.float32), np.random.randn(9, 10).astype(np.float32)] )
@pytest.mark.parametrize('output_size', ( 7, ))
class TestAdaptiveAvgPool1D(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor,)
def create_model(self, output_size):
class aten_adaptive_avg_pool1d(torch.nn.Module):
def __init__(self, output_size) -> None:
super().__init__()
self.output_size = output_size
def forward(self, input_tensor):
return torch.nn.functional.adaptive_avg_pool1d(input_tensor, self.output_size)
ref_net = None
return aten_adaptive_avg_pool1d(output_size), ref_net, "aten::adaptive_avg_pool1d"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_ts_backend
@pytest.mark.precommit_fx_backend
def test_adaptive_avg_pool1d(self, ie_device, precision, ir_version, input_tensor, output_size):
self.input_tensor = input_tensor
self._test(*self.create_model(output_size), ie_device, precision, ir_version)

View File

@ -1,39 +0,0 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
@pytest.mark.parametrize('input_tensor', (np.random.randn(1, 2, 8, 9, 10).astype(np.float32),
np.random.randn(2, 8, 9, 10).astype(np.float32)))
@pytest.mark.parametrize('output_size', ([5, 7, 9], 7))
class TestAdaptiveAvgPool3D(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor,)
def create_model(self, output_size):
class aten_adaptive_avg_pool3d(torch.nn.Module):
def __init__(self, output_size) -> None:
super().__init__()
self.output_size = output_size
def forward(self, input_tensor):
return torch.nn.functional.adaptive_avg_pool3d(input_tensor, self.output_size)
ref_net = None
return aten_adaptive_avg_pool3d(output_size), ref_net, "aten::adaptive_avg_pool3d"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_ts_backend
@pytest.mark.precommit_fx_backend
def test_adaptive_avg_pool3d(self, ie_device, precision, ir_version, input_tensor, output_size):
self.input_tensor = input_tensor
self._test(*self.create_model(output_size), ie_device, precision, ir_version)

View File

@ -0,0 +1,144 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
import torch.nn.functional as F
from pytorch_layer_test_class import PytorchLayerTest
class TestAdaptiveMaxPool3D(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor,)
def create_model(self, output_size=None, return_indices=False):
class aten_adaptive_max_pool3d(torch.nn.Module):
def __init__(self, output_size=None, return_indices=False) -> None:
super().__init__()
self.output_size = output_size
self.return_indices = return_indices
def forward(self, input_tensor):
if self.return_indices:
output, indices = F.adaptive_max_pool3d(input_tensor, self.output_size, True)
return output, indices
return F.adaptive_max_pool3d(input_tensor, self.output_size, False), input_tensor.to(torch.int64)
ref_net = None
return aten_adaptive_max_pool3d(output_size, return_indices), ref_net, "aten::adaptive_max_pool3d"
@pytest.mark.parametrize('input_tensor', ([
np.random.randn(2, 1, 1, 4, 4).astype(np.float32),
np.random.randn(4, 1, 3, 32, 32).astype(np.float32),
np.random.randn(1, 3, 32, 32).astype(np.float32)
]))
@pytest.mark.parametrize('output_size', ([
[2, 2, 2],
[4, 4, 4],
]))
@pytest.mark.parametrize('return_indices', ([
False,
True,
]))
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_ts_backend
@pytest.mark.precommit_fx_backend
def test_adaptive_max_pool3d(self, ie_device, precision, ir_version, input_tensor, output_size, return_indices):
self.input_tensor = input_tensor
self._test(*self.create_model(output_size, return_indices), ie_device, precision, ir_version)
class TestAdaptiveMaxPool2D(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor,)
def create_model(self, output_size=None, return_indices=False):
class aten_adaptive_max_pool2d(torch.nn.Module):
def __init__(self, output_size=None, return_indices=False) -> None:
super().__init__()
self.output_size = output_size
self.return_indices = return_indices
def forward(self, input_tensor):
if self.return_indices:
output, indices = F.adaptive_max_pool2d(input_tensor, self.output_size, True)
return output, indices
return F.adaptive_max_pool2d(input_tensor, self.output_size, False), input_tensor.to(torch.int64)
ref_net = None
return aten_adaptive_max_pool2d(output_size, return_indices), ref_net, "aten::adaptive_max_pool2d"
@pytest.mark.parametrize('input_tensor', ([
np.random.randn(2, 1, 4, 4).astype(np.float32),
np.random.randn(1, 3, 32, 32).astype(np.float32),
np.random.randn(3, 32, 32).astype(np.float32)
]))
@pytest.mark.parametrize('output_size', ([
[2, 2],
[4, 4],
]))
@pytest.mark.parametrize('return_indices', ([
False,
True,
]))
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_ts_backend
@pytest.mark.precommit_fx_backend
def test_adaptive_max_pool2d(self, ie_device, precision, ir_version, input_tensor, output_size, return_indices):
self.input_tensor = input_tensor
self._test(*self.create_model(output_size, return_indices), ie_device, precision, ir_version)
class TestAdaptiveMaxPool1D(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor,)
def create_model(self, output_size=None, return_indices=False):
class aten_adaptive_max_pool1d(torch.nn.Module):
def __init__(self, output_size=None, return_indices=False) -> None:
super().__init__()
self.output_size = output_size
self.return_indices = return_indices
def forward(self, input_tensor):
if self.return_indices:
output, indices = F.adaptive_max_pool1d(input_tensor, self.output_size, True)
return output, indices
return F.adaptive_max_pool1d(input_tensor, self.output_size, False), input_tensor.to(torch.int64)
ref_net = None
return aten_adaptive_max_pool1d(output_size, return_indices), ref_net, "aten::adaptive_max_pool1d"
@pytest.mark.parametrize('input_tensor', ([
np.random.randn(1, 4, 4).astype(np.float32),
np.random.randn(3, 32, 32).astype(np.float32),
np.random.randn(16, 8).astype(np.float32),
]))
@pytest.mark.parametrize('output_size', ([
2,
4,
]))
@pytest.mark.parametrize('return_indices', ([
False,
True,
]))
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_ts_backend
@pytest.mark.precommit_fx_backend
def test_adaptive_max_pool1d(self, ie_device, precision, ir_version, input_tensor, output_size, return_indices):
self.input_tensor = input_tensor
self._test(*self.create_model(output_size, return_indices), ie_device, precision, ir_version)

View File

@ -1,53 +0,0 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
import torch.nn.functional as F
from pytorch_layer_test_class import PytorchLayerTest
class TestAdaptiveMaxPool2D(PytorchLayerTest):
def _prepare_input(self):
return (self.input_tensor,)
def create_model(self, output_size=None, return_indices=False):
class aten_adaptive_max_pool2d(torch.nn.Module):
def __init__(self, output_size=None, return_indices=False) -> None:
super().__init__()
self.output_size = output_size
self.return_indices = return_indices
def forward(self, input_tensor):
if self.return_indices:
output, indices = F.adaptive_max_pool2d(input_tensor, self.output_size, True)
return output
return F.adaptive_max_pool2d(input_tensor, self.output_size, False)
ref_net = None
return aten_adaptive_max_pool2d(output_size, return_indices), ref_net, "aten::adaptive_max_pool2d"
@pytest.mark.parametrize('input_tensor', ([
np.random.randn(1, 1, 4, 4).astype(np.float32),
np.random.randn(1, 3, 32, 32).astype(np.float32)
]))
@pytest.mark.parametrize('output_size', ([
[2, 2],
[4, 4],
]))
@pytest.mark.parametrize('return_indices', ([
False,
True,
]))
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.precommit_ts_backend
@pytest.mark.precommit_fx_backend
def test_adaptive_max_pool2d(self, ie_device, precision, ir_version, input_tensor, output_size, return_indices):
self.input_tensor = input_tensor
self._test(*self.create_model(output_size, return_indices), ie_device, precision, ir_version)

View File

@ -0,0 +1,57 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestErf(PytorchLayerTest):
def _prepare_input(self, input_dtype, out=False):
import numpy as np
x = np.linspace(-3, 3).astype(input_dtype)
if not out:
return (x, )
return (x, np.zeros_like(x).astype(input_dtype))
def create_model(self, mode="", input_dtype="float32"):
import torch
dtypes = {
"float32": torch.float32,
"float64": torch.float64,
"int32": torch.int32
}
dtype = dtypes[input_dtype]
class aten_erf(torch.nn.Module):
def __init__(self, mode, dtype):
super(aten_erf, self).__init__()
self.dtype = dtype
if mode == "out":
self.forward = self.forward_out
elif mode == "inplace":
self.forward = self.forward_inplace
def forward(self, x):
return torch.special.erf(x.to(self.dtype))
def forward_out(self, x, y):
return torch.special.erf(x.to(self.dtype), out=y), y
def forward_inplace(self, x):
x = x.to(self.dtype)
return x.erf_(), x
ref_net = None
return aten_erf(mode, dtype), ref_net, "aten::erf" if mode != "inplace" else "aten::erf_"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("mode,input_dtype", [
("", "float32"), ("", "float64"), ("", "int32"),
("out", "float32"), ("out", "float64"),
("inplace", "float32"), ("inplace", "float64")])
def test_erf(self, mode, input_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(mode, input_dtype), ie_device, precision, ir_version,
kwargs_to_prepare_input={"input_dtype": input_dtype, "out": mode == "out"} )