[Symbolic Transformation] DeReshape FullyConnected (#21419)

This commit is contained in:
Evgenya Nugmanova 2023-12-05 12:05:03 +04:00 committed by GitHub
parent 24209239bf
commit a6903b8398
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 205 additions and 0 deletions

View File

@ -10,6 +10,7 @@
namespace ov {
namespace pass {
class TRANSFORMATIONS_API DeReshapeMatMul;
class TRANSFORMATIONS_API DeReshapeFullyConnected;
} // namespace pass
} // namespace ov
@ -64,3 +65,28 @@ public:
OPENVINO_RTTI("DeReshapeMatMul", "0");
DeReshapeMatMul();
};
/**
* @ingroup ie_transformation_common_api
* @brief Transformation uses symbol / label information to optimize out Reshape operations surrounding special cases of
* MatMul. It checks that surrounding Reshapes are only manipulating with batch dimensions of tensor in a do-undo kind
* of way. The difference with previous optimization is that this case has Reshape only on one input of MatMul and the
* other input is strictly 2D. Such MatMuls are also called FullyConnected
*
* Example:
* Before:
* [A,B,4096] -> Reshape -> [A*B,4096]
* MatMul [A*B,4608] -> Reshape -> [A,B,4608]
* [4096,4608]
*
* After:
* [A,B,4096] ->
* MatMul -> [A,B,4608]
* [4096,4608] ->
*
*/
class ov::pass::DeReshapeFullyConnected : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("DeReshapeFullyConnected", "0");
DeReshapeFullyConnected();
};

View File

@ -8,6 +8,7 @@
#include "openvino/core/dimension_tracker.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/util/binary_elementwise_arithmetic.hpp"
@ -334,3 +335,68 @@ ov::pass::DeReshapeMatMul::DeReshapeMatMul() {
auto m = std::make_shared<pattern::Matcher>(final_reshape, matcher_name);
register_matcher(m, matcher_pass_callback);
}
ov::pass::DeReshapeFullyConnected::DeReshapeFullyConnected() {
MATCHER_SCOPE(DeReshapeFullyConnected);
auto reshaped_input = pattern::wrap_type<op::v1::Reshape>([](Output<Node> out) -> bool {
const auto& input_shape = out.get_node_shared_ptr()->get_input_partial_shape(0);
if (input_shape.rank().is_dynamic() || input_shape.size() < 3)
return false;
const auto& output_shape = out.get_partial_shape();
if (output_shape.rank().is_dynamic() || output_shape.size() < 2)
return false;
return dims_are_equal(input_shape[input_shape.size() - 1], output_shape[output_shape.size() - 1]);
});
auto converted =
pattern::wrap_type<op::v0::Convert>({reshaped_input}, pattern::consumers_count(1)); // optional convert
auto dynamic_input = std::make_shared<pattern::op::Or>(OutputVector{reshaped_input, converted});
auto static_input = pattern::any_input(pattern::rank_equals(2));
auto mm_label = pattern::wrap_type<op::v0::MatMul>({dynamic_input, static_input}, [](Output<Node> out) -> bool {
auto mm = ov::as_type_ptr<op::v0::MatMul>(out.get_node_shared_ptr());
return mm && !mm->get_transpose_a() && pattern::consumers_count(1)(out);
});
auto reshaped_output =
pattern::wrap_type<op::v1::Reshape>({mm_label, pattern::any_input()}, [](Output<Node> out) -> bool {
const auto& input_shape = out.get_node_shared_ptr()->get_input_partial_shape(0);
if (input_shape.rank().is_dynamic() || input_shape.size() < 2)
return false;
const auto& output_shape = out.get_partial_shape();
if (output_shape.rank().is_dynamic() || output_shape.size() < 3)
return false;
return dims_are_equal(input_shape[input_shape.size() - 1], output_shape[output_shape.size() - 1]);
});
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
const auto& pm = m.get_pattern_map();
const auto& in_reshape = pm.at(reshaped_input);
const auto& out_reshape = pm.at(reshaped_output);
const auto& matmul = pm.at(mm_label);
const auto& in_shape = in_reshape->get_input_partial_shape(0);
const auto& out_shape = out_reshape->get_output_partial_shape(0);
if (in_shape.size() != out_shape.size())
return false;
for (size_t i = 0; i < in_shape.size() - 1; ++i)
if (!dims_are_equal(in_shape[i], out_shape[i]))
return false;
if (pm.count(converted)) {
const auto& convert = pm.at(converted);
convert->input(0).replace_source_output(in_reshape->input_value(0));
convert->validate_and_infer_types();
} else {
matmul->input(0).replace_source_output(in_reshape->input_value(0));
}
ov::replace_output_update_name(out_reshape->output(0), matmul->output(0));
matmul->validate_and_infer_types();
return true;
};
auto m = std::make_shared<pattern::Matcher>(reshaped_output, matcher_name);
register_matcher(m, matcher_pass_callback);
}

View File

