[CommonOptimizations] SimplifySecondInputOfReshape transformation (#7412)
* [CommonOptimizations] SimplifySecondInputOfReshape implementation * postreview fixes * review comments fixed
This commit is contained in:
parent
6f754052cf
commit
0cf8d18988
@ -22,6 +22,7 @@ class TRANSFORMATIONS_API SharedShapeOf;
|
||||
class TRANSFORMATIONS_API GroupedGatherElimination;
|
||||
class TRANSFORMATIONS_API GatherNopElimination;
|
||||
class TRANSFORMATIONS_API SimplifyGatherShapeOf;
|
||||
class TRANSFORMATIONS_API SimplifySecondInputOfReshape;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
@ -82,3 +83,14 @@ public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
SimplifyGatherShapeOf();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief SimplifySecondInputOfReshape optimizes `shapeof->gather` into zero values for
|
||||
* reshape pattern values if possible.
|
||||
*/
|
||||
class ngraph::pass::SimplifySecondInputOfReshape : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
SimplifySecondInputOfReshape();
|
||||
};
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include <ngraph/opsets/opset2.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
|
||||
@ -191,6 +192,102 @@ ngraph::pass::SimplifyGatherShapeOf::SimplifyGatherShapeOf() {
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::SimplifySecondInputOfReshape, "SimplifySecondInputOfReshape", 0);
|
||||
|
||||
ngraph::pass::SimplifySecondInputOfReshape::SimplifySecondInputOfReshape() {
|
||||
MATCHER_SCOPE(SimplifySecondInputOfReshape);
|
||||
const auto input = pattern::any_input();
|
||||
auto has_static_1d_shape = [](const Output<Node>& output) {
|
||||
return pattern::has_static_shape()(output) && pattern::rank_equals(1)(output);
|
||||
};
|
||||
const auto concat = pattern::wrap_type<opset8::Concat>(has_static_1d_shape);
|
||||
const auto reshape_pattern = pattern::wrap_type<opset8::Reshape>({ input, concat });
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto node = m.get_match_root();
|
||||
const auto reshape = as_type_ptr<opset8::Reshape>(node);
|
||||
if (!reshape || reshape->get_special_zero() == false) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto concat = as_type_ptr<opset8::Concat>(reshape->get_input_node_shared_ptr(1));
|
||||
if (!concat)
|
||||
return false;
|
||||
|
||||
const auto concat_axis = concat->get_axis();
|
||||
OPENVINO_ASSERT(concat_axis == 0, "axis is not valid for matched Concat with 1D output");
|
||||
|
||||
auto data = m.get_pattern_value_map().at(input);
|
||||
if (is_type<opset8::FakeQuantize>(data.get_node_shared_ptr()) ||
|
||||
ngraph::op::is_unary_elementwise_arithmetic(data.get_node_shared_ptr())) {
|
||||
data = data.get_node_shared_ptr()->input_value(0);
|
||||
}
|
||||
|
||||
auto check_shape_of_gather = [&](const std::shared_ptr<Node>& gather) {
|
||||
auto shape_of = gather->get_input_node_shared_ptr(0);
|
||||
if ((!is_type<opset8::ShapeOf>(shape_of) && !is_type<opset1::ShapeOf>(shape_of)) ||
|
||||
(shape_of->get_output_target_inputs(0).size() > 1)) {
|
||||
return false;
|
||||
}
|
||||
return shape_of->input_value(0) == data;
|
||||
};
|
||||
|
||||
const auto concat_inputs = concat->input_values();
|
||||
OutputVector new_concat_inputs = concat_inputs;
|
||||
std::int64_t gather_dims_expected_location = 0;
|
||||
bool gather_folded = false;
|
||||
|
||||
// We need this check to avoid sequences shapeOf -> gather -> concat
|
||||
// that change the arrangement of dimensions in the reshape pattern
|
||||
for (auto& input : new_concat_inputs) {
|
||||
if (const auto gather = as_type_ptr<op::util::GatherBase>(input.get_node_shared_ptr())) {
|
||||
auto indices_constant = as_type_ptr<opset8::Constant>(gather->get_input_node_shared_ptr(1));
|
||||
if (!indices_constant || !check_shape_of_gather(gather)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool gather_can_be_fused = true;
|
||||
const auto indices = indices_constant->cast_vector<std::int64_t>();
|
||||
for (size_t i = 0; i < indices.size(); ++i) {
|
||||
if (indices[i] != gather_dims_expected_location) {
|
||||
gather_can_be_fused = false;
|
||||
}
|
||||
gather_dims_expected_location++;
|
||||
}
|
||||
|
||||
if (gather_can_be_fused) {
|
||||
const size_t num_of_unchanged_dimensions = indices.size();
|
||||
const auto subgraph_et = gather->get_input_element_type(0);
|
||||
input = opset8::Constant::create(subgraph_et, Shape{ num_of_unchanged_dimensions }, { 0 });
|
||||
gather_folded = true;
|
||||
}
|
||||
} else {
|
||||
const auto concat_input_shape = input.get_shape();
|
||||
OPENVINO_ASSERT(concat_input_shape.size() == 1, "concat input rank is not valid for matched Concat with 1D output");
|
||||
gather_dims_expected_location += concat_input_shape[0];
|
||||
}
|
||||
}
|
||||
|
||||
if (!gather_folded) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto new_concat = op::util::make_try_fold<opset8::Concat>(new_concat_inputs, concat_axis);
|
||||
new_concat->set_friendly_name(concat->get_friendly_name());
|
||||
copy_runtime_info(concat, new_concat);
|
||||
|
||||
const auto new_reshape = reshape->clone_with_new_inputs({ reshape->input_value(0), new_concat });
|
||||
new_reshape->set_friendly_name(reshape->get_friendly_name());
|
||||
|
||||
copy_runtime_info(reshape, new_reshape);
|
||||
replace_node(reshape, new_reshape);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_pattern, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::SimplifyShapeOfSubGraph, "SimplifyShapeOfSubGraph", 0);
|
||||
|
||||
bool ngraph::pass::SimplifyShapeOfSubGraph::run_on_function(std::shared_ptr<ngraph::Function> f) {
|
||||
@ -201,6 +298,7 @@ bool ngraph::pass::SimplifyShapeOfSubGraph::run_on_function(std::shared_ptr<ngra
|
||||
manager.register_pass<ngraph::pass::GroupedGatherElimination>();
|
||||
manager.register_pass<ngraph::pass::GatherNopElimination>();
|
||||
manager.register_pass<ngraph::pass::SimplifyGatherShapeOf>();
|
||||
manager.register_pass<ngraph::pass::SimplifySecondInputOfReshape>();
|
||||
manager.run_passes(f);
|
||||
return false;
|
||||
}
|
||||
|
@ -0,0 +1,548 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <transformations/common_optimizations/simplify_shape_of_sub_graph.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
|
||||
auto gather = [](const std::shared_ptr<Node> input, std::vector<int64_t> indices) -> Output<Node> {
|
||||
std::shared_ptr<Node> indices_node = opset7::Constant::create(element::i64, {indices.size()}, indices);
|
||||
std::shared_ptr<Node> axis_node = opset7::Constant::create(element::i64, {}, { 0 });
|
||||
return std::make_shared<opset7::Gather>(input, indices_node, axis_node);
|
||||
};
|
||||
|
||||
auto fake_quantize = [](const std::shared_ptr<Node> input) -> Output<Node> {
|
||||
auto il = opset7::Constant::create(element::f32, Shape{}, { 0.f });
|
||||
auto ih = opset7::Constant::create(element::f32, Shape{}, { 25.5f });
|
||||
auto ol = opset7::Constant::create(element::f32, Shape{}, { 0.f });
|
||||
auto oh = opset7::Constant::create(element::f32, Shape{}, { 25.5f });
|
||||
return std::make_shared<opset7::FakeQuantize>(input, il, ih, ol, oh, 256);
|
||||
};
|
||||
|
||||
TEST(TransformationTests, SimplifySecondInputOfReshapeTest1) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
PartialShape data_shape{1, 128, 12, 64};
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
|
||||
auto constant = opset7::Constant::create(element::i64, Shape{1}, {768});
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant }, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
|
||||
f = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 768 }));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, 768 });
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
|
||||
f_ref = std::make_shared<Function>(NodeVector{reshape}, ParameterVector{data});
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifySecondInputOfReshapeTest2) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
PartialShape data_shape{ 1, 128, 12, 64 };
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto fq = fake_quantize(data);
|
||||
|
||||
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
|
||||
auto constant = opset7::Constant::create(element::i64, Shape{ 1 }, { 768 });
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant }, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(fq, concat, true);
|
||||
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 768 }));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto fq = fake_quantize(data);
|
||||
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, 768 });
|
||||
auto reshape = std::make_shared<opset7::Reshape>(fq, reshape_pattern, true);
|
||||
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifySecondInputOfReshapeTest3) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
PartialShape data_shape{ 1, 128, 768 };
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
|
||||
auto constant_1 = opset7::Constant::create(element::i64, Shape{ 1 }, { 12 });
|
||||
auto constant_2 = opset7::Constant::create(element::i64, Shape{ 1 }, { 64 });
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant_1, constant_2 }, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
|
||||
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 12, 64 }));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 4 }, { 0, 0, 12, 64 });
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
|
||||
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifySecondInputOfReshapeTest4) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
PartialShape data_shape{ 1, 128, 768 };
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto fq = fake_quantize(data);
|
||||
|
||||
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
|
||||
auto constant_1 = opset7::Constant::create(element::i64, Shape{ 1 }, { 12 });
|
||||
auto constant_2 = opset7::Constant::create(element::i64, Shape{ 1 }, { 64 });
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant_1, constant_2 }, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(fq, concat, true);
|
||||
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 12, 64 }));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto fq = fake_quantize(data);
|
||||
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 4 }, { 0, 0, 12, 64 });
|
||||
auto reshape = std::make_shared<opset7::Reshape>(fq, reshape_pattern, true);
|
||||
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifySecondInputOfReshapeTest5) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
PartialShape data_shape = PartialShape::dynamic(3);
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
|
||||
auto constant = opset7::Constant::create(element::i64, Shape{ 1 }, { -1 });
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant }, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
|
||||
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape::dynamic(3));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, -1 });
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
|
||||
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifySecondInputOfReshapeTest6) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
PartialShape data_shape = PartialShape::dynamic();
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
|
||||
auto constant = opset7::Constant::create(element::i64, Shape{ 1 }, { -1 });
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant }, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
|
||||
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape::dynamic(3));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, -1 });
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
|
||||
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifySecondInputOfReshapeTest7) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
PartialShape data_shape{ 1, 128, 12, 64 };
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_op = gather(shape_of, std::vector<int64_t>{2, 3});
|
||||
auto constant_1 = opset7::Constant::create(element::i64, Shape{ 1 }, { 64 });
|
||||
auto constant_2 = opset7::Constant::create(element::i64, Shape{ 1 }, { 2 });
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{ constant_1, constant_2, gather_op }, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
|
||||
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 64, 2, 12, 64 }));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 4 }, { 64, 2, 0, 0 });
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
|
||||
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifySecondInputOfReshapeTest8) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
PartialShape data_shape{ 1, 128, 12, 64 };
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_op = gather(shape_of, std::vector<int64_t>{2});
|
||||
auto constant_1 = opset7::Constant::create(element::i64, Shape{ 1 }, { 64 });
|
||||
auto constant_2 = opset7::Constant::create(element::i64, Shape{ 1 }, { 2 });
|
||||
auto constant_3 = opset7::Constant::create(element::i64, Shape{ 1 }, { 64 });
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{ constant_1, constant_2, gather_op, constant_3 }, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
|
||||
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 64, 2, 12, 64 }));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 4 }, { 64, 2, 0, 64 });
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
|
||||
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifySecondInputOfReshapeTest9) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
PartialShape data_shape{ 1, 128, 12, 64 };
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 2});
|
||||
auto constant = opset7::Constant::create(element::i64, Shape{ 1 }, { -1 });
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant }, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
|
||||
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
f_ref = f;
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 12, 8192 }));
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifySecondInputOfReshapeTest10) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
PartialShape data_shape{ 1, 128, 12, 64 };
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_of_1 = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto shape_of_2 = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_op_1 = gather(shape_of_1, std::vector<int64_t>{0, 1});
|
||||
auto gather_op_2 = gather(shape_of_2, std::vector<int64_t>{3});
|
||||
auto gather_op_3 = gather(shape_of_2, std::vector<int64_t>{2});
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op_1, gather_op_2, gather_op_3 }, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
|
||||
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 64, 12 }));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto constant = opset7::Constant::create(element::i64, Shape{ 2 }, { 0, 0 });
|
||||
auto gather_op_2 = gather(shape_of, std::vector<int64_t>{3});
|
||||
auto gather_op_3 = gather(shape_of, std::vector<int64_t>{2});
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{ constant, gather_op_2, gather_op_3 }, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
|
||||
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifySecondInputOfReshapeTest11) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
PartialShape data_shape{ 1, 128, 12, 64 };
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_of_1 = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto shape_of_2 = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto concat_input_0 = gather(shape_of_1, std::vector<int64_t>{0});
|
||||
auto concat_input_1 = ngraph::opset7::Constant::create(ngraph::element::i64, {1}, { 64 });
|
||||
auto concat_input_2 = gather(shape_of_2, std::vector<int64_t>{2});
|
||||
auto concat_input_3 = ngraph::opset7::Constant::create(ngraph::element::i64, {1}, { 128 });
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{ concat_input_0, concat_input_1, concat_input_2, concat_input_3 }, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
|
||||
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 64, 12, 128 }));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto constant = opset7::Constant::create(element::i64, Shape{ 4 }, { 0, 64, 0, 128 });
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, constant, true);
|
||||
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifySecondInputOfReshapeTest12) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
PartialShape data_shape{ 1, 128, 768 };
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto gelu = std::make_shared<opset7::Gelu>(data);
|
||||
|
||||
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
|
||||
auto constant_1 = opset7::Constant::create(element::i64, Shape{ 1 }, { 12 });
|
||||
auto constant_2 = opset7::Constant::create(element::i64, Shape{ 1 }, { 64 });
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant_1, constant_2 }, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(gelu, concat, true);
|
||||
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 12, 64 }));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto gelu = std::make_shared<opset7::Gelu>(data);
|
||||
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 4 }, { 0, 0, 12, 64 });
|
||||
auto reshape = std::make_shared<opset7::Reshape>(gelu, reshape_pattern, true);
|
||||
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifySecondInputOfReshapeTest13) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
PartialShape data_shape{ 1, 128, 12, 64 };
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_of = std::make_shared<opset7::ShapeOf>(data, element::i32);
|
||||
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
|
||||
auto constant = opset7::Constant::create(element::i32, Shape{ 1 }, { 768 });
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant }, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
|
||||
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 768 }));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto reshape_pattern = opset7::Constant::create(element::i32, Shape{ 3 }, { 0, 0, 768 });
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
|
||||
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifySecondInputOfReshapeTest14) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
PartialShape data_shape{ 1, 128, 12, 64 };
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
|
||||
auto shape_of = std::make_shared<opset1::ShapeOf>(data);
|
||||
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
|
||||
auto constant = opset7::Constant::create(element::i64, Shape{ 1 }, { 768 });
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant }, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, concat, true);
|
||||
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 768 }));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 3 }, { 0, 0, 768 });
|
||||
auto reshape = std::make_shared<opset7::Reshape>(data, reshape_pattern, true);
|
||||
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(TransformationTests, SimplifySecondInputOfReshapeTest15) {
|
||||
std::shared_ptr<Function> f(nullptr), f_ref(nullptr);
|
||||
|
||||
PartialShape data_shape{ 1, 128, 768 };
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto gelu = std::make_shared<opset7::Gelu>(data);
|
||||
|
||||
auto shape_of = std::make_shared<opset7::ShapeOf>(data);
|
||||
auto gather_op = gather(shape_of, std::vector<int64_t>{0, 1});
|
||||
auto constant = opset7::Constant::create(element::i64, Shape{ 2 }, { 12, 64 });
|
||||
auto concat = std::make_shared<opset7::Concat>(OutputVector{ gather_op, constant }, 0);
|
||||
|
||||
auto reshape = std::make_shared<opset7::Reshape>(gelu, concat, true);
|
||||
f = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
|
||||
pass::Manager m;
|
||||
m.register_pass<pass::InitNodeInfo>();
|
||||
m.register_pass<pass::SimplifySecondInputOfReshape>();
|
||||
m.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
ASSERT_EQ(reshape->get_output_partial_shape(0), PartialShape({ 1, 128, 12, 64 }));
|
||||
}
|
||||
{
|
||||
auto data = std::make_shared<opset7::Parameter>(element::f32, data_shape);
|
||||
auto gelu = std::make_shared<opset7::Gelu>(data);
|
||||
auto reshape_pattern = opset7::Constant::create(element::i64, Shape{ 4 }, { 0, 0, 12, 64 });
|
||||
auto reshape = std::make_shared<opset7::Reshape>(gelu, reshape_pattern, true);
|
||||
f_ref = std::make_shared<Function>(NodeVector{ reshape }, ParameterVector{ data });
|
||||
}
|
||||
|
||||
auto res = compare_functions(f, f_ref, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
Loading…
Reference in New Issue
Block a user