[PT FE] Remove torch strides compensation (#19129)

This commit is contained in:
Maxim Vafin 2023-08-11 10:46:04 +02:00 committed by GitHub
parent c1c4c4cd51
commit cc08b0091e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 5 additions and 91 deletions

View File

@ -154,28 +154,6 @@ class TorchFXPythonDecoder (Decoder):
else:
return OVAny(OVType.f32)
def get_input_transpose_order(self, index):
return []
# TODO TBD
input = self._raw_input(index)
if input.type() is not None and input.type().kind() == 'TensorType':
strides = input.type().strides()
if strides is not None:
return [s[0] for s in sorted(enumerate(strides), key=lambda x:x[1], reverse=True)]
return []
def get_output_transpose_order(self, index):
return []
# old code
output = self._raw_output(index)
if output.type() is not None and output.type().kind() == 'TensorType':
strides = output.type().strides()
if strides is not None:
return [s[0] for s in sorted(enumerate(strides), key=lambda x:x[1], reverse=True)]
return []
def get_subgraph_size(self):
if issubclass(type(self.pt_module), torch.fx.Node):
return 0

View File

@ -261,22 +261,6 @@ class TorchScriptPythonDecoder (Decoder):
full_type = self._get_known_type_for_value(value.type())
return full_type
def get_input_transpose_order(self, index: int) -> list:
raw_input = self._raw_input(index)
if raw_input.type() is not None and raw_input.type().kind() == "TensorType":
strides = raw_input.type().strides()
if strides is not None:
return [s[0] for s in sorted(enumerate(strides), key=lambda x:x[1], reverse=True)]
return []
def get_output_transpose_order(self, index: int) -> list:
output = self._raw_output(index)
if output.type() is not None and output.type().kind() == "TensorType":
strides = output.type().strides()
if strides is not None:
return [s[0] for s in sorted(enumerate(strides), key=lambda x:x[1], reverse=True)]
return []
def get_subgraph_size(self) -> int:
if isinstance(self.graph_element, torch.Node):
return len(self.get_subgraphs())

View File

@ -56,20 +56,15 @@ def get_type_from_py_type(value):
def torch_tensor_to_ov_const(torch_t: torch.Tensor, shared_memory=True):
torch_t = torch_t.to(memory_format=torch.contiguous_format)
torch_t = torch_t.contiguous()
if torch_t.dtype == torch.bfloat16:
# reinterpret bfloat16 data as float16 to allow conversion to numpy
torch_t = torch_t.view(torch.float16)
narr = torch_t.numpy(force=True)
if not narr.flags['C_CONTIGUOUS']:
narr = np.ascontiguousarray(narr)
# TODO: this tensor doesn't share memory with initial tensor
tensor = Tensor(narr, torch_t.shape, OVType.bf16)
ov_const = op.Constant(tensor, shared_memory=shared_memory)
else:
narr = torch_t.numpy(force=True)
if not narr.flags['C_CONTIGUOUS']:
narr = np.ascontiguousarray(narr)
ov_const = op.Constant(narr, shared_memory=shared_memory)
return ov_const

View File

@ -38,10 +38,6 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder {
PYBIND11_OVERRIDE_PURE(ov::Any, TorchDecoder, get_input_type, index);
}
const std::vector<size_t>& get_input_transpose_order(size_t index) const override {
PYBIND11_OVERRIDE_PURE(const std::vector<size_t>&, TorchDecoder, get_input_transpose_order, index);
}
const std::string& get_output_debug_name(size_t index) const override {
PYBIND11_OVERRIDE_PURE(const std::string&, TorchDecoder, get_output_debug_name, index);
}
@ -54,10 +50,6 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder {
PYBIND11_OVERRIDE_PURE(ov::Any, TorchDecoder, get_output_type, index);
}
const std::vector<size_t>& get_output_transpose_order(size_t index) const override {
PYBIND11_OVERRIDE_PURE(const std::vector<size_t>&, TorchDecoder, get_output_transpose_order, index);
}
bool input_is_none(size_t index) const override {
PYBIND11_OVERRIDE_PURE(bool, TorchDecoder, input_is_none, index);
}

View File

@ -44,9 +44,6 @@ public:
// (see custom_type.hpp)
virtual Any get_input_type(size_t index) const = 0;
// TODO: Consider deleting this method, probably it doesn't make sence outside Torch JIT execution
virtual const std::vector<size_t>& get_input_transpose_order(size_t index) const = 0;
// Return debug name of the input tensor
virtual const std::string& get_output_debug_name(size_t index) const = 0;
@ -56,9 +53,6 @@ public:
// Return element::Type when it the original type can be represented, otherwise returns PT-specific data type object
// (see custom_type.hpp)
virtual Any get_output_type(size_t index) const = 0;
// TODO: Consider deleting this method, probably it doesn't make sence outside Torch JIT execution
virtual const std::vector<size_t>& get_output_transpose_order(size_t index) const = 0;
// ------------------------------
// TODO: required? can be implemented in the context of a single node?

View File

@ -118,20 +118,6 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
encode_tensor_name(parameter->output(0), inputs.at(i), {pytorch_model->get_input_debug_name(i)});
parameters->push_back(parameter);
input_node = parameter;
auto order = pytorch_model->get_input_transpose_order(i);
if (order.size() > 0 && !std::is_sorted(order.begin(), order.end())) {
FRONT_END_GENERAL_CHECK(pshape.is_static(), "Shape must be static."); // TODO: make dynamic
auto sh = pshape.get_shape();
Shape new_shape(sh.size());
for (size_t i = 0; i < sh.size(); i++) {
new_shape[order[i]] = sh[i];
}
auto shape_const = v0::Constant::create(element::i32, {new_shape.size()}, new_shape);
auto reshape = std::make_shared<v1::Reshape>(parameter, shape_const, false);
auto order_const = v0::Constant::create(element::i32, {order.size()}, order);
auto transpose = std::make_shared<v1::Transpose>(reshape, order_const);
input_node = transpose;
}
}
(*tensor_map)[inputs.at(i)] = input_node;
}
@ -167,7 +153,6 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
dtype = type.as<element::Type>();
}
auto parameter = std::make_shared<v0::Parameter>(dtype, ps);
// TODO: Missing get_input_transpose_order handling for not trivial layouts
(*tensor_map)[input] = parameter;
// set name of parameter to the index of node in the model
encode_tensor_name(parameter->output(0), input);
@ -240,9 +225,6 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
(*tensor_map)[id] = parameter;
}
auto ov_output = tensor_map->at(id);
auto order = pytorch_model->get_output_transpose_order(i);
FRONT_END_GENERAL_CHECK(order.size() == 0 || std::is_sorted(order.begin(), order.end()),
"Output strides have wrong order.");
FRONT_END_GENERAL_CHECK(ov_output.get_names().size() > 0,
"Tensor doesn't have name, while it should have name: ",
id);

