Fix issue with np.expand_dims for numpy<18.x (#3436)
* Fix trouble with np.expand_dims for numpy<18.x * Delete function expand_dims * Added additional line
This commit is contained in:
committed by
GitHub
parent
8344c29090
commit
b3124a5c77
@@ -134,7 +134,10 @@ def explicit_broadcasting(input_value: np.array, target_shape: np.array, axes_ma
|
||||
:return: broadcasted value
|
||||
"""
|
||||
res_shape, normalized_axes_mapping = explicit_shape_broadcasting(input_value.shape, target_shape, axes_mapping)
|
||||
#TODO: Function 'expand_dims' should be replaced with 'numpy.expand_dims' if numpy version will be >=18.x in requirements.
|
||||
expand_dim_axis = set(np.arange(len(target_shape))) - set(normalized_axes_mapping)
|
||||
|
||||
input_expanded = np.expand_dims(input_value.copy(), axis=list(expand_dim_axis))
|
||||
input_expanded = input_value.copy()
|
||||
|
||||
for axis in sorted(list(expand_dim_axis)):
|
||||
input_expanded = np.expand_dims(input_expanded, axis)
|
||||
return np.broadcast_to(input_expanded, res_shape)
|
||||
|
||||
Reference in New Issue
Block a user