[IE TOOLS] Use input_info in python benchmark app (#660)
This commit is contained in:
parent
cbad43f3a5
commit
3ef1a26174
@ -66,7 +66,7 @@ class Benchmark:
|
||||
|
||||
ie_network = self.ie.read_network(xml_filename, bin_filename)
|
||||
|
||||
input_info = ie_network.inputs
|
||||
input_info = ie_network.input_info
|
||||
|
||||
if not input_info:
|
||||
raise AttributeError('No inputs info is provided')
|
||||
|
@ -175,12 +175,12 @@ def run(args):
|
||||
# --------------------- 5. Resizing network to match image sizes and given batch ---------------------------
|
||||
next_step()
|
||||
|
||||
shapes = {k: v.shape.copy() for k, v in ie_network.inputs.items()}
|
||||
shapes = {k: v.input_data.shape.copy() for k, v in ie_network.input_info.items()}
|
||||
reshape = False
|
||||
if args.shape:
|
||||
reshape |= update_shapes(shapes, args.shape, ie_network.inputs)
|
||||
reshape |= update_shapes(shapes, args.shape, ie_network.input_info)
|
||||
if args.batch_size and args.batch_size != ie_network.batch_size:
|
||||
reshape |= adjust_shapes_batch(shapes, args.batch_size, ie_network.inputs)
|
||||
reshape |= adjust_shapes_batch(shapes, args.batch_size, ie_network.input_info)
|
||||
|
||||
if reshape:
|
||||
start_time = datetime.utcnow()
|
||||
@ -259,7 +259,7 @@ def run(args):
|
||||
if args.paths_to_input:
|
||||
for path in args.paths_to_input:
|
||||
paths_to_input.append(os.path.abspath(*path) if args.paths_to_input else None)
|
||||
set_inputs(paths_to_input, batch_size, exe_network.inputs, infer_requests)
|
||||
set_inputs(paths_to_input, batch_size, exe_network.input_info, infer_requests)
|
||||
|
||||
if statistics:
|
||||
statistics.add_parameters(StatisticsReport.Category.RUNTIME_CONFIG,
|
||||
|
@ -47,13 +47,13 @@ def set_inputs(paths_to_input, batch_size, input_info, requests):
|
||||
def get_inputs(paths_to_input, batch_size, input_info, requests):
|
||||
input_image_sizes = {}
|
||||
for key in sorted(input_info.keys()):
|
||||
if is_image(input_info[key]):
|
||||
input_image_sizes[key] = (input_info[key].shape[2], input_info[key].shape[3])
|
||||
if is_image(input_info[key].input_data):
|
||||
input_image_sizes[key] = (input_info[key].input_data.shape[2], input_info[key].input_data.shape[3])
|
||||
logger.info("Network input '{}' precision {}, dimensions ({}): {}".format(key,
|
||||
input_info[key].precision,
|
||||
input_info[key].layout,
|
||||
input_info[key].input_data.precision,
|
||||
input_info[key].input_data.layout,
|
||||
" ".join(str(x) for x in
|
||||
input_info[key].shape)))
|
||||
input_info[key].input_data.shape)))
|
||||
|
||||
images_count = len(input_image_sizes.keys())
|
||||
binaries_count = len(input_info) - images_count
|
||||
@ -102,31 +102,31 @@ def get_inputs(paths_to_input, batch_size, input_info, requests):
|
||||
input_data = {}
|
||||
keys = list(sorted(input_info.keys()))
|
||||
for key in keys:
|
||||
if is_image(input_info[key]):
|
||||
if is_image(input_info[key].input_data):
|
||||
# input is image
|
||||
if len(image_files) > 0:
|
||||
input_data[key] = fill_blob_with_image(image_files, request_id, batch_size, keys.index(key),
|
||||
len(keys), input_info[key])
|
||||
len(keys), input_info[key].input_data)
|
||||
continue
|
||||
|
||||
# input is binary
|
||||
if len(binary_files):
|
||||
input_data[key] = fill_blob_with_binary(binary_files, request_id, batch_size, keys.index(key),
|
||||
len(keys), input_info[key])
|
||||
len(keys), input_info[key].input_data)
|
||||
continue
|
||||
|
||||
# most likely input is image info
|
||||
if is_image_info(input_info[key]) and len(input_image_sizes) == 1:
|
||||
if is_image_info(input_info[key].input_data) and len(input_image_sizes) == 1:
|
||||
image_size = input_image_sizes[list(input_image_sizes.keys()).pop()]
|
||||
logger.info("Fill input '" + key + "' with image size " + str(image_size[0]) + "x" +
|
||||
str(image_size[1]))
|
||||
input_data[key] = fill_blob_with_image_info(image_size, input_info[key])
|
||||
input_data[key] = fill_blob_with_image_info(image_size, input_info[key].input_data)
|
||||
continue
|
||||
|
||||
# fill with random data
|
||||
logger.info("Fill input '{}' with random values ({} is expected)".format(key, "image" if is_image(
|
||||
input_info[key]) else "some binary data"))
|
||||
input_data[key] = fill_blob_with_random(input_info[key])
|
||||
input_info[key].input_data) else "some binary data"))
|
||||
input_data[key] = fill_blob_with_random(input_info[key].input_data)
|
||||
|
||||
requests_input_data.append(input_data)
|
||||
|
||||
|
@ -62,10 +62,10 @@ def next_step(additional_info='', step_id=0):
|
||||
|
||||
|
||||
def config_network_inputs(ie_network: IENetwork):
|
||||
input_info = ie_network.inputs
|
||||
input_info = ie_network.input_info
|
||||
|
||||
for key in input_info.keys():
|
||||
if is_image(input_info[key]):
|
||||
if is_image(input_info[key].input_data):
|
||||
# Set the precision of input data provided by the user
|
||||
# Should be called before load of the network to the plugin
|
||||
input_info[key].precision = 'U8'
|
||||
@ -261,7 +261,7 @@ def update_shapes(shapes, shapes_string: str, inputs_info):
|
||||
def adjust_shapes_batch(shapes, batch_size: int, inputs_info):
|
||||
updated = False
|
||||
for name, data in inputs_info.items():
|
||||
layout = data.layout
|
||||
layout = data.input_data.layout
|
||||
batch_index = layout.index('N') if 'N' in layout else -1
|
||||
if batch_index != -1 and shapes[name][batch_index] != batch_size:
|
||||
shapes[name][batch_index] = batch_size
|
||||
|
Loading…
Reference in New Issue
Block a user