[MO] Move redundant checks in ScatterUpdate operation shape infer (#10306)

* Add extender for ScatterUpdate operation

* Remove scatterupdate extender

* Remove redundant checks in Scatter shape inference function

* Move checks to ScatterElementsUpdate operations

* mava checks to appropriate place
This commit is contained in:
Anton Chetverikov
2022-02-15 04:55:38 +03:00
committed by GitHub
parent a0ad849c19
commit 84ee38d89e

View File

@@ -41,11 +41,6 @@ class Scatter(Op):
updates_shape = node.in_port(2).data.get_shape()
assert input_shape is not None and updates_shape is not None and indices_shape is not None, \
'The node "{}" input shape is None'.format(node_name)
assert len(input_shape) == len(indices_shape), 'data and indices inputs for node "{}" must be of the ' \
'same rank. Instead got {} and {}'.format(node_name, len(input_shape), len(indices_shape))
assert compatible_shapes(indices_shape, updates_shape), \
'updates and indices shapes for node "{}" must be equal. Instead got {} and {}.' \
''.format(node_name, indices_shape, updates_shape)
node.out_port(0).data.set_shape(input_shape)
@@ -101,8 +96,20 @@ class ScatterElementsUpdate(Scatter):
input_value = node.in_port(0).data.get_value()
indices_value = node.in_port(1).data.get_value()
indices_shape = node.in_port(1).data.get_shape()
updates_value = node.in_port(2).data.get_value()
input_shape = node.in_port(0).data.get_shape()
indices_shape = node.in_port(1).data.get_shape()
updates_shape = node.in_port(2).data.get_shape()
assert len(input_shape) == len(indices_shape), 'data and indices inputs for node "{}" must be of the ' \
'same rank. Instead got {} and {}'.format(node_name,
len(input_shape),
len(indices_shape))
assert compatible_shapes(indices_shape, updates_shape), \
'updates and indices shapes for node "{}" must be equal. Instead got {} and {}.' \
''.format(node_name, indices_shape, updates_shape)
axis = node.in_port(3).data.get_value()
if input_value is not None and indices_value is not None and updates_value is not None and axis is not None:
assert axis.size == 1, "The node {} has axis input value size equal to {} but it should be exactly 1.".format(