Fix MO IR Reader extender for StridedSlice to support empty begin and end masks (#17019)
This commit is contained in:
parent
aa5b6ecac2
commit
01065338ef
@ -18,8 +18,6 @@ class StridedSlice_extender(Extender):
|
||||
if op.has(attr) and op[attr] != '':
|
||||
Extender.attr_to_list(op, attr)
|
||||
else:
|
||||
assert attr not in ['begin_mask', 'end_mask'],\
|
||||
'{} is not defined for the node {}'.format(attr, op.soft_get('name', op.id))
|
||||
op[attr] = int64_array([])
|
||||
|
||||
op.begin_mask = int64_array([1 - i for i in op.begin_mask])
|
||||
|
@ -182,3 +182,14 @@ class TestOps(unittest.TestCase):
|
||||
_, callable_attribute = layer_info[0]
|
||||
self.assertTrue(callable(callable_attribute))
|
||||
self.assertEqual(callable_attribute(if_node), "If_opset8")
|
||||
|
||||
def test_strided_slice_no_begin_end_mask(self):
|
||||
data_shape = [6, 12, 10, 24]
|
||||
data_parameter = opset11.parameter(
|
||||
data_shape, name="Data", dtype=np.float32)
|
||||
strided_slice = opset11.strided_slice(data_parameter, np.int32([1, 2, 3, 4]), np.int32(
|
||||
[3, 6, 9, 12]), np.int32([1, 1, 1, 1]), begin_mask=[], end_mask=[], name="StridedSlice_10")
|
||||
model = Model(strided_slice, [data_parameter])
|
||||
graph = TestOps.check_graph_can_save(model, 'strided_slice_model')
|
||||
strided_slice_node = graph.get_op_nodes(op="StridedSlice")[0]
|
||||
self.assertEqual(strided_slice_node["version"], "opset1")
|
||||
|
Loading…
Reference in New Issue
Block a user