[IE TOOLS] Use input_info in python benchmark app (#660)

This commit is contained in:
Anastasia Kuporosova 2020-05-29 21:28:17 +03:00 committed by GitHub
parent cbad43f3a5
commit 3ef1a26174
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 20 additions and 20 deletions

View File

@ -66,7 +66,7 @@ class Benchmark:
ie_network = self.ie.read_network(xml_filename, bin_filename) ie_network = self.ie.read_network(xml_filename, bin_filename)
input_info = ie_network.inputs input_info = ie_network.input_info
if not input_info: if not input_info:
raise AttributeError('No inputs info is provided') raise AttributeError('No inputs info is provided')

View File

@ -175,12 +175,12 @@ def run(args):
# --------------------- 5. Resizing network to match image sizes and given batch --------------------------- # --------------------- 5. Resizing network to match image sizes and given batch ---------------------------
next_step() next_step()
shapes = {k: v.shape.copy() for k, v in ie_network.inputs.items()} shapes = {k: v.input_data.shape.copy() for k, v in ie_network.input_info.items()}
reshape = False reshape = False
if args.shape: if args.shape:
reshape |= update_shapes(shapes, args.shape, ie_network.inputs) reshape |= update_shapes(shapes, args.shape, ie_network.input_info)
if args.batch_size and args.batch_size != ie_network.batch_size: if args.batch_size and args.batch_size != ie_network.batch_size:
reshape |= adjust_shapes_batch(shapes, args.batch_size, ie_network.inputs) reshape |= adjust_shapes_batch(shapes, args.batch_size, ie_network.input_info)
if reshape: if reshape:
start_time = datetime.utcnow() start_time = datetime.utcnow()
@ -259,7 +259,7 @@ def run(args):
if args.paths_to_input: if args.paths_to_input:
for path in args.paths_to_input: for path in args.paths_to_input:
paths_to_input.append(os.path.abspath(*path) if args.paths_to_input else None) paths_to_input.append(os.path.abspath(*path) if args.paths_to_input else None)
set_inputs(paths_to_input, batch_size, exe_network.inputs, infer_requests) set_inputs(paths_to_input, batch_size, exe_network.input_info, infer_requests)
if statistics: if statistics:
statistics.add_parameters(StatisticsReport.Category.RUNTIME_CONFIG, statistics.add_parameters(StatisticsReport.Category.RUNTIME_CONFIG,

View File

@ -47,13 +47,13 @@ def set_inputs(paths_to_input, batch_size, input_info, requests):
def get_inputs(paths_to_input, batch_size, input_info, requests): def get_inputs(paths_to_input, batch_size, input_info, requests):
input_image_sizes = {} input_image_sizes = {}
for key in sorted(input_info.keys()): for key in sorted(input_info.keys()):
if is_image(input_info[key]): if is_image(input_info[key].input_data):
input_image_sizes[key] = (input_info[key].shape[2], input_info[key].shape[3]) input_image_sizes[key] = (input_info[key].input_data.shape[2], input_info[key].input_data.shape[3])
logger.info("Network input '{}' precision {}, dimensions ({}): {}".format(key, logger.info("Network input '{}' precision {}, dimensions ({}): {}".format(key,
input_info[key].precision, input_info[key].input_data.precision,
input_info[key].layout, input_info[key].input_data.layout,
" ".join(str(x) for x in " ".join(str(x) for x in
input_info[key].shape))) input_info[key].input_data.shape)))
images_count = len(input_image_sizes.keys()) images_count = len(input_image_sizes.keys())
binaries_count = len(input_info) - images_count binaries_count = len(input_info) - images_count
@ -102,31 +102,31 @@ def get_inputs(paths_to_input, batch_size, input_info, requests):
input_data = {} input_data = {}
keys = list(sorted(input_info.keys())) keys = list(sorted(input_info.keys()))
for key in keys: for key in keys:
if is_image(input_info[key]): if is_image(input_info[key].input_data):
# input is image # input is image
if len(image_files) > 0: if len(image_files) > 0:
input_data[key] = fill_blob_with_image(image_files, request_id, batch_size, keys.index(key), input_data[key] = fill_blob_with_image(image_files, request_id, batch_size, keys.index(key),
len(keys), input_info[key]) len(keys), input_info[key].input_data)
continue continue
# input is binary # input is binary
if len(binary_files): if len(binary_files):
input_data[key] = fill_blob_with_binary(binary_files, request_id, batch_size, keys.index(key), input_data[key] = fill_blob_with_binary(binary_files, request_id, batch_size, keys.index(key),
len(keys), input_info[key]) len(keys), input_info[key].input_data)
continue continue
# most likely input is image info # most likely input is image info
if is_image_info(input_info[key]) and len(input_image_sizes) == 1: if is_image_info(input_info[key].input_data) and len(input_image_sizes) == 1:
image_size = input_image_sizes[list(input_image_sizes.keys()).pop()] image_size = input_image_sizes[list(input_image_sizes.keys()).pop()]
logger.info("Fill input '" + key + "' with image size " + str(image_size[0]) + "x" + logger.info("Fill input '" + key + "' with image size " + str(image_size[0]) + "x" +
str(image_size[1])) str(image_size[1]))
input_data[key] = fill_blob_with_image_info(image_size, input_info[key]) input_data[key] = fill_blob_with_image_info(image_size, input_info[key].input_data)
continue continue
# fill with random data # fill with random data
logger.info("Fill input '{}' with random values ({} is expected)".format(key, "image" if is_image( logger.info("Fill input '{}' with random values ({} is expected)".format(key, "image" if is_image(
input_info[key]) else "some binary data")) input_info[key].input_data) else "some binary data"))
input_data[key] = fill_blob_with_random(input_info[key]) input_data[key] = fill_blob_with_random(input_info[key].input_data)
requests_input_data.append(input_data) requests_input_data.append(input_data)

View File

@ -62,10 +62,10 @@ def next_step(additional_info='', step_id=0):
def config_network_inputs(ie_network: IENetwork): def config_network_inputs(ie_network: IENetwork):
input_info = ie_network.inputs input_info = ie_network.input_info
for key in input_info.keys(): for key in input_info.keys():
if is_image(input_info[key]): if is_image(input_info[key].input_data):
# Set the precision of input data provided by the user # Set the precision of input data provided by the user
# Should be called before load of the network to the plugin # Should be called before load of the network to the plugin
input_info[key].precision = 'U8' input_info[key].precision = 'U8'
@ -261,7 +261,7 @@ def update_shapes(shapes, shapes_string: str, inputs_info):
def adjust_shapes_batch(shapes, batch_size: int, inputs_info): def adjust_shapes_batch(shapes, batch_size: int, inputs_info):
updated = False updated = False
for name, data in inputs_info.items(): for name, data in inputs_info.items():
layout = data.layout layout = data.input_data.layout
batch_index = layout.index('N') if 'N' in layout else -1 batch_index = layout.index('N') if 'N' in layout else -1
if batch_index != -1 and shapes[name][batch_index] != batch_size: if batch_index != -1 and shapes[name][batch_index] != batch_size:
shapes[name][batch_index] = batch_size shapes[name][batch_index] = batch_size