Modifications for Object Detection SSD Python sample (#3976)
* Modifications for Object Detection SSD Sample Python * Fixes for tests * Changing the way output checks are processed
This commit is contained in:
@@ -73,13 +73,7 @@ def main():
|
||||
# -----------------------------------------------------------------------------------------------------
|
||||
|
||||
# --------------------------- 3. Read and preprocess input --------------------------------------------
|
||||
|
||||
print("inputs number: " + str(len(net.input_info.keys())))
|
||||
assert len(net.input_info.keys()) == 1, 'Sample supports networks with one input'
|
||||
|
||||
for input_key in net.input_info:
|
||||
print("input shape: " + str(net.input_info[input_key].input_data.shape))
|
||||
print("input key: " + input_key)
|
||||
if len(net.input_info[input_key].input_data.layout) == 4:
|
||||
n, c, h, w = net.input_info[input_key].input_data.shape
|
||||
|
||||
@@ -96,7 +90,6 @@ def main():
|
||||
image = cv2.resize(image, (w, h))
|
||||
image = image.transpose((2, 0, 1)) # Change data layout from HWC to CHW
|
||||
images[i] = image
|
||||
|
||||
# -----------------------------------------------------------------------------------------------------
|
||||
|
||||
# --------------------------- 4. Configure input & output ---------------------------------------------
|
||||
@@ -122,23 +115,30 @@ def main():
|
||||
data[input_name] = images
|
||||
|
||||
if input_info_name != "":
|
||||
infos = np.ndarray(shape=(n, c), dtype=float)
|
||||
detection_size = net.input_info[input_info_name].input_data.shape[1]
|
||||
infos = np.ndarray(shape=(n, detection_size), dtype=float)
|
||||
for i in range(n):
|
||||
infos[i, 0] = h
|
||||
infos[i, 1] = w
|
||||
infos[i, 2] = 1.0
|
||||
for j in range(2, detection_size):
|
||||
infos[i, j] = 1.0
|
||||
data[input_info_name] = infos
|
||||
|
||||
# --------------------------- Prepare output blobs ----------------------------------------------------
|
||||
log.info('Preparing output blobs')
|
||||
|
||||
output_name, output_info = "", None
|
||||
func = ng.function_from_cnn(net)
|
||||
ops = func.get_ordered_ops()
|
||||
output_name, output_info = "", net.outputs[next(iter(net.outputs.keys()))]
|
||||
output_ops = {op.friendly_name : op for op in ops \
|
||||
if op.friendly_name in net.outputs and op.get_type_name() == "DetectionOutput"}
|
||||
if len(output_ops) != 0:
|
||||
output_name, output_info = output_ops.popitem()
|
||||
if func:
|
||||
ops = func.get_ordered_ops()
|
||||
for op in ops:
|
||||
if op.friendly_name in net.outputs and op.get_type_name() == "DetectionOutput":
|
||||
output_name = op.friendly_name
|
||||
output_info = net.outputs[output_name]
|
||||
break
|
||||
else:
|
||||
output_name = list(net.outputs.keys())[0]
|
||||
output_info = net.outputs[output_name]
|
||||
|
||||
if output_name == "":
|
||||
log.error("Can't find a DetectionOutput layer in the topology")
|
||||
@@ -189,12 +189,12 @@ def main():
|
||||
else:
|
||||
print()
|
||||
|
||||
tmp_image = cv2.imread(args.input)
|
||||
for imid in classes:
|
||||
tmp_image = cv2.imread(args.input)
|
||||
for box in boxes[imid]:
|
||||
cv2.rectangle(tmp_image, (box[0], box[1]), (box[2], box[3]), (232, 35, 244), 2)
|
||||
cv2.imwrite("out.bmp", tmp_image)
|
||||
log.info("Image out.bmp created!")
|
||||
cv2.imwrite("out.bmp", tmp_image)
|
||||
log.info("Image out.bmp created!")
|
||||
# -----------------------------------------------------------------------------------------------------
|
||||
|
||||
log.info("Execution successful\n")
|
||||
|
||||
Reference in New Issue
Block a user