add support for multiple scale factors in speech sample (#835)
Co-authored-by: Anna Alberska <anna.alberska@intel.com>
This commit is contained in:
parent
d4e880de3d
commit
ef8a8dd309
@ -404,6 +404,28 @@ void sumPerformanceCounters(std::map<std::string, InferenceEngine::InferenceEngi
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::string> ParseScaleFactors(const std::string& str) {
|
||||
std::vector<std::string> scaleFactorInput;
|
||||
|
||||
if (!str.empty()) {
|
||||
std::string outStr;
|
||||
std::istringstream stream(str);
|
||||
int i = 0;
|
||||
while (getline(stream, outStr, ',')) {
|
||||
auto floatScaleFactor = std::stof(outStr);
|
||||
if (floatScaleFactor <= 0.0f) {
|
||||
throw std::logic_error("Scale factor for input #" + std::to_string(i)
|
||||
+ " (counting from zero) is out of range (must be positive).");
|
||||
}
|
||||
scaleFactorInput.push_back(outStr);
|
||||
i++;
|
||||
}
|
||||
} else {
|
||||
throw std::logic_error("Scale factor need to be specified via -sf option if you are using -q user");
|
||||
}
|
||||
return scaleFactorInput;
|
||||
}
|
||||
|
||||
bool ParseAndCheckCommandLine(int argc, char *argv[]) {
|
||||
// ---------------------------Parsing and validation of input args--------------------------------------
|
||||
slog::info << "Parsing input parameters" << slog::endl;
|
||||
@ -453,11 +475,6 @@ bool ParseAndCheckCommandLine(int argc, char *argv[]) {
|
||||
throw std::logic_error("Specified device is not supported.");
|
||||
}
|
||||
|
||||
float scaleFactorInput = static_cast<float>(FLAGS_sf);
|
||||
if (scaleFactorInput <= 0.0f) {
|
||||
throw std::logic_error("Scale factor out of range (must be non-negative).");
|
||||
}
|
||||
|
||||
uint32_t batchSize = (uint32_t) FLAGS_bs;
|
||||
if ((batchSize < 1) || (batchSize > 8)) {
|
||||
throw std::logic_error("Batch size out of range (1..8).");
|
||||
@ -515,7 +532,6 @@ int main(int argc, char *argv[]) {
|
||||
bool useHetero = isFeature("HETERO");
|
||||
std::string deviceStr =
|
||||
useHetero && useGna ? "HETERO:GNA,CPU" : FLAGS_d.substr(0, (FLAGS_d.find("_")));
|
||||
float scaleFactorInput = static_cast<float>(FLAGS_sf);
|
||||
uint32_t batchSize = (FLAGS_cw_r > 0 || FLAGS_cw_l > 0) ? 1 : (uint32_t) FLAGS_bs;
|
||||
|
||||
std::vector<std::string> inputArkFiles;
|
||||
@ -586,12 +602,18 @@ int main(int argc, char *argv[]) {
|
||||
}
|
||||
|
||||
if (FLAGS_q.compare("user") == 0) {
|
||||
if (numInputArkFiles > 1) {
|
||||
std::string errMessage("Incorrect use case for multiple input ark files. Please don't use -q 'user' for this case.");
|
||||
auto scaleFactorInput = ParseScaleFactors(FLAGS_sf);
|
||||
if (scaleFactorInput.size() != network.getInputsInfo().size()) {
|
||||
std::string errMessage("Incorrect command line for multiple inputs: "
|
||||
+ std::to_string(scaleFactorInput.size()) + " scale factors provided for "
|
||||
+ std::to_string(network.getInputsInfo().size()) + " inputs.");
|
||||
throw std::logic_error(errMessage);
|
||||
}
|
||||
slog::info << "Using scale factor of " << FLAGS_sf << slog::endl;
|
||||
gnaPluginConfig[GNA_CONFIG_KEY(SCALE_FACTOR)] = std::to_string(FLAGS_sf);
|
||||
for (size_t i = 0; i < scaleFactorInput.size(); ++i) {
|
||||
slog::info << "For input " << i << " using scale factor of " << scaleFactorInput[i] << slog::endl;
|
||||
std::string scaleFactorConfigKey = GNA_CONFIG_KEY(SCALE_FACTOR) + std::string("_") + std::to_string(i);
|
||||
gnaPluginConfig[scaleFactorConfigKey] = scaleFactorInput[i];
|
||||
}
|
||||
} else {
|
||||
// "static" quantization with calculated scale factor
|
||||
for (size_t i = 0; i < numInputArkFiles; i++) {
|
||||
@ -608,12 +630,12 @@ int main(int argc, char *argv[]) {
|
||||
&numFrames,
|
||||
&numFrameElements,
|
||||
&numBytesPerElement);
|
||||
scaleFactorInput =
|
||||
auto floatScaleFactor =
|
||||
ScaleFactorForQuantization(ptrFeatures.data(), MAX_VAL_2B_FEAT, numFrames * numFrameElements);
|
||||
slog::info << "Using scale factor of " << scaleFactorInput << " calculated from first utterance."
|
||||
slog::info << "Using scale factor of " << floatScaleFactor << " calculated from first utterance."
|
||||
<< slog::endl;
|
||||
std::string scaleFactorConfigKey = GNA_CONFIG_KEY(SCALE_FACTOR) + std::string("_") + std::to_string(i);
|
||||
gnaPluginConfig[scaleFactorConfigKey] = std::to_string(scaleFactorInput);
|
||||
gnaPluginConfig[scaleFactorConfigKey] = std::to_string(floatScaleFactor);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -61,7 +61,8 @@ static const char quantization_message[] = "Input quantization mode: static (de
|
||||
static const char quantization_bits_message[] = "Weight bits for quantization: 8 or 16 (default)";
|
||||
|
||||
/// @brief message for scale factor argument
|
||||
static const char scale_factor_message[] = "Optional user-specified input scale factor for quantization (use with -q user).";
|
||||
static const char scale_factor_message[] = "Optional user-specified input scale factor for quantization (use with -q user). "
|
||||
"If the network contains multiple inputs, provide scale factors by separating them with commas.";
|
||||
|
||||
/// @brief message for batch size argument
|
||||
static const char batch_size_message[] = "Batch size 1-8 (default 1)";
|
||||
@ -130,7 +131,7 @@ DEFINE_string(q, "static", quantization_message);
|
||||
DEFINE_int32(qb, 16, quantization_bits_message);
|
||||
|
||||
/// @brief Scale factor for quantization (default 1.0)
|
||||
DEFINE_double(sf, 1.0, scale_factor_message);
|
||||
DEFINE_string(sf, "", scale_factor_message);
|
||||
|
||||
/// @brief Batch size (default 1)
|
||||
DEFINE_int32(bs, 1, batch_size_message);
|
||||
|
Loading…
Reference in New Issue
Block a user