@ -202,6 +202,7 @@ ov::pass::SymbolicOptimizations::SymbolicOptimizations(bool full_run) {
REGISTER_SYMBOLIC(OptimizeLabelsUsedAsValues) // reduce shape sub-graphs
REGISTER_SYMBOLIC(LabelResolvingThroughSelect) // figures out that broadcasting didn't happen through Select op
REGISTER_SYMBOLIC(DeReshapeMatMul)
REGISTER_SYMBOLIC(DeReshapeFullyConnected)
REGISTER_SYMBOLIC(ReshapeOptimizations)
REGISTER_SYMBOLIC(SimplifyShapeOfSubGraph)
}

View File

@ -0,0 +1,112 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "common_test_utils/ov_test_utils.hpp"
#include "openvino/core/dimension_tracker.hpp"
#include "openvino/core/model.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/op/parameter.hpp"
#include "openvino/op/reshape.hpp"
#include "transformations/symbolic_transformations/dereshape_matmul.hpp"
#include "transformations/utils/utils.hpp"
using namespace ov;
using namespace ov::op;
using namespace std;
namespace {
void label_shape(ov::PartialShape& shape) {
auto table = std::make_shared<ov::TableOfEquivalence>(42);
auto tracker = ov::DimensionTracker(table);
tracker.set_up_for_tracking(shape);
}
} // namespace
TEST_F(TransformationTestsF, DeReshapeFC) {
{
auto shape = PartialShape{-1, -1, 40};
label_shape(shape); // we label shape with consecutive labels: 42, 43, 44
auto data = make_shared<v0::Parameter>(element::f32, shape);
auto in_reshape = make_shared<v1::Reshape>(data, v0::Constant::create(element::i64, {2}, {-1, 40}), true);
auto second_input = make_shared<v0::Parameter>(element::f32, Shape{40, 80});
auto matmul = make_shared<v0::MatMul>(in_reshape, second_input);
auto batch_dims = ov::op::util::node_to_get_shape_value_of_indices_from_shape_source(data, {0, 1});
auto pattern =
make_shared<v0::Concat>(OutputVector{batch_dims, v0::Constant::create(element::i64, {1}, {80})}, 0);
auto out_reshape = make_shared<v1::Reshape>(matmul, pattern, false);
model = make_shared<Model>(NodeVector{out_reshape}, ParameterVector{data, second_input});
manager.register_pass<pass::DeReshapeFullyConnected>();
}
{
auto shape = PartialShape{-1, -1, 40};
label_shape(shape); // we label shape with consecutive labels: 42, 43, 44
auto data = make_shared<v0::Parameter>(element::f32, shape);
auto second_input = make_shared<v0::Parameter>(element::f32, Shape{40, 80});
auto matmul = make_shared<v0::MatMul>(data, second_input);
model_ref = make_shared<Model>(NodeVector{matmul}, ParameterVector{data, second_input});
}
}
TEST_F(TransformationTestsF, DeReshapeFCWithConvert) {
{
auto shape = PartialShape{-1, -1, 40};
label_shape(shape); // we label shape with consecutive labels: 42, 43, 44
auto data = make_shared<v0::Parameter>(element::f16, shape);
auto in_reshape = make_shared<v1::Reshape>(data, v0::Constant::create(element::i64, {2}, {-1, 40}), true);
auto convert = make_shared<v0::Convert>(in_reshape, element::f32);
auto second_input = make_shared<v0::Parameter>(element::f32, Shape{40, 80});
auto matmul = make_shared<v0::MatMul>(convert, second_input);
auto batch_dims = ov::op::util::node_to_get_shape_value_of_indices_from_shape_source(data, {0, 1});
auto pattern =
make_shared<v0::Concat>(OutputVector{batch_dims, v0::Constant::create(element::i64, {1}, {80})}, 0);
auto out_reshape = make_shared<v1::Reshape>(matmul, pattern, false);
model = make_shared<Model>(NodeVector{out_reshape}, ParameterVector{data, second_input});
manager.register_pass<pass::DeReshapeFullyConnected>();
}
{
auto shape = PartialShape{-1, -1, 40};
label_shape(shape); // we label shape with consecutive labels: 42, 43, 44
auto data = make_shared<v0::Parameter>(element::f16, shape);
auto convert = make_shared<v0::Convert>(data, element::f32);
auto second_input = make_shared<v0::Parameter>(element::f32, Shape{40, 80});
auto matmul = make_shared<v0::MatMul>(convert, second_input);
model_ref = make_shared<Model>(NodeVector{matmul}, ParameterVector{data, second_input});
}
}
TEST_F(TransformationTestsF, DeReshapeFCNegative) {
{
auto shape = PartialShape{-1, -1, 40};
label_shape(shape); // we label shape with consecutive labels: 42, 43, 44
auto data = make_shared<v0::Parameter>(element::f16, shape);
auto in_reshape = make_shared<v1::Reshape>(data, v0::Constant::create(element::i64, {2}, {-1, 40}), true);
auto convert = make_shared<v0::Convert>(in_reshape, element::f32);
auto second_input = make_shared<v0::Parameter>(element::f32, Shape{40, 80});
auto matmul = make_shared<v0::MatMul>(convert, second_input);
auto pattern = v0::Constant::create(element::i64, {3}, {4, -1, 80});
auto out_reshape = make_shared<v1::Reshape>(matmul, pattern, false);
model = make_shared<Model>(NodeVector{out_reshape}, ParameterVector{data, second_input});
manager.register_pass<pass::DeReshapeFullyConnected>();
}
}