Handle negative axis in SimplifySecondInputOfReshape (#11524)
Fixes #11501
This commit is contained in:
parent
22398ac9cd
commit
e53f702f81
@ -234,7 +234,7 @@ ngraph::pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
|
||||
return false;
|
||||
|
||||
const auto concat_axis = concat->get_axis();
|
||||
OPENVINO_ASSERT(concat_axis == 0, "axis is not valid for matched Concat with 1D output");
|
||||
OPENVINO_ASSERT(concat_axis == 0 || concat_axis == -1, "axis is not valid for matched Concat with 1D output");
|
||||
|
||||
auto data = m.get_pattern_value_map().at(input);
|
||||
if (is_type<opset8::FakeQuantize>(data.get_node_shared_ptr()) ||
|
||||
|
@ -584,3 +584,27 @@ TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTest20) {
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CONST_VALUES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, SimplifySecondInputOfReshapeTest21) {
|
||||
PartialShape data_shape{1, 128, 12, 64};
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
|
||||
auto constant = opset7::Constant::create(element::i64, Shape{1}, {768});
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant }, -1);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
|
||||
function = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||
|
||||
manager.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, 768 });
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
|
||||
function_ref = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CONST_VALUES);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user