Fix TransposeReduce Transformation (#8291)
This commit is contained in:
@@ -63,7 +63,7 @@ std::shared_ptr<ngraph::opset6::Constant> get_reversed_order_constant(const std:
|
||||
ngraph::pass::TransposeReduction::TransposeReduction() {
|
||||
MATCHER_SCOPE(TransposeReduction);
|
||||
|
||||
auto transpose_label = pattern::wrap_type<opset6::Transpose>({pattern::any_input(), pattern::wrap_type<opset6::Constant>()});
|
||||
auto transpose_label = pattern::wrap_type<opset6::Transpose>({pattern::any_input(), pattern::wrap_type<opset6::Constant>()}, pattern::consumers_count(1));
|
||||
auto reduce_or_squeeze_label = pattern::wrap_type<op::util::ArithmeticReductionKeepDims, op::util::LogicalReductionKeepDims, opset6::Squeeze>(
|
||||
{transpose_label, pattern::wrap_type<opset6::Constant>()});
|
||||
|
||||
|
||||
@@ -265,3 +265,17 @@ TEST_F(TransformationTestsF, TransposeFuses) {
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input });
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, TransposeReduceNegative) {
|
||||
{
|
||||
auto input = std::make_shared<ngraph::opset6::Parameter>(ngraph::element::f32, ngraph::Shape{ 1, 3, 64});
|
||||
auto order = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{ 3 }, { 0, 2, 1});
|
||||
auto transpose = std::make_shared<ngraph::opset6::Transpose>(input, order);
|
||||
auto axes = ngraph::opset6::Constant::create(ngraph::element::i64, ngraph::Shape{}, {-1});
|
||||
auto reduce_mean = std::make_shared<ngraph::opset6::ReduceMean>(transpose, axes, true);
|
||||
auto sub = std::make_shared<opset6::Subtract>(transpose, reduce_mean);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ sub }, ngraph::ParameterVector{ input });
|
||||
manager.register_pass<ngraph::pass::TransposeReduction>();
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user