From d992c6b9c7ad548cd3f8bba88047db663344c29f Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Fri, 10 Feb 2023 14:23:27 +0100 Subject: [PATCH] [PT FE] Fix aten::flatten, add more tests (#15576) * Fix flatten, add more tests * Apply review feedback * Fix code style --- src/frontends/pytorch/src/op/flatten.cpp | 66 +++++++++---------- src/frontends/pytorch/src/op/linear.cpp | 5 +- src/frontends/pytorch/src/utils.cpp | 13 +++- src/frontends/pytorch/src/utils.hpp | 5 +- .../layer_tests/pytorch_tests/test_flatten.py | 40 +++++++++++ .../layer_tests/pytorch_tests/test_linear.py | 50 ++++++++++++++ .../pytorch_tests/test_transpose.py | 36 ++++++++++ 7 files changed, 174 insertions(+), 41 deletions(-) create mode 100644 tests/layer_tests/pytorch_tests/test_flatten.py create mode 100644 tests/layer_tests/pytorch_tests/test_linear.py create mode 100644 tests/layer_tests/pytorch_tests/test_transpose.py diff --git a/src/frontends/pytorch/src/op/flatten.cpp b/src/frontends/pytorch/src/op/flatten.cpp index f8b4ee1d20b..3d1550d0824 100644 --- a/src/frontends/pytorch/src/op/flatten.cpp +++ b/src/frontends/pytorch/src/op/flatten.cpp @@ -3,59 +3,55 @@ // #include "openvino/frontend/pytorch/node_context.hpp" -#include "openvino/opsets/opset10.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/slice.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "utils.hpp" namespace ov { namespace frontend { namespace pytorch { namespace op { +using namespace ov::op; + OutputVector translate_flatten(NodeContext& context) { + num_inputs_check(context, 2, 3); + auto x = context.get_input(0); auto start_dim = context.const_input(1); auto end_dim = context.const_input(2); - auto shape = std::make_shared(context.get_input(0), element::i32); - auto rank_ = std::make_shared(shape, element::i32); - auto rank = std::make_shared(rank_); - // Use opset::If for dim normalization + Output shape; + Output rank; + std::tie(shape, rank) = get_shape_rank(context, x, true); + // Use opset::If for dim normalization. For now we only have flatten with constant start and end auto start_dim_node = context.get_input(1); auto end_dim_node = context.get_input(2); if (start_dim < 0) { - start_dim_node = std::make_shared(rank, start_dim_node); + start_dim_node = context.mark_node(std::make_shared(rank, start_dim_node)); } if (end_dim < 0) { - end_dim_node = std::make_shared(rank, end_dim_node); + end_dim_node = context.mark_node(std::make_shared(rank, end_dim_node)); } - auto delta = std::make_shared(end_dim_node, start_dim_node); - auto rank_delta = std::make_shared(rank, delta); - auto true_const0 = opset10::Constant::create(element::boolean, Shape{}, {1}); - auto zeros_loop = std::make_shared(rank_delta, true_const0); - auto true_const = opset10::Constant::create(element::boolean, Shape{}, {1}); - auto result_true = std::make_shared(true_const); - auto zero_const = opset10::Constant::create(element::i32, Shape{1}, {0}); - auto result_zero = std::make_shared(zero_const); - auto f = std::make_shared(ResultVector{result_true, result_zero}, ParameterVector{}); - zeros_loop->set_function(f); - zeros_loop->set_special_body_ports({-1, 0}); - auto zeros = zeros_loop->get_concatenated_slices(result_zero, 0, 1, 1, -1, 0); - auto neg_1_const = opset10::Constant::create(element::i32, Shape{1}, {-1}); - auto axis_0 = opset10::Constant::create(element::i32, Shape{1}, {0}); - auto start_dim_node_ = std::make_shared(start_dim_node, axis_0); - auto new_shape = std::make_shared(zeros, start_dim_node_, neg_1_const, axis_0); + // Slice shape from begin and end, then concat with -1, if slice return empty tensor concat shuold still be able to + // work with it + auto zero = v0::Constant::create(element::i32, Shape{1}, {0}); + auto one = v0::Constant::create(element::i32, Shape{1}, {1}); + auto int_max = v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits::max()}); + auto start_dim_u = std::make_shared(start_dim_node, zero); + auto slice_begin = std::make_shared(shape, zero, start_dim_u, one); + auto neg_1_const = v0::Constant::create(element::i32, Shape{1}, {-1}); + auto end_dim_u = std::make_shared(end_dim_node, zero); + auto end_dim_next = std::make_shared(end_dim_u, one); + auto slice_end = std::make_shared(shape, end_dim_next, int_max, one); + auto new_shape = std::make_shared(OutputVector{slice_begin, neg_1_const, slice_end}, 0); - context.mark_nodes({shape, - rank_, - rank, - delta, - rank_delta, - true_const0, - zeros_loop, - neg_1_const, - axis_0, - start_dim_node_, - new_shape}); + context.mark_nodes({zero, one, int_max, start_dim_u, end_dim_u, slice_begin, slice_end, neg_1_const, new_shape}); - return {context.mark_node(std::make_shared(context.get_input(0), new_shape, true))}; + return {context.mark_node(std::make_shared(context.get_input(0), new_shape, true))}; }; } // namespace op diff --git a/src/frontends/pytorch/src/op/linear.cpp b/src/frontends/pytorch/src/op/linear.cpp index 956f1d00646..9bf3b7db822 100644 --- a/src/frontends/pytorch/src/op/linear.cpp +++ b/src/frontends/pytorch/src/op/linear.cpp @@ -3,7 +3,7 @@ // #include "openvino/frontend/pytorch/node_context.hpp" -#include "openvino/opsets/opset10.hpp" +#include "openvino/op/matmul.hpp" #include "utils.hpp" namespace ov { @@ -12,9 +12,10 @@ namespace pytorch { namespace op { OutputVector translate_linear(NodeContext& context) { + // schema: aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor auto x = context.get_input(0); auto y = context.get_input(1); - auto matmul = std::make_shared(x, y, false, true); + auto matmul = context.mark_node(std::make_shared(x, y, false, true)); return {context.mark_output(make_optional_bias(matmul, context, 2))}; }; diff --git a/src/frontends/pytorch/src/utils.cpp b/src/frontends/pytorch/src/utils.cpp index 3e1d97f8d51..c7403c7d7b1 100644 --- a/src/frontends/pytorch/src/utils.cpp +++ b/src/frontends/pytorch/src/utils.cpp @@ -56,9 +56,16 @@ Output reshape_channelwise(const NodeContext& context, Output(data, new_shape, false)); } -std::shared_ptr get_rank_node(const Output& node) { - auto shape = std::make_shared(node); - return std::make_shared(shape); +std::tuple, Output> get_shape_rank(const NodeContext& context, + const Output& x, + bool as_scalar, + element::Type output_type) { + auto shape = context.mark_node(std::make_shared(x, output_type)); + Output rank = context.mark_node(std::make_shared(shape, output_type)); + if (as_scalar) { + rank = context.mark_node(std::make_shared(rank)); + } + return std::make_tuple(shape, rank); } Output reshape_kernel_for_group(const NodeContext& context, const Output& kernel, int64_t groups) { diff --git a/src/frontends/pytorch/src/utils.hpp b/src/frontends/pytorch/src/utils.hpp index 78f62c7380b..632056aeb28 100644 --- a/src/frontends/pytorch/src/utils.hpp +++ b/src/frontends/pytorch/src/utils.hpp @@ -29,7 +29,10 @@ Output reshape_channelwise(const NodeContext& context, Output data, Output shape_source); -std::shared_ptr get_rank_node(const Output& node); +std::tuple, Output> get_shape_rank(const NodeContext& context, + const Output& x, + bool as_scalar = false, + element::Type output_type = element::i32); Output reshape_kernel_for_group(const NodeContext& context, const Output& kernel, int64_t groups); diff --git a/tests/layer_tests/pytorch_tests/test_flatten.py b/tests/layer_tests/pytorch_tests/test_flatten.py new file mode 100644 index 00000000000..1702d3bf525 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_flatten.py @@ -0,0 +1,40 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestFlatten(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(2, 3, 4, 5).astype(np.float32),) + + def create_model(self, dim0, dim1): + import torch + + class aten_flatten(torch.nn.Module): + def __init__(self, dim0, dim1): + super(aten_flatten, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + return torch.flatten(x, self.dim0, self.dim1) + + ref_net = None + + return aten_flatten(dim0, dim1), ref_net, "aten::flatten" + + @pytest.mark.parametrize("dim0,dim1", [[0, 1], + [0, 2], + [0, 3], + [1, 2], + [1, 3], + [2, 3]]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_relu(self, dim0, dim1, ie_device, precision, ir_version): + self._test(*self.create_model(dim0, dim1), + ie_device, precision, ir_version) diff --git a/tests/layer_tests/pytorch_tests/test_linear.py b/tests/layer_tests/pytorch_tests/test_linear.py new file mode 100644 index 00000000000..4b7d9d58b1b --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_linear.py @@ -0,0 +1,50 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestMatMul(PytorchLayerTest): + def _prepare_input(self, m1_shape=(2, 2), m2_shape=(2, 2), bias_shape=None): + import numpy as np + if bias_shape is None: + return (np.random.randn(*m1_shape).astype(np.float32), np.random.randn(*m2_shape).astype(np.float32)) + else: + return (np.random.randn(*m1_shape).astype(np.float32), np.random.randn(*m2_shape).astype(np.float32), np.random.randn(*bias_shape).astype(np.float32)) + + def create_model(self, is_bias): + import torch + + class aten_mm(torch.nn.Module): + def __init__(self, is_bias): + super(aten_mm, self).__init__() + self.forward = self.forward2 if is_bias else self.forward1 + + def forward1(self, m1, m2): + return torch.nn.functional.linear(m1, m2) + + def forward2(self, m1, m2, bias): + return torch.nn.functional.linear(m1, m2, bias) + + ref_net = None + + return aten_mm(is_bias), ref_net, "aten::linear" + + @pytest.mark.parametrize("kwargs_to_prepare_input", [ + {'m1_shape': [9], 'm2_shape': [10, 9]}, + {'m1_shape': [9], 'm2_shape': [9]}, + {'m1_shape': [3, 9], 'm2_shape': [10, 9]}, + {'m1_shape': [3, 9], 'm2_shape': [9]}, + {'m1_shape': [2, 3, 9], 'm2_shape': [10, 9]}, + {'m1_shape': [2, 3, 9], 'm2_shape': [9]}, + {'m1_shape': [9], 'm2_shape': [10, 9], 'bias_shape': [10]}, + {'m1_shape': [3, 9], 'm2_shape': [10, 9], 'bias_shape': [10]}, + {'m1_shape': [2, 3, 9], 'm2_shape': [10, 9], 'bias_shape': [10]}, + ]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_matmul(self, kwargs_to_prepare_input, ie_device, precision, ir_version): + self._test(*self.create_model(len(kwargs_to_prepare_input) == 3), ie_device, precision, ir_version, + kwargs_to_prepare_input=kwargs_to_prepare_input) diff --git a/tests/layer_tests/pytorch_tests/test_transpose.py b/tests/layer_tests/pytorch_tests/test_transpose.py new file mode 100644 index 00000000000..c1d0bd4f3e7 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_transpose.py @@ -0,0 +1,36 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestTranspose(PytorchLayerTest): + def _prepare_input(self): + import numpy as np + return (np.random.randn(2, 3, 4, 5).astype(np.float32),) + + def create_model(self, dim0, dim1): + import torch + + class aten_transpose(torch.nn.Module): + def __init__(self, dim0, dim1): + super(aten_transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x): + return torch.transpose(x, self.dim0, self.dim1) + + ref_net = None + + return aten_transpose(dim0, dim1), ref_net, "aten::transpose" + + @pytest.mark.parametrize("dim0", [0, 1, 2, 3, -1, -2, -3, -4]) + @pytest.mark.parametrize("dim1", [0, 1, 2, 3, -1, -2, -3, -4]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_relu(self, dim0, dim1, ie_device, precision, ir_version): + self._test(*self.create_model(dim0, dim1), + ie_device, precision, ir_version)