Update IE Python Samples (#5166)

* refactor: update ie python samples

* python samples: change comment about infer request creation (step 5)

* python sample: add the ability to run object_detection_sample_ssd.py with a model with 2 outputs

* Add batch size usage to python style transfer sample

* Change comment about model reading

* Add output queue to classification async sample

* add reshape for output to catch results with more than 2 dimensions (classification samples)

* Set a log output stream to stdout to pass the hello query device test

* Add comments to the hello query device sample

* Set sys.stdout as a logging stream for all python IE samples

* Add batch size usage to ngraph_function_creation_sample

* Return the ability to read an image from a ubyte file

* Add few comments and function docstrings

* Restore IE python classification samples output

* Add --original_size arg for python style transfer sample

* Change log message to pass tests (object detection ie python sample)

* Return python shebang

* Add comment about a probs array sorting using np.argsort

* Fix the hello query python sample (Ticket: 52937)

* Add color inversion for light images for correct predictions

* Add few log messages to the python device query sample
This commit is contained in:
Dmitry Pigasin 2021-04-14 13:24:32 +03:00 committed by GitHub
parent 9709432d29
commit 19ace232cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 760 additions and 728 deletions

View File

@ -1,160 +1,168 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from __future__ import print_function
import argparse
import logging as log
import sys
import os
from argparse import ArgumentParser, SUPPRESS
import cv2
import numpy as np
import logging as log
from openvino.inference_engine import IECore
import threading
from openvino.inference_engine import IECore, StatusCode
class InferReqWrap:
def __init__(self, request, id, num_iter):
self.id = id
self.request = request
self.num_iter = num_iter
self.cur_iter = 0
self.cv = threading.Condition()
self.request.set_completion_callback(self.callback, self.id)
def callback(self, statusCode, userdata):
if (userdata != self.id):
log.error(f"Request ID {self.id} does not correspond to user data {userdata}")
elif statusCode != 0:
log.error(f"Request {self.id} failed with status code {statusCode}")
self.cur_iter += 1
log.info(f"Completed {self.cur_iter} Async request execution")
if self.cur_iter < self.num_iter:
# here a user can read output containing inference results and put new input
# to repeat async request again
self.request.async_infer(self.input)
else:
# continue sample execution after last Asynchronous inference request execution
self.cv.acquire()
self.cv.notify()
self.cv.release()
def execute(self, mode, input_data):
if (mode == "async"):
log.info(f"Start inference ({self.num_iter} Asynchronous executions)")
self.input = input_data
# Start async request for the first time. Wait all repetitions of the async request
self.request.async_infer(input_data)
self.cv.acquire()
self.cv.wait()
self.cv.release()
elif (mode == "sync"):
log.info(f"Start inference ({self.num_iter} Synchronous executions)")
for self.cur_iter in range(self.num_iter):
# here we start inference synchronously and wait for
# last inference request execution
self.request.infer(input_data)
log.info(f"Completed {self.cur_iter + 1} Sync request execution")
else:
log.error("wrong inference mode is chosen. Please use \"sync\" or \"async\" mode")
sys.exit(1)
def build_argparser():
parser = ArgumentParser(add_help=False)
def parse_args() -> argparse.Namespace:
'''Parse and return command line arguments'''
parser = argparse.ArgumentParser(add_help=False)
args = parser.add_argument_group('Options')
args.add_argument('-h', '--help', action='help', default=SUPPRESS, help='Show this help message and exit.')
args.add_argument("-m", "--model", help="Required. Path to an .xml or .onnx file with a trained model.",
required=True, type=str)
args.add_argument("-i", "--input", help="Required. Path to an image files",
required=True, type=str, nargs="+")
args.add_argument("-l", "--cpu_extension",
help="Optional. Required for CPU custom layers. Absolute path to a shared library with the"
" kernels implementations.", type=str, default=None)
args.add_argument("-d", "--device",
help="Optional. Specify the target device to infer on; CPU, GPU, FPGA, HDDL or MYRIAD is "
"acceptable. The sample will look for a suitable plugin for device specified. Default value is CPU",
default="CPU", type=str)
args.add_argument("--labels", help="Optional. Labels mapping file", default=None, type=str)
args.add_argument("-nt", "--number_top", help="Optional. Number of top results", default=10, type=int)
args.add_argument('-h', '--help', action='help', help='Show this help message and exit.')
args.add_argument('-m', '--model', required=True, type=str,
help='Required. Path to an .xml or .onnx file with a trained model.')
args.add_argument('-i', '--input', required=True, type=str, nargs='+', help='Required. Path to an image file(s).')
args.add_argument('-l', '--extension', type=str, default=None,
help='Optional. Required by the CPU Plugin for executing the custom operation on a CPU. '
'Absolute path to a shared library with the kernels implementations.')
args.add_argument('-c', '--config', type=str, default=None,
help='Optional. Required by GPU or VPU Plugins for the custom operation kernel. '
'Absolute path to operation description file (.xml).')
args.add_argument('-d', '--device', default='CPU', type=str,
help='Optional. Specify the target device to infer on; CPU, GPU, MYRIAD, HDDL or HETERO: '
'is acceptable. The sample will look for a suitable plugin for device specified. '
'Default value is CPU.')
args.add_argument('--labels', default=None, type=str, help='Optional. Path to a labels mapping file.')
args.add_argument('-nt', '--number_top', default=10, type=int, help='Optional. Number of top results.')
return parser.parse_args()
return parser
def main():
log.basicConfig(format="[ %(levelname)s ] %(message)s", level=log.INFO, stream=sys.stdout)
args = build_argparser().parse_args()
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
args = parse_args()
# Plugin initialization for specified device and load extensions library if specified
log.info("Creating Inference Engine")
# ---------------------------Step 1. Initialize inference engine core--------------------------------------------------
log.info('Creating Inference Engine')
ie = IECore()
if args.cpu_extension and 'CPU' in args.device:
ie.add_extension(args.cpu_extension, "CPU")
# Read a model in OpenVINO Intermediate Representation (.xml and .bin files) or ONNX (.onnx file) format
model = args.model
log.info(f"Loading network:\n\t{model}")
net = ie.read_network(model=model)
if args.extension and args.device == 'CPU':
log.info(f'Loading the {args.device} extension: {args.extension}')
ie.add_extension(args.extension, args.device)
assert len(net.input_info.keys()) == 1, "Sample supports only single input topologies"
assert len(net.outputs) == 1, "Sample supports only single output topologies"
if args.config and args.device in ('GPU', 'MYRIAD', 'HDDL'):
log.info(f'Loading the {args.device} configuration: {args.config}')
ie.set_config({'CONFIG_FILE': args.config}, args.device)
log.info("Preparing input blobs")
# ---------------------------Step 2. Read a model in OpenVINO Intermediate Representation or ONNX format---------------
log.info(f'Reading the network: {args.model}')
# (.xml and .bin files) or (.onnx file)
net = ie.read_network(model=args.model)
if len(net.input_info) != 1:
log.error('Sample supports only single input topologies')
return -1
if len(net.outputs) != 1:
log.error('Sample supports only single output topologies')
return -1
# ---------------------------Step 3. Configure input & output----------------------------------------------------------
log.info('Configuring input and output blobs')
# Get names of input and output blobs
input_blob = next(iter(net.input_info))
out_blob = next(iter(net.outputs))
net.batch_size = len(args.input)
# Read and pre-process input images
n, c, h, w = net.input_info[input_blob].input_data.shape
images = np.ndarray(shape=(n, c, h, w))
for i in range(n):
# Set input and output precision manually
net.input_info[input_blob].precision = 'U8'
net.outputs[out_blob].precision = 'FP32'
# Get a number of input images
num_of_input = len(args.input)
# Get a number of classes recognized by a model
num_of_classes = max(net.outputs[out_blob].shape)
# ---------------------------Step 4. Loading model to the device-------------------------------------------------------
log.info('Loading the model to the plugin')
exec_net = ie.load_network(network=net, device_name=args.device, num_requests=num_of_input)
# ---------------------------Step 5. Create infer request--------------------------------------------------------------
# load_network() method of the IECore class with a specified number of requests (default 1) returns an ExecutableNetwork
# instance which stores infer requests. So you already created Infer requests in the previous step.
# ---------------------------Step 6. Prepare input---------------------------------------------------------------------
input_data = []
_, _, h, w = net.input_info[input_blob].input_data.shape
for i in range(num_of_input):
image = cv2.imread(args.input[i])
if image.shape[:-1] != (h, w):
log.warning(f"Image {args.input[i]} is resized from {image.shape[:-1]} to {(h, w)}")
log.warning(f'Image {args.input[i]} is resized from {image.shape[:-1]} to {(h, w)}')
image = cv2.resize(image, (w, h))
image = image.transpose((2, 0, 1)) # Change data layout from HWC to CHW
images[i] = image
log.info(f"Batch size is {n}")
# Loading model to the plugin
log.info("Loading model to the plugin")
exec_net = ie.load_network(network=net, device_name=args.device)
# Change data layout from HWC to CHW
image = image.transpose((2, 0, 1))
# Add N dimension to transform to NCHW
image = np.expand_dims(image, axis=0)
# create one inference request for asynchronous execution
request_id = 0
infer_request = exec_net.requests[request_id]
input_data.append(image)
num_iter = 10
request_wrap = InferReqWrap(infer_request, request_id, num_iter)
# Start inference request execution. Wait for last execution being completed
request_wrap.execute("async", {input_blob: images})
# ---------------------------Step 7. Do inference----------------------------------------------------------------------
log.info('Starting inference in asynchronous mode')
for i in range(num_of_input):
exec_net.requests[i].async_infer({input_blob: input_data[i]})
# Processing output blob
log.info("Processing output blob")
res = infer_request.output_blobs[out_blob]
log.info(f"Top {args.number_top} results: ")
# ---------------------------Step 8. Process output--------------------------------------------------------------------
# Generate a label list
if args.labels:
with open(args.labels, 'r') as f:
labels_map = [x.split(sep=' ', maxsplit=1)[-1].strip() for x in f]
else:
labels_map = None
classid_str = "classid"
probability_str = "probability"
for i, probs in enumerate(res.buffer):
probs = np.squeeze(probs)
top_ind = np.argsort(probs)[-args.number_top:][::-1]
print(f"Image {args.input[i]}\n")
print(classid_str, probability_str)
print(f"{'-' * len(classid_str)} {'-' * len(probability_str)}")
for id in top_ind:
det_label = labels_map[id] if labels_map else f"{id}"
label_length = len(det_label)
space_num_before = (7 - label_length) // 2
space_num_after = 7 - (space_num_before + label_length) + 2
print(f"{' ' * space_num_before}{det_label}"
f"{' ' * space_num_after}{probs[id]:.7f}")
print("\n")
log.info("This sample is an API example, for any performance measurements please use the dedicated benchmark_app tool\n")
labels = [line.split(',')[0].strip() for line in f]
# Create a list to control a order of output
output_queue = list(range(num_of_input))
while True:
for i in output_queue:
# Immediately returns a inference status without blocking or interrupting
infer_status = exec_net.requests[i].wait(0)
if infer_status == StatusCode.RESULT_NOT_READY:
continue
log.info(f'Infer request {i} returned {infer_status}')
if infer_status != StatusCode.OK:
return -2
# Read infer request results from buffer
res = exec_net.requests[i].output_blobs[out_blob].buffer
# Change a shape of a numpy.ndarray with results to get another one with one dimension
probs = res.reshape(num_of_classes)
# Get an array of args.number_top class IDs in descending order of probability
top_n_idexes = np.argsort(probs)[-args.number_top:][::-1]
header = 'classid probability'
header = header + ' label' if args.labels else header
log.info(f'Image path: {args.input[i]}')
log.info(f'Top {args.number_top} results: ')
log.info(header)
log.info('-' * len(header))
for class_id in top_n_idexes:
probability_indent = ' ' * (len('classid') - len(str(class_id)) + 1)
label_indent = ' ' * (len('probability') - 8) if args.labels else ''
label = labels[class_id] if args.labels else ''
log.info(f'{class_id}{probability_indent}{probs[class_id]:.7f}{label_indent}{label}')
log.info('')
output_queue.remove(i)
if len(output_queue) == 0:
break
# ----------------------------------------------------------------------------------------------------------------------
log.info('This sample is an API example, '
'for any performance measurements please use the dedicated benchmark_app tool\n')
return 0
if __name__ == '__main__':
sys.exit(main() or 0)
sys.exit(main())

View File

@ -1,106 +1,125 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from __future__ import print_function
import argparse
import logging as log
import sys
import os
from argparse import ArgumentParser, SUPPRESS
import cv2
import numpy as np
import logging as log
from openvino.inference_engine import IECore
def build_argparser():
parser = ArgumentParser(add_help=False)
def parse_args() -> argparse.Namespace:
'''Parse and return command line arguments'''
parser = argparse.ArgumentParser(add_help=False)
args = parser.add_argument_group('Options')
args.add_argument('-h', '--help', action='help', default=SUPPRESS, help='Show this help message and exit.')
args.add_argument("-m", "--model", help="Required. Path to an .xml or .onnx file with a trained model.", required=True,
type=str)
args.add_argument("-i", "--input", help="Required. Path to an image file.",
required=True, type=str)
args.add_argument("-l", "--cpu_extension",
help="Optional. Required for CPU custom layers. "
"MKLDNN (CPU)-targeted custom layers. Absolute path to a shared library with the"
" kernels implementations.", type=str, default=None)
args.add_argument("-d", "--device",
help="Optional. Specify the target device to infer on; CPU, GPU, FPGA, HDDL, MYRIAD or HETERO: is "
"acceptable. The sample will look for a suitable plugin for device specified. Default "
"value is CPU",
default="CPU", type=str)
args.add_argument("--labels", help="Optional. Path to a labels mapping file", default=None, type=str)
args.add_argument("-nt", "--number_top", help="Optional. Number of top results", default=10, type=int)
args.add_argument('-h', '--help', action='help', help='Show this help message and exit.')
args.add_argument('-m', '--model', required=True, type=str,
help='Required. Path to an .xml or .onnx file with a trained model.')
args.add_argument('-i', '--input', required=True, type=str, help='Required. Path to an image file.')
args.add_argument('-d', '--device', default='CPU', type=str,
help='Optional. Specify the target device to infer on; CPU, GPU, MYRIAD, HDDL or HETERO: '
'is acceptable. The sample will look for a suitable plugin for device specified. '
'Default value is CPU.')
args.add_argument('--labels', default=None, type=str, help='Optional. Path to a labels mapping file.')
args.add_argument('-nt', '--number_top', default=10, type=int, help='Optional. Number of top results.')
return parser
return parser.parse_args()
def main():
log.basicConfig(format="[ %(levelname)s ] %(message)s", level=log.INFO, stream=sys.stdout)
args = build_argparser().parse_args()
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
args = parse_args()
# Plugin initialization for specified device and load extensions library if specified
log.info("Creating Inference Engine")
# ---------------------------Step 1. Initialize inference engine core--------------------------------------------------
log.info('Creating Inference Engine')
ie = IECore()
if args.cpu_extension and 'CPU' in args.device:
ie.add_extension(args.cpu_extension, "CPU")
# Read a model in OpenVINO Intermediate Representation (.xml and .bin files) or ONNX (.onnx file) format
model = args.model
log.info(f"Loading network:\n\t{model}")
net = ie.read_network(model=model)
# ---------------------------Step 2. Read a model in OpenVINO Intermediate Representation or ONNX format---------------
log.info(f'Reading the network: {args.model}')
# (.xml and .bin files) or (.onnx file)
net = ie.read_network(model=args.model)
assert len(net.input_info.keys()) == 1, "Sample supports only single input topologies"
assert len(net.outputs) == 1, "Sample supports only single output topologies"
if len(net.input_info) != 1:
log.error('Sample supports only single input topologies')
return -1
if len(net.outputs) != 1:
log.error('Sample supports only single output topologies')
return -1
log.info("Preparing input blobs")
# ---------------------------Step 3. Configure input & output----------------------------------------------------------
log.info('Configuring input and output blobs')
# Get names of input and output blobs
input_blob = next(iter(net.input_info))
out_blob = next(iter(net.outputs))
# Read and pre-process input images
n, c, h, w = net.input_info[input_blob].input_data.shape
images = np.ndarray(shape=(n, c, h, w))
for i in range(n):
image = cv2.imread(args.input)
if image.shape[:-1] != (h, w):
log.warning(f"Image {args.input} is resized from {image.shape[:-1]} to {(h, w)}")
image = cv2.resize(image, (w, h))
image = image.transpose((2, 0, 1)) # Change data layout from HWC to CHW
images[i] = image
# Set input and output precision manually
net.input_info[input_blob].precision = 'U8'
net.outputs[out_blob].precision = 'FP32'
# Loading model to the plugin
log.info("Loading model to the plugin")
# Get a number of classes recognized by a model
num_of_classes = max(net.outputs[out_blob].shape)
# ---------------------------Step 4. Loading model to the device-------------------------------------------------------
log.info('Loading the model to the plugin')
exec_net = ie.load_network(network=net, device_name=args.device)
# Start sync inference
log.info("Starting inference in synchronous mode")
res = exec_net.infer(inputs={input_blob: images})
# ---------------------------Step 5. Create infer request--------------------------------------------------------------
# load_network() method of the IECore class with a specified number of requests (default 1) returns an ExecutableNetwork
# instance which stores infer requests. So you already created Infer requests in the previous step.
# Processing output blob
log.info("Processing output blob")
res = res[out_blob]
log.info(f"Top {args.number_top} results: ")
# ---------------------------Step 6. Prepare input---------------------------------------------------------------------
original_image = cv2.imread(args.input)
image = original_image.copy()
_, _, h, w = net.input_info[input_blob].input_data.shape
if image.shape[:-1] != (h, w):
log.warning(f'Image {args.input} is resized from {image.shape[:-1]} to {(h, w)}')
image = cv2.resize(image, (w, h))
# Change data layout from HWC to CHW
image = image.transpose((2, 0, 1))
# Add N dimension to transform to NCHW
image = np.expand_dims(image, axis=0)
# ---------------------------Step 7. Do inference----------------------------------------------------------------------
log.info('Starting inference in synchronous mode')
res = exec_net.infer(inputs={input_blob: image})
# ---------------------------Step 8. Process output--------------------------------------------------------------------
# Generate a label list
if args.labels:
with open(args.labels, 'r') as f:
labels_map = [x.split(sep=' ', maxsplit=1)[-1].strip() for x in f]
else:
labels_map = None
classid_str = "classid"
probability_str = "probability"
for i, probs in enumerate(res):
probs = np.squeeze(probs)
top_ind = np.argsort(probs)[-args.number_top:][::-1]
print(f"Image {args.input}\n")
print(classid_str, probability_str)
print(f"{'-' * len(classid_str)} {'-' * len(probability_str)}")
for id in top_ind:
det_label = labels_map[id] if labels_map else f"{id}"
label_length = len(det_label)
space_num_before = (len(classid_str) - label_length) // 2
space_num_after = len(classid_str) - (space_num_before + label_length) + 2
print(f"{' ' * space_num_before}{det_label}"
f"{' ' * space_num_after}{probs[id]:.7f}")
print("\n")
log.info("This sample is an API example, for any performance measurements please use the dedicated benchmark_app tool\n")
labels = [line.split(',')[0].strip() for line in f]
res = res[out_blob]
# Change a shape of a numpy.ndarray with results to get another one with one dimension
probs = res.reshape(num_of_classes)
# Get an array of args.number_top class IDs in descending order of probability
top_n_idexes = np.argsort(probs)[-args.number_top:][::-1]
header = 'classid probability'
header = header + ' label' if args.labels else header
log.info(f'Image path: {args.input}')
log.info(f'Top {args.number_top} results: ')
log.info(header)
log.info('-' * len(header))
for class_id in top_n_idexes:
probability_indent = ' ' * (len('classid') - len(str(class_id)) + 1)
label_indent = ' ' * (len('probability') - 8) if args.labels else ''
label = labels[class_id] if args.labels else ''
log.info(f'{class_id}{probability_indent}{probs[class_id]:.7f}{label_indent}{label}')
log.info('')
# ----------------------------------------------------------------------------------------------------------------------
log.info('This sample is an API example, '
'for any performance measurements please use the dedicated benchmark_app tool\n')
return 0
if __name__ == '__main__':
sys.exit(main() or 0)
sys.exit(main())

View File

@ -1,43 +1,54 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import logging as log
import sys
from openvino.inference_engine import IECore
def param_to_string(metric):
def param_to_string(metric) -> str:
'''Convert a list / tuple of parameters returned from IE to a string'''
if isinstance(metric, (list, tuple)):
return ", ".join([str(val) for val in metric])
elif isinstance(metric, dict):
str_param_repr = ""
for k, v in metric.items():
str_param_repr += f"{k}: {v}\n"
return str_param_repr
return ', '.join([str(x) for x in metric])
else:
return str(metric)
def main():
ie = IECore()
print("Available devices:")
for device in ie.available_devices:
print(f"\tDevice: {device}")
print("\tMetrics:")
for metric in ie.get_metric(device, "SUPPORTED_METRICS"):
try:
metric_val = ie.get_metric(device, metric)
print(f"\t\t{metric}: {param_to_string(metric_val)}")
except TypeError:
print(f"\t\t{metric}: UNSUPPORTED TYPE")
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
print("\n\tDefault values for device configuration keys:")
for cfg in ie.get_metric(device, "SUPPORTED_CONFIG_KEYS"):
# ---------------------------Initialize inference engine core----------------------------------------------------------
log.info('Creating Inference Engine')
ie = IECore()
# ---------------------------Get metrics of available devices----------------------------------------------------------
log.info('Available devices:')
for device in ie.available_devices:
log.info(f'{device} :')
log.info('\tSUPPORTED_METRICS:')
for metric in ie.get_metric(device, 'SUPPORTED_METRICS'):
if metric not in ('SUPPORTED_METRICS', 'SUPPORTED_CONFIG_KEYS'):
try:
metric_val = ie.get_metric(device, metric)
except TypeError:
metric_val = 'UNSUPPORTED TYPE'
log.info(f'\t\t{metric}: {param_to_string(metric_val)}')
log.info('')
log.info('\tSUPPORTED_CONFIG_KEYS (default values):')
for config_key in ie.get_metric(device, 'SUPPORTED_CONFIG_KEYS'):
try:
cfg_val = ie.get_config(device, cfg)
print(f"\t\t{cfg}: {param_to_string(cfg_val)}")
config_val = ie.get_config(device, config_key)
except TypeError:
print(f"\t\t{cfg}: UNSUPPORTED TYPE")
config_val = 'UNSUPPORTED TYPE'
log.info(f'\t\t{config_key}: {param_to_string(config_val)}')
log.info('')
# ----------------------------------------------------------------------------------------------------------------------
return 0
if __name__ == '__main__':
sys.exit(main() or 0)
sys.exit(main())

View File

@ -1,158 +1,145 @@
#!/usr/bin/env python
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from __future__ import print_function
from argparse import ArgumentParser, SUPPRESS
import argparse
import logging as log
import os
import sys
import cv2
import ngraph as ng
import numpy as np
from openvino.inference_engine import IECore
def build_argparser():
parser = ArgumentParser(add_help=False)
def parse_args() -> argparse.Namespace:
'''Parse and return command line arguments'''
parser = argparse.ArgumentParser(add_help=False)
args = parser.add_argument_group('Options')
args.add_argument('-h', '--help', action='help', default=SUPPRESS, help='Show this help message and exit.')
args.add_argument('-m', '--model', help='Required. Path to an .xml or .onnx file with a trained model.',
required=True, type=str)
args.add_argument('-i', '--input', help='Required. Path to an image file.',
required=True, type=str)
args.add_argument('-d', '--device',
help='Optional. Specify the target device to infer on; '
'CPU, GPU, FPGA or MYRIAD is acceptable. '
'Sample will look for a suitable plugin for device specified (CPU by default)',
default='CPU', type=str)
return parser
args.add_argument('-h', '--help', action='help', help='Show this help message and exit.')
args.add_argument('-m', '--model', required=True, type=str,
help='Required. Path to an .xml or .onnx file with a trained model.')
args.add_argument('-i', '--input', required=True, type=str, help='Required. Path to an image file.')
args.add_argument('-l', '--extension', type=str, default=None,
help='Optional. Required by the CPU Plugin for executing the custom operation on a CPU. '
'Absolute path to a shared library with the kernels implementations.')
args.add_argument('-c', '--config', type=str, default=None,
help='Optional. Required by GPU or VPU Plugins for the custom operation kernel. '
'Absolute path to operation description file (.xml).')
args.add_argument('-d', '--device', default='CPU', type=str,
help='Optional. Specify the target device to infer on; CPU, GPU, MYRIAD, HDDL or HETERO: '
'is acceptable. The sample will look for a suitable plugin for device specified. '
'Default value is CPU.')
args.add_argument('--labels', default=None, type=str, help='Optional. Path to a labels mapping file.')
return parser.parse_args()
def main():
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
args = build_argparser().parse_args()
log.info('Loading Inference Engine')
args = parse_args()
# ---------------------------Step 1. Initialize inference engine core--------------------------------------------------
log.info('Creating Inference Engine')
ie = IECore()
# ---1. Read a model in OpenVINO Intermediate Representation (.xml and .bin files) or ONNX (.onnx file) format ---
model = args.model
log.info(f'Loading network:')
log.info(f' {model}')
net = ie.read_network(model=model)
# -----------------------------------------------------------------------------------------------------
if args.extension and args.device == 'CPU':
log.info(f'Loading the {args.device} extension: {args.extension}')
ie.add_extension(args.extension, args.device)
# ------------- 2. Load Plugin for inference engine and extensions library if specified --------------
log.info('Device info:')
versions = ie.get_versions(args.device)
log.info(f' {args.device}')
log.info(f' MKLDNNPlugin version ......... {versions[args.device].major}.{versions[args.device].minor}')
log.info(f' Build ........... {versions[args.device].build_number}')
# -----------------------------------------------------------------------------------------------------
if args.config and args.device in ('GPU', 'MYRIAD', 'HDDL'):
log.info(f'Loading the {args.device} configuration: {args.config}')
ie.set_config({'CONFIG_FILE': args.config}, args.device)
# --------------------------- 3. Read and preprocess input --------------------------------------------
# ---------------------------Step 2. Read a model in OpenVINO Intermediate Representation or ONNX format---------------
log.info(f'Reading the network: {args.model}')
# (.xml and .bin files) or (.onnx file)
net = ie.read_network(model=args.model)
log.info(f'Inputs number: {len(net.input_info.keys())}')
assert len(net.input_info.keys()) == 1, 'Sample supports clean SSD network with one input'
assert len(net.outputs.keys()) == 1, 'Sample supports clean SSD network with one output'
input_name = list(net.input_info.keys())[0]
input_info = net.input_info[input_name]
supported_input_dims = 4 # Supported input layout - NHWC
if len(net.input_info) != 1:
log.error('Sample supports only single input topologies')
return -1
if len(net.outputs) != 1:
log.error('Sample supports only single output topologies')
return -1
# ---------------------------Step 3. Configure input & output----------------------------------------------------------
log.info('Configuring input and output blobs')
# Get names of input and output blobs
input_blob = next(iter(net.input_info))
out_blob = next(iter(net.outputs))
# Set input and output precision manually
net.input_info[input_blob].precision = 'U8'
net.outputs[out_blob].precision = 'FP32'
original_image = cv2.imread(args.input)
image = original_image.copy()
# Change data layout from HWC to CHW
image = image.transpose((2, 0, 1))
# Add N dimension to transform to NCHW
image = np.expand_dims(image, axis=0)
log.info(f' Input name: {input_name}')
log.info(f' Input shape: {str(input_info.input_data.shape)}')
if len(input_info.input_data.layout) == supported_input_dims:
n, c, h, w = input_info.input_data.shape
assert n == 1, 'Sample supports topologies with one input image only'
else:
raise AssertionError('Sample supports input with NHWC shape only')
image = cv2.imread(args.input)
h_new, w_new = image.shape[:-1]
images = np.ndarray(shape=(n, c, h_new, w_new))
log.info('File was added: ')
log.info(f' {args.input}')
image = image.transpose((2, 0, 1)) # Change data layout from HWC to CHW
images[0] = image
log.info('Reshaping the network to the height and width of the input image')
net.reshape({input_name: [n, c, h_new, w_new]})
log.info(f'Input shape after reshape: {str(net.input_info["data"].input_data.shape)}')
log.info(f'Input shape before reshape: {net.input_info[input_blob].input_data.shape}')
net.reshape({input_blob: image.shape})
log.info(f'Input shape after reshape: {net.input_info[input_blob].input_data.shape}')
# -----------------------------------------------------------------------------------------------------
# --------------------------- 4. Configure input & output ---------------------------------------------
# --------------------------- Prepare input blobs -----------------------------------------------------
log.info('Preparing input blobs')
if len(input_info.layout) == supported_input_dims:
input_info.precision = 'U8'
data = {}
data[input_name] = images
# --------------------------- Prepare output blobs ----------------------------------------------------
log.info('Preparing output blobs')
func = ng.function_from_cnn(net)
ops = func.get_ordered_ops()
output_name, output_info = '', net.outputs[next(iter(net.outputs.keys()))]
output_ops = {op.friendly_name : op for op in ops \
if op.friendly_name in net.outputs and op.get_type_name() == 'DetectionOutput'}
if len(output_ops) == 1:
output_name, output_info = output_ops.popitem()
assert output_name != '', 'Can''t find a DetectionOutput layer in the topology'
output_dims = output_info.shape
assert output_dims != 4, 'Incorrect output dimensions for SSD model'
assert output_dims[3] == 7, 'Output item should have 7 as a last dimension'
output_info.precision = 'FP32'
# -----------------------------------------------------------------------------------------------------
# --------------------------- Performing inference ----------------------------------------------------
log.info('Loading model to the device')
# ---------------------------Step 4. Loading model to the device-------------------------------------------------------
log.info('Loading the model to the plugin')
exec_net = ie.load_network(network=net, device_name=args.device)
log.info('Creating infer request and starting inference')
exec_result = exec_net.infer(inputs=data)
# -----------------------------------------------------------------------------------------------------
# --------------------------- Read and postprocess output ---------------------------------------------
log.info('Processing output blobs')
result = exec_result[output_name]
boxes = {}
detections = result[0][0] # [0][0] - location of detections in result blob
for number, proposal in enumerate(detections):
imid, label, confidence, coords = np.int(proposal[0]), np.int(proposal[1]), proposal[2], proposal[3:]
# ---------------------------Step 5. Create infer request--------------------------------------------------------------
# load_network() method of the IECore class with a specified number of requests (default 1) returns an ExecutableNetwork
# instance which stores infer requests. So you already created Infer requests in the previous step.
# ---------------------------Step 6. Prepare input---------------------------------------------------------------------
# This sample changes a network input layer shape instead of a image shape. See Step 4.
# ---------------------------Step 7. Do inference----------------------------------------------------------------------
log.info('Starting inference in synchronous mode')
res = exec_net.infer(inputs={input_blob: image})
# ---------------------------Step 8. Process output--------------------------------------------------------------------
# Generate a label list
if args.labels:
with open(args.labels, 'r') as f:
labels = [line.split(',')[0].strip() for line in f]
res = res[out_blob]
output_image = original_image.copy()
h, w, _ = output_image.shape
# Change a shape of a numpy.ndarray with results ([1, 1, N, 7]) to get another one ([N, 7]),
# where N is the number of detected bounding boxes
detections = res.reshape(-1, 7)
for detection in detections:
confidence = detection[2]
if confidence > 0.5:
# correcting coordinates to actual image resolution
xmin, ymin, xmax, ymax = w_new * coords[0], h_new * coords[1], w_new * coords[2], h_new * coords[3]
class_id = int(detection[1])
label = labels[class_id] if args.labels else class_id
log.info(f' [{number},{label}] element, prob = {confidence:.6f}, '
f'bbox = ({xmin:.3f},{ymin:.3f})-({xmax:.3f},{ymax:.3f}), batch id = {imid}')
if not imid in boxes.keys():
boxes[imid] = []
boxes[imid].append([xmin, ymin, xmax, ymax])
xmin = int(detection[3] * w)
ymin = int(detection[4] * h)
xmax = int(detection[5] * w)
ymax = int(detection[6] * h)
imid = 0 # as sample supports one input image only, imid in results will always be 0
log.info(f'Found: label = {label}, confidence = {confidence:.2f}, '
f'coords = ({xmin}, {ymin}), ({xmax}, {ymax})')
tmp_image = cv2.imread(args.input)
for box in boxes[imid]:
# drawing bounding boxes on the output image
cv2.rectangle(
tmp_image,
(np.int(box[0]), np.int(box[1])), (np.int(box[2]), np.int(box[3])),
color=(232, 35, 244), thickness=2)
cv2.imwrite('out.bmp', tmp_image)
log.info('Image out.bmp created!')
# -----------------------------------------------------------------------------------------------------
# Draw a bounding box on a output image
cv2.rectangle(output_image, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2)
log.info('Execution successful\n')
log.info(
'This sample is an API example, for any performance measurements please use the dedicated benchmark_app tool')
cv2.imwrite('out.bmp', output_image)
log.info('Image out.bmp was created!')
# ----------------------------------------------------------------------------------------------------------------------
log.info('This sample is an API example, '
'for any performance measurements please use the dedicated benchmark_app tool\n')
return 0
if __name__ == '__main__':
sys.exit(main() or 0)
sys.exit(main())

View File

@ -1,84 +1,68 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import sys
import os
from argparse import ArgumentParser, SUPPRESS
import cv2
import numpy as np
import argparse
import logging as log
from openvino.inference_engine import IECore, IENetwork
import ngraph
from ngraph.impl import Function
from functools import reduce
import struct as st
import sys
from functools import reduce
import cv2
import ngraph
import numpy as np
from openvino.inference_engine import IECore, IENetwork
def build_argparser() -> ArgumentParser:
parser = ArgumentParser(add_help=False)
def parse_args() -> argparse.Namespace:
'''Parse and return command line arguments'''
parser = argparse.ArgumentParser(add_help=False)
args = parser.add_argument_group('Options')
args.add_argument('-h', '--help', action='help', default=SUPPRESS, help='Show this help message and exit.')
args.add_argument('-i', '--input', help='Required. Path to a folder with images or path to an image files',
required=True, type=str, nargs="+")
args.add_argument('-m', '--model', help='Required. Path to file where weights for the network are located',
required=True)
args.add_argument('-d', '--device',
help='Optional. Specify the target device to infer on; CPU, GPU, FPGA, HDDL, MYRIAD or HETERO: '
'is acceptable. The sample will look for a suitable plugin for device specified. Default '
'value is CPU',
default='CPU', type=str)
args.add_argument('--labels', help='Optional. Path to a labels mapping file', default=None, type=str)
args.add_argument('-nt', '--number_top', help='Optional. Number of top results', default=1, type=int)
args.add_argument('-h', '--help', action='help', help='Show this help message and exit.')
args.add_argument('-m', '--model', required=True, type=str,
help='Required. Path to a file with network weights.')
args.add_argument('-i', '--input', required=True, type=str, nargs='+', help='Required. Path to an image file.')
args.add_argument('-d', '--device', default='CPU', type=str,
help='Optional. Specify the target device to infer on; CPU, GPU, MYRIAD, HDDL or HETERO: '
'is acceptable. The sample will look for a suitable plugin for device specified. '
'Default value is CPU.')
args.add_argument('--labels', default=None, type=str, help='Optional. Path to a labels mapping file.')
args.add_argument('-nt', '--number_top', default=10, type=int, help='Optional. Number of top results.')
return parser
return parser.parse_args()
def list_input_images(input_dirs: list):
images = []
for input_dir in input_dirs:
if os.path.isdir(input_dir):
for root, directories, filenames in os.walk(input_dir):
for filename in filenames:
images.append(os.path.join(root, filename))
elif os.path.isfile(input_dir):
images.append(input_dir)
def read_image(image_path: str) -> np.ndarray:
'''Read and return an image as grayscale (one channel)'''
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
return images
def read_image(image_path: np):
# try to read image in usual image formats
image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
# try to open image as ubyte
# Try to open image as ubyte
if image is None:
with open(image_path, 'rb') as f:
image_file = open(image_path, 'rb')
image_file.seek(0)
st.unpack('>4B', image_file.read(4)) # need to skip 4 bytes
nimg = st.unpack('>I', image_file.read(4))[0] # number of images
nrow = st.unpack('>I', image_file.read(4))[0] # number of rows
ncolumn = st.unpack('>I', image_file.read(4))[0] # number of column
st.unpack('>4B', f.read(4)) # need to skip 4 bytes
nimg = st.unpack('>I', f.read(4))[0] # number of images
nrow = st.unpack('>I', f.read(4))[0] # number of rows
ncolumn = st.unpack('>I', f.read(4))[0] # number of column
nbytes = nimg * nrow * ncolumn * 1 # each pixel data is 1 byte
assert nimg == 1, log.error('Sample supports ubyte files with 1 image inside')
if nimg != 1:
raise Exception('Sample supports ubyte files with 1 image inside')
image = np.asarray(st.unpack('>' + 'B' * nbytes, image_file.read(nbytes))).reshape(
(nrow, ncolumn))
image = np.asarray(st.unpack('>' + 'B' * nbytes, f.read(nbytes))).reshape((nrow, ncolumn))
return image
def shape_and_length(shape: list):
length = reduce(lambda x, y: x*y, shape)
return shape, length
def create_ngraph_function(args: argparse.Namespace) -> ngraph.impl.Function:
'''Create a network on the fly from the source code using ngraph'''
def shape_and_length(shape: list) -> (list, int):
length = reduce(lambda x, y: x * y, shape)
return shape, length
def create_ngraph_function(args) -> Function:
weights = np.fromfile(args.model, dtype=np.float32)
weights_offset = 0
padding_begin = [0, 0]
padding_end = [0, 0]
padding_begin = padding_end = [0, 0]
# input
input_shape = [64, 1, 28, 28]
@ -93,7 +77,7 @@ def create_ngraph_function(args) -> Function:
# add 1
add_1_kernel_shape, add_1_kernel_length = shape_and_length([1, 20, 1, 1])
add_1_kernel = ngraph.constant(
weights[weights_offset:weights_offset + add_1_kernel_length].reshape(add_1_kernel_shape)
weights[weights_offset:weights_offset + add_1_kernel_length].reshape(add_1_kernel_shape),
)
weights_offset += add_1_kernel_length
add_1_node = ngraph.add(conv_1_node, add_1_kernel)
@ -104,7 +88,7 @@ def create_ngraph_function(args) -> Function:
# convolution 2
conv_2_kernel_shape, conv_2_kernel_length = shape_and_length([50, 20, 5, 5])
conv_2_kernel = ngraph.constant(
weights[weights_offset:weights_offset + conv_2_kernel_length].reshape(conv_2_kernel_shape)
weights[weights_offset:weights_offset + conv_2_kernel_length].reshape(conv_2_kernel_shape),
)
weights_offset += conv_2_kernel_length
conv_2_node = ngraph.convolution(maxpool_1_node, conv_2_kernel, [1, 1], padding_begin, padding_end, [1, 1])
@ -112,7 +96,7 @@ def create_ngraph_function(args) -> Function:
# add 2
add_2_kernel_shape, add_2_kernel_length = shape_and_length([1, 50, 1, 1])
add_2_kernel = ngraph.constant(
weights[weights_offset:weights_offset + add_2_kernel_length].reshape(add_2_kernel_shape)
weights[weights_offset:weights_offset + add_2_kernel_length].reshape(add_2_kernel_shape),
)
weights_offset += add_2_kernel_length
add_2_node = ngraph.add(conv_2_node, add_2_kernel)
@ -124,16 +108,16 @@ def create_ngraph_function(args) -> Function:
reshape_1_dims, reshape_1_length = shape_and_length([2])
# workaround to get int64 weights from float32 ndarray w/o unnecessary copying
dtype_weights = np.frombuffer(
weights[weights_offset:weights_offset + 2*reshape_1_length], dtype=np.int64
weights[weights_offset:weights_offset + 2 * reshape_1_length], dtype=np.int64,
)
reshape_1_kernel = ngraph.constant(dtype_weights)
weights_offset += 2*reshape_1_length
weights_offset += 2 * reshape_1_length
reshape_1_node = ngraph.reshape(maxpool_2_node, reshape_1_kernel, True)
# matmul 1
matmul_1_kernel_shape, matmul_1_kernel_length = shape_and_length([500, 800])
matmul_1_kernel = ngraph.constant(
weights[weights_offset:weights_offset + matmul_1_kernel_length].reshape(matmul_1_kernel_shape)
weights[weights_offset:weights_offset + matmul_1_kernel_length].reshape(matmul_1_kernel_shape),
)
weights_offset += matmul_1_kernel_length
matmul_1_node = ngraph.matmul(reshape_1_node, matmul_1_kernel, False, True)
@ -141,7 +125,7 @@ def create_ngraph_function(args) -> Function:
# add 3
add_3_kernel_shape, add_3_kernel_length = shape_and_length([1, 500])
add_3_kernel = ngraph.constant(
weights[weights_offset:weights_offset + add_3_kernel_length].reshape(add_3_kernel_shape)
weights[weights_offset:weights_offset + add_3_kernel_length].reshape(add_3_kernel_shape),
)
weights_offset += add_3_kernel_length
add_3_node = ngraph.add(matmul_1_node, add_3_kernel)
@ -156,7 +140,7 @@ def create_ngraph_function(args) -> Function:
# matmul 2
matmul_2_kernel_shape, matmul_2_kernel_length = shape_and_length([10, 500])
matmul_2_kernel = ngraph.constant(
weights[weights_offset:weights_offset + matmul_2_kernel_length].reshape(matmul_2_kernel_shape)
weights[weights_offset:weights_offset + matmul_2_kernel_length].reshape(matmul_2_kernel_shape),
)
weights_offset += matmul_2_kernel_length
matmul_2_node = ngraph.matmul(reshape_2_node, matmul_2_kernel, False, True)
@ -164,7 +148,7 @@ def create_ngraph_function(args) -> Function:
# add 4
add_4_kernel_shape, add_4_kernel_length = shape_and_length([1, 10])
add_4_kernel = ngraph.constant(
weights[weights_offset:weights_offset + add_4_kernel_length].reshape(add_4_kernel_shape)
weights[weights_offset:weights_offset + add_4_kernel_length].reshape(add_4_kernel_shape),
)
weights_offset += add_4_kernel_length
add_4_node = ngraph.add(matmul_2_node, add_4_kernel)
@ -175,85 +159,101 @@ def create_ngraph_function(args) -> Function:
# result
result_node = ngraph.result(softmax_node)
# nGraph function
function = Function(result_node, [param_node], 'lenet')
return function
return ngraph.impl.Function(result_node, [param_node], 'lenet')
def main():
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
args = build_argparser().parse_args()
args = parse_args()
input_images = list_input_images(args.input)
# Loading network using ngraph function
ngraph_function = create_ngraph_function(args)
net = IENetwork(Function.to_capsule(ngraph_function))
assert len(net.input_info.keys()) == 1, "Sample supports only single input topologies"
assert len(net.outputs) == 1, "Sample supports only single output topologies"
log.info("Preparing input blobs")
input_blob = next(iter(net.input_info))
out_blob = next(iter(net.outputs))
net.batch_size = len(input_images)
# Read and pre-process input images
n, c, h, w = net.input_info[input_blob].input_data.shape
images = np.ndarray(shape=(n, c, h, w))
for i in range(n):
image = read_image(input_images[i])
assert image is not None, log.error(f"Can't open an image {input_images[i]}")
assert len(image.shape) == 2, log.error('Sample supports images with 1 channel only')
if image.shape[:] != (w, h):
log.warning(f"Image {input_images[i]} is resized from {image.shape[:]} to {(w, h)}")
image = cv2.resize(image, (w, h))
images[i] = image
log.info(f"Batch size is {n}")
log.info("Creating Inference Engine")
# ---------------------------Step 1. Initialize inference engine core--------------------------------------------------
log.info('Creating Inference Engine')
ie = IECore()
log.info('Loading model to the device')
exec_net = ie.load_network(network=net, device_name=args.device.upper())
# ---------------------------Step 2. Read a model in OpenVINO Intermediate Representation------------------------------
log.info(f'Loading the network using ngraph function with weights from {args.model}')
ngraph_function = create_ngraph_function(args)
net = IENetwork(ngraph.impl.Function.to_capsule(ngraph_function))
# Start sync inference
log.info('Creating infer request and starting inference')
res = exec_net.infer(inputs={input_blob: images})
# ---------------------------Step 3. Configure input & output----------------------------------------------------------
log.info('Configuring input and output blobs')
# Get names of input and output blobs
input_blob = next(iter(net.input_info))
out_blob = next(iter(net.outputs))
# Processing results
log.info("Processing output blob")
res = res[out_blob]
log.info(f"Top {args.number_top} results: ")
# Set input and output precision manually
net.input_info[input_blob].precision = 'U8'
net.outputs[out_blob].precision = 'FP32'
# Read labels file if it is provided as argument
labels_map = None
# Set a batch size to a equal number of input images
net.batch_size = len(args.input)
# ---------------------------Step 4. Loading model to the device-------------------------------------------------------
log.info('Loading the model to the plugin')
exec_net = ie.load_network(network=net, device_name=args.device)
# ---------------------------Step 5. Create infer request--------------------------------------------------------------
# load_network() method of the IECore class with a specified number of requests (default 1) returns an ExecutableNetwork
# instance which stores infer requests. So you already created Infer requests in the previous step.
# ---------------------------Step 6. Prepare input---------------------------------------------------------------------
n, c, h, w = net.input_info[input_blob].input_data.shape
input_data = np.ndarray(shape=(n, c, h, w))
for i in range(n):
image = read_image(args.input[i])
light_pixel_count = np.count_nonzero(image > 127)
darK_pixel_count = np.count_nonzero(image < 127)
is_light_image = (light_pixel_count - darK_pixel_count) > 0
if is_light_image:
log.warning(f'Image {args.input[i]} is inverted to white over black')
image = cv2.bitwise_not(image)
if image.shape != (h, w):
log.warning(f'Image {args.input[i]} is resized from {image.shape} to {(h, w)}')
image = cv2.resize(image, (w, h))
input_data[i] = image
# ---------------------------Step 7. Do inference----------------------------------------------------------------------
log.info('Starting inference in synchronous mode')
res = exec_net.infer(inputs={input_blob: input_data})
# ---------------------------Step 8. Process output--------------------------------------------------------------------
# Generate a label list
if args.labels:
with open(args.labels, 'r') as f:
labels_map = [x.split(sep=' ', maxsplit=1)[-1].strip() for x in f]
labels = [line.split(',')[0].strip() for line in f]
classid_str = "classid"
probability_str = "probability"
for i, probs in enumerate(res):
probs = np.squeeze(probs)
top_ind = np.argsort(probs)[-args.number_top:][::-1]
print(f"Image {input_images[i]}\n")
print(classid_str, probability_str)
print(f"{'-' * len(classid_str)} {'-' * len(probability_str)}")
for class_id in top_ind:
det_label = labels_map[class_id] if labels_map else f"{class_id}"
label_length = len(det_label)
space_num_before = (len(classid_str) - label_length) // 2
space_num_after = len(classid_str) - (space_num_before + label_length) + 2
print(f"{' ' * space_num_before}{det_label}"
f"{' ' * space_num_after}{probs[class_id]:.7f}")
print("\n")
res = res[out_blob]
log.info('This sample is an API example, for any performance measurements '
'please use the dedicated benchmark_app tool')
for i in range(n):
probs = res[i]
# Get an array of args.number_top class IDs in descending order of probability
top_n_idexes = np.argsort(probs)[-args.number_top:][::-1]
header = 'classid probability'
header = header + ' label' if args.labels else header
log.info(f'Image path: {args.input[i]}')
log.info(f'Top {args.number_top} results: ')
log.info(header)
log.info('-' * len(header))
for class_id in top_n_idexes:
probability_indent = ' ' * (len('classid') - len(str(class_id)) + 1)
label_indent = ' ' * (len('probability') - 8) if args.labels else ''
label = labels[class_id] if args.labels else ''
log.info(f'{class_id}{probability_indent}{probs[class_id]:.7f}{label_indent}{label}')
log.info('')
# ----------------------------------------------------------------------------------------------------------------------
log.info('This sample is an API example, '
'for any performance measurements please use the dedicated benchmark_app tool\n')
return 0
if __name__ == '__main__':
sys.exit(main() or 0)
sys.exit(main())

View File

@ -1,193 +1,158 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from __future__ import print_function
import argparse
import logging as log
import sys
import os
from argparse import ArgumentParser, SUPPRESS
import cv2
import numpy as np
import logging as log
from openvino.inference_engine import IECore
import ngraph as ng
def build_argparser():
parser = ArgumentParser(add_help=False)
args = parser.add_argument_group("Options")
args.add_argument('-h', '--help', action='help', default=SUPPRESS, help='Show this help message and exit.')
args.add_argument("-m", "--model", help="Required. Path to an .xml or .onnx file with a trained model.",
required=True, type=str)
args.add_argument("-i", "--input", help="Required. Path to an image file.",
required=True, type=str)
args.add_argument("-l", "--cpu_extension",
help="Optional. Required for CPU custom layers. "
"Absolute path to a shared library with the kernels implementations.",
type=str, default=None)
args.add_argument("-d", "--device",
help="Optional. Specify the target device to infer on; "
"CPU, GPU, FPGA or MYRIAD is acceptable. "
"Sample will look for a suitable plugin for device specified (CPU by default)",
default="CPU", type=str)
args.add_argument("--labels", help="Optional. Labels mapping file", default=None, type=str)
args.add_argument("-nt", "--number_top", help="Optional. Number of top results", default=10, type=int)
def parse_args() -> argparse.Namespace:
'''Parse and return command line arguments'''
parser = argparse.ArgumentParser(add_help=False)
args = parser.add_argument_group('Options')
args.add_argument('-h', '--help', action='help', help='Show this help message and exit.')
args.add_argument('-m', '--model', required=True, type=str,
help='Required. Path to an .xml or .onnx file with a trained model.')
args.add_argument('-i', '--input', required=True, type=str, help='Required. Path to an image file.')
args.add_argument('-l', '--extension', type=str, default=None,
help='Optional. Required by the CPU Plugin for executing the custom operation on a CPU. '
'Absolute path to a shared library with the kernels implementations.')
args.add_argument('-c', '--config', type=str, default=None,
help='Optional. Required by GPU or VPU Plugins for the custom operation kernel. '
'Absolute path to operation description file (.xml).')
args.add_argument('-d', '--device', default='CPU', type=str,
help='Optional. Specify the target device to infer on; CPU, GPU, MYRIAD, HDDL or HETERO: '
'is acceptable. The sample will look for a suitable plugin for device specified. '
'Default value is CPU.')
args.add_argument('--labels', default=None, type=str, help='Optional. Path to a labels mapping file.')
return parser
return parser.parse_args()
def main():
log.basicConfig(format="[ %(levelname)s ] %(message)s", level=log.INFO, stream=sys.stdout)
args = build_argparser().parse_args()
log.info("Loading Inference Engine")
def main(): # noqa
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
args = parse_args()
# ---------------------------Step 1. Initialize inference engine core--------------------------------------------------
log.info('Creating Inference Engine')
ie = IECore()
# ---1. Read a model in OpenVINO Intermediate Representation (.xml and .bin files) or ONNX (.onnx file) format ---
model = args.model
log.info(f"Loading network:\n\t{model}")
net = ie.read_network(model=model)
# -----------------------------------------------------------------------------------------------------
if args.extension and args.device == 'CPU':
log.info(f'Loading the {args.device} extension: {args.extension}')
ie.add_extension(args.extension, args.device)
# ------------- 2. Load Plugin for inference engine and extensions library if specified --------------
log.info("Device info:")
versions = ie.get_versions(args.device)
print(f"{' ' * 8}{args.device}")
print(f"{' ' * 8}MKLDNNPlugin version ......... {versions[args.device].major}.{versions[args.device].minor}")
print(f"{' ' * 8}Build ........... {versions[args.device].build_number}")
if args.config and args.device in ('GPU', 'MYRIAD', 'HDDL'):
log.info(f'Loading the {args.device} configuration: {args.config}')
ie.set_config({'CONFIG_FILE': args.config}, args.device)
if args.cpu_extension and "CPU" in args.device:
ie.add_extension(args.cpu_extension, "CPU")
log.info(f"CPU extension loaded: {args.cpu_extension}")
# -----------------------------------------------------------------------------------------------------
# ---------------------------Step 2. Read a model in OpenVINO Intermediate Representation or ONNX format---------------
log.info(f'Reading the network: {args.model}')
# (.xml and .bin files) or (.onnx file)
net = ie.read_network(model=args.model)
# --------------------------- 3. Read and preprocess input --------------------------------------------
for input_key in net.input_info:
if len(net.input_info[input_key].input_data.layout) == 4:
n, c, h, w = net.input_info[input_key].input_data.shape
if len(net.input_info) != 1:
log.error('The sample supports only single input topologies')
return - 1
images = np.ndarray(shape=(n, c, h, w))
images_hw = []
for i in range(n):
image = cv2.imread(args.input)
ih, iw = image.shape[:-1]
images_hw.append((ih, iw))
log.info("File was added: ")
log.info(f" {args.input}")
if (ih, iw) != (h, w):
log.warning(f"Image {args.input} is resized from {image.shape[:-1]} to {(h, w)}")
image = cv2.resize(image, (w, h))
image = image.transpose((2, 0, 1)) # Change data layout from HWC to CHW
images[i] = image
# -----------------------------------------------------------------------------------------------------
if len(net.outputs) != 1 and not ('boxes' in net.outputs or 'labels' in net.outputs):
log.error('The sample supports models with 1 output or with 2 with the names "boxes" and "labels"')
return -1
# --------------------------- 4. Configure input & output ---------------------------------------------
# --------------------------- Prepare input blobs -----------------------------------------------------
log.info("Preparing input blobs")
assert (len(net.input_info.keys()) == 1 or len(
net.input_info.keys()) == 2), "Sample supports topologies only with 1 or 2 inputs"
out_blob = next(iter(net.outputs))
input_name, input_info_name = "", ""
# ---------------------------Step 3. Configure input & output----------------------------------------------------------
log.info('Configuring input and output blobs')
# Get name of input blob
input_blob = next(iter(net.input_info))
for input_key in net.input_info:
if len(net.input_info[input_key].layout) == 4:
input_name = input_key
net.input_info[input_key].precision = 'U8'
elif len(net.input_info[input_key].layout) == 2:
input_info_name = input_key
net.input_info[input_key].precision = 'FP32'
if net.input_info[input_key].input_data.shape[1] != 3 and net.input_info[input_key].input_data.shape[1] != 6 or \
net.input_info[input_key].input_data.shape[0] != 1:
log.error('Invalid input info. Should be 3 or 6 values length.')
# Set input and output precision manually
net.input_info[input_blob].precision = 'U8'
data = {}
data[input_name] = images
if input_info_name != "":
detection_size = net.input_info[input_info_name].input_data.shape[1]
infos = np.ndarray(shape=(n, detection_size), dtype=float)
for i in range(n):
infos[i, 0] = h
infos[i, 1] = w
for j in range(2, detection_size):
infos[i, j] = 1.0
data[input_info_name] = infos
# --------------------------- Prepare output blobs ----------------------------------------------------
log.info('Preparing output blobs')
output_name, output_info = "", None
func = ng.function_from_cnn(net)
if func:
ops = func.get_ordered_ops()
for op in ops:
if op.friendly_name in net.outputs and op.get_type_name() == "DetectionOutput":
output_name = op.friendly_name
output_info = net.outputs[output_name]
break
if len(net.outputs) == 1:
output_blob = next(iter(net.outputs))
net.outputs[output_blob].precision = 'FP32'
else:
output_name = list(net.outputs.keys())[0]
output_info = net.outputs[output_name]
net.outputs['boxes'].precision = 'FP32'
net.outputs['labels'].precision = 'U16'
if output_name == "":
log.error("Can't find a DetectionOutput layer in the topology")
output_dims = output_info.shape
if len(output_dims) != 4:
log.error("Incorrect output dimensions for SSD model")
max_proposal_count, object_size = output_dims[2], output_dims[3]
if object_size != 7:
log.error("Output item should have 7 as a last dimension")
output_info.precision = "FP32"
# -----------------------------------------------------------------------------------------------------
# --------------------------- Performing inference ----------------------------------------------------
log.info("Loading model to the device")
# ---------------------------Step 4. Loading model to the device-------------------------------------------------------
log.info('Loading the model to the plugin')
exec_net = ie.load_network(network=net, device_name=args.device)
log.info("Creating infer request and starting inference")
res = exec_net.infer(inputs=data)
# -----------------------------------------------------------------------------------------------------
# --------------------------- Read and postprocess output ---------------------------------------------
log.info("Processing output blobs")
res = res[out_blob]
boxes, classes = {}, {}
data = res[0][0]
for number, proposal in enumerate(data):
if proposal[2] > 0:
imid = np.int(proposal[0])
ih, iw = images_hw[imid]
label = np.int(proposal[1])
confidence = proposal[2]
xmin = np.int(iw * proposal[3])
ymin = np.int(ih * proposal[4])
xmax = np.int(iw * proposal[5])
ymax = np.int(ih * proposal[6])
print(f"[{number},{label}] element, prob = {confidence:.6f} ({xmin},{ymin})-({xmax},{ymax}) "
f"batch id : {imid}", end="")
if proposal[2] > 0.5:
print(" WILL BE PRINTED!")
if not imid in boxes.keys():
boxes[imid] = []
boxes[imid].append([xmin, ymin, xmax, ymax])
if not imid in classes.keys():
classes[imid] = []
classes[imid].append(label)
else:
print()
# ---------------------------Step 5. Create infer request--------------------------------------------------------------
# load_network() method of the IECore class with a specified number of requests (default 1) returns an ExecutableNetwork
# instance which stores infer requests. So you already created Infer requests in the previous step.
tmp_image = cv2.imread(args.input)
for imid in classes:
for box in boxes[imid]:
cv2.rectangle(tmp_image, (box[0], box[1]), (box[2], box[3]), (232, 35, 244), 2)
cv2.imwrite("out.bmp", tmp_image)
log.info("Image out.bmp created!")
# -----------------------------------------------------------------------------------------------------
# ---------------------------Step 6. Prepare input---------------------------------------------------------------------
original_image = cv2.imread(args.input)
image = original_image.copy()
_, _, net_h, net_w = net.input_info[input_blob].input_data.shape
log.info("Execution successful\n")
log.info(
"This sample is an API example, for any performance measurements please use the dedicated benchmark_app tool")
if image.shape[:-1] != (net_h, net_w):
log.warning(f'Image {args.input} is resized from {image.shape[:-1]} to {(net_h, net_w)}')
image = cv2.resize(image, (net_w, net_h))
# Change data layout from HWC to CHW
image = image.transpose((2, 0, 1))
# Add N dimension to transform to NCHW
image = np.expand_dims(image, axis=0)
# ---------------------------Step 7. Do inference----------------------------------------------------------------------
log.info('Starting inference in synchronous mode')
res = exec_net.infer(inputs={input_blob: image})
# ---------------------------Step 8. Process output--------------------------------------------------------------------
# Generate a label list
if args.labels:
with open(args.labels, 'r') as f:
labels = [line.split(',')[0].strip() for line in f]
output_image = original_image.copy()
h, w, _ = output_image.shape
if len(net.outputs) == 1:
res = res[output_blob]
# Change a shape of a numpy.ndarray with results ([1, 1, N, 7]) to get another one ([N, 7]),
# where N is the number of detected bounding boxes
detections = res.reshape(-1, 7)
else:
detections = res['boxes']
labels = res['labels']
# Redefine scale coefficients
w, h = w / net_w, h / net_h
for i, detection in enumerate(detections):
if len(net.outputs) == 1:
_, class_id, confidence, xmin, ymin, xmax, ymax = detection
else:
class_id = labels[i]
xmin, ymin, xmax, ymax, confidence = detection
if confidence > 0.5:
label = int(labels[class_id]) if args.labels else int(class_id)
xmin = int(xmin * w)
ymin = int(ymin * h)
xmax = int(xmax * w)
ymax = int(ymax * h)
log.info(f'Found: label = {label}, confidence = {confidence:.2f}, '
f'coords = ({xmin}, {ymin}), ({xmax}, {ymax})')
# Draw a bounding box on a output image
cv2.rectangle(output_image, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2)
cv2.imwrite('out.bmp', output_image)
log.info('Image out.bmp created!')
# ----------------------------------------------------------------------------------------------------------------------
log.info('This sample is an API example, '
'for any performance measurements please use the dedicated benchmark_app tool\n')
return 0
if __name__ == '__main__':
sys.exit(main() or 0)
sys.exit(main())

View File

@ -1,104 +1,146 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2021 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
from __future__ import print_function
import argparse
import logging as log
import sys
import os
from argparse import ArgumentParser, SUPPRESS
import cv2
import numpy as np
import logging as log
from openvino.inference_engine import IECore
def build_argparser():
parser = ArgumentParser(add_help=False)
def parse_args() -> argparse.Namespace:
'''Parse and return command line arguments'''
parser = argparse.ArgumentParser(add_help=False)
args = parser.add_argument_group('Options')
args.add_argument('-h', '--help', action='help', default=SUPPRESS, help='Show this help message and exit.')
args.add_argument("-m", "--model", help="Required. Path to an .xml or .onnx file with a trained model.", required=True, type=str)
args.add_argument("-i", "--input", help="Required. Path to an image files", required=True,
type=str, nargs="+")
args.add_argument("-l", "--cpu_extension",
help="Optional. Required for CPU custom layers. "
"Absolute MKLDNN (CPU)-targeted custom layers. Absolute path to a shared library with the "
"kernels implementations", type=str, default=None)
args.add_argument("-d", "--device",
help="Optional. Specify the target device to infer on; CPU, GPU, FPGA, HDDL or MYRIAD is acceptable. Sample "
"will look for a suitable plugin for device specified. Default value is CPU", default="CPU",
type=str)
args.add_argument("-nt", "--number_top", help="Optional. Number of top results", default=10, type=int)
args.add_argument("--mean_val_r", "-mean_val_r",
help="Optional. Mean value of red channel for mean value subtraction in postprocessing ", default=0,
type=float)
args.add_argument("--mean_val_g", "-mean_val_g",
help="Optional. Mean value of green channel for mean value subtraction in postprocessing ", default=0,
type=float)
args.add_argument("--mean_val_b", "-mean_val_b",
help="Optional. Mean value of blue channel for mean value subtraction in postprocessing ", default=0,
type=float)
return parser
args.add_argument('-h', '--help', action='help', help='Show this help message and exit.')
args.add_argument('-m', '--model', required=True, type=str,
help='Required. Path to an .xml or .onnx file with a trained model.')
args.add_argument('-i', '--input', required=True, type=str, nargs='+', help='Required. Path to an image file.')
args.add_argument('-l', '--extension', type=str, default=None,
help='Optional. Required by the CPU Plugin for executing the custom operation on a CPU. '
'Absolute path to a shared library with the kernels implementations.')
args.add_argument('-c', '--config', type=str, default=None,
help='Optional. Required by GPU or VPU Plugins for the custom operation kernel. '
'Absolute path to operation description file (.xml).')
args.add_argument('-d', '--device', default='CPU', type=str,
help='Optional. Specify the target device to infer on; CPU, GPU, MYRIAD, HDDL or HETERO: '
'is acceptable. The sample will look for a suitable plugin for device specified. '
'Default value is CPU.')
args.add_argument('--original_size', action='store_true', default=False,
help='Optional. Resize an output image to original image size.')
args.add_argument('--mean_val_r', default=0, type=float,
help='Optional. Mean value of red channel for mean value subtraction in postprocessing.')
args.add_argument('--mean_val_g', default=0, type=float,
help='Optional. Mean value of green channel for mean value subtraction in postprocessing.')
args.add_argument('--mean_val_b', default=0, type=float,
help='Optional. Mean value of blue channel for mean value subtraction in postprocessing.')
return parser.parse_args()
def main():
log.basicConfig(format="[ %(levelname)s ] %(message)s", level=log.INFO, stream=sys.stdout)
args = build_argparser().parse_args()
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
args = parse_args()
# Plugin initialization for specified device and load extensions library if specified
log.info("Creating Inference Engine")
# ---------------------------Step 1. Initialize inference engine core--------------------------------------------------
log.info('Creating Inference Engine')
ie = IECore()
if args.cpu_extension and 'CPU' in args.device:
ie.add_extension(args.cpu_extension, "CPU")
# Read a model in OpenVINO Intermediate Representation (.xml and .bin files) or ONNX (.onnx file) format
model = args.model
log.info(f"Loading network:\n\t{model}")
net = ie.read_network(model=model)
if args.extension and args.device == 'CPU':
log.info(f'Loading the {args.device} extension: {args.extension}')
ie.add_extension(args.extension, args.device)
assert len(net.input_info.keys()) == 1, "Sample supports only single input topologies"
assert len(net.outputs) == 1, "Sample supports only single output topologies"
if args.config and args.device in ('GPU', 'MYRIAD', 'HDDL'):
log.info(f'Loading the {args.device} configuration: {args.config}')
ie.set_config({'CONFIG_FILE': args.config}, args.device)
log.info("Preparing input blobs")
# ---------------------------Step 2. Read a model in OpenVINO Intermediate Representation or ONNX format---------------
log.info(f'Reading the network: {args.model}')
# (.xml and .bin files) or (.onnx file)
net = ie.read_network(model=args.model)
if len(net.input_info) != 1:
log.error('Sample supports only single input topologies')
return -1
if len(net.outputs) != 1:
log.error('Sample supports only single output topologies')
return -1
# ---------------------------Step 3. Configure input & output----------------------------------------------------------
log.info('Configuring input and output blobs')
# Get names of input and output blobs
input_blob = next(iter(net.input_info))
out_blob = next(iter(net.outputs))
# Set input and output precision manually
net.input_info[input_blob].precision = 'U8'
net.outputs[out_blob].precision = 'FP32'
# Set a batch size to a equal number of input images
net.batch_size = len(args.input)
# Read and pre-process input images
n, c, h, w = net.input_info[input_blob].input_data.shape
images = np.ndarray(shape=(n, c, h, w))
for i in range(n):
image = cv2.imread(args.input[i])
if image.shape[:-1] != (h, w):
log.warning(f"Image {args.input[i]} is resized from {image.shape[:-1]} to {(h, w)}")
image = cv2.resize(image, (w, h))
image = image.transpose((2, 0, 1)) # Change data layout from HWC to CHW
images[i] = image
log.info(f"Batch size is {n}")
# Loading model to the plugin
log.info("Loading model to the plugin")
# ---------------------------Step 4. Loading model to the device-------------------------------------------------------
log.info('Loading the model to the plugin')
exec_net = ie.load_network(network=net, device_name=args.device)
# Start sync inference
log.info("Starting inference")
res = exec_net.infer(inputs={input_blob: images})
# ---------------------------Step 5. Create infer request--------------------------------------------------------------
# load_network() method of the IECore class with a specified number of requests (default 1) returns an ExecutableNetwork
# instance which stores infer requests. So you already created Infer requests in the previous step.
# Processing output blob
log.info("Processing output blob")
# ---------------------------Step 6. Prepare input---------------------------------------------------------------------
original_images = []
n, c, h, w = net.input_info[input_blob].input_data.shape
input_data = np.ndarray(shape=(n, c, h, w))
for i in range(n):
image = cv2.imread(args.input[i])
original_images.append(image)
if image.shape[:-1] != (h, w):
log.warning(f'Image {args.input[i]} is resized from {image.shape[:-1]} to {(h, w)}')
image = cv2.resize(image, (w, h))
# Change data layout from HWC to CHW
image = image.transpose((2, 0, 1))
input_data[i] = image
# ---------------------------Step 7. Do inference----------------------------------------------------------------------
log.info('Starting inference in synchronous mode')
res = exec_net.infer(inputs={input_blob: input_data})
# ---------------------------Step 8. Process output--------------------------------------------------------------------
res = res[out_blob]
# Post process output
for batch, data in enumerate(res):
# Clip values to [0, 255] range
data = np.swapaxes(data, 0, 2)
data = np.swapaxes(data, 0, 1)
data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
data[data < 0] = 0
data[data > 255] = 255
data = data[::] - (args.mean_val_r, args.mean_val_g, args.mean_val_b)
out_img = os.path.join(os.path.dirname(__file__), f"out_{batch}.bmp")
cv2.imwrite(out_img, data)
log.info(f"Result image was saved to {out_img}")
log.info("This sample is an API example, for any performance measurements please use the dedicated benchmark_app tool\n")
for i in range(n):
output_image = res[i]
# Change data layout from CHW to HWC
output_image = output_image.transpose((1, 2, 0))
# Convert BGR color order to RGB
output_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
# Apply mean argument values
output_image = output_image[::] - (args.mean_val_r, args.mean_val_g, args.mean_val_b)
# Set pixel values bitween 0 and 255
output_image = np.clip(output_image, 0, 255)
# Resize a output image to original size
if args.original_size:
h, w, _ = original_images[i].shape
output_image = cv2.resize(output_image, (w, h))
cv2.imwrite(f'out_{i}.bmp', output_image)
log.info(f'Image out_{i}.bmp created!')
# ----------------------------------------------------------------------------------------------------------------------
log.info('This sample is an API example, '
'for any performance measurements please use the dedicated benchmark_app tool\n')
return 0
if __name__ == '__main__':
sys.exit(main() or 0)
sys.exit(main())