[CommonOptimizations] SimplifySecondInputOfReshape fix (#9210)
* [CommonOptimizations] SimplifySecondInputOfReshape fix * GroupedGatherElimination: added Gather v8 support
This commit is contained in:
parent
4a6575b4b7
commit
01689ee408
@ -74,20 +74,31 @@ ngraph::pass::GroupedGatherElimination::GroupedGatherElimination() {
|
||||
while (inputs.size() > i + 1) {
|
||||
auto curr = inputs[i].get_node_shared_ptr(), next = inputs[i + 1].get_node_shared_ptr();
|
||||
if (curr->get_type_info() != next->get_type_info() ||
|
||||
(!ov::is_type<opset1::Gather>(curr) && !ov::is_type<opset7::Gather>(curr)) ||
|
||||
(!ov::is_type<opset1::Gather>(curr) && !ov::is_type<opset7::Gather>(curr) && !ov::is_type<opset8::Gather>(curr)) ||
|
||||
(curr->input_value(0) != next->input_value(0))) {
|
||||
++i;
|
||||
continue;
|
||||
} // curr and next are the same type of gather which takes data from the same source
|
||||
bool is_opset1 = ov::is_type<opset1::Gather>(curr);
|
||||
auto joint_indices = ngraph::op::util::make_try_fold<opset1::Concat>(OutputVector{curr->input_value(1), next->input_value(1)}, 0);
|
||||
std::shared_ptr<Node> new_gather;
|
||||
if (is_opset1)
|
||||
if (ov::is_type<opset1::Gather>(curr)) {
|
||||
new_gather = register_new_node<ngraph::opset1::Gather>(
|
||||
curr->input_value(0), joint_indices->output(0), ngraph::opset1::Constant::create(element::i64, {}, {0})->output(0));
|
||||
else
|
||||
curr->input_value(0),
|
||||
joint_indices->output(0),
|
||||
ngraph::opset1::Constant::create(element::i64, {}, {0})->output(0));
|
||||
} else if (ov::is_type<opset7::Gather>(curr)) {
|
||||
new_gather = register_new_node<ngraph::opset7::Gather>(
|
||||
curr->input_value(0), joint_indices->output(0), ngraph::opset1::Constant::create(element::i64, {}, {0})->output(0));
|
||||
curr->input_value(0),
|
||||
joint_indices->output(0),
|
||||
ngraph::opset1::Constant::create(element::i64, {}, {0})->output(0));
|
||||
} else if (ov::is_type<opset8::Gather>(curr)) {
|
||||
new_gather = register_new_node<ngraph::opset8::Gather>(
|
||||
curr->input_value(0),
|
||||
joint_indices->output(0),
|
||||
ngraph::opset1::Constant::create(element::i64, {}, {0})->output(0));
|
||||
} else {
|
||||
OPENVINO_UNREACHABLE("Unexpected Gather version");
|
||||
}
|
||||
new_ops.push_back(joint_indices);
|
||||
new_ops.push_back(new_gather);
|
||||
inputs.erase(inputs.begin() + i);
|
||||
@ -239,8 +250,7 @@ ngraph::pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
|
||||
|
||||
auto check_shape_of_gather = [&](const std::shared_ptr<Node>& gather) {
|
||||
auto shape_of = gather->get_input_node_shared_ptr(0);
|
||||
if ((!is_type<opset8::ShapeOf>(shape_of) && !is_type<opset1::ShapeOf>(shape_of)) ||
|
||||
(shape_of->get_output_target_inputs(0).size() > 1)) {
|
||||
if (!is_type<opset8::ShapeOf>(shape_of) && !is_type<opset1::ShapeOf>(shape_of)) {
|
||||
return false;
|
||||
}
|
||||
return shape_of->input_value(0) == data;
|
||||
|
@ -547,3 +547,37 @@ TEST(TransformationTests, SimplifySecondInputOfReshapeTest15) {
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifySecondInputOfReshapeTest16) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
PartialShape data_shape{ 1, 128, 12, 64 };
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_op_1 = gather(shape_of, std::vector<int64_t>{0});
|
||||
auto gather_op_2 = gather(shape_of, std::vector<int64_t>{1});
|
||||
auto constant = opset7::Constant::create(element::i64, Shape{ 1 }, { 768 });
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op_1, gather_op_2, constant }, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
|
||||
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 768 }));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, 768 });
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
|
||||
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
@ -10,6 +10,7 @@
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
@ -20,7 +21,7 @@
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
|
||||
auto gather = [](const std::shared_ptr<Node> input, std::vector<int64_t> indices, bool scalar = false) -> Output<Node> {
|
||||
auto gatherv7 = [](const std::shared_ptr<Node> input, std::vector<int64_t> indices, bool scalar = false) -> Output<Node> {
|
||||
std::shared_ptr<Node> indices_node;
|
||||
if (scalar)
|
||||
indices_node = opset7::Constant::create(element::i64, {}, indices);
|
||||
@ -30,18 +31,29 @@ auto gather = [](const std::shared_ptr<Node> input, std::vector<int64_t> indices
|
||||
input, indices_node, opset7::Constant::create(element::i64, {}, {0}));
|
||||
};
|
||||
|
||||
TEST_F(TransformationTestsF, ShapeSubGraphTest) {
|
||||
auto gatherv8 = [](const std::shared_ptr<Node> input, std::vector<int64_t> indices, bool scalar = false) -> Output<Node> {
|
||||
std::shared_ptr<Node> indices_node;
|
||||
if (scalar)
|
||||
indices_node = opset7::Constant::create(element::i64, {}, indices);
|
||||
else
|
||||
indices_node = opset7::Constant::create(element::i64, {indices.size()}, indices);
|
||||
return std::make_shared<ngraph::opset8::Gather>(input,
|
||||
indices_node,
|
||||
opset7::Constant::create(element::i64, {}, {0}));
|
||||
};
|
||||
|
||||
TEST_F(TransformationTestsF, ShapeSubGraphTestGatherv7) {
|
||||
Shape data_shape{1, 2, 3, 4};
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_1 = gather(shape_op_1, {1}, true);
|
||||
auto gather_1 = gatherv7(shape_op_1, {1}, true);
|
||||
auto unsqueeze_1 = std::make_shared<opset7::Unsqueeze>(
|
||||
gather_1, opset7::Constant::create(element::i64, {1}, {0}));
|
||||
|
||||
auto shape_op_2 = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_2 = gather(shape_op_2, {2}, true);
|
||||
auto gather_2 = gatherv7(shape_op_2, {2}, true);
|
||||
auto unsqueeze_2 = std::make_shared<opset7::Unsqueeze>(
|
||||
gather_2, opset7::Constant::create(element::i64, {1}, {0}));
|
||||
|
||||
@ -58,7 +70,7 @@ TEST_F(TransformationTestsF, ShapeSubGraphTest) {
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_1 = gather(shape_op_1, {1, 2});
|
||||
auto gather_1 = gatherv7(shape_op_1, {1, 2});
|
||||
|
||||
auto const_1 = opset7::Constant::create(element::i64, Shape{1}, {2});
|
||||
auto const_2 = opset7::Constant::create(element::i64, Shape{1}, {2});
|
||||
@ -70,18 +82,58 @@ TEST_F(TransformationTestsF, ShapeSubGraphTest) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ShapeNopSubGraphTest) {
|
||||
TEST_F(TransformationTestsF, ShapeSubGraphTestGatherv8) {
|
||||
Shape data_shape{1, 2, 3, 4};
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_1 = gatherv8(shape_op_1, {1}, true);
|
||||
auto unsqueeze_1 =
|
||||
std::make_shared<opset7::Unsqueeze>(gather_1, opset7::Constant::create(element::i64, {1}, {0}));
|
||||
|
||||
auto shape_op_2 = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_2 = gatherv8(shape_op_2, {2}, true);
|
||||
auto unsqueeze_2 =
|
||||
std::make_shared<opset7::Unsqueeze>(gather_2, opset7::Constant::create(element::i64, {1}, {0}));
|
||||
|
||||
auto const_1 = opset7::Constant::create(element::i64, Shape{1}, {2});
|
||||
auto const_2 = opset7::Constant::create(element::i64, Shape{1}, {2});
|
||||
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{unsqueeze_1, unsqueeze_2, const_1, const_2}, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
|
||||
function = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||
manager.register_pass<pass::SimplifyShapeOfSubGraph>();
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_1 = gatherv8(shape_op_1, {1, 2});
|
||||
|
||||
auto const_1 = opset7::Constant::create(element::i64, Shape{1}, {2});
|
||||
auto const_2 = opset7::Constant::create(element::i64, Shape{1}, {2});
|
||||
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{gather_1, const_1, const_2}, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
|
||||
function_ref = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ShapeNopSubGraphTestGatherv7) {
|
||||
PartialShape data_shape{-1, -1};
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_1 = gather(shape_op_1, {0}, true);
|
||||
auto gather_1 = gatherv7(shape_op_1, {0}, true);
|
||||
auto unsqueeze_1 = std::make_shared<opset7::Unsqueeze>(
|
||||
gather_1, opset7::Constant::create(element::i64, {1}, {0}));
|
||||
|
||||
auto shape_op_2 = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_2 = gather(shape_op_2, {1}, true);
|
||||
auto gather_2 = gatherv7(shape_op_2, {1}, true);
|
||||
auto unsqueeze_2 = std::make_shared<opset7::Unsqueeze>(
|
||||
gather_2, opset7::Constant::create(element::i64, {1}, {0}));
|
||||
|
||||
@ -98,3 +150,51 @@ TEST_F(TransformationTestsF, ShapeNopSubGraphTest) {
|
||||
function_ref = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ShapeNopSubGraphTestGatherv8) {
|
||||
PartialShape data_shape{-1, -1};
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_1 = gatherv8(shape_op_1, {0}, true);
|
||||
auto unsqueeze_1 =
|
||||
std::make_shared<opset7::Unsqueeze>(gather_1, opset7::Constant::create(element::i64, {1}, {0}));
|
||||
|
||||
auto shape_op_2 = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_2 = gatherv8(shape_op_2, {1}, true);
|
||||
auto unsqueeze_2 =
|
||||
std::make_shared<opset7::Unsqueeze>(gather_2, opset7::Constant::create(element::i64, {1}, {0}));
|
||||
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{unsqueeze_1, unsqueeze_2}, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
|
||||
function = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||
manager.register_pass<pass::SimplifyShapeOfSubGraph>();
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, shape_op_1, false);
|
||||
function_ref = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, GroupedGatherEliminationNegative) {
|
||||
PartialShape data_shape{2, 128};
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_op = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather = gatherv8(shape_op, {1}, true);
|
||||
auto unsqueeze = std::make_shared<opset7::Unsqueeze>(gather, opset7::Constant::create(element::i64, {1}, {0}));
|
||||
|
||||
auto constant_1 = ngraph::opset7::Constant::create(element::i64, {1}, {0});
|
||||
auto constant_2 = ngraph::opset7::Constant::create(element::i64, {1}, {1});
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{constant_1, constant_2, unsqueeze}, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
|
||||
function = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||
manager.register_pass<pass::GroupedGatherElimination>();
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user