[PT FE] Add aten::as_strided (#19482)

* Add aten::as_strided

* rm commented code

* Update src/frontends/pytorch/src/op/as_strided.cpp

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

* Update src/frontends/pytorch/src/op/as_strided.cpp

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>

* Fix CI error

* Fix CI issues

* mark_node for remaining constants

* Add test reproducing issue

* Use strides from torchscript

* Add led model to test suite

* Add sugested changes

---------

Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
Mateusz Mikolajczyk 2023-10-20 12:24:10 +02:00 committed by GitHub
parent 73d25a0f99
commit 891f79ac84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 257 additions and 4 deletions

View File

@ -107,9 +107,10 @@ class TorchScriptPythonDecoder (Decoder):
gptq.unpatch_model(pt_module)
if not skip_freeze:
ops_kind_no_freeze = ["quantize", "aten::as_strided"]
for n in scripted.inlined_graph.nodes():
# TODO: switch off freezing for all traced models
if "quantize" in n.kind():
if any(kind in n.kind() for kind in ops_kind_no_freeze):
# do not freeze quantized models
skip_freeze = True
break
@ -150,6 +151,16 @@ class TorchScriptPythonDecoder (Decoder):
raw_input = self._raw_input(index)
return self.get_shape_for_value(raw_input)
def get_input_strides(self, index: int) -> typing.List[int]:
raw_input = self._raw_input(index)
if isinstance(raw_input, torch.Value):
inp_type = raw_input.type()
if isinstance(inp_type, torch.TensorType):
strides = inp_type.strides()
if strides:
return strides
return []
def get_input_type(self, index: int):
raw_input = self._raw_input(index)
return self.get_type_for_value(raw_input)

View File

@ -34,6 +34,10 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder {
PYBIND11_OVERRIDE_PURE(ov::PartialShape, TorchDecoder, get_input_shape, index);
}
const std::vector<size_t>& get_input_strides(size_t index) const override {
PYBIND11_OVERRIDE_PURE(const std::vector<size_t>&, TorchDecoder, get_input_strides, index);
}
ov::Any get_input_type(size_t index) const override {
PYBIND11_OVERRIDE_PURE(ov::Any, TorchDecoder, get_input_type, index);
}

View File

@ -40,6 +40,9 @@ public:
// Return shape if inputs has torch::Tensor type in the original model, otherwise returns the shape [] of a scalar
virtual PartialShape get_input_shape(size_t index) const = 0;
// Return strides if inputs has torch::Tensor type in original model, otherwise return [].
virtual const std::vector<size_t>& get_input_strides(size_t index) const = 0;
// Return element::Type when it the original type can be represented, otherwise returns PT-specific data type object
// (see custom_type.hpp)
virtual Any get_input_type(size_t index) const = 0;

View File

@ -0,0 +1,106 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/scatter_update.hpp"
#include "openvino/op/tile.hpp"
#include "openvino/op/transpose.hpp"
#include "utils.hpp"
namespace ov {
namespace frontend {
namespace pytorch {
namespace op {
using namespace ov::op;
bool compare_strides(const std::tuple<size_t, size_t>& a, const std::tuple<size_t, size_t>& b) {
return std::get<0>(a) > std::get<0>(b);
}
OutputVector translate_as_strided(const NodeContext& context) {
// "aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)"
num_inputs_check(context, 3, 4);
auto decoder = context.get_decoder();
auto input = context.get_input(0);
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto input_strides = decoder->get_input_strides(0);
FRONT_END_OP_CONVERSION_CHECK(input_strides.size() != 0,
"aten::as_strided: Couldn't retrive input stride information from torchscript.");
std::vector<size_t> idxs(input_strides.size());
iota(idxs.begin(), idxs.end(), 0);
std::vector<std::tuple<size_t, size_t>> stride_idxs(idxs.size());
std::for_each(idxs.rbegin(), idxs.rend(), [&](size_t& idx) {
stride_idxs[idx] = {input_strides[idx], idx};
});
std::sort(stride_idxs.begin(), stride_idxs.end(), compare_strides);
std::vector<uint64_t> transpose_idx(idxs.size());
int transpose_counter = 0;
std::for_each(stride_idxs.begin(), stride_idxs.end(), [&](std::tuple<size_t, size_t>& pair) {
transpose_idx[transpose_counter] = uint64_t(std::get<1>(pair));
transpose_counter++;
});
auto transpose_idx_const =
context.mark_node(v0::Constant::create(element::i32, Shape{transpose_idx.size()}, transpose_idx));
auto transposed_input = context.mark_node(std::make_shared<v1::Transpose>(input, transpose_idx_const));
auto flat_input = context.mark_node(std::make_shared<v1::Reshape>(transposed_input, const_neg_1, false));
std::deque<Output<Node>> sizes;
std::deque<Output<Node>> strides;
if (std::dynamic_pointer_cast<v0::Constant>(context.get_input_from_visible_context(1).get_node_shared_ptr())) {
auto input_vector = context.const_input<std::vector<int64_t>>(1);
std::for_each(input_vector.rbegin(), input_vector.rend(), [&](int64_t input_val) {
auto const_input = context.mark_node(v0::Constant::create(element::i32, Shape{}, {input_val}));
sizes.push_front(const_input);
});
} else {
sizes = get_list_as_outputs(context.get_input(1));
}
if (std::dynamic_pointer_cast<v0::Constant>(context.get_input_from_visible_context(2).get_node_shared_ptr())) {
auto input_vector = context.const_input<std::vector<int64_t>>(2);
std::for_each(input_vector.rbegin(), input_vector.rend(), [&](int64_t input_val) {
auto const_input = context.mark_node(v0::Constant::create(element::i32, Shape{}, {input_val}));
strides.push_front(const_input);
});
} else {
strides = get_list_as_outputs(context.get_input(2));
}
auto offset = const_0->output(0);
if (!context.input_is_none(3)) {
offset = context.get_input(3);
}
FRONT_END_OP_CONVERSION_CHECK(sizes.size() == strides.size(),
"aten::as_strided: Vector for strides and sizes need to have equal length.");
auto strides_size = strides.size() - 1;
auto i = 0;
auto strides_length_const = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {strides.size()}));
auto ones_strides_len = context.mark_node(std::make_shared<v0::Tile>(const_1, strides_length_const));
auto indices = const_0;
std::for_each(strides.rbegin(), strides.rend(), [&](Output<Node>& stride) {
auto const_num_iter = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {strides_size - i}));
stride = context.mark_node(std::make_shared<v0::Convert>(stride, element::i32));
auto size = sizes.at(strides_size - i);
auto range = context.mark_node(std::make_shared<v4::Range>(const_0, size, const_1, element::i32));
range = context.mark_node(std::make_shared<v1::Multiply>(range, stride));
auto iteration_shape = context.mark_node(
std::make_shared<v3::ScatterUpdate>(ones_strides_len, const_num_iter, const_neg_1, const_0));
range = context.mark_node(std::make_shared<v1::Reshape>(range, iteration_shape, false));
indices = context.mark_node(std::make_shared<v1::Add>(indices, range));
i++;
});
indices = context.mark_node(std::make_shared<v1::Add>(indices, offset));
auto gather = context.mark_node(std::make_shared<v8::Gather>(flat_input, indices, const_0));
return {gather};
};
} // namespace op
} // namespace pytorch
} // namespace frontend
} // namespace ov

