Decompose/flatten tuple inputs (#18092)
* prim::TupleUnpack and prim::ListUnpack removing transformation in PT FE to flatten input list and tuples * Enabled tuples and lists as items in example_inputs * Applied code style * Added tests for tuples as inputs and extended test infrastructure to support it * Negligible performance optimizations Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com> * Fixed duplicated names of test classes * Added description for tuple flattening transformation * Removed any support for list flattening on inputs; fixed layer tests * Fixed style * Fixed order of new Parameters and Results while flattening tuples * Fixed style * Better diagnostics when not all prim::TupleUnpack ops after Parameters are decomposed * Small fix in diagnostics message --------- Co-authored-by: Ekaterina Aidova <ekaterina.aidova@intel.com> Co-authored-by: Andrei Kochin <andrei.kochin@intel.com> Co-authored-by: Roman Kazantsev <roman.kazantsev@intel.com> Co-authored-by: Alina Kladieva <alina.kladieva@intel.com>
This commit is contained in:
parent
553dab43b4
commit
dfba702c74
@ -34,6 +34,7 @@
|
||||
#include "transforms/prim_list_construct_pad.hpp"
|
||||
#include "transforms/prim_list_tuple_construct_replacer.hpp"
|
||||
#include "transforms/prim_list_unpack_replacer.hpp"
|
||||
#include "transforms/prim_tuple_unpack_parameter_replacer.hpp"
|
||||
#include "transforms/rfftn_complex_replacer.hpp"
|
||||
#include "transforms/string_equality_replacer.hpp"
|
||||
#include "transforms/tuple_unpack_replacer.hpp"
|
||||
@ -174,6 +175,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
|
||||
manager.register_pass<ov::frontend::pytorch::pass::AtenEinsumListConstructReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::MinMaxPrimListConstructReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::StringEqualityReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::DecomposeTupleParameters>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::RFFTNComplexReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::IRFFTNComplexReplacer>();
|
||||
manager.register_pass<ov::frontend::pytorch::pass::PrimTupleUnpackReplacer>();
|
||||
|
@ -3,7 +3,7 @@
|
||||
//
|
||||
#include "prim_list_tuple_construct_replacer.hpp"
|
||||
|
||||
#include <queue>
|
||||
#include <deque>
|
||||
|
||||
#include "openvino/frontend/pytorch/decoder.hpp"
|
||||
#include "openvino/op/result.hpp"
|
||||
@ -17,21 +17,24 @@ namespace pass {
|
||||
|
||||
bool DecomposeListTupleResults::run_on_model(const std::shared_ptr<Model>& model) {
|
||||
bool at_least_one_decomposed = false;
|
||||
std::queue<std::shared_ptr<ov::op::v0::Result>> results;
|
||||
for (auto res : model->get_results()) {
|
||||
results.push(res);
|
||||
}
|
||||
const auto& orig_results = model->get_results();
|
||||
std::deque<std::shared_ptr<ov::op::v0::Result>> results(orig_results.begin(), orig_results.end());
|
||||
ov::ResultVector updated_results; // will hold final fully unpacked results list
|
||||
|
||||
while (!results.empty()) {
|
||||
auto result = results.front();
|
||||
results.pop();
|
||||
results.pop_front();
|
||||
auto input_node = result->get_input_node_shared_ptr(0);
|
||||
auto tuple_construct = cast_fw_node(input_node, "prim::TupleConstruct");
|
||||
auto list_construct = cast_fw_node(input_node, "prim::ListConstruct");
|
||||
if (!tuple_construct && !list_construct) {
|
||||
updated_results.push_back(result);
|
||||
continue;
|
||||
}
|
||||
for (const auto& input : input_node->inputs()) {
|
||||
const auto& out = input.get_source_output();
|
||||
const auto& inputs = input_node->inputs();
|
||||
// enumerating inputs in reverse order because of results.push_front below
|
||||
for (auto pinput = inputs.rbegin(); pinput != inputs.rend(); ++pinput) {
|
||||
const auto& out = pinput->get_source_output();
|
||||
if (const auto& fw_node = cast_fw_node(out.get_node_shared_ptr(), "prim::Constant")) {
|
||||
const auto& attrs = fw_node->get_attrs();
|
||||
if (attrs.find("none_value") != attrs.end()) {
|
||||
@ -42,13 +45,19 @@ bool DecomposeListTupleResults::run_on_model(const std::shared_ptr<Model>& model
|
||||
}
|
||||
}
|
||||
auto new_result = std::make_shared<ov::op::v0::Result>(out);
|
||||
model->add_results({new_result});
|
||||
results.push(new_result);
|
||||
model->remove_result(result);
|
||||
results.push_front(new_result);
|
||||
at_least_one_decomposed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (at_least_one_decomposed) {
|
||||
// remove all results
|
||||
while (!model->get_results().empty())
|
||||
model->remove_result(model->get_results()[0]);
|
||||
// and replace them all by updated list of results
|
||||
model->add_results(updated_results);
|
||||
}
|
||||
|
||||
return at_least_one_decomposed;
|
||||
};
|
||||
} // namespace pass
|
||||
|
@ -0,0 +1,122 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include "prim_tuple_unpack_parameter_replacer.hpp"
|
||||
|
||||
#include <deque>
|
||||
#include <sstream>
|
||||
|
||||
#include "openvino/frontend/pytorch/decoder.hpp"
|
||||
#include "openvino/op/result.hpp"
|
||||
#include "openvino/op/util/framework_node.hpp"
|
||||
#include "utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace pass {
|
||||
|
||||
bool DecomposeTupleParameters::run_on_model(const std::shared_ptr<Model>& model) {
|
||||
bool at_least_one_decomposed = false;
|
||||
const auto& orig_parameters = model->get_parameters();
|
||||
std::deque<std::shared_ptr<ov::op::v0::Parameter>> parameters(orig_parameters.begin(), orig_parameters.end());
|
||||
ov::ParameterVector updated_parameters; // will hold final fully unpacked parameters list
|
||||
|
||||
while (!parameters.empty()) {
|
||||
auto parameter = parameters.front();
|
||||
parameters.pop_front();
|
||||
auto consumers = parameter->get_output_target_inputs(0);
|
||||
size_t num_outputs = 0; // number of outputs in each unpack consumer should match
|
||||
bool all_unpacks = true;
|
||||
|
||||
// collects all outputs per each consumer operation for this tuple Parameter
|
||||
std::vector<OutputVector> consumer_outputs;
|
||||
|
||||
// The following vector track consumer nodes having prim::TupleUnpack type to form a detailed
|
||||
// error message in case when parameter replacement is required but not possible.
|
||||
std::vector<std::shared_ptr<Node>> consumer_unpacks;
|
||||
|
||||
for (const auto& consumer : consumers) {
|
||||
auto node = consumer.get_node()->shared_from_this();
|
||||
auto tuple_unpack = cast_fw_node(node, "prim::TupleUnpack");
|
||||
if (!tuple_unpack) {
|
||||
all_unpacks = false;
|
||||
continue; // need to look at all consumers to form good diagnostics
|
||||
}
|
||||
consumer_unpacks.push_back(node);
|
||||
if (num_outputs == 0) {
|
||||
num_outputs = node->get_output_size();
|
||||
} else if (num_outputs != node->get_output_size()) {
|
||||
std::stringstream message;
|
||||
message << "Unpack node " << node
|
||||
<< " as one of the consumers of a tuple, which is introduced by parameter "
|
||||
<< parameter->output(0) << ", has number of outputs " << node->get_output_size()
|
||||
<< " not matching number of outputs " << num_outputs << " for other consumer(s) found earlier.";
|
||||
add_exception_to_fw_node(node, message.str());
|
||||
all_unpacks = false;
|
||||
break;
|
||||
}
|
||||
consumer_outputs.push_back(node->outputs());
|
||||
}
|
||||
|
||||
if (!all_unpacks || consumer_outputs.empty()) {
|
||||
// if at least one consumer is not an unpack-like op or there are not matching number of unpacked objects,
|
||||
// we cannot replace other unpacks even if they exist, leaving Unpack-op(s) in the graph for this Parameter
|
||||
|
||||
updated_parameters.push_back(parameter);
|
||||
// In case if at least one Unpack exists there is an opportinity to attach diagnostics
|
||||
for (const auto& consumer : consumer_unpacks) {
|
||||
std::stringstream message;
|
||||
message << "Not prim::TupleUnpack operations exist except this one: " << consumer
|
||||
<< " found as one of the consumers of a tuple, which is introduced by parameter "
|
||||
<< parameter->output(0) << ".";
|
||||
add_exception_to_fw_node(consumer, message.str());
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// enumerating outputs in reverse order because of parameters.push_front below
|
||||
for (size_t i = num_outputs; i--;) {
|
||||
// Merged partial shape and element type among all the consumers of i-th result of unpack ops
|
||||
PartialShape ps = PartialShape::dynamic();
|
||||
element::Type et = element::dynamic;
|
||||
std::set<Input<Node>> inputs;
|
||||
|
||||
for (const auto& outputs : consumer_outputs) {
|
||||
auto output = outputs[i];
|
||||
OPENVINO_ASSERT(PartialShape::merge_into(ps, output.get_partial_shape()),
|
||||
"Consumers for unpack op have incompatible shape");
|
||||
OPENVINO_ASSERT(element::Type::merge(et, et, output.get_element_type()),
|
||||
"Consumers for unpack op have incompatible types");
|
||||
auto target_inputs = output.get_target_inputs();
|
||||
inputs.insert(target_inputs.begin(), target_inputs.end());
|
||||
}
|
||||
|
||||
auto new_parameter = std::make_shared<ov::op::v0::Parameter>(et, ps);
|
||||
|
||||
for (auto input : inputs) {
|
||||
auto names = input.get_tensor().get_names();
|
||||
input.replace_source_output(new_parameter->output(0));
|
||||
new_parameter->output(0).add_names(names);
|
||||
}
|
||||
|
||||
// TODO: Assign correct names
|
||||
parameters.push_front(new_parameter);
|
||||
at_least_one_decomposed = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (at_least_one_decomposed) {
|
||||
// remove all parameters
|
||||
while (!model->get_parameters().empty())
|
||||
model->remove_parameter(model->get_parameters()[0]);
|
||||
// and replace them by updated list of parameters
|
||||
model->add_parameters(updated_parameters);
|
||||
}
|
||||
|
||||
return at_least_one_decomposed;
|
||||
};
|
||||
} // namespace pass
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -0,0 +1,37 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace frontend {
|
||||
namespace pytorch {
|
||||
namespace pass {
|
||||
|
||||
// This transformation replaces all prim::TupleUnpack operations coming after Parameters with
|
||||
// more Parameters -- one new parameter for each prim::TupleUnpack output. The original Parameter
|
||||
// is replaced with these new Parameters preserving the order relative to other Parameters in a model.
|
||||
// Order of new parameters is the same as the order of prim::TupleUnpack outputs.
|
||||
// If prim::TupleUnpack has a consumer that is also prim::TupleUnpack, the transformation applies
|
||||
// the replacement recursively until all prim::TupleUnpacks that take a Parameter output are eliminated.
|
||||
//
|
||||
// For example, if a model has the following signature: a, (b, (c, d)), e, where a, b, c, d, and e are
|
||||
// tensors, and (x1, x2) means tuple consisting two elements x1 and x2, then the resulting model
|
||||
// after the transformation will have a, b, c, d, e as inputs (without tuples, flattened).
|
||||
// Note, that there is no special 'tuple' type of an input, tuple structure is restored by
|
||||
// following prim::TupleUnpack operations in the graph only assuming that they can be applied on
|
||||
// tuples only and the most nested objects in those tuples are tensors.
|
||||
class DecomposeTupleParameters : public ov::pass::ModelPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ov::frontend::pytorch::pass::DecomposeTupleParameters");
|
||||
bool run_on_model(const std::shared_ptr<Model>& model) override;
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace pytorch
|
||||
} // namespace frontend
|
||||
} // namespace ov
|
@ -50,8 +50,15 @@ class PytorchLayerTest:
|
||||
else:
|
||||
inputs = self._prepare_input()
|
||||
|
||||
torch_inputs = [torch.from_numpy(inp) if isinstance(
|
||||
inp, np.ndarray) else inp for inp in inputs]
|
||||
def numpy_to_torch_recursively(x):
|
||||
if isinstance(x, tuple):
|
||||
return tuple(numpy_to_torch_recursively(y) for y in x)
|
||||
elif isinstance(x, np.ndarray):
|
||||
return torch.from_numpy(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
torch_inputs = [numpy_to_torch_recursively(inp) for inp in inputs]
|
||||
|
||||
if 'custom_eps' in kwargs and kwargs['custom_eps'] is not None:
|
||||
custom_eps = kwargs['custom_eps']
|
||||
@ -61,6 +68,8 @@ class PytorchLayerTest:
|
||||
def use_ts_backend():
|
||||
return(os.environ.get('USE_TS_BACKEND', False))
|
||||
|
||||
ov_inputs = flattenize_inputs(inputs)
|
||||
|
||||
if use_ts_backend():
|
||||
self.ts_backend_test(model, torch_inputs, custom_eps)
|
||||
else:
|
||||
@ -68,7 +77,7 @@ class PytorchLayerTest:
|
||||
model.eval()
|
||||
trace_model = kwargs.get('trace_model', False)
|
||||
freeze_model = kwargs.get('freeze_model', True)
|
||||
model, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, inputs, freeze_model)
|
||||
model, converted_model = self.convert_directly_via_frontend(model, torch_inputs, trace_model, dynamic_shapes, ov_inputs, freeze_model)
|
||||
graph = model.inlined_graph
|
||||
|
||||
if kind is not None and not isinstance(kind, (tuple, list)):
|
||||
@ -80,7 +89,7 @@ class PytorchLayerTest:
|
||||
# OV infer:
|
||||
core = Core()
|
||||
compiled = core.compile_model(converted_model, ie_device)
|
||||
infer_res = compiled(deepcopy(inputs))
|
||||
infer_res = compiled(deepcopy(ov_inputs))
|
||||
|
||||
if hasattr(self, 'skip_framework') and self.skip_framework:
|
||||
warnings.warn('Framework is skipped')
|
||||
@ -266,25 +275,33 @@ def get_params(ie_device=None, precision=None):
|
||||
return test_args
|
||||
|
||||
|
||||
def flattenize_dict_outputs(res):
|
||||
def flattenize_dict_outputs(res, types):
|
||||
if isinstance(res, dict):
|
||||
return flattenize_outputs(res.values())
|
||||
return flattenize(res.values(), types)
|
||||
|
||||
|
||||
def flattenize_outputs(res):
|
||||
def flattenize(res, types: list):
|
||||
results = []
|
||||
for res_item in res:
|
||||
# if None is at output we skip it
|
||||
if res_item is None:
|
||||
continue
|
||||
# If input is list or tuple flattenize it
|
||||
if isinstance(res_item, (list, tuple)):
|
||||
decomposed_res = flattenize_outputs(res_item)
|
||||
if isinstance(res_item, (list, tuple)) and type(res_item) in types:
|
||||
decomposed_res = flattenize(res_item, types)
|
||||
results.extend(decomposed_res)
|
||||
continue
|
||||
if isinstance(res_item, dict):
|
||||
decomposed_res = flattenize_dict_outputs(res_item)
|
||||
if isinstance(res_item, dict) and type(res_item) in types:
|
||||
decomposed_res = flattenize_dict_outputs(res_item, types)
|
||||
results.extend(decomposed_res)
|
||||
continue
|
||||
results.append(res_item)
|
||||
return results
|
||||
|
||||
|
||||
def flattenize_outputs(res):
|
||||
return flattenize(res, [list, tuple, dict])
|
||||
|
||||
|
||||
def flattenize_inputs(res):
|
||||
return flattenize(res, [tuple])
|
||||
|
@ -33,6 +33,11 @@ class TestTupleConstruct(PytorchLayerTest):
|
||||
def forward(self, x):
|
||||
return (x, [None, x + x], None)
|
||||
|
||||
class prim_tuple_construct_with_tensor_tail(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return ((x, x + x), x + x + x)
|
||||
|
||||
class prim_tuple_construct_with_list_and_tuple(torch.nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
@ -43,6 +48,7 @@ class TestTupleConstruct(PytorchLayerTest):
|
||||
"multiple": prim_tuple_construct,
|
||||
"none": prim_tuple_construct_with_none,
|
||||
"list": prim_tuple_construct_with_list,
|
||||
"tensor_tail": prim_tuple_construct_with_tensor_tail,
|
||||
"list_and_tuple": prim_tuple_construct_with_list_and_tuple
|
||||
}
|
||||
|
||||
@ -51,11 +57,11 @@ class TestTupleConstruct(PytorchLayerTest):
|
||||
|
||||
return model(), ref_net, "prim::TupleConstruct"
|
||||
|
||||
@pytest.mark.parametrize("case", ["single", "multiple", "none", "list", "list_and_tuple"])
|
||||
@pytest.mark.parametrize("case", ["single", "multiple", "none", "list", "tensor_tail", "list_and_tuple"])
|
||||
@pytest.mark.nightly
|
||||
def test_tuple_construct(self, case, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(case), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
|
||||
class TestTupleConstructTupleUnpack(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
@ -69,7 +75,7 @@ class TestTupleConstructTupleUnpack(PytorchLayerTest):
|
||||
def forward(self, x):
|
||||
x1, x2, x3, x4, x5 = self.prepare_input(x)
|
||||
return x1, x2, x3, x4, x5
|
||||
|
||||
|
||||
def prepare_input(self, x):
|
||||
return x, x + 2, None, x.reshape(-1), (x * 10).to(torch.int32)
|
||||
|
||||
@ -80,4 +86,106 @@ class TestTupleConstructTupleUnpack(PytorchLayerTest):
|
||||
|
||||
@pytest.mark.nightly
|
||||
def test_tuple_construct_unpack(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version, freeze_model=False)
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version, freeze_model=False)
|
||||
|
||||
|
||||
class TestTupleUnpackParameterSingle(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
def tensor_gen():
|
||||
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
|
||||
return ( (tensor_gen(), tensor_gen()), )
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
class model(torch.nn.Module):
|
||||
|
||||
def forward(self, x: Tuple[torch.Tensor, torch.Tensor]):
|
||||
x1, x2 = x
|
||||
return x1, x2
|
||||
|
||||
|
||||
return model(), None, ["prim::TupleUnpack"]
|
||||
|
||||
@pytest.mark.nightly
|
||||
def test(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestTupleUnpackParameterSingleMixed(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
def tensor_gen():
|
||||
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
|
||||
# generate tensor with a different shape for easier mismatch detection in case of mixed input order
|
||||
def tensor_gen_2():
|
||||
return np.random.uniform(0, 50, (2, 3)).astype(np.float32)
|
||||
return (tensor_gen_2(), (tensor_gen(), tensor_gen()), tensor_gen_2())
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
class model(torch.nn.Module):
|
||||
|
||||
def forward(self, y1, x: Tuple[torch.Tensor, torch.Tensor], y2):
|
||||
x1, x2 = x
|
||||
return x1, x2, y1, y2
|
||||
|
||||
|
||||
return model(), None, ["prim::TupleUnpack"]
|
||||
|
||||
@pytest.mark.nightly
|
||||
def test(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestTupleUnpackParameterNested(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
def tensor_gen():
|
||||
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
|
||||
return ( ((tensor_gen(), tensor_gen()), (tensor_gen(), tensor_gen())), )
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
class model(torch.nn.Module):
|
||||
|
||||
def forward(self, x: Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]):
|
||||
x1, x2 = x
|
||||
y1, y2 = x1
|
||||
y3, y4 = x2
|
||||
return y1, y2, y3, y4
|
||||
|
||||
|
||||
return model(), None, ["prim::TupleUnpack"]
|
||||
|
||||
@pytest.mark.nightly
|
||||
def test(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
||||
|
||||
|
||||
class TestTupleUnpackParameterMultiple(PytorchLayerTest):
|
||||
def _prepare_input(self):
|
||||
def tensor_gen():
|
||||
return np.random.uniform(0, 50, (1, 2, 10)).astype(np.float32)
|
||||
return ( (tensor_gen(), tensor_gen()), (tensor_gen(), tensor_gen()) )
|
||||
|
||||
def create_model(self):
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
class model(torch.nn.Module):
|
||||
|
||||
def forward(self, x: Tuple[torch.Tensor, torch.Tensor], y: Tuple[torch.Tensor, torch.Tensor]):
|
||||
z1, z2 = x
|
||||
z3, z4 = y
|
||||
return z1, z2, z3, z4
|
||||
|
||||
|
||||
return model(), None, ["prim::TupleUnpack"]
|
||||
|
||||
@pytest.mark.nightly
|
||||
def test(self, ie_device, precision, ir_version):
|
||||
self._test(*self.create_model(), ie_device, precision, ir_version)
|
||||
|
@ -43,7 +43,7 @@ def update_list_or_dict(container, name, idx, value):
|
||||
container[idx] = value
|
||||
return
|
||||
|
||||
|
||||
|
||||
def get_value_from_list_or_dict(container, name, idx):
|
||||
if isinstance(container, dict):
|
||||
if name is None:
|
||||
@ -87,8 +87,8 @@ def extract_input_info_from_example(args, inputs):
|
||||
example_dtype = pt_to_ov_type_map.get(str(dtype))
|
||||
user_dtype = get_value_from_list_or_dict(data_types, input_name, input_id)
|
||||
if user_dtype is not None and example_dtype.to_dtype() != user_dtype:
|
||||
raise Error(f"Defined input type {user_dtype} is not equal to provided example_input type {example_dtype.to_dtype()}")
|
||||
|
||||
raise Error(f"Defined input type {user_dtype} is not equal to provided example_input type {example_dtype.to_dtype()}")
|
||||
|
||||
data_rank = getattr(example_input, "ndim", 0)
|
||||
user_input_shape = get_value_from_list_or_dict(input_shapes, input_name, input_id)
|
||||
if user_input_shape.rank.get_length() != data_rank:
|
||||
@ -108,7 +108,7 @@ def extract_input_info_from_example(args, inputs):
|
||||
input_name = input_names[input_id] if input_names else None
|
||||
update_list_or_dict(input_shapes, input_name, input_id, input_shape)
|
||||
update_list_or_dict(data_types, input_name, input_id, ov_dtype.to_dtype())
|
||||
|
||||
|
||||
args.placeholder_data_types = data_types
|
||||
args.placeholder_shapes = input_shapes
|
||||
if not args.input and input_names:
|
||||
@ -126,6 +126,9 @@ def to_torch_tensor(tensor):
|
||||
return torch.tensor(tensor.data)
|
||||
if isinstance(tensor, (float, int, bool)):
|
||||
return tensor
|
||||
if isinstance(tensor, tuple):
|
||||
# TODO: Function to_torch_tensor should be renamed as it handles not only a tensor
|
||||
return tuple(to_torch_tensor(x) for x in tensor)
|
||||
else:
|
||||
raise Error("Unexpected type of example_input. Supported types torch.Tensor, np.array or ov.Tensor. "
|
||||
"Got {}".format(type(tensor)))
|
||||
|
Loading…
Reference in New Issue
Block a user