[CommonOptimizations] SimplifySecondInputOfReshape fix (#9210)
* [CommonOptimizations] SimplifySecondInputOfReshape fix * GroupedGatherElimination: added Gather v8 support
This commit is contained in:
parent
4a6575b4b7
commit
01689ee408
@ -74,20 +74,31 @@ ngraph::pass::GroupedGatherElimination::GroupedGatherElimination() {
|
|||||||
while (inputs.size() > i + 1) {
|
while (inputs.size() > i + 1) {
|
||||||
auto curr = inputs[i].get_node_shared_ptr(), next = inputs[i + 1].get_node_shared_ptr();
|
auto curr = inputs[i].get_node_shared_ptr(), next = inputs[i + 1].get_node_shared_ptr();
|
||||||
if (curr->get_type_info() != next->get_type_info() ||
|
if (curr->get_type_info() != next->get_type_info() ||
|
||||||
(!ov::is_type<opset1::Gather>(curr) && !ov::is_type<opset7::Gather>(curr)) ||
|
(!ov::is_type<opset1::Gather>(curr) && !ov::is_type<opset7::Gather>(curr) && !ov::is_type<opset8::Gather>(curr)) ||
|
||||||
(curr->input_value(0) != next->input_value(0))) {
|
(curr->input_value(0) != next->input_value(0))) {
|
||||||
++i;
|
++i;
|
||||||
continue;
|
continue;
|
||||||
} // curr and next are the same type of gather which takes data from the same source
|
} // curr and next are the same type of gather which takes data from the same source
|
||||||
bool is_opset1 = ov::is_type<opset1::Gather>(curr);
|
|
||||||
auto joint_indices = ngraph::op::util::make_try_fold<opset1::Concat>(OutputVector{curr->input_value(1), next->input_value(1)}, 0);
|
auto joint_indices = ngraph::op::util::make_try_fold<opset1::Concat>(OutputVector{curr->input_value(1), next->input_value(1)}, 0);
|
||||||
std::shared_ptr<Node> new_gather;
|
std::shared_ptr<Node> new_gather;
|
||||||
if (is_opset1)
|
if (ov::is_type<opset1::Gather>(curr)) {
|
||||||
new_gather = register_new_node<ngraph::opset1::Gather>(
|
new_gather = register_new_node<ngraph::opset1::Gather>(
|
||||||
curr->input_value(0), joint_indices->output(0), ngraph::opset1::Constant::create(element::i64, {}, {0})->output(0));
|
curr->input_value(0),
|
||||||
else
|
joint_indices->output(0),
|
||||||
|
ngraph::opset1::Constant::create(element::i64, {}, {0})->output(0));
|
||||||
|
} else if (ov::is_type<opset7::Gather>(curr)) {
|
||||||
new_gather = register_new_node<ngraph::opset7::Gather>(
|
new_gather = register_new_node<ngraph::opset7::Gather>(
|
||||||
curr->input_value(0), joint_indices->output(0), ngraph::opset1::Constant::create(element::i64, {}, {0})->output(0));
|
curr->input_value(0),
|
||||||
|
joint_indices->output(0),
|
||||||
|
ngraph::opset1::Constant::create(element::i64, {}, {0})->output(0));
|
||||||
|
} else if (ov::is_type<opset8::Gather>(curr)) {
|
||||||
|
new_gather = register_new_node<ngraph::opset8::Gather>(
|
||||||
|
curr->input_value(0),
|
||||||
|
joint_indices->output(0),
|
||||||
|
ngraph::opset1::Constant::create(element::i64, {}, {0})->output(0));
|
||||||
|
} else {
|
||||||
|
OPENVINO_UNREACHABLE("Unexpected Gather version");
|
||||||
|
}
|
||||||
new_ops.push_back(joint_indices);
|
new_ops.push_back(joint_indices);
|
||||||
new_ops.push_back(new_gather);
|
new_ops.push_back(new_gather);
|
||||||
inputs.erase(inputs.begin() + i);
|
inputs.erase(inputs.begin() + i);
|
||||||
@ -239,8 +250,7 @@ ngraph::pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
|
|||||||
|
|
||||||
auto check_shape_of_gather = [&](const std::shared_ptr<Node>& gather) {
|
auto check_shape_of_gather = [&](const std::shared_ptr<Node>& gather) {
|
||||||
auto shape_of = gather->get_input_node_shared_ptr(0);
|
auto shape_of = gather->get_input_node_shared_ptr(0);
|
||||||
if ((!is_type<opset8::ShapeOf>(shape_of) && !is_type<opset1::ShapeOf>(shape_of)) ||
|
if (!is_type<opset8::ShapeOf>(shape_of) && !is_type<opset1::ShapeOf>(shape_of)) {
|
||||||
(shape_of->get_output_target_inputs(0).size() > 1)) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return shape_of->input_value(0) == data;
|
return shape_of->input_value(0) == data;
|
||||||
|
@ -547,3 +547,37 @@ TEST(TransformationTests, SimplifySecondInputOfReshapeTest15) {
|
|||||||
auto res = compare_functions(f, f_ref, true);
|
auto res = compare_functions(f, f_ref, true);
|
||||||
ASSERT_TRUE(res.first) << res.second;
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TransformationTests, SimplifySecondInputOfReshapeTest16) {
|
||||||
|
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||||
|
|
||||||
|
PartialShape data_shape{ 1, 128, 12, 64 };
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||||
|
|
||||||
|
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||||
|
auto gather_op_1 = gather(shape_of, std::vector<int64_t>{0});
|
||||||
|
auto gather_op_2 = gather(shape_of, std::vector<int64_t>{1});
|
||||||
|
auto constant = opset7::Constant::create(element::i64, Shape{ 1 }, { 768 });
|
||||||
|
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op_1, gather_op_2, constant }, 0);
|
||||||
|
|
||||||
|
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
|
||||||
|
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||||
|
|
||||||
|
pass::Manager m;
|
||||||
|
m.register_pass<pass::InitNodeInfo>();
|
||||||
|
m.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||||
|
m.run_passes(f);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(f));
|
||||||
|
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 768 }));
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||||
|
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, 768 });
|
||||||
|
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
|
||||||
|
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||||
|
}
|
||||||
|
|
||||||
|
auto res = compare_functions(f, f_ref, true);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
@ -10,6 +10,7 @@
|
|||||||
|
|
||||||
#include <ngraph/function.hpp>
|
#include <ngraph/function.hpp>
|
||||||
#include <ngraph/opsets/opset7.hpp>
|
#include <ngraph/opsets/opset7.hpp>
|
||||||
|
#include <ngraph/opsets/opset8.hpp>
|
||||||
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
|
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
|
||||||
#include <transformations/init_node_info.hpp>
|
#include <transformations/init_node_info.hpp>
|
||||||
#include <ngraph/pass/manager.hpp>
|
#include <ngraph/pass/manager.hpp>
|
||||||
@ -20,7 +21,7 @@
|
|||||||
using namespace testing;
|
using namespace testing;
|
||||||
using namespace ngraph;
|
using namespace ngraph;
|
||||||
|
|
||||||
auto gather = [](const std::shared_ptr<Node> input, std::vector<int64_t> indices, bool scalar = false) -> Output<Node> {
|
auto gatherv7 = [](const std::shared_ptr<Node> input, std::vector<int64_t> indices, bool scalar = false) -> Output<Node> {
|
||||||
std::shared_ptr<Node> indices_node;
|
std::shared_ptr<Node> indices_node;
|
||||||
if (scalar)
|
if (scalar)
|
||||||
indices_node = opset7::Constant::create(element::i64, {}, indices);
|
indices_node = opset7::Constant::create(element::i64, {}, indices);
|
||||||
@ -30,18 +31,29 @@ auto gather = [](const std::shared_ptr<Node> input, std::vector<int64_t> indices
|
|||||||
input, indices_node, opset7::Constant::create(element::i64, {}, {0}));
|
input, indices_node, opset7::Constant::create(element::i64, {}, {0}));
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, ShapeSubGraphTest) {
|
auto gatherv8 = [](const std::shared_ptr<Node> input, std::vector<int64_t> indices, bool scalar = false) -> Output<Node> {
|
||||||
|
std::shared_ptr<Node> indices_node;
|
||||||
|
if (scalar)
|
||||||
|
indices_node = opset7::Constant::create(element::i64, {}, indices);
|
||||||
|
else
|
||||||
|
indices_node = opset7::Constant::create(element::i64, {indices.size()}, indices);
|
||||||
|
return std::make_shared<ngraph::opset8::Gather>(input,
|
||||||
|
indices_node,
|
||||||
|
opset7::Constant::create(element::i64, {}, {0}));
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(TransformationTestsF, ShapeSubGraphTestGatherv7) {
|
||||||
Shape data_shape{1, 2, 3, 4};
|
Shape data_shape{1, 2, 3, 4};
|
||||||
{
|
{
|
||||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||||
|
|
||||||
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
||||||
auto gather_1 = gather(shape_op_1, {1}, true);
|
auto gather_1 = gatherv7(shape_op_1, {1}, true);
|
||||||
auto unsqueeze_1 = std::make_shared<opset7::Unsqueeze>(
|
auto unsqueeze_1 = std::make_shared<opset7::Unsqueeze>(
|
||||||
gather_1, opset7::Constant::create(element::i64, {1}, {0}));
|
gather_1, opset7::Constant::create(element::i64, {1}, {0}));
|
||||||
|
|
||||||
auto shape_op_2 = std::make_shared<opset7::ShapeOf>(data);
|
auto shape_op_2 = std::make_shared<opset7::ShapeOf>(data);
|
||||||
auto gather_2 = gather(shape_op_2, {2}, true);
|
auto gather_2 = gatherv7(shape_op_2, {2}, true);
|
||||||
auto unsqueeze_2 = std::make_shared<opset7::Unsqueeze>(
|
auto unsqueeze_2 = std::make_shared<opset7::Unsqueeze>(
|
||||||
gather_2, opset7::Constant::create(element::i64, {1}, {0}));
|
gather_2, opset7::Constant::create(element::i64, {1}, {0}));
|
||||||
|
|
||||||
@ -58,7 +70,7 @@ TEST_F(TransformationTestsF, ShapeSubGraphTest) {
|
|||||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||||
|
|
||||||
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
||||||
auto gather_1 = gather(shape_op_1, {1, 2});
|
auto gather_1 = gatherv7(shape_op_1, {1, 2});
|
||||||
|
|
||||||
auto const_1 = opset7::Constant::create(element::i64, Shape{1}, {2});
|
auto const_1 = opset7::Constant::create(element::i64, Shape{1}, {2});
|
||||||
auto const_2 = opset7::Constant::create(element::i64, Shape{1}, {2});
|
auto const_2 = opset7::Constant::create(element::i64, Shape{1}, {2});
|
||||||
@ -70,18 +82,58 @@ TEST_F(TransformationTestsF, ShapeSubGraphTest) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(TransformationTestsF, ShapeNopSubGraphTest) {
|
TEST_F(TransformationTestsF, ShapeSubGraphTestGatherv8) {
|
||||||
|
Shape data_shape{1, 2, 3, 4};
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||||
|
|
||||||
|
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
||||||
|
auto gather_1 = gatherv8(shape_op_1, {1}, true);
|
||||||
|
auto unsqueeze_1 =
|
||||||
|
std::make_shared<opset7::Unsqueeze>(gather_1, opset7::Constant::create(element::i64, {1}, {0}));
|
||||||
|
|
||||||
|
auto shape_op_2 = std::make_shared<opset7::ShapeOf>(data);
|
||||||
|
auto gather_2 = gatherv8(shape_op_2, {2}, true);
|
||||||
|
auto unsqueeze_2 =
|
||||||
|
std::make_shared<opset7::Unsqueeze>(gather_2, opset7::Constant::create(element::i64, {1}, {0}));
|
||||||
|
|
||||||
|
auto const_1 = opset7::Constant::create(element::i64, Shape{1}, {2});
|
||||||
|
auto const_2 = opset7::Constant::create(element::i64, Shape{1}, {2});
|
||||||
|
|
||||||
|
auto concat = std::make_shared<opset7::Concat>(OutputVector{unsqueeze_1, unsqueeze_2, const_1, const_2}, 0);
|
||||||
|
|
||||||
|
auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
|
||||||
|
function = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||||
|
manager.register_pass<pass::SimplifyShapeOfSubGraph>();
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||||
|
|
||||||
|
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
||||||
|
auto gather_1 = gatherv8(shape_op_1, {1, 2});
|
||||||
|
|
||||||
|
auto const_1 = opset7::Constant::create(element::i64, Shape{1}, {2});
|
||||||
|
auto const_2 = opset7::Constant::create(element::i64, Shape{1}, {2});
|
||||||
|
|
||||||
|
auto concat = std::make_shared<opset7::Concat>(OutputVector{gather_1, const_1, const_2}, 0);
|
||||||
|
|
||||||
|
auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
|
||||||
|
function_ref = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformationTestsF, ShapeNopSubGraphTestGatherv7) {
|
||||||
PartialShape data_shape{-1, -1};
|
PartialShape data_shape{-1, -1};
|
||||||
{
|
{
|
||||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||||
|
|
||||||
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
||||||
auto gather_1 = gather(shape_op_1, {0}, true);
|
auto gather_1 = gatherv7(shape_op_1, {0}, true);
|
||||||
auto unsqueeze_1 = std::make_shared<opset7::Unsqueeze>(
|
auto unsqueeze_1 = std::make_shared<opset7::Unsqueeze>(
|
||||||
gather_1, opset7::Constant::create(element::i64, {1}, {0}));
|
gather_1, opset7::Constant::create(element::i64, {1}, {0}));
|
||||||
|
|
||||||
auto shape_op_2 = std::make_shared<opset7::ShapeOf>(data);
|
auto shape_op_2 = std::make_shared<opset7::ShapeOf>(data);
|
||||||
auto gather_2 = gather(shape_op_2, {1}, true);
|
auto gather_2 = gatherv7(shape_op_2, {1}, true);
|
||||||
auto unsqueeze_2 = std::make_shared<opset7::Unsqueeze>(
|
auto unsqueeze_2 = std::make_shared<opset7::Unsqueeze>(
|
||||||
gather_2, opset7::Constant::create(element::i64, {1}, {0}));
|
gather_2, opset7::Constant::create(element::i64, {1}, {0}));
|
||||||
|
|
||||||
@ -98,3 +150,51 @@ TEST_F(TransformationTestsF, ShapeNopSubGraphTest) {
|
|||||||
function_ref = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
function_ref = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformationTestsF, ShapeNopSubGraphTestGatherv8) {
|
||||||
|
PartialShape data_shape{-1, -1};
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||||
|
|
||||||
|
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
||||||
|
auto gather_1 = gatherv8(shape_op_1, {0}, true);
|
||||||
|
auto unsqueeze_1 =
|
||||||
|
std::make_shared<opset7::Unsqueeze>(gather_1, opset7::Constant::create(element::i64, {1}, {0}));
|
||||||
|
|
||||||
|
auto shape_op_2 = std::make_shared<opset7::ShapeOf>(data);
|
||||||
|
auto gather_2 = gatherv8(shape_op_2, {1}, true);
|
||||||
|
auto unsqueeze_2 =
|
||||||
|
std::make_shared<opset7::Unsqueeze>(gather_2, opset7::Constant::create(element::i64, {1}, {0}));
|
||||||
|
|
||||||
|
auto concat = std::make_shared<opset7::Concat>(OutputVector{unsqueeze_1, unsqueeze_2}, 0);
|
||||||
|
|
||||||
|
auto reshape = std::make_shared<opset7::Reshape>(data, concat, false);
|
||||||
|
function = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||||
|
manager.register_pass<pass::SimplifyShapeOfSubGraph>();
|
||||||
|
}
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||||
|
auto shape_op_1 = std::make_shared<opset7::ShapeOf>(data);
|
||||||
|
auto reshape = std::make_shared<opset7::Reshape>(data, shape_op_1, false);
|
||||||
|
function_ref = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TransformationTestsF, GroupedGatherEliminationNegative) {
|
||||||
|
PartialShape data_shape{2, 128};
|
||||||
|
{
|
||||||
|
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||||
|
|
||||||
|
auto shape_op = std::make_shared<opset7::ShapeOf>(data);
|
||||||
|
auto gather = gatherv8(shape_op, {1}, true);
|
||||||
|
auto unsqueeze = std::make_shared<opset7::Unsqueeze>(gather, opset7::Constant::create(element::i64, {1}, {0}));
|
||||||
|
|
||||||
|
auto constant_1 = ngraph::opset7::Constant::create(element::i64, {1}, {0});
|
||||||
|
auto constant_2 = ngraph::opset7::Constant::create(element::i64, {1}, {1});
|
||||||
|
auto concat = std::make_shared<opset7::Concat>(OutputVector{constant_1, constant_2, unsqueeze}, 0);
|
||||||
|
|
||||||
|
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
|
||||||
|
function = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||||
|
manager.register_pass<pass::GroupedGatherElimination>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user