Fixed transformations looking for FusedBatchNorm operation to look for FBNV2 and FBNV3 also (#3078)
* Fixed transformations looking for FusedBatchNorm operation to consider FusedBatchNormV2 and FusedBatchNormV3 also. * Updated unit test for FusedBatchNormTraining * Fixed unit test
This commit is contained in:
parent
9420b6e599
commit
f4d399f471
@ -35,7 +35,7 @@ class MVNReplacer(FrontReplacementSubgraph):
|
||||
('variance', dict(op='ReduceMean')),
|
||||
('squeeze_mean', dict(op='Squeeze')),
|
||||
('squeeze_variance', dict(op='Squeeze')),
|
||||
('fbn', dict(op='FusedBatchNorm')),
|
||||
('fbn', dict(op=lambda op: op in ['FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3'])),
|
||||
],
|
||||
edges=[
|
||||
('mean', 'stop_grad', {'in': 0}),
|
||||
|
@ -40,7 +40,8 @@ class FusedBatchNormNonConstant(MiddleReplacementPattern):
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
('op', dict(kind='op', op='FusedBatchNorm'))],
|
||||
('op', dict(kind='op', op=lambda op: op in ['FusedBatchNorm', 'FusedBatchNormV2',
|
||||
'FusedBatchNormV3']))],
|
||||
edges=[]
|
||||
)
|
||||
|
||||
|
@ -44,7 +44,8 @@ class FusedBatchNormTraining(MiddleReplacementPattern):
|
||||
def pattern(self):
|
||||
return dict(
|
||||
nodes=[
|
||||
('op', dict(kind='op', op='FusedBatchNorm', is_training=True))],
|
||||
('op', dict(kind='op', op=lambda op: op in ['FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3'],
|
||||
is_training=True))],
|
||||
edges=[]
|
||||
)
|
||||
|
||||
|
@ -17,6 +17,7 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from generator import generator, generate
|
||||
|
||||
from extensions.middle.FusedBatchNormTraining import FusedBatchNormTraining
|
||||
from mo.front.common.partial_infer.utils import int64_array
|
||||
@ -74,8 +75,12 @@ nodes_attributes = {
|
||||
}
|
||||
|
||||
|
||||
@generator
|
||||
class FusedBatchNormTrainingTest(unittest.TestCase):
|
||||
def test_transformation(self):
|
||||
@generate(*[
|
||||
'FusedBatchNorm', 'FusedBatchNormV2', 'FusedBatchNormV3',
|
||||
])
|
||||
def test_transformation(self, op: str):
|
||||
graph = build_graph(nodes_attributes,
|
||||
[('placeholder', 'placeholder_data', {}),
|
||||
('scale', 'scale_data'),
|
||||
@ -91,7 +96,7 @@ class FusedBatchNormTrainingTest(unittest.TestCase):
|
||||
('batchnorm_data', 'result'),
|
||||
],
|
||||
{}, nodes_with_edges_only=True)
|
||||
|
||||
graph.nodes['batchnorm']['op'] = op
|
||||
graph_ref = build_graph(nodes_attributes,
|
||||
[('placeholder', 'placeholder_data', {}),
|
||||
('scale', 'scale_data'),
|
||||
@ -125,6 +130,8 @@ class FusedBatchNormTrainingTest(unittest.TestCase):
|
||||
FusedBatchNormTraining().find_and_replace_pattern(graph)
|
||||
shape_inference(graph)
|
||||
|
||||
graph_ref.nodes['batchnorm']['op'] = op
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user