Add new parameters to compile-tool (#1153)
This commit is contained in:
parent
8da90f8890
commit
bd3b6bfc5e
@ -50,6 +50,11 @@ static constexpr char iop_message[] = "Optional. Specifies precision for input a
|
||||
" Notice that quotes are required.\n"
|
||||
" Overwrites precision from ip and op options for specified layers.";
|
||||
|
||||
static constexpr char inputs_layout_message[] = "Optional. Specifies layout for all input layers of the network."
|
||||
" Supported values: NCHW, NHWC, NC, C.";
|
||||
static constexpr char outputs_layout_message[] = "Optional. Specifies layout for all input layers of the network."
|
||||
" Supported values: NCHW, NHWC, NC, C.";
|
||||
|
||||
static constexpr char dla_arch_name[] = "Optional. Specify architecture name used to compile executable network for FPGA device.";
|
||||
|
||||
DEFINE_bool(h, false, help_message);
|
||||
@ -60,6 +65,8 @@ DEFINE_string(c, "config", config_message);
|
||||
DEFINE_string(ip, "", inputs_precision_message);
|
||||
DEFINE_string(op, "", outputs_precision_message);
|
||||
DEFINE_string(iop, "", iop_message);
|
||||
DEFINE_string(il, "", inputs_layout_message);
|
||||
DEFINE_string(ol, "", outputs_layout_message);
|
||||
DEFINE_string(VPU_MYRIAD_PLATFORM, "", platform_message);
|
||||
DEFINE_string(VPU_NUMBER_OF_SHAVES, "", number_of_shaves_message);
|
||||
DEFINE_string(VPU_NUMBER_OF_CMX_SLICES, "", number_of_cmx_slices_message);
|
||||
@ -76,7 +83,9 @@ static void showUsage() {
|
||||
std::cout << " -c <value> " << config_message << std::endl;
|
||||
std::cout << " -ip <value> " << inputs_precision_message << std::endl;
|
||||
std::cout << " -op <value> " << outputs_precision_message << std::endl;
|
||||
std::cout << " -iop \"<value>\" " << iop_message << std::endl;
|
||||
std::cout << " -iop \"<value>\" " << iop_message << std::endl;
|
||||
std::cout << " -il <value> " << inputs_layout_message << std::endl;
|
||||
std::cout << " -ol <value> " << outputs_layout_message << std::endl;
|
||||
std::cout << " " << std::endl;
|
||||
std::cout << " VPU options: " << std::endl;
|
||||
std::cout << " -VPU_MYRIAD_PLATFORM <value> " << platform_message << std::endl;
|
||||
@ -184,6 +193,20 @@ static std::map<std::string, std::string> parsePrecisions(const std::string &iop
|
||||
}
|
||||
|
||||
using supported_precisions_t = std::unordered_map<std::string, InferenceEngine::Precision>;
|
||||
using supported_layouts_t = std::unordered_map<std::string, InferenceEngine::Layout>;
|
||||
using matchLayoutToDims_t = std::unordered_map<size_t, size_t>;
|
||||
|
||||
static InferenceEngine::Layout getLayout(const std::string &value,
|
||||
const supported_layouts_t &supported_layouts) {
|
||||
std::string upper_value = value;
|
||||
std::transform(value.begin(), value.end(), upper_value.begin(), ::toupper);
|
||||
auto layout = supported_layouts.find(upper_value);
|
||||
if (layout == supported_layouts.end()) {
|
||||
throw std::logic_error("\"" + value + "\"" + " is not a valid layout.");
|
||||
}
|
||||
|
||||
return layout->second;
|
||||
}
|
||||
|
||||
static InferenceEngine::Precision getPrecision(const std::string &value,
|
||||
const supported_precisions_t &supported_precisions,
|
||||
@ -216,6 +239,33 @@ static InferenceEngine::Precision getOutputPrecision(const std::string &value) {
|
||||
return getPrecision(value, supported_precisions, "for output layer");
|
||||
}
|
||||
|
||||
static InferenceEngine::Layout getLayout(const std::string &value) {
|
||||
static const supported_layouts_t supported_layouts = {
|
||||
{ "NCHW", InferenceEngine::Layout::NCHW },
|
||||
{ "NHWC", InferenceEngine::Layout::NHWC },
|
||||
{ "CHW", InferenceEngine::Layout::CHW },
|
||||
{ "NC", InferenceEngine::Layout::NC },
|
||||
{ "C", InferenceEngine::Layout::C }
|
||||
};
|
||||
return getLayout(value, supported_layouts);
|
||||
}
|
||||
|
||||
static bool isMatchLayoutToDims(const InferenceEngine::Layout& layout, const size_t dimension) {
|
||||
static const matchLayoutToDims_t matchLayoutToDims = {
|
||||
{static_cast<size_t>(InferenceEngine::Layout::NCHW), 4 },
|
||||
{static_cast<size_t>(InferenceEngine::Layout::NHWC), 4 },
|
||||
{static_cast<size_t>(InferenceEngine::Layout::CHW), 3 },
|
||||
{static_cast<size_t>(InferenceEngine::Layout::NC), 2 },
|
||||
{static_cast<size_t>(InferenceEngine::Layout::C), 1 }};
|
||||
|
||||
auto dims = matchLayoutToDims.find(static_cast<size_t>(layout));
|
||||
if (dims == matchLayoutToDims.end()) {
|
||||
throw std::logic_error("Layout is not valid.");
|
||||
}
|
||||
|
||||
return dimension == dims->second;
|
||||
}
|
||||
|
||||
bool isFP16(InferenceEngine::Precision precision) {
|
||||
return precision == InferenceEngine::Precision::FP16;
|
||||
}
|
||||
@ -289,7 +339,7 @@ static void processPrecisions(InferenceEngine::CNNNetwork &network,
|
||||
for (auto &&layer : network.getInputsInfo()) {
|
||||
const auto layerPrecision = layer.second->getPrecision();
|
||||
if ((isFloat(layerPrecision) && isFloat(precision)) ||
|
||||
(isFP16(layerPrecision) && isU8(precision))) {
|
||||
(isFloat(layerPrecision) && isU8(precision))) {
|
||||
layer.second->setPrecision(precision);
|
||||
}
|
||||
}
|
||||
@ -310,6 +360,27 @@ static void processPrecisions(InferenceEngine::CNNNetwork &network,
|
||||
}
|
||||
}
|
||||
|
||||
static void processLayout(InferenceEngine::CNNNetwork &network,
|
||||
const std::string &inputs_layout, const std::string &outputs_layout) {
|
||||
if (!inputs_layout.empty()) {
|
||||
auto layout = getLayout(inputs_layout);
|
||||
for (auto &&layer : network.getInputsInfo()) {
|
||||
if (isMatchLayoutToDims(layout, layer.second->getTensorDesc().getDims().size())) {
|
||||
layer.second->setLayout(layout);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!outputs_layout.empty()) {
|
||||
auto layout = getLayout(outputs_layout);
|
||||
for (auto &&layer : network.getOutputsInfo()) {
|
||||
if (isMatchLayoutToDims(layout, layer.second->getTensorDesc().getDims().size())) {
|
||||
layer.second->setLayout(layout);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string getFileNameFromPath(const std::string& path,
|
||||
#if defined(_WIN32)
|
||||
const std::string sep = "\\") {
|
||||
@ -341,6 +412,7 @@ int main(int argc, char *argv[]) {
|
||||
|
||||
setDefaultIOPrecisions(network, FLAGS_d);
|
||||
processPrecisions(network, FLAGS_ip, FLAGS_op, FLAGS_iop);
|
||||
processLayout(network, FLAGS_il, FLAGS_ol);
|
||||
|
||||
auto timeBeforeLoadNetwork = std::chrono::steady_clock::now();
|
||||
auto executableNetwork = ie.LoadNetwork(network, FLAGS_d, configure(FLAGS_c, FLAGS_m));
|
||||
|
Loading…
Reference in New Issue
Block a user