[PT FE]: support aten::amax, aten::amin, aten::clip, aten::clamp_ (#20338)

This commit is contained in:
Ekaterina Aidova 2023-10-10 15:05:10 +04:00 committed by GitHub
parent 1454e77bbf
commit a5b6606132
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 108 additions and 7 deletions

View File

@ -112,6 +112,36 @@ OutputVector translate_minimum(const NodeContext& context) {
return {res};
}
OutputVector translate_amin(const NodeContext& context) {
// aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
// aten::amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
num_inputs_check(context, 3, 4);
auto x = context.get_input(0);
auto dims = context.get_input(1);
auto keep_dims = context.const_input<bool>(2);
auto res = context.mark_node(std::make_shared<v1::ReduceMin>(x, dims, keep_dims));
if (!context.input_is_none(3)) {
context.mutate_input(3, res);
}
return {res};
}
OutputVector translate_amax(const NodeContext& context) {
// aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor
// aten::amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
num_inputs_check(context, 3, 4);
auto x = context.get_input(0);
auto dims = context.get_input(1);
auto keep_dims = context.const_input<bool>(2);
auto res = context.mark_node(std::make_shared<v1::ReduceMax>(x, dims, keep_dims));
if (!context.input_is_none(3)) {
context.mutate_input(3, res);
}
return {res};
}
} // namespace op
} // namespace pytorch
} // namespace frontend

View File

@ -22,6 +22,8 @@ OP_CONVERTER(translate_add);
OP_CONVERTER(translate_addcmul);
OP_CONVERTER(translate_addmm);
OP_CONVERTER(translate_all);
OP_CONVERTER(translate_amax);
OP_CONVERTER(translate_amin);
OP_CONVERTER(translate_and);
OP_CONVERTER(translate_arange);
OP_CONVERTER(translate_argmax);
@ -237,6 +239,8 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::addcmul", op::translate_addcmul},
{"aten::addmm", op::translate_addmm},
{"aten::all", op::translate_all},
{"aten::amax", op::translate_amax},
{"aten::amin", op::translate_amin},
{"aten::arange", op::translate_arange},
{"aten::argmax", op::translate_argmax},
{"aten::argmin", op::translate_argmin},
@ -266,8 +270,11 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::ceil_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Ceiling>>},
{"aten::channel_shuffle", op::translate_channel_shuffle},
{"aten::clamp", op::translate_clamp},
{"aten::clamp_", op::inplace_op<op::translate_clamp>},
{"aten::clamp_max", op::translate_1to1_match_2_inputs<opset10::Minimum>},
{"aten::clamp_min", op::translate_1to1_match_2_inputs<opset10::Maximum>},
{"aten::clip", op::translate_clamp},
{"aten::clip_", op::inplace_op<op::translate_clamp>},
{"aten::clone", op::skip_node}, // ignore clone operators that are inserted by PyTorch autograd
{"aten::contiguous", op::skip_node}, // In openvino how tensors are stored in memory is internal plugin detail,
// we assume all tensors are contiguous

View File

@ -11,11 +11,11 @@ class TestClamp(PytorchLayerTest):
import numpy as np
return (np.random.randn(1, 3, 224, 224).astype(np.float32),)
def create_model(self, minimum, maximum, as_tensors=False):
def create_model(self, minimum, maximum, as_tensors=False, op_type='clamp'):
import torch
class aten_clamp(torch.nn.Module):
def __init__(self, minimum, maximum, as_tensors):
def __init__(self, minimum, maximum, as_tensors, op_type="clamp"):
super(aten_clamp, self).__init__()
if minimum is not None and as_tensors:
minimum = torch.tensor(minimum)
@ -23,20 +23,31 @@ class TestClamp(PytorchLayerTest):
if maximum is not None and as_tensors:
maximum = torch.tensor(maximum)
self.max = maximum
self.forward = getattr(self, f"forward_{op_type}")
def forward(self, x):
def forward_clamp(self, x):
return torch.clamp(x, self.min, self.max)
def forward_clip(self, x):
return torch.clip(x, self.min, self.max)
def forward_clamp_(self, x):
return x.clamp_(self.min, self.max), x
def forward_clip_(self, x):
return x.clip_(self.min, self.max), x
ref_net = None
op_name = "aten::clamp"
return aten_clamp(minimum, maximum, as_tensors), ref_net, op_name
op_name = f"aten::{op_type}"
return aten_clamp(minimum, maximum, as_tensors, op_type), ref_net, op_name
@pytest.mark.parametrize("minimum,maximum",
[(0., 1.), (-0.5, 1.5), (None, 10.), (None, -10.), (10., None), (-10., None), (100, 200)])
@pytest.mark.parametrize("as_tensors", [True, False])
@pytest.mark.parametrize("op_type", ["clamp", "clamp_"])
@pytest.mark.nightly
def test_clamp(self, minimum, maximum, as_tensors, ie_device, precision, ir_version):
self._test(*self.create_model(minimum, maximum, as_tensors), ie_device, precision, ir_version)
def test_clamp(self, minimum, maximum, as_tensors, op_type, ie_device, precision, ir_version):
self._test(*self.create_model(minimum, maximum, as_tensors, op_type), ie_device, precision, ir_version)
@pytest.mark.xfail(reason='OpenVINO clamp does not support min > max')
def test_clamp_min_greater(self, ie_device, precision, ir_version):

View File

@ -283,4 +283,57 @@ class TestMinimumMaximum(PytorchLayerTest):
ie_device, precision, ir_version, kwargs_to_prepare_input=
{"input_dtype": input_dtype, "second_input_dtype": input_dtype,
"out": True}
)
class TestAminAmax(PytorchLayerTest):
def _prepare_input(self, input_dtype="float32", out=False, axes=None, keep_dims=False):
import numpy as np
x = np.random.randn(1, 3, 10, 10).astype(input_dtype)
if not out:
return (x,)
if isinstance(axes, list):
axes = tuple(axes)
out = np.zeros_like(np.max(x, axis=axes, keepdims=keep_dims), dtype=input_dtype)
return (x, out)
def create_model(self, op_type, axis, keep_dims, out=False):
import torch
op_types = {
"amax": torch.amax,
"amin": torch.amin
}
op = op_types[op_type]
class aten_amin_amax(torch.nn.Module):
def __init__(self, op, axis, keep_dims, out):
super().__init__()
self.op = op
self.axis = axis
self.keep_dims = keep_dims
if out:
self.forward = self.forward_out
def forward_out(self, x, y):
return self.op(x, self.axis, self.keep_dims, out=y), y
def forward(self, x):
return self.op(x, self.axis, self.keep_dims)
model_cls = aten_amin_amax(op, axis, keep_dims, out)
return model_cls, None, f"aten::{op_type}"
@pytest.mark.parametrize("op_type", ["amin", "amax"])
@pytest.mark.parametrize("axis", [0, -1, 1, [1, 2], [-1, -2], [2, 0, -1], [0, 1, 2, 3]])
@pytest.mark.parametrize("keep_dims", [True, False])
@pytest.mark.parametrize("out", [True, False])
@pytest.mark.parametrize("input_dtype", ['float32', 'int32', 'int64', 'float64'])
def test_amin_amax(self, op_type, input_dtype, axis, keep_dims, out, ie_device, precision, ir_version):
self._test(*self.create_model(op_type, axis, keep_dims, out),
ie_device, precision, ir_version, kwargs_to_prepare_input=
{"input_dtype": input_dtype, "out": out, "axes": axis, "keep_dims": keep_dims}
)