Fix MO IR Reader extender for StridedSlice to support empty begin and end masks (#17019)

This commit is contained in:
Maxim Vafin 2023-04-24 11:08:28 +02:00 committed by GitHub
parent aa5b6ecac2
commit 01065338ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 2 deletions

View File

@ -18,8 +18,6 @@ class StridedSlice_extender(Extender):
if op.has(attr) and op[attr] != '':
Extender.attr_to_list(op, attr)
else:
assert attr not in ['begin_mask', 'end_mask'],\
'{} is not defined for the node {}'.format(attr, op.soft_get('name', op.id))
op[attr] = int64_array([])
op.begin_mask = int64_array([1 - i for i in op.begin_mask])

View File

@ -182,3 +182,14 @@ class TestOps(unittest.TestCase):
_, callable_attribute = layer_info[0]
self.assertTrue(callable(callable_attribute))
self.assertEqual(callable_attribute(if_node), "If_opset8")
def test_strided_slice_no_begin_end_mask(self):
data_shape = [6, 12, 10, 24]
data_parameter = opset11.parameter(
data_shape, name="Data", dtype=np.float32)
strided_slice = opset11.strided_slice(data_parameter, np.int32([1, 2, 3, 4]), np.int32(
[3, 6, 9, 12]), np.int32([1, 1, 1, 1]), begin_mask=[], end_mask=[], name="StridedSlice_10")
model = Model(strided_slice, [data_parameter])
graph = TestOps.check_graph_can_save(model, 'strided_slice_model')
strided_slice_node = graph.get_op_nodes(op="StridedSlice")[0]
self.assertEqual(strided_slice_node["version"], "opset1")