Added TransposeReshapeEliminationForMatmul Transformation (#7938)

This commit is contained in:
Alexandra Sidorova
2021-12-08 17:45:40 +03:00
committed by GitHub
parent 672565a8ed
commit 92760949bf
4 changed files with 358 additions and 0 deletions

View File

@@ -0,0 +1,149 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <memory>
#include <queue>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <transformations/common_optimizations/transpose_reshape_elimination_for_matmul.hpp>
#include <transformations/op_conversions/einsum_decomposition.hpp>
#include <transformations/init_node_info.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ngraph;
TEST_F(TransformationTestsF, TransposeReshapeEliminationForMatMul) {
Shape data_shape_1{10, 2};
Shape data_shape_2{10, 2, 25};
{
auto data_1 = std::make_shared<opset1::Parameter>(element::f32, data_shape_1);
auto data_2 = std::make_shared<opset1::Parameter>(element::f32, data_shape_2);
auto const_transpose_before = opset1::Constant::create(element::i32, Shape{3}, {1, 2, 0});
auto transpose_before = std::make_shared<opset1::Transpose>(data_2, const_transpose_before);
auto const_reshape_before = opset1::Constant::create(element::i32, Shape{2}, {2, 250});
auto reshape_before = std::make_shared<opset1::Reshape>(transpose_before, const_reshape_before, false);
auto matmul = std::make_shared<opset1::MatMul>(data_1, reshape_before);
auto const_reshape_after = opset1::Constant::create(element::i32, Shape{3}, {10, 10, 25});
auto reshape_after = std::make_shared<opset1::Reshape>(matmul, const_reshape_after, false);
auto const_tranpose_after = opset1::Constant::create(element::i32, Shape{3}, {2, 0, 1});
auto tranpose_after = std::make_shared<opset1::Transpose>(reshape_after, const_tranpose_after);
function = std::make_shared<Function>(NodeVector{tranpose_after}, ParameterVector{data_1, data_2});
manager.register_pass<pass::InitNodeInfo>();
manager.register_pass<pass::TransposeReshapeEliminationForMatmul>();
}
{
auto data_1 = std::make_shared<opset1::Parameter>(element::f32, data_shape_1);
auto data_2 = std::make_shared<opset1::Parameter>(element::f32, data_shape_2);
auto matmul = std::make_shared<opset1::MatMul>(data_1, data_2);
function_ref = std::make_shared<Function>(NodeVector{matmul}, ParameterVector{data_1, data_2});
}
}
TEST_F(TransformationTestsF, TransposeReshapeEliminationForMatMul_TransposedA) {
Shape data_shape_1{2, 10};
Shape data_shape_2{10, 2, 25};
{
auto data_1 = std::make_shared<opset1::Parameter>(element::f32, data_shape_1);
auto data_2 = std::make_shared<opset1::Parameter>(element::f32, data_shape_2);
auto const_transpose_before = opset1::Constant::create(element::i32, Shape{3}, {1, 2, 0});
auto transpose_before = std::make_shared<opset1::Transpose>(data_2, const_transpose_before);
auto const_reshape_before = opset1::Constant::create(element::i32, Shape{2}, {2, 250});
auto reshape_before = std::make_shared<opset1::Reshape>(transpose_before, const_reshape_before, false);
auto matmul = std::make_shared<opset1::MatMul>(data_1, reshape_before, true, false);
auto const_reshape_after = opset1::Constant::create(element::i32, Shape{3}, {10, 10, 25});
auto reshape_after = std::make_shared<opset1::Reshape>(matmul, const_reshape_after, false);
auto const_tranpose_after = opset1::Constant::create(element::i32, Shape{3}, {2, 0, 1});
auto tranpose_after = std::make_shared<opset1::Transpose>(reshape_after, const_tranpose_after);
function = std::make_shared<Function>(NodeVector{tranpose_after}, ParameterVector{data_1, data_2});
manager.register_pass<pass::TransposeReshapeEliminationForMatmul>();
}
{
auto data_1 = std::make_shared<opset1::Parameter>(element::f32, data_shape_1);
auto data_2 = std::make_shared<opset1::Parameter>(element::f32, data_shape_2);
auto matmul = std::make_shared<opset1::MatMul>(data_1, data_2, true, false);
function_ref = std::make_shared<Function>(NodeVector{matmul}, ParameterVector{data_1, data_2});
}
}
TEST_F(TransformationTestsF, TransposeReshapeEliminationForMatMul_TransposedB) {
Shape data_shape_1{10, 2};
Shape data_shape_2{10, 2, 25};
{
auto data_1 = std::make_shared<opset1::Parameter>(element::f32, data_shape_1);
auto data_2 = std::make_shared<opset1::Parameter>(element::f32, data_shape_2);
auto const_transpose_before = opset1::Constant::create(element::i32, Shape{3}, {0, 2, 1});
auto transpose_before = std::make_shared<opset1::Transpose>(data_2, const_transpose_before);
auto const_reshape_before = opset1::Constant::create(element::i32, Shape{2}, {250, 2});
auto reshape_before = std::make_shared<opset1::Reshape>(transpose_before, const_reshape_before, false);
auto matmul = std::make_shared<opset1::MatMul>(data_1, reshape_before, false, true);
auto const_reshape_after = opset1::Constant::create(element::i32, Shape{3}, {10, 10, 25});
auto reshape_after = std::make_shared<opset1::Reshape>(matmul, const_reshape_after, false);
auto const_tranpose_after = opset1::Constant::create(element::i32, Shape{3}, {1, 0, 2});
auto tranpose_after = std::make_shared<opset1::Transpose>(reshape_after, const_tranpose_after);
function = std::make_shared<Function>(NodeVector{tranpose_after}, ParameterVector{data_1, data_2});
manager.register_pass<pass::TransposeReshapeEliminationForMatmul>();
}
{
auto data_1 = std::make_shared<opset1::Parameter>(element::f32, data_shape_1);
auto data_2 = std::make_shared<opset1::Parameter>(element::f32, data_shape_2);
auto matmul = std::make_shared<opset1::MatMul>(data_1, data_2);
function_ref = std::make_shared<Function>(NodeVector{matmul}, ParameterVector{data_1, data_2});
}
}
TEST_F(TransformationTestsF, TransposeReshapeEliminationForMatMul_TransposedAB) {
Shape data_shape_1{2, 10};
Shape data_shape_2{10, 2, 25};
{
auto data_1 = std::make_shared<opset1::Parameter>(element::f32, data_shape_1);
auto data_2 = std::make_shared<opset1::Parameter>(element::f32, data_shape_2);
auto const_transpose_before = opset1::Constant::create(element::i32, Shape{3}, {0, 2, 1});
auto transpose_before = std::make_shared<opset1::Transpose>(data_2, const_transpose_before);
auto const_reshape_before = opset1::Constant::create(element::i32, Shape{2}, {250, 2});
auto reshape_before = std::make_shared<opset1::Reshape>(transpose_before, const_reshape_before, false);
auto matmul = std::make_shared<opset1::MatMul>(data_1, reshape_before, true, true);
auto const_reshape_after = opset1::Constant::create(element::i32, Shape{3}, {10, 10, 25});
auto reshape_after = std::make_shared<opset1::Reshape>(matmul, const_reshape_after, false);
auto const_tranpose_after = opset1::Constant::create(element::i32, Shape{3}, {1, 0, 2});
auto tranpose_after = std::make_shared<opset1::Transpose>(reshape_after, const_tranpose_after);
function = std::make_shared<Function>(NodeVector{tranpose_after}, ParameterVector{data_1, data_2});
manager.register_pass<pass::TransposeReshapeEliminationForMatmul>();
}
{
auto data_1 = std::make_shared<opset1::Parameter>(element::f32, data_shape_1);
auto data_2 = std::make_shared<opset1::Parameter>(element::f32, data_shape_2);
auto matmul = std::make_shared<opset1::MatMul>(data_1, data_2, true, false);
function_ref = std::make_shared<Function>(NodeVector{matmul}, ParameterVector{data_1, data_2});
}
}
TEST_F(TransformationTestsF, TransposeReshapeEliminationForMatMul_Einsum) {
Shape data_shape_1{5, 2};
Shape data_shape_2{10, 2, 25};
{
auto data_1 = std::make_shared<opset1::Parameter>(element::f32, data_shape_1);
auto data_2 = std::make_shared<opset1::Parameter>(element::f32, data_shape_2);
auto einsum = std::make_shared<opset7::Einsum>(OutputVector{data_1, data_2}, "kl,mlj->mkj");
function = std::make_shared<Function>(NodeVector{einsum}, ParameterVector{data_1, data_2});
manager.register_pass<pass::EinsumDecomposition>();
manager.register_pass<pass::TransposeReshapeEliminationForMatmul>();
}
{
auto data_1 = std::make_shared<opset1::Parameter>(element::f32, data_shape_1);
auto data_2 = std::make_shared<opset1::Parameter>(element::f32, data_shape_2);
// for some cases Reshape may be first input for Matmul
auto shape_constant = std::make_shared<opset1::Constant>(element::i64, Shape{data_shape_1.size()}, data_shape_1);
auto reshape = std::make_shared<opset1::Reshape>(data_1, shape_constant, false);
auto matmul = std::make_shared<opset1::MatMul>(reshape, data_2, false, false);
function_ref = std::make_shared<Function>(NodeVector{matmul}, ParameterVector{data_1, data_2});
}
}

View File

@@ -0,0 +1,32 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <vector>
#include <memory>
#include "transformations_visibility.hpp"
#include "ngraph/pass/graph_rewrite.hpp"
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API TransposeReshapeEliminationForMatmul;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief TransposeReshapeEliminationForMatmul transformation eliminates Transpose and Reshape which were created to
* align input and output dimension ranks before second MatMul input and after MatMul output
* (for example, after Einsum Decomposition inside TensorFlow 1 and nGraph EinsumDecomposition transformation)
*/
class ngraph::pass::TransposeReshapeEliminationForMatmul: public ngraph::pass::MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
TransposeReshapeEliminationForMatmul();
};

