From 67c88f4434b51eb582a2d28cd35ff3d75268364a Mon Sep 17 00:00:00 2001 From: Alexandra Sidorova Date: Fri, 14 Jul 2023 09:31:24 +0400 Subject: [PATCH] [Snippets] Added SplitDimensionM optimization (#18160) --- .../snippets/pass/common_optimizations.hpp | 11 +- .../include/snippets/pass/tokenization.hpp | 6 +- .../src/pass/common_optimizations.cpp | 251 +++++++++++++++++- src/common/snippets/src/pass/tokenization.cpp | 2 +- .../tests/include/pass/mha_tokenization.hpp | 5 + .../tests/src/pass/mha_tokenization.cpp | 20 +- .../transformation_pipeline.cpp | 11 +- .../include/subgraph_mha.hpp | 59 ++-- .../src/subgraph_mha.cpp | 130 +++++++++ 9 files changed, 470 insertions(+), 25 deletions(-) diff --git a/src/common/snippets/include/snippets/pass/common_optimizations.hpp b/src/common/snippets/include/snippets/pass/common_optimizations.hpp index 2961603077f..dbbb723394d 100644 --- a/src/common/snippets/include/snippets/pass/common_optimizations.hpp +++ b/src/common/snippets/include/snippets/pass/common_optimizations.hpp @@ -5,8 +5,8 @@ #pragma once #include "openvino/pass/graph_rewrite.hpp" - #include "snippets/op/subgraph.hpp" +#include "snippets/pass/tokenization.hpp" namespace ov { namespace snippets { @@ -15,13 +15,20 @@ namespace pass { class CommonOptimizations : public ov::pass::MatcherPass { public: OPENVINO_RTTI("CommonOptimizations", "0"); - CommonOptimizations(); + CommonOptimizations(const SnippetsTokenization::Config& config = {}); + + // Returns True if parallelism work amount can be increased using SplitDimensionM optimization + static bool CanOptimizeParallelWA(const std::shared_ptr& node, size_t minimal_concurrency); private: // Move up Constants which aren't scalars from body to Subgraph and replace them with Parameters inside body void ExtractConstants(const std::shared_ptr& subgraph); // Move up unsupported Transposes on Parameter outputs from body void ExtractUnsupportedTransposes(const std::shared_ptr& subgraph); + // Insert Reshape nodes after and before Parameters and Results in Subgraphs with MatMul inside + // to split dimension M for MatMuls to increase work amount for parallelism + // Note: works only with 3D MHA patterns + void SplitDimensionM(const std::shared_ptr& subgraph, size_t minimal_concurrency); }; } // namespace pass diff --git a/src/common/snippets/include/snippets/pass/tokenization.hpp b/src/common/snippets/include/snippets/pass/tokenization.hpp index 151e49bb00d..3d71ad28172 100644 --- a/src/common/snippets/include/snippets/pass/tokenization.hpp +++ b/src/common/snippets/include/snippets/pass/tokenization.hpp @@ -61,8 +61,12 @@ public: * @ingroup snippets */ struct Config { - Config(bool enable_transpose = true) : mha_token_enable_transpose(enable_transpose) {} + Config(size_t minimal_concurrency = 1, bool split_m_dimension = true, bool enable_transpose = true) + : minimal_concurrency(minimal_concurrency), split_m_dimension(split_m_dimension), mha_token_enable_transpose(enable_transpose) {} + size_t minimal_concurrency = 1; + // True if "SplitDimensionM" optimization is enabled. Otherwise, it's disabled. + bool split_m_dimension = true; // False if all Transposes aren't tokenized in MHA Tokenization. // Otherwise, they may be fused into Subgraph if possible // TODO [106921]: Remove please when the ticket 106921 is implemented diff --git a/src/common/snippets/src/pass/common_optimizations.cpp b/src/common/snippets/src/pass/common_optimizations.cpp index 180207aa841..f9197afc4ec 100644 --- a/src/common/snippets/src/pass/common_optimizations.cpp +++ b/src/common/snippets/src/pass/common_optimizations.cpp @@ -19,6 +19,251 @@ namespace ov { namespace snippets { namespace pass { +bool CommonOptimizations::CanOptimizeParallelWA(const std::shared_ptr& node, size_t minimal_concurrency) { + if (!ov::is_type(node)) + return false; + // It's needed only for 3D MHA patterns + const auto mm_shape = node->get_shape(); + if (mm_shape.size() != 3) + return false; + const auto current_parallel_work_amount = + std::accumulate(mm_shape.rbegin() + 2, mm_shape.rend(), size_t(1), std::multiplies()); + const auto dim_M = *(mm_shape.rbegin() + 1); + return (current_parallel_work_amount < minimal_concurrency) && + (current_parallel_work_amount * dim_M >= minimal_concurrency); +} + +void CommonOptimizations::SplitDimensionM(const std::shared_ptr& subgraph, size_t minimal_concurrency) { + // To increase parallelism work in 3D cases for MHA pattern, + // we split 1st dimension (starting from 0th) into 2 new dimensions to get 4D Shapes where + // - 0th and 1st dimensions are used in parallel scheduling, + // - 2nd and 3rd dimensions are used in kernel + // Note: 3D Patterns don't contain Transpose inside so the reshaping is valid + + // It's needed only for MHA patterns. Need to add support for common patterns + if (!subgraph->has_domain_sensitive_ops()) + return; + + const auto& body = subgraph->body_ptr(); + const auto& parameters = body->get_parameters(); + // [107806]: If count of Parameters isn't equal to Subgraph inputs (it's possible case in general), + // we cannot garantee correct extraction since we don't have correct connections between body I/O and Subgraph I/O. + OPENVINO_ASSERT(parameters.size() == subgraph->input_values().size(), + "Failed to extract unsupported transposes: the count of Parameters isn't equal to Subgraph inputs"); + + // Need to find MatMul0 and check output shape + const auto& ops = body->get_ordered_ops(); + const auto mm_it = std::find_if(ops.begin(), ops.end(), + [](const std::shared_ptr& node){ return ov::is_type(node); }); + if (mm_it == ops.end()) + return; + + const auto matmul0 = ov::as_type_ptr(*mm_it); + if (!matmul0 || !CanOptimizeParallelWA(matmul0, minimal_concurrency)) + return; + + auto get_dim_M = [](const ov::Shape& shape) { + return *(shape.rbegin() + 1); + }; + + const auto mm_shape = matmul0->get_shape(); + const auto m_dim = get_dim_M(mm_shape); // M + const auto n_dim = mm_shape.back(); // N + // [113745] Heurestic is equal to double block size. + // When this optimization will be moved into Subgraph and blocking param will be implemented as dependents of shapes, + // need to implement common way (for all backends) to calculate optimal value of M dimension + const auto optimal_m_dim = 32 * 2; + const auto optimal_parallelism_work_amount = minimal_concurrency; + if (m_dim <= optimal_m_dim) + return; + + const auto batch_dim = + std::accumulate(mm_shape.rbegin() + 2, mm_shape.rend(), size_t(1), std::multiplies()); // B (batch) + size_t batch_m_dim = 1; + size_t new_m_dim = m_dim; + + // Need to find optimized dimension splitting: [b1..bk, m, n] -> [b1..bk, batch_m_dim, new_m_dim, n] + // The work amount for parallelism should be divided by max thread count in ideal case + // that all threads have the same full work amount (avoid of thread downtime) + // If it's impossible, it should be more than max thread count + // [115284]: Find solution for finding of optimal splitting in these cases + // For example, there are 16 threads and shape [6, 512, 32] + // LCM(6, 16) = 48 <- ideal work amount for parallelism + // new_shape [6, 48 / 6, 512 / (48 / 6), 32 ] => [6, 8, 64, 32] + // Each thread has parallelism_work_amount = 6 * 8 / nthrs = 3 + auto get_lcm = [](size_t a, size_t b) { + std::function get_gcd; + get_gcd = [&get_gcd](size_t a, size_t b) { + if (b == 0) + return a; + return get_gcd(b, a % b); + }; + return a / get_gcd(a, b) * b; + }; + const auto lcm = get_lcm(batch_dim, optimal_parallelism_work_amount); // LCM(b, nthrs) + const auto batch_dim_multiplier = lcm / batch_dim; // LCM(b, nthrs) / b + const auto needed_new_dim = m_dim / batch_dim_multiplier; // m / (LCM(b, nthrs) / b) - needed factors of dimension m + + auto is_optimized = [&](size_t batch_m_dim, size_t new_m_dim) { + return batch_m_dim != 1 && new_m_dim >= optimal_m_dim; + }; + + if (batch_dim_multiplier * needed_new_dim == m_dim) { + batch_m_dim = batch_dim_multiplier; + new_m_dim = needed_new_dim; + } + if (!is_optimized(batch_m_dim, new_m_dim)) { + auto get_factors = [](size_t dim) -> std::vector { + std::vector factors; + size_t div = 2; + while (div <= dim) { + const auto res = dim / div; + if (res * div == dim) { + factors.push_back(div); + dim = res; + } else { + div++; + } + } + return factors; + }; + const auto m_factors = get_factors(m_dim); + // If m_dim is Prime number + if (m_factors.size() == 2) + return; + + batch_m_dim = 1; + new_m_dim = m_dim; + size_t idx = 0; + // [115284] The current solution is not enough optimized. For more details please go to the ticket + while (batch_m_dim * batch_dim < optimal_parallelism_work_amount && idx < m_factors.size()) { + auto tmp_batch_m_dim = batch_m_dim * m_factors[idx]; + // There should be enough work for kernel execution + if (m_dim / tmp_batch_m_dim * n_dim < optimal_m_dim) + break; + batch_m_dim = tmp_batch_m_dim; + } + new_m_dim = m_dim / batch_m_dim; + } + + OPENVINO_ASSERT(batch_m_dim * new_m_dim == m_dim, "Incorrect dimension M splitting!"); + // nothing to split + if (!is_optimized(batch_m_dim, new_m_dim)) + return; + + /***** Reshape insertion *****/ + + // There are two Parameter variants: + // - Parameter on branches for Second input of MatMul - the shape should be only unsqueezed (add just 1) + // - Other Parameters (on First input of MatMuls and between) - the shape should be splitted on M dimension + + bool updated = false; + std::set> reshaped_params; + + auto insert_reshape = [&](const std::shared_ptr& param, const ov::Shape& new_shape) { + const auto index = std::distance(parameters.begin(), std::find(parameters.begin(), parameters.end(), param)); + const auto shape_const = std::make_shared(ov::element::i32, ov::Shape{new_shape.size()}, new_shape); + const auto reshape = std::make_shared(subgraph->input_value(index), shape_const, false); + subgraph->input(index).replace_source_output(reshape); + param->set_partial_shape(new_shape); + reshaped_params.insert(param); + updated = true; + }; + + auto get_updated_shape = [&](const ov::Shape& shape, bool split_m_dim) { + const auto current_m_dim = get_dim_M(shape); + OPENVINO_ASSERT(!split_m_dim || current_m_dim == 1 || current_m_dim == m_dim, "Incorrect shape for splitting!"); + ov::Shape new_shape = shape; + if ((split_m_dim && current_m_dim == 1) || !split_m_dim) { + new_shape.insert((new_shape.rbegin() + 2).base(), 1); + } else { + new_shape.insert((new_shape.rbegin() + 2).base(), batch_m_dim); + *(new_shape.rbegin() + 1) = new_m_dim; + } + OPENVINO_ASSERT(ov::shape_size(new_shape) == ov::shape_size(shape), "Incorrect shape splitting!"); + return new_shape; + }; + + auto reshape_parameter = [&](const std::shared_ptr& node, bool split_m_dim = true) { + const auto param = ov::as_type_ptr(node); + if (!param || reshaped_params.count(param) > 0) + return; + insert_reshape(param, get_updated_shape(param->get_partial_shape().get_shape(), split_m_dim)); + }; + + auto update_matmul_second_branch = [&](const std::shared_ptr& node) { + auto parent = node->get_input_node_shared_ptr(1); + while (!ov::is_type(parent)) { + if (parent->get_input_size() > 1) { + for (const auto& input_source : parent->input_values()) { + reshape_parameter(input_source.get_node_shared_ptr(), false); + } + } + + // [107731]: It's covered my MHA tokenization + parent = parent->get_input_node_shared_ptr(0); + } + reshape_parameter(parent, false); + }; + + // Firstly, Unsqueeze parameters on second branches of MatMuls + for (const auto& op : ops) { + if (ov::is_type(op)) { + update_matmul_second_branch(op); + } + } + + // Secondly, Update All M dimensions for remaining parameters + for (const auto& param : parameters) { + if (reshaped_params.count(param) == 0) + reshape_parameter(param, true); + } + + // Return the previous shape on outputs + for (size_t i = 0; i < subgraph->get_output_size() && updated; ++i) { + const auto output_shape = subgraph->get_output_shape(i); + if (is_scalar(output_shape)) + continue; + + const auto& target_inputs = subgraph->get_output_target_inputs(i); + const auto shape_const = std::make_shared(ov::element::i32, ov::Shape{output_shape.size()}, output_shape); + const auto reshape = std::make_shared(subgraph->output(i), shape_const, false); + // Save output name + const auto original_output = body->get_results()[i]->get_input_node_shared_ptr(0); + const auto original_name = original_output->get_friendly_name(); + reshape->set_friendly_name(original_name); + original_output->set_friendly_name(original_name + "_original"); + + for (const auto& input : target_inputs) { + input.replace_source_output(reshape); + // Result input tensor name was changed, the name has to be restored + if (ov::is_type(input.get_node())) { + input.get_tensor_ptr()->add_names(subgraph->output(i).get_tensor_ptr()->get_names()); + } + } + subgraph->output(i).get_tensor_ptr()->set_names({}); + updated = true; + } + subgraph->set_friendly_name(subgraph->get_friendly_name() + "_original"); + + // Need to update inner Shapes and Softmax Axis + if (updated) { + for (const auto &op : ops) { + if (const auto softmax_v8 = ov::as_type_ptr(op)) { + softmax_v8->set_axis(-1); + } else if (const auto softmax_v1 = ov::as_type_ptr(op)) { + softmax_v1->set_axis(softmax_v1->get_output_partial_shape(0).size()); // since new_shape.size() = old_shape.size() + 1 + } else if (const auto broadcast = ov::as_type_ptr(op)) { + // Broadcast is tokenized only between MatMuls -> Split M dimension + const auto shape_const = ov::as_type_ptr(broadcast->input_value(1).get_node_shared_ptr()); + OPENVINO_ASSERT(shape_const, "SplitDimensionM expects Broadcast with Constant output shape"); + const auto new_shape = get_updated_shape(shape_const->cast_vector(), true); + broadcast->set_argument(1, std::make_shared(ov::element::i32, ov::Shape{new_shape.size()}, new_shape)); + } + } + subgraph->validate_and_infer_types(); + } +} void CommonOptimizations::ExtractConstants(const std::shared_ptr& subgraph) { OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ExtractConstants"); @@ -98,9 +343,9 @@ void CommonOptimizations::ExtractUnsupportedTransposes(const std::shared_ptr(m.get_match_root()); @@ -130,6 +375,8 @@ CommonOptimizations::CommonOptimizations() { // Extract unsupported Transposes from body if (subgraph->has_domain_sensitive_ops()) { ExtractUnsupportedTransposes(subgraph); + if (config.split_m_dimension) + SplitDimensionM(subgraph, config.minimal_concurrency); } return true; }; diff --git a/src/common/snippets/src/pass/tokenization.cpp b/src/common/snippets/src/pass/tokenization.cpp index 13346efabef..2a57834ea2c 100644 --- a/src/common/snippets/src/pass/tokenization.cpp +++ b/src/common/snippets/src/pass/tokenization.cpp @@ -80,7 +80,7 @@ bool SnippetsTokenization::run_on_model(const std::shared_ptr& m) { manager.register_pass(); manager.register_pass(m_config); manager.register_pass(); - manager.register_pass(); + manager.register_pass(m_config); manager.run_passes(m); // Returning value is false because pass::Manager always apply Validation pass if function was changed. diff --git a/src/common/snippets/tests/include/pass/mha_tokenization.hpp b/src/common/snippets/tests/include/pass/mha_tokenization.hpp index 6b092209e98..c1be616f972 100644 --- a/src/common/snippets/tests/include/pass/mha_tokenization.hpp +++ b/src/common/snippets/tests/include/pass/mha_tokenization.hpp @@ -6,6 +6,8 @@ #include +#include "snippets/pass/tokenization.hpp" + namespace ov { namespace test { namespace snippets { @@ -13,6 +15,9 @@ namespace snippets { class TokenizeMHASnippetsTests : public TransformationTestsF { public: virtual void run(); + +protected: + ov::snippets::pass::SnippetsTokenization::Config config; }; } // namespace snippets diff --git a/src/common/snippets/tests/src/pass/mha_tokenization.cpp b/src/common/snippets/tests/src/pass/mha_tokenization.cpp index 86e211d5c1d..cc6aa41299b 100644 --- a/src/common/snippets/tests/src/pass/mha_tokenization.cpp +++ b/src/common/snippets/tests/src/pass/mha_tokenization.cpp @@ -17,7 +17,7 @@ void TokenizeMHASnippetsTests::run() { ASSERT_TRUE(function); manager.register_pass(); manager.register_pass(); - manager.register_pass(); + manager.register_pass(config); disable_rt_info_check(); } @@ -67,7 +67,25 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Transpose_fusion) { run(); } +TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_SplitM) { + const auto& f = MHAWOTransposeSplitMFunction(std::vector{{10, 9216, 128}, {10, 128, 9216}, {10, 9216, 128}}, + std::vector({ov::element::f32, ov::element::f32, ov::element::f32}), + std::vector{{10, 9, 1024, 128}, {10, 1, 128, 9216}, {10, 1, 9216, 128}, {10, 9216, 128}}); + function = f.getOriginal(); + function_ref = f.getReference(); + config.minimal_concurrency = 18; + run(); +} +TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHASelect_SplitM) { + const auto& f = MHASelectSplitMFunction(std::vector{{8, 512, 18}, {8, 18, 64}, {1, 512, 64}, {1, 1, 64}, {8, 64, 512}}, + std::vector{{8, 2, 256, 18}, {8, 1, 18, 64}, {1, 2, 256, 64}, {1, 1, 1, 64}, + {8, 1, 64, 512}, {8, 512, 512}}); + function = f.getOriginal(); + function_ref = f.getReference(); + config.minimal_concurrency = 16; + run(); +} } // namespace snippets } // namespace test diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 209c3ca3167..1ec5d40071d 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -627,13 +627,17 @@ void Transformations::MainSnippets(void) { !dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) // snippets are implemented only for relevant platforms (avx2+ extensions) return; + ov::snippets::pass::SnippetsTokenization::Config tokenization_config; // At the moment Snippets supports Transposes in MHA pattern only in FP32 case since // - ConvertSaturation[BF16->FP32] will be inserted after Parameters and before Transposes in canonicalization stage // - ConvertSaturation[FP32->BF16] will be inserted after Transposes and before Brgemm in precision propagation stage // Because of that Transposes won't be fused into Brgemm // TODO [111813]: Need to update this pipeline to avoid Converts between Transposes and Brgemm on inputs - ov::snippets::pass::SnippetsTokenization::Config tokenization_config; tokenization_config.mha_token_enable_transpose = (inferencePrecision == ov::element::f32); + tokenization_config.minimal_concurrency = parallel_get_num_threads(); + // The optimization "SplitDimensionM" depends on target machine (thread count). + // To avoid uncontrolled behavior in tests, we disabled the optimization when there is Config::SnippetsMode::IgnoreCallback + tokenization_config.split_m_dimension = snippetsMode != Config::SnippetsMode::IgnoreCallback; ngraph::pass::Manager snippetsManager; snippetsManager.set_per_pass_validation(false); @@ -687,8 +691,9 @@ void Transformations::MainSnippets(void) { // TODO: The heuristic will be removed after parallelism support on JIT level const auto needed_num_of_threads = 12lu; const auto is_unsupported_parallel_work_amount = - parallel_get_num_threads() / 2 > parallel_work_amount && - static_cast(parallel_work_amount) < needed_num_of_threads; + parallel_get_num_threads() / 2 > parallel_work_amount && + static_cast(parallel_work_amount) < needed_num_of_threads && + !ov::snippets::pass::CommonOptimizations::CanOptimizeParallelWA(n, tokenization_config.minimal_concurrency); return is_unsupported_parallel_work_amount; }, snippets::pass::TokenizeMHASnippets); diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp index f3392902094..8837a0aa610 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_mha.hpp @@ -45,8 +45,8 @@ class MHAFunction : public SnippetsFunctionBase { public: explicit MHAFunction(const std::vector& inputShapes, const std::vector& precisions, bool with_mul = true) : SnippetsFunctionBase(inputShapes), with_mul(with_mul), precisions(precisions) { - NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); - NGRAPH_CHECK(precisions.size() == 4, "Got invalid number of input precisions"); + OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes"); + OPENVINO_ASSERT(precisions.size() == 4, "Got invalid number of input precisions"); } protected: std::shared_ptr initOriginal() const override; @@ -75,8 +75,8 @@ class MHAMatMul0TransposeFunction : public SnippetsFunctionBase { public: explicit MHAMatMul0TransposeFunction(const std::vector& inputShapes, const std::vector& precisions) : SnippetsFunctionBase(inputShapes), precisions(precisions) { - NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); - NGRAPH_CHECK(precisions.size() == 4, "Got invalid number of input precisions"); + OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes"); + OPENVINO_ASSERT(precisions.size() == 4, "Got invalid number of input precisions"); } protected: std::shared_ptr initOriginal() const override; @@ -104,8 +104,8 @@ class MHASelectFunction : public SnippetsFunctionBase { public: explicit MHASelectFunction(const std::vector& inputShapes, const std::vector& precisions) : SnippetsFunctionBase(inputShapes), precisions(precisions) { - NGRAPH_CHECK(input_shapes.size() == 6, "Got invalid number of input shapes"); - NGRAPH_CHECK(precisions.size() == 6, "Got invalid number of input precisions"); + OPENVINO_ASSERT(input_shapes.size() == 6, "Got invalid number of input shapes"); + OPENVINO_ASSERT(precisions.size() == 6, "Got invalid number of input precisions"); } protected: std::shared_ptr initOriginal() const override; @@ -113,6 +113,22 @@ protected: std::vector precisions; }; +// Only for tokenization tests since boolean type->u8 +// Without Transposes +class MHASelectSplitMFunction : public SnippetsFunctionBase { +public: + explicit MHASelectSplitMFunction(const std::vector& inputShapes, const std::vector& reshapes) + : SnippetsFunctionBase(inputShapes), reshapes(reshapes) { + OPENVINO_ASSERT(input_shapes.size() == 5, "Got invalid number of input shapes"); + OPENVINO_ASSERT(reshapes.size() == 6, "Got invalid number of input precisions"); + } +protected: + std::shared_ptr initOriginal() const override; + std::shared_ptr initReference() const override; + + std::vector reshapes; +}; + /* Graph: * Constant * \ / @@ -129,7 +145,7 @@ protected: class MHAWOTransposeOnInputsFunction : public SnippetsFunctionBase { public: explicit MHAWOTransposeOnInputsFunction(const std::vector& inputShapes) : SnippetsFunctionBase(inputShapes) { - NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); + OPENVINO_ASSERT(input_shapes.size() == 3, "Got invalid number of input shapes"); } protected: std::shared_ptr initOriginal() const override; @@ -148,8 +164,8 @@ class MHAWOTransposeFunction : public SnippetsFunctionBase { public: explicit MHAWOTransposeFunction(const std::vector& inputShapes, const std::vector& precisions) : SnippetsFunctionBase(inputShapes), precisions(precisions) { - NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); - NGRAPH_CHECK(precisions.size() == 3, "Got invalid number of input precisions"); + OPENVINO_ASSERT(input_shapes.size() == 3, "Got invalid number of input shapes"); + OPENVINO_ASSERT(precisions.size() == 3, "Got invalid number of input precisions"); } protected: std::shared_ptr initOriginal() const override; @@ -157,6 +173,19 @@ protected: std::vector precisions; }; +class MHAWOTransposeSplitMFunction : public MHAWOTransposeFunction { +public: + explicit MHAWOTransposeSplitMFunction(const std::vector& inputShapes, const std::vector& precisions, + const std::vector& reshapes) + : MHAWOTransposeFunction(inputShapes, precisions), reshapes(reshapes) { + OPENVINO_ASSERT(reshapes.size() == 4, "Got invalid number of Reshape shapes"); + } +protected: + std::shared_ptr initReference() const override; + + std::vector reshapes; +}; + /* Graph: * Transpose0[0,2,1,3] Transpose1[0,2,3,1] * \ / @@ -176,7 +205,7 @@ class MHAFQAfterMatMulFunction : public SnippetsFunctionBase { public: explicit MHAFQAfterMatMulFunction(const std::vector& inputShapes) : SnippetsFunctionBase(inputShapes) { - NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); + OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes"); } protected: std::shared_ptr initOriginal() const override; @@ -203,7 +232,7 @@ class MHAINT8MatMulFunction : public SnippetsFunctionBase { public: explicit MHAINT8MatMulFunction(const std::vector& inputShapes) : SnippetsFunctionBase(inputShapes) { - NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); + OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes"); } protected: std::shared_ptr initOriginal() const override; @@ -232,7 +261,7 @@ class MHAFQFunction : public SnippetsFunctionBase { public: explicit MHAFQFunction(const std::vector& inputShapes) : SnippetsFunctionBase(inputShapes) { - NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); + OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes"); } protected: std::shared_ptr initOriginal() const override; @@ -261,7 +290,7 @@ class MHAINT8MatMulTypeRelaxedFunction : public SnippetsFunctionBase { public: explicit MHAINT8MatMulTypeRelaxedFunction(const std::vector& inputShapes) : SnippetsFunctionBase(inputShapes) { - NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes"); + OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes"); } protected: std::shared_ptr initOriginal() const override; @@ -283,7 +312,7 @@ protected: class MHAMulAddFunction : public SnippetsFunctionBase { public: explicit MHAMulAddFunction(const std::vector& inputShapes) : SnippetsFunctionBase(inputShapes) { - NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); + OPENVINO_ASSERT(input_shapes.size() == 3, "Got invalid number of input shapes"); } protected: std::shared_ptr initOriginal() const override; @@ -304,7 +333,7 @@ public: explicit MHATransposedInputFunction(const std::vector& inputShapes, bool transposed_b = false, std::vector order = {}) : SnippetsFunctionBase(inputShapes), m_transposed_b(transposed_b), m_order(order) { - NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); + OPENVINO_ASSERT(input_shapes.size() == 3, "Got invalid number of input shapes"); } protected: std::shared_ptr initOriginal() const override; diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp index 440e3607a2f..7a51db410cd 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_mha.cpp @@ -324,6 +324,105 @@ std::shared_ptr MHASelectFunction::initOriginal() const { return std::make_shared(results, ngraphParam, "mha"); } +std::shared_ptr MHASelectSplitMFunction::initOriginal() const { + auto transpose0Param = std::make_shared(precision, input_shapes[0]); + auto transpose1Param = std::make_shared(precision, input_shapes[1]); + auto addParam = std::make_shared(precision, input_shapes[2]); + auto selectParam = std::make_shared(ov::element::u8, input_shapes[3]); + auto transpose2Param = std::make_shared(precision, input_shapes[4]); + ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, selectParam, transpose2Param}; + + // Value is equal to '1' - to avoid situation e^(-1000) / (sum(e^(-1000)) = 0/0 = NAN + auto selectConst = ngraph::builder::makeConstant(precision, ov::Shape{1}, std::vector{1}); + + const auto matMul0 = std::make_shared(transpose0Param, transpose1Param); + const auto add = std::make_shared(matMul0, addParam); + std::shared_ptr selectCond = selectParam; + if (add->get_output_partial_shape(0) != selectParam->get_output_partial_shape(0)) { + const auto broadcast_shape = + ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{add->get_output_shape(0).size()}, add->get_output_shape(0)); + selectCond = std::make_shared(selectCond, broadcast_shape); + } + const auto select = std::make_shared>( + std::vector{ element::boolean, element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(selectCond, element::boolean).get(), + ov::op::TemporaryReplaceOutputType(selectConst, element::f32).get(), + ov::op::TemporaryReplaceOutputType(add, element::f32).get()); + + const auto interm_shape = select->get_shape(); + std::vector reshape0ConstData = {-1, static_cast(interm_shape.back())}; + auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{reshape0ConstData.size()}, reshape0ConstData); + + std::vector reshape1ConstData; + for (const auto& dim : interm_shape) + reshape1ConstData.push_back(static_cast(dim)); + auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{reshape1ConstData.size()}, reshape1ConstData); + + const auto reshape0 = std::make_shared(select, reshape0Const, true); + const auto softMax = std::make_shared(reshape0, 1); + const auto reshape1 = std::make_shared(softMax, reshape1Const, true); + const auto matMul1 = std::make_shared(reshape1, transpose2Param); + + ngraph::ResultVector results{std::make_shared(matMul1)}; + return std::make_shared(results, ngraphParam, "mha"); +} + +std::shared_ptr MHASelectSplitMFunction::initReference() const { + auto param0 = std::make_shared(precision, input_shapes[0]); + auto param1 = std::make_shared(precision, input_shapes[1]); + auto addParam = std::make_shared(precision, input_shapes[2]); + auto selectParam = std::make_shared(ov::element::u8, input_shapes[3]); + auto param2 = std::make_shared(precision, input_shapes[4]); + ngraph::ParameterVector ngraphParam = {param0, param1, addParam, selectParam, param2}; + + auto make_reshape = [](const std::shared_ptr& node, const ov::Shape& new_shape) { + auto shape_const = ngraph::builder::makeConstant(ngraph::element::i32, {new_shape.size()}, new_shape); + return std::make_shared(node, shape_const, true); + }; + + auto reshape0 = make_reshape(param0, reshapes[0]); + auto reshape1 = make_reshape(param1, reshapes[1]); + auto reshapeAdd = make_reshape(addParam, reshapes[2]); + auto reshapeSelect = make_reshape(selectParam, reshapes[3]); + auto reshape2 = make_reshape(param2, reshapes[4]); + + auto data0 = std::make_shared(reshape0->get_element_type(), reshape0->get_shape()); + auto data1 = std::make_shared(reshape1->get_element_type(), reshape1->get_shape()); + auto dataAdd = std::make_shared(reshapeAdd->get_element_type(), reshapeAdd->get_shape()); + auto dataSelect = std::make_shared(reshapeSelect->get_element_type(), reshapeSelect->get_shape()); + auto data2 = std::make_shared(reshape2->get_element_type(), reshape2->get_shape()); + + const auto matMul0 = std::make_shared(data0, data1); + const auto add = std::make_shared(matMul0, dataAdd); + + // Value is equal to '1' - to avoid situation e^(-1000) / (sum(e^(-1000)) = 0/0 = NAN + auto selectConst = ngraph::builder::makeConstant(precision, ov::Shape{1}, std::vector{1}); + std::shared_ptr selectCond = dataSelect; + if (add->get_output_partial_shape(0) != dataSelect->get_output_partial_shape(0)) { + const auto broadcast_shape = + ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{add->get_output_shape(0).size()}, add->get_output_shape(0)); + selectCond = std::make_shared(selectCond, broadcast_shape); + } + const auto select = std::make_shared>( + std::vector{ element::boolean, element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(selectCond, element::boolean).get(), + ov::op::TemporaryReplaceOutputType(selectConst, element::f32).get(), + ov::op::TemporaryReplaceOutputType(add, element::f32).get()); + + const auto softMax = std::make_shared(select, add->get_shape().size() - 1); + const auto matMul1 = std::make_shared(softMax, data2); + + const auto subgraph = + std::make_shared( + ov::NodeVector{reshape0, reshape1, reshapeAdd, reshapeSelect, reshape2}, + std::make_shared(ov::OutputVector{matMul1}, ov::ParameterVector{data0, data1, dataAdd, dataSelect, data2})); + auto reshape3 = make_reshape(subgraph, reshapes[5]); + ngraph::ResultVector results{std::make_shared(reshape3)}; + return std::make_shared(results, ngraphParam, "mha"); +} + std::shared_ptr MHAWOTransposeOnInputsFunction::initOriginal() const { auto param0 = std::make_shared(precision, input_shapes[0]); auto param1 = std::make_shared(precision, input_shapes[1]); @@ -361,6 +460,37 @@ std::shared_ptr MHAWOTransposeFunction::initOriginal() const { return std::make_shared(results, ngraphParam, "mha"); } +std::shared_ptr MHAWOTransposeSplitMFunction::initReference() const { + auto param0 = std::make_shared(precisions[0], input_shapes[0]); + auto param1 = std::make_shared(precisions[1], input_shapes[1]); + auto param2 = std::make_shared(precisions[2], input_shapes[2]); + ngraph::ParameterVector ngraphParam = {param0, param1, param2}; + + auto make_reshape = [](const std::shared_ptr& node, const ov::Shape& new_shape) { + auto shape_const = ngraph::builder::makeConstant(ngraph::element::i32, {new_shape.size()}, new_shape); + return std::make_shared(node, shape_const, true); + }; + + auto reshape0 = make_reshape(param0, reshapes[0]); + auto reshape1 = make_reshape(param1, reshapes[1]); + auto reshape2 = make_reshape(param2, reshapes[2]); + + auto data0 = std::make_shared(precisions[0], reshape0->get_shape()); + auto data1 = std::make_shared(precisions[1], reshape1->get_shape()); + auto data2 = std::make_shared(precisions[2], reshape2->get_shape()); + + const auto matMul0 = std::make_shared(data0, data1); + const auto softmax = std::make_shared(matMul0, -1); + const auto matMul1 = std::make_shared(softmax, data2); + + const auto subgraph = std::make_shared(ov::NodeVector{reshape0, reshape1, reshape2}, + std::make_shared(ov::OutputVector{matMul1}, + ov::ParameterVector{data0, data1, data2})); + auto reshape3 = make_reshape(subgraph, reshapes[3]); + ngraph::ResultVector results{std::make_shared(reshape3)}; + return std::make_shared(results, ngraphParam, "mha"); +} + std::shared_ptr MHAFQAfterMatMulFunction::initOriginal() const { auto transpose0Param = std::make_shared(precision, input_shapes[0]); auto transpose1Param = std::make_shared(precision, input_shapes[1]);