[IE Transformations] Disable pull_transpose_through_fq transformation for activations path (#8178)
* [IE Transformations] Disable pull_transpose_through_fq transformation for activations path * Update MO part * Replace TEST to TEST_F * Fix tests with fixture. Remove dynamic shapes tests due to the fact that there is always a constant at the input
This commit is contained in:
parent
b807800321
commit
0feecd4450
@ -17,7 +17,8 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::PullTransposeThroughFQUp, "PullTransposeThr
|
||||
|
||||
ngraph::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
|
||||
MATCHER_SCOPE(PullTransposeThroughFQUp);
|
||||
auto m_fq = pattern::wrap_type<opset1::FakeQuantize>({pattern::any_input(pattern::has_static_rank()),
|
||||
const auto weights = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
|
||||
auto m_fq = pattern::wrap_type<opset1::FakeQuantize>({weights,
|
||||
pattern::any_input(pattern::has_static_shape()),
|
||||
pattern::any_input(pattern::has_static_shape()),
|
||||
pattern::any_input(pattern::has_static_shape()),
|
||||
|
@ -56,28 +56,28 @@ TEST_F(TransformationTestsF, FQTransposeTest1) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TransformationTests, FQTransposeDynamic) {
|
||||
auto data = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
|
||||
auto input_low = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {2});
|
||||
auto input_high = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3});
|
||||
auto output_low = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {2});
|
||||
auto output_high = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3});
|
||||
auto transpose_order = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1});
|
||||
TEST_F(TransformationTestsF, FQTransposeNegativeCase) {
|
||||
auto create_graph = []() -> std::shared_ptr<ngraph::Function> {
|
||||
auto data = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::PartialShape{1, 3, 1});
|
||||
auto sigmoid = std::make_shared<ngraph::op::Sigmoid>(data);
|
||||
auto input_low = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {2});
|
||||
auto input_high = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3});
|
||||
auto output_low = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {2});
|
||||
auto output_high = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3});
|
||||
auto transpose_order = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1});
|
||||
|
||||
std::shared_ptr<ngraph::Function> f(nullptr);
|
||||
{
|
||||
auto fq = std::make_shared<ngraph::op::FakeQuantize>(data, input_low, input_high, output_low, output_high, 1);
|
||||
auto fq = std::make_shared<ngraph::op::FakeQuantize>(sigmoid, input_low, input_high, output_low, output_high, 1);
|
||||
auto transpose = std::make_shared<ngraph::op::Transpose>(fq, transpose_order);
|
||||
|
||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{transpose}, ngraph::ParameterVector{data});
|
||||
return std::make_shared<ngraph::Function>(ngraph::NodeVector{transpose}, ngraph::ParameterVector{data});
|
||||
};
|
||||
function = create_graph();
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::PullTransposeThroughFQUp>();
|
||||
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
|
||||
check_rt_info(f);
|
||||
});
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
ASSERT_NO_THROW(manager.run_passes(f));
|
||||
}
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::PullTransposeThroughFQUp>();
|
||||
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
|
||||
check_rt_info(f);
|
||||
});
|
||||
|
||||
function_ref = create_graph();
|
||||
}
|
||||
|
@ -73,14 +73,13 @@ class MatMulConstTransposesExtraction(BackReplacementPattern):
|
||||
class PullTransposeThroughFQUp(BackReplacementPattern):
|
||||
r"""
|
||||
BEFORE AFTER
|
||||
T T T T T
|
||||
\ \ | / / \ \ | / /
|
||||
FakeQuantize FakeQuantize
|
||||
Const Const
|
||||
\ \ | / / |
|
||||
FakeQuantize T T T T T
|
||||
| \ \ | / /
|
||||
Transpose FakeQuantize
|
||||
| |
|
||||
Transpose next_op
|
||||
|
|
||||
next_op
|
||||
|
||||
next_op next_op
|
||||
`T` is Transpose for short
|
||||
"""
|
||||
enabled = True
|
||||
@ -94,13 +93,17 @@ class PullTransposeThroughFQUp(BackReplacementPattern):
|
||||
def pattern():
|
||||
return dict(
|
||||
nodes=[
|
||||
('fq_const_input', dict(kind='op', type='Const')),
|
||||
('fq_const_input_d', dict()),
|
||||
('fq', dict(kind='op', type='FakeQuantize')),
|
||||
('data', dict()),
|
||||
('fq_d', dict()),
|
||||
('transpose', dict(kind='op', type='Transpose')),
|
||||
],
|
||||
edges=[
|
||||
('fq', 'data'),
|
||||
('data', 'transpose'),
|
||||
('fq_const_input', 'fq_const_input_d'),
|
||||
('fq_const_input_d', 'fq', {'in': 0}),
|
||||
('fq', 'fq_d'),
|
||||
('fq_d', 'transpose'),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -15,7 +15,7 @@ from mo.front.common.partial_infer.utils import int64_array
|
||||
from mo.ops.reshape import Reshape
|
||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||
from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, \
|
||||
result, connect, connect_data
|
||||
shaped_const_with_data, result, connect, connect_data
|
||||
from unit_tests.utils.graph import regular_op_with_empty_data as op_with_empty_data
|
||||
|
||||
|
||||
@ -101,9 +101,8 @@ class SmartReshape_HC_Reshape_MatMulTest(unittest.TestCase):
|
||||
|
||||
|
||||
class FQTransposePullerTest(unittest.TestCase):
|
||||
def nodes(self, input_shape, transpose_shape, fq_shape):
|
||||
return {
|
||||
**regular_op_with_shaped_data('input', input_shape, dict(type='Parameter', op='Parameter')),
|
||||
def nodes(self, input_shape, transpose_shape, fq_shape, is_input_const):
|
||||
nodes = {
|
||||
**valued_const_with_data('il', np.array([[[[0]]]])),
|
||||
**valued_const_with_data('ih', np.array([[[[255]]]])),
|
||||
**valued_const_with_data('ol', np.array([[[[0]]]])),
|
||||
@ -116,8 +115,16 @@ class FQTransposePullerTest(unittest.TestCase):
|
||||
**result(),
|
||||
}
|
||||
|
||||
if is_input_const:
|
||||
input_node = shaped_const_with_data('input', input_shape)
|
||||
else:
|
||||
input_node = regular_op_with_shaped_data('input', input_shape, dict(type='Parameter', op='Parameter'))
|
||||
|
||||
nodes.update(input_node)
|
||||
return nodes
|
||||
|
||||
def test_positive(self):
|
||||
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224])
|
||||
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224], True)
|
||||
edges = [
|
||||
*connect('input', '0:FQ'),
|
||||
*connect('il', '1:FQ'),
|
||||
@ -132,7 +139,7 @@ class FQTransposePullerTest(unittest.TestCase):
|
||||
PullTransposeThroughFQUp().find_and_replace_pattern(graph)
|
||||
graph.clean_up()
|
||||
|
||||
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 224, 224, 3])
|
||||
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 224, 224, 3], True)
|
||||
edges = [
|
||||
*connect('input', '0:transpose'),
|
||||
*connect('order:0', '1:transpose'),
|
||||
@ -148,8 +155,8 @@ class FQTransposePullerTest(unittest.TestCase):
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_negative(self):
|
||||
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224])
|
||||
def test_negative_1(self):
|
||||
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224], True)
|
||||
edges = [
|
||||
*connect('input', '0:FQ'),
|
||||
*connect('il', '1:FQ'),
|
||||
@ -168,3 +175,21 @@ class FQTransposePullerTest(unittest.TestCase):
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
||||
def test_negative_2(self):
|
||||
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224], False)
|
||||
edges = [
|
||||
*connect('input', '0:FQ'),
|
||||
*connect('il', '1:FQ'),
|
||||
*connect('ih', '2:FQ'),
|
||||
*connect('ol', '3:FQ'),
|
||||
*connect('oh', '4:FQ'),
|
||||
*connect('FQ:0', '0:transpose'),
|
||||
*connect('order:0', '1:transpose'),
|
||||
*connect('transpose:0', 'output'),
|
||||
]
|
||||
graph = build_graph(nodes_attrs=nodes, edges=edges, nodes_with_edges_only=True)
|
||||
graph_ref = graph.copy()
|
||||
PullTransposeThroughFQUp().find_and_replace_pattern(graph)
|
||||
|
||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||
self.assertTrue(flag, resp)
|
||||
|
Loading…
Reference in New Issue
Block a user