[IE Transformations] Disable pull_transpose_through_fq transformation for activations path (#8178)

* [IE Transformations] Disable pull_transpose_through_fq transformation for activations path

* Update MO part

* Replace TEST to TEST_F

* Fix tests with fixture.

Remove dynamic shapes tests due to the fact that there is always a constant at the input
This commit is contained in:
Aleksandr Pertovsky 2021-11-09 13:26:07 +03:00 committed by GitHub
parent b807800321
commit 0feecd4450
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 68 additions and 39 deletions

View File

@ -17,7 +17,8 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::PullTransposeThroughFQUp, "PullTransposeThr
ngraph::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() { ngraph::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
MATCHER_SCOPE(PullTransposeThroughFQUp); MATCHER_SCOPE(PullTransposeThroughFQUp);
auto m_fq = pattern::wrap_type<opset1::FakeQuantize>({pattern::any_input(pattern::has_static_rank()), const auto weights = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
auto m_fq = pattern::wrap_type<opset1::FakeQuantize>({weights,
pattern::any_input(pattern::has_static_shape()), pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()), pattern::any_input(pattern::has_static_shape()),
pattern::any_input(pattern::has_static_shape()), pattern::any_input(pattern::has_static_shape()),

View File

@ -56,28 +56,28 @@ TEST_F(TransformationTestsF, FQTransposeTest1) {
} }
} }
TEST(TransformationTests, FQTransposeDynamic) { TEST_F(TransformationTestsF, FQTransposeNegativeCase) {
auto data = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic()); auto create_graph = []() -> std::shared_ptr<ngraph::Function> {
auto input_low = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {2}); auto data = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::PartialShape{1, 3, 1});
auto input_high = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3}); auto sigmoid = std::make_shared<ngraph::op::Sigmoid>(data);
auto output_low = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {2}); auto input_low = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {2});
auto output_high = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3}); auto input_high = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3});
auto transpose_order = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1}); auto output_low = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {2});
auto output_high = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3});
auto transpose_order = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1});
std::shared_ptr<ngraph::Function> f(nullptr); auto fq = std::make_shared<ngraph::op::FakeQuantize>(sigmoid, input_low, input_high, output_low, output_high, 1);
{
auto fq = std::make_shared<ngraph::op::FakeQuantize>(data, input_low, input_high, output_low, output_high, 1);
auto transpose = std::make_shared<ngraph::op::Transpose>(fq, transpose_order); auto transpose = std::make_shared<ngraph::op::Transpose>(fq, transpose_order);
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{transpose}, ngraph::ParameterVector{data}); return std::make_shared<ngraph::Function>(ngraph::NodeVector{transpose}, ngraph::ParameterVector{data});
};
function = create_graph();
ngraph::pass::Manager manager; manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::InitNodeInfo>(); manager.register_pass<ngraph::pass::PullTransposeThroughFQUp>();
manager.register_pass<ngraph::pass::PullTransposeThroughFQUp>(); manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) { check_rt_info(f);
check_rt_info(f); });
});
manager.register_pass<ngraph::pass::ConstantFolding>(); function_ref = create_graph();
ASSERT_NO_THROW(manager.run_passes(f));
}
} }

View File

@ -73,14 +73,13 @@ class MatMulConstTransposesExtraction(BackReplacementPattern):
class PullTransposeThroughFQUp(BackReplacementPattern): class PullTransposeThroughFQUp(BackReplacementPattern):
r""" r"""
BEFORE AFTER BEFORE AFTER
T T T T T Const Const
\ \ | / / \ \ | / / \ \ | / / |
FakeQuantize FakeQuantize FakeQuantize T T T T T
| \ \ | / /
Transpose FakeQuantize
| | | |
Transpose next_op next_op next_op
|
next_op
`T` is Transpose for short `T` is Transpose for short
""" """
enabled = True enabled = True
@ -94,13 +93,17 @@ class PullTransposeThroughFQUp(BackReplacementPattern):
def pattern(): def pattern():
return dict( return dict(
nodes=[ nodes=[
('fq_const_input', dict(kind='op', type='Const')),
('fq_const_input_d', dict()),
('fq', dict(kind='op', type='FakeQuantize')), ('fq', dict(kind='op', type='FakeQuantize')),
('data', dict()), ('fq_d', dict()),
('transpose', dict(kind='op', type='Transpose')), ('transpose', dict(kind='op', type='Transpose')),
], ],
edges=[ edges=[
('fq', 'data'), ('fq_const_input', 'fq_const_input_d'),
('data', 'transpose'), ('fq_const_input_d', 'fq', {'in': 0}),
('fq', 'fq_d'),
('fq_d', 'transpose'),
] ]
) )

