[MO] Move redundant checks in ScatterUpdate operation shape infer (#10306)
* Add extender for ScatterUpdate operation * Remove scatterupdate extender * Remove redundant checks in Scatter shape inference function * Move checks to ScatterElementsUpdate operations * mava checks to appropriate place
This commit is contained in:
committed by
GitHub
parent
a0ad849c19
commit
84ee38d89e
@@ -41,11 +41,6 @@ class Scatter(Op):
|
||||
updates_shape = node.in_port(2).data.get_shape()
|
||||
assert input_shape is not None and updates_shape is not None and indices_shape is not None, \
|
||||
'The node "{}" input shape is None'.format(node_name)
|
||||
assert len(input_shape) == len(indices_shape), 'data and indices inputs for node "{}" must be of the ' \
|
||||
'same rank. Instead got {} and {}'.format(node_name, len(input_shape), len(indices_shape))
|
||||
assert compatible_shapes(indices_shape, updates_shape), \
|
||||
'updates and indices shapes for node "{}" must be equal. Instead got {} and {}.' \
|
||||
''.format(node_name, indices_shape, updates_shape)
|
||||
|
||||
node.out_port(0).data.set_shape(input_shape)
|
||||
|
||||
@@ -101,8 +96,20 @@ class ScatterElementsUpdate(Scatter):
|
||||
|
||||
input_value = node.in_port(0).data.get_value()
|
||||
indices_value = node.in_port(1).data.get_value()
|
||||
indices_shape = node.in_port(1).data.get_shape()
|
||||
updates_value = node.in_port(2).data.get_value()
|
||||
|
||||
input_shape = node.in_port(0).data.get_shape()
|
||||
indices_shape = node.in_port(1).data.get_shape()
|
||||
updates_shape = node.in_port(2).data.get_shape()
|
||||
|
||||
assert len(input_shape) == len(indices_shape), 'data and indices inputs for node "{}" must be of the ' \
|
||||
'same rank. Instead got {} and {}'.format(node_name,
|
||||
len(input_shape),
|
||||
len(indices_shape))
|
||||
assert compatible_shapes(indices_shape, updates_shape), \
|
||||
'updates and indices shapes for node "{}" must be equal. Instead got {} and {}.' \
|
||||
''.format(node_name, indices_shape, updates_shape)
|
||||
|
||||
axis = node.in_port(3).data.get_value()
|
||||
if input_value is not None and indices_value is not None and updates_value is not None and axis is not None:
|
||||
assert axis.size == 1, "The node {} has axis input value size equal to {} but it should be exactly 1.".format(
|
||||
|
||||
Reference in New Issue
Block a user