Handle negative axis in SimplifySecondInputOfReshape (#11524)

Fixes #11501
This commit is contained in:
Mateusz Tabaka 2022-04-19 12:20:21 +02:00 committed by GitHub
parent 22398ac9cd
commit e53f702f81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 25 additions and 1 deletions

View File

@ -234,7 +234,7 @@ ngraph::pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
return false;
const auto concat_axis = concat->get_axis();
OPENVINO_ASSERT(concat_axis == 0, "axis is not valid for matched Concat with 1D output");
OPENVINO_ASSERT(concat_axis == 0 || concat_axis == -1, "axis is not valid for matched Concat with 1D output");
auto data = m.get_pattern_value_map().at(input);
if (is_type<opset8::FakeQuantize>(data.get_node_shared_ptr()) ||

View File

@ -584,3 +584,27 @@ TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTest20) {
}
comparator.enable(FunctionsComparator::CONST_VALUES);
}
TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTest21) {
PartialShape data_shape{1, 128, 12, 64};
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
auto constant = opset7::Constant::create(element::i64, Shape{1}, {768});
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant }, -1);
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
function = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
manager.register_pass<pass::SimplifySecondInputOfReshape>();
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, 768 });
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
function_ref = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
}
comparator.enable(FunctionsComparator::CONST_VALUES);
}