Fix aux tensor shape issue
This commit is contained in:
parent
f2bea40a7d
commit
e56f5a2bfe
@ -1068,6 +1068,12 @@ void Graph::PullOutputData(std::unordered_map<std::string, ov::Tensor>& out,
|
||||
if (it == aux_tensors.end()) {
|
||||
OPENVINO_THROW("Output precision has been changed, but cannot find its aux tensor.");
|
||||
}
|
||||
|
||||
auto& aux_tensor = it->second;
|
||||
// Dynamic case
|
||||
if (outDims != aux_tensor.get_shape()) {
|
||||
aux_tensor.set_shape(outDims);
|
||||
}
|
||||
void* ext_blob_ptr = it->second.data();
|
||||
if ((intr_blob_ptr == nullptr) || (ext_blob_ptr == nullptr)) {
|
||||
OPENVINO_THROW("Get tensor has no allocated memory");
|
||||
|
@ -468,9 +468,8 @@ ov::Tensor SyncInferRequest::get_tensor(const ov::Output<const ov::Node>& _port)
|
||||
|
||||
// If precision has been changed, it need return original precision tensor
|
||||
// port's data will be stored in _aux_tensors, and need converted to compiled tensor
|
||||
// input tensor: will be copied to compiled tensor when sent to do inference
|
||||
// output tensor: need copy compiled tensor to aux tensor and return aux tensor
|
||||
bool is_input = ov::op::util::is_parameter(port.get_node());
|
||||
// input tensor: will be copied to compiled tensor before do graph inference
|
||||
// output tensor: has be copied from graph's memory to aux tensor
|
||||
|
||||
// Find aux tensor, will create one if cannot find
|
||||
if (_aux_tensors.find(port_name) == _aux_tensors.end()) {
|
||||
@ -478,14 +477,14 @@ ov::Tensor SyncInferRequest::get_tensor(const ov::Output<const ov::Node>& _port)
|
||||
if (it == _orig_ports_map.end()) {
|
||||
OPENVINO_THROW("Cannot find original port, name: ", port_name);
|
||||
}
|
||||
auto external_partial_shape = _orig_ports_map[port_name].get_partial_shape();
|
||||
ov::Shape external_shape;
|
||||
if (external_partial_shape.is_dynamic()) {
|
||||
external_shape = compiled_tensor.get_shape();
|
||||
auto port_shape = _orig_ports_map[port_name].get_partial_shape();
|
||||
ov::Shape aux_shape;
|
||||
if (port_shape.is_dynamic()) {
|
||||
aux_shape = compiled_tensor.get_shape();
|
||||
} else {
|
||||
external_shape = _orig_ports_map[port_name].get_shape();
|
||||
aux_shape = _orig_ports_map[port_name].get_shape();
|
||||
}
|
||||
_aux_tensors[port_name] = ov::Tensor(_orig_ports_map[port_name].get_element_type(), external_shape);
|
||||
_aux_tensors[port_name] = ov::Tensor(_orig_ports_map[port_name].get_element_type(), aux_shape);
|
||||
}
|
||||
|
||||
return _aux_tensors[port_name];
|
||||
@ -775,6 +774,10 @@ void SyncInferRequest::PushInputData() {
|
||||
auto tensor = get_compiled_tensor(input);
|
||||
if (_aux_tensors.find(input_name) != _aux_tensors.end()) {
|
||||
auto& aux_tensor = _aux_tensors[input_name];
|
||||
|
||||
if (aux_tensor.get_shape() != tensor.get_shape()) {
|
||||
tensor.set_shape(aux_tensor.get_shape());
|
||||
}
|
||||
const void* srcData = aux_tensor.data();
|
||||
void* dstData = tensor.data();
|
||||
if ((dstData == nullptr) || (srcData == nullptr)) {
|
||||
|
Loading…
Reference in New Issue
Block a user