[ MO ONNX ] Resize-11 clear error message (#620)
* Small refactoring of extractors * [ MO ] Throwing an exception while extracting Resize-11 which is not supported
This commit is contained in:
parent
d3ea03bbfc
commit
5c2eb05990
@ -16,16 +16,20 @@
|
||||
|
||||
from extensions.ops.upsample import UpsampleOp
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.onnx.extractors.utils import onnx_attr
|
||||
from mo.front.onnx.extractors.utils import onnx_attr, get_onnx_opset_version
|
||||
from mo.graph.graph import Node
|
||||
from mo.utils.error import Error
|
||||
|
||||
|
||||
class ResizeExtractor(FrontExtractorOp):
|
||||
op = 'Resize'
|
||||
enabled = True
|
||||
|
||||
@staticmethod
|
||||
def extract(node: Node):
|
||||
@classmethod
|
||||
def extract(cls, node: Node):
|
||||
onnx_opset_version = get_onnx_opset_version(node)
|
||||
if onnx_opset_version is not None and onnx_opset_version >= 11:
|
||||
raise Error("ONNX Resize operation from opset {} is not supported.".format(onnx_opset_version))
|
||||
mode = onnx_attr(node, 'mode', 's', default=b'nearest').decode()
|
||||
UpsampleOp.update_node_stat(node, {'mode': mode})
|
||||
return __class__.enabled
|
||||
return cls.enabled
|
||||
|
@ -23,8 +23,8 @@ class ReverseSequenceExtractor(FrontExtractorOp):
|
||||
op = 'ReverseSequence'
|
||||
enabled = True
|
||||
|
||||
@staticmethod
|
||||
def extract(node):
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
batch_axis = onnx_attr(node, 'batch_axis', 'i', default=1)
|
||||
time_axis = onnx_attr(node, 'time_axis', 'i', default=0)
|
||||
|
||||
@ -33,4 +33,4 @@ class ReverseSequenceExtractor(FrontExtractorOp):
|
||||
'seq_axis': time_axis,
|
||||
}
|
||||
ReverseSequence.update_node_stat(node, attrs)
|
||||
return __class__.enabled
|
||||
return cls.enabled
|
||||
|
@ -24,9 +24,8 @@ class BucketizeFrontExtractor(FrontExtractorOp):
|
||||
op = 'Bucketize'
|
||||
enabled = True
|
||||
|
||||
@staticmethod
|
||||
def extract(node):
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
boundaries = np.array(node.pb.attr['boundaries'].list.f, dtype=np.float)
|
||||
Bucketize.update_node_stat(node, {'boundaries': boundaries, 'with_right_bound': False})
|
||||
|
||||
return __class__.enabled
|
||||
return cls.enabled
|
||||
|
@ -22,8 +22,7 @@ class SparseToDenseFrontExtractor(FrontExtractorOp):
|
||||
op = 'SparseToDense'
|
||||
enabled = True
|
||||
|
||||
@staticmethod
|
||||
def extract(node):
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
SparseToDense.update_node_stat(node)
|
||||
|
||||
return __class__.enabled
|
||||
return cls.enabled
|
||||
|
@ -25,8 +25,8 @@ class LinearComponentFrontExtractor(FrontExtractorOp):
|
||||
op = 'linearcomponent'
|
||||
enabled = True
|
||||
|
||||
@staticmethod
|
||||
def extract(node):
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
pb = node.parameters
|
||||
collect_until_token(pb, b'<Params>')
|
||||
weights, weights_shape = read_binary_matrix(pb)
|
||||
@ -39,4 +39,4 @@ class LinearComponentFrontExtractor(FrontExtractorOp):
|
||||
embed_input(mapping_rule, 1, 'weights', weights)
|
||||
|
||||
FullyConnected.update_node_stat(node, mapping_rule)
|
||||
return __class__.enabled
|
||||
return cls.enabled
|
||||
|
Loading…
Reference in New Issue
Block a user