Update IE Python Samples (#5166)
* refactor: update ie python samples * python samples: change comment about infer request creation (step 5) * python sample: add the ability to run object_detection_sample_ssd.py with a model with 2 outputs * Add batch size usage to python style transfer sample * Change comment about model reading * Add output queue to classification async sample * add reshape for output to catch results with more than 2 dimensions (classification samples) * Set a log output stream to stdout to pass the hello query device test * Add comments to the hello query device sample * Set sys.stdout as a logging stream for all python IE samples * Add batch size usage to ngraph_function_creation_sample * Return the ability to read an image from a ubyte file * Add few comments and function docstrings * Restore IE python classification samples output * Add --original_size arg for python style transfer sample * Change log message to pass tests (object detection ie python sample) * Return python shebang * Add comment about a probs array sorting using np.argsort * Fix the hello query python sample (Ticket: 52937) * Add color inversion for light images for correct predictions * Add few log messages to the python device query sample
This commit is contained in:
parent
9709432d29
commit
19ace232cf
266
inference-engine/ie_bridges/python/sample/classification_sample_async/classification_sample_async.py
Normal file → Executable file
266
inference-engine/ie_bridges/python/sample/classification_sample_async/classification_sample_async.py
Normal file → Executable file
@ -1,160 +1,168 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import logging as log
|
||||
import sys
|
||||
import os
|
||||
from argparse import ArgumentParser, SUPPRESS
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import logging as log
|
||||
from openvino.inference_engine import IECore
|
||||
import threading
|
||||
from openvino.inference_engine import IECore, StatusCode
|
||||
|
||||
|
||||
class InferReqWrap:
|
||||
def __init__(self, request, id, num_iter):
|
||||
self.id = id
|
||||
self.request = request
|
||||
self.num_iter = num_iter
|
||||
self.cur_iter = 0
|
||||
self.cv = threading.Condition()
|
||||
self.request.set_completion_callback(self.callback, self.id)
|
||||
|
||||
def callback(self, statusCode, userdata):
|
||||
if (userdata != self.id):
|
||||
log.error(f"Request ID {self.id} does not correspond to user data {userdata}")
|
||||
elif statusCode != 0:
|
||||
log.error(f"Request {self.id} failed with status code {statusCode}")
|
||||
self.cur_iter += 1
|
||||
log.info(f"Completed {self.cur_iter} Async request execution")
|
||||
if self.cur_iter < self.num_iter:
|
||||
# here a user can read output containing inference results and put new input
|
||||
# to repeat async request again
|
||||
self.request.async_infer(self.input)
|
||||
else:
|
||||
# continue sample execution after last Asynchronous inference request execution
|
||||
self.cv.acquire()
|
||||
self.cv.notify()
|
||||
self.cv.release()
|
||||
|
||||
def execute(self, mode, input_data):
|
||||
if (mode == "async"):
|
||||
log.info(f"Start inference ({self.num_iter} Asynchronous executions)")
|
||||
self.input = input_data
|
||||
# Start async request for the first time. Wait all repetitions of the async request
|
||||
self.request.async_infer(input_data)
|
||||
self.cv.acquire()
|
||||
self.cv.wait()
|
||||
self.cv.release()
|
||||
elif (mode == "sync"):
|
||||
log.info(f"Start inference ({self.num_iter} Synchronous executions)")
|
||||
for self.cur_iter in range(self.num_iter):
|
||||
# here we start inference synchronously and wait for
|
||||
# last inference request execution
|
||||
self.request.infer(input_data)
|
||||
log.info(f"Completed {self.cur_iter + 1} Sync request execution")
|
||||
else:
|
||||
log.error("wrong inference mode is chosen. Please use \"sync\" or \"async\" mode")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
|
||||
def build_argparser():
|
||||
parser = ArgumentParser(add_help=False)
|
||||
def parse_args() -> argparse.Namespace:
|
||||
'''Parse and return command line arguments'''
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
args = parser.add_argument_group('Options')
|
||||
args.add_argument('-h', '--help', action='help', default=SUPPRESS, help='Show this help message and exit.')
|
||||
args.add_argument("-m", "--model", help="Required. Path to an .xml or .onnx file with a trained model.",
|
||||
required=True, type=str)
|
||||
args.add_argument("-i", "--input", help="Required. Path to an image files",
|
||||
required=True, type=str, nargs="+")
|
||||
args.add_argument("-l", "--cpu_extension",
|
||||
help="Optional. Required for CPU custom layers. Absolute path to a shared library with the"
|
||||
" kernels implementations.", type=str, default=None)
|
||||
args.add_argument("-d", "--device",
|
||||
help="Optional. Specify the target device to infer on; CPU, GPU, FPGA, HDDL or MYRIAD is "
|
||||
"acceptable. The sample will look for a suitable plugin for device specified. Default value is CPU",
|
||||
default="CPU", type=str)
|
||||
args.add_argument("--labels", help="Optional. Labels mapping file", default=None, type=str)
|
||||
args.add_argument("-nt", "--number_top", help="Optional. Number of top results", default=10, type=int)
|
||||
args.add_argument('-h', '--help', action='help', help='Show this help message and exit.')
|
||||
args.add_argument('-m', '--model', required=True, type=str,
|
||||
help='Required. Path to an .xml or .onnx file with a trained model.')
|
||||
args.add_argument('-i', '--input', required=True, type=str, nargs='+', help='Required. Path to an image file(s).')
|
||||
args.add_argument('-l', '--extension', type=str, default=None,
|
||||
help='Optional. Required by the CPU Plugin for executing the custom operation on a CPU. '
|
||||
'Absolute path to a shared library with the kernels implementations.')
|
||||
args.add_argument('-c', '--config', type=str, default=None,
|
||||
help='Optional. Required by GPU or VPU Plugins for the custom operation kernel. '
|
||||
'Absolute path to operation description file (.xml).')
|
||||
args.add_argument('-d', '--device', default='CPU', type=str,
|
||||
help='Optional. Specify the target device to infer on; CPU, GPU, MYRIAD, HDDL or HETERO: '
|
||||
'is acceptable. The sample will look for a suitable plugin for device specified. '
|
||||
'Default value is CPU.')
|
||||
args.add_argument('--labels', default=None, type=str, help='Optional. Path to a labels mapping file.')
|
||||
args.add_argument('-nt', '--number_top', default=10, type=int, help='Optional. Number of top results.')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
return parser
|
||||
|
||||
def main():
|
||||
log.basicConfig(format="[ %(levelname)s ] %(message)s", level=log.INFO, stream=sys.stdout)
|
||||
args = build_argparser().parse_args()
|
||||
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
|
||||
args = parse_args()
|
||||
|
||||
# Plugin initialization for specified device and load extensions library if specified
|
||||
log.info("Creating Inference Engine")
|
||||
# ---------------------------Step 1. Initialize inference engine core--------------------------------------------------
|
||||
log.info('Creating Inference Engine')
|
||||
ie = IECore()
|
||||
if args.cpu_extension and 'CPU' in args.device:
|
||||
ie.add_extension(args.cpu_extension, "CPU")
|
||||
|
||||
# Read a model in OpenVINO Intermediate Representation (.xml and .bin files) or ONNX (.onnx file) format
|
||||
model = args.model
|
||||
log.info(f"Loading network:\n\t{model}")
|
||||
net = ie.read_network(model=model)
|
||||
if args.extension and args.device == 'CPU':
|
||||
log.info(f'Loading the {args.device} extension: {args.extension}')
|
||||
ie.add_extension(args.extension, args.device)
|
||||
|
||||
assert len(net.input_info.keys()) == 1, "Sample supports only single input topologies"
|
||||
assert len(net.outputs) == 1, "Sample supports only single output topologies"
|
||||
if args.config and args.device in ('GPU', 'MYRIAD', 'HDDL'):
|
||||
log.info(f'Loading the {args.device} configuration: {args.config}')
|
||||
ie.set_config({'CONFIG_FILE': args.config}, args.device)
|
||||
|
||||
log.info("Preparing input blobs")
|
||||
# ---------------------------Step 2. Read a model in OpenVINO Intermediate Representation or ONNX format---------------
|
||||
log.info(f'Reading the network: {args.model}')
|
||||
# (.xml and .bin files) or (.onnx file)
|
||||
net = ie.read_network(model=args.model)
|
||||
|
||||
if len(net.input_info) != 1:
|
||||
log.error('Sample supports only single input topologies')
|
||||
return -1
|
||||
if len(net.outputs) != 1:
|
||||
log.error('Sample supports only single output topologies')
|
||||
return -1
|
||||
|
||||
# ---------------------------Step 3. Configure input & output----------------------------------------------------------
|
||||
log.info('Configuring input and output blobs')
|
||||
# Get names of input and output blobs
|
||||
input_blob = next(iter(net.input_info))
|
||||
out_blob = next(iter(net.outputs))
|
||||
net.batch_size = len(args.input)
|
||||
|
||||
# Read and pre-process input images
|
||||
n, c, h, w = net.input_info[input_blob].input_data.shape
|
||||
images = np.ndarray(shape=(n, c, h, w))
|
||||
for i in range(n):
|
||||
# Set input and output precision manually
|
||||
net.input_info[input_blob].precision = 'U8'
|
||||
net.outputs[out_blob].precision = 'FP32'
|
||||
|
||||
# Get a number of input images
|
||||
num_of_input = len(args.input)
|
||||
# Get a number of classes recognized by a model
|
||||
num_of_classes = max(net.outputs[out_blob].shape)
|
||||
|
||||
# ---------------------------Step 4. Loading model to the device-------------------------------------------------------
|
||||
log.info('Loading the model to the plugin')
|
||||
exec_net = ie.load_network(network=net, device_name=args.device, num_requests=num_of_input)
|
||||
|
||||
# ---------------------------Step 5. Create infer request--------------------------------------------------------------
|
||||
# load_network() method of the IECore class with a specified number of requests (default 1) returns an ExecutableNetwork
|
||||
# instance which stores infer requests. So you already created Infer requests in the previous step.
|
||||
|
||||
# ---------------------------Step 6. Prepare input---------------------------------------------------------------------
|
||||
input_data = []
|
||||
_, _, h, w = net.input_info[input_blob].input_data.shape
|
||||
|
||||
for i in range(num_of_input):
|
||||
image = cv2.imread(args.input[i])
|
||||
|
||||
if image.shape[:-1] != (h, w):
|
||||
log.warning(f"Image {args.input[i]} is resized from {image.shape[:-1]} to {(h, w)}")
|
||||
log.warning(f'Image {args.input[i]} is resized from {image.shape[:-1]} to {(h, w)}')
|
||||
image = cv2.resize(image, (w, h))
|
||||
image = image.transpose((2, 0, 1)) # Change data layout from HWC to CHW
|
||||
images[i] = image
|
||||
log.info(f"Batch size is {n}")
|
||||
|
||||
# Loading model to the plugin
|
||||
log.info("Loading model to the plugin")
|
||||
exec_net = ie.load_network(network=net, device_name=args.device)
|
||||
# Change data layout from HWC to CHW
|
||||
image = image.transpose((2, 0, 1))
|
||||
# Add N dimension to transform to NCHW
|
||||
image = np.expand_dims(image, axis=0)
|
||||
|
||||
# create one inference request for asynchronous execution
|
||||
request_id = 0
|
||||
infer_request = exec_net.requests[request_id]
|
||||
input_data.append(image)
|
||||
|
||||
num_iter = 10
|
||||
request_wrap = InferReqWrap(infer_request, request_id, num_iter)
|
||||
# Start inference request execution. Wait for last execution being completed
|
||||
request_wrap.execute("async", {input_blob: images})
|
||||
# ---------------------------Step 7. Do inference----------------------------------------------------------------------
|
||||
log.info('Starting inference in asynchronous mode')
|
||||
for i in range(num_of_input):
|
||||
exec_net.requests[i].async_infer({input_blob: input_data[i]})
|
||||
|
||||
# Processing output blob
|
||||
log.info("Processing output blob")
|
||||
res = infer_request.output_blobs[out_blob]
|
||||
log.info(f"Top {args.number_top} results: ")
|
||||
# ---------------------------Step 8. Process output--------------------------------------------------------------------
|
||||
# Generate a label list
|
||||
if args.labels:
|
||||
with open(args.labels, 'r') as f:
|
||||
labels_map = [x.split(sep=' ', maxsplit=1)[-1].strip() for x in f]
|
||||
else:
|
||||
labels_map = None
|
||||
classid_str = "classid"
|
||||
probability_str = "probability"
|
||||
for i, probs in enumerate(res.buffer):
|
||||
probs = np.squeeze(probs)
|
||||
top_ind = np.argsort(probs)[-args.number_top:][::-1]
|
||||
print(f"Image {args.input[i]}\n")
|
||||
print(classid_str, probability_str)
|
||||
print(f"{'-' * len(classid_str)} {'-' * len(probability_str)}")
|
||||
for id in top_ind:
|
||||
det_label = labels_map[id] if labels_map else f"{id}"
|
||||
label_length = len(det_label)
|
||||
space_num_before = (7 - label_length) // 2
|
||||
space_num_after = 7 - (space_num_before + label_length) + 2
|
||||
print(f"{' ' * space_num_before}{det_label}"
|
||||
f"{' ' * space_num_after}{probs[id]:.7f}")
|
||||
print("\n")
|
||||
log.info("This sample is an API example, for any performance measurements please use the dedicated benchmark_app tool\n")
|
||||
labels = [line.split(',')[0].strip() for line in f]
|
||||
|
||||
# Create a list to control a order of output
|
||||
output_queue = list(range(num_of_input))
|
||||
|
||||
while True:
|
||||
for i in output_queue:
|
||||
# Immediately returns a inference status without blocking or interrupting
|
||||
infer_status = exec_net.requests[i].wait(0)
|
||||
|
||||
if infer_status == StatusCode.RESULT_NOT_READY:
|
||||
continue
|
||||
|
||||
log.info(f'Infer request {i} returned {infer_status}')
|
||||
|
||||
if infer_status != StatusCode.OK:
|
||||
return -2
|
||||
|
||||
# Read infer request results from buffer
|
||||
res = exec_net.requests[i].output_blobs[out_blob].buffer
|
||||
# Change a shape of a numpy.ndarray with results to get another one with one dimension
|
||||
probs = res.reshape(num_of_classes)
|
||||
# Get an array of args.number_top class IDs in descending order of probability
|
||||
top_n_idexes = np.argsort(probs)[-args.number_top:][::-1]
|
||||
|
||||
header = 'classid probability'
|
||||
header = header + ' label' if args.labels else header
|
||||
|
||||
log.info(f'Image path: {args.input[i]}')
|
||||
log.info(f'Top {args.number_top} results: ')
|
||||
log.info(header)
|
||||
log.info('-' * len(header))
|
||||
|
||||
for class_id in top_n_idexes:
|
||||
probability_indent = ' ' * (len('classid') - len(str(class_id)) + 1)
|
||||
label_indent = ' ' * (len('probability') - 8) if args.labels else ''
|
||||
label = labels[class_id] if args.labels else ''
|
||||
log.info(f'{class_id}{probability_indent}{probs[class_id]:.7f}{label_indent}{label}')
|
||||
log.info('')
|
||||
|
||||
output_queue.remove(i)
|
||||
|
||||
if len(output_queue) == 0:
|
||||
break
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
log.info('This sample is an API example, '
|
||||
'for any performance measurements please use the dedicated benchmark_app tool\n')
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main() or 0)
|
||||
sys.exit(main())
|
||||
|
173
inference-engine/ie_bridges/python/sample/hello_classification/hello_classification.py
Normal file → Executable file
173
inference-engine/ie_bridges/python/sample/hello_classification/hello_classification.py
Normal file → Executable file
@ -1,106 +1,125 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import logging as log
|
||||
import sys
|
||||
import os
|
||||
from argparse import ArgumentParser, SUPPRESS
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import logging as log
|
||||
from openvino.inference_engine import IECore
|
||||
|
||||
|
||||
def build_argparser():
|
||||
parser = ArgumentParser(add_help=False)
|
||||
def parse_args() -> argparse.Namespace:
|
||||
'''Parse and return command line arguments'''
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
args = parser.add_argument_group('Options')
|
||||
args.add_argument('-h', '--help', action='help', default=SUPPRESS, help='Show this help message and exit.')
|
||||
args.add_argument("-m", "--model", help="Required. Path to an .xml or .onnx file with a trained model.", required=True,
|
||||
type=str)
|
||||
args.add_argument("-i", "--input", help="Required. Path to an image file.",
|
||||
required=True, type=str)
|
||||
args.add_argument("-l", "--cpu_extension",
|
||||
help="Optional. Required for CPU custom layers. "
|
||||
"MKLDNN (CPU)-targeted custom layers. Absolute path to a shared library with the"
|
||||
" kernels implementations.", type=str, default=None)
|
||||
args.add_argument("-d", "--device",
|
||||
help="Optional. Specify the target device to infer on; CPU, GPU, FPGA, HDDL, MYRIAD or HETERO: is "
|
||||
"acceptable. The sample will look for a suitable plugin for device specified. Default "
|
||||
"value is CPU",
|
||||
default="CPU", type=str)
|
||||
args.add_argument("--labels", help="Optional. Path to a labels mapping file", default=None, type=str)
|
||||
args.add_argument("-nt", "--number_top", help="Optional. Number of top results", default=10, type=int)
|
||||
args.add_argument('-h', '--help', action='help', help='Show this help message and exit.')
|
||||
args.add_argument('-m', '--model', required=True, type=str,
|
||||
help='Required. Path to an .xml or .onnx file with a trained model.')
|
||||
args.add_argument('-i', '--input', required=True, type=str, help='Required. Path to an image file.')
|
||||
args.add_argument('-d', '--device', default='CPU', type=str,
|
||||
help='Optional. Specify the target device to infer on; CPU, GPU, MYRIAD, HDDL or HETERO: '
|
||||
'is acceptable. The sample will look for a suitable plugin for device specified. '
|
||||
'Default value is CPU.')
|
||||
args.add_argument('--labels', default=None, type=str, help='Optional. Path to a labels mapping file.')
|
||||
args.add_argument('-nt', '--number_top', default=10, type=int, help='Optional. Number of top results.')
|
||||
|
||||
return parser
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
log.basicConfig(format="[ %(levelname)s ] %(message)s", level=log.INFO, stream=sys.stdout)
|
||||
args = build_argparser().parse_args()
|
||||
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
|
||||
args = parse_args()
|
||||
|
||||
# Plugin initialization for specified device and load extensions library if specified
|
||||
log.info("Creating Inference Engine")
|
||||
# ---------------------------Step 1. Initialize inference engine core--------------------------------------------------
|
||||
log.info('Creating Inference Engine')
|
||||
ie = IECore()
|
||||
if args.cpu_extension and 'CPU' in args.device:
|
||||
ie.add_extension(args.cpu_extension, "CPU")
|
||||
|
||||
# Read a model in OpenVINO Intermediate Representation (.xml and .bin files) or ONNX (.onnx file) format
|
||||
model = args.model
|
||||
log.info(f"Loading network:\n\t{model}")
|
||||
net = ie.read_network(model=model)
|
||||
# ---------------------------Step 2. Read a model in OpenVINO Intermediate Representation or ONNX format---------------
|
||||
log.info(f'Reading the network: {args.model}')
|
||||
# (.xml and .bin files) or (.onnx file)
|
||||
net = ie.read_network(model=args.model)
|
||||
|
||||
assert len(net.input_info.keys()) == 1, "Sample supports only single input topologies"
|
||||
assert len(net.outputs) == 1, "Sample supports only single output topologies"
|
||||
if len(net.input_info) != 1:
|
||||
log.error('Sample supports only single input topologies')
|
||||
return -1
|
||||
if len(net.outputs) != 1:
|
||||
log.error('Sample supports only single output topologies')
|
||||
return -1
|
||||
|
||||
log.info("Preparing input blobs")
|
||||
# ---------------------------Step 3. Configure input & output----------------------------------------------------------
|
||||
log.info('Configuring input and output blobs')
|
||||
# Get names of input and output blobs
|
||||
input_blob = next(iter(net.input_info))
|
||||
out_blob = next(iter(net.outputs))
|
||||
|
||||
# Read and pre-process input images
|
||||
n, c, h, w = net.input_info[input_blob].input_data.shape
|
||||
images = np.ndarray(shape=(n, c, h, w))
|
||||
for i in range(n):
|
||||
image = cv2.imread(args.input)
|
||||
if image.shape[:-1] != (h, w):
|
||||
log.warning(f"Image {args.input} is resized from {image.shape[:-1]} to {(h, w)}")
|
||||
image = cv2.resize(image, (w, h))
|
||||
image = image.transpose((2, 0, 1)) # Change data layout from HWC to CHW
|
||||
images[i] = image
|
||||
# Set input and output precision manually
|
||||
net.input_info[input_blob].precision = 'U8'
|
||||
net.outputs[out_blob].precision = 'FP32'
|
||||
|
||||
# Loading model to the plugin
|
||||
log.info("Loading model to the plugin")
|
||||
# Get a number of classes recognized by a model
|
||||
num_of_classes = max(net.outputs[out_blob].shape)
|
||||
|
||||
# ---------------------------Step 4. Loading model to the device-------------------------------------------------------
|
||||
log.info('Loading the model to the plugin')
|
||||
exec_net = ie.load_network(network=net, device_name=args.device)
|
||||
|
||||
# Start sync inference
|
||||
log.info("Starting inference in synchronous mode")
|
||||
res = exec_net.infer(inputs={input_blob: images})
|
||||
# ---------------------------Step 5. Create infer request--------------------------------------------------------------
|
||||
# load_network() method of the IECore class with a specified number of requests (default 1) returns an ExecutableNetwork
|
||||
# instance which stores infer requests. So you already created Infer requests in the previous step.
|
||||
|
||||
# Processing output blob
|
||||
log.info("Processing output blob")
|
||||
res = res[out_blob]
|
||||
log.info(f"Top {args.number_top} results: ")
|
||||
# ---------------------------Step 6. Prepare input---------------------------------------------------------------------
|
||||
original_image = cv2.imread(args.input)
|
||||
image = original_image.copy()
|
||||
_, _, h, w = net.input_info[input_blob].input_data.shape
|
||||
|
||||
if image.shape[:-1] != (h, w):
|
||||
log.warning(f'Image {args.input} is resized from {image.shape[:-1]} to {(h, w)}')
|
||||
image = cv2.resize(image, (w, h))
|
||||
|
||||
# Change data layout from HWC to CHW
|
||||
image = image.transpose((2, 0, 1))
|
||||
# Add N dimension to transform to NCHW
|
||||
image = np.expand_dims(image, axis=0)
|
||||
|
||||
# ---------------------------Step 7. Do inference----------------------------------------------------------------------
|
||||
log.info('Starting inference in synchronous mode')
|
||||
res = exec_net.infer(inputs={input_blob: image})
|
||||
|
||||
# ---------------------------Step 8. Process output--------------------------------------------------------------------
|
||||
# Generate a label list
|
||||
if args.labels:
|
||||
with open(args.labels, 'r') as f:
|
||||
labels_map = [x.split(sep=' ', maxsplit=1)[-1].strip() for x in f]
|
||||
else:
|
||||
labels_map = None
|
||||
classid_str = "classid"
|
||||
probability_str = "probability"
|
||||
for i, probs in enumerate(res):
|
||||
probs = np.squeeze(probs)
|
||||
top_ind = np.argsort(probs)[-args.number_top:][::-1]
|
||||
print(f"Image {args.input}\n")
|
||||
print(classid_str, probability_str)
|
||||
print(f"{'-' * len(classid_str)} {'-' * len(probability_str)}")
|
||||
for id in top_ind:
|
||||
det_label = labels_map[id] if labels_map else f"{id}"
|
||||
label_length = len(det_label)
|
||||
space_num_before = (len(classid_str) - label_length) // 2
|
||||
space_num_after = len(classid_str) - (space_num_before + label_length) + 2
|
||||
print(f"{' ' * space_num_before}{det_label}"
|
||||
f"{' ' * space_num_after}{probs[id]:.7f}")
|
||||
print("\n")
|
||||
log.info("This sample is an API example, for any performance measurements please use the dedicated benchmark_app tool\n")
|
||||
labels = [line.split(',')[0].strip() for line in f]
|
||||
|
||||
res = res[out_blob]
|
||||
# Change a shape of a numpy.ndarray with results to get another one with one dimension
|
||||
probs = res.reshape(num_of_classes)
|
||||
# Get an array of args.number_top class IDs in descending order of probability
|
||||
top_n_idexes = np.argsort(probs)[-args.number_top:][::-1]
|
||||
|
||||
header = 'classid probability'
|
||||
header = header + ' label' if args.labels else header
|
||||
|
||||
log.info(f'Image path: {args.input}')
|
||||
log.info(f'Top {args.number_top} results: ')
|
||||
log.info(header)
|
||||
log.info('-' * len(header))
|
||||
|
||||
for class_id in top_n_idexes:
|
||||
probability_indent = ' ' * (len('classid') - len(str(class_id)) + 1)
|
||||
label_indent = ' ' * (len('probability') - 8) if args.labels else ''
|
||||
label = labels[class_id] if args.labels else ''
|
||||
log.info(f'{class_id}{probability_indent}{probs[class_id]:.7f}{label_indent}{label}')
|
||||
log.info('')
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
log.info('This sample is an API example, '
|
||||
'for any performance measurements please use the dedicated benchmark_app tool\n')
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main() or 0)
|
||||
sys.exit(main())
|
||||
|
61
inference-engine/ie_bridges/python/sample/hello_query_device/hello_query_device.py
Normal file → Executable file
61
inference-engine/ie_bridges/python/sample/hello_query_device/hello_query_device.py
Normal file → Executable file
@ -1,43 +1,54 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging as log
|
||||
import sys
|
||||
|
||||
from openvino.inference_engine import IECore
|
||||
|
||||
|
||||
def param_to_string(metric):
|
||||
def param_to_string(metric) -> str:
|
||||
'''Convert a list / tuple of parameters returned from IE to a string'''
|
||||
if isinstance(metric, (list, tuple)):
|
||||
return ", ".join([str(val) for val in metric])
|
||||
elif isinstance(metric, dict):
|
||||
str_param_repr = ""
|
||||
for k, v in metric.items():
|
||||
str_param_repr += f"{k}: {v}\n"
|
||||
return str_param_repr
|
||||
return ', '.join([str(x) for x in metric])
|
||||
else:
|
||||
return str(metric)
|
||||
|
||||
|
||||
def main():
|
||||
ie = IECore()
|
||||
print("Available devices:")
|
||||
for device in ie.available_devices:
|
||||
print(f"\tDevice: {device}")
|
||||
print("\tMetrics:")
|
||||
for metric in ie.get_metric(device, "SUPPORTED_METRICS"):
|
||||
try:
|
||||
metric_val = ie.get_metric(device, metric)
|
||||
print(f"\t\t{metric}: {param_to_string(metric_val)}")
|
||||
except TypeError:
|
||||
print(f"\t\t{metric}: UNSUPPORTED TYPE")
|
||||
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
|
||||
|
||||
print("\n\tDefault values for device configuration keys:")
|
||||
for cfg in ie.get_metric(device, "SUPPORTED_CONFIG_KEYS"):
|
||||
# ---------------------------Initialize inference engine core----------------------------------------------------------
|
||||
log.info('Creating Inference Engine')
|
||||
ie = IECore()
|
||||
|
||||
# ---------------------------Get metrics of available devices----------------------------------------------------------
|
||||
log.info('Available devices:')
|
||||
for device in ie.available_devices:
|
||||
log.info(f'{device} :')
|
||||
log.info('\tSUPPORTED_METRICS:')
|
||||
for metric in ie.get_metric(device, 'SUPPORTED_METRICS'):
|
||||
if metric not in ('SUPPORTED_METRICS', 'SUPPORTED_CONFIG_KEYS'):
|
||||
try:
|
||||
metric_val = ie.get_metric(device, metric)
|
||||
except TypeError:
|
||||
metric_val = 'UNSUPPORTED TYPE'
|
||||
log.info(f'\t\t{metric}: {param_to_string(metric_val)}')
|
||||
log.info('')
|
||||
|
||||
log.info('\tSUPPORTED_CONFIG_KEYS (default values):')
|
||||
for config_key in ie.get_metric(device, 'SUPPORTED_CONFIG_KEYS'):
|
||||
try:
|
||||
cfg_val = ie.get_config(device, cfg)
|
||||
print(f"\t\t{cfg}: {param_to_string(cfg_val)}")
|
||||
config_val = ie.get_config(device, config_key)
|
||||
except TypeError:
|
||||
print(f"\t\t{cfg}: UNSUPPORTED TYPE")
|
||||
config_val = 'UNSUPPORTED TYPE'
|
||||
log.info(f'\t\t{config_key}: {param_to_string(config_val)}')
|
||||
log.info('')
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main() or 0)
|
||||
sys.exit(main())
|
||||
|
239
inference-engine/ie_bridges/python/sample/hello_reshape_ssd/hello_reshape_ssd.py
Normal file → Executable file
239
inference-engine/ie_bridges/python/sample/hello_reshape_ssd/hello_reshape_ssd.py
Normal file → Executable file
@ -1,158 +1,145 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import print_function
|
||||
from argparse import ArgumentParser, SUPPRESS
|
||||
import argparse
|
||||
import logging as log
|
||||
import os
|
||||
import sys
|
||||
|
||||
import cv2
|
||||
import ngraph as ng
|
||||
import numpy as np
|
||||
from openvino.inference_engine import IECore
|
||||
|
||||
def build_argparser():
|
||||
parser = ArgumentParser(add_help=False)
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
'''Parse and return command line arguments'''
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
args = parser.add_argument_group('Options')
|
||||
args.add_argument('-h', '--help', action='help', default=SUPPRESS, help='Show this help message and exit.')
|
||||
args.add_argument('-m', '--model', help='Required. Path to an .xml or .onnx file with a trained model.',
|
||||
required=True, type=str)
|
||||
args.add_argument('-i', '--input', help='Required. Path to an image file.',
|
||||
required=True, type=str)
|
||||
args.add_argument('-d', '--device',
|
||||
help='Optional. Specify the target device to infer on; '
|
||||
'CPU, GPU, FPGA or MYRIAD is acceptable. '
|
||||
'Sample will look for a suitable plugin for device specified (CPU by default)',
|
||||
default='CPU', type=str)
|
||||
return parser
|
||||
args.add_argument('-h', '--help', action='help', help='Show this help message and exit.')
|
||||
args.add_argument('-m', '--model', required=True, type=str,
|
||||
help='Required. Path to an .xml or .onnx file with a trained model.')
|
||||
args.add_argument('-i', '--input', required=True, type=str, help='Required. Path to an image file.')
|
||||
args.add_argument('-l', '--extension', type=str, default=None,
|
||||
help='Optional. Required by the CPU Plugin for executing the custom operation on a CPU. '
|
||||
'Absolute path to a shared library with the kernels implementations.')
|
||||
args.add_argument('-c', '--config', type=str, default=None,
|
||||
help='Optional. Required by GPU or VPU Plugins for the custom operation kernel. '
|
||||
'Absolute path to operation description file (.xml).')
|
||||
args.add_argument('-d', '--device', default='CPU', type=str,
|
||||
help='Optional. Specify the target device to infer on; CPU, GPU, MYRIAD, HDDL or HETERO: '
|
||||
'is acceptable. The sample will look for a suitable plugin for device specified. '
|
||||
'Default value is CPU.')
|
||||
args.add_argument('--labels', default=None, type=str, help='Optional. Path to a labels mapping file.')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
|
||||
args = build_argparser().parse_args()
|
||||
log.info('Loading Inference Engine')
|
||||
args = parse_args()
|
||||
|
||||
# ---------------------------Step 1. Initialize inference engine core--------------------------------------------------
|
||||
log.info('Creating Inference Engine')
|
||||
ie = IECore()
|
||||
|
||||
# ---1. Read a model in OpenVINO Intermediate Representation (.xml and .bin files) or ONNX (.onnx file) format ---
|
||||
model = args.model
|
||||
log.info(f'Loading network:')
|
||||
log.info(f' {model}')
|
||||
net = ie.read_network(model=model)
|
||||
# -----------------------------------------------------------------------------------------------------
|
||||
if args.extension and args.device == 'CPU':
|
||||
log.info(f'Loading the {args.device} extension: {args.extension}')
|
||||
ie.add_extension(args.extension, args.device)
|
||||
|
||||
# ------------- 2. Load Plugin for inference engine and extensions library if specified --------------
|
||||
log.info('Device info:')
|
||||
versions = ie.get_versions(args.device)
|
||||
log.info(f' {args.device}')
|
||||
log.info(f' MKLDNNPlugin version ......... {versions[args.device].major}.{versions[args.device].minor}')
|
||||
log.info(f' Build ........... {versions[args.device].build_number}')
|
||||
# -----------------------------------------------------------------------------------------------------
|
||||
if args.config and args.device in ('GPU', 'MYRIAD', 'HDDL'):
|
||||
log.info(f'Loading the {args.device} configuration: {args.config}')
|
||||
ie.set_config({'CONFIG_FILE': args.config}, args.device)
|
||||
|
||||
# --------------------------- 3. Read and preprocess input --------------------------------------------
|
||||
# ---------------------------Step 2. Read a model in OpenVINO Intermediate Representation or ONNX format---------------
|
||||
log.info(f'Reading the network: {args.model}')
|
||||
# (.xml and .bin files) or (.onnx file)
|
||||
net = ie.read_network(model=args.model)
|
||||
|
||||
log.info(f'Inputs number: {len(net.input_info.keys())}')
|
||||
assert len(net.input_info.keys()) == 1, 'Sample supports clean SSD network with one input'
|
||||
assert len(net.outputs.keys()) == 1, 'Sample supports clean SSD network with one output'
|
||||
input_name = list(net.input_info.keys())[0]
|
||||
input_info = net.input_info[input_name]
|
||||
supported_input_dims = 4 # Supported input layout - NHWC
|
||||
if len(net.input_info) != 1:
|
||||
log.error('Sample supports only single input topologies')
|
||||
return -1
|
||||
if len(net.outputs) != 1:
|
||||
log.error('Sample supports only single output topologies')
|
||||
return -1
|
||||
|
||||
# ---------------------------Step 3. Configure input & output----------------------------------------------------------
|
||||
log.info('Configuring input and output blobs')
|
||||
# Get names of input and output blobs
|
||||
input_blob = next(iter(net.input_info))
|
||||
out_blob = next(iter(net.outputs))
|
||||
|
||||
# Set input and output precision manually
|
||||
net.input_info[input_blob].precision = 'U8'
|
||||
net.outputs[out_blob].precision = 'FP32'
|
||||
|
||||
original_image = cv2.imread(args.input)
|
||||
image = original_image.copy()
|
||||
|
||||
# Change data layout from HWC to CHW
|
||||
image = image.transpose((2, 0, 1))
|
||||
# Add N dimension to transform to NCHW
|
||||
image = np.expand_dims(image, axis=0)
|
||||
|
||||
log.info(f' Input name: {input_name}')
|
||||
log.info(f' Input shape: {str(input_info.input_data.shape)}')
|
||||
if len(input_info.input_data.layout) == supported_input_dims:
|
||||
n, c, h, w = input_info.input_data.shape
|
||||
assert n == 1, 'Sample supports topologies with one input image only'
|
||||
else:
|
||||
raise AssertionError('Sample supports input with NHWC shape only')
|
||||
|
||||
image = cv2.imread(args.input)
|
||||
h_new, w_new = image.shape[:-1]
|
||||
images = np.ndarray(shape=(n, c, h_new, w_new))
|
||||
log.info('File was added: ')
|
||||
log.info(f' {args.input}')
|
||||
image = image.transpose((2, 0, 1)) # Change data layout from HWC to CHW
|
||||
images[0] = image
|
||||
|
||||
log.info('Reshaping the network to the height and width of the input image')
|
||||
net.reshape({input_name: [n, c, h_new, w_new]})
|
||||
log.info(f'Input shape after reshape: {str(net.input_info["data"].input_data.shape)}')
|
||||
log.info(f'Input shape before reshape: {net.input_info[input_blob].input_data.shape}')
|
||||
net.reshape({input_blob: image.shape})
|
||||
log.info(f'Input shape after reshape: {net.input_info[input_blob].input_data.shape}')
|
||||
|
||||
# -----------------------------------------------------------------------------------------------------
|
||||
|
||||
# --------------------------- 4. Configure input & output ---------------------------------------------
|
||||
# --------------------------- Prepare input blobs -----------------------------------------------------
|
||||
log.info('Preparing input blobs')
|
||||
|
||||
if len(input_info.layout) == supported_input_dims:
|
||||
input_info.precision = 'U8'
|
||||
|
||||
data = {}
|
||||
data[input_name] = images
|
||||
|
||||
# --------------------------- Prepare output blobs ----------------------------------------------------
|
||||
log.info('Preparing output blobs')
|
||||
|
||||
func = ng.function_from_cnn(net)
|
||||
ops = func.get_ordered_ops()
|
||||
output_name, output_info = '', net.outputs[next(iter(net.outputs.keys()))]
|
||||
output_ops = {op.friendly_name : op for op in ops \
|
||||
if op.friendly_name in net.outputs and op.get_type_name() == 'DetectionOutput'}
|
||||
if len(output_ops) == 1:
|
||||
output_name, output_info = output_ops.popitem()
|
||||
|
||||
assert output_name != '', 'Can''t find a DetectionOutput layer in the topology'
|
||||
|
||||
output_dims = output_info.shape
|
||||
assert output_dims != 4, 'Incorrect output dimensions for SSD model'
|
||||
assert output_dims[3] == 7, 'Output item should have 7 as a last dimension'
|
||||
|
||||
output_info.precision = 'FP32'
|
||||
# -----------------------------------------------------------------------------------------------------
|
||||
|
||||
# --------------------------- Performing inference ----------------------------------------------------
|
||||
log.info('Loading model to the device')
|
||||
# ---------------------------Step 4. Loading model to the device-------------------------------------------------------
|
||||
log.info('Loading the model to the plugin')
|
||||
exec_net = ie.load_network(network=net, device_name=args.device)
|
||||
log.info('Creating infer request and starting inference')
|
||||
exec_result = exec_net.infer(inputs=data)
|
||||
# -----------------------------------------------------------------------------------------------------
|
||||
|
||||
# --------------------------- Read and postprocess output ---------------------------------------------
|
||||
log.info('Processing output blobs')
|
||||
result = exec_result[output_name]
|
||||
boxes = {}
|
||||
detections = result[0][0] # [0][0] - location of detections in result blob
|
||||
for number, proposal in enumerate(detections):
|
||||
imid, label, confidence, coords = np.int(proposal[0]), np.int(proposal[1]), proposal[2], proposal[3:]
|
||||
# ---------------------------Step 5. Create infer request--------------------------------------------------------------
|
||||
# load_network() method of the IECore class with a specified number of requests (default 1) returns an ExecutableNetwork
|
||||
# instance which stores infer requests. So you already created Infer requests in the previous step.
|
||||
|
||||
# ---------------------------Step 6. Prepare input---------------------------------------------------------------------
|
||||
# This sample changes a network input layer shape instead of a image shape. See Step 4.
|
||||
|
||||
# ---------------------------Step 7. Do inference----------------------------------------------------------------------
|
||||
log.info('Starting inference in synchronous mode')
|
||||
res = exec_net.infer(inputs={input_blob: image})
|
||||
|
||||
# ---------------------------Step 8. Process output--------------------------------------------------------------------
|
||||
# Generate a label list
|
||||
if args.labels:
|
||||
with open(args.labels, 'r') as f:
|
||||
labels = [line.split(',')[0].strip() for line in f]
|
||||
|
||||
res = res[out_blob]
|
||||
|
||||
output_image = original_image.copy()
|
||||
h, w, _ = output_image.shape
|
||||
|
||||
# Change a shape of a numpy.ndarray with results ([1, 1, N, 7]) to get another one ([N, 7]),
|
||||
# where N is the number of detected bounding boxes
|
||||
detections = res.reshape(-1, 7)
|
||||
|
||||
for detection in detections:
|
||||
confidence = detection[2]
|
||||
if confidence > 0.5:
|
||||
# correcting coordinates to actual image resolution
|
||||
xmin, ymin, xmax, ymax = w_new * coords[0], h_new * coords[1], w_new * coords[2], h_new * coords[3]
|
||||
class_id = int(detection[1])
|
||||
label = labels[class_id] if args.labels else class_id
|
||||
|
||||
log.info(f' [{number},{label}] element, prob = {confidence:.6f}, '
|
||||
f'bbox = ({xmin:.3f},{ymin:.3f})-({xmax:.3f},{ymax:.3f}), batch id = {imid}')
|
||||
if not imid in boxes.keys():
|
||||
boxes[imid] = []
|
||||
boxes[imid].append([xmin, ymin, xmax, ymax])
|
||||
xmin = int(detection[3] * w)
|
||||
ymin = int(detection[4] * h)
|
||||
xmax = int(detection[5] * w)
|
||||
ymax = int(detection[6] * h)
|
||||
|
||||
imid = 0 # as sample supports one input image only, imid in results will always be 0
|
||||
log.info(f'Found: label = {label}, confidence = {confidence:.2f}, '
|
||||
f'coords = ({xmin}, {ymin}), ({xmax}, {ymax})')
|
||||
|
||||
tmp_image = cv2.imread(args.input)
|
||||
for box in boxes[imid]:
|
||||
# drawing bounding boxes on the output image
|
||||
cv2.rectangle(
|
||||
tmp_image,
|
||||
(np.int(box[0]), np.int(box[1])), (np.int(box[2]), np.int(box[3])),
|
||||
color=(232, 35, 244), thickness=2)
|
||||
cv2.imwrite('out.bmp', tmp_image)
|
||||
log.info('Image out.bmp created!')
|
||||
# -----------------------------------------------------------------------------------------------------
|
||||
# Draw a bounding box on a output image
|
||||
cv2.rectangle(output_image, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2)
|
||||
|
||||
log.info('Execution successful\n')
|
||||
log.info(
|
||||
'This sample is an API example, for any performance measurements please use the dedicated benchmark_app tool')
|
||||
cv2.imwrite('out.bmp', output_image)
|
||||
log.info('Image out.bmp was created!')
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
log.info('This sample is an API example, '
|
||||
'for any performance measurements please use the dedicated benchmark_app tool\n')
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main() or 0)
|
||||
sys.exit(main())
|
||||
|
262
inference-engine/ie_bridges/python/sample/ngraph_function_creation_sample/ngraph_function_creation_sample.py
Normal file → Executable file
262
inference-engine/ie_bridges/python/sample/ngraph_function_creation_sample/ngraph_function_creation_sample.py
Normal file → Executable file
@ -1,84 +1,68 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import sys
|
||||
import os
|
||||
from argparse import ArgumentParser, SUPPRESS
|
||||
import cv2
|
||||
import numpy as np
|
||||
import argparse
|
||||
import logging as log
|
||||
from openvino.inference_engine import IECore, IENetwork
|
||||
import ngraph
|
||||
from ngraph.impl import Function
|
||||
from functools import reduce
|
||||
import struct as st
|
||||
import sys
|
||||
from functools import reduce
|
||||
|
||||
import cv2
|
||||
import ngraph
|
||||
import numpy as np
|
||||
from openvino.inference_engine import IECore, IENetwork
|
||||
|
||||
|
||||
def build_argparser() -> ArgumentParser:
|
||||
parser = ArgumentParser(add_help=False)
|
||||
def parse_args() -> argparse.Namespace:
|
||||
'''Parse and return command line arguments'''
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
args = parser.add_argument_group('Options')
|
||||
args.add_argument('-h', '--help', action='help', default=SUPPRESS, help='Show this help message and exit.')
|
||||
args.add_argument('-i', '--input', help='Required. Path to a folder with images or path to an image files',
|
||||
required=True, type=str, nargs="+")
|
||||
args.add_argument('-m', '--model', help='Required. Path to file where weights for the network are located',
|
||||
required=True)
|
||||
args.add_argument('-d', '--device',
|
||||
help='Optional. Specify the target device to infer on; CPU, GPU, FPGA, HDDL, MYRIAD or HETERO: '
|
||||
'is acceptable. The sample will look for a suitable plugin for device specified. Default '
|
||||
'value is CPU',
|
||||
default='CPU', type=str)
|
||||
args.add_argument('--labels', help='Optional. Path to a labels mapping file', default=None, type=str)
|
||||
args.add_argument('-nt', '--number_top', help='Optional. Number of top results', default=1, type=int)
|
||||
args.add_argument('-h', '--help', action='help', help='Show this help message and exit.')
|
||||
args.add_argument('-m', '--model', required=True, type=str,
|
||||
help='Required. Path to a file with network weights.')
|
||||
args.add_argument('-i', '--input', required=True, type=str, nargs='+', help='Required. Path to an image file.')
|
||||
args.add_argument('-d', '--device', default='CPU', type=str,
|
||||
help='Optional. Specify the target device to infer on; CPU, GPU, MYRIAD, HDDL or HETERO: '
|
||||
'is acceptable. The sample will look for a suitable plugin for device specified. '
|
||||
'Default value is CPU.')
|
||||
args.add_argument('--labels', default=None, type=str, help='Optional. Path to a labels mapping file.')
|
||||
args.add_argument('-nt', '--number_top', default=10, type=int, help='Optional. Number of top results.')
|
||||
|
||||
return parser
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def list_input_images(input_dirs: list):
|
||||
images = []
|
||||
for input_dir in input_dirs:
|
||||
if os.path.isdir(input_dir):
|
||||
for root, directories, filenames in os.walk(input_dir):
|
||||
for filename in filenames:
|
||||
images.append(os.path.join(root, filename))
|
||||
elif os.path.isfile(input_dir):
|
||||
images.append(input_dir)
|
||||
def read_image(image_path: str) -> np.ndarray:
|
||||
'''Read and return an image as grayscale (one channel)'''
|
||||
image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def read_image(image_path: np):
|
||||
# try to read image in usual image formats
|
||||
image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
|
||||
|
||||
# try to open image as ubyte
|
||||
# Try to open image as ubyte
|
||||
if image is None:
|
||||
with open(image_path, 'rb') as f:
|
||||
image_file = open(image_path, 'rb')
|
||||
image_file.seek(0)
|
||||
st.unpack('>4B', image_file.read(4)) # need to skip 4 bytes
|
||||
nimg = st.unpack('>I', image_file.read(4))[0] # number of images
|
||||
nrow = st.unpack('>I', image_file.read(4))[0] # number of rows
|
||||
ncolumn = st.unpack('>I', image_file.read(4))[0] # number of column
|
||||
st.unpack('>4B', f.read(4)) # need to skip 4 bytes
|
||||
nimg = st.unpack('>I', f.read(4))[0] # number of images
|
||||
nrow = st.unpack('>I', f.read(4))[0] # number of rows
|
||||
ncolumn = st.unpack('>I', f.read(4))[0] # number of column
|
||||
nbytes = nimg * nrow * ncolumn * 1 # each pixel data is 1 byte
|
||||
|
||||
assert nimg == 1, log.error('Sample supports ubyte files with 1 image inside')
|
||||
if nimg != 1:
|
||||
raise Exception('Sample supports ubyte files with 1 image inside')
|
||||
|
||||
image = np.asarray(st.unpack('>' + 'B' * nbytes, image_file.read(nbytes))).reshape(
|
||||
(nrow, ncolumn))
|
||||
image = np.asarray(st.unpack('>' + 'B' * nbytes, f.read(nbytes))).reshape((nrow, ncolumn))
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def shape_and_length(shape: list):
|
||||
length = reduce(lambda x, y: x*y, shape)
|
||||
return shape, length
|
||||
def create_ngraph_function(args: argparse.Namespace) -> ngraph.impl.Function:
|
||||
'''Create a network on the fly from the source code using ngraph'''
|
||||
|
||||
def shape_and_length(shape: list) -> (list, int):
|
||||
length = reduce(lambda x, y: x * y, shape)
|
||||
return shape, length
|
||||
|
||||
def create_ngraph_function(args) -> Function:
|
||||
weights = np.fromfile(args.model, dtype=np.float32)
|
||||
weights_offset = 0
|
||||
padding_begin = [0, 0]
|
||||
padding_end = [0, 0]
|
||||
padding_begin = padding_end = [0, 0]
|
||||
|
||||
# input
|
||||
input_shape = [64, 1, 28, 28]
|
||||
@ -93,7 +77,7 @@ def create_ngraph_function(args) -> Function:
|
||||
# add 1
|
||||
add_1_kernel_shape, add_1_kernel_length = shape_and_length([1, 20, 1, 1])
|
||||
add_1_kernel = ngraph.constant(
|
||||
weights[weights_offset:weights_offset + add_1_kernel_length].reshape(add_1_kernel_shape)
|
||||
weights[weights_offset:weights_offset + add_1_kernel_length].reshape(add_1_kernel_shape),
|
||||
)
|
||||
weights_offset += add_1_kernel_length
|
||||
add_1_node = ngraph.add(conv_1_node, add_1_kernel)
|
||||
@ -104,7 +88,7 @@ def create_ngraph_function(args) -> Function:
|
||||
# convolution 2
|
||||
conv_2_kernel_shape, conv_2_kernel_length = shape_and_length([50, 20, 5, 5])
|
||||
conv_2_kernel = ngraph.constant(
|
||||
weights[weights_offset:weights_offset + conv_2_kernel_length].reshape(conv_2_kernel_shape)
|
||||
weights[weights_offset:weights_offset + conv_2_kernel_length].reshape(conv_2_kernel_shape),
|
||||
)
|
||||
weights_offset += conv_2_kernel_length
|
||||
conv_2_node = ngraph.convolution(maxpool_1_node, conv_2_kernel, [1, 1], padding_begin, padding_end, [1, 1])
|
||||
@ -112,7 +96,7 @@ def create_ngraph_function(args) -> Function:
|
||||
# add 2
|
||||
add_2_kernel_shape, add_2_kernel_length = shape_and_length([1, 50, 1, 1])
|
||||
add_2_kernel = ngraph.constant(
|
||||
weights[weights_offset:weights_offset + add_2_kernel_length].reshape(add_2_kernel_shape)
|
||||
weights[weights_offset:weights_offset + add_2_kernel_length].reshape(add_2_kernel_shape),
|
||||
)
|
||||
weights_offset += add_2_kernel_length
|
||||
add_2_node = ngraph.add(conv_2_node, add_2_kernel)
|
||||
@ -124,16 +108,16 @@ def create_ngraph_function(args) -> Function:
|
||||
reshape_1_dims, reshape_1_length = shape_and_length([2])
|
||||
# workaround to get int64 weights from float32 ndarray w/o unnecessary copying
|
||||
dtype_weights = np.frombuffer(
|
||||
weights[weights_offset:weights_offset + 2*reshape_1_length], dtype=np.int64
|
||||
weights[weights_offset:weights_offset + 2 * reshape_1_length], dtype=np.int64,
|
||||
)
|
||||
reshape_1_kernel = ngraph.constant(dtype_weights)
|
||||
weights_offset += 2*reshape_1_length
|
||||
weights_offset += 2 * reshape_1_length
|
||||
reshape_1_node = ngraph.reshape(maxpool_2_node, reshape_1_kernel, True)
|
||||
|
||||
# matmul 1
|
||||
matmul_1_kernel_shape, matmul_1_kernel_length = shape_and_length([500, 800])
|
||||
matmul_1_kernel = ngraph.constant(
|
||||
weights[weights_offset:weights_offset + matmul_1_kernel_length].reshape(matmul_1_kernel_shape)
|
||||
weights[weights_offset:weights_offset + matmul_1_kernel_length].reshape(matmul_1_kernel_shape),
|
||||
)
|
||||
weights_offset += matmul_1_kernel_length
|
||||
matmul_1_node = ngraph.matmul(reshape_1_node, matmul_1_kernel, False, True)
|
||||
@ -141,7 +125,7 @@ def create_ngraph_function(args) -> Function:
|
||||
# add 3
|
||||
add_3_kernel_shape, add_3_kernel_length = shape_and_length([1, 500])
|
||||
add_3_kernel = ngraph.constant(
|
||||
weights[weights_offset:weights_offset + add_3_kernel_length].reshape(add_3_kernel_shape)
|
||||
weights[weights_offset:weights_offset + add_3_kernel_length].reshape(add_3_kernel_shape),
|
||||
)
|
||||
weights_offset += add_3_kernel_length
|
||||
add_3_node = ngraph.add(matmul_1_node, add_3_kernel)
|
||||
@ -156,7 +140,7 @@ def create_ngraph_function(args) -> Function:
|
||||
# matmul 2
|
||||
matmul_2_kernel_shape, matmul_2_kernel_length = shape_and_length([10, 500])
|
||||
matmul_2_kernel = ngraph.constant(
|
||||
weights[weights_offset:weights_offset + matmul_2_kernel_length].reshape(matmul_2_kernel_shape)
|
||||
weights[weights_offset:weights_offset + matmul_2_kernel_length].reshape(matmul_2_kernel_shape),
|
||||
)
|
||||
weights_offset += matmul_2_kernel_length
|
||||
matmul_2_node = ngraph.matmul(reshape_2_node, matmul_2_kernel, False, True)
|
||||
@ -164,7 +148,7 @@ def create_ngraph_function(args) -> Function:
|
||||
# add 4
|
||||
add_4_kernel_shape, add_4_kernel_length = shape_and_length([1, 10])
|
||||
add_4_kernel = ngraph.constant(
|
||||
weights[weights_offset:weights_offset + add_4_kernel_length].reshape(add_4_kernel_shape)
|
||||
weights[weights_offset:weights_offset + add_4_kernel_length].reshape(add_4_kernel_shape),
|
||||
)
|
||||
weights_offset += add_4_kernel_length
|
||||
add_4_node = ngraph.add(matmul_2_node, add_4_kernel)
|
||||
@ -175,85 +159,101 @@ def create_ngraph_function(args) -> Function:
|
||||
|
||||
# result
|
||||
result_node = ngraph.result(softmax_node)
|
||||
|
||||
# nGraph function
|
||||
function = Function(result_node, [param_node], 'lenet')
|
||||
|
||||
return function
|
||||
return ngraph.impl.Function(result_node, [param_node], 'lenet')
|
||||
|
||||
|
||||
def main():
|
||||
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
|
||||
args = build_argparser().parse_args()
|
||||
args = parse_args()
|
||||
|
||||
input_images = list_input_images(args.input)
|
||||
|
||||
# Loading network using ngraph function
|
||||
ngraph_function = create_ngraph_function(args)
|
||||
net = IENetwork(Function.to_capsule(ngraph_function))
|
||||
|
||||
assert len(net.input_info.keys()) == 1, "Sample supports only single input topologies"
|
||||
assert len(net.outputs) == 1, "Sample supports only single output topologies"
|
||||
|
||||
log.info("Preparing input blobs")
|
||||
input_blob = next(iter(net.input_info))
|
||||
out_blob = next(iter(net.outputs))
|
||||
net.batch_size = len(input_images)
|
||||
|
||||
# Read and pre-process input images
|
||||
n, c, h, w = net.input_info[input_blob].input_data.shape
|
||||
images = np.ndarray(shape=(n, c, h, w))
|
||||
for i in range(n):
|
||||
image = read_image(input_images[i])
|
||||
assert image is not None, log.error(f"Can't open an image {input_images[i]}")
|
||||
assert len(image.shape) == 2, log.error('Sample supports images with 1 channel only')
|
||||
if image.shape[:] != (w, h):
|
||||
log.warning(f"Image {input_images[i]} is resized from {image.shape[:]} to {(w, h)}")
|
||||
image = cv2.resize(image, (w, h))
|
||||
images[i] = image
|
||||
log.info(f"Batch size is {n}")
|
||||
|
||||
log.info("Creating Inference Engine")
|
||||
# ---------------------------Step 1. Initialize inference engine core--------------------------------------------------
|
||||
log.info('Creating Inference Engine')
|
||||
ie = IECore()
|
||||
|
||||
log.info('Loading model to the device')
|
||||
exec_net = ie.load_network(network=net, device_name=args.device.upper())
|
||||
# ---------------------------Step 2. Read a model in OpenVINO Intermediate Representation------------------------------
|
||||
log.info(f'Loading the network using ngraph function with weights from {args.model}')
|
||||
ngraph_function = create_ngraph_function(args)
|
||||
net = IENetwork(ngraph.impl.Function.to_capsule(ngraph_function))
|
||||
|
||||
# Start sync inference
|
||||
log.info('Creating infer request and starting inference')
|
||||
res = exec_net.infer(inputs={input_blob: images})
|
||||
# ---------------------------Step 3. Configure input & output----------------------------------------------------------
|
||||
log.info('Configuring input and output blobs')
|
||||
# Get names of input and output blobs
|
||||
input_blob = next(iter(net.input_info))
|
||||
out_blob = next(iter(net.outputs))
|
||||
|
||||
# Processing results
|
||||
log.info("Processing output blob")
|
||||
res = res[out_blob]
|
||||
log.info(f"Top {args.number_top} results: ")
|
||||
# Set input and output precision manually
|
||||
net.input_info[input_blob].precision = 'U8'
|
||||
net.outputs[out_blob].precision = 'FP32'
|
||||
|
||||
# Read labels file if it is provided as argument
|
||||
labels_map = None
|
||||
# Set a batch size to a equal number of input images
|
||||
net.batch_size = len(args.input)
|
||||
|
||||
# ---------------------------Step 4. Loading model to the device-------------------------------------------------------
|
||||
log.info('Loading the model to the plugin')
|
||||
exec_net = ie.load_network(network=net, device_name=args.device)
|
||||
|
||||
# ---------------------------Step 5. Create infer request--------------------------------------------------------------
|
||||
# load_network() method of the IECore class with a specified number of requests (default 1) returns an ExecutableNetwork
|
||||
# instance which stores infer requests. So you already created Infer requests in the previous step.
|
||||
|
||||
# ---------------------------Step 6. Prepare input---------------------------------------------------------------------
|
||||
n, c, h, w = net.input_info[input_blob].input_data.shape
|
||||
input_data = np.ndarray(shape=(n, c, h, w))
|
||||
|
||||
for i in range(n):
|
||||
image = read_image(args.input[i])
|
||||
|
||||
light_pixel_count = np.count_nonzero(image > 127)
|
||||
darK_pixel_count = np.count_nonzero(image < 127)
|
||||
is_light_image = (light_pixel_count - darK_pixel_count) > 0
|
||||
|
||||
if is_light_image:
|
||||
log.warning(f'Image {args.input[i]} is inverted to white over black')
|
||||
image = cv2.bitwise_not(image)
|
||||
|
||||
if image.shape != (h, w):
|
||||
log.warning(f'Image {args.input[i]} is resized from {image.shape} to {(h, w)}')
|
||||
image = cv2.resize(image, (w, h))
|
||||
|
||||
input_data[i] = image
|
||||
|
||||
# ---------------------------Step 7. Do inference----------------------------------------------------------------------
|
||||
log.info('Starting inference in synchronous mode')
|
||||
res = exec_net.infer(inputs={input_blob: input_data})
|
||||
|
||||
# ---------------------------Step 8. Process output--------------------------------------------------------------------
|
||||
# Generate a label list
|
||||
if args.labels:
|
||||
with open(args.labels, 'r') as f:
|
||||
labels_map = [x.split(sep=' ', maxsplit=1)[-1].strip() for x in f]
|
||||
labels = [line.split(',')[0].strip() for line in f]
|
||||
|
||||
classid_str = "classid"
|
||||
probability_str = "probability"
|
||||
for i, probs in enumerate(res):
|
||||
probs = np.squeeze(probs)
|
||||
top_ind = np.argsort(probs)[-args.number_top:][::-1]
|
||||
print(f"Image {input_images[i]}\n")
|
||||
print(classid_str, probability_str)
|
||||
print(f"{'-' * len(classid_str)} {'-' * len(probability_str)}")
|
||||
for class_id in top_ind:
|
||||
det_label = labels_map[class_id] if labels_map else f"{class_id}"
|
||||
label_length = len(det_label)
|
||||
space_num_before = (len(classid_str) - label_length) // 2
|
||||
space_num_after = len(classid_str) - (space_num_before + label_length) + 2
|
||||
print(f"{' ' * space_num_before}{det_label}"
|
||||
f"{' ' * space_num_after}{probs[class_id]:.7f}")
|
||||
print("\n")
|
||||
res = res[out_blob]
|
||||
|
||||
log.info('This sample is an API example, for any performance measurements '
|
||||
'please use the dedicated benchmark_app tool')
|
||||
for i in range(n):
|
||||
probs = res[i]
|
||||
# Get an array of args.number_top class IDs in descending order of probability
|
||||
top_n_idexes = np.argsort(probs)[-args.number_top:][::-1]
|
||||
|
||||
header = 'classid probability'
|
||||
header = header + ' label' if args.labels else header
|
||||
|
||||
log.info(f'Image path: {args.input[i]}')
|
||||
log.info(f'Top {args.number_top} results: ')
|
||||
log.info(header)
|
||||
log.info('-' * len(header))
|
||||
|
||||
for class_id in top_n_idexes:
|
||||
probability_indent = ' ' * (len('classid') - len(str(class_id)) + 1)
|
||||
label_indent = ' ' * (len('probability') - 8) if args.labels else ''
|
||||
label = labels[class_id] if args.labels else ''
|
||||
log.info(f'{class_id}{probability_indent}{probs[class_id]:.7f}{label_indent}{label}')
|
||||
log.info('')
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
log.info('This sample is an API example, '
|
||||
'for any performance measurements please use the dedicated benchmark_app tool\n')
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main() or 0)
|
||||
sys.exit(main())
|
||||
|
293
inference-engine/ie_bridges/python/sample/object_detection_sample_ssd/object_detection_sample_ssd.py
Normal file → Executable file
293
inference-engine/ie_bridges/python/sample/object_detection_sample_ssd/object_detection_sample_ssd.py
Normal file → Executable file
@ -1,193 +1,158 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import logging as log
|
||||
import sys
|
||||
import os
|
||||
from argparse import ArgumentParser, SUPPRESS
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import logging as log
|
||||
from openvino.inference_engine import IECore
|
||||
|
||||
import ngraph as ng
|
||||
|
||||
def build_argparser():
|
||||
parser = ArgumentParser(add_help=False)
|
||||
args = parser.add_argument_group("Options")
|
||||
args.add_argument('-h', '--help', action='help', default=SUPPRESS, help='Show this help message and exit.')
|
||||
args.add_argument("-m", "--model", help="Required. Path to an .xml or .onnx file with a trained model.",
|
||||
required=True, type=str)
|
||||
args.add_argument("-i", "--input", help="Required. Path to an image file.",
|
||||
required=True, type=str)
|
||||
args.add_argument("-l", "--cpu_extension",
|
||||
help="Optional. Required for CPU custom layers. "
|
||||
"Absolute path to a shared library with the kernels implementations.",
|
||||
type=str, default=None)
|
||||
args.add_argument("-d", "--device",
|
||||
help="Optional. Specify the target device to infer on; "
|
||||
"CPU, GPU, FPGA or MYRIAD is acceptable. "
|
||||
"Sample will look for a suitable plugin for device specified (CPU by default)",
|
||||
default="CPU", type=str)
|
||||
args.add_argument("--labels", help="Optional. Labels mapping file", default=None, type=str)
|
||||
args.add_argument("-nt", "--number_top", help="Optional. Number of top results", default=10, type=int)
|
||||
def parse_args() -> argparse.Namespace:
|
||||
'''Parse and return command line arguments'''
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
args = parser.add_argument_group('Options')
|
||||
args.add_argument('-h', '--help', action='help', help='Show this help message and exit.')
|
||||
args.add_argument('-m', '--model', required=True, type=str,
|
||||
help='Required. Path to an .xml or .onnx file with a trained model.')
|
||||
args.add_argument('-i', '--input', required=True, type=str, help='Required. Path to an image file.')
|
||||
args.add_argument('-l', '--extension', type=str, default=None,
|
||||
help='Optional. Required by the CPU Plugin for executing the custom operation on a CPU. '
|
||||
'Absolute path to a shared library with the kernels implementations.')
|
||||
args.add_argument('-c', '--config', type=str, default=None,
|
||||
help='Optional. Required by GPU or VPU Plugins for the custom operation kernel. '
|
||||
'Absolute path to operation description file (.xml).')
|
||||
args.add_argument('-d', '--device', default='CPU', type=str,
|
||||
help='Optional. Specify the target device to infer on; CPU, GPU, MYRIAD, HDDL or HETERO: '
|
||||
'is acceptable. The sample will look for a suitable plugin for device specified. '
|
||||
'Default value is CPU.')
|
||||
args.add_argument('--labels', default=None, type=str, help='Optional. Path to a labels mapping file.')
|
||||
|
||||
return parser
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
log.basicConfig(format="[ %(levelname)s ] %(message)s", level=log.INFO, stream=sys.stdout)
|
||||
args = build_argparser().parse_args()
|
||||
log.info("Loading Inference Engine")
|
||||
def main(): # noqa
|
||||
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
|
||||
args = parse_args()
|
||||
|
||||
# ---------------------------Step 1. Initialize inference engine core--------------------------------------------------
|
||||
log.info('Creating Inference Engine')
|
||||
ie = IECore()
|
||||
|
||||
# ---1. Read a model in OpenVINO Intermediate Representation (.xml and .bin files) or ONNX (.onnx file) format ---
|
||||
model = args.model
|
||||
log.info(f"Loading network:\n\t{model}")
|
||||
net = ie.read_network(model=model)
|
||||
# -----------------------------------------------------------------------------------------------------
|
||||
if args.extension and args.device == 'CPU':
|
||||
log.info(f'Loading the {args.device} extension: {args.extension}')
|
||||
ie.add_extension(args.extension, args.device)
|
||||
|
||||
# ------------- 2. Load Plugin for inference engine and extensions library if specified --------------
|
||||
log.info("Device info:")
|
||||
versions = ie.get_versions(args.device)
|
||||
print(f"{' ' * 8}{args.device}")
|
||||
print(f"{' ' * 8}MKLDNNPlugin version ......... {versions[args.device].major}.{versions[args.device].minor}")
|
||||
print(f"{' ' * 8}Build ........... {versions[args.device].build_number}")
|
||||
if args.config and args.device in ('GPU', 'MYRIAD', 'HDDL'):
|
||||
log.info(f'Loading the {args.device} configuration: {args.config}')
|
||||
ie.set_config({'CONFIG_FILE': args.config}, args.device)
|
||||
|
||||
if args.cpu_extension and "CPU" in args.device:
|
||||
ie.add_extension(args.cpu_extension, "CPU")
|
||||
log.info(f"CPU extension loaded: {args.cpu_extension}")
|
||||
# -----------------------------------------------------------------------------------------------------
|
||||
# ---------------------------Step 2. Read a model in OpenVINO Intermediate Representation or ONNX format---------------
|
||||
log.info(f'Reading the network: {args.model}')
|
||||
# (.xml and .bin files) or (.onnx file)
|
||||
net = ie.read_network(model=args.model)
|
||||
|
||||
# --------------------------- 3. Read and preprocess input --------------------------------------------
|
||||
for input_key in net.input_info:
|
||||
if len(net.input_info[input_key].input_data.layout) == 4:
|
||||
n, c, h, w = net.input_info[input_key].input_data.shape
|
||||
if len(net.input_info) != 1:
|
||||
log.error('The sample supports only single input topologies')
|
||||
return - 1
|
||||
|
||||
images = np.ndarray(shape=(n, c, h, w))
|
||||
images_hw = []
|
||||
for i in range(n):
|
||||
image = cv2.imread(args.input)
|
||||
ih, iw = image.shape[:-1]
|
||||
images_hw.append((ih, iw))
|
||||
log.info("File was added: ")
|
||||
log.info(f" {args.input}")
|
||||
if (ih, iw) != (h, w):
|
||||
log.warning(f"Image {args.input} is resized from {image.shape[:-1]} to {(h, w)}")
|
||||
image = cv2.resize(image, (w, h))
|
||||
image = image.transpose((2, 0, 1)) # Change data layout from HWC to CHW
|
||||
images[i] = image
|
||||
# -----------------------------------------------------------------------------------------------------
|
||||
if len(net.outputs) != 1 and not ('boxes' in net.outputs or 'labels' in net.outputs):
|
||||
log.error('The sample supports models with 1 output or with 2 with the names "boxes" and "labels"')
|
||||
return -1
|
||||
|
||||
# --------------------------- 4. Configure input & output ---------------------------------------------
|
||||
# --------------------------- Prepare input blobs -----------------------------------------------------
|
||||
log.info("Preparing input blobs")
|
||||
assert (len(net.input_info.keys()) == 1 or len(
|
||||
net.input_info.keys()) == 2), "Sample supports topologies only with 1 or 2 inputs"
|
||||
out_blob = next(iter(net.outputs))
|
||||
input_name, input_info_name = "", ""
|
||||
# ---------------------------Step 3. Configure input & output----------------------------------------------------------
|
||||
log.info('Configuring input and output blobs')
|
||||
# Get name of input blob
|
||||
input_blob = next(iter(net.input_info))
|
||||
|
||||
for input_key in net.input_info:
|
||||
if len(net.input_info[input_key].layout) == 4:
|
||||
input_name = input_key
|
||||
net.input_info[input_key].precision = 'U8'
|
||||
elif len(net.input_info[input_key].layout) == 2:
|
||||
input_info_name = input_key
|
||||
net.input_info[input_key].precision = 'FP32'
|
||||
if net.input_info[input_key].input_data.shape[1] != 3 and net.input_info[input_key].input_data.shape[1] != 6 or \
|
||||
net.input_info[input_key].input_data.shape[0] != 1:
|
||||
log.error('Invalid input info. Should be 3 or 6 values length.')
|
||||
# Set input and output precision manually
|
||||
net.input_info[input_blob].precision = 'U8'
|
||||
|
||||
data = {}
|
||||
data[input_name] = images
|
||||
|
||||
if input_info_name != "":
|
||||
detection_size = net.input_info[input_info_name].input_data.shape[1]
|
||||
infos = np.ndarray(shape=(n, detection_size), dtype=float)
|
||||
for i in range(n):
|
||||
infos[i, 0] = h
|
||||
infos[i, 1] = w
|
||||
for j in range(2, detection_size):
|
||||
infos[i, j] = 1.0
|
||||
data[input_info_name] = infos
|
||||
|
||||
# --------------------------- Prepare output blobs ----------------------------------------------------
|
||||
log.info('Preparing output blobs')
|
||||
|
||||
output_name, output_info = "", None
|
||||
func = ng.function_from_cnn(net)
|
||||
if func:
|
||||
ops = func.get_ordered_ops()
|
||||
for op in ops:
|
||||
if op.friendly_name in net.outputs and op.get_type_name() == "DetectionOutput":
|
||||
output_name = op.friendly_name
|
||||
output_info = net.outputs[output_name]
|
||||
break
|
||||
if len(net.outputs) == 1:
|
||||
output_blob = next(iter(net.outputs))
|
||||
net.outputs[output_blob].precision = 'FP32'
|
||||
else:
|
||||
output_name = list(net.outputs.keys())[0]
|
||||
output_info = net.outputs[output_name]
|
||||
net.outputs['boxes'].precision = 'FP32'
|
||||
net.outputs['labels'].precision = 'U16'
|
||||
|
||||
if output_name == "":
|
||||
log.error("Can't find a DetectionOutput layer in the topology")
|
||||
|
||||
output_dims = output_info.shape
|
||||
if len(output_dims) != 4:
|
||||
log.error("Incorrect output dimensions for SSD model")
|
||||
max_proposal_count, object_size = output_dims[2], output_dims[3]
|
||||
|
||||
if object_size != 7:
|
||||
log.error("Output item should have 7 as a last dimension")
|
||||
|
||||
output_info.precision = "FP32"
|
||||
# -----------------------------------------------------------------------------------------------------
|
||||
|
||||
# --------------------------- Performing inference ----------------------------------------------------
|
||||
log.info("Loading model to the device")
|
||||
# ---------------------------Step 4. Loading model to the device-------------------------------------------------------
|
||||
log.info('Loading the model to the plugin')
|
||||
exec_net = ie.load_network(network=net, device_name=args.device)
|
||||
log.info("Creating infer request and starting inference")
|
||||
res = exec_net.infer(inputs=data)
|
||||
# -----------------------------------------------------------------------------------------------------
|
||||
|
||||
# --------------------------- Read and postprocess output ---------------------------------------------
|
||||
log.info("Processing output blobs")
|
||||
res = res[out_blob]
|
||||
boxes, classes = {}, {}
|
||||
data = res[0][0]
|
||||
for number, proposal in enumerate(data):
|
||||
if proposal[2] > 0:
|
||||
imid = np.int(proposal[0])
|
||||
ih, iw = images_hw[imid]
|
||||
label = np.int(proposal[1])
|
||||
confidence = proposal[2]
|
||||
xmin = np.int(iw * proposal[3])
|
||||
ymin = np.int(ih * proposal[4])
|
||||
xmax = np.int(iw * proposal[5])
|
||||
ymax = np.int(ih * proposal[6])
|
||||
print(f"[{number},{label}] element, prob = {confidence:.6f} ({xmin},{ymin})-({xmax},{ymax}) "
|
||||
f"batch id : {imid}", end="")
|
||||
if proposal[2] > 0.5:
|
||||
print(" WILL BE PRINTED!")
|
||||
if not imid in boxes.keys():
|
||||
boxes[imid] = []
|
||||
boxes[imid].append([xmin, ymin, xmax, ymax])
|
||||
if not imid in classes.keys():
|
||||
classes[imid] = []
|
||||
classes[imid].append(label)
|
||||
else:
|
||||
print()
|
||||
# ---------------------------Step 5. Create infer request--------------------------------------------------------------
|
||||
# load_network() method of the IECore class with a specified number of requests (default 1) returns an ExecutableNetwork
|
||||
# instance which stores infer requests. So you already created Infer requests in the previous step.
|
||||
|
||||
tmp_image = cv2.imread(args.input)
|
||||
for imid in classes:
|
||||
for box in boxes[imid]:
|
||||
cv2.rectangle(tmp_image, (box[0], box[1]), (box[2], box[3]), (232, 35, 244), 2)
|
||||
cv2.imwrite("out.bmp", tmp_image)
|
||||
log.info("Image out.bmp created!")
|
||||
# -----------------------------------------------------------------------------------------------------
|
||||
# ---------------------------Step 6. Prepare input---------------------------------------------------------------------
|
||||
original_image = cv2.imread(args.input)
|
||||
image = original_image.copy()
|
||||
_, _, net_h, net_w = net.input_info[input_blob].input_data.shape
|
||||
|
||||
log.info("Execution successful\n")
|
||||
log.info(
|
||||
"This sample is an API example, for any performance measurements please use the dedicated benchmark_app tool")
|
||||
if image.shape[:-1] != (net_h, net_w):
|
||||
log.warning(f'Image {args.input} is resized from {image.shape[:-1]} to {(net_h, net_w)}')
|
||||
image = cv2.resize(image, (net_w, net_h))
|
||||
|
||||
# Change data layout from HWC to CHW
|
||||
image = image.transpose((2, 0, 1))
|
||||
# Add N dimension to transform to NCHW
|
||||
image = np.expand_dims(image, axis=0)
|
||||
|
||||
# ---------------------------Step 7. Do inference----------------------------------------------------------------------
|
||||
log.info('Starting inference in synchronous mode')
|
||||
res = exec_net.infer(inputs={input_blob: image})
|
||||
|
||||
# ---------------------------Step 8. Process output--------------------------------------------------------------------
|
||||
# Generate a label list
|
||||
if args.labels:
|
||||
with open(args.labels, 'r') as f:
|
||||
labels = [line.split(',')[0].strip() for line in f]
|
||||
|
||||
output_image = original_image.copy()
|
||||
h, w, _ = output_image.shape
|
||||
|
||||
if len(net.outputs) == 1:
|
||||
res = res[output_blob]
|
||||
# Change a shape of a numpy.ndarray with results ([1, 1, N, 7]) to get another one ([N, 7]),
|
||||
# where N is the number of detected bounding boxes
|
||||
detections = res.reshape(-1, 7)
|
||||
else:
|
||||
detections = res['boxes']
|
||||
labels = res['labels']
|
||||
# Redefine scale coefficients
|
||||
w, h = w / net_w, h / net_h
|
||||
|
||||
for i, detection in enumerate(detections):
|
||||
if len(net.outputs) == 1:
|
||||
_, class_id, confidence, xmin, ymin, xmax, ymax = detection
|
||||
else:
|
||||
class_id = labels[i]
|
||||
xmin, ymin, xmax, ymax, confidence = detection
|
||||
|
||||
if confidence > 0.5:
|
||||
label = int(labels[class_id]) if args.labels else int(class_id)
|
||||
|
||||
xmin = int(xmin * w)
|
||||
ymin = int(ymin * h)
|
||||
xmax = int(xmax * w)
|
||||
ymax = int(ymax * h)
|
||||
|
||||
log.info(f'Found: label = {label}, confidence = {confidence:.2f}, '
|
||||
f'coords = ({xmin}, {ymin}), ({xmax}, {ymax})')
|
||||
|
||||
# Draw a bounding box on a output image
|
||||
cv2.rectangle(output_image, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2)
|
||||
|
||||
cv2.imwrite('out.bmp', output_image)
|
||||
log.info('Image out.bmp created!')
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
log.info('This sample is an API example, '
|
||||
'for any performance measurements please use the dedicated benchmark_app tool\n')
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main() or 0)
|
||||
sys.exit(main())
|
||||
|
194
inference-engine/ie_bridges/python/sample/style_transfer_sample/style_transfer_sample.py
Normal file → Executable file
194
inference-engine/ie_bridges/python/sample/style_transfer_sample/style_transfer_sample.py
Normal file → Executable file
@ -1,104 +1,146 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright (C) 2018-2021 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import logging as log
|
||||
import sys
|
||||
import os
|
||||
from argparse import ArgumentParser, SUPPRESS
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import logging as log
|
||||
from openvino.inference_engine import IECore
|
||||
|
||||
|
||||
def build_argparser():
|
||||
parser = ArgumentParser(add_help=False)
|
||||
def parse_args() -> argparse.Namespace:
|
||||
'''Parse and return command line arguments'''
|
||||
parser = argparse.ArgumentParser(add_help=False)
|
||||
args = parser.add_argument_group('Options')
|
||||
args.add_argument('-h', '--help', action='help', default=SUPPRESS, help='Show this help message and exit.')
|
||||
args.add_argument("-m", "--model", help="Required. Path to an .xml or .onnx file with a trained model.", required=True, type=str)
|
||||
args.add_argument("-i", "--input", help="Required. Path to an image files", required=True,
|
||||
type=str, nargs="+")
|
||||
args.add_argument("-l", "--cpu_extension",
|
||||
help="Optional. Required for CPU custom layers. "
|
||||
"Absolute MKLDNN (CPU)-targeted custom layers. Absolute path to a shared library with the "
|
||||
"kernels implementations", type=str, default=None)
|
||||
args.add_argument("-d", "--device",
|
||||
help="Optional. Specify the target device to infer on; CPU, GPU, FPGA, HDDL or MYRIAD is acceptable. Sample "
|
||||
"will look for a suitable plugin for device specified. Default value is CPU", default="CPU",
|
||||
type=str)
|
||||
args.add_argument("-nt", "--number_top", help="Optional. Number of top results", default=10, type=int)
|
||||
args.add_argument("--mean_val_r", "-mean_val_r",
|
||||
help="Optional. Mean value of red channel for mean value subtraction in postprocessing ", default=0,
|
||||
type=float)
|
||||
args.add_argument("--mean_val_g", "-mean_val_g",
|
||||
help="Optional. Mean value of green channel for mean value subtraction in postprocessing ", default=0,
|
||||
type=float)
|
||||
args.add_argument("--mean_val_b", "-mean_val_b",
|
||||
help="Optional. Mean value of blue channel for mean value subtraction in postprocessing ", default=0,
|
||||
type=float)
|
||||
return parser
|
||||
args.add_argument('-h', '--help', action='help', help='Show this help message and exit.')
|
||||
args.add_argument('-m', '--model', required=True, type=str,
|
||||
help='Required. Path to an .xml or .onnx file with a trained model.')
|
||||
args.add_argument('-i', '--input', required=True, type=str, nargs='+', help='Required. Path to an image file.')
|
||||
args.add_argument('-l', '--extension', type=str, default=None,
|
||||
help='Optional. Required by the CPU Plugin for executing the custom operation on a CPU. '
|
||||
'Absolute path to a shared library with the kernels implementations.')
|
||||
args.add_argument('-c', '--config', type=str, default=None,
|
||||
help='Optional. Required by GPU or VPU Plugins for the custom operation kernel. '
|
||||
'Absolute path to operation description file (.xml).')
|
||||
args.add_argument('-d', '--device', default='CPU', type=str,
|
||||
help='Optional. Specify the target device to infer on; CPU, GPU, MYRIAD, HDDL or HETERO: '
|
||||
'is acceptable. The sample will look for a suitable plugin for device specified. '
|
||||
'Default value is CPU.')
|
||||
args.add_argument('--original_size', action='store_true', default=False,
|
||||
help='Optional. Resize an output image to original image size.')
|
||||
args.add_argument('--mean_val_r', default=0, type=float,
|
||||
help='Optional. Mean value of red channel for mean value subtraction in postprocessing.')
|
||||
args.add_argument('--mean_val_g', default=0, type=float,
|
||||
help='Optional. Mean value of green channel for mean value subtraction in postprocessing.')
|
||||
args.add_argument('--mean_val_b', default=0, type=float,
|
||||
help='Optional. Mean value of blue channel for mean value subtraction in postprocessing.')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
log.basicConfig(format="[ %(levelname)s ] %(message)s", level=log.INFO, stream=sys.stdout)
|
||||
args = build_argparser().parse_args()
|
||||
log.basicConfig(format='[ %(levelname)s ] %(message)s', level=log.INFO, stream=sys.stdout)
|
||||
args = parse_args()
|
||||
|
||||
# Plugin initialization for specified device and load extensions library if specified
|
||||
log.info("Creating Inference Engine")
|
||||
# ---------------------------Step 1. Initialize inference engine core--------------------------------------------------
|
||||
log.info('Creating Inference Engine')
|
||||
ie = IECore()
|
||||
if args.cpu_extension and 'CPU' in args.device:
|
||||
ie.add_extension(args.cpu_extension, "CPU")
|
||||
|
||||
# Read a model in OpenVINO Intermediate Representation (.xml and .bin files) or ONNX (.onnx file) format
|
||||
model = args.model
|
||||
log.info(f"Loading network:\n\t{model}")
|
||||
net = ie.read_network(model=model)
|
||||
if args.extension and args.device == 'CPU':
|
||||
log.info(f'Loading the {args.device} extension: {args.extension}')
|
||||
ie.add_extension(args.extension, args.device)
|
||||
|
||||
assert len(net.input_info.keys()) == 1, "Sample supports only single input topologies"
|
||||
assert len(net.outputs) == 1, "Sample supports only single output topologies"
|
||||
if args.config and args.device in ('GPU', 'MYRIAD', 'HDDL'):
|
||||
log.info(f'Loading the {args.device} configuration: {args.config}')
|
||||
ie.set_config({'CONFIG_FILE': args.config}, args.device)
|
||||
|
||||
log.info("Preparing input blobs")
|
||||
# ---------------------------Step 2. Read a model in OpenVINO Intermediate Representation or ONNX format---------------
|
||||
log.info(f'Reading the network: {args.model}')
|
||||
# (.xml and .bin files) or (.onnx file)
|
||||
net = ie.read_network(model=args.model)
|
||||
|
||||
if len(net.input_info) != 1:
|
||||
log.error('Sample supports only single input topologies')
|
||||
return -1
|
||||
if len(net.outputs) != 1:
|
||||
log.error('Sample supports only single output topologies')
|
||||
return -1
|
||||
|
||||
# ---------------------------Step 3. Configure input & output----------------------------------------------------------
|
||||
log.info('Configuring input and output blobs')
|
||||
# Get names of input and output blobs
|
||||
input_blob = next(iter(net.input_info))
|
||||
out_blob = next(iter(net.outputs))
|
||||
|
||||
# Set input and output precision manually
|
||||
net.input_info[input_blob].precision = 'U8'
|
||||
net.outputs[out_blob].precision = 'FP32'
|
||||
|
||||
# Set a batch size to a equal number of input images
|
||||
net.batch_size = len(args.input)
|
||||
|
||||
# Read and pre-process input images
|
||||
n, c, h, w = net.input_info[input_blob].input_data.shape
|
||||
images = np.ndarray(shape=(n, c, h, w))
|
||||
for i in range(n):
|
||||
image = cv2.imread(args.input[i])
|
||||
if image.shape[:-1] != (h, w):
|
||||
log.warning(f"Image {args.input[i]} is resized from {image.shape[:-1]} to {(h, w)}")
|
||||
image = cv2.resize(image, (w, h))
|
||||
image = image.transpose((2, 0, 1)) # Change data layout from HWC to CHW
|
||||
images[i] = image
|
||||
log.info(f"Batch size is {n}")
|
||||
|
||||
# Loading model to the plugin
|
||||
log.info("Loading model to the plugin")
|
||||
# ---------------------------Step 4. Loading model to the device-------------------------------------------------------
|
||||
log.info('Loading the model to the plugin')
|
||||
exec_net = ie.load_network(network=net, device_name=args.device)
|
||||
|
||||
# Start sync inference
|
||||
log.info("Starting inference")
|
||||
res = exec_net.infer(inputs={input_blob: images})
|
||||
# ---------------------------Step 5. Create infer request--------------------------------------------------------------
|
||||
# load_network() method of the IECore class with a specified number of requests (default 1) returns an ExecutableNetwork
|
||||
# instance which stores infer requests. So you already created Infer requests in the previous step.
|
||||
|
||||
# Processing output blob
|
||||
log.info("Processing output blob")
|
||||
# ---------------------------Step 6. Prepare input---------------------------------------------------------------------
|
||||
original_images = []
|
||||
|
||||
n, c, h, w = net.input_info[input_blob].input_data.shape
|
||||
input_data = np.ndarray(shape=(n, c, h, w))
|
||||
|
||||
for i in range(n):
|
||||
image = cv2.imread(args.input[i])
|
||||
original_images.append(image)
|
||||
|
||||
if image.shape[:-1] != (h, w):
|
||||
log.warning(f'Image {args.input[i]} is resized from {image.shape[:-1]} to {(h, w)}')
|
||||
image = cv2.resize(image, (w, h))
|
||||
|
||||
# Change data layout from HWC to CHW
|
||||
image = image.transpose((2, 0, 1))
|
||||
|
||||
input_data[i] = image
|
||||
|
||||
# ---------------------------Step 7. Do inference----------------------------------------------------------------------
|
||||
log.info('Starting inference in synchronous mode')
|
||||
res = exec_net.infer(inputs={input_blob: input_data})
|
||||
|
||||
# ---------------------------Step 8. Process output--------------------------------------------------------------------
|
||||
res = res[out_blob]
|
||||
# Post process output
|
||||
for batch, data in enumerate(res):
|
||||
# Clip values to [0, 255] range
|
||||
data = np.swapaxes(data, 0, 2)
|
||||
data = np.swapaxes(data, 0, 1)
|
||||
data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
|
||||
data[data < 0] = 0
|
||||
data[data > 255] = 255
|
||||
data = data[::] - (args.mean_val_r, args.mean_val_g, args.mean_val_b)
|
||||
out_img = os.path.join(os.path.dirname(__file__), f"out_{batch}.bmp")
|
||||
cv2.imwrite(out_img, data)
|
||||
log.info(f"Result image was saved to {out_img}")
|
||||
log.info("This sample is an API example, for any performance measurements please use the dedicated benchmark_app tool\n")
|
||||
|
||||
for i in range(n):
|
||||
output_image = res[i]
|
||||
# Change data layout from CHW to HWC
|
||||
output_image = output_image.transpose((1, 2, 0))
|
||||
# Convert BGR color order to RGB
|
||||
output_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Apply mean argument values
|
||||
output_image = output_image[::] - (args.mean_val_r, args.mean_val_g, args.mean_val_b)
|
||||
# Set pixel values bitween 0 and 255
|
||||
output_image = np.clip(output_image, 0, 255)
|
||||
|
||||
# Resize a output image to original size
|
||||
if args.original_size:
|
||||
h, w, _ = original_images[i].shape
|
||||
output_image = cv2.resize(output_image, (w, h))
|
||||
|
||||
cv2.imwrite(f'out_{i}.bmp', output_image)
|
||||
log.info(f'Image out_{i}.bmp created!')
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------------------------
|
||||
log.info('This sample is an API example, '
|
||||
'for any performance measurements please use the dedicated benchmark_app tool\n')
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
sys.exit(main() or 0)
|
||||
sys.exit(main())
|
||||
|
Loading…
Reference in New Issue
Block a user