diff --git a/inference-engine/tools/compile_tool/main.cpp b/inference-engine/tools/compile_tool/main.cpp index 12bf644502f..8c7a2261504 100644 --- a/inference-engine/tools/compile_tool/main.cpp +++ b/inference-engine/tools/compile_tool/main.cpp @@ -50,6 +50,11 @@ static constexpr char iop_message[] = "Optional. Specifies precision for input a " Notice that quotes are required.\n" " Overwrites precision from ip and op options for specified layers."; +static constexpr char inputs_layout_message[] = "Optional. Specifies layout for all input layers of the network." + " Supported values: NCHW, NHWC, NC, C."; +static constexpr char outputs_layout_message[] = "Optional. Specifies layout for all input layers of the network." + " Supported values: NCHW, NHWC, NC, C."; + static constexpr char dla_arch_name[] = "Optional. Specify architecture name used to compile executable network for FPGA device."; DEFINE_bool(h, false, help_message); @@ -60,6 +65,8 @@ DEFINE_string(c, "config", config_message); DEFINE_string(ip, "", inputs_precision_message); DEFINE_string(op, "", outputs_precision_message); DEFINE_string(iop, "", iop_message); +DEFINE_string(il, "", inputs_layout_message); +DEFINE_string(ol, "", outputs_layout_message); DEFINE_string(VPU_MYRIAD_PLATFORM, "", platform_message); DEFINE_string(VPU_NUMBER_OF_SHAVES, "", number_of_shaves_message); DEFINE_string(VPU_NUMBER_OF_CMX_SLICES, "", number_of_cmx_slices_message); @@ -76,7 +83,9 @@ static void showUsage() { std::cout << " -c " << config_message << std::endl; std::cout << " -ip " << inputs_precision_message << std::endl; std::cout << " -op " << outputs_precision_message << std::endl; - std::cout << " -iop \"\" " << iop_message << std::endl; + std::cout << " -iop \"\" " << iop_message << std::endl; + std::cout << " -il " << inputs_layout_message << std::endl; + std::cout << " -ol " << outputs_layout_message << std::endl; std::cout << " " << std::endl; std::cout << " VPU options: " << std::endl; std::cout << " -VPU_MYRIAD_PLATFORM " << platform_message << std::endl; @@ -184,6 +193,20 @@ static std::map parsePrecisions(const std::string &iop } using supported_precisions_t = std::unordered_map; +using supported_layouts_t = std::unordered_map; +using matchLayoutToDims_t = std::unordered_map; + +static InferenceEngine::Layout getLayout(const std::string &value, + const supported_layouts_t &supported_layouts) { + std::string upper_value = value; + std::transform(value.begin(), value.end(), upper_value.begin(), ::toupper); + auto layout = supported_layouts.find(upper_value); + if (layout == supported_layouts.end()) { + throw std::logic_error("\"" + value + "\"" + " is not a valid layout."); + } + + return layout->second; +} static InferenceEngine::Precision getPrecision(const std::string &value, const supported_precisions_t &supported_precisions, @@ -216,6 +239,33 @@ static InferenceEngine::Precision getOutputPrecision(const std::string &value) { return getPrecision(value, supported_precisions, "for output layer"); } +static InferenceEngine::Layout getLayout(const std::string &value) { + static const supported_layouts_t supported_layouts = { + { "NCHW", InferenceEngine::Layout::NCHW }, + { "NHWC", InferenceEngine::Layout::NHWC }, + { "CHW", InferenceEngine::Layout::CHW }, + { "NC", InferenceEngine::Layout::NC }, + { "C", InferenceEngine::Layout::C } + }; + return getLayout(value, supported_layouts); +} + +static bool isMatchLayoutToDims(const InferenceEngine::Layout& layout, const size_t dimension) { + static const matchLayoutToDims_t matchLayoutToDims = { + {static_cast(InferenceEngine::Layout::NCHW), 4 }, + {static_cast(InferenceEngine::Layout::NHWC), 4 }, + {static_cast(InferenceEngine::Layout::CHW), 3 }, + {static_cast(InferenceEngine::Layout::NC), 2 }, + {static_cast(InferenceEngine::Layout::C), 1 }}; + + auto dims = matchLayoutToDims.find(static_cast(layout)); + if (dims == matchLayoutToDims.end()) { + throw std::logic_error("Layout is not valid."); + } + + return dimension == dims->second; +} + bool isFP16(InferenceEngine::Precision precision) { return precision == InferenceEngine::Precision::FP16; } @@ -289,7 +339,7 @@ static void processPrecisions(InferenceEngine::CNNNetwork &network, for (auto &&layer : network.getInputsInfo()) { const auto layerPrecision = layer.second->getPrecision(); if ((isFloat(layerPrecision) && isFloat(precision)) || - (isFP16(layerPrecision) && isU8(precision))) { + (isFloat(layerPrecision) && isU8(precision))) { layer.second->setPrecision(precision); } } @@ -310,6 +360,27 @@ static void processPrecisions(InferenceEngine::CNNNetwork &network, } } +static void processLayout(InferenceEngine::CNNNetwork &network, + const std::string &inputs_layout, const std::string &outputs_layout) { + if (!inputs_layout.empty()) { + auto layout = getLayout(inputs_layout); + for (auto &&layer : network.getInputsInfo()) { + if (isMatchLayoutToDims(layout, layer.second->getTensorDesc().getDims().size())) { + layer.second->setLayout(layout); + } + } + } + + if (!outputs_layout.empty()) { + auto layout = getLayout(outputs_layout); + for (auto &&layer : network.getOutputsInfo()) { + if (isMatchLayoutToDims(layout, layer.second->getTensorDesc().getDims().size())) { + layer.second->setLayout(layout); + } + } + } +} + std::string getFileNameFromPath(const std::string& path, #if defined(_WIN32) const std::string sep = "\\") { @@ -341,6 +412,7 @@ int main(int argc, char *argv[]) { setDefaultIOPrecisions(network, FLAGS_d); processPrecisions(network, FLAGS_ip, FLAGS_op, FLAGS_iop); + processLayout(network, FLAGS_il, FLAGS_ol); auto timeBeforeLoadNetwork = std::chrono::steady_clock::now(); auto executableNetwork = ie.LoadNetwork(network, FLAGS_d, configure(FLAGS_c, FLAGS_m));