Fix aux tensor shape issue

This commit is contained in:
River,Li 2023-06-25 19:24:38 +08:00
parent f2bea40a7d
commit e56f5a2bfe
2 changed files with 18 additions and 9 deletions

View File

@ -1068,6 +1068,12 @@ void Graph::PullOutputData(std::unordered_map<std::string, ov::Tensor>& out,
if (it == aux_tensors.end()) {
OPENVINO_THROW("Output precision has been changed, but cannot find its aux tensor.");
}
auto& aux_tensor = it->second;
// Dynamic case
if (outDims != aux_tensor.get_shape()) {
aux_tensor.set_shape(outDims);
}
void* ext_blob_ptr = it->second.data();
if ((intr_blob_ptr == nullptr) || (ext_blob_ptr == nullptr)) {
OPENVINO_THROW("Get tensor has no allocated memory");

View File

@ -468,9 +468,8 @@ ov::Tensor SyncInferRequest::get_tensor(const ov::Output<const ov::Node>& _port)
// If precision has been changed, it need return original precision tensor
// port's data will be stored in _aux_tensors, and need converted to compiled tensor
// input tensor: will be copied to compiled tensor when sent to do inference
// output tensor: need copy compiled tensor to aux tensor and return aux tensor
bool is_input = ov::op::util::is_parameter(port.get_node());
// input tensor: will be copied to compiled tensor before do graph inference
// output tensor: has be copied from graph's memory to aux tensor
// Find aux tensor, will create one if cannot find
if (_aux_tensors.find(port_name) == _aux_tensors.end()) {
@ -478,14 +477,14 @@ ov::Tensor SyncInferRequest::get_tensor(const ov::Output<const ov::Node>& _port)
if (it == _orig_ports_map.end()) {
OPENVINO_THROW("Cannot find original port, name: ", port_name);
}
auto external_partial_shape = _orig_ports_map[port_name].get_partial_shape();
ov::Shape external_shape;
if (external_partial_shape.is_dynamic()) {
external_shape = compiled_tensor.get_shape();
auto port_shape = _orig_ports_map[port_name].get_partial_shape();
ov::Shape aux_shape;
if (port_shape.is_dynamic()) {
aux_shape = compiled_tensor.get_shape();
} else {
external_shape = _orig_ports_map[port_name].get_shape();
aux_shape = _orig_ports_map[port_name].get_shape();
}
_aux_tensors[port_name] = ov::Tensor(_orig_ports_map[port_name].get_element_type(), external_shape);
_aux_tensors[port_name] = ov::Tensor(_orig_ports_map[port_name].get_element_type(), aux_shape);
}
return _aux_tensors[port_name];
@ -775,6 +774,10 @@ void SyncInferRequest::PushInputData() {
auto tensor = get_compiled_tensor(input);
if (_aux_tensors.find(input_name) != _aux_tensors.end()) {
auto& aux_tensor = _aux_tensors[input_name];
if (aux_tensor.get_shape() != tensor.get_shape()) {
tensor.set_shape(aux_tensor.get_shape());
}
const void* srcData = aux_tensor.data();
void* dstData = tensor.data();
if ((dstData == nullptr) || (srcData == nullptr)) {