[PT FE] Remove torch strides compensation (#19129)
This commit is contained in:
parent
c1c4c4cd51
commit
cc08b0091e
@ -154,28 +154,6 @@ class TorchFXPythonDecoder (Decoder):
|
||||
else:
|
||||
return OVAny(OVType.f32)
|
||||
|
||||
def get_input_transpose_order(self, index):
|
||||
return []
|
||||
# TODO TBD
|
||||
|
||||
input = self._raw_input(index)
|
||||
if input.type() is not None and input.type().kind() == 'TensorType':
|
||||
strides = input.type().strides()
|
||||
if strides is not None:
|
||||
return [s[0] for s in sorted(enumerate(strides), key=lambda x:x[1], reverse=True)]
|
||||
return []
|
||||
|
||||
def get_output_transpose_order(self, index):
|
||||
return []
|
||||
|
||||
# old code
|
||||
output = self._raw_output(index)
|
||||
if output.type() is not None and output.type().kind() == 'TensorType':
|
||||
strides = output.type().strides()
|
||||
if strides is not None:
|
||||
return [s[0] for s in sorted(enumerate(strides), key=lambda x:x[1], reverse=True)]
|
||||
return []
|
||||
|
||||
def get_subgraph_size(self):
|
||||
if issubclass(type(self.pt_module), torch.fx.Node):
|
||||
return 0
|
||||
|
@ -261,22 +261,6 @@ class TorchScriptPythonDecoder (Decoder):
|
||||
full_type = self._get_known_type_for_value(value.type())
|
||||
return full_type
|
||||
|
||||
def get_input_transpose_order(self, index: int) -> list:
|
||||
raw_input = self._raw_input(index)
|
||||
if raw_input.type() is not None and raw_input.type().kind() == "TensorType":
|
||||
strides = raw_input.type().strides()
|
||||
if strides is not None:
|
||||
return [s[0] for s in sorted(enumerate(strides), key=lambda x:x[1], reverse=True)]
|
||||
return []
|
||||
|
||||
def get_output_transpose_order(self, index: int) -> list:
|
||||
output = self._raw_output(index)
|
||||
if output.type() is not None and output.type().kind() == "TensorType":
|
||||
strides = output.type().strides()
|
||||
if strides is not None:
|
||||
return [s[0] for s in sorted(enumerate(strides), key=lambda x:x[1], reverse=True)]
|
||||
return []
|
||||
|
||||
def get_subgraph_size(self) -> int:
|
||||
if isinstance(self.graph_element, torch.Node):
|
||||
return len(self.get_subgraphs())
|
||||
|
@ -56,20 +56,15 @@ def get_type_from_py_type(value):
|
||||
|
||||
|
||||
def torch_tensor_to_ov_const(torch_t: torch.Tensor, shared_memory=True):
|
||||
torch_t = torch_t.to(memory_format=torch.contiguous_format)
|
||||
torch_t = torch_t.contiguous()
|
||||
if torch_t.dtype == torch.bfloat16:
|
||||
# reinterpret bfloat16 data as float16 to allow conversion to numpy
|
||||
torch_t = torch_t.view(torch.float16)
|
||||
narr = torch_t.numpy(force=True)
|
||||
if not narr.flags['C_CONTIGUOUS']:
|
||||
narr = np.ascontiguousarray(narr)
|
||||
# TODO: this tensor doesn't share memory with initial tensor
|
||||
tensor = Tensor(narr, torch_t.shape, OVType.bf16)
|
||||
ov_const = op.Constant(tensor, shared_memory=shared_memory)
|
||||
else:
|
||||
narr = torch_t.numpy(force=True)
|
||||
if not narr.flags['C_CONTIGUOUS']:
|
||||
narr = np.ascontiguousarray(narr)
|
||||
ov_const = op.Constant(narr, shared_memory=shared_memory)
|
||||
return ov_const
|
||||
|
||||
|
@ -38,10 +38,6 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder {
|
||||
PYBIND11_OVERRIDE_PURE(ov::Any, TorchDecoder, get_input_type, index);
|
||||
}
|
||||
|
||||
const std::vector<size_t>& get_input_transpose_order(size_t index) const override {
|
||||
PYBIND11_OVERRIDE_PURE(const std::vector<size_t>&, TorchDecoder, get_input_transpose_order, index);
|
||||
}
|
||||
|
||||
const std::string& get_output_debug_name(size_t index) const override {
|
||||
PYBIND11_OVERRIDE_PURE(const std::string&, TorchDecoder, get_output_debug_name, index);
|
||||
}
|
||||
@ -54,10 +50,6 @@ class PyDecoder : public ov::frontend::pytorch::TorchDecoder {
|
||||
PYBIND11_OVERRIDE_PURE(ov::Any, TorchDecoder, get_output_type, index);
|
||||
}
|
||||
|
||||
const std::vector<size_t>& get_output_transpose_order(size_t index) const override {
|
||||
PYBIND11_OVERRIDE_PURE(const std::vector<size_t>&, TorchDecoder, get_output_transpose_order, index);
|
||||
}
|
||||
|
||||
bool input_is_none(size_t index) const override {
|
||||
PYBIND11_OVERRIDE_PURE(bool, TorchDecoder, input_is_none, index);
|
||||
}
|
||||
|
@ -44,9 +44,6 @@ public:
|
||||
// (see custom_type.hpp)
|
||||
virtual Any get_input_type(size_t index) const = 0;
|
||||
|
||||
// TODO: Consider deleting this method, probably it doesn't make sence outside Torch JIT execution
|
||||
virtual const std::vector<size_t>& get_input_transpose_order(size_t index) const = 0;
|
||||
|
||||
// Return debug name of the input tensor
|
||||
virtual const std::string& get_output_debug_name(size_t index) const = 0;
|
||||
|
||||
@ -56,9 +53,6 @@ public:
|
||||
// Return element::Type when it the original type can be represented, otherwise returns PT-specific data type object
|
||||
// (see custom_type.hpp)
|
||||
virtual Any get_output_type(size_t index) const = 0;
|
||||
|
||||
// TODO: Consider deleting this method, probably it doesn't make sence outside Torch JIT execution
|
||||
virtual const std::vector<size_t>& get_output_transpose_order(size_t index) const = 0;
|
||||
// ------------------------------
|
||||
|
||||
// TODO: required? can be implemented in the context of a single node?
|
||||
|
@ -118,20 +118,6 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
||||
encode_tensor_name(parameter->output(0), inputs.at(i), {pytorch_model->get_input_debug_name(i)});
|
||||
parameters->push_back(parameter);
|
||||
input_node = parameter;
|
||||
auto order = pytorch_model->get_input_transpose_order(i);
|
||||
if (order.size() > 0 && !std::is_sorted(order.begin(), order.end())) {
|
||||
FRONT_END_GENERAL_CHECK(pshape.is_static(), "Shape must be static."); // TODO: make dynamic
|
||||
auto sh = pshape.get_shape();
|
||||
Shape new_shape(sh.size());
|
||||
for (size_t i = 0; i < sh.size(); i++) {
|
||||
new_shape[order[i]] = sh[i];
|
||||
}
|
||||
auto shape_const = v0::Constant::create(element::i32, {new_shape.size()}, new_shape);
|
||||
auto reshape = std::make_shared<v1::Reshape>(parameter, shape_const, false);
|
||||
auto order_const = v0::Constant::create(element::i32, {order.size()}, order);
|
||||
auto transpose = std::make_shared<v1::Transpose>(reshape, order_const);
|
||||
input_node = transpose;
|
||||
}
|
||||
}
|
||||
(*tensor_map)[inputs.at(i)] = input_node;
|
||||
}
|
||||
@ -167,7 +153,6 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
||||
dtype = type.as<element::Type>();
|
||||
}
|
||||
auto parameter = std::make_shared<v0::Parameter>(dtype, ps);
|
||||
// TODO: Missing get_input_transpose_order handling for not trivial layouts
|
||||
(*tensor_map)[input] = parameter;
|
||||
// set name of parameter to the index of node in the model
|
||||
encode_tensor_name(parameter->output(0), input);
|
||||
@ -240,9 +225,6 @@ std::shared_ptr<Model> TranslateSession::convert_pytorch_model(
|
||||
(*tensor_map)[id] = parameter;
|
||||
}
|
||||
auto ov_output = tensor_map->at(id);
|
||||
auto order = pytorch_model->get_output_transpose_order(i);
|
||||
FRONT_END_GENERAL_CHECK(order.size() == 0 || std::is_sorted(order.begin(), order.end()),
|
||||
"Output strides have wrong order.");
|
||||
FRONT_END_GENERAL_CHECK(ov_output.get_names().size() > 0,
|
||||
"Tensor doesn't have name, while it should have name: ",
|
||||
id);
|
||||
|
@ -161,9 +161,6 @@ public:
|
||||
virtual Any get_input_type(size_t index) const override {
|
||||
FRONT_END_NOT_IMPLEMENTED(get_input_type);
|
||||
}
|
||||
virtual const std::vector<size_t>& get_input_transpose_order(size_t index) const override {
|
||||
FRONT_END_NOT_IMPLEMENTED(get_input_transpose_order);
|
||||
}
|
||||
virtual const std::string& get_output_debug_name(size_t index) const override {
|
||||
FRONT_END_NOT_IMPLEMENTED(get_output_debug_name);
|
||||
}
|
||||
@ -173,9 +170,6 @@ public:
|
||||
virtual Any get_output_type(size_t index) const override {
|
||||
FRONT_END_NOT_IMPLEMENTED(get_output_type);
|
||||
}
|
||||
virtual const std::vector<size_t>& get_output_transpose_order(size_t index) const override {
|
||||
FRONT_END_NOT_IMPLEMENTED(get_output_transpose_order);
|
||||
}
|
||||
virtual bool input_is_none(size_t index) const override {
|
||||
FRONT_END_NOT_IMPLEMENTED(input_is_none);
|
||||
}
|
||||
|
@ -136,7 +136,7 @@ class PytorchLayerTest:
|
||||
assert 'quant_size' in kwargs, "quant size must be specified for quantized_ops flag"
|
||||
quant_size = kwargs['quant_size']
|
||||
for i in range(len(infer_res)):
|
||||
cur_fw_res = flatten_fw_res[i].to(memory_format=torch.contiguous_format).numpy(
|
||||
cur_fw_res = flatten_fw_res[i].contiguous().numpy(
|
||||
) if isinstance(flatten_fw_res[i], torch.Tensor) else flatten_fw_res[i]
|
||||
if np.array(cur_fw_res).size == 0:
|
||||
continue
|
||||
|
@ -41,7 +41,7 @@ class TestQuantizedConv2D(PytorchLayerTest):
|
||||
x_quantized = torch.quantize_per_tensor(
|
||||
x, 1.0, 0, torch.quint8)
|
||||
conv = self.conv(x_quantized)
|
||||
return torch.dequantize(conv).contiguous()
|
||||
return torch.dequantize(conv)
|
||||
|
||||
ref_net = None
|
||||
if not relu:
|
||||
@ -54,13 +54,8 @@ class TestQuantizedConv2D(PytorchLayerTest):
|
||||
@pytest.mark.parametrize(
|
||||
"params",
|
||||
[
|
||||
pytest.param(
|
||||
{"weights_shape": [1, 3, 3, 3], "strides": 1,
|
||||
"pads": 0, "dilations": 1, "groups": 1},
|
||||
marks=pytest.mark.xfail(
|
||||
reason="Output channels equal to 1 creates output that fails to cast to contiguous."
|
||||
),
|
||||
),
|
||||
{"weights_shape": [1, 3, 3, 3], "strides": 1,
|
||||
"pads": 0, "dilations": 1, "groups": 1},
|
||||
{"weights_shape": [2, 3, 3, 3], "strides": 1,
|
||||
"pads": 0, "dilations": 1, "groups": 1},
|
||||
{"weights_shape": [2, 3, 3, 3], "strides": 2,
|
||||
|
Loading…
Reference in New Issue
Block a user