[ MO ONNX ] TopK-1/10/11 proper extracting (#600)
This commit is contained in:
parent
4a44f84dab
commit
73f3b7c8fc
@ -16,7 +16,7 @@
|
||||
|
||||
from extensions.ops.topk import TopK
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.onnx.extractors.utils import onnx_attr
|
||||
from mo.front.onnx.extractors.utils import onnx_attr, onnx_node_has_attr
|
||||
|
||||
|
||||
class TopKExtractor(FrontExtractorOp):
|
||||
@ -25,6 +25,18 @@ class TopKExtractor(FrontExtractorOp):
|
||||
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
axis = onnx_attr(node, 'axis', 'i', default=-1)
|
||||
TopK.update_node_stat(node, {'axis': axis, 'sort': 'value'})
|
||||
"""
|
||||
TopK-1 (k as attribute, required)
|
||||
TopK-10 (k as input, no sorting manipulations)
|
||||
TopK-11 (k as input, sorting manipulations through `sorted` and `largest` attrs)
|
||||
"""
|
||||
attrs = {
|
||||
'axis': onnx_attr(node, 'axis', 'i', default=-1)
|
||||
}
|
||||
if onnx_node_has_attr(node, 'k'):
|
||||
attrs['k'] = onnx_attr(node, 'k', 'i')
|
||||
attrs['sort'] = 'value' if onnx_attr(node, 'sorted', 'i', default=1) else 'none'
|
||||
attrs['mode'] = 'max' if onnx_attr(node, 'largest', 'i', default=1) else 'min'
|
||||
|
||||
TopK.update_node_stat(node, attrs)
|
||||
return cls.enabled
|
||||
|
@ -20,6 +20,11 @@ from mo.graph.graph import Node
|
||||
from mo.utils.error import Error
|
||||
|
||||
|
||||
def onnx_node_has_attr(node: Node, name: str):
|
||||
attrs = [a for a in node.pb.attribute if a.name == name]
|
||||
return len(attrs) != 0
|
||||
|
||||
|
||||
def onnx_attr(node: Node, name: str, field: str, default=None, dst_type=None):
|
||||
""" Retrieves ONNX attribute with name `name` from ONNX protobuf `node.pb`.
|
||||
The final value is casted to dst_type if attribute really exists.
|
||||
|
Loading…
Reference in New Issue
Block a user