Fixed extractor for MVN from ONNX (#1653)

* Fixed extractor for MVN from ONNX

* Updated MVN extractor from ONNX

* Code style
This commit is contained in:
Evgeny Lazarev 2020-08-06 13:53:16 +03:00 committed by GitHub
parent 8ae30481f1
commit 853cfaa038
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -33,23 +33,16 @@ class MeanVarianceNormalizationExtractor(FrontExtractorOp):
default=np.array([0, 2, 3], dtype=np.int64),
dst_type=lambda x: np.array(x, dtype=np.int64))
if axes is not None:
if 0 in axes:
raise Error('Reduction over the batch dimension in node "{}" '
'is not supported by the backend.'.format(name))
# Dimension 4 (if it's present in the input tensor)
# should also be in the list of axes for reduction.
# This case will be handled at the MVN Op side,
# 'cause input shape is not available at that stage.
for i in (2, 3):
if i not in axes:
raise Error(
'Reduction over spatial dimensions in node "{}" '
'is obligatory for the backend.'.format(name))
if 0 in axes:
raise Error('Reduction over the batch dimension in node "{}" is not supported by the backend.'.format(name))
# Dimension 4 (if it's present in the input tensor) should also be in the list of axes for reduction.
# This case will be handled at the MVN Op side, because input shape is not available at that stage.
for i in (2, 3):
if i not in axes:
raise Error('Reduction over spatial dimensions in node "{}" is obligatory for a backend.'.format(name))
attrs = {
'eps': 1e-9,
'across_channels': 1 if 1 in axes else 0,
'normalize_variance': 1,
'axes': axes
}