add support for multiple scale factors in speech sample (#835)

Co-authored-by: Anna Alberska <anna.alberska@intel.com>
This commit is contained in:
Denis Orlov 2020-06-09 14:36:28 +03:00 committed by GitHub
parent d4e880de3d
commit ef8a8dd309
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 15 deletions

View File

@ -404,6 +404,28 @@ void sumPerformanceCounters(std::map<std::string, InferenceEngine::InferenceEngi
}
}
std::vector<std::string> ParseScaleFactors(const std::string& str) {
std::vector<std::string> scaleFactorInput;
if (!str.empty()) {
std::string outStr;
std::istringstream stream(str);
int i = 0;
while (getline(stream, outStr, ',')) {
auto floatScaleFactor = std::stof(outStr);
if (floatScaleFactor <= 0.0f) {
throw std::logic_error("Scale factor for input #" + std::to_string(i)
+ " (counting from zero) is out of range (must be positive).");
}
scaleFactorInput.push_back(outStr);
i++;
}
} else {
throw std::logic_error("Scale factor need to be specified via -sf option if you are using -q user");
}
return scaleFactorInput;
}
bool ParseAndCheckCommandLine(int argc, char *argv[]) {
// ---------------------------Parsing and validation of input args--------------------------------------
slog::info << "Parsing input parameters" << slog::endl;
@ -453,11 +475,6 @@ bool ParseAndCheckCommandLine(int argc, char *argv[]) {
throw std::logic_error("Specified device is not supported.");
}
float scaleFactorInput = static_cast<float>(FLAGS_sf);
if (scaleFactorInput <= 0.0f) {
throw std::logic_error("Scale factor out of range (must be non-negative).");
}
uint32_t batchSize = (uint32_t) FLAGS_bs;
if ((batchSize < 1) || (batchSize > 8)) {
throw std::logic_error("Batch size out of range (1..8).");
@ -515,7 +532,6 @@ int main(int argc, char *argv[]) {
bool useHetero = isFeature("HETERO");
std::string deviceStr =
useHetero && useGna ? "HETERO:GNA,CPU" : FLAGS_d.substr(0, (FLAGS_d.find("_")));
float scaleFactorInput = static_cast<float>(FLAGS_sf);
uint32_t batchSize = (FLAGS_cw_r > 0 || FLAGS_cw_l > 0) ? 1 : (uint32_t) FLAGS_bs;
std::vector<std::string> inputArkFiles;
@ -586,12 +602,18 @@ int main(int argc, char *argv[]) {
}
if (FLAGS_q.compare("user") == 0) {
if (numInputArkFiles > 1) {
std::string errMessage("Incorrect use case for multiple input ark files. Please don't use -q 'user' for this case.");
auto scaleFactorInput = ParseScaleFactors(FLAGS_sf);
if (scaleFactorInput.size() != network.getInputsInfo().size()) {
std::string errMessage("Incorrect command line for multiple inputs: "
+ std::to_string(scaleFactorInput.size()) + " scale factors provided for "
+ std::to_string(network.getInputsInfo().size()) + " inputs.");
throw std::logic_error(errMessage);
}
slog::info << "Using scale factor of " << FLAGS_sf << slog::endl;
gnaPluginConfig[GNA_CONFIG_KEY(SCALE_FACTOR)] = std::to_string(FLAGS_sf);
for (size_t i = 0; i < scaleFactorInput.size(); ++i) {
slog::info << "For input " << i << " using scale factor of " << scaleFactorInput[i] << slog::endl;
std::string scaleFactorConfigKey = GNA_CONFIG_KEY(SCALE_FACTOR) + std::string("_") + std::to_string(i);
gnaPluginConfig[scaleFactorConfigKey] = scaleFactorInput[i];
}
} else {
// "static" quantization with calculated scale factor
for (size_t i = 0; i < numInputArkFiles; i++) {
@ -608,12 +630,12 @@ int main(int argc, char *argv[]) {
&numFrames,
&numFrameElements,
&numBytesPerElement);
scaleFactorInput =
auto floatScaleFactor =
ScaleFactorForQuantization(ptrFeatures.data(), MAX_VAL_2B_FEAT, numFrames * numFrameElements);
slog::info << "Using scale factor of " << scaleFactorInput << " calculated from first utterance."
slog::info << "Using scale factor of " << floatScaleFactor << " calculated from first utterance."
<< slog::endl;
std::string scaleFactorConfigKey = GNA_CONFIG_KEY(SCALE_FACTOR) + std::string("_") + std::to_string(i);
gnaPluginConfig[scaleFactorConfigKey] = std::to_string(scaleFactorInput);
gnaPluginConfig[scaleFactorConfigKey] = std::to_string(floatScaleFactor);
}
}

View File

@ -61,7 +61,8 @@ static const char quantization_message[] = "Input quantization mode: static (de
static const char quantization_bits_message[] = "Weight bits for quantization: 8 or 16 (default)";
/// @brief message for scale factor argument
static const char scale_factor_message[] = "Optional user-specified input scale factor for quantization (use with -q user).";
static const char scale_factor_message[] = "Optional user-specified input scale factor for quantization (use with -q user). "
"If the network contains multiple inputs, provide scale factors by separating them with commas.";
/// @brief message for batch size argument
static const char batch_size_message[] = "Batch size 1-8 (default 1)";
@ -130,7 +131,7 @@ DEFINE_string(q, "static", quantization_message);
DEFINE_int32(qb, 16, quantization_bits_message);
/// @brief Scale factor for quantization (default 1.0)
DEFINE_double(sf, 1.0, scale_factor_message);
DEFINE_string(sf, "", scale_factor_message);
/// @brief Batch size (default 1)
DEFINE_int32(bs, 1, batch_size_message);