Fixes for Object Detection SSD Python sample (#3518)

* added check so that sample only supports networks with one input
* moved ngraph-realted operations to related segment of the sample
* fix for output image not being saved correcly due
This commit is contained in:
Maksim Makridin 2020-12-08 23:16:59 +03:00 committed by GitHub
parent 2aec8a610b
commit 7d8144f160
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -58,8 +58,6 @@ def main():
model = args.model model = args.model
log.info(f"Loading network:\n\t{model}") log.info(f"Loading network:\n\t{model}")
net = ie.read_network(model=model) net = ie.read_network(model=model)
func = ng.function_from_cnn(net)
ops = func.get_ordered_ops()
# ----------------------------------------------------------------------------------------------------- # -----------------------------------------------------------------------------------------------------
# ------------- 2. Load Plugin for inference engine and extensions library if specified -------------- # ------------- 2. Load Plugin for inference engine and extensions library if specified --------------
@ -78,6 +76,7 @@ def main():
# --------------------------- 3. Read and preprocess input -------------------------------------------- # --------------------------- 3. Read and preprocess input --------------------------------------------
print("inputs number: " + str(len(net.input_info.keys()))) print("inputs number: " + str(len(net.input_info.keys())))
assert len(net.input_info.keys()) == 1, 'Sample supports networks with one input'
for input_key in net.input_info: for input_key in net.input_info:
print("input shape: " + str(net.input_info[input_key].input_data.shape)) print("input shape: " + str(net.input_info[input_key].input_data.shape))
@ -92,9 +91,9 @@ def main():
ih, iw = image.shape[:-1] ih, iw = image.shape[:-1]
images_hw.append((ih, iw)) images_hw.append((ih, iw))
log.info("File was added: ") log.info("File was added: ")
log.info(" {}".format(args.input[i])) log.info(" {}".format(args.input))
if (ih, iw) != (h, w): if (ih, iw) != (h, w):
log.warning("Image {} is resized from {} to {}".format(args.input[i], image.shape[:-1], (h, w))) log.warning("Image {} is resized from {} to {}".format(args.input, image.shape[:-1], (h, w)))
image = cv2.resize(image, (w, h)) image = cv2.resize(image, (w, h))
image = image.transpose((2, 0, 1)) # Change data layout from HWC to CHW image = image.transpose((2, 0, 1)) # Change data layout from HWC to CHW
images[i] = image images[i] = image
@ -134,6 +133,8 @@ def main():
# --------------------------- Prepare output blobs ---------------------------------------------------- # --------------------------- Prepare output blobs ----------------------------------------------------
log.info('Preparing output blobs') log.info('Preparing output blobs')
func = ng.function_from_cnn(net)
ops = func.get_ordered_ops()
output_name, output_info = "", net.outputs[next(iter(net.outputs.keys()))] output_name, output_info = "", net.outputs[next(iter(net.outputs.keys()))]
output_ops = {op.friendly_name : op for op in ops \ output_ops = {op.friendly_name : op for op in ops \
if op.friendly_name in net.outputs and op.get_type_name() == "DetectionOutput"} if op.friendly_name in net.outputs and op.get_type_name() == "DetectionOutput"}
@ -190,7 +191,7 @@ def main():
print() print()
for imid in classes: for imid in classes:
tmp_image = cv2.imread(args.input[imid]) tmp_image = cv2.imread(args.input)
for box in boxes[imid]: for box in boxes[imid]:
cv2.rectangle(tmp_image, (box[0], box[1]), (box[2], box[3]), (232, 35, 244), 2) cv2.rectangle(tmp_image, (box[0], box[1]), (box[2], box[3]), (232, 35, 244), 2)
cv2.imwrite("out.bmp", tmp_image) cv2.imwrite("out.bmp", tmp_image)