[PT FE]: support flip operation (#17705)
* [PT FE]: support flip operation * more tests
This commit is contained in:
39
src/frontends/pytorch/src/op/flip.cpp
Normal file
39
src/frontends/pytorch/src/op/flip.cpp
Normal file
@@ -0,0 +1,39 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/broadcast.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/shape_of.hpp"
|
||||
#include "openvino/op/slice.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
|
||||
OutputVector translate_flip(const NodeContext& context) {
|
||||
num_inputs_check(context, 2, 3);
|
||||
auto x = context.get_input(0);
|
||||
auto axis = context.get_input(1);
|
||||
auto minus_one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {-1}));
|
||||
auto minimum_int =
|
||||
context.mark_node(v0::Constant::create(element::i32, Shape{}, {std::numeric_limits<int>::min()}));
|
||||
auto axis_shape = context.mark_node(std::make_shared<v3::ShapeOf>(axis, element::i32));
|
||||
auto start = context.mark_node(std::make_shared<v3::Broadcast>(minus_one, axis_shape));
|
||||
auto stop = context.mark_node(std::make_shared<v3::Broadcast>(minimum_int, axis_shape));
|
||||
auto slice = context.mark_node(std::make_shared<v8::Slice>(x, start, stop, start, axis));
|
||||
if (!context.input_is_none(2)) {
|
||||
context.mutate_input(2, slice);
|
||||
}
|
||||
return {slice};
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
||||
@@ -51,6 +51,7 @@ OP_CONVERTER(translate_expand_as);
|
||||
OP_CONVERTER(translate_eye);
|
||||
OP_CONVERTER(translate_fill_);
|
||||
OP_CONVERTER(translate_flatten);
|
||||
OP_CONVERTER(translate_flip);
|
||||
OP_CONVERTER(translate_floor_divide);
|
||||
OP_CONVERTER(translate_floordiv);
|
||||
OP_CONVERTER(translate_frobenius_norm);
|
||||
@@ -233,6 +234,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops() {
|
||||
{"aten::eye", op::translate_eye},
|
||||
{"aten::fill_", op::inplace_op<op::translate_fill_>},
|
||||
{"aten::flatten", op::translate_flatten},
|
||||
{"aten::flip", op::translate_flip},
|
||||
{"aten::floor", op::translate_1to1_match_1_inputs<opset10::Floor>},
|
||||
{"aten::floor_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Floor>>},
|
||||
{"aten::floor_divide", op::translate_floor_divide},
|
||||
|
||||
43
tests/layer_tests/pytorch_tests/test_flip.py
Normal file
43
tests/layer_tests/pytorch_tests/test_flip.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestFlip(PytorchLayerTest):
|
||||
def _prepare_input(self, out=False, dtype="float32"):
|
||||
import numpy as np
|
||||
x = np.random.randn(2, 3, 4, 5).astype(dtype)
|
||||
if not out:
|
||||
return (x,)
|
||||
return (x, np.zeros_like(x).astype(dtype))
|
||||
|
||||
|
||||
def create_model(self, axis, out):
|
||||
import torch
|
||||
class aten_flip(torch.nn.Module):
|
||||
def __init__(self, dim, out):
|
||||
super(aten_flip, self).__init__()
|
||||
self.dim = dim
|
||||
if out:
|
||||
self.forward = self.forward_out
|
||||
|
||||
def forward(self, x):
|
||||
return torch.flip(x, self.dim)
|
||||
|
||||
def forward_out(self, x, y):
|
||||
return torch.flip(x, self.dim, out=y), y
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_flip(axis, out), ref_net, "aten::flip"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
@pytest.mark.parametrize("axis", [[0], [1], [-1], [1, 2], [2, 3], [1, 2, 3]])
|
||||
@pytest.mark.parametrize("out", [True, False])
|
||||
@pytest.mark.parametrize("dtype", ["float32", "float64", "int32", "int64", "uint8"])
|
||||
def test_flip(self, axis, out, dtype, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(axis, out), ie_device, precision, ir_version, kwargs_to_prepare_input={"out": out, "dtype": dtype})
|
||||
Reference in New Issue
Block a user