View File

@ -34,6 +34,7 @@ OP_CONVERTER(translate_argmax);
OP_CONVERTER(translate_argsort);
OP_CONVERTER(translate_argmax);
OP_CONVERTER(translate_argmin);
OP_CONVERTER(translate_as_strided);
OP_CONVERTER(translate_as_tensor);
OP_CONVERTER(translate_avg_poolnd);
OP_CONVERTER(translate_bool);
@ -256,6 +257,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::argmax", op::translate_argmax},
{"aten::argmin", op::translate_argmin},
{"aten::argsort", op::translate_argsort},
{"aten::as_strided", op::translate_as_strided},
{"aten::as_tensor", op::translate_as_tensor},
{"aten::asin", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Asin>},
{"aten::asin_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Asin>>},

View File

@ -158,6 +158,9 @@ public:
virtual PartialShape get_input_shape(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_input_shape);
}
virtual const std::vector<size_t>& get_input_strides(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_input_strides);
}
virtual Any get_input_type(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_input_type);
}

View File

@ -0,0 +1,125 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from pytorch_layer_test_class import PytorchLayerTest
class TestAsStrided(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(8, 8).astype(np.float32),)
def create_model(self, size, stride, offset):
class aten_as_strided(torch.nn.Module):
def __init__(self, size, stride, offset):
super().__init__()
self.size = size
self.stride = stride
self.offset = offset
def forward(self, x):
return torch.as_strided(x, self.size, self.stride, self.offset)
ref_net = None
return aten_as_strided(size, stride, offset), ref_net, "aten::as_strided"
@pytest.mark.parametrize(
"size,stride",
[
([1], [1]),
([2, 2], [1, 1]),
([5, 4, 3], [1, 3, 7]),
([5, 5, 5], [5, 0, 5]),
([1, 2, 3, 4], [4, 3, 2, 1]),
],
)
@pytest.mark.parametrize("offset", [None, 1, 3, 7])
@pytest.mark.nightly
@pytest.mark.precommit
def test_as_strided(self, size, stride, offset, ie_device, precision, ir_version):
self._test(*self.create_model(size, stride, offset), ie_device, precision, ir_version, trace_model=True)
class TestAsStridedListConstruct(PytorchLayerTest):
def _prepare_input(self, size_shape_tensor=[1], stride_shape_tensor=[1]):
return (
np.random.randn(8, 8).astype(np.float32),
np.ones(size_shape_tensor),
np.ones(stride_shape_tensor),
)
def create_model(self, size, stride, offset, mode):
class aten_as_strided(torch.nn.Module):
def __init__(self, size, stride, offset, mode):
super().__init__()
self.size = size
self.stride = stride
self.size_shape_tensor = torch.empty(size)
self.stride_shape_tensor = torch.empty(stride)
self.offset = offset
modes = {
"no_const": self.forward_no_const,
"stride_const": self.forward_stride_const,
"size_const": self.forward_size_const,
}
self.forward = modes.get(mode)
def forward_no_const(self, x, size_shape_tensor, stride_shape_tensor):
sz1, sz2, sz3 = size_shape_tensor.shape
st1, st2, st3 = stride_shape_tensor.shape
return torch.as_strided(x, [sz1, sz2, sz3], [st1, st2, st3], self.offset)
def forward_stride_const(self, x, size_shape_tensor, stride_shape_tensor):
sz1, sz2, sz3 = size_shape_tensor.shape
return torch.as_strided(x, [sz1, sz2, sz3], self.stride, self.offset)
def forward_size_const(self, x, size_shape_tensor, stride_shape_tensor):
st1, st2, st3 = stride_shape_tensor.shape
return torch.as_strided(x, self.size, [st1, st2, st3], self.offset)
ref_net = None
return aten_as_strided(size, stride, offset, mode), ref_net, ["aten::as_strided", "prim::ListConstruct"]
@pytest.mark.parametrize("size,stride", [([5, 4, 3], [1, 3, 7]), ([5, 5, 5], [5, 0, 5])])
@pytest.mark.parametrize("offset", [None, 7])
@pytest.mark.parametrize("mode", ["no_const", "stride_const", "size_const"])
@pytest.mark.nightly
@pytest.mark.precommit
def test_as_strided_list_construct(self, size, stride, offset, mode, ie_device, precision, ir_version):
inp_kwargs = {"size_shape_tensor": size, "stride_shape_tensor": stride}
self._test(
*self.create_model(size, stride, offset, mode),
ie_device,
precision,
ir_version,
kwargs_to_prepare_input=inp_kwargs,
trace_model=True
)
class TestAsStridedLongformer(PytorchLayerTest):
def _prepare_input(self):
return (np.random.randn(1, 10, 20, 40).astype(np.float32).transpose([0, 2, 3, 1]),)
def create_model(self):
class aten_as_strided_lf(torch.nn.Module):
def forward(self, x):
chunk_size = list(x.size())
chunk_size[1] = chunk_size[1] * 2 - 1
chunk_stride = list(x.stride())
chunk_stride[1] = chunk_stride[1] // 2
return x.as_strided(size=chunk_size, stride=chunk_stride)
ref_net = None
return aten_as_strided_lf(), ref_net, "aten::as_strided"
@pytest.mark.nightly
@pytest.mark.precommit
def test_as_strided_lf(self, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True, freeze_model=False)

