Add method to classification_result for returning results (#1460)

This commit is contained in:
Harald Rotuna
2020-08-21 11:59:12 +03:00
committed by GitHub
parent 762cb8d6ab
commit 84028be544

View File

@@ -29,6 +29,7 @@ private:
const std::vector<std::string> _labels;
const std::vector<strType> _imageNames;
const size_t _batchSize;
std::vector<unsigned> _results;
void printHeader() {
std::cout << _classidStr << " " << _probabilityStr;
@@ -123,20 +124,18 @@ public:
_outBlob(std::move(output_blob)),
_labels(std::move(labels)),
_imageNames(std::move(image_names)),
_batchSize(batch_size) {
_batchSize(batch_size),
_results() {
if (_imageNames.size() != _batchSize) {
throw std::logic_error("Batch size should be equal to the number of images.");
}
topResults(_nTop, *_outBlob, _results);
}
/**
* @brief prints formatted classification results
*/
void print() {
/** This vector stores id's of top N results **/
std::vector<unsigned> results;
topResults(_nTop, *_outBlob, results);
/** Print the result iterating over each batch **/
std::cout << std::endl << "Top " << _nTop << " results:" << std::endl << std::endl;
for (unsigned int image_id = 0; image_id < _batchSize; ++image_id) {
@@ -160,19 +159,26 @@ public:
const auto result = moutputHolder.
as<const InferenceEngine::PrecisionTrait<InferenceEngine::Precision::FP32>::value_type*>()
[results[id] + image_id * (_outBlob->size() / _batchSize)];
[_results[id] + image_id * (_outBlob->size() / _batchSize)];
std::cout << std::setw(static_cast<int>(_classidStr.length())) << std::left << results[id] << " ";
std::cout << std::setw(static_cast<int>(_classidStr.length())) << std::left << _results[id] << " ";
std::cout << std::left << std::setw(static_cast<int>(_probabilityStr.length())) << std::fixed << result;
if (!_labels.empty()) {
std::cout << " " + _labels[results[id]];
std::cout << " " + _labels[_results[id]];
}
std::cout << std::endl;
}
std::cout << std::endl;
}
}
/**
* @brief returns the classification results in a vector
*/
std::vector<unsigned> getResults() {
return _results;
}
};
using ClassificationResult = ClassificationResultT<>;