[Symbolic Transformation] DeReshape FullyConnected (#21419)
This commit is contained in:
parent
24209239bf
commit
a6903b8398
@ -10,6 +10,7 @@
|
||||
namespace ov {
|
||||
namespace pass {
|
||||
class TRANSFORMATIONS_API DeReshapeMatMul;
|
||||
class TRANSFORMATIONS_API DeReshapeFullyConnected;
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
|
||||
@ -64,3 +65,28 @@ public:
|
||||
OPENVINO_RTTI("DeReshapeMatMul", "0");
|
||||
DeReshapeMatMul();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Transformation uses symbol / label information to optimize out Reshape operations surrounding special cases of
|
||||
* MatMul. It checks that surrounding Reshapes are only manipulating with batch dimensions of tensor in a do-undo kind
|
||||
* of way. The difference with previous optimization is that this case has Reshape only on one input of MatMul and the
|
||||
* other input is strictly 2D. Such MatMuls are also called FullyConnected
|
||||
*
|
||||
* Example:
|
||||
* Before:
|
||||
* [A,B,4096] -> Reshape -> [A*B,4096]
|
||||
* MatMul [A*B,4608] -> Reshape -> [A,B,4608]
|
||||
* [4096,4608]
|
||||
*
|
||||
* After:
|
||||
* [A,B,4096] ->
|
||||
* MatMul -> [A,B,4608]
|
||||
* [4096,4608] ->
|
||||
*
|
||||
*/
|
||||
class ov::pass::DeReshapeFullyConnected : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("DeReshapeFullyConnected", "0");
|
||||
DeReshapeFullyConnected();
|
||||
};
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include "openvino/core/dimension_tracker.hpp"
|
||||
#include "openvino/core/validation_util.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/matmul.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "openvino/op/util/binary_elementwise_arithmetic.hpp"
|
||||
@ -334,3 +335,68 @@ ov::pass::DeReshapeMatMul::DeReshapeMatMul() {
|
||||
auto m = std::make_shared<pattern::Matcher>(final_reshape, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
||||
ov::pass::DeReshapeFullyConnected::DeReshapeFullyConnected() {
|
||||
MATCHER_SCOPE(DeReshapeFullyConnected);
|
||||
|
||||
auto reshaped_input = pattern::wrap_type<op::v1::Reshape>([](Output<Node> out) -> bool {
|
||||
const auto& input_shape = out.get_node_shared_ptr()->get_input_partial_shape(0);
|
||||
if (input_shape.rank().is_dynamic() || input_shape.size() < 3)
|
||||
return false;
|
||||
const auto& output_shape = out.get_partial_shape();
|
||||
if (output_shape.rank().is_dynamic() || output_shape.size() < 2)
|
||||
return false;
|
||||
return dims_are_equal(input_shape[input_shape.size() - 1], output_shape[output_shape.size() - 1]);
|
||||
});
|
||||
auto converted =
|
||||
pattern::wrap_type<op::v0::Convert>({reshaped_input}, pattern::consumers_count(1)); // optional convert
|
||||
|
||||
auto dynamic_input = std::make_shared<pattern::op::Or>(OutputVector{reshaped_input, converted});
|
||||
auto static_input = pattern::any_input(pattern::rank_equals(2));
|
||||
auto mm_label = pattern::wrap_type<op::v0::MatMul>({dynamic_input, static_input}, [](Output<Node> out) -> bool {
|
||||
auto mm = ov::as_type_ptr<op::v0::MatMul>(out.get_node_shared_ptr());
|
||||
return mm && !mm->get_transpose_a() && pattern::consumers_count(1)(out);
|
||||
});
|
||||
|
||||
auto reshaped_output =
|
||||
pattern::wrap_type<op::v1::Reshape>({mm_label, pattern::any_input()}, [](Output<Node> out) -> bool {
|
||||
const auto& input_shape = out.get_node_shared_ptr()->get_input_partial_shape(0);
|
||||
if (input_shape.rank().is_dynamic() || input_shape.size() < 2)
|
||||
return false;
|
||||
const auto& output_shape = out.get_partial_shape();
|
||||
if (output_shape.rank().is_dynamic() || output_shape.size() < 3)
|
||||
return false;
|
||||
return dims_are_equal(input_shape[input_shape.size() - 1], output_shape[output_shape.size() - 1]);
|
||||
});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) {
|
||||
const auto& pm = m.get_pattern_map();
|
||||
|
||||
const auto& in_reshape = pm.at(reshaped_input);
|
||||
const auto& out_reshape = pm.at(reshaped_output);
|
||||
const auto& matmul = pm.at(mm_label);
|
||||
|
||||
const auto& in_shape = in_reshape->get_input_partial_shape(0);
|
||||
const auto& out_shape = out_reshape->get_output_partial_shape(0);
|
||||
|
||||
if (in_shape.size() != out_shape.size())
|
||||
return false;
|
||||
|
||||
for (size_t i = 0; i < in_shape.size() - 1; ++i)
|
||||
if (!dims_are_equal(in_shape[i], out_shape[i]))
|
||||
return false;
|
||||
if (pm.count(converted)) {
|
||||
const auto& convert = pm.at(converted);
|
||||
convert->input(0).replace_source_output(in_reshape->input_value(0));
|
||||
convert->validate_and_infer_types();
|
||||
} else {
|
||||
matmul->input(0).replace_source_output(in_reshape->input_value(0));
|
||||
}
|
||||
ov::replace_output_update_name(out_reshape->output(0), matmul->output(0));
|
||||
matmul->validate_and_infer_types();
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(reshaped_output, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
|
@ -202,6 +202,7 @@ ov::pass::SymbolicOptimizations::SymbolicOptimizations(bool full_run) {
|
||||
REGISTER_SYMBOLIC(OptimizeLabelsUsedAsValues) // reduce shape sub-graphs
|
||||
REGISTER_SYMBOLIC(LabelResolvingThroughSelect) // figures out that broadcasting didn't happen through Select op
|
||||
REGISTER_SYMBOLIC(DeReshapeMatMul)
|
||||
REGISTER_SYMBOLIC(DeReshapeFullyConnected)
|
||||
REGISTER_SYMBOLIC(ReshapeOptimizations)
|
||||
REGISTER_SYMBOLIC(SimplifyShapeOfSubGraph)
|
||||
}
|
||||
|
@ -0,0 +1,112 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "common_test_utils/ov_test_utils.hpp"
|
||||
#include "openvino/core/dimension_tracker.hpp"
|
||||
#include "openvino/core/model.hpp"
|
||||
#include "openvino/op/concat.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/matmul.hpp"
|
||||
#include "openvino/op/parameter.hpp"
|
||||
#include "openvino/op/reshape.hpp"
|
||||
#include "transformations/symbolic_transformations/dereshape_matmul.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
using namespace ov;
|
||||
using namespace ov::op;
|
||||
using namespace std;
|
||||
|
||||
namespace {
|
||||
void label_shape(ov::PartialShape& shape) {
|
||||
auto table = std::make_shared<ov::TableOfEquivalence>(42);
|
||||
auto tracker = ov::DimensionTracker(table);
|
||||
tracker.set_up_for_tracking(shape);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST_F(TransformationTestsF, DeReshapeFC) {
|
||||
{
|
||||
auto shape = PartialShape{-1, -1, 40};
|
||||
label_shape(shape); // we label shape with consecutive labels: 42, 43, 44
|
||||
|
||||
auto data = make_shared<v0::Parameter>(element::f32, shape);
|
||||
auto in_reshape = make_shared<v1::Reshape>(data, v0::Constant::create(element::i64, {2}, {-1, 40}), true);
|
||||
auto second_input = make_shared<v0::Parameter>(element::f32, Shape{40, 80});
|
||||
|
||||
auto matmul = make_shared<v0::MatMul>(in_reshape, second_input);
|
||||
|
||||
auto batch_dims = ov::op::util::node_to_get_shape_value_of_indices_from_shape_source(data, {0, 1});
|
||||
auto pattern =
|
||||
make_shared<v0::Concat>(OutputVector{batch_dims, v0::Constant::create(element::i64, {1}, {80})}, 0);
|
||||
auto out_reshape = make_shared<v1::Reshape>(matmul, pattern, false);
|
||||
|
||||
model = make_shared<Model>(NodeVector{out_reshape}, ParameterVector{data, second_input});
|
||||
manager.register_pass<pass::DeReshapeFullyConnected>();
|
||||
}
|
||||
{
|
||||
auto shape = PartialShape{-1, -1, 40};
|
||||
label_shape(shape); // we label shape with consecutive labels: 42, 43, 44
|
||||
|
||||
auto data = make_shared<v0::Parameter>(element::f32, shape);
|
||||
auto second_input = make_shared<v0::Parameter>(element::f32, Shape{40, 80});
|
||||
auto matmul = make_shared<v0::MatMul>(data, second_input);
|
||||
|
||||
model_ref = make_shared<Model>(NodeVector{matmul}, ParameterVector{data, second_input});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, DeReshapeFCWithConvert) {
|
||||
{
|
||||
auto shape = PartialShape{-1, -1, 40};
|
||||
label_shape(shape); // we label shape with consecutive labels: 42, 43, 44
|
||||
|
||||
auto data = make_shared<v0::Parameter>(element::f16, shape);
|
||||
auto in_reshape = make_shared<v1::Reshape>(data, v0::Constant::create(element::i64, {2}, {-1, 40}), true);
|
||||
auto convert = make_shared<v0::Convert>(in_reshape, element::f32);
|
||||
auto second_input = make_shared<v0::Parameter>(element::f32, Shape{40, 80});
|
||||
|
||||
auto matmul = make_shared<v0::MatMul>(convert, second_input);
|
||||
|
||||
auto batch_dims = ov::op::util::node_to_get_shape_value_of_indices_from_shape_source(data, {0, 1});
|
||||
auto pattern =
|
||||
make_shared<v0::Concat>(OutputVector{batch_dims, v0::Constant::create(element::i64, {1}, {80})}, 0);
|
||||
auto out_reshape = make_shared<v1::Reshape>(matmul, pattern, false);
|
||||
|
||||
model = make_shared<Model>(NodeVector{out_reshape}, ParameterVector{data, second_input});
|
||||
manager.register_pass<pass::DeReshapeFullyConnected>();
|
||||
}
|
||||
{
|
||||
auto shape = PartialShape{-1, -1, 40};
|
||||
label_shape(shape); // we label shape with consecutive labels: 42, 43, 44
|
||||
|
||||
auto data = make_shared<v0::Parameter>(element::f16, shape);
|
||||
auto convert = make_shared<v0::Convert>(data, element::f32);
|
||||
auto second_input = make_shared<v0::Parameter>(element::f32, Shape{40, 80});
|
||||
auto matmul = make_shared<v0::MatMul>(convert, second_input);
|
||||
|
||||
model_ref = make_shared<Model>(NodeVector{matmul}, ParameterVector{data, second_input});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, DeReshapeFCNegative) {
|
||||
{
|
||||
auto shape = PartialShape{-1, -1, 40};
|
||||
label_shape(shape); // we label shape with consecutive labels: 42, 43, 44
|
||||
|
||||
auto data = make_shared<v0::Parameter>(element::f16, shape);
|
||||
auto in_reshape = make_shared<v1::Reshape>(data, v0::Constant::create(element::i64, {2}, {-1, 40}), true);
|
||||
auto convert = make_shared<v0::Convert>(in_reshape, element::f32);
|
||||
auto second_input = make_shared<v0::Parameter>(element::f32, Shape{40, 80});
|
||||
|
||||
auto matmul = make_shared<v0::MatMul>(convert, second_input);
|
||||
|
||||
auto pattern = v0::Constant::create(element::i64, {3}, {4, -1, 80});
|
||||
auto out_reshape = make_shared<v1::Reshape>(matmul, pattern, false);
|
||||
|
||||
model = make_shared<Model>(NodeVector{out_reshape}, ParameterVector{data, second_input});
|
||||
manager.register_pass<pass::DeReshapeFullyConnected>();
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user