View File

@@ -51,6 +51,7 @@
#include "transformations/common_optimizations/mul_conv_fusion.hpp"
#include "transformations/common_optimizations/interpolate_sequence_fusion.hpp"
#include "transformations/common_optimizations/convert_compression_only_to_legacy.hpp"
#include <transformations/common_optimizations/transpose_reshape_elimination_for_matmul.hpp>
#include "transformations/op_conversions/bidirectional_sequences_decomposition.hpp"
#include "transformations/op_conversions/convert_pad_to_group_conv.hpp"
#include "transformations/op_conversions/convert_divide.hpp"
@@ -149,6 +150,7 @@ bool ngraph::pass::CommonOptimizations::run_on_function(std::shared_ptr<ngraph::
decomp->add_matcher<ngraph::pass::SoftmaxDecomposition, false>();
decomp->add_matcher<ngraph::pass::GatherNegativeConstIndicesNormalize>();
decomp->add_matcher<ngraph::pass::DropoutWithRandomUniformReplacer>();
decomp->add_matcher<ngraph::pass::TransposeReshapeEliminationForMatmul>();
decomp->set_name("ngraph::pass::CommonDecompositions");
// CF is required after all decompositions

View File

@@ -0,0 +1,175 @@
// Copyright (C) 2018-2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/transpose_reshape_elimination_for_matmul.hpp"
#include <memory>
#include <vector>
#include "ngraph/opsets/opset1.hpp"
#include "ngraph/rt_info.hpp"
#include "ngraph/pattern/op/wrap_type.hpp"
#include "ngraph/validation_util.hpp"
#include "itt.hpp"
namespace {
/// \brief Check for correct Transpose orders which are before and after MatMul. Second Transpose must be back for
/// first Transpose before MatMul
///
/// \param before_order Order of Transpose which is before MatMul
/// \param after_order Order of Transpose which is after MatMul
/// \param transposed_b true - second MatMul input is transposed, otherwise, it's not transposed
///
/// \return True - Transposes have right orders, otherwise, Transposes have incorrect order for transformation
///
bool check_transposes(const std::vector<int64_t>& before_order, const std::vector<int64_t>& after_order, const bool transposed_b) {
const size_t rank = before_order.size();
if (rank < 3)
return false;
if (before_order.size() != after_order.size())
return false;
if (transposed_b) {
// before order must be : 0, 1, 2, ..., N-1, N-2
std::vector<int64_t> start_order(rank);
std::iota(start_order.begin(), start_order.begin() + rank - 2, 0);
start_order[rank - 1] = rank - 2;
start_order[rank - 2] = rank - 1;
if (before_order != start_order)
return false;
// after order must be : 1, ..., N-2, 0, N-1
std::vector<int64_t> back_order(rank);
std::iota(back_order.begin(), back_order.begin() + rank - 2, 1);
back_order[rank - 2] = 0;
back_order[rank - 1] = rank - 1;
if (after_order != back_order)
return false;
} else {
// before order must be : N-2, N-1, 0, 1, 2, ...
std::vector<int64_t> needed_transpose_order_before(rank);
std::iota(needed_transpose_order_before.begin() + 2, needed_transpose_order_before.end(), 0);
needed_transpose_order_before[0] = rank - 2;
needed_transpose_order_before[1] = rank - 1;
if (before_order != needed_transpose_order_before)
return false;
// transpose order after matmul must be back for transpose before
std::vector<int64_t> back_order(rank);
for (size_t i = 0; i < rank; i++)
back_order[i] = std::distance(after_order.begin(), std::find(after_order.begin(), after_order.end(), i));
if (before_order != back_order)
return false;
}
return true;
}
/// \brief Check for input Reshape which are before MatMul
///
/// \param reshape Reshape which is before MatMul
/// \param new_shape New shape for Reshape
/// \param transposed_b true - second MatMul input is transposed, otherwise, it's not transposed
///
/// \return True - Reshape has right new shape for reshaping, otherwise, Reshape has incorrect new shape for transformation
///
bool check_input_reshape(const std::shared_ptr<ngraph::opset1::Reshape>& reshape,
const std::vector<int64_t>& new_shape, const bool transposed_b) {
const auto input_shape = reshape->get_input_shape(0);
const size_t input_rank = input_shape.size();
const size_t output_rank = reshape->get_output_shape(0).size();
if (input_rank < 3 || output_rank != 2)
return false;
if (transposed_b) {
const int64_t k = input_shape.back();
const int64_t new_n = ov::shape_size(input_shape) / k;
if (new_shape != std::vector<int64_t>{new_n, k})
return false;
} else {
const int64_t k = input_shape.front();
const int64_t new_n = ov::shape_size(input_shape) / k;
if (new_shape != std::vector<int64_t>{k, -1} && new_shape != std::vector<int64_t>{k, new_n})
return false;
}
return true;
}
} // namespace
NGRAPH_RTTI_DEFINITION(ngraph::pass::TransposeReshapeEliminationForMatmul, "TransposeReshapeEliminationForMatmul", 0);
ngraph::pass::TransposeReshapeEliminationForMatmul::TransposeReshapeEliminationForMatmul() {
MATCHER_SCOPE(TransposeReshapeEliminationForMatmul);
auto input_1_pattern = ngraph::pattern::any_input([] (const Output<Node>& node) -> bool {
const auto& shape = node.get_partial_shape();
const auto& rank = shape.rank();
return rank.is_static() && rank.get_length() == 2 && shape.is_static();
});
auto input_2_pattern = ngraph::pattern::any_input([] (const Output<Node>& node) -> bool {
return node.get_partial_shape().is_static();
});
auto const_transpose_before_pattern = ngraph::pattern::wrap_type<opset1::Constant>();
auto transpose_before_pattern = ngraph::pattern::wrap_type<opset1::Transpose>({input_2_pattern, const_transpose_before_pattern});
auto const_reshape_before_pattern = ngraph::pattern::wrap_type<opset1::Constant>();
auto reshape_before_pattern = ngraph::pattern::wrap_type<opset1::Reshape>({transpose_before_pattern, const_reshape_before_pattern});
auto matmul_pattern = ngraph::pattern::wrap_type<opset1::MatMul>({input_1_pattern, reshape_before_pattern});
auto const_reshape_after_pattern = ngraph::pattern::wrap_type<opset1::Constant>();
auto reshape_after_pattern = ngraph::pattern::wrap_type<opset1::Reshape>({matmul_pattern, const_reshape_after_pattern});
auto const_transpose_after_pattern = ngraph::pattern::wrap_type<opset1::Constant>();
auto transpose_after_pattern = ngraph::pattern::wrap_type<opset1::Transpose>({reshape_after_pattern, const_transpose_after_pattern});
ngraph::matcher_pass_callback callback = [=](pattern::Matcher& m) {
const auto& pattern_value_map = m.get_pattern_value_map();
const auto& input_1 = pattern_value_map.at(input_1_pattern);
const auto& input_2 = pattern_value_map.at(input_2_pattern);
auto matmul = std::dynamic_pointer_cast<opset1::MatMul>(pattern_value_map.at(matmul_pattern).get_node_shared_ptr());
if (!matmul)
return false;
const bool transposed_a = matmul->get_transpose_a();
const bool transposed_b = matmul->get_transpose_b();
auto reshape_before = std::dynamic_pointer_cast<opset1::Reshape>(pattern_value_map.at(reshape_before_pattern).get_node_shared_ptr());
auto reshape_after = std::dynamic_pointer_cast<opset1::Reshape>(pattern_value_map.at(reshape_after_pattern).get_node_shared_ptr());
auto reshape_before_constant = std::dynamic_pointer_cast<ngraph::opset1::Constant>(
pattern_value_map.at(const_reshape_before_pattern).get_node_shared_ptr());
if (!reshape_before || !reshape_after || !reshape_before_constant)
return false;
if (!check_input_reshape(reshape_before, reshape_before_constant->cast_vector<int64_t>(), transposed_b))
return false;
// check transpose order before and after matmul
auto transpose_before = std::dynamic_pointer_cast<opset1::Transpose>(pattern_value_map.at(transpose_before_pattern).get_node_shared_ptr());
auto transpose_after = std::dynamic_pointer_cast<opset1::Transpose>(pattern_value_map.at(transpose_after_pattern).get_node_shared_ptr());
auto transpose_before_constant = std::dynamic_pointer_cast<ngraph::opset1::Constant>(transpose_before->get_input_node_shared_ptr(1));
auto transpose_after_constant = std::dynamic_pointer_cast<ngraph::opset1::Constant>(transpose_after->get_input_node_shared_ptr(1));
if (!transpose_before || !transpose_after || !transpose_before_constant || !transpose_after_constant)
return false;
auto transpose_before_order = transpose_before_constant->cast_vector<int64_t>();
auto transpose_after_order = transpose_after_constant->cast_vector<int64_t>();
// need to check that input shape is correctly contracted and output shape is correctly unpacked using transposes
if (!check_transposes(transpose_before_order, transpose_after_order, transposed_b))
return false;
const auto new_matmul = std::make_shared<opset1::MatMul>(input_1, input_2, transposed_a, false);
new_matmul->set_friendly_name(transpose_after->get_friendly_name());
copy_runtime_info({transpose_before, reshape_before, matmul, reshape_after, transpose_after}, new_matmul);
replace_node(transpose_after, new_matmul);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose_after_pattern, matcher_name);
this->register_matcher(m, callback);
}