diff --git a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py index 11d5991e700..f7a398bf67e 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py @@ -107,9 +107,10 @@ class TorchScriptPythonDecoder (Decoder): gptq.unpatch_model(pt_module) if not skip_freeze: + ops_kind_no_freeze = ["quantize", "aten::as_strided"] for n in scripted.inlined_graph.nodes(): # TODO: switch off freezing for all traced models - if "quantize" in n.kind(): + if any(kind in n.kind() for kind in ops_kind_no_freeze): # do not freeze quantized models skip_freeze = True break @@ -150,6 +151,16 @@ class TorchScriptPythonDecoder (Decoder): raw_input = self._raw_input(index) return self.get_shape_for_value(raw_input) + def get_input_strides(self, index: int) -> typing.List[int]: + raw_input = self._raw_input(index) + if isinstance(raw_input, torch.Value): + inp_type = raw_input.type() + if isinstance(inp_type, torch.TensorType): + strides = inp_type.strides() + if strides: + return strides + return [] + def get_input_type(self, index: int): raw_input = self._raw_input(index) return self.get_type_for_value(raw_input) diff --git a/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp b/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp index a1136e4cda6..024b03b2ff4 100644 --- a/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp +++ b/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp @@ -34,6 +34,10 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder { PYBIND11_OVERRIDE_PURE(ov::PartialShape, TorchDecoder, get_input_shape, index); } + const std::vector& get_input_strides(size_t index) const override { + PYBIND11_OVERRIDE_PURE(const std::vector&, TorchDecoder, get_input_strides, index); + } + ov::Any get_input_type(size_t index) const override { PYBIND11_OVERRIDE_PURE(ov::Any, TorchDecoder, get_input_type, index); } diff --git a/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp b/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp index 066c203e3a1..d5878783c31 100644 --- a/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp +++ b/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp @@ -40,6 +40,9 @@ public: // Return shape if inputs has torch::Tensor type in the original model, otherwise returns the shape [] of a scalar virtual PartialShape get_input_shape(size_t index) const = 0; + // Return strides if inputs has torch::Tensor type in original model, otherwise return []. + virtual const std::vector& get_input_strides(size_t index) const = 0; + // Return element::Type when it the original type can be represented, otherwise returns PT-specific data type object // (see custom_type.hpp) virtual Any get_input_type(size_t index) const = 0; diff --git a/src/frontends/pytorch/src/op/as_strided.cpp b/src/frontends/pytorch/src/op/as_strided.cpp new file mode 100644 index 00000000000..5d1dfe38bda --- /dev/null +++ b/src/frontends/pytorch/src/op/as_strided.cpp @@ -0,0 +1,106 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "openvino/frontend/pytorch/node_context.hpp" +#include "openvino/op/add.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/gather.hpp" +#include "openvino/op/multiply.hpp" +#include "openvino/op/range.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/op/scatter_update.hpp" +#include "openvino/op/tile.hpp" +#include "openvino/op/transpose.hpp" +#include "utils.hpp" + +namespace ov { +namespace frontend { +namespace pytorch { +namespace op { + +using namespace ov::op; +bool compare_strides(const std::tuple& a, const std::tuple& b) { + return std::get<0>(a) > std::get<0>(b); +} +OutputVector translate_as_strided(const NodeContext& context) { + // "aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)" + num_inputs_check(context, 3, 4); + auto decoder = context.get_decoder(); + auto input = context.get_input(0); + auto const_1 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1})); + auto const_0 = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0})); + auto const_neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1})); + auto input_strides = decoder->get_input_strides(0); + FRONT_END_OP_CONVERSION_CHECK(input_strides.size() != 0, + "aten::as_strided: Couldn't retrive input stride information from torchscript."); + + std::vector idxs(input_strides.size()); + iota(idxs.begin(), idxs.end(), 0); + std::vector> stride_idxs(idxs.size()); + std::for_each(idxs.rbegin(), idxs.rend(), [&](size_t& idx) { + stride_idxs[idx] = {input_strides[idx], idx}; + }); + + std::sort(stride_idxs.begin(), stride_idxs.end(), compare_strides); + std::vector transpose_idx(idxs.size()); + int transpose_counter = 0; + std::for_each(stride_idxs.begin(), stride_idxs.end(), [&](std::tuple& pair) { + transpose_idx[transpose_counter] = uint64_t(std::get<1>(pair)); + transpose_counter++; + }); + auto transpose_idx_const = + context.mark_node(v0::Constant::create(element::i32, Shape{transpose_idx.size()}, transpose_idx)); + auto transposed_input = context.mark_node(std::make_shared(input, transpose_idx_const)); + auto flat_input = context.mark_node(std::make_shared(transposed_input, const_neg_1, false)); + std::deque> sizes; + std::deque> strides; + if (std::dynamic_pointer_cast(context.get_input_from_visible_context(1).get_node_shared_ptr())) { + auto input_vector = context.const_input>(1); + std::for_each(input_vector.rbegin(), input_vector.rend(), [&](int64_t input_val) { + auto const_input = context.mark_node(v0::Constant::create(element::i32, Shape{}, {input_val})); + sizes.push_front(const_input); + }); + } else { + sizes = get_list_as_outputs(context.get_input(1)); + } + if (std::dynamic_pointer_cast(context.get_input_from_visible_context(2).get_node_shared_ptr())) { + auto input_vector = context.const_input>(2); + std::for_each(input_vector.rbegin(), input_vector.rend(), [&](int64_t input_val) { + auto const_input = context.mark_node(v0::Constant::create(element::i32, Shape{}, {input_val})); + strides.push_front(const_input); + }); + } else { + strides = get_list_as_outputs(context.get_input(2)); + } + auto offset = const_0->output(0); + if (!context.input_is_none(3)) { + offset = context.get_input(3); + } + FRONT_END_OP_CONVERSION_CHECK(sizes.size() == strides.size(), + "aten::as_strided: Vector for strides and sizes need to have equal length."); + auto strides_size = strides.size() - 1; + auto i = 0; + auto strides_length_const = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {strides.size()})); + auto ones_strides_len = context.mark_node(std::make_shared(const_1, strides_length_const)); + auto indices = const_0; + std::for_each(strides.rbegin(), strides.rend(), [&](Output& stride) { + auto const_num_iter = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {strides_size - i})); + stride = context.mark_node(std::make_shared(stride, element::i32)); + auto size = sizes.at(strides_size - i); + auto range = context.mark_node(std::make_shared(const_0, size, const_1, element::i32)); + range = context.mark_node(std::make_shared(range, stride)); + auto iteration_shape = context.mark_node( + std::make_shared(ones_strides_len, const_num_iter, const_neg_1, const_0)); + range = context.mark_node(std::make_shared(range, iteration_shape, false)); + indices = context.mark_node(std::make_shared(indices, range)); + i++; + }); + indices = context.mark_node(std::make_shared(indices, offset)); + auto gather = context.mark_node(std::make_shared(flat_input, indices, const_0)); + return {gather}; +}; +} // namespace op +} // namespace pytorch +} // namespace frontend +} // namespace ov diff --git a/src/frontends/pytorch/src/op_table.cpp b/src/frontends/pytorch/src/op_table.cpp index 5614a3881c3..d9ac0aff6af 100644 --- a/src/frontends/pytorch/src/op_table.cpp +++ b/src/frontends/pytorch/src/op_table.cpp @@ -34,6 +34,7 @@ OP_CONVERTER(translate_argmax); OP_CONVERTER(translate_argsort); OP_CONVERTER(translate_argmax); OP_CONVERTER(translate_argmin); +OP_CONVERTER(translate_as_strided); OP_CONVERTER(translate_as_tensor); OP_CONVERTER(translate_avg_poolnd); OP_CONVERTER(translate_bool); @@ -256,6 +257,7 @@ const std::map get_supported_ops_ts() { {"aten::argmax", op::translate_argmax}, {"aten::argmin", op::translate_argmin}, {"aten::argsort", op::translate_argsort}, + {"aten::as_strided", op::translate_as_strided}, {"aten::as_tensor", op::translate_as_tensor}, {"aten::asin", op::translate_1to1_match_1_inputs_with_fp32_type_alignment}, {"aten::asin_", op::inplace_op>}, diff --git a/src/frontends/pytorch/src/utils.hpp b/src/frontends/pytorch/src/utils.hpp index 1635296e612..b4a37118961 100644 --- a/src/frontends/pytorch/src/utils.hpp +++ b/src/frontends/pytorch/src/utils.hpp @@ -158,6 +158,9 @@ public: virtual PartialShape get_input_shape(size_t index) const override { FRONT_END_NOT_IMPLEMENTED(get_input_shape); } + virtual const std::vector& get_input_strides(size_t index) const override { + FRONT_END_NOT_IMPLEMENTED(get_input_strides); + } virtual Any get_input_type(size_t index) const override { FRONT_END_NOT_IMPLEMENTED(get_input_type); } diff --git a/tests/layer_tests/pytorch_tests/test_as_strided.py b/tests/layer_tests/pytorch_tests/test_as_strided.py new file mode 100644 index 00000000000..9bfaa66d3a7 --- /dev/null +++ b/tests/layer_tests/pytorch_tests/test_as_strided.py @@ -0,0 +1,125 @@ +# Copyright (C) 2018-2023 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch + +from pytorch_layer_test_class import PytorchLayerTest + + +class TestAsStrided(PytorchLayerTest): + def _prepare_input(self): + return (np.random.randn(8, 8).astype(np.float32),) + + def create_model(self, size, stride, offset): + class aten_as_strided(torch.nn.Module): + def __init__(self, size, stride, offset): + super().__init__() + self.size = size + self.stride = stride + self.offset = offset + + def forward(self, x): + return torch.as_strided(x, self.size, self.stride, self.offset) + + ref_net = None + + return aten_as_strided(size, stride, offset), ref_net, "aten::as_strided" + + @pytest.mark.parametrize( + "size,stride", + [ + ([1], [1]), + ([2, 2], [1, 1]), + ([5, 4, 3], [1, 3, 7]), + ([5, 5, 5], [5, 0, 5]), + ([1, 2, 3, 4], [4, 3, 2, 1]), + ], + ) + @pytest.mark.parametrize("offset", [None, 1, 3, 7]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_as_strided(self, size, stride, offset, ie_device, precision, ir_version): + self._test(*self.create_model(size, stride, offset), ie_device, precision, ir_version, trace_model=True) + + +class TestAsStridedListConstruct(PytorchLayerTest): + def _prepare_input(self, size_shape_tensor=[1], stride_shape_tensor=[1]): + return ( + np.random.randn(8, 8).astype(np.float32), + np.ones(size_shape_tensor), + np.ones(stride_shape_tensor), + ) + + def create_model(self, size, stride, offset, mode): + class aten_as_strided(torch.nn.Module): + def __init__(self, size, stride, offset, mode): + super().__init__() + self.size = size + self.stride = stride + self.size_shape_tensor = torch.empty(size) + self.stride_shape_tensor = torch.empty(stride) + self.offset = offset + modes = { + "no_const": self.forward_no_const, + "stride_const": self.forward_stride_const, + "size_const": self.forward_size_const, + } + self.forward = modes.get(mode) + + def forward_no_const(self, x, size_shape_tensor, stride_shape_tensor): + sz1, sz2, sz3 = size_shape_tensor.shape + st1, st2, st3 = stride_shape_tensor.shape + return torch.as_strided(x, [sz1, sz2, sz3], [st1, st2, st3], self.offset) + + def forward_stride_const(self, x, size_shape_tensor, stride_shape_tensor): + sz1, sz2, sz3 = size_shape_tensor.shape + return torch.as_strided(x, [sz1, sz2, sz3], self.stride, self.offset) + + def forward_size_const(self, x, size_shape_tensor, stride_shape_tensor): + st1, st2, st3 = stride_shape_tensor.shape + return torch.as_strided(x, self.size, [st1, st2, st3], self.offset) + + ref_net = None + + return aten_as_strided(size, stride, offset, mode), ref_net, ["aten::as_strided", "prim::ListConstruct"] + + @pytest.mark.parametrize("size,stride", [([5, 4, 3], [1, 3, 7]), ([5, 5, 5], [5, 0, 5])]) + @pytest.mark.parametrize("offset", [None, 7]) + @pytest.mark.parametrize("mode", ["no_const", "stride_const", "size_const"]) + @pytest.mark.nightly + @pytest.mark.precommit + def test_as_strided_list_construct(self, size, stride, offset, mode, ie_device, precision, ir_version): + inp_kwargs = {"size_shape_tensor": size, "stride_shape_tensor": stride} + self._test( + *self.create_model(size, stride, offset, mode), + ie_device, + precision, + ir_version, + kwargs_to_prepare_input=inp_kwargs, + trace_model=True + ) + + +class TestAsStridedLongformer(PytorchLayerTest): + def _prepare_input(self): + return (np.random.randn(1, 10, 20, 40).astype(np.float32).transpose([0, 2, 3, 1]),) + + def create_model(self): + class aten_as_strided_lf(torch.nn.Module): + def forward(self, x): + chunk_size = list(x.size()) + chunk_size[1] = chunk_size[1] * 2 - 1 + chunk_stride = list(x.stride()) + chunk_stride[1] = chunk_stride[1] // 2 + return x.as_strided(size=chunk_size, stride=chunk_stride) + + ref_net = None + + return aten_as_strided_lf(), ref_net, "aten::as_strided" + + @pytest.mark.nightly + @pytest.mark.precommit + def test_as_strided_lf(self, ie_device, precision, ir_version): + self._test(*self.create_model(), ie_device, precision, ir_version, trace_model=True, freeze_model=False) diff --git a/tests/model_hub_tests/torch_tests/hf_transformers_models b/tests/model_hub_tests/torch_tests/hf_transformers_models index 0618d98a4d9..56deedc29b7 100644 --- a/tests/model_hub_tests/torch_tests/hf_transformers_models +++ b/tests/model_hub_tests/torch_tests/hf_transformers_models @@ -10,7 +10,6 @@ albert-base-v2,albert AlekseyKorshuk/test_reward_model,reward_model,skip,Load problem alibaba-damo/mgp-str-base,mgp-str,xfail,Compile error: unsupported Einsum allenai/hvila-block-layoutlm-finetuned-docbank,hierarchical_model,skip,Load problem -allenai/longformer-base-4096,longformer,xfail,Unsupported op aten::as_strided ameya772/sentence-t5-base-atis-fine-tuned,T5,skip,Load problem andreasmadsen/efficient_mlm_m0.40,roberta-prelayernorm anton-l/emformer-base-librispeech,emformer,skip,Load problem @@ -301,7 +300,6 @@ pie/example-re-textclf-tacred,TransformerTextClassificationModel,skip,Load probl pleisto/yuren-baichuan-7b,multimodal_llama,skip,Load problem predictia/europe_reanalysis_downscaler_convbaseline,convbilinear,skip,Load problem predictia/europe_reanalysis_downscaler_convswin2sr,conv_swin2sr,skip,Load problem -pszemraj/led-large-book-summary,led,xfail,Unsupported op aten::as_strided qmeeus/whisper-small-ner-combined,whisper_for_slu,skip,Load problem raman-ai/pcqv2-tokengt-lap16,tokengt,skip,Load problem range3/pegasus-gpt2-medium,pegasusgpt2,skip,Load problem diff --git a/tests/model_hub_tests/torch_tests/test_hf_transformers.py b/tests/model_hub_tests/torch_tests/test_hf_transformers.py index 184e725a04f..caeb2e0ff2a 100644 --- a/tests/model_hub_tests/torch_tests/test_hf_transformers.py +++ b/tests/model_hub_tests/torch_tests/test_hf_transformers.py @@ -292,7 +292,8 @@ class TestTransformersModel(TestConvertModel): cleanup_dir(hf_hub_cache_dir) super().teardown_method() - @pytest.mark.parametrize("name,type", [("bert-base-uncased", "bert"), + @pytest.mark.parametrize("name,type", [("allenai/led-base-16384", "led"), + ("bert-base-uncased", "bert"), ("facebook/bart-large-mnli", "bart"), ("google/flan-t5-base", "t5"), ("google/tapas-large-finetuned-wtq", "tapas"),