[IE Samples] Changed input's tensor preprocessing for speech sample (#10552)

* Changed input's tensor preprocessing

* improved processing
This commit is contained in:
Maxim Gordeev 2022-02-21 23:29:38 +03:00 committed by GitHub
parent d26fd3aa22
commit e7145bd343
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -506,13 +506,42 @@ int main(int argc, char* argv[]) {
inferRequest.frameIndex = -1; inferRequest.frameIndex = -1;
continue; continue;
} }
ptrInputBlobs.clear();
if (FLAGS_iname.empty()) {
for (auto& input : cInputInfo) {
ptrInputBlobs.push_back(inferRequest.inferRequest.get_tensor(input));
}
} else {
std::vector<std::string> inputNameBlobs = convert_str_to_vector(FLAGS_iname);
for (const auto& input : inputNameBlobs) {
ov::Tensor blob = inferRequests.begin()->inferRequest.get_tensor(input);
if (!blob) {
std::string errMessage("No blob with name : " + input);
throw std::logic_error(errMessage);
}
ptrInputBlobs.push_back(blob);
}
}
/** Iterate over all the input blobs **/
for (size_t i = 0; i < numInputFiles; ++i) {
ov::Tensor minput = ptrInputBlobs[i];
if (!minput) {
std::string errMessage("We expect ptrInputBlobs[" + std::to_string(i) +
"] to be inherited from Tensor, " +
"but in fact we were not able to cast input to Tensor");
throw std::logic_error(errMessage);
}
memcpy(minput.data<float>(), inputFrame[i], minput.get_byte_size());
// Used to infer fewer frames than the batch size
if (batchSize != numFramesThisBatch) {
memset(minput.data<float>() + numFramesThisBatch * numFrameElementsInput[i],
0,
(batchSize - numFramesThisBatch) * numFrameElementsInput[i]);
}
}
// ----------------------------------------------------------------------------------------------------- // -----------------------------------------------------------------------------------------------------
int index = static_cast<int>(frameIndex) - (FLAGS_cw_l + FLAGS_cw_r); int index = static_cast<int>(frameIndex) - (FLAGS_cw_l + FLAGS_cw_r);
for (int i = 0; i < executableNet.inputs().size(); i++) {
inferRequest.inferRequest.set_input_tensor(
i,
ov::Tensor(ov::element::f32, executableNet.inputs()[i].get_shape(), inputFrame[i]));
}
/* Starting inference in asynchronous mode*/ /* Starting inference in asynchronous mode*/
inferRequest.inferRequest.start_async(); inferRequest.inferRequest.start_async();
inferRequest.frameIndex = index < 0 ? -2 : index; inferRequest.frameIndex = index < 0 ? -2 : index;