[ MO ONNX ] TopK-1/10/11 proper extracting (#600)

This commit is contained in:
Evgenya Stepyreva 2020-05-26 21:53:24 +03:00 committed by GitHub
parent 4a44f84dab
commit 73f3b7c8fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 3 deletions

View File

@ -16,7 +16,7 @@
from extensions.ops.topk import TopK
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import onnx_attr
from mo.front.onnx.extractors.utils import onnx_attr, onnx_node_has_attr
class TopKExtractor(FrontExtractorOp):
@ -25,6 +25,18 @@ class TopKExtractor(FrontExtractorOp):
@classmethod
def extract(cls, node):
axis = onnx_attr(node, 'axis', 'i', default=-1)
TopK.update_node_stat(node, {'axis': axis, 'sort': 'value'})
"""
TopK-1 (k as attribute, required)
TopK-10 (k as input, no sorting manipulations)
TopK-11 (k as input, sorting manipulations through `sorted` and `largest` attrs)
"""
attrs = {
'axis': onnx_attr(node, 'axis', 'i', default=-1)
}
if onnx_node_has_attr(node, 'k'):
attrs['k'] = onnx_attr(node, 'k', 'i')
attrs['sort'] = 'value' if onnx_attr(node, 'sorted', 'i', default=1) else 'none'
attrs['mode'] = 'max' if onnx_attr(node, 'largest', 'i', default=1) else 'min'
TopK.update_node_stat(node, attrs)
return cls.enabled

View File

@ -20,6 +20,11 @@ from mo.graph.graph import Node
from mo.utils.error import Error
def onnx_node_has_attr(node: Node, name: str):
attrs = [a for a in node.pb.attribute if a.name == name]
return len(attrs) != 0
def onnx_attr(node: Node, name: str, field: str, default=None, dst_type=None):
""" Retrieves ONNX attribute with name `name` from ONNX protobuf `node.pb`.
The final value is casted to dst_type if attribute really exists.