Fixed transformations looking for FusedBatchNorm operation to look for FBNV2 and FBNV3 also (#3078)

* Fixed transformations looking for FusedBatchNorm operation to consider FusedBatchNormV2 and FusedBatchNormV3 also.

* Updated unit test for FusedBatchNormTraining

* Fixed unit test
This commit is contained in:
Evgeny Lazarev 2020-11-12 07:33:39 +03:00 committed by GitHub
parent 9420b6e599
commit f4d399f471
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 14 additions and 5 deletions

View File

@ -35,7 +35,7 @@ class MVNReplacer(FrontReplacementSubgraph):
('variance', dict(op='ReduceMean')),
('squeeze_mean', dict(op='Squeeze')),
('squeeze_variance', dict(op='Squeeze')),
('fbn', dict(op='FusedBatchNorm')),
('fbn', dict(op=lambda op: op in ['FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3'])),
],
edges=[
('mean', 'stop_grad', {'in': 0}),

View File

@ -40,7 +40,8 @@ class FusedBatchNormNonConstant(MiddleReplacementPattern):
def pattern(self):
return dict(
nodes=[
('op', dict(kind='op', op='FusedBatchNorm'))],
('op', dict(kind='op', op=lambda op: op in ['FusedBatchNorm', 'FusedBatchNormV2',
'FusedBatchNormV3']))],
edges=[]
)

View File

@ -44,7 +44,8 @@ class FusedBatchNormTraining(MiddleReplacementPattern):
def pattern(self):
return dict(
nodes=[
('op', dict(kind='op', op='FusedBatchNorm', is_training=True))],
('op', dict(kind='op', op=lambda op: op in ['FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3'],
is_training=True))],
edges=[]
)

View File

@ -17,6 +17,7 @@
import unittest
import numpy as np
from generator import generator, generate
from extensions.middle.FusedBatchNormTraining import FusedBatchNormTraining
from mo.front.common.partial_infer.utils import int64_array
@ -74,8 +75,12 @@ nodes_attributes = {
}
@generator
class FusedBatchNormTrainingTest(unittest.TestCase):
def test_transformation(self):
@generate(*[
'FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3',
])
def test_transformation(self, op: str):
graph = build_graph(nodes_attributes,
[('placeholder', 'placeholder_data', {}),
('scale', 'scale_data'),
@ -91,7 +96,7 @@ class FusedBatchNormTrainingTest(unittest.TestCase):
('batchnorm_data', 'result'),
],
{}, nodes_with_edges_only=True)
graph.nodes['batchnorm']['op'] = op
graph_ref = build_graph(nodes_attributes,
[('placeholder', 'placeholder_data', {}),
('scale', 'scale_data'),
@ -125,6 +130,8 @@ class FusedBatchNormTrainingTest(unittest.TestCase):
FusedBatchNormTraining().find_and_replace_pattern(graph)
shape_inference(graph)
graph_ref.nodes['batchnorm']['op'] = op
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
self.assertTrue(flag, resp)