[PT FE] Add aten::as_strided (#19482)
* Add aten::as_strided * rm commented code * Update src/frontends/pytorch/src/op/as_strided.cpp Co-authored-by: Maxim Vafin <maxim.vafin@intel.com> * Update src/frontends/pytorch/src/op/as_strided.cpp Co-authored-by: Maxim Vafin <maxim.vafin@intel.com> * Fix CI error * Fix CI issues * mark_node for remaining constants * Add test reproducing issue * Use strides from torchscript * Add led model to test suite * Add sugested changes --------- Co-authored-by: Maxim Vafin <maxim.vafin@intel.com>
This commit is contained in:
parent
73d25a0f99
commit
891f79ac84
@ -107,9 +107,10 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
gptq.unpatch_model(pt_module)
|
||||
|
||||
if not skip_freeze:
|
||||
ops_kind_no_freeze = ["quantize", "aten::as_strided"]
|
||||
for n in scripted.inlined_graph.nodes():
|
||||
# TODO: switch off freezing for all traced models
|
||||
if "quantize" in n.kind():
|
||||
if any(kind in n.kind() for kind in ops_kind_no_freeze):
|
||||
# do not freeze quantized models
|
||||
skip_freeze = True
|
||||
break
|
||||
@ -150,6 +151,16 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
raw_input = self._raw_input(index)
|
||||
return self.get_shape_for_value(raw_input)
|
||||
|
||||
def get_input_strides(self, index: int) -> typing.List[int]:
|
||||
raw_input = self._raw_input(index)
|
||||
if isinstance(raw_input, torch.Value):
|
||||
inp_type = raw_input.type()
|
||||
if isinstance(inp_type, torch.TensorType):
|
||||
strides = inp_type.strides()
|
||||
if strides:
|
||||
return strides
|
||||
return []
|
||||
|
||||
def get_input_type(self, index: int):
|
||||
raw_input = self._raw_input(index)
|
||||
return self.get_type_for_value(raw_input)
|
||||
|
@ -34,6 +34,10 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder {
|
||||
PYBIND11_OVERRIDE_PURE(ov::PartialShape, TorchDecoder, get_input_shape, index);
|
||||
}
|
||||
|
||||
const std::vector<size_t>& get_input_strides(size_t index) const override {
|
||||
PYBIND11_OVERRIDE_PURE(const std::vector<size_t>&, TorchDecoder, get_input_strides, index);
|
||||
}
|
||||
|
||||
ov::Any get_input_type(size_t index) const override {
|
||||
PYBIND11_OVERRIDE_PURE(ov::Any, TorchDecoder, get_input_type, index);
|
||||
}
|
||||
|
@ -40,6 +40,9 @@ public:
|
||||
// Return shape if inputs has torch::Tensor type in the original model, otherwise returns the shape [] of a scalar
|
||||
virtual PartialShape get_input_shape(size_t index) const = 0;
|
||||
|
||||
// Return strides if inputs has torch::Tensor type in original model, otherwise return [].
|
||||
virtual const std::vector<size_t>& get_input_strides(size_t index) const = 0;
|
||||
|
||||
// Return element::Type when it the original type can be represented, otherwise returns PT-specific data type object
|
||||
// (see custom_type.hpp)
|
||||
virtual Any get_input_type(size_t index) const = 0;
|
||||
|
106
src/frontends/pytorch/src/op/as_strided.cpp
Normal file
106
src/frontends/pytorch/src/op/as_strided.cpp
Normal file
@ -0,0 +1,106 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "openvino/frontend/pytorch/node_context.hpp"
|
||||
#include "openvino/op/add.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/gather.hpp"
|
||||
#include "openvino/op/multiply.hpp"
|
||||
#include "openvino/op/range.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/scatter_update.hpp"
|
||||
#include "openvino/op/tile.hpp"
|
||||
#include "openvino/op/transpose.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace op {
|
||||
|
||||
using namespace ov::op;
|
||||
bool compare_strides(const std::tuple<size_t, size_t>& a, const std::tuple<size_t, size_t>& b) {
|
||||
return std::get<0>(a) > std::get<0>(b);
|
||||
}
|
||||
OutputVector translate_as_strided(const NodeContext& context) {
|
||||
// "aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)"
|
||||
num_inputs_check(context, 3, 4);
|
||||
auto decoder = context.get_decoder();
|
||||
auto input = context.get_input(0);
|
||||
auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
|
||||
auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
|
||||
auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
|
||||
auto input_strides = decoder->get_input_strides(0);
|
||||
FRONT_END_OP_CONVERSION_CHECK(input_strides.size() != 0,
|
||||
"aten::as_strided: Couldn't retrive input stride information from torchscript.");
|
||||
|
||||
std::vector<size_t> idxs(input_strides.size());
|
||||
iota(idxs.begin(), idxs.end(), 0);
|
||||
std::vector<std::tuple<size_t, size_t>> stride_idxs(idxs.size());
|
||||
std::for_each(idxs.rbegin(), idxs.rend(), [&](size_t& idx) {
|
||||
stride_idxs[idx] = {input_strides[idx], idx};
|
||||
});
|
||||
|
||||
std::sort(stride_idxs.begin(), stride_idxs.end(), compare_strides);
|
||||
std::vector<uint64_t> transpose_idx(idxs.size());
|
||||
int transpose_counter = 0;
|
||||
std::for_each(stride_idxs.begin(), stride_idxs.end(), [&](std::tuple<size_t, size_t>& pair) {
|
||||
transpose_idx[transpose_counter] = uint64_t(std::get<1>(pair));
|
||||
transpose_counter++;
|
||||
});
|
||||
auto transpose_idx_const =
|
||||
context.mark_node(v0::Constant::create(element::i32, Shape{transpose_idx.size()}, transpose_idx));
|
||||
auto transposed_input = context.mark_node(std::make_shared<v1::Transpose>(input, transpose_idx_const));
|
||||
auto flat_input = context.mark_node(std::make_shared<v1::Reshape>(transposed_input, const_neg_1, false));
|
||||
std::deque<Output<Node>> sizes;
|
||||
std::deque<Output<Node>> strides;
|
||||
if (std::dynamic_pointer_cast<v0::Constant>(context.get_input_from_visible_context(1).get_node_shared_ptr())) {
|
||||
auto input_vector = context.const_input<std::vector<int64_t>>(1);
|
||||
std::for_each(input_vector.rbegin(), input_vector.rend(), [&](int64_t input_val) {
|
||||
auto const_input = context.mark_node(v0::Constant::create(element::i32, Shape{}, {input_val}));
|
||||
sizes.push_front(const_input);
|
||||
});
|
||||
} else {
|
||||
sizes = get_list_as_outputs(context.get_input(1));
|
||||
}
|
||||
if (std::dynamic_pointer_cast<v0::Constant>(context.get_input_from_visible_context(2).get_node_shared_ptr())) {
|
||||
auto input_vector = context.const_input<std::vector<int64_t>>(2);
|
||||
std::for_each(input_vector.rbegin(), input_vector.rend(), [&](int64_t input_val) {
|
||||
auto const_input = context.mark_node(v0::Constant::create(element::i32, Shape{}, {input_val}));
|
||||
strides.push_front(const_input);
|
||||
});
|
||||
} else {
|
||||
strides = get_list_as_outputs(context.get_input(2));
|
||||
}
|
||||
auto offset = const_0->output(0);
|
||||
if (!context.input_is_none(3)) {
|
||||
offset = context.get_input(3);
|
||||
}
|
||||
FRONT_END_OP_CONVERSION_CHECK(sizes.size() == strides.size(),
|
||||
"aten::as_strided: Vector for strides and sizes need to have equal length.");
|
||||
auto strides_size = strides.size() - 1;
|
||||
auto i = 0;
|
||||
auto strides_length_const = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {strides.size()}));
|
||||
auto ones_strides_len = context.mark_node(std::make_shared<v0::Tile>(const_1, strides_length_const));
|
||||
auto indices = const_0;
|
||||
std::for_each(strides.rbegin(), strides.rend(), [&](Output<Node>& stride) {
|
||||
auto const_num_iter = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {strides_size - i}));
|
||||
stride = context.mark_node(std::make_shared<v0::Convert>(stride, element::i32));
|
||||
auto size = sizes.at(strides_size - i);
|
||||
auto range = context.mark_node(std::make_shared<v4::Range>(const_0, size, const_1, element::i32));
|
||||
range = context.mark_node(std::make_shared<v1::Multiply>(range, stride));
|
||||
auto iteration_shape = context.mark_node(
|
||||
std::make_shared<v3::ScatterUpdate>(ones_strides_len, const_num_iter, const_neg_1, const_0));
|
||||
range = context.mark_node(std::make_shared<v1::Reshape>(range, iteration_shape, false));
|
||||
indices = context.mark_node(std::make_shared<v1::Add>(indices, range));
|
||||
i++;
|
||||
});
|
||||
indices = context.mark_node(std::make_shared<v1::Add>(indices, offset));
|
||||
auto gather = context.mark_node(std::make_shared<v8::Gather>(flat_input, indices, const_0));
|
||||
return {gather};
|
||||
};
|
||||
} // namespace op
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -34,6 +34,7 @@ OP_CONVERTER(translate_argmax);
|
||||
OP_CONVERTER(translate_argsort);
|
||||
OP_CONVERTER(translate_argmax);
|
||||
OP_CONVERTER(translate_argmin);
|
||||
OP_CONVERTER(translate_as_strided);
|
||||
OP_CONVERTER(translate_as_tensor);
|
||||
OP_CONVERTER(translate_avg_poolnd);
|
||||
OP_CONVERTER(translate_bool);
|
||||
@ -256,6 +257,7 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
|
||||
{"aten::argmax", op::translate_argmax},
|
||||
{"aten::argmin", op::translate_argmin},
|
||||
{"aten::argsort", op::translate_argsort},
|
||||
{"aten::as_strided", op::translate_as_strided},
|
||||
{"aten::as_tensor", op::translate_as_tensor},
|
||||
{"aten::asin", op::translate_1to1_match_1_inputs_with_fp32_type_alignment<opset10::Asin>},
|
||||
{"aten::asin_", op::inplace_op<op::translate_1to1_match_1_inputs<opset10::Asin>>},
|
||||
|
@ -158,6 +158,9 @@ public:
|
||||
virtual PartialShape get_input_shape(size_t index) const override {
|
||||
FRONT_END_NOT_IMPLEMENTED(get_input_shape);
|
||||
}
|
||||
virtual const std::vector<size_t>& get_input_strides(size_t index) const override {
|
||||
FRONT_END_NOT_IMPLEMENTED(get_input_strides);
|
||||
}
|
||||
virtual Any get_input_type(size_t index) const override {
|
||||
FRONT_END_NOT_IMPLEMENTED(get_input_type);
|
||||
}
|
||||
|
125
tests/layer_tests/pytorch_tests/test_as_strided.py
Normal file
125
tests/layer_tests/pytorch_tests/test_as_strided.py
Normal file
@ -0,0 +1,125 @@
|
||||
# Copyright (C) 2018-2023 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from pytorch_layer_test_class import PytorchLayerTest
|
||||
|
||||
|
||||
class TestAsStrided(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(8, 8).astype(np.float32),)
|
||||
|
||||
def create_model(self, size, stride, offset):
|
||||
class aten_as_strided(torch.nn.Module):
|
||||
def __init__(self, size, stride, offset):
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.stride = stride
|
||||
self.offset = offset
|
||||
|
||||
def forward(self, x):
|
||||
return torch.as_strided(x, self.size, self.stride, self.offset)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_as_strided(size, stride, offset), ref_net, "aten::as_strided"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"size,stride",
|
||||
[
|
||||
([1], [1]),
|
||||
([2, 2], [1, 1]),
|
||||
([5, 4, 3], [1, 3, 7]),
|
||||
([5, 5, 5], [5, 0, 5]),
|
||||
([1, 2, 3, 4], [4, 3, 2, 1]),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("offset", [None, 1, 3, 7])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_as_strided(self, size, stride, offset, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(size, stride, offset), ie_device, precision, ir_version, trace_model=True)
|
||||
|
||||
|
||||
class TestAsStridedListConstruct(PytorchLayerTest):
|
||||
def _prepare_input(self, size_shape_tensor=[1], stride_shape_tensor=[1]):
|
||||
return (
|
||||
np.random.randn(8, 8).astype(np.float32),
|
||||
np.ones(size_shape_tensor),
|
||||
np.ones(stride_shape_tensor),
|
||||
)
|
||||
|
||||
def create_model(self, size, stride, offset, mode):
|
||||
class aten_as_strided(torch.nn.Module):
|
||||
def __init__(self, size, stride, offset, mode):
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.stride = stride
|
||||
self.size_shape_tensor = torch.empty(size)
|
||||
self.stride_shape_tensor = torch.empty(stride)
|
||||
self.offset = offset
|
||||
modes = {
|
||||
"no_const": self.forward_no_const,
|
||||
"stride_const": self.forward_stride_const,
|
||||
"size_const": self.forward_size_const,
|
||||
}
|
||||
self.forward = modes.get(mode)
|
||||
|
||||
def forward_no_const(self, x, size_shape_tensor, stride_shape_tensor):
|
||||
sz1, sz2, sz3 = size_shape_tensor.shape
|
||||
st1, st2, st3 = stride_shape_tensor.shape
|
||||
return torch.as_strided(x, [sz1, sz2, sz3], [st1, st2, st3], self.offset)
|
||||
|
||||
def forward_stride_const(self, x, size_shape_tensor, stride_shape_tensor):
|
||||
sz1, sz2, sz3 = size_shape_tensor.shape
|
||||
return torch.as_strided(x, [sz1, sz2, sz3], self.stride, self.offset)
|
||||
|
||||
def forward_size_const(self, x, size_shape_tensor, stride_shape_tensor):
|
||||
st1, st2, st3 = stride_shape_tensor.shape
|
||||
return torch.as_strided(x, self.size, [st1, st2, st3], self.offset)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_as_strided(size, stride, offset, mode), ref_net, ["aten::as_strided", "prim::ListConstruct"]
|
||||
|
||||
@pytest.mark.parametrize("size,stride", [([5, 4, 3], [1, 3, 7]), ([5, 5, 5], [5, 0, 5])])
|
||||
@pytest.mark.parametrize("offset", [None, 7])
|
||||
@pytest.mark.parametrize("mode", ["no_const", "stride_const", "size_const"])
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_as_strided_list_construct(self, size, stride, offset, mode, ie_device, precision, ir_version):
|
||||
inp_kwargs = {"size_shape_tensor": size, "stride_shape_tensor": stride}
|
||||
self._test(
|
||||
*self.create_model(size, stride, offset, mode),
|
||||
ie_device,
|
||||
precision,
|
||||
ir_version,
|
||||
kwargs_to_prepare_input=inp_kwargs,
|
||||
trace_model=True
|
||||
)
|
||||
|
||||
|
||||
class TestAsStridedLongformer(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
return (np.random.randn(1, 10, 20, 40).astype(np.float32).transpose([0, 2, 3, 1]),)
|
||||
|
||||
def create_model(self):
|
||||
class aten_as_strided_lf(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
chunk_size = list(x.size())
|
||||
chunk_size[1] = chunk_size[1] * 2 - 1
|
||||
chunk_stride = list(x.stride())
|
||||
chunk_stride[1] = chunk_stride[1] // 2
|
||||
return x.as_strided(size=chunk_size, stride=chunk_stride)
|
||||
|
||||
ref_net = None
|
||||
|
||||
return aten_as_strided_lf(), ref_net, "aten::as_strided"
|
||||
|
||||
@pytest.mark.nightly
|
||||
@pytest.mark.precommit
|
||||
def test_as_strided_lf(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True, freeze_model=False)
|
@ -10,7 +10,6 @@ albert-base-v2,albert
|
||||
AlekseyKorshuk/test_reward_model,reward_model,skip,Load problem
|
||||
alibaba-damo/mgp-str-base,mgp-str,xfail,Compile error: unsupported Einsum
|
||||
allenai/hvila-block-layoutlm-finetuned-docbank,hierarchical_model,skip,Load problem
|
||||
allenai/longformer-base-4096,longformer,xfail,Unsupported op aten::as_strided
|
||||
ameya772/sentence-t5-base-atis-fine-tuned,T5,skip,Load problem
|
||||
andreasmadsen/efficient_mlm_m0.40,roberta-prelayernorm
|
||||
anton-l/emformer-base-librispeech,emformer,skip,Load problem
|
||||
@ -301,7 +300,6 @@ pie/example-re-textclf-tacred,TransformerTextClassificationModel,skip,Load probl
|
||||
pleisto/yuren-baichuan-7b,multimodal_llama,skip,Load problem
|
||||
predictia/europe_reanalysis_downscaler_convbaseline,convbilinear,skip,Load problem
|
||||
predictia/europe_reanalysis_downscaler_convswin2sr,conv_swin2sr,skip,Load problem
|
||||
pszemraj/led-large-book-summary,led,xfail,Unsupported op aten::as_strided
|
||||
qmeeus/whisper-small-ner-combined,whisper_for_slu,skip,Load problem
|
||||
raman-ai/pcqv2-tokengt-lap16,tokengt,skip,Load problem
|
||||
range3/pegasus-gpt2-medium,pegasusgpt2,skip,Load problem
|
||||
|
@ -292,7 +292,8 @@ class TestTransformersModel(TestConvertModel):
|
||||
cleanup_dir(hf_hub_cache_dir)
|
||||
super().teardown_method()
|
||||
|
||||
@pytest.mark.parametrize("name,type", [("bert-base-uncased", "bert"),
|
||||
@pytest.mark.parametrize("name,type", [("allenai/led-base-16384", "led"),
|
||||
("bert-base-uncased", "bert"),
|
||||
("facebook/bart-large-mnli", "bart"),
|
||||
("google/flan-t5-base", "t5"),
|
||||
("google/tapas-large-finetuned-wtq", "tapas"),
|
||||
|
Loading…
Reference in New Issue
Block a user