[PT FE]: support aten::minimum aten::maximum (#19996)
This commit is contained in:
@@ -80,6 +80,38 @@ OutputVector translate_min(const NodeContext& context) {
|
||||
return {values, indicies};
|
||||
};
|
||||
|
||||
OutputVector translate_maximum(const NodeContext& context) {
|
||||
// aten::maximum(Tensor self, Tensor other) -> Tensor
|
||||
|
||||
// aten::maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
align_eltwise_input_types(context, x, y, true);
|
||||
auto res = context.mark_node(std::make_shared<v1::Maximum>(x, y));
|
||||
if (!context.input_is_none(2)) {
|
||||
context.mutate_input(2, res);
|
||||
}
|
||||
return {res};
|
||||
}
|
||||
|
||||
OutputVector translate_minimum(const NodeContext& context) {
|
||||
// aten::minimum(Tensor self, Tensor other) -> Tensor
|
||||
|
||||
// aten::minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto x = context.get_input(0);
|
||||
auto y = context.get_input(1);
|
||||
align_eltwise_input_types(context, x, y, true);
|
||||
auto res = context.mark_node(std::make_shared<v1::Minimum>(x, y));
|
||||
if (!context.input_is_none(2)) {
|
||||
context.mutate_input(2, res);
|
||||
}
|
||||
return {res};
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
|
||||
@@ -96,10 +96,12 @@ OP_CONVERTER(translate_loop);
|
||||
OP_CONVERTER(translate_masked_fill);
|
||||
OP_CONVERTER(translate_masked_scatter);
|
||||
OP_CONVERTER(translate_max);
|
||||
OP_CONVERTER(translate_maximum);
|
||||
OP_CONVERTER(translate_max_poolnd);
|
||||
OP_CONVERTER(translate_mean);
|
||||
OP_CONVERTER(translate_meshgrid);
|
||||
OP_CONVERTER(translate_min);
|
||||
OP_CONVERTER(translate_minimum);
|
||||
OP_CONVERTER(translate_narrow);
|
||||
OP_CONVERTER(translate_native_multi_head_attention);
|
||||
OP_CONVERTER(translate_neg);
|
||||
@@ -351,12 +353,14 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::masked_scatter_", op::inplace_op<op::translate_masked_scatter>},
|
||||
{"aten::matmul", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
{"aten::max", op::translate_max},
|
||||
{"aten::maximum", op::translate_maximum},
|
||||
{"aten::max_pool1d", op::quantizable_op<op::translate_max_poolnd>},
|
||||
{"aten::max_pool2d", op::quantizable_op<op::translate_max_poolnd>},
|
||||
{"aten::max_pool3d", op::quantizable_op<op::translate_max_poolnd>},
|
||||
{"aten::mean", op::quantizable_op<op::translate_mean>},
|
||||
{"aten::meshgrid", op::translate_meshgrid},
|
||||
{"aten::min", op::translate_min},
|
||||
{"aten::minimum", op::translate_minimum},
|
||||
{"aten::mm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
|
||||
{"aten::mul", op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>},
|
||||
{"aten::mul_", op::inplace_op<op::translate_1to1_match_2_inputs_align_types<opset10::Multiply>>},
|
||||
|
||||
@@ -210,3 +210,77 @@ class TestPrimMin(PytorchLayerTest):
|
||||
def test_min(self, case, kwargs_to_prepare_input, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(case),
|
||||
ie_device, precision, ir_version, kwargs_to_prepare_input=kwargs_to_prepare_input, use_mo_convert=False)
|
||||
|
||||
|
||||
class TestMinimumMaximum(PytorchLayerTest):
|
||||
def _prepare_input(self, input_dtype="float32", second_input_dtype="float32", out=False):
|
||||
import numpy as np
|
||||
x = np.random.randn(1, 3, 10, 10).astype(input_dtype)
|
||||
y = np.random.randn(1, 3, 10, 10).astype(second_input_dtype)
|
||||
if not out:
|
||||
return x, y
|
||||
return (x, y, np.zeros_like(x).astype(input_dtype))
|
||||
|
||||
def create_model(self, op_type, dtypes=("float32", "float32"), out=False):
|
||||
import torch
|
||||
op_types = {
|
||||
"maximum": torch.maximum,
|
||||
"minimum": torch.minimum
|
||||
}
|
||||
|
||||
dtypes_map = {
|
||||
"float32": torch.float32,
|
||||
"int32": torch.int32,
|
||||
"int64": torch.int64,
|
||||
"float64": torch.float64
|
||||
}
|
||||
|
||||
op = op_types[op_type]
|
||||
|
||||
class aten_minimum_maximum(torch.nn.Module):
|
||||
def __init__(self, op, l_dtype, r_dtype, out):
|
||||
super(aten_minimum_maximum, self).__init__()
|
||||
self.op = op
|
||||
self.l_dtype = l_dtype
|
||||
self.r_dtype = r_dtype
|
||||
if out:
|
||||
self.forward = self.forward_out
|
||||
|
||||
def forward_out(self, x, y, z):
|
||||
return self.op(x.to(self.l_dtype), y.to(self.r_dtype), out=z), z
|
||||
|
||||
def forward(self, x, y):
|
||||
return self.op(x.to(self.l_dtype), y.to(self.r_dtype))
|
||||
|
||||
l_dtype = dtypes_map[dtypes[0]]
|
||||
r_dtype = dtypes_map[dtypes[1]]
|
||||
model_cls = aten_minimum_maximum(op, l_dtype, r_dtype, out)
|
||||
|
||||
return model_cls, None, f"aten::{op_type}"
|
||||
|
||||
@pytest.mark.parametrize("op_type", ["minimum", "maximum"])
|
||||
@pytest.mark.parametrize("second_input_dtype", ["float32", "int32", "int64", "float64"])
|
||||
@pytest.mark.parametrize("first_input_dtype", ["float32", "int32", "int64", "float64"])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_minimum_maximum(
|
||||
self, op_type, first_input_dtype, second_input_dtype, ie_device, precision, ir_version
|
||||
):
|
||||
self._test(*self.create_model(op_type, dtypes=(first_input_dtype, second_input_dtype), out=False),
|
||||
ie_device, precision, ir_version, kwargs_to_prepare_input=
|
||||
{"input_dtype": first_input_dtype, "second_input_dtype": second_input_dtype, "out": False}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("op_type", ['minimum', 'maximum'])
|
||||
@pytest.mark.parametrize("input_dtype", ["float32", "int32", "int64", "float64"])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_minimum_maximum_out(
|
||||
self, op_type, input_dtype, ie_device, precision, ir_version
|
||||
):
|
||||
self._test(*self.create_model(op_type, dtypes=(input_dtype, input_dtype), out=True),
|
||||
ie_device, precision, ir_version, kwargs_to_prepare_input=
|
||||
{"input_dtype": input_dtype, "second_input_dtype": input_dtype,
|
||||
"out": True}
|
||||
)
|
||||
Reference in New Issue
Block a user