From a5b6606132a19aa7f3be2c7dbf18b66eea8fb84e Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Tue, 10 Oct 2023 15:05:10 +0400 Subject: [PATCH] [PT FE]: support aten::amax, aten::amin, aten::clip, aten::clamp_ (#20338) --- src/frontends/pytorch/src/op/min_max.cpp | 30 +++++++++++ src/frontends/pytorch/src/op_table.cpp | 7 +++ tests/layer_tests/pytorch_tests/test_clamp.py | 25 ++++++--- .../layer_tests/pytorch_tests/test_min_max.py | 53 +++++++++++++++++++ 4 files changed, 108 insertions(+), 7 deletions(-) diff --git a/src/frontends/pytorch/src/op/min_max.cpp b/src/frontends/pytorch/src/op/min_max.cpp index 670b4eca4d4..45b4f5f0155 100644 --- a/src/frontends/pytorch/src/op/min_max.cpp +++ b/src/frontends/pytorch/src/op/min_max.cpp @@ -112,6 +112,36 @@ OutputVector translate_minimum(const NodeContext& context) { return {res}; } +OutputVector translate_amin(const NodeContext& context) { + // aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor + + // aten::amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + num_inputs_check(context, 3, 4); + auto x = context.get_input(0); + auto dims = context.get_input(1); + auto keep_dims = context.const_input(2); + auto res = context.mark_node(std::make_shared(x, dims, keep_dims)); + if (!context.input_is_none(3)) { + context.mutate_input(3, res); + } + return {res}; +} + +OutputVector translate_amax(const NodeContext& context) { + // aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor + + // aten::amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + num_inputs_check(context, 3, 4); + auto x = context.get_input(0); + auto dims = context.get_input(1); + auto keep_dims = context.const_input(2); + auto res = context.mark_node(std::make_shared(x, dims, keep_dims)); + if (!context.input_is_none(3)) { + context.mutate_input(3, res); + } + return {res}; +} + } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index c420a1b16e1..b168775acc0 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -22,6 +22,8 @@ OP_CONVERTER(translate_add); OP_CONVERTER(translate_addcmul); OP_CONVERTER(translate_addmm); OP_CONVERTER(translate_all); +OP_CONVERTER(translate_amax); +OP_CONVERTER(translate_amin); OP_CONVERTER(translate_and); OP_CONVERTER(translate_arange); OP_CONVERTER(translate_argmax); @@ -237,6 +239,8 @@ const std::map get_supported_ops_ts() { {"aten::addcmul", op::translate_addcmul}, {"aten::addmm", op::translate_addmm}, {"aten::all", op::translate_all}, + {"aten::amax", op::translate_amax}, + {"aten::amin", op::translate_amin}, {"aten::arange", op::translate_arange}, {"aten::argmax", op::translate_argmax}, {"aten::argmin", op::translate_argmin}, @@ -266,8 +270,11 @@ const std::map get_supported_ops_ts() { {"aten::ceil_", op::inplace_op>}, {"aten::channel_shuffle", op::translate_channel_shuffle}, {"aten::clamp", op::translate_clamp}, + {"aten::clamp_", op::inplace_op}, {"aten::clamp_max", op::translate_1to1_match_2_inputs}, {"aten::clamp_min", op::translate_1to1_match_2_inputs}, + {"aten::clip", op::translate_clamp}, + {"aten::clip_", op::inplace_op}, {"aten::clone", op::skip_node}, // ignore clone operators that are inserted by PyTorch autograd {"aten::contiguous", op::skip_node}, // In openvino how tensors are stored in memory is internal plugin detail, // we assume all tensors are contiguous diff --git a/tests/layer_tests/pytorch_tests/test_clamp.py b/tests/layer_tests/pytorch_tests/test_clamp.py index 346b47c3d1f..ad869d6211e 100644 --- a/tests/layer_tests/pytorch_tests/test_clamp.py +++ b/tests/layer_tests/pytorch_tests/test_clamp.py @@ -11,11 +11,11 @@ class TestClamp(PytorchLayerTest): import numpy as np return (np.random.randn(1, 3, 224, 224).astype(np.float32),) - def create_model(self, minimum, maximum, as_tensors=False): + def create_model(self, minimum, maximum, as_tensors=False, op_type='clamp'): import torch class aten_clamp(torch.nn.Module): - def __init__(self, minimum, maximum, as_tensors): + def __init__(self, minimum, maximum, as_tensors, op_type="clamp"): super(aten_clamp, self).__init__() if minimum is not None and as_tensors: minimum = torch.tensor(minimum) @@ -23,20 +23,31 @@ class TestClamp(PytorchLayerTest): if maximum is not None and as_tensors: maximum = torch.tensor(maximum) self.max = maximum + self.forward = getattr(self, f"forward_{op_type}") - def forward(self, x): + def forward_clamp(self, x): return torch.clamp(x, self.min, self.max) + def forward_clip(self, x): + return torch.clip(x, self.min, self.max) + + def forward_clamp_(self, x): + return x.clamp_(self.min, self.max), x + + def forward_clip_(self, x): + return x.clip_(self.min, self.max), x + ref_net = None - op_name = "aten::clamp" - return aten_clamp(minimum, maximum, as_tensors), ref_net, op_name + op_name = f"aten::{op_type}" + return aten_clamp(minimum, maximum, as_tensors, op_type), ref_net, op_name @pytest.mark.parametrize("minimum,maximum", [(0., 1.), (-0.5, 1.5), (None, 10.), (None, -10.), (10., None), (-10., None), (100, 200)]) @pytest.mark.parametrize("as_tensors", [True, False]) + @pytest.mark.parametrize("op_type", ["clamp", "clamp_"]) @pytest.mark.nightly - def test_clamp(self, minimum, maximum, as_tensors, ie_device, precision, ir_version): - self._test(*self.create_model(minimum, maximum, as_tensors), ie_device, precision, ir_version) + def test_clamp(self, minimum, maximum, as_tensors, op_type, ie_device, precision, ir_version): + self._test(*self.create_model(minimum, maximum, as_tensors, op_type), ie_device, precision, ir_version) @pytest.mark.xfail(reason='OpenVINO clamp does not support min > max') def test_clamp_min_greater(self, ie_device, precision, ir_version): diff --git a/tests/layer_tests/pytorch_tests/test_min_max.py b/tests/layer_tests/pytorch_tests/test_min_max.py index c32fe41512f..3a624d534fa 100644 --- a/tests/layer_tests/pytorch_tests/test_min_max.py +++ b/tests/layer_tests/pytorch_tests/test_min_max.py @@ -283,4 +283,57 @@ class TestMinimumMaximum(PytorchLayerTest): ie_device, precision, ir_version, kwargs_to_prepare_input= {"input_dtype": input_dtype, "second_input_dtype": input_dtype, "out": True} + ) + + +class TestAminAmax(PytorchLayerTest): + def _prepare_input(self, input_dtype="float32", out=False, axes=None, keep_dims=False): + import numpy as np + x = np.random.randn(1, 3, 10, 10).astype(input_dtype) + if not out: + return (x,) + if isinstance(axes, list): + axes = tuple(axes) + out = np.zeros_like(np.max(x, axis=axes, keepdims=keep_dims), dtype=input_dtype) + return (x, out) + + def create_model(self, op_type, axis, keep_dims, out=False): + import torch + op_types = { + "amax": torch.amax, + "amin": torch.amin + } + + + op = op_types[op_type] + + class aten_amin_amax(torch.nn.Module): + def __init__(self, op, axis, keep_dims, out): + super().__init__() + self.op = op + self.axis = axis + self.keep_dims = keep_dims + if out: + self.forward = self.forward_out + + def forward_out(self, x, y): + return self.op(x, self.axis, self.keep_dims, out=y), y + + def forward(self, x): + return self.op(x, self.axis, self.keep_dims) + + + model_cls = aten_amin_amax(op, axis, keep_dims, out) + + return model_cls, None, f"aten::{op_type}" + + @pytest.mark.parametrize("op_type", ["amin", "amax"]) + @pytest.mark.parametrize("axis", [0, -1, 1, [1, 2], [-1, -2], [2, 0, -1], [0, 1, 2, 3]]) + @pytest.mark.parametrize("keep_dims", [True, False]) + @pytest.mark.parametrize("out", [True, False]) + @pytest.mark.parametrize("input_dtype", ['float32', 'int32', 'int64', 'float64']) + def test_amin_amax(self, op_type, input_dtype, axis, keep_dims, out, ie_device, precision, ir_version): + self._test(*self.create_model(op_type, axis, keep_dims, out), + ie_device, precision, ir_version, kwargs_to_prepare_input= + {"input_dtype": input_dtype, "out": out, "axes": axis, "keep_dims": keep_dims} ) \ No newline at end of file