View File

@ -161,9 +161,6 @@ public:
virtual Any get_input_type(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_input_type);
}
virtual const std::vector<size_t>& get_input_transpose_order(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_input_transpose_order);
}
virtual const std::string& get_output_debug_name(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_output_debug_name);
}
@ -173,9 +170,6 @@ public:
virtual Any get_output_type(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_output_type);
}
virtual const std::vector<size_t>& get_output_transpose_order(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(get_output_transpose_order);
}
virtual bool input_is_none(size_t index) const override {
FRONT_END_NOT_IMPLEMENTED(input_is_none);
}

View File

@ -136,7 +136,7 @@ class PytorchLayerTest:
assert 'quant_size' in kwargs, "quant size must be specified for quantized_ops flag"
quant_size = kwargs['quant_size']
for i in range(len(infer_res)):
cur_fw_res = flatten_fw_res[i].to(memory_format=torch.contiguous_format).numpy(
cur_fw_res = flatten_fw_res[i].contiguous().numpy(
) if isinstance(flatten_fw_res[i], torch.Tensor) else flatten_fw_res[i]
if np.array(cur_fw_res).size == 0:
continue

View File

@ -41,7 +41,7 @@ class TestQuantizedConv2D(PytorchLayerTest):
x_quantized = torch.quantize_per_tensor(
x, 1.0, 0, torch.quint8)
conv = self.conv(x_quantized)
return torch.dequantize(conv).contiguous()
return torch.dequantize(conv)
ref_net = None
if not relu:
@ -54,13 +54,8 @@ class TestQuantizedConv2D(PytorchLayerTest):
@pytest.mark.parametrize(
"params",
[
pytest.param(
{"weights_shape": [1, 3, 3, 3], "strides": 1,
"pads": 0, "dilations": 1, "groups": 1},
marks=pytest.mark.xfail(
reason="Output channels equal to 1 creates output that fails to cast to contiguous."
),
),
{"weights_shape": [1, 3, 3, 3], "strides": 1,
"pads": 0, "dilations": 1, "groups": 1},
{"weights_shape": [2, 3, 3, 3], "strides": 1,
"pads": 0, "dilations": 1, "groups": 1},
{"weights_shape": [2, 3, 3, 3], "strides": 2,