Support Pad-2 in opset-11 ONNX model (#4886)
* Support Pad-2 in opset-11 ONNX model * Add unit test for pad * Apply review feedback
This commit is contained in:
parent
d77869ba71
commit
d40a607ca0
@ -16,6 +16,7 @@
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.front.extractor import FrontExtractorOp
|
||||
from mo.front.onnx.extractors.utils import onnx_attr, get_onnx_opset_version
|
||||
from mo.ops.pad import AttributedPad, ONNXPad
|
||||
@ -28,11 +29,13 @@ class PadFrontExtractor(FrontExtractorOp):
|
||||
@classmethod
|
||||
def extract(cls, node):
|
||||
mode = onnx_attr(node, 'mode', 's', default='constant', dst_type=lambda x: x.decode())
|
||||
if get_onnx_opset_version(node) < 11:
|
||||
pads = onnx_attr(node, 'pads', 'ints', dst_type=lambda x: np.array(x, dtype=np.int64))
|
||||
# Pytorch 1.3 while converting to opset 11, creates Pad from older opset.
|
||||
# To be able to convert such models we have to check if pads attribute exists.
|
||||
pads = onnx_attr(node, 'pads', 'ints', dst_type=int64_array)
|
||||
if get_onnx_opset_version(node) < 11 or pads is not None:
|
||||
value = onnx_attr(node, 'value', 'f', default=0.)
|
||||
|
||||
assert pads is not None
|
||||
assert pads is not None, 'pads is required attribute for Pad operation'
|
||||
|
||||
# MO Pad op and ONNX Pad op have different format for pads values
|
||||
# MO Pad has Dx2 where D is the total number of dimensions
|
||||
|
@ -57,6 +57,20 @@ class TestPad(BaseExtractorsTestingClass):
|
||||
|
||||
self.compare()
|
||||
|
||||
def test_older_pad_opset_11(self):
|
||||
node = self._create_node()
|
||||
node.graph.graph['fw_opset_version'] = 11
|
||||
PadFrontExtractor.extract(node)
|
||||
self.res = node
|
||||
|
||||
self.expected = {
|
||||
'pads': [[1, 3], [2, 4]],
|
||||
'mode': 'constant',
|
||||
'fill_value': 0
|
||||
}
|
||||
|
||||
self.compare()
|
||||
|
||||
def test_reflect(self):
|
||||
node = self._create_node(mode='reflect')
|
||||
PadFrontExtractor.extract(node)
|
||||
|
Loading…
Reference in New Issue
Block a user