From cc08b0091e99f25ca60457231ed8f508423a89ab Mon Sep 17 00:00:00 2001 From: Maxim Vafin Date: Fri, 11 Aug 2023 10:46:04 +0200 Subject: [PATCH] [PT FE] Remove torch strides compensation (#19129) --- .../openvino/frontend/pytorch/fx_decoder.py | 22 ------------------- .../openvino/frontend/pytorch/ts_decoder.py | 16 -------------- .../src/openvino/frontend/pytorch/utils.py | 7 +----- .../pyopenvino/frontend/pytorch/decoder.hpp | 8 ------- .../openvino/frontend/pytorch/decoder.hpp | 6 ----- .../pytorch/src/translate_session.cpp | 18 --------------- src/frontends/pytorch/src/utils.hpp | 6 ----- .../pytorch_tests/pytorch_layer_test_class.py | 2 +- .../pytorch_tests/test_quantized_convnd.py | 11 +++------- 9 files changed, 5 insertions(+), 91 deletions(-) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py index 47ee16d8282..dfa795b5d0f 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py @@ -154,28 +154,6 @@ class TorchFXPythonDecoder (Decoder): else: return OVAny(OVType.f32) - def get_input_transpose_order(self, index): - return [] - # TODO TBD - - input = self._raw_input(index) - if input.type() is not None and input.type().kind() == 'TensorType': - strides = input.type().strides() - if strides is not None: - return [s[0] for s in sorted(enumerate(strides), key=lambda x:x[1], reverse=True)] - return [] - - def get_output_transpose_order(self, index): - return [] - - # old code - output = self._raw_output(index) - if output.type() is not None and output.type().kind() == 'TensorType': - strides = output.type().strides() - if strides is not None: - return [s[0] for s in sorted(enumerate(strides), key=lambda x:x[1], reverse=True)] - return [] - def get_subgraph_size(self): if issubclass(type(self.pt_module), torch.fx.Node): return 0 diff --git a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py index 1090dce0163..da4942bb5a7 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py @@ -261,22 +261,6 @@ class TorchScriptPythonDecoder (Decoder): full_type = self._get_known_type_for_value(value.type()) return full_type - def get_input_transpose_order(self, index: int) -> list: - raw_input = self._raw_input(index) - if raw_input.type() is not None and raw_input.type().kind() == "TensorType": - strides = raw_input.type().strides() - if strides is not None: - return [s[0] for s in sorted(enumerate(strides), key=lambda x:x[1], reverse=True)] - return [] - - def get_output_transpose_order(self, index: int) -> list: - output = self._raw_output(index) - if output.type() is not None and output.type().kind() == "TensorType": - strides = output.type().strides() - if strides is not None: - return [s[0] for s in sorted(enumerate(strides), key=lambda x:x[1], reverse=True)] - return [] - def get_subgraph_size(self) -> int: if isinstance(self.graph_element, torch.Node): return len(self.get_subgraphs()) diff --git a/src/bindings/python/src/openvino/frontend/pytorch/utils.py b/src/bindings/python/src/openvino/frontend/pytorch/utils.py index 0e7ffd66780..3c658119bb1 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/utils.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/utils.py @@ -56,20 +56,15 @@ def get_type_from_py_type(value): def torch_tensor_to_ov_const(torch_t: torch.Tensor, shared_memory=True): - torch_t = torch_t.to(memory_format=torch.contiguous_format) + torch_t = torch_t.contiguous() if torch_t.dtype == torch.bfloat16: # reinterpret bfloat16 data as float16 to allow conversion to numpy torch_t = torch_t.view(torch.float16) narr = torch_t.numpy(force=True) - if not narr.flags['C_CONTIGUOUS']: - narr = np.ascontiguousarray(narr) - # TODO: this tensor doesn't share memory with initial tensor tensor = Tensor(narr, torch_t.shape, OVType.bf16) ov_const = op.Constant(tensor, shared_memory=shared_memory) else: narr = torch_t.numpy(force=True) - if not narr.flags['C_CONTIGUOUS']: - narr = np.ascontiguousarray(narr) ov_const = op.Constant(narr, shared_memory=shared_memory) return ov_const diff --git a/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp b/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp index 4b32c1b6b0d..004fc19b209 100644 --- a/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp +++ b/src/bindings/python/src/pyopenvino/frontend/pytorch/decoder.hpp @@ -38,10 +38,6 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder { PYBIND11_OVERRIDE_PURE(ov::Any, TorchDecoder, get_input_type, index); } - const std::vector& get_input_transpose_order(size_t index) const override { - PYBIND11_OVERRIDE_PURE(const std::vector&, TorchDecoder, get_input_transpose_order, index); - } - const std::string& get_output_debug_name(size_t index) const override { PYBIND11_OVERRIDE_PURE(const std::string&, TorchDecoder, get_output_debug_name, index); } @@ -54,10 +50,6 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder { PYBIND11_OVERRIDE_PURE(ov::Any, TorchDecoder, get_output_type, index); } - const std::vector& get_output_transpose_order(size_t index) const override { - PYBIND11_OVERRIDE_PURE(const std::vector&, TorchDecoder, get_output_transpose_order, index); - } - bool input_is_none(size_t index) const override { PYBIND11_OVERRIDE_PURE(bool, TorchDecoder, input_is_none, index); } diff --git a/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp b/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp index 21517bea278..abab389a9c4 100644 --- a/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp +++ b/src/frontends/pytorch/include/openvino/frontend/pytorch/decoder.hpp @@ -44,9 +44,6 @@ public: // (see custom_type.hpp) virtual Any get_input_type(size_t index) const = 0; - // TODO: Consider deleting this method, probably it doesn't make sence outside Torch JIT execution - virtual const std::vector& get_input_transpose_order(size_t index) const = 0; - // Return debug name of the input tensor virtual const std::string& get_output_debug_name(size_t index) const = 0; @@ -56,9 +53,6 @@ public: // Return element::Type when it the original type can be represented, otherwise returns PT-specific data type object // (see custom_type.hpp) virtual Any get_output_type(size_t index) const = 0; - - // TODO: Consider deleting this method, probably it doesn't make sence outside Torch JIT execution - virtual const std::vector& get_output_transpose_order(size_t index) const = 0; // ------------------------------ // TODO: required? can be implemented in the context of a single node? diff --git a/src/frontends/pytorch/src/translate_session.cpp b/src/frontends/pytorch/src/translate_session.cpp index 894b6bd3f15..c075373f017 100644 --- a/src/frontends/pytorch/src/translate_session.cpp +++ b/src/frontends/pytorch/src/translate_session.cpp @@ -118,20 +118,6 @@ std::shared_ptr TranslateSession::convert_pytorch_model( encode_tensor_name(parameter->output(0), inputs.at(i), {pytorch_model->get_input_debug_name(i)}); parameters->push_back(parameter); input_node = parameter; - auto order = pytorch_model->get_input_transpose_order(i); - if (order.size() > 0 && !std::is_sorted(order.begin(), order.end())) { - FRONT_END_GENERAL_CHECK(pshape.is_static(), "Shape must be static."); // TODO: make dynamic - auto sh = pshape.get_shape(); - Shape new_shape(sh.size()); - for (size_t i = 0; i < sh.size(); i++) { - new_shape[order[i]] = sh[i]; - } - auto shape_const = v0::Constant::create(element::i32, {new_shape.size()}, new_shape); - auto reshape = std::make_shared(parameter, shape_const, false); - auto order_const = v0::Constant::create(element::i32, {order.size()}, order); - auto transpose = std::make_shared(reshape, order_const); - input_node = transpose; - } } (*tensor_map)[inputs.at(i)] = input_node; } @@ -167,7 +153,6 @@ std::shared_ptr TranslateSession::convert_pytorch_model( dtype = type.as(); } auto parameter = std::make_shared(dtype, ps); - // TODO: Missing get_input_transpose_order handling for not trivial layouts (*tensor_map)[input] = parameter; // set name of parameter to the index of node in the model encode_tensor_name(parameter->output(0), input); @@ -240,9 +225,6 @@ std::shared_ptr TranslateSession::convert_pytorch_model( (*tensor_map)[id] = parameter; } auto ov_output = tensor_map->at(id); - auto order = pytorch_model->get_output_transpose_order(i); - FRONT_END_GENERAL_CHECK(order.size() == 0 || std::is_sorted(order.begin(), order.end()), - "Output strides have wrong order."); FRONT_END_GENERAL_CHECK(ov_output.get_names().size() > 0, "Tensor doesn't have name, while it should have name: ", id); diff --git a/src/frontends/pytorch/src/utils.hpp b/src/frontends/pytorch/src/utils.hpp index 565476e7974..376ce027768 100644 --- a/src/frontends/pytorch/src/utils.hpp +++ b/src/frontends/pytorch/src/utils.hpp @@ -161,9 +161,6 @@ public: virtual Any get_input_type(size_t index) const override { FRONT_END_NOT_IMPLEMENTED(get_input_type); } - virtual const std::vector& get_input_transpose_order(size_t index) const override { - FRONT_END_NOT_IMPLEMENTED(get_input_transpose_order); - } virtual const std::string& get_output_debug_name(size_t index) const override { FRONT_END_NOT_IMPLEMENTED(get_output_debug_name); } @@ -173,9 +170,6 @@ public: virtual Any get_output_type(size_t index) const override { FRONT_END_NOT_IMPLEMENTED(get_output_type); } - virtual const std::vector& get_output_transpose_order(size_t index) const override { - FRONT_END_NOT_IMPLEMENTED(get_output_transpose_order); - } virtual bool input_is_none(size_t index) const override { FRONT_END_NOT_IMPLEMENTED(input_is_none); } diff --git a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py index 851036c0306..68bad824d7d 100644 --- a/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py +++ b/tests/layer_tests/pytorch_tests/pytorch_layer_test_class.py @@ -136,7 +136,7 @@ class PytorchLayerTest: assert 'quant_size' in kwargs, "quant size must be specified for quantized_ops flag" quant_size = kwargs['quant_size'] for i in range(len(infer_res)): - cur_fw_res = flatten_fw_res[i].to(memory_format=torch.contiguous_format).numpy( + cur_fw_res = flatten_fw_res[i].contiguous().numpy( ) if isinstance(flatten_fw_res[i], torch.Tensor) else flatten_fw_res[i] if np.array(cur_fw_res).size == 0: continue diff --git a/tests/layer_tests/pytorch_tests/test_quantized_convnd.py b/tests/layer_tests/pytorch_tests/test_quantized_convnd.py index 45ef28c47ce..e0caaa31c4f 100644 --- a/tests/layer_tests/pytorch_tests/test_quantized_convnd.py +++ b/tests/layer_tests/pytorch_tests/test_quantized_convnd.py @@ -41,7 +41,7 @@ class TestQuantizedConv2D(PytorchLayerTest): x_quantized = torch.quantize_per_tensor( x, 1.0, 0, torch.quint8) conv = self.conv(x_quantized) - return torch.dequantize(conv).contiguous() + return torch.dequantize(conv) ref_net = None if not relu: @@ -54,13 +54,8 @@ class TestQuantizedConv2D(PytorchLayerTest): @pytest.mark.parametrize( "params", [ - pytest.param( - {"weights_shape": [1, 3, 3, 3], "strides": 1, - "pads": 0, "dilations": 1, "groups": 1}, - marks=pytest.mark.xfail( - reason="Output channels equal to 1 creates output that fails to cast to contiguous." - ), - ), + {"weights_shape": [1, 3, 3, 3], "strides": 1, + "pads": 0, "dilations": 1, "groups": 1}, {"weights_shape": [2, 3, 3, 3], "strides": 1, "pads": 0, "dilations": 1, "groups": 1}, {"weights_shape": [2, 3, 3, 3], "strides": 2,