From fde054e4a6df0a264ac5a2b065082c8402de911c Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Fri, 22 Sep 2023 09:30:57 +0400 Subject: [PATCH] [PT FE]: support aten::minimum aten::maximum (#19996) --- src/frontends/pytorch/src/op/min_max.cpp | 32 ++++++++ src/frontends/pytorch/src/op_table.cpp | 4 + .../layer_tests/pytorch_tests/test_min_max.py | 74 +++++++++++++++++++ 3 files changed, 110 insertions(+) diff --git a/src/frontends/pytorch/src/op/min_max.cpp b/src/frontends/pytorch/src/op/min_max.cpp index c6b63f11e25..670b4eca4d4 100644 --- a/src/frontends/pytorch/src/op/min_max.cpp +++ b/src/frontends/pytorch/src/op/min_max.cpp @@ -80,6 +80,38 @@ OutputVector translate_min(const NodeContext& context) { return {values, indicies}; }; +OutputVector translate_maximum(const NodeContext& context) { + // aten::maximum(Tensor self, Tensor other) -> Tensor + + // aten::maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + + num_inputs_check(context, 2, 3); + auto x = context.get_input(0); + auto y = context.get_input(1); + align_eltwise_input_types(context, x, y, true); + auto res = context.mark_node(std::make_shared(x, y)); + if (!context.input_is_none(2)) { + context.mutate_input(2, res); + } + return {res}; +} + +OutputVector translate_minimum(const NodeContext& context) { + // aten::minimum(Tensor self, Tensor other) -> Tensor + + // aten::minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + + num_inputs_check(context, 2, 3); + auto x = context.get_input(0); + auto y = context.get_input(1); + align_eltwise_input_types(context, x, y, true); + auto res = context.mark_node(std::make_shared(x, y)); + if (!context.input_is_none(2)) { + context.mutate_input(2, res); + } + return {res}; +} + } // namespace op } // namespace pytorch } // namespace frontend diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 9cb68a3ea5c..088a2ce9639 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -96,10 +96,12 @@ OP_CONVERTER(translate_loop); OP_CONVERTER(translate_masked_fill); OP_CONVERTER(translate_masked_scatter); OP_CONVERTER(translate_max); +OP_CONVERTER(translate_maximum); OP_CONVERTER(translate_max_poolnd); OP_CONVERTER(translate_mean); OP_CONVERTER(translate_meshgrid); OP_CONVERTER(translate_min); +OP_CONVERTER(translate_minimum); OP_CONVERTER(translate_narrow); OP_CONVERTER(translate_native_multi_head_attention); OP_CONVERTER(translate_neg); @@ -351,12 +353,14 @@ const std::map get_supported_ops_ts() { {"aten::masked_scatter_", op::inplace_op}, {"aten::matmul", op::translate_1to1_match_2_inputs}, {"aten::max", op::translate_max}, + {"aten::maximum", op::translate_maximum}, {"aten::max_pool1d", op::quantizable_op}, {"aten::max_pool2d", op::quantizable_op}, {"aten::max_pool3d", op::quantizable_op}, {"aten::mean", op::quantizable_op}, {"aten::meshgrid", op::translate_meshgrid}, {"aten::min", op::translate_min}, + {"aten::minimum", op::translate_minimum}, {"aten::mm", op::translate_1to1_match_2_inputs}, {"aten::mul", op::translate_1to1_match_2_inputs_align_types}, {"aten::mul_", op::inplace_op>}, diff --git a/tests/layer_tests/pytorch_tests/test_min_max.py b/tests/layer_tests/pytorch_tests/test_min_max.py index 1d03c837380..c32fe41512f 100644 --- a/tests/layer_tests/pytorch_tests/test_min_max.py +++ b/tests/layer_tests/pytorch_tests/test_min_max.py @@ -210,3 +210,77 @@ class TestPrimMin(PytorchLayerTest): def test_min(self, case, kwargs_to_prepare_input, ie_device, precision, ir_version): self._test(*self.create_model(case), ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input, use_mo_convert=False) + + +class TestMinimumMaximum(PytorchLayerTest): + def _prepare_input(self, input_dtype="float32", second_input_dtype="float32", out=False): + import numpy as np + x = np.random.randn(1, 3, 10, 10).astype(input_dtype) + y = np.random.randn(1, 3, 10, 10).astype(second_input_dtype) + if not out: + return x, y + return (x, y, np.zeros_like(x).astype(input_dtype)) + + def create_model(self, op_type, dtypes=("float32", "float32"), out=False): + import torch + op_types = { + "maximum": torch.maximum, + "minimum": torch.minimum + } + + dtypes_map = { + "float32": torch.float32, + "int32": torch.int32, + "int64": torch.int64, + "float64": torch.float64 + } + + op = op_types[op_type] + + class aten_minimum_maximum(torch.nn.Module): + def __init__(self, op, l_dtype, r_dtype, out): + super(aten_minimum_maximum, self).__init__() + self.op = op + self.l_dtype = l_dtype + self.r_dtype = r_dtype + if out: + self.forward = self.forward_out + + def forward_out(self, x, y, z): + return self.op(x.to(self.l_dtype), y.to(self.r_dtype), out=z), z + + def forward(self, x, y): + return self.op(x.to(self.l_dtype), y.to(self.r_dtype)) + + l_dtype = dtypes_map[dtypes[0]] + r_dtype = dtypes_map[dtypes[1]] + model_cls = aten_minimum_maximum(op, l_dtype, r_dtype, out) + + return model_cls, None, f"aten::{op_type}" + + @pytest.mark.parametrize("op_type", ["minimum", "maximum"]) + @pytest.mark.parametrize("second_input_dtype", ["float32", "int32", "int64", "float64"]) + @pytest.mark.parametrize("first_input_dtype", ["float32", "int32", "int64", "float64"]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_minimum_maximum( + self, op_type, first_input_dtype, second_input_dtype, ie_device, precision, ir_version + ): + self._test(*self.create_model(op_type, dtypes=(first_input_dtype, second_input_dtype), out=False), + ie_device, precision, ir_version, kwargs_to_prepare_input= + {"input_dtype": first_input_dtype, "second_input_dtype": second_input_dtype, "out": False} + ) + + + @pytest.mark.parametrize("op_type", ['minimum', 'maximum']) + @pytest.mark.parametrize("input_dtype", ["float32", "int32", "int64", "float64"]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_minimum_maximum_out( + self, op_type, input_dtype, ie_device, precision, ir_version + ): + self._test(*self.create_model(op_type, dtypes=(input_dtype, input_dtype), out=True), + ie_device, precision, ir_version, kwargs_to_prepare_input= + {"input_dtype": input_dtype, "second_input_dtype": input_dtype, + "out": True} + ) \ No newline at end of file