[PT FE]: support flip operation (#17705)

* [PT FE]: support flip operation

* more tests
This commit is contained in:
Ekaterina Aidova
2023-05-31 14:07:24 +04:00
committed by GitHub
parent 2d7db5e3d3
commit fb4efe7203
3 changed files with 84 additions and 0 deletions

View File

@@ -0,0 +1,39 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/broadcast.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/slice.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_flip(const NodeContext& context) {
num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto axis = context.get_input(1);
auto minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
auto minimum_int =
context.mark_node(v0::Constant::create(element::i32, Shape{}, {std::numeric_limits<int>::min()}));
auto axis_shape = context.mark_node(std::make_shared<v3::ShapeOf>(axis, element::i32));
auto start = context.mark_node(std::make_shared<v3::Broadcast>(minus_one, axis_shape));
auto stop = context.mark_node(std::make_shared<v3::Broadcast>(minimum_int, axis_shape));
auto slice = context.mark_node(std::make_shared<v8::Slice>(x, start, stop, start, axis));
if (!context.input_is_none(2)) {
context.mutate_input(2, slice);
}
return {slice};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@@ -51,6 +51,7 @@ OP_CONVERTER(translate_expand_as);
OP_CONVERTER(translate_eye);
OP_CONVERTER(translate_fill_);
OP_CONVERTER(translate_flatten);
OP_CONVERTER(translate_flip);
OP_CONVERTER(translate_floor_divide);
OP_CONVERTER(translate_floordiv);
OP_CONVERTER(translate_frobenius_norm);
@@ -233,6 +234,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
{"aten::eye", op::translate_eye},
{"aten::fill_", op::inplace_op<op::translate_fill_>},
{"aten::flatten", op::translate_flatten},
{"aten::flip", op::translate_flip},
{"aten::floor", op::translate_1to1_match_1_inputs<opset10::Floor>},
{"aten::floor_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Floor>>},
{"aten::floor_divide", op::translate_floor_divide},

View File

@@ -0,0 +1,43 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestFlip(PytorchLayerTest):
def _prepare_input(self, out=False, dtype="float32"):
import numpy as np
x = np.random.randn(2, 3, 4, 5).astype(dtype)
if not out:
return (x,)
return (x, np.zeros_like(x).astype(dtype))
def create_model(self, axis, out):
import torch
class aten_flip(torch.nn.Module):
def __init__(self, dim, out):
super(aten_flip, self).__init__()
self.dim = dim
if out:
self.forward = self.forward_out
def forward(self, x):
return torch.flip(x, self.dim)
def forward_out(self, x, y):
return torch.flip(x, self.dim, out=y), y
ref_net = None
return aten_flip(axis, out), ref_net, "aten::flip"
@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("axis", [[0], [1], [-1], [1, 2], [2, 3], [1, 2, 3]])
@pytest.mark.parametrize("out", [True, False])
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "uint8"])
def test_flip(self, axis, out, dtype, ie_device, precision, ir_version):
self._test(*self.create_model(axis, out), ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out, "dtype": dtype})