diff --git a/inference-engine/samples/common/samples/classification_results.h b/inference-engine/samples/common/samples/classification_results.h index 12b4c10fcaf..f623d2c0f24 100644 --- a/inference-engine/samples/common/samples/classification_results.h +++ b/inference-engine/samples/common/samples/classification_results.h @@ -29,6 +29,7 @@ private: const std::vector _labels; const std::vector _imageNames; const size_t _batchSize; + std::vector _results; void printHeader() { std::cout << _classidStr << " " << _probabilityStr; @@ -123,20 +124,18 @@ public: _outBlob(std::move(output_blob)), _labels(std::move(labels)), _imageNames(std::move(image_names)), - _batchSize(batch_size) { + _batchSize(batch_size), + _results() { if (_imageNames.size() != _batchSize) { throw std::logic_error("Batch size should be equal to the number of images."); } + topResults(_nTop, *_outBlob, _results); } /** * @brief prints formatted classification results */ void print() { - /** This vector stores id's of top N results **/ - std::vector results; - topResults(_nTop, *_outBlob, results); - /** Print the result iterating over each batch **/ std::cout << std::endl << "Top " << _nTop << " results:" << std::endl << std::endl; for (unsigned int image_id = 0; image_id < _batchSize; ++image_id) { @@ -160,19 +159,26 @@ public: const auto result = moutputHolder. as::value_type*>() - [results[id] + image_id * (_outBlob->size() / _batchSize)]; + [_results[id] + image_id * (_outBlob->size() / _batchSize)]; - std::cout << std::setw(static_cast(_classidStr.length())) << std::left << results[id] << " "; + std::cout << std::setw(static_cast(_classidStr.length())) << std::left << _results[id] << " "; std::cout << std::left << std::setw(static_cast(_probabilityStr.length())) << std::fixed << result; if (!_labels.empty()) { - std::cout << " " + _labels[results[id]]; + std::cout << " " + _labels[_results[id]]; } std::cout << std::endl; } std::cout << std::endl; } } + + /** + * @brief returns the classification results in a vector + */ + std::vector getResults() { + return _results; + } }; using ClassificationResult = ClassificationResultT<>;