[IE Transformations] Disable pull_transpose_through_fq transformation for activations path (#8178)
* [IE Transformations] Disable pull_transpose_through_fq transformation for activations path * Update MO part * Replace TEST to TEST_F * Fix tests with fixture. Remove dynamic shapes tests due to the fact that there is always a constant at the input
This commit is contained in:
parent
b807800321
commit
0feecd4450
@ -17,7 +17,8 @@ NGRAPH_RTTI_DEFINITION(ngraph::pass::PullTransposeThroughFQUp, "PullTransposeThr
|
|||||||
|
|
||||||
ngraph::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
|
ngraph::pass::PullTransposeThroughFQUp::PullTransposeThroughFQUp() {
|
||||||
MATCHER_SCOPE(PullTransposeThroughFQUp);
|
MATCHER_SCOPE(PullTransposeThroughFQUp);
|
||||||
auto m_fq = pattern::wrap_type<opset1::FakeQuantize>({pattern::any_input(pattern::has_static_rank()),
|
const auto weights = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
|
||||||
|
auto m_fq = pattern::wrap_type<opset1::FakeQuantize>({weights,
|
||||||
pattern::any_input(pattern::has_static_shape()),
|
pattern::any_input(pattern::has_static_shape()),
|
||||||
pattern::any_input(pattern::has_static_shape()),
|
pattern::any_input(pattern::has_static_shape()),
|
||||||
pattern::any_input(pattern::has_static_shape()),
|
pattern::any_input(pattern::has_static_shape()),
|
||||||
|
@ -56,28 +56,28 @@ TEST_F(TransformationTestsF, FQTransposeTest1) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TransformationTests, FQTransposeDynamic) {
|
TEST_F(TransformationTestsF, FQTransposeNegativeCase) {
|
||||||
auto data = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::PartialShape::dynamic());
|
auto create_graph = []() -> std::shared_ptr<ngraph::Function> {
|
||||||
auto input_low = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {2});
|
auto data = std::make_shared<ngraph::op::Parameter>(ngraph::element::f32, ngraph::PartialShape{1, 3, 1});
|
||||||
auto input_high = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3});
|
auto sigmoid = std::make_shared<ngraph::op::Sigmoid>(data);
|
||||||
auto output_low = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {2});
|
auto input_low = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {2});
|
||||||
auto output_high = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3});
|
auto input_high = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3});
|
||||||
auto transpose_order = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1});
|
auto output_low = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {2});
|
||||||
|
auto output_high = ngraph::op::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {3});
|
||||||
|
auto transpose_order = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{3}, {0, 2, 1});
|
||||||
|
|
||||||
std::shared_ptr<ngraph::Function> f(nullptr);
|
auto fq = std::make_shared<ngraph::op::FakeQuantize>(sigmoid, input_low, input_high, output_low, output_high, 1);
|
||||||
{
|
|
||||||
auto fq = std::make_shared<ngraph::op::FakeQuantize>(data, input_low, input_high, output_low, output_high, 1);
|
|
||||||
auto transpose = std::make_shared<ngraph::op::Transpose>(fq, transpose_order);
|
auto transpose = std::make_shared<ngraph::op::Transpose>(fq, transpose_order);
|
||||||
|
|
||||||
f = std::make_shared<ngraph::Function>(ngraph::NodeVector{transpose}, ngraph::ParameterVector{data});
|
return std::make_shared<ngraph::Function>(ngraph::NodeVector{transpose}, ngraph::ParameterVector{data});
|
||||||
|
};
|
||||||
|
function = create_graph();
|
||||||
|
|
||||||
ngraph::pass::Manager manager;
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
manager.register_pass<ngraph::pass::PullTransposeThroughFQUp>();
|
||||||
manager.register_pass<ngraph::pass::PullTransposeThroughFQUp>();
|
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
|
||||||
manager.register_pass<ngraph::pass::InjectionPass>([](std::shared_ptr<ngraph::Function> f) {
|
check_rt_info(f);
|
||||||
check_rt_info(f);
|
});
|
||||||
});
|
|
||||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
function_ref = create_graph();
|
||||||
ASSERT_NO_THROW(manager.run_passes(f));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@ -73,14 +73,13 @@ class MatMulConstTransposesExtraction(BackReplacementPattern):
|
|||||||
class PullTransposeThroughFQUp(BackReplacementPattern):
|
class PullTransposeThroughFQUp(BackReplacementPattern):
|
||||||
r"""
|
r"""
|
||||||
BEFORE AFTER
|
BEFORE AFTER
|
||||||
T T T T T
|
Const Const
|
||||||
\ \ | / / \ \ | / /
|
\ \ | / / |
|
||||||
FakeQuantize FakeQuantize
|
FakeQuantize T T T T T
|
||||||
|
| \ \ | / /
|
||||||
|
Transpose FakeQuantize
|
||||||
| |
|
| |
|
||||||
Transpose next_op
|
next_op next_op
|
||||||
|
|
|
||||||
next_op
|
|
||||||
|
|
||||||
`T` is Transpose for short
|
`T` is Transpose for short
|
||||||
"""
|
"""
|
||||||
enabled = True
|
enabled = True
|
||||||
@ -94,13 +93,17 @@ class PullTransposeThroughFQUp(BackReplacementPattern):
|
|||||||
def pattern():
|
def pattern():
|
||||||
return dict(
|
return dict(
|
||||||
nodes=[
|
nodes=[
|
||||||
|
('fq_const_input', dict(kind='op', type='Const')),
|
||||||
|
('fq_const_input_d', dict()),
|
||||||
('fq', dict(kind='op', type='FakeQuantize')),
|
('fq', dict(kind='op', type='FakeQuantize')),
|
||||||
('data', dict()),
|
('fq_d', dict()),
|
||||||
('transpose', dict(kind='op', type='Transpose')),
|
('transpose', dict(kind='op', type='Transpose')),
|
||||||
],
|
],
|
||||||
edges=[
|
edges=[
|
||||||
('fq', 'data'),
|
('fq_const_input', 'fq_const_input_d'),
|
||||||
('data', 'transpose'),
|
('fq_const_input_d', 'fq', {'in': 0}),
|
||||||
|
('fq', 'fq_d'),
|
||||||
|
('fq_d', 'transpose'),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -15,7 +15,7 @@ from mo.front.common.partial_infer.utils import int64_array
|
|||||||
from mo.ops.reshape import Reshape
|
from mo.ops.reshape import Reshape
|
||||||
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
from mo.utils.ir_engine.compare_graphs import compare_graphs
|
||||||
from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, \
|
from unit_tests.utils.graph import build_graph, regular_op_with_shaped_data, valued_const_with_data, \
|
||||||
result, connect, connect_data
|
shaped_const_with_data, result, connect, connect_data
|
||||||
from unit_tests.utils.graph import regular_op_with_empty_data as op_with_empty_data
|
from unit_tests.utils.graph import regular_op_with_empty_data as op_with_empty_data
|
||||||
|
|
||||||
|
|
||||||
@ -101,9 +101,8 @@ class SmartReshape_HC_Reshape_MatMulTest(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
class FQTransposePullerTest(unittest.TestCase):
|
class FQTransposePullerTest(unittest.TestCase):
|
||||||
def nodes(self, input_shape, transpose_shape, fq_shape):
|
def nodes(self, input_shape, transpose_shape, fq_shape, is_input_const):
|
||||||
return {
|
nodes = {
|
||||||
**regular_op_with_shaped_data('input', input_shape, dict(type='Parameter', op='Parameter')),
|
|
||||||
**valued_const_with_data('il', np.array([[[[0]]]])),
|
**valued_const_with_data('il', np.array([[[[0]]]])),
|
||||||
**valued_const_with_data('ih', np.array([[[[255]]]])),
|
**valued_const_with_data('ih', np.array([[[[255]]]])),
|
||||||
**valued_const_with_data('ol', np.array([[[[0]]]])),
|
**valued_const_with_data('ol', np.array([[[[0]]]])),
|
||||||
@ -116,8 +115,16 @@ class FQTransposePullerTest(unittest.TestCase):
|
|||||||
**result(),
|
**result(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if is_input_const:
|
||||||
|
input_node = shaped_const_with_data('input', input_shape)
|
||||||
|
else:
|
||||||
|
input_node = regular_op_with_shaped_data('input', input_shape, dict(type='Parameter', op='Parameter'))
|
||||||
|
|
||||||
|
nodes.update(input_node)
|
||||||
|
return nodes
|
||||||
|
|
||||||
def test_positive(self):
|
def test_positive(self):
|
||||||
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224])
|
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224], True)
|
||||||
edges = [
|
edges = [
|
||||||
*connect('input', '0:FQ'),
|
*connect('input', '0:FQ'),
|
||||||
*connect('il', '1:FQ'),
|
*connect('il', '1:FQ'),
|
||||||
@ -132,7 +139,7 @@ class FQTransposePullerTest(unittest.TestCase):
|
|||||||
PullTransposeThroughFQUp().find_and_replace_pattern(graph)
|
PullTransposeThroughFQUp().find_and_replace_pattern(graph)
|
||||||
graph.clean_up()
|
graph.clean_up()
|
||||||
|
|
||||||
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 224, 224, 3])
|
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 224, 224, 3], True)
|
||||||
edges = [
|
edges = [
|
||||||
*connect('input', '0:transpose'),
|
*connect('input', '0:transpose'),
|
||||||
*connect('order:0', '1:transpose'),
|
*connect('order:0', '1:transpose'),
|
||||||
@ -148,8 +155,8 @@ class FQTransposePullerTest(unittest.TestCase):
|
|||||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||||
self.assertTrue(flag, resp)
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
def test_negative(self):
|
def test_negative_1(self):
|
||||||
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224])
|
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224], True)
|
||||||
edges = [
|
edges = [
|
||||||
*connect('input', '0:FQ'),
|
*connect('input', '0:FQ'),
|
||||||
*connect('il', '1:FQ'),
|
*connect('il', '1:FQ'),
|
||||||
@ -168,3 +175,21 @@ class FQTransposePullerTest(unittest.TestCase):
|
|||||||
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||||
self.assertTrue(flag, resp)
|
self.assertTrue(flag, resp)
|
||||||
|
|
||||||
|
def test_negative_2(self):
|
||||||
|
nodes = self.nodes([1, 3, 224, 224], [1, 224, 224, 3], [1, 3, 224, 224], False)
|
||||||
|
edges = [
|
||||||
|
*connect('input', '0:FQ'),
|
||||||
|
*connect('il', '1:FQ'),
|
||||||
|
*connect('ih', '2:FQ'),
|
||||||
|
*connect('ol', '3:FQ'),
|
||||||
|
*connect('oh', '4:FQ'),
|
||||||
|
*connect('FQ:0', '0:transpose'),
|
||||||
|
*connect('order:0', '1:transpose'),
|
||||||
|
*connect('transpose:0', 'output'),
|
||||||
|
]
|
||||||
|
graph = build_graph(nodes_attrs=nodes, edges=edges, nodes_with_edges_only=True)
|
||||||
|
graph_ref = graph.copy()
|
||||||
|
PullTransposeThroughFQUp().find_and_replace_pattern(graph)
|
||||||
|
|
||||||
|
(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
|
||||||
|
self.assertTrue(flag, resp)
|
||||||
|
Loading…
Reference in New Issue
Block a user