[PT FE] Fix aten::flatten, add more tests (#15576)

* Fix flatten, add more tests

* Apply review feedback

* Fix code style
This commit is contained in:
Maxim Vafin 2023-02-10 14:23:27 +01:00 committed by GitHub
parent f48b5278fd
commit d992c6b9c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 174 additions and 41 deletions

View File

@ -3,59 +3,55 @@
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
OutputVector translate_flatten(NodeContext& context) {
num_inputs_check(context, 2, 3);
auto x = context.get_input(0);
auto start_dim = context.const_input<int64_t>(1);
auto end_dim = context.const_input<int64_t>(2);
auto shape = std::make_shared<opset10::ShapeOf>(context.get_input(0), element::i32);
auto rank_ = std::make_shared<opset10::ShapeOf>(shape, element::i32);
auto rank = std::make_shared<opset10::Squeeze>(rank_);
// Use opset::If for dim normalization
Output<Node> shape;
Output<Node> rank;
std::tie(shape, rank) = get_shape_rank(context, x, true);
// Use opset::If for dim normalization. For now we only have flatten with constant start and end
auto start_dim_node = context.get_input(1);
auto end_dim_node = context.get_input(2);
if (start_dim < 0) {
start_dim_node = std::make_shared<opset10::Add>(rank, start_dim_node);
start_dim_node = context.mark_node(std::make_shared<v1::Add>(rank, start_dim_node));
}
if (end_dim < 0) {
end_dim_node = std::make_shared<opset10::Add>(rank, end_dim_node);
end_dim_node = context.mark_node(std::make_shared<v1::Add>(rank, end_dim_node));
}
auto delta = std::make_shared<opset10::Subtract>(end_dim_node, start_dim_node);
auto rank_delta = std::make_shared<opset10::Subtract>(rank, delta);
auto true_const0 = opset10::Constant::create(element::boolean, Shape{}, {1});
auto zeros_loop = std::make_shared<opset10::Loop>(rank_delta, true_const0);
auto true_const = opset10::Constant::create(element::boolean, Shape{}, {1});
auto result_true = std::make_shared<opset10::Result>(true_const);
auto zero_const = opset10::Constant::create(element::i32, Shape{1}, {0});
auto result_zero = std::make_shared<opset10::Result>(zero_const);
auto f = std::make_shared<ov::Model>(ResultVector{result_true, result_zero}, ParameterVector{});
zeros_loop->set_function(f);
zeros_loop->set_special_body_ports({-1, 0});
auto zeros = zeros_loop->get_concatenated_slices(result_zero, 0, 1, 1, -1, 0);
auto neg_1_const = opset10::Constant::create(element::i32, Shape{1}, {-1});
auto axis_0 = opset10::Constant::create(element::i32, Shape{1}, {0});
auto start_dim_node_ = std::make_shared<opset10::Unsqueeze>(start_dim_node, axis_0);
auto new_shape = std::make_shared<opset10::ScatterElementsUpdate>(zeros, start_dim_node_, neg_1_const, axis_0);
// Slice shape from begin and end, then concat with -1, if slice return empty tensor concat shuold still be able to
// work with it
auto zero = v0::Constant::create(element::i32, Shape{1}, {0});
auto one = v0::Constant::create(element::i32, Shape{1}, {1});
auto int_max = v0::Constant::create(element::i32, Shape{1}, {std::numeric_limits<int32_t>::max()});
auto start_dim_u = std::make_shared<v0::Unsqueeze>(start_dim_node, zero);
auto slice_begin = std::make_shared<v8::Slice>(shape, zero, start_dim_u, one);
auto neg_1_const = v0::Constant::create(element::i32, Shape{1}, {-1});
auto end_dim_u = std::make_shared<v0::Unsqueeze>(end_dim_node, zero);
auto end_dim_next = std::make_shared<v1::Add>(end_dim_u, one);
auto slice_end = std::make_shared<v8::Slice>(shape, end_dim_next, int_max, one);
auto new_shape = std::make_shared<v0::Concat>(OutputVector{slice_begin, neg_1_const, slice_end}, 0);
context.mark_nodes({shape,
rank_,
rank,
delta,
rank_delta,
true_const0,
zeros_loop,
neg_1_const,
axis_0,
start_dim_node_,
new_shape});
context.mark_nodes({zero, one, int_max, start_dim_u, end_dim_u, slice_begin, slice_end, neg_1_const, new_shape});
return {context.mark_node(std::make_shared<opset10::Reshape>(context.get_input(0), new_shape, true))};
return {context.mark_node(std::make_shared<v1::Reshape>(context.get_input(0), new_shape, true))};
};
} // namespace op

View File

