Fix issue with np.expand_dims for numpy<18.x (#3436)

* Fix trouble with np.expand_dims for numpy<18.x

* Delete function expand_dims

* Added additional line
This commit is contained in:
Eugeny Volosenkov
2020-12-03 09:48:17 +03:00
committed by GitHub
parent 8344c29090
commit b3124a5c77

View File

@@ -134,7 +134,10 @@ def explicit_broadcasting(input_value: np.array, target_shape: np.array, axes_ma
:return: broadcasted value
"""
res_shape, normalized_axes_mapping = explicit_shape_broadcasting(input_value.shape, target_shape, axes_mapping)
#TODO: Function 'expand_dims' should be replaced with 'numpy.expand_dims' if numpy version will be >=18.x in requirements.
expand_dim_axis = set(np.arange(len(target_shape))) - set(normalized_axes_mapping)
input_expanded = np.expand_dims(input_value.copy(), axis=list(expand_dim_axis))
input_expanded = input_value.copy()
for axis in sorted(list(expand_dim_axis)):
input_expanded = np.expand_dims(input_expanded, axis)
return np.broadcast_to(input_expanded, res_shape)