[MO] update ROIAlign op to opset9 (#11623)
This commit is contained in:
parent
4426aa58e2
commit
1543e35b1b
@ -30,7 +30,7 @@ class CommonLayerTest:
|
||||
" the specific framework")
|
||||
|
||||
def _test(self, framework_model, ref_net, ie_device, precision, ir_version, temp_dir, api_2,
|
||||
use_new_frontend=False, infer_timeout=60, enabled_transforms='',
|
||||
use_new_frontend=True, infer_timeout=60, enabled_transforms='',
|
||||
disabled_transforms='', **kwargs):
|
||||
"""
|
||||
:param enabled_transforms/disabled_transforms: string with idxs of transforms that should be enabled/disabled.
|
||||
@ -60,6 +60,8 @@ class CommonLayerTest:
|
||||
|
||||
if use_new_frontend:
|
||||
mo_params["use_new_frontend"] = True
|
||||
else:
|
||||
mo_params["use_legacy_frontend"] = True
|
||||
|
||||
exit_code, stderr = generate_ir(**mo_params)
|
||||
|
||||
|
@ -2,6 +2,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from common.layer_test_class import check_ir_version
|
||||
from common.onnx_layer_test_class import OnnxRuntimeLayerTest
|
||||
|
||||
@ -9,8 +10,22 @@ from unit_tests.utils.graph import build_graph
|
||||
|
||||
|
||||
class TestROIAlign(OnnxRuntimeLayerTest):
|
||||
def _prepare_input(self, inputs_dict):
|
||||
for input in inputs_dict.keys():
|
||||
if input == 'indices':
|
||||
if isinstance(inputs_dict['input'], list):
|
||||
batch = inputs_dict['input'][0]
|
||||
else:
|
||||
batch = inputs_dict['input'].shape[0]
|
||||
inputs_dict[input] = np.random.choice(range(batch), inputs_dict[input])
|
||||
elif input == 'input':
|
||||
inputs_dict[input] = np.ones(inputs_dict[input]).astype(np.float32)
|
||||
else:
|
||||
inputs_dict[input] = np.random.randint(-255, 255, inputs_dict[input]).astype(np.float32)
|
||||
return inputs_dict
|
||||
|
||||
def create_net(self, input_shape, rois_shape, indices_shape, output_shape,
|
||||
pooled_h, pooled_w, mode, sampling_ratio, spatial_scale, ir_version):
|
||||
pooled_h, pooled_w, mode, sampling_ratio, spatial_scale, ir_version, onnx_version):
|
||||
"""
|
||||
ONNX net IR net
|
||||
|
||||
@ -18,21 +33,22 @@ class TestROIAlign(OnnxRuntimeLayerTest):
|
||||
|
||||
"""
|
||||
|
||||
|
||||
#
|
||||
# Create ONNX model
|
||||
#
|
||||
|
||||
import onnx
|
||||
from onnx import helper
|
||||
from onnx import TensorProto
|
||||
from onnx import TensorProto, OperatorSetIdProto
|
||||
|
||||
input = helper.make_tensor_value_info('input', TensorProto.FLOAT, input_shape)
|
||||
rois = helper.make_tensor_value_info('rois', TensorProto.FLOAT, rois_shape)
|
||||
indices = helper.make_tensor_value_info('indices', TensorProto.FLOAT, indices_shape)
|
||||
indices = helper.make_tensor_value_info('indices', TensorProto.INT64, indices_shape)
|
||||
output = helper.make_tensor_value_info('output', TensorProto.FLOAT, output_shape)
|
||||
|
||||
node_def = onnx.helper.make_node(
|
||||
'ROIAlign',
|
||||
'RoiAlign',
|
||||
inputs=['input', 'rois', 'indices'],
|
||||
outputs=['output'],
|
||||
**{'output_height': pooled_h, 'output_width': pooled_w, 'mode': mode,
|
||||
@ -47,8 +63,11 @@ class TestROIAlign(OnnxRuntimeLayerTest):
|
||||
[output],
|
||||
)
|
||||
|
||||
operatorsetid = OperatorSetIdProto()
|
||||
operatorsetid.domain = ""
|
||||
operatorsetid.version = onnx_version
|
||||
# Create the model (ModelProto)
|
||||
onnx_net = helper.make_model(graph_def, producer_name='test_model')
|
||||
onnx_net = helper.make_model(graph_def, producer_name='test_model', opset_imports=[operatorsetid])
|
||||
|
||||
#
|
||||
# Create reference IR net
|
||||
@ -95,18 +114,30 @@ class TestROIAlign(OnnxRuntimeLayerTest):
|
||||
dict(input_shape=[1, 256, 200, 272], rois_shape=[1000, 4], indices_shape=[1000],
|
||||
pooled_h=7, pooled_w=7, mode="avg", sampling_ratio=2, spatial_scale=0.25,
|
||||
output_shape=[1000, 256, 7, 7]),
|
||||
dict(input_shape=[7, 256, 200, 200], rois_shape=[1000, 4], indices_shape=[1000],
|
||||
pooled_h=6, pooled_w=6, mode="max", sampling_ratio=2, spatial_scale=16.0,
|
||||
output_shape=[1000, 256, 6, 6]),
|
||||
dict(input_shape=[7, 256, 200, 200], rois_shape=[1000, 4], indices_shape=[1000],
|
||||
pooled_h=5, pooled_w=6, mode="max", sampling_ratio=2, spatial_scale=16.0,
|
||||
output_shape=[1000, 256, 5, 6]),
|
||||
|
||||
dict(input_shape=[1, 90, 12, 14], rois_shape=[5, 4], indices_shape=[5],
|
||||
pooled_h=2, pooled_w=2, mode="avg", sampling_ratio=2, spatial_scale=0.25,
|
||||
output_shape=[5, 90, 2, 2]),
|
||||
dict(input_shape=[1, 20, 12, 14], rois_shape=[5, 4], indices_shape=[5],
|
||||
pooled_h=2, pooled_w=2, mode="avg", sampling_ratio=2, spatial_scale=0.25,
|
||||
output_shape=[5, 20, 2, 2]),
|
||||
dict(input_shape=[1, 50, 12, 14], rois_shape=[5, 4], indices_shape=[5],
|
||||
pooled_h=2, pooled_w=2, mode="avg", sampling_ratio=2, spatial_scale=0.25,
|
||||
output_shape=[5, 50, 2, 2]),
|
||||
dict(input_shape=[1, 120, 12, 14], rois_shape=[5, 4], indices_shape=[5],
|
||||
pooled_h=2, pooled_w=2, mode="avg", sampling_ratio=2, spatial_scale=0.25,
|
||||
output_shape=[5, 120, 2, 2]),
|
||||
dict(input_shape=[7, 1, 4, 4], rois_shape=[2, 4], indices_shape=[2],
|
||||
pooled_h=2, pooled_w=2, mode="max", sampling_ratio=2, spatial_scale=16.0,
|
||||
output_shape=[2, 1, 2, 2]),
|
||||
]
|
||||
|
||||
@pytest.mark.parametrize("params", test_data)
|
||||
@pytest.mark.nightly
|
||||
def test_roi_align(self, params, ie_device, precision, ir_version, temp_dir, api_2):
|
||||
self._test(*self.create_net(**params, ir_version=ir_version), ie_device, precision,
|
||||
ir_version,
|
||||
temp_dir=temp_dir, api_2=api_2)
|
||||
@pytest.mark.precommit
|
||||
def test_roi_alignv10(self, params, ie_device, precision, ir_version, temp_dir, api_2):
|
||||
# TODO: ticket for investigating GPU failures: CVS-86300
|
||||
if ie_device != "GPU":
|
||||
self._test(*self.create_net(**params, ir_version=ir_version, onnx_version=10), ie_device, precision,
|
||||
ir_version,
|
||||
temp_dir=temp_dir, api_2=api_2,
|
||||
use_legacy_frontend=True)
|
||||
|
@ -53,4 +53,4 @@ class TestReverseV2Ops(CommonTFLayerTest):
|
||||
def test_reversev2_precommit(self, params, keep_dims, ie_device, precision, ir_version,
|
||||
temp_dir, api_2):
|
||||
self._test(*self.create_reversev2_net(**params, keep_dims=keep_dims, ir_version=ir_version),
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir, api_2=api_2)
|
||||
ie_device, precision, ir_version, temp_dir=temp_dir, api_2=api_2, use_new_frontend=False)
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
from openvino.tools.mo.ops.roialign import ROIAlign
|
||||
from openvino.tools.mo.front.extractor import FrontExtractorOp
|
||||
from openvino.tools.mo.front.onnx.extractors.utils import onnx_attr
|
||||
from openvino.tools.mo.front.onnx.extractors.utils import onnx_attr, get_onnx_opset_version
|
||||
|
||||
|
||||
class ROIAlignExtractor(FrontExtractorOp):
|
||||
@ -17,8 +17,16 @@ class ROIAlignExtractor(FrontExtractorOp):
|
||||
output_width = onnx_attr(node, 'output_width', 'i', default=1)
|
||||
sampling_ratio = onnx_attr(node, 'sampling_ratio', 'i', default=0)
|
||||
spatial_scale = onnx_attr(node, 'spatial_scale', 'f', default=1.0)
|
||||
|
||||
ROIAlign.update_node_stat(node, {'pooled_h': output_height, 'pooled_w': output_width,
|
||||
'sampling_ratio': sampling_ratio, 'spatial_scale': spatial_scale,
|
||||
'mode': mode})
|
||||
onnx_opset_version = get_onnx_opset_version(node)
|
||||
if onnx_opset_version >= 16:
|
||||
aligned_mode = onnx_attr(node, 'coordinate_transformation_mode', 's', default=b'half_pixel').decode()
|
||||
if aligned_mode == "output_half_pixel":
|
||||
aligned_mode = "asymmetric"
|
||||
ROIAlign.update_node_stat(node, {'pooled_h': output_height, 'pooled_w': output_width,
|
||||
'sampling_ratio': sampling_ratio, 'spatial_scale': spatial_scale,
|
||||
'mode': mode, 'aligned_mode': aligned_mode})
|
||||
else:
|
||||
ROIAlign.update_node_stat(node, {'pooled_h': output_height, 'pooled_w': output_width,
|
||||
'sampling_ratio': sampling_ratio, 'spatial_scale': spatial_scale,
|
||||
'mode': mode})
|
||||
return cls.enabled
|
||||
|
@ -18,33 +18,45 @@ class ROIAlign(Op):
|
||||
assert 'pooled_w' in attrs, '`pooled_w` attribute is not set for ROIAlign during creation'
|
||||
assert 'sampling_ratio' in attrs, '`sampling_ratio` attribute is not set for ROIAlign during creation'
|
||||
assert 'spatial_scale' in attrs, '`spatial_scale` attribute is not set for ROIAlign during creation'
|
||||
|
||||
super().__init__(graph, {
|
||||
'op': self.op,
|
||||
'type': self.op,
|
||||
'version': 'opset3',
|
||||
'version': 'opset9',
|
||||
|
||||
'infer': self.infer,
|
||||
'reverse_infer': self.reverse_infer,
|
||||
|
||||
'in_ports_count': 3,
|
||||
'out_ports_count': 1,
|
||||
|
||||
'aligned_mode': 'asymmetric',
|
||||
}, attrs)
|
||||
|
||||
def backend_attrs(self):
|
||||
return [
|
||||
('mode', lambda node: str(node.mode)),
|
||||
('pooled_h', lambda node: str(int(node.pooled_h))),
|
||||
('pooled_w', lambda node: str(int(node.pooled_w))),
|
||||
('sampling_ratio', lambda node: str(int(node.sampling_ratio))),
|
||||
('spatial_scale', lambda node: str(float(node.spatial_scale))),
|
||||
]
|
||||
version = self.get_opset()
|
||||
if version == 'opset3':
|
||||
return [
|
||||
('mode', lambda node: str(node.mode)),
|
||||
('pooled_h', lambda node: str(int(node.pooled_h))),
|
||||
('pooled_w', lambda node: str(int(node.pooled_w))),
|
||||
('sampling_ratio', lambda node: str(int(node.sampling_ratio))),
|
||||
('spatial_scale', lambda node: str(float(node.spatial_scale))),
|
||||
]
|
||||
elif version == 'opset9':
|
||||
return [
|
||||
('mode', lambda node: str(node.mode)),
|
||||
('pooled_h', lambda node: str(int(node.pooled_h))),
|
||||
('pooled_w', lambda node: str(int(node.pooled_w))),
|
||||
('sampling_ratio', lambda node: str(int(node.sampling_ratio))),
|
||||
('spatial_scale', lambda node: str(float(node.spatial_scale))),
|
||||
('aligned_mode', lambda node: str(node.aligned_mode))
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def infer(node):
|
||||
|
||||
layout = node.graph.graph['layout']
|
||||
node_name = node.soft_get('name', node.id)
|
||||
|
||||
assert len([port for port in node.in_ports().values() if not port.disconnected()]) == 3, \
|
||||
'The node "{}" must 3 inputs'.format(node_name)
|
||||
|
||||
@ -53,7 +65,9 @@ class ROIAlign(Op):
|
||||
assert node.has_valid('mode'), '"mode" attribute is not set for node "{}"'.format(node_name)
|
||||
assert node.mode in ['avg', 'max'], \
|
||||
'"mode" attribute range of values is ["avg", "max"], got {} for node "{}"'.format(node.mode, node_name)
|
||||
|
||||
if node.get_opset() == 'opset9':
|
||||
assert node.aligned_mode in ['asymmetric', 'half_pixel_for_nn', 'half_pixel'], \
|
||||
'"aligned_mode" attribute range of values is ["asymmetric", "half_pixel_for_nn", "half_pixel"]'
|
||||
input_shape = node.in_port(0).data.get_shape()
|
||||
rois_shape = node.in_port(1).data.get_shape()
|
||||
indices_shape = node.in_port(2).data.get_shape()
|
||||
@ -63,7 +77,6 @@ class ROIAlign(Op):
|
||||
'to number of ROIs for node "{}"'.format(node_name)
|
||||
assert compatible_dims(rois_shape[1], 4), 'The size of ROI element must be 4 for node "{}"'.format(node_name)
|
||||
assert len(input_shape) == 4, 'The rank of port 0 input tensor of node "{}" must be 4.'.format(node_name)
|
||||
|
||||
node.out_port(0).data.set_shape(
|
||||
shape_for_layout(layout,
|
||||
batch=rois_shape[0],
|
||||
|
140
tools/mo/unit_tests/mo/ops/roialign_test.py
Normal file
140
tools/mo/unit_tests/mo/ops/roialign_test.py
Normal file
@ -0,0 +1,140 @@
|
||||
# Copyright (C) 2018-2022 Intel Corporation
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import unittest
|
||||
from openvino.tools.mo.front.common.partial_infer.utils import int64_array, float32_array
|
||||
|
||||
import numpy as np
|
||||
from openvino.tools.mo.graph.graph import Node
|
||||
from openvino.tools.mo.ops.roialign import ROIAlign
|
||||
from unit_tests.utils.graph import build_graph
|
||||
|
||||
|
||||
class TestROIAlignOps(unittest.TestCase):
|
||||
node_attrs = {
|
||||
# input 1
|
||||
"1_input": {"kind": "op", "type": "Parameter", "value": None},
|
||||
"input_data": {"shape": None, "kind": "data", "value": None},
|
||||
#input 2
|
||||
"2_rois": {"kind": "op", "type": "Parameter","value": None},
|
||||
"rois_data": {"shape": None,"kind": "data", "value": None},
|
||||
# input 3
|
||||
"3_indices": {"kind": "op","type": "Parameter"},
|
||||
"indices_data": {"shape": None, "kind": "data", "value": None},
|
||||
# ROIAlign
|
||||
"node": {
|
||||
"kind": "op",
|
||||
"type": "ROIAlign",
|
||||
"pooled_h": None,
|
||||
"pooled_w": None,
|
||||
"mode": None,
|
||||
"sampling_ratio": 2,
|
||||
"spatial_scale": 16,
|
||||
"aligned_mode": None,
|
||||
},
|
||||
"node_data": {"shape": None, "kind": "data", "value": None},
|
||||
# output
|
||||
"result": {"kind": "op","type": "Result"},
|
||||
}
|
||||
|
||||
def test_roialignv1(self):
|
||||
graph = build_graph(
|
||||
self.node_attrs,
|
||||
[
|
||||
("1_input", "input_data"),
|
||||
("input_data", "node", {"in": 0}),
|
||||
("2_rois", "rois_data"),
|
||||
("rois_data", "node", {"in": 1}),
|
||||
("3_indices", "indices_data"),
|
||||
("indices_data", "node", {"in": 2}),
|
||||
("node", "node_data"),
|
||||
("node_data", "result"),
|
||||
],
|
||||
{
|
||||
'input_data': {'shape': int64_array([1, 256, 200, 272])},
|
||||
'rois_data': {'shape': int64_array([1000, 4])},
|
||||
'indices_data': {'shape': int64_array([1000])},
|
||||
'node': {'mode': 'max', 'pooled_h': 7, 'pooled_w': 7, 'aligned_mode': 'asymmetric', 'version': 'opset9'},
|
||||
}
|
||||
)
|
||||
graph.graph["layout"] = "NCHW"
|
||||
node = Node(graph, "node")
|
||||
ROIAlign.infer(node)
|
||||
self.assertListEqual(list([1000, 256, 7, 7]), graph.node['node_data']['shape'].data.tolist())
|
||||
|
||||
def test_roialignv2(self):
|
||||
graph = build_graph(
|
||||
self.node_attrs,
|
||||
[
|
||||
("1_input", "input_data"),
|
||||
("input_data", "node", {"in": 0}),
|
||||
("2_rois", "rois_data"),
|
||||
("rois_data", "node", {"in": 1}),
|
||||
("3_indices", "indices_data"),
|
||||
("indices_data", "node", {"in": 2}),
|
||||
("node", "node_data"),
|
||||
("node_data", "result"),
|
||||
],
|
||||
{
|
||||
'input_data': {'shape': int64_array([7, 256, 200, 200])},
|
||||
'rois_data': {'shape': int64_array([300, 4])},
|
||||
'indices_data': {'shape': int64_array([300])},
|
||||
'node': {'mode': 'max', 'pooled_h': 5, 'pooled_w': 6, 'aligned_mode': 'half_pixel_for_nn', 'version':'opset9'},
|
||||
}
|
||||
)
|
||||
graph.graph["layout"] = "NCHW"
|
||||
node = Node(graph, "node")
|
||||
|
||||
ROIAlign.infer(node)
|
||||
self.assertListEqual(list([300, 256, 5, 6]), graph.node['node_data']['shape'].data.tolist())
|
||||
|
||||
def test_roialignv3(self):
|
||||
graph = build_graph(
|
||||
self.node_attrs,
|
||||
[
|
||||
("1_input", "input_data"),
|
||||
("input_data", "node", {"in": 0}),
|
||||
("2_rois", "rois_data"),
|
||||
("rois_data", "node", {"in": 1}),
|
||||
("3_indices", "indices_data"),
|
||||
("indices_data", "node", {"in": 2}),
|
||||
("node", "node_data"),
|
||||
("node_data", "result"),
|
||||
],
|
||||
{
|
||||
'input_data': {'shape': int64_array([2, 3, 5, 5])},
|
||||
'rois_data': {'shape': int64_array([7, 4])},
|
||||
'indices_data': {'shape': int64_array([7])},
|
||||
'node': {'mode': 'max', 'pooled_h': 2, 'pooled_w': 2, 'aligned_mode': 'half_pixel', 'version': 'opset9'},
|
||||
}
|
||||
)
|
||||
graph.graph["layout"] = "NCHW"
|
||||
node = Node(graph, "node")
|
||||
|
||||
ROIAlign.infer(node)
|
||||
self.assertListEqual(list([7, 3, 2, 2]), graph.node['node_data']['shape'].data.tolist())
|
||||
|
||||
|
||||
def test_roialign_wrong_aligned_mode(self):
|
||||
graph = build_graph(
|
||||
self.node_attrs,
|
||||
[
|
||||
("1_input", "input_data"),
|
||||
("input_data", "node", {"in": 0}),
|
||||
("2_rois", "rois_data"),
|
||||
("rois_data", "node", {"in": 1}),
|
||||
("3_indices", "indices_data"),
|
||||
("indices_data", "node", {"in": 2}),
|
||||
("node", "node_data"),
|
||||
("node_data", "result"),
|
||||
],
|
||||
{
|
||||
'input_data': {'shape': int64_array([2, 3, 5, 5])},
|
||||
'rois_data': {'shape': int64_array([7, 4])},
|
||||
'indices_data': {'shape': int64_array([7])},
|
||||
'node': {'mode': 'max', 'pooled_h': 2, 'pooled_w': 2, 'aligned_mode': 'full_pixel', 'version': 'opset9'},
|
||||
}
|
||||
)
|
||||
graph.graph["layout"] = "NCHW"
|
||||
node = Node(graph, "node")
|
||||
self.assertRaises(AssertionError, ROIAlign.infer, node)
|
Loading…
Reference in New Issue
Block a user