[ MO ONNX ] Resize-11 clear error message (#620)

* Small refactoring of extractors

* [ MO ] Throwing an exception while extracting Resize-11 which is not supported
This commit is contained in:
Evgenya Stepyreva 2020-05-27 08:09:15 +03:00 committed by GitHub
parent d3ea03bbfc
commit 5c2eb05990
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 20 additions and 18 deletions

View File

@ -16,16 +16,20 @@
from extensions.ops.upsample import UpsampleOp
from mo.front.extractor import FrontExtractorOp
from mo.front.onnx.extractors.utils import onnx_attr
from mo.front.onnx.extractors.utils import onnx_attr, get_onnx_opset_version
from mo.graph.graph import Node
from mo.utils.error import Error
class ResizeExtractor(FrontExtractorOp):
op = 'Resize'
enabled = True
@staticmethod
def extract(node: Node):
@classmethod
def extract(cls, node: Node):
onnx_opset_version = get_onnx_opset_version(node)
if onnx_opset_version is not None and onnx_opset_version >= 11:
raise Error("ONNX Resize operation from opset {} is not supported.".format(onnx_opset_version))
mode = onnx_attr(node, 'mode', 's', default=b'nearest').decode()
UpsampleOp.update_node_stat(node, {'mode': mode})
return __class__.enabled
return cls.enabled

View File

@ -23,8 +23,8 @@ class ReverseSequenceExtractor(FrontExtractorOp):
op = 'ReverseSequence'
enabled = True
@staticmethod
def extract(node):
@classmethod
def extract(cls, node):
batch_axis = onnx_attr(node, 'batch_axis', 'i', default=1)
time_axis = onnx_attr(node, 'time_axis', 'i', default=0)
@ -33,4 +33,4 @@ class ReverseSequenceExtractor(FrontExtractorOp):
'seq_axis': time_axis,
}
ReverseSequence.update_node_stat(node, attrs)
return __class__.enabled
return cls.enabled

View File

@ -24,9 +24,8 @@ class BucketizeFrontExtractor(FrontExtractorOp):
op = 'Bucketize'
enabled = True
@staticmethod
def extract(node):
@classmethod
def extract(cls, node):
boundaries = np.array(node.pb.attr['boundaries'].list.f, dtype=np.float)
Bucketize.update_node_stat(node, {'boundaries': boundaries, 'with_right_bound': False})
return __class__.enabled
return cls.enabled

View File

@ -22,8 +22,7 @@ class SparseToDenseFrontExtractor(FrontExtractorOp):
op = 'SparseToDense'
enabled = True
@staticmethod
def extract(node):
@classmethod
def extract(cls, node):
SparseToDense.update_node_stat(node)
return __class__.enabled
return cls.enabled

View File

@ -25,8 +25,8 @@ class LinearComponentFrontExtractor(FrontExtractorOp):
op = 'linearcomponent'
enabled = True
@staticmethod
def extract(node):
@classmethod
def extract(cls, node):
pb = node.parameters
collect_until_token(pb, b'<Params>')
weights, weights_shape = read_binary_matrix(pb)
@ -39,4 +39,4 @@ class LinearComponentFrontExtractor(FrontExtractorOp):
embed_input(mapping_rule, 1, 'weights', weights)
FullyConnected.update_node_stat(node, mapping_rule)
return __class__.enabled
return cls.enabled