Fix TransposeReduce Transformation (#8291)

This commit is contained in:
Gleb Kazantaev
2021-11-09 21:10:18 +03:00
committed by GitHub
parent eb2b149fca
commit d21f0ed242
2 changed files with 15 additions and 1 deletions

View File

@@ -63,7 +63,7 @@ std::shared_ptr<ngraph::opset6::Constant> get_reversed_order_constant(const std:
ngraph::pass::TransposeReduction::TransposeReduction() {
MATCHER_SCOPE(TransposeReduction);
auto transpose_label = pattern::wrap_type<opset6::Transpose>({pattern::any_input(), pattern::wrap_type<opset6::Constant>()});
auto transpose_label = pattern::wrap_type<opset6::Transpose>({pattern::any_input(), pattern::wrap_type<opset6::Constant>()}, pattern::consumers_count(1));
auto reduce_or_squeeze_label = pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, opset6::Squeeze>(
{transpose_label, pattern::wrap_type<opset6::Constant>()});

View File

@@ -265,3 +265,17 @@ TEST_F(TransformationTestsF, TransposeFuses) {
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input });
}
}
TEST_F(TransformationTestsF, TransposeReduceNegative) {
{
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 64});
auto order = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 3 }, { 0, 2, 1});
auto transpose = std::make_shared<ngraph::opset6::Transpose>(input, order);
auto axes = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{}, {-1});
auto reduce_mean = std::make_shared<ngraph::opset6::ReduceMean>(transpose, axes, true);
auto sub = std::make_shared<opset6::Subtract>(transpose, reduce_mean);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ sub }, ngraph::ParameterVector{ input });
manager.register_pass<ngraph::pass::TransposeReduction>();
}
}