[CommonOptimizations] SimplifySecondInputOfReshape transformation (#7412)

* [CommonOptimizations] SimplifySecondInputOfReshape implementation

* postreview fixes

* review comments fixed
This commit is contained in:
Vladislav Golubev 2021-10-21 10:49:42 +03:00 committed by GitHub
parent 6f754052cf
commit 0cf8d18988
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 658 additions and 0 deletions

View File

@ -22,6 +22,7 @@ class TRANSFORMATIONS_API SharedShapeOf;
class TRANSFORMATIONS_API GroupedGatherElimination;
class TRANSFORMATIONS_API GatherNopElimination;
class TRANSFORMATIONS_API SimplifyGatherShapeOf;
class TRANSFORMATIONS_API SimplifySecondInputOfReshape;
} // namespace pass
} // namespace ngraph
@ -82,3 +83,14 @@ public:
NGRAPH_RTTI_DECLARATION;
SimplifyGatherShapeOf();
};
/**
* @ingroup ie_transformation_common_api
* @brief SimplifySecondInputOfReshape optimizes `shapeof->gather` into zero values for
* reshape pattern values if possible.
*/
class ngraph::pass::SimplifySecondInputOfReshape : public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
SimplifySecondInputOfReshape();
};

View File

@ -10,6 +10,7 @@
#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
@ -191,6 +192,102 @@ ngraph::pass::SimplifyGatherShapeOf::SimplifyGatherShapeOf() {
this->register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::SimplifySecondInputOfReshape, "SimplifySecondInputOfReshape", 0);
ngraph::pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
MATCHER_SCOPE(SimplifySecondInputOfReshape);
const auto input = pattern::any_input();
auto has_static_1d_shape = [](const Output<Node>& output) {
return pattern::has_static_shape()(output) && pattern::rank_equals(1)(output);
};
const auto concat = pattern::wrap_type<opset8::Concat>(has_static_1d_shape);
const auto reshape_pattern = pattern::wrap_type<opset8::Reshape>({ input, concat });
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto node = m.get_match_root();
const auto reshape = as_type_ptr<opset8::Reshape>(node);
if (!reshape || reshape->get_special_zero() == false) {
return false;
}
const auto concat = as_type_ptr<opset8::Concat>(reshape->get_input_node_shared_ptr(1));
if (!concat)
return false;
const auto concat_axis = concat->get_axis();
OPENVINO_ASSERT(concat_axis == 0, "axis is not valid for matched Concat with 1D output");
auto data = m.get_pattern_value_map().at(input);
if (is_type<opset8::FakeQuantize>(data.get_node_shared_ptr()) ||
ngraph::op::is_unary_elementwise_arithmetic(data.get_node_shared_ptr())) {
data = data.get_node_shared_ptr()->input_value(0);
}
auto check_shape_of_gather = [&](const std::shared_ptr<Node>& gather) {
auto shape_of = gather->get_input_node_shared_ptr(0);
if ((!is_type<opset8::ShapeOf>(shape_of) && !is_type<opset1::ShapeOf>(shape_of)) ||
(shape_of->get_output_target_inputs(0).size() > 1)) {
return false;
}
return shape_of->input_value(0) == data;
};
const auto concat_inputs = concat->input_values();
OutputVector new_concat_inputs = concat_inputs;
std::int64_t gather_dims_expected_location = 0;
bool gather_folded = false;
// We need this check to avoid sequences shapeOf -> gather -> concat
// that change the arrangement of dimensions in the reshape pattern
for (auto& input : new_concat_inputs) {
if (const auto gather = as_type_ptr<op::util::GatherBase>(input.get_node_shared_ptr())) {
auto indices_constant = as_type_ptr<opset8::Constant>(gather->get_input_node_shared_ptr(1));
if (!indices_constant || !check_shape_of_gather(gather)) {
continue;
}
bool gather_can_be_fused = true;
const auto indices = indices_constant->cast_vector<std::int64_t>();
for (size_t i = 0; i < indices.size(); ++i) {
if (indices[i] != gather_dims_expected_location) {
gather_can_be_fused = false;
}
gather_dims_expected_location++;
}
if (gather_can_be_fused) {
const size_t num_of_unchanged_dimensions = indices.size();
const auto subgraph_et = gather->get_input_element_type(0);
input = opset8::Constant::create(subgraph_et, Shape{ num_of_unchanged_dimensions }, { 0 });
gather_folded = true;
}
} else {
const auto concat_input_shape = input.get_shape();
OPENVINO_ASSERT(concat_input_shape.size() == 1, "concat input rank is not valid for matched Concat with 1D output");
gather_dims_expected_location += concat_input_shape[0];
}
}
if (!gather_folded) {
return false;
}
const auto new_concat = op::util::make_try_fold<opset8::Concat>(new_concat_inputs, concat_axis);
new_concat->set_friendly_name(concat->get_friendly_name());
copy_runtime_info(concat, new_concat);
const auto new_reshape = reshape->clone_with_new_inputs({ reshape->input_value(0), new_concat });
new_reshape->set_friendly_name(reshape->get_friendly_name());
copy_runtime_info(reshape, new_reshape);
replace_node(reshape, new_reshape);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_pattern, matcher_name);
this->register_matcher(m, callback);
}
NGRAPH_RTTI_DEFINITION(ngraph::pass::SimplifyShapeOfSubGraph, "SimplifyShapeOfSubGraph", 0);
bool ngraph::pass::SimplifyShapeOfSubGraph::run_on_function(std::shared_ptr<ngraph::Function> f) {
@ -201,6 +298,7 @@ bool ngraph::pass::SimplifyShapeOfSubGraph::run_on_function(std::shared_ptr<ngra
manager.register_pass<ngraph::pass::GroupedGatherElimination>();
manager.register_pass<ngraph::pass::GatherNopElimination>();
manager.register_pass<ngraph::pass::SimplifyGatherShapeOf>();
manager.register_pass<ngraph::pass::SimplifySecondInputOfReshape>();
manager.run_passes(f);
return false;
}

View File

@ -0,0 +1,548 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <memory>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
#include <transformations/init_node_info.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ngraph;
auto gather = [](const std::shared_ptr<Node> input, std::vector<int64_t> indices) -> Output<Node> {
std::shared_ptr<Node> indices_node = opset7::Constant::create(element::i64, {indices.size()}, indices);
std::shared_ptr<Node> axis_node = opset7::Constant::create(element::i64, {}, { 0 });
return std::make_shared<opset7::Gather>(input, indices_node, axis_node);
};
auto fake_quantize = [](const std::shared_ptr<Node> input) -> Output<Node> {
auto il = opset7::Constant::create(element::f32, Shape{}, { 0.f });
auto ih = opset7::Constant::create(element::f32, Shape{}, { 25.5f });
auto ol = opset7::Constant::create(element::f32, Shape{}, { 0.f });
auto oh = opset7::Constant::create(element::f32, Shape{}, { 25.5f });
return std::make_shared<opset7::FakeQuantize>(input, il, ih, ol, oh, 256);
};
TEST(TransformationTests, SimplifySecondInputOfReshapeTest1) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
PartialShape data_shape{1, 128, 12, 64};
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
auto constant = opset7::Constant::create(element::i64, Shape{1}, {768});
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant }, 0);
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
f = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::SimplifySecondInputOfReshape>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 768 }));
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, 768 });
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
f_ref = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifySecondInputOfReshapeTest2) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
PartialShape data_shape{ 1, 128, 12, 64 };
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto fq = fake_quantize(data);
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
auto constant = opset7::Constant::create(element::i64, Shape{ 1 }, { 768 });
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant }, 0);
auto reshape = std::make_shared<opset7::Reshape>(fq, concat, true);
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::SimplifySecondInputOfReshape>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 768 }));
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto fq = fake_quantize(data);
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, 768 });
auto reshape = std::make_shared<opset7::Reshape>(fq, reshape_pattern, true);
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifySecondInputOfReshapeTest3) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
PartialShape data_shape{ 1, 128, 768 };
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
auto constant_1 = opset7::Constant::create(element::i64, Shape{ 1 }, { 12 });
auto constant_2 = opset7::Constant::create(element::i64, Shape{ 1 }, { 64 });
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant_1, constant_2 }, 0);
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::SimplifySecondInputOfReshape>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 12, 64 }));
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 4 }, { 0, 0, 12, 64 });
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifySecondInputOfReshapeTest4) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
PartialShape data_shape{ 1, 128, 768 };
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto fq = fake_quantize(data);
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
auto constant_1 = opset7::Constant::create(element::i64, Shape{ 1 }, { 12 });
auto constant_2 = opset7::Constant::create(element::i64, Shape{ 1 }, { 64 });
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant_1, constant_2 }, 0);
auto reshape = std::make_shared<opset7::Reshape>(fq, concat, true);
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::SimplifySecondInputOfReshape>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 12, 64 }));
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto fq = fake_quantize(data);
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 4 }, { 0, 0, 12, 64 });
auto reshape = std::make_shared<opset7::Reshape>(fq, reshape_pattern, true);
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifySecondInputOfReshapeTest5) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
PartialShape data_shape = PartialShape::dynamic(3);
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
auto constant = opset7::Constant::create(element::i64, Shape{ 1 }, { -1 });
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant }, 0);
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::SimplifySecondInputOfReshape>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape::dynamic(3));
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, -1 });
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifySecondInputOfReshapeTest6) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
PartialShape data_shape = PartialShape::dynamic();
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
auto constant = opset7::Constant::create(element::i64, Shape{ 1 }, { -1 });
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant }, 0);
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::SimplifySecondInputOfReshape>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape::dynamic(3));
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, -1 });
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifySecondInputOfReshapeTest7) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
PartialShape data_shape{ 1, 128, 12, 64 };
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
auto gather_op = gather(shape_of, std::vector<int64_t>{2, 3});
auto constant_1 = opset7::Constant::create(element::i64, Shape{ 1 }, { 64 });
auto constant_2 = opset7::Constant::create(element::i64, Shape{ 1 }, { 2 });
auto concat = std::make_shared<opset7::Concat>(OutputVector{ constant_1, constant_2, gather_op }, 0);
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::SimplifySecondInputOfReshape>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 64, 2, 12, 64 }));
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 4 }, { 64, 2, 0, 0 });
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifySecondInputOfReshapeTest8) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
PartialShape data_shape{ 1, 128, 12, 64 };
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
auto gather_op = gather(shape_of, std::vector<int64_t>{2});
auto constant_1 = opset7::Constant::create(element::i64, Shape{ 1 }, { 64 });
auto constant_2 = opset7::Constant::create(element::i64, Shape{ 1 }, { 2 });
auto constant_3 = opset7::Constant::create(element::i64, Shape{ 1 }, { 64 });
auto concat = std::make_shared<opset7::Concat>(OutputVector{ constant_1, constant_2, gather_op, constant_3 }, 0);
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::SimplifySecondInputOfReshape>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 64, 2, 12, 64 }));
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 4 }, { 64, 2, 0, 64 });
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifySecondInputOfReshapeTest9) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
PartialShape data_shape{ 1, 128, 12, 64 };
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 2});
auto constant = opset7::Constant::create(element::i64, Shape{ 1 }, { -1 });
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant }, 0);
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
f_ref = f;
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::SimplifySecondInputOfReshape>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 12, 8192 }));
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifySecondInputOfReshapeTest10) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
PartialShape data_shape{ 1, 128, 12, 64 };
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto shape_of_1 = std::make_shared<opset7::ShapeOf>(data);
auto shape_of_2 = std::make_shared<opset7::ShapeOf>(data);
auto gather_op_1 = gather(shape_of_1, std::vector<int64_t>{0, 1});
auto gather_op_2 = gather(shape_of_2, std::vector<int64_t>{3});
auto gather_op_3 = gather(shape_of_2, std::vector<int64_t>{2});
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op_1, gather_op_2, gather_op_3 }, 0);
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::SimplifySecondInputOfReshape>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 64, 12 }));
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
auto constant = opset7::Constant::create(element::i64, Shape{ 2 }, { 0, 0 });
auto gather_op_2 = gather(shape_of, std::vector<int64_t>{3});
auto gather_op_3 = gather(shape_of, std::vector<int64_t>{2});
auto concat = std::make_shared<opset7::Concat>(OutputVector{ constant, gather_op_2, gather_op_3 }, 0);
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifySecondInputOfReshapeTest11) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
PartialShape data_shape{ 1, 128, 12, 64 };
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto shape_of_1 = std::make_shared<opset7::ShapeOf>(data);
auto shape_of_2 = std::make_shared<opset7::ShapeOf>(data);
auto concat_input_0 = gather(shape_of_1, std::vector<int64_t>{0});
auto concat_input_1 = ngraph::opset7::Constant::create(ngraph::element::i64, {1}, { 64 });
auto concat_input_2 = gather(shape_of_2, std::vector<int64_t>{2});
auto concat_input_3 = ngraph::opset7::Constant::create(ngraph::element::i64, {1}, { 128 });
auto concat = std::make_shared<opset7::Concat>(OutputVector{ concat_input_0, concat_input_1, concat_input_2, concat_input_3 }, 0);
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::SimplifySecondInputOfReshape>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 64, 12, 128 }));
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto constant = opset7::Constant::create(element::i64, Shape{ 4 }, { 0, 64, 0, 128 });
auto reshape = std::make_shared<opset7::Reshape>(data, constant, true);
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifySecondInputOfReshapeTest12) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
PartialShape data_shape{ 1, 128, 768 };
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto gelu = std::make_shared<opset7::Gelu>(data);
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
auto constant_1 = opset7::Constant::create(element::i64, Shape{ 1 }, { 12 });
auto constant_2 = opset7::Constant::create(element::i64, Shape{ 1 }, { 64 });
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant_1, constant_2 }, 0);
auto reshape = std::make_shared<opset7::Reshape>(gelu, concat, true);
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::SimplifySecondInputOfReshape>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 12, 64 }));
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto gelu = std::make_shared<opset7::Gelu>(data);
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 4 }, { 0, 0, 12, 64 });
auto reshape = std::make_shared<opset7::Reshape>(gelu, reshape_pattern, true);
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifySecondInputOfReshapeTest13) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
PartialShape data_shape{ 1, 128, 12, 64 };
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto shape_of = std::make_shared<opset7::ShapeOf>(data, element::i32);
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
auto constant = opset7::Constant::create(element::i32, Shape{ 1 }, { 768 });
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant }, 0);
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::SimplifySecondInputOfReshape>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 768 }));
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto reshape_pattern = opset7::Constant::create(element::i32, Shape{ 3 }, { 0, 0, 768 });
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifySecondInputOfReshapeTest14) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
PartialShape data_shape{ 1, 128, 12, 64 };
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto shape_of = std::make_shared<opset1::ShapeOf>(data);
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
auto constant = opset7::Constant::create(element::i64, Shape{ 1 }, { 768 });
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant }, 0);
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::SimplifySecondInputOfReshape>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 768 }));
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, 768 });
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}
TEST(TransformationTests, SimplifySecondInputOfReshapeTest15) {
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
PartialShape data_shape{ 1, 128, 768 };
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto gelu = std::make_shared<opset7::Gelu>(data);
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
auto constant = opset7::Constant::create(element::i64, Shape{ 2 }, { 12, 64 });
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant }, 0);
auto reshape = std::make_shared<opset7::Reshape>(gelu, concat, true);
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
pass::Manager m;
m.register_pass<pass::InitNodeInfo>();
m.register_pass<pass::SimplifySecondInputOfReshape>();
m.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 12, 64 }));
}
{
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
auto gelu = std::make_shared<opset7::Gelu>(data);
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 4 }, { 0, 0, 12, 64 });
auto reshape = std::make_shared<opset7::Reshape>(gelu, reshape_pattern, true);
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
}
auto res = compare_functions(f, f_ref, true);
ASSERT_TRUE(res.first) << res.second;
}