diff --git a/src/common/transformations/include/transformations/symbolic_transformations/dereshape_matmul.hpp b/src/common/transformations/include/transformations/symbolic_transformations/dereshape_matmul.hpp index 6df2e406ee8..dbe44093623 100644 --- a/src/common/transformations/include/transformations/symbolic_transformations/dereshape_matmul.hpp +++ b/src/common/transformations/include/transformations/symbolic_transformations/dereshape_matmul.hpp @@ -10,6 +10,7 @@ namespace ov { namespace pass { class TRANSFORMATIONS_API DeReshapeMatMul; +class TRANSFORMATIONS_API DeReshapeFullyConnected; } // namespace pass } // namespace ov @@ -64,3 +65,28 @@ public: OPENVINO_RTTI("DeReshapeMatMul", "0"); DeReshapeMatMul(); }; + +/** + * @ingroup ie_transformation_common_api + * @brief Transformation uses symbol / label information to optimize out Reshape operations surrounding special cases of + * MatMul. It checks that surrounding Reshapes are only manipulating with batch dimensions of tensor in a do-undo kind + * of way. The difference with previous optimization is that this case has Reshape only on one input of MatMul and the + * other input is strictly 2D. Such MatMuls are also called FullyConnected + * + * Example: + * Before: + * [A,B,4096] -> Reshape -> [A*B,4096] + * MatMul [A*B,4608] -> Reshape -> [A,B,4608] + * [4096,4608] + * + * After: + * [A,B,4096] -> + * MatMul -> [A,B,4608] + * [4096,4608] -> + * + */ +class ov::pass::DeReshapeFullyConnected : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("DeReshapeFullyConnected", "0"); + DeReshapeFullyConnected(); +}; diff --git a/src/common/transformations/src/transformations/symbolic_transformations/dereshape_matmul.cpp b/src/common/transformations/src/transformations/symbolic_transformations/dereshape_matmul.cpp index 2c7ee44c632..943c912bf68 100644 --- a/src/common/transformations/src/transformations/symbolic_transformations/dereshape_matmul.cpp +++ b/src/common/transformations/src/transformations/symbolic_transformations/dereshape_matmul.cpp @@ -8,6 +8,7 @@ #include "openvino/core/dimension_tracker.hpp" #include "openvino/core/validation_util.hpp" #include "openvino/op/concat.hpp" +#include "openvino/op/convert.hpp" #include "openvino/op/matmul.hpp" #include "openvino/op/reshape.hpp" #include "openvino/op/util/binary_elementwise_arithmetic.hpp" @@ -334,3 +335,68 @@ ov::pass::DeReshapeMatMul::DeReshapeMatMul() { auto m = std::make_shared(final_reshape, matcher_name); register_matcher(m, matcher_pass_callback); } + +ov::pass::DeReshapeFullyConnected::DeReshapeFullyConnected() { + MATCHER_SCOPE(DeReshapeFullyConnected); + + auto reshaped_input = pattern::wrap_type([](Output out) -> bool { + const auto& input_shape = out.get_node_shared_ptr()->get_input_partial_shape(0); + if (input_shape.rank().is_dynamic() || input_shape.size() < 3) + return false; + const auto& output_shape = out.get_partial_shape(); + if (output_shape.rank().is_dynamic() || output_shape.size() < 2) + return false; + return dims_are_equal(input_shape[input_shape.size() - 1], output_shape[output_shape.size() - 1]); + }); + auto converted = + pattern::wrap_type({reshaped_input}, pattern::consumers_count(1)); // optional convert + + auto dynamic_input = std::make_shared(OutputVector{reshaped_input, converted}); + auto static_input = pattern::any_input(pattern::rank_equals(2)); + auto mm_label = pattern::wrap_type({dynamic_input, static_input}, [](Output out) -> bool { + auto mm = ov::as_type_ptr(out.get_node_shared_ptr()); + return mm && !mm->get_transpose_a() && pattern::consumers_count(1)(out); + }); + + auto reshaped_output = + pattern::wrap_type({mm_label, pattern::any_input()}, [](Output out) -> bool { + const auto& input_shape = out.get_node_shared_ptr()->get_input_partial_shape(0); + if (input_shape.rank().is_dynamic() || input_shape.size() < 2) + return false; + const auto& output_shape = out.get_partial_shape(); + if (output_shape.rank().is_dynamic() || output_shape.size() < 3) + return false; + return dims_are_equal(input_shape[input_shape.size() - 1], output_shape[output_shape.size() - 1]); + }); + + ov::matcher_pass_callback matcher_pass_callback = [=](pattern::Matcher& m) { + const auto& pm = m.get_pattern_map(); + + const auto& in_reshape = pm.at(reshaped_input); + const auto& out_reshape = pm.at(reshaped_output); + const auto& matmul = pm.at(mm_label); + + const auto& in_shape = in_reshape->get_input_partial_shape(0); + const auto& out_shape = out_reshape->get_output_partial_shape(0); + + if (in_shape.size() != out_shape.size()) + return false; + + for (size_t i = 0; i < in_shape.size() - 1; ++i) + if (!dims_are_equal(in_shape[i], out_shape[i])) + return false; + if (pm.count(converted)) { + const auto& convert = pm.at(converted); + convert->input(0).replace_source_output(in_reshape->input_value(0)); + convert->validate_and_infer_types(); + } else { + matmul->input(0).replace_source_output(in_reshape->input_value(0)); + } + ov::replace_output_update_name(out_reshape->output(0), matmul->output(0)); + matmul->validate_and_infer_types(); + return true; + }; + + auto m = std::make_shared(reshaped_output, matcher_name); + register_matcher(m, matcher_pass_callback); +} diff --git a/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp b/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp index dc979964299..1437845fb9a 100644 --- a/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp +++ b/src/common/transformations/src/transformations/symbolic_transformations/symbolic_optimizations.cpp @@ -202,6 +202,7 @@ ov::pass::SymbolicOptimizations::SymbolicOptimizations(bool full_run) { REGISTER_SYMBOLIC(OptimizeLabelsUsedAsValues) // reduce shape sub-graphs REGISTER_SYMBOLIC(LabelResolvingThroughSelect) // figures out that broadcasting didn't happen through Select op REGISTER_SYMBOLIC(DeReshapeMatMul) + REGISTER_SYMBOLIC(DeReshapeFullyConnected) REGISTER_SYMBOLIC(ReshapeOptimizations) REGISTER_SYMBOLIC(SimplifyShapeOfSubGraph) } diff --git a/src/common/transformations/tests/symbolic_transformations/dereshape_fullyconnected.cpp b/src/common/transformations/tests/symbolic_transformations/dereshape_fullyconnected.cpp new file mode 100644 index 00000000000..8b2fc3075be --- /dev/null +++ b/src/common/transformations/tests/symbolic_transformations/dereshape_fullyconnected.cpp @@ -0,0 +1,112 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "common_test_utils/ov_test_utils.hpp" +#include "openvino/core/dimension_tracker.hpp" +#include "openvino/core/model.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/convert.hpp" +#include "openvino/op/matmul.hpp" +#include "openvino/op/parameter.hpp" +#include "openvino/op/reshape.hpp" +#include "transformations/symbolic_transformations/dereshape_matmul.hpp" +#include "transformations/utils/utils.hpp" + +using namespace ov; +using namespace ov::op; +using namespace std; + +namespace { +void label_shape(ov::PartialShape& shape) { + auto table = std::make_shared(42); + auto tracker = ov::DimensionTracker(table); + tracker.set_up_for_tracking(shape); +} +} // namespace + +TEST_F(TransformationTestsF, DeReshapeFC) { + { + auto shape = PartialShape{-1, -1, 40}; + label_shape(shape); // we label shape with consecutive labels: 42, 43, 44 + + auto data = make_shared(element::f32, shape); + auto in_reshape = make_shared(data, v0::Constant::create(element::i64, {2}, {-1, 40}), true); + auto second_input = make_shared(element::f32, Shape{40, 80}); + + auto matmul = make_shared(in_reshape, second_input); + + auto batch_dims = ov::op::util::node_to_get_shape_value_of_indices_from_shape_source(data, {0, 1}); + auto pattern = + make_shared(OutputVector{batch_dims, v0::Constant::create(element::i64, {1}, {80})}, 0); + auto out_reshape = make_shared(matmul, pattern, false); + + model = make_shared(NodeVector{out_reshape}, ParameterVector{data, second_input}); + manager.register_pass(); + } + { + auto shape = PartialShape{-1, -1, 40}; + label_shape(shape); // we label shape with consecutive labels: 42, 43, 44 + + auto data = make_shared(element::f32, shape); + auto second_input = make_shared(element::f32, Shape{40, 80}); + auto matmul = make_shared(data, second_input); + + model_ref = make_shared(NodeVector{matmul}, ParameterVector{data, second_input}); + } +} + +TEST_F(TransformationTestsF, DeReshapeFCWithConvert) { + { + auto shape = PartialShape{-1, -1, 40}; + label_shape(shape); // we label shape with consecutive labels: 42, 43, 44 + + auto data = make_shared(element::f16, shape); + auto in_reshape = make_shared(data, v0::Constant::create(element::i64, {2}, {-1, 40}), true); + auto convert = make_shared(in_reshape, element::f32); + auto second_input = make_shared(element::f32, Shape{40, 80}); + + auto matmul = make_shared(convert, second_input); + + auto batch_dims = ov::op::util::node_to_get_shape_value_of_indices_from_shape_source(data, {0, 1}); + auto pattern = + make_shared(OutputVector{batch_dims, v0::Constant::create(element::i64, {1}, {80})}, 0); + auto out_reshape = make_shared(matmul, pattern, false); + + model = make_shared(NodeVector{out_reshape}, ParameterVector{data, second_input}); + manager.register_pass(); + } + { + auto shape = PartialShape{-1, -1, 40}; + label_shape(shape); // we label shape with consecutive labels: 42, 43, 44 + + auto data = make_shared(element::f16, shape); + auto convert = make_shared(data, element::f32); + auto second_input = make_shared(element::f32, Shape{40, 80}); + auto matmul = make_shared(convert, second_input); + + model_ref = make_shared(NodeVector{matmul}, ParameterVector{data, second_input}); + } +} + +TEST_F(TransformationTestsF, DeReshapeFCNegative) { + { + auto shape = PartialShape{-1, -1, 40}; + label_shape(shape); // we label shape with consecutive labels: 42, 43, 44 + + auto data = make_shared(element::f16, shape); + auto in_reshape = make_shared(data, v0::Constant::create(element::i64, {2}, {-1, 40}), true); + auto convert = make_shared(in_reshape, element::f32); + auto second_input = make_shared(element::f32, Shape{40, 80}); + + auto matmul = make_shared(convert, second_input); + + auto pattern = v0::Constant::create(element::i64, {3}, {4, -1, 80}); + auto out_reshape = make_shared(matmul, pattern, false); + + model = make_shared(NodeVector{out_reshape}, ParameterVector{data, second_input}); + manager.register_pass(); + } +} \ No newline at end of file