@ -3,7 +3,7 @@
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/op/matmul.hpp"
#include "utils.hpp"
namespace ov {
@ -12,9 +12,10 @@ namespace pytorch {
namespace op {
OutputVector translate_linear(NodeContext& context) {
// schema: aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor
auto x = context.get_input(0);
auto y = context.get_input(1);
auto matmul = std::make_shared<opset10::MatMul>(x, y, false, true);
auto matmul = context.mark_node(std::make_shared<ov::op::v0::MatMul>(x, y, false, true));
return {context.mark_output(make_optional_bias(matmul, context, 2))};
};

View File

@ -56,9 +56,16 @@ Output<ov::Node> reshape_channelwise(const NodeContext& context, Output<ov::Node
return context.mark_node(std::make_shared<opset10::Reshape>(data, new_shape, false));
}
std::shared_ptr<Node> get_rank_node(const Output<Node>& node) {
auto shape = std::make_shared<opset10::ShapeOf>(node);
return std::make_shared<opset10::ShapeOf>(shape);
std::tuple<Output<Node>, Output<Node>> get_shape_rank(const NodeContext& context,
const Output<Node>& x,
bool as_scalar,
element::Type output_type) {
auto shape = context.mark_node(std::make_shared<opset10::ShapeOf>(x, output_type));
Output<Node> rank = context.mark_node(std::make_shared<opset10::ShapeOf>(shape, output_type));
if (as_scalar) {
rank = context.mark_node(std::make_shared<opset10::Squeeze>(rank));
}
return std::make_tuple(shape, rank);
}
Output<Node> reshape_kernel_for_group(const NodeContext& context, const Output<Node>& kernel, int64_t groups) {

View File

@ -29,7 +29,10 @@ Output<ov::Node> reshape_channelwise(const NodeContext& context,
Output<ov::Node> data,
Output<ngraph::Node> shape_source);
std::shared_ptr<ov::Node> get_rank_node(const Output<Node>& node);
std::tuple<Output<Node>, Output<Node>> get_shape_rank(const NodeContext& context,
const Output<Node>& x,
bool as_scalar = false,
element::Type output_type = element::i32);
Output<Node> reshape_kernel_for_group(const NodeContext& context, const Output<Node>& kernel, int64_t groups);

View File

@ -0,0 +1,40 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestFlatten(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(2, 3, 4, 5).astype(np.float32),)
def create_model(self, dim0, dim1):
import torch
class aten_flatten(torch.nn.Module):
def __init__(self, dim0, dim1):
super(aten_flatten, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
return torch.flatten(x, self.dim0, self.dim1)
ref_net = None
return aten_flatten(dim0, dim1), ref_net, "aten::flatten"
@pytest.mark.parametrize("dim0,dim1", [[0, 1],
[0, 2],
[0, 3],
[1, 2],
[1, 3],
[2, 3]])
@pytest.mark.nightly
@pytest.mark.precommit
def test_relu(self, dim0, dim1, ie_device, precision, ir_version):
self._test(*self.create_model(dim0, dim1),
ie_device, precision, ir_version)

View File

@ -0,0 +1,50 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestMatMul(PytorchLayerTest):
def _prepare_input(self, m1_shape=(2, 2), m2_shape=(2, 2), bias_shape=None):
import numpy as np
if bias_shape is None:
return (np.random.randn(*m1_shape).astype(np.float32), np.random.randn(*m2_shape).astype(np.float32))
else:
return (np.random.randn(*m1_shape).astype(np.float32), np.random.randn(*m2_shape).astype(np.float32), np.random.randn(*bias_shape).astype(np.float32))
def create_model(self, is_bias):
import torch
class aten_mm(torch.nn.Module):
def __init__(self, is_bias):
super(aten_mm, self).__init__()
self.forward = self.forward2 if is_bias else self.forward1
def forward1(self, m1, m2):
return torch.nn.functional.linear(m1, m2)
def forward2(self, m1, m2, bias):
return torch.nn.functional.linear(m1, m2, bias)
ref_net = None
return aten_mm(is_bias), ref_net, "aten::linear"
@pytest.mark.parametrize("kwargs_to_prepare_input", [
{'m1_shape': [9], 'm2_shape': [10, 9]},
{'m1_shape': [9], 'm2_shape': [9]},
{'m1_shape': [3, 9], 'm2_shape': [10, 9]},
{'m1_shape': [3, 9], 'm2_shape': [9]},
{'m1_shape': [2, 3, 9], 'm2_shape': [10, 9]},
{'m1_shape': [2, 3, 9], 'm2_shape': [9]},
{'m1_shape': [9], 'm2_shape': [10, 9], 'bias_shape': [10]},
{'m1_shape': [3, 9], 'm2_shape': [10, 9], 'bias_shape': [10]},
{'m1_shape': [2, 3, 9], 'm2_shape': [10, 9], 'bias_shape': [10]},
])
@pytest.mark.nightly
@pytest.mark.precommit
def test_matmul(self, kwargs_to_prepare_input, ie_device, precision, ir_version):
self._test(*self.create_model(len(kwargs_to_prepare_input) == 3), ie_device, precision, ir_version,
kwargs_to_prepare_input=kwargs_to_prepare_input)

View File

@ -0,0 +1,36 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import pytest
from pytorch_layer_test_class import PytorchLayerTest
class TestTranspose(PytorchLayerTest):
def _prepare_input(self):
import numpy as np
return (np.random.randn(2, 3, 4, 5).astype(np.float32),)
def create_model(self, dim0, dim1):
import torch
class aten_transpose(torch.nn.Module):
def __init__(self, dim0, dim1):
super(aten_transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
return torch.transpose(x, self.dim0, self.dim1)
ref_net = None
return aten_transpose(dim0, dim1), ref_net, "aten::transpose"
@pytest.mark.parametrize("dim0", [0, 1, 2, 3, -1, -2, -3, -4])
@pytest.mark.parametrize("dim1", [0, 1, 2, 3, -1, -2, -3, -4])
@pytest.mark.nightly
@pytest.mark.precommit
def test_relu(self, dim0, dim1, ie_device, precision, ir_version):
self._test(*self.create_model(dim0, dim1),
ie_device, precision, ir_version)