diff --git a/inference-engine/samples/speech_sample/main.cpp b/inference-engine/samples/speech_sample/main.cpp index 30307713b00..58b4b98874b 100644 --- a/inference-engine/samples/speech_sample/main.cpp +++ b/inference-engine/samples/speech_sample/main.cpp @@ -404,6 +404,28 @@ void sumPerformanceCounters(std::map ParseScaleFactors(const std::string& str) { + std::vector scaleFactorInput; + + if (!str.empty()) { + std::string outStr; + std::istringstream stream(str); + int i = 0; + while (getline(stream, outStr, ',')) { + auto floatScaleFactor = std::stof(outStr); + if (floatScaleFactor <= 0.0f) { + throw std::logic_error("Scale factor for input #" + std::to_string(i) + + " (counting from zero) is out of range (must be positive)."); + } + scaleFactorInput.push_back(outStr); + i++; + } + } else { + throw std::logic_error("Scale factor need to be specified via -sf option if you are using -q user"); + } + return scaleFactorInput; +} + bool ParseAndCheckCommandLine(int argc, char *argv[]) { // ---------------------------Parsing and validation of input args-------------------------------------- slog::info << "Parsing input parameters" << slog::endl; @@ -453,11 +475,6 @@ bool ParseAndCheckCommandLine(int argc, char *argv[]) { throw std::logic_error("Specified device is not supported."); } - float scaleFactorInput = static_cast(FLAGS_sf); - if (scaleFactorInput <= 0.0f) { - throw std::logic_error("Scale factor out of range (must be non-negative)."); - } - uint32_t batchSize = (uint32_t) FLAGS_bs; if ((batchSize < 1) || (batchSize > 8)) { throw std::logic_error("Batch size out of range (1..8)."); @@ -515,7 +532,6 @@ int main(int argc, char *argv[]) { bool useHetero = isFeature("HETERO"); std::string deviceStr = useHetero && useGna ? "HETERO:GNA,CPU" : FLAGS_d.substr(0, (FLAGS_d.find("_"))); - float scaleFactorInput = static_cast(FLAGS_sf); uint32_t batchSize = (FLAGS_cw_r > 0 || FLAGS_cw_l > 0) ? 1 : (uint32_t) FLAGS_bs; std::vector inputArkFiles; @@ -586,12 +602,18 @@ int main(int argc, char *argv[]) { } if (FLAGS_q.compare("user") == 0) { - if (numInputArkFiles > 1) { - std::string errMessage("Incorrect use case for multiple input ark files. Please don't use -q 'user' for this case."); + auto scaleFactorInput = ParseScaleFactors(FLAGS_sf); + if (scaleFactorInput.size() != network.getInputsInfo().size()) { + std::string errMessage("Incorrect command line for multiple inputs: " + + std::to_string(scaleFactorInput.size()) + " scale factors provided for " + + std::to_string(network.getInputsInfo().size()) + " inputs."); throw std::logic_error(errMessage); } - slog::info << "Using scale factor of " << FLAGS_sf << slog::endl; - gnaPluginConfig[GNA_CONFIG_KEY(SCALE_FACTOR)] = std::to_string(FLAGS_sf); + for (size_t i = 0; i < scaleFactorInput.size(); ++i) { + slog::info << "For input " << i << " using scale factor of " << scaleFactorInput[i] << slog::endl; + std::string scaleFactorConfigKey = GNA_CONFIG_KEY(SCALE_FACTOR) + std::string("_") + std::to_string(i); + gnaPluginConfig[scaleFactorConfigKey] = scaleFactorInput[i]; + } } else { // "static" quantization with calculated scale factor for (size_t i = 0; i < numInputArkFiles; i++) { @@ -608,12 +630,12 @@ int main(int argc, char *argv[]) { &numFrames, &numFrameElements, &numBytesPerElement); - scaleFactorInput = + auto floatScaleFactor = ScaleFactorForQuantization(ptrFeatures.data(), MAX_VAL_2B_FEAT, numFrames * numFrameElements); - slog::info << "Using scale factor of " << scaleFactorInput << " calculated from first utterance." + slog::info << "Using scale factor of " << floatScaleFactor << " calculated from first utterance." << slog::endl; std::string scaleFactorConfigKey = GNA_CONFIG_KEY(SCALE_FACTOR) + std::string("_") + std::to_string(i); - gnaPluginConfig[scaleFactorConfigKey] = std::to_string(scaleFactorInput); + gnaPluginConfig[scaleFactorConfigKey] = std::to_string(floatScaleFactor); } } diff --git a/inference-engine/samples/speech_sample/speech_sample.hpp b/inference-engine/samples/speech_sample/speech_sample.hpp index 4547c380249..a76a51c45a7 100644 --- a/inference-engine/samples/speech_sample/speech_sample.hpp +++ b/inference-engine/samples/speech_sample/speech_sample.hpp @@ -61,7 +61,8 @@ static const char quantization_message[] = "Input quantization mode: static (de static const char quantization_bits_message[] = "Weight bits for quantization: 8 or 16 (default)"; /// @brief message for scale factor argument -static const char scale_factor_message[] = "Optional user-specified input scale factor for quantization (use with -q user)."; +static const char scale_factor_message[] = "Optional user-specified input scale factor for quantization (use with -q user). " + "If the network contains multiple inputs, provide scale factors by separating them with commas."; /// @brief message for batch size argument static const char batch_size_message[] = "Batch size 1-8 (default 1)"; @@ -130,7 +131,7 @@ DEFINE_string(q, "static", quantization_message); DEFINE_int32(qb, 16, quantization_bits_message); /// @brief Scale factor for quantization (default 1.0) -DEFINE_double(sf, 1.0, scale_factor_message); +DEFINE_string(sf, "", scale_factor_message); /// @brief Batch size (default 1) DEFINE_int32(bs, 1, batch_size_message);