View File

@ -15,7 +15,7 @@ from mo.front.common.partial_infer.utils import int64_array
from mo.ops.reshape import Reshape from mo.ops.reshape import Reshape
from mo.utils.ir_engine.compare_graphs import compare_graphs from mo.utils.ir_engine.compare_graphs import compare_graphs
from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, \ from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, \
result, connect, connect_data shaped_const_with_data, result, connect, connect_data
from unit_tests.utils.graph import regular_op_with_empty_data as op_with_empty_data from unit_tests.utils.graph import regular_op_with_empty_data as op_with_empty_data
@ -101,9 +101,8 @@ class SmartReshape_HC_Reshape_MatMulTest(unittest.TestCase):
class FQTransposePullerTest(unittest.TestCase): class FQTransposePullerTest(unittest.TestCase):
def nodes(self, input_shape, transpose_shape, fq_shape): def nodes(self, input_shape, transpose_shape, fq_shape, is_input_const):
return { nodes = {
**regular_op_with_shaped_data('input', input_shape, dict(type='Parameter', op='Parameter')),
**valued_const_with_data('il', np.array([[[[0]]]])), **valued_const_with_data('il', np.array([[[[0]]]])),
**valued_const_with_data('ih', np.array([[[[255]]]])), **valued_const_with_data('ih', np.array([[[[255]]]])),
**valued_const_with_data('ol', np.array([[[[0]]]])), **valued_const_with_data('ol', np.array([[[[0]]]])),
@ -116,8 +115,16 @@ class FQTransposePullerTest(unittest.TestCase):
**result(), **result(),
} }
if is_input_const:
input_node = shaped_const_with_data('input', input_shape)
else:
input_node = regular_op_with_shaped_data('input', input_shape, dict(type='Parameter', op='Parameter'))
nodes.update(input_node)
return nodes
def test_positive(self): def test_positive(self):
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224]) nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224], True)
edges = [ edges = [
*connect('input', '0:FQ'), *connect('input', '0:FQ'),
*connect('il', '1:FQ'), *connect('il', '1:FQ'),
@ -132,7 +139,7 @@ class FQTransposePullerTest(unittest.TestCase):
PullTransposeThroughFQUp().find_and_replace_pattern(graph) PullTransposeThroughFQUp().find_and_replace_pattern(graph)
graph.clean_up() graph.clean_up()
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 224, 224, 3]) nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 224, 224, 3], True)
edges = [ edges = [
*connect('input', '0:transpose'), *connect('input', '0:transpose'),
*connect('order:0', '1:transpose'), *connect('order:0', '1:transpose'),
@ -148,8 +155,8 @@ class FQTransposePullerTest(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp) self.assertTrue(flag, resp)
def test_negative(self): def test_negative_1(self):
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224]) nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224], True)
edges = [ edges = [
*connect('input', '0:FQ'), *connect('input', '0:FQ'),
*connect('il', '1:FQ'), *connect('il', '1:FQ'),
@ -168,3 +175,21 @@ class FQTransposePullerTest(unittest.TestCase):
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp) self.assertTrue(flag, resp)
def test_negative_2(self):
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224], False)
edges = [
*connect('input', '0:FQ'),
*connect('il', '1:FQ'),
*connect('ih', '2:FQ'),
*connect('ol', '3:FQ'),
*connect('oh', '4:FQ'),
*connect('FQ:0', '0:transpose'),
*connect('order:0', '1:transpose'),
*connect('transpose:0', 'output'),
]
graph = build_graph(nodes_attrs=nodes, edges=edges, nodes_with_edges_only=True)
graph_ref = graph.copy()
PullTransposeThroughFQUp().find_and_replace_pattern(graph)
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)