View File

@ -10,7 +10,6 @@ albert-base-v2,albert
AlekseyKorshuk/test_reward_model,reward_model,skip,Load problem
alibaba-damo/mgp-str-base,mgp-str,xfail,Compile error: unsupported Einsum
allenai/hvila-block-layoutlm-finetuned-docbank,hierarchical_model,skip,Load problem
allenai/longformer-base-4096,longformer,xfail,Unsupported op aten::as_strided
ameya772/sentence-t5-base-atis-fine-tuned,T5,skip,Load problem
andreasmadsen/efficient_mlm_m0.40,roberta-prelayernorm
anton-l/emformer-base-librispeech,emformer,skip,Load problem
@ -301,7 +300,6 @@ pie/example-re-textclf-tacred,TransformerTextClassificationModel,skip,Load probl
pleisto/yuren-baichuan-7b,multimodal_llama,skip,Load problem
predictia/europe_reanalysis_downscaler_convbaseline,convbilinear,skip,Load problem
predictia/europe_reanalysis_downscaler_convswin2sr,conv_swin2sr,skip,Load problem
pszemraj/led-large-book-summary,led,xfail,Unsupported op aten::as_strided
qmeeus/whisper-small-ner-combined,whisper_for_slu,skip,Load problem
raman-ai/pcqv2-tokengt-lap16,tokengt,skip,Load problem
range3/pegasus-gpt2-medium,pegasusgpt2,skip,Load problem

View File

@ -292,7 +292,8 @@ class TestTransformersModel(TestConvertModel):
cleanup_dir(hf_hub_cache_dir)
super().teardown_method()
@pytest.mark.parametrize("name,type", [("bert-base-uncased", "bert"),
@pytest.mark.parametrize("name,type", [("allenai/led-base-16384", "led"),
("bert-base-uncased", "bert"),
("facebook/bart-large-mnli", "bart"),
("google/flan-t5-base", "t5"),
("google/tapas-large-finetuned-wtq", "tapas"),