[PT FE]: support aten::minimum aten::maximum (#19996)

This commit is contained in:
Ekaterina Aidova
2023-09-22 09:30:57 +04:00
committed by GitHub
parent f1b8abe55a
commit fde054e4a6
3 changed files with 110 additions and 0 deletions

View File

@@ -80,6 +80,38 @@ OutputVector translate_min(const NodeContext& context) {
return {values, indicies};
};
OutputVector translate_maximum(const NodeContext& context) {
// aten::maximum(Tensor self, Tensor other) -> Tensor
// aten::maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto y = context.get_input(1);
align_eltwise_input_types(context, x, y, true);
auto res = context.mark_node(std::make_shared<v1::Maximum>(x, y));
if (!context.input_is_none(2)) {
context.mutate_input(2, res);
}
return {res};
}
OutputVector translate_minimum(const NodeContext& context) {
// aten::minimum(Tensor self, Tensor other) -> Tensor
// aten::minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto y = context.get_input(1);
align_eltwise_input_types(context, x, y, true);
auto res = context.mark_node(std::make_shared<v1::Minimum>(x, y));
if (!context.input_is_none(2)) {
context.mutate_input(2, res);
}
return {res};
}
} // namespace op
} // namespace pytorch
} // namespace frontend

View File

@@ -96,10 +96,12 @@ OP_CONVERTER(translate_loop);
OP_CONVERTER(translate_masked_fill);
OP_CONVERTER(translate_masked_scatter);
OP_CONVERTER(translate_max);
OP_CONVERTER(translate_maximum);
OP_CONVERTER(translate_max_poolnd);
OP_CONVERTER(translate_mean);
OP_CONVERTER(translate_meshgrid);
OP_CONVERTER(translate_min);
OP_CONVERTER(translate_minimum);
OP_CONVERTER(translate_narrow);
OP_CONVERTER(translate_native_multi_head_attention);
OP_CONVERTER(translate_neg);
@@ -351,12 +353,14 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::masked_scatter_", op::inplace_op<op::translate_masked_scatter>},
{"aten::matmul", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::max", op::translate_max},
{"aten::maximum", op::translate_maximum},
{"aten::max_pool1d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool2d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::max_pool3d", op::quantizable_op<op::translate_max_poolnd>},
{"aten::mean", op::quantizable_op<op::translate_mean>},
{"aten::meshgrid", op::translate_meshgrid},
{"aten::min", op::translate_min},
{"aten::minimum", op::translate_minimum},
{"aten::mm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::mul", op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
{"aten::mul_", op::inplace_op<op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>>},

View File

@@ -210,3 +210,77 @@ class TestPrimMin(PytorchLayerTest):
def test_min(self, case, kwargs_to_prepare_input, ie_device, precision, ir_version):
self._test(*self.create_model(case),
ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input, use_mo_convert=False)
class TestMinimumMaximum(PytorchLayerTest):
def _prepare_input(self, input_dtype="float32", second_input_dtype="float32", out=False):
import numpy as np
x = np.random.randn(1, 3, 10, 10).astype(input_dtype)
y = np.random.randn(1, 3, 10, 10).astype(second_input_dtype)
if not out:
return x, y
return (x, y, np.zeros_like(x).astype(input_dtype))
def create_model(self, op_type, dtypes=("float32", "float32"), out=False):
import torch
op_types = {
"maximum": torch.maximum,
"minimum": torch.minimum
}
dtypes_map = {
"float32": torch.float32,
"int32": torch.int32,
"int64": torch.int64,
"float64": torch.float64
}
op = op_types[op_type]
class aten_minimum_maximum(torch.nn.Module):
def __init__(self, op, l_dtype, r_dtype, out):
super(aten_minimum_maximum, self).__init__()
self.op = op
self.l_dtype = l_dtype
self.r_dtype = r_dtype
if out:
self.forward = self.forward_out
def forward_out(self, x, y, z):
return self.op(x.to(self.l_dtype), y.to(self.r_dtype), out=z), z
def forward(self, x, y):
return self.op(x.to(self.l_dtype), y.to(self.r_dtype))
l_dtype = dtypes_map[dtypes[0]]
r_dtype = dtypes_map[dtypes[1]]
model_cls = aten_minimum_maximum(op, l_dtype, r_dtype, out)
return model_cls, None, f"aten::{op_type}"
@pytest.mark.parametrize("op_type", ["minimum", "maximum"])
@pytest.mark.parametrize("second_input_dtype", ["float32", "int32", "int64", "float64"])
@pytest.mark.parametrize("first_input_dtype", ["float32", "int32", "int64", "float64"])
@pytest.mark.nightly
@pytest.mark.precommit
def test_minimum_maximum(
self, op_type, first_input_dtype, second_input_dtype, ie_device, precision, ir_version
):
self._test(*self.create_model(op_type, dtypes=(first_input_dtype, second_input_dtype), out=False),
ie_device, precision, ir_version, kwargs_to_prepare_input=
{"input_dtype": first_input_dtype, "second_input_dtype": second_input_dtype, "out": False}
)
@pytest.mark.parametrize("op_type", ['minimum', 'maximum'])
@pytest.mark.parametrize("input_dtype", ["float32", "int32", "int64", "float64"])
@pytest.mark.nightly
@pytest.mark.precommit
def test_minimum_maximum_out(
self, op_type, input_dtype, ie_device, precision, ir_version
):
self._test(*self.create_model(op_type, dtypes=(input_dtype, input_dtype), out=True),
ie_device, precision, ir_version, kwargs_to_prepare_input=
{"input_dtype": input_dtype, "second_input_dtype": input_dtype,
"out": True}
)