[Snippets] Added SplitDimensionM optimization (#18160)

This commit is contained in:
Alexandra Sidorova 2023-07-14 09:31:24 +04:00 committed by GitHub
parent b7935bb869
commit 67c88f4434
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 470 additions and 25 deletions

View File

@ -5,8 +5,8 @@
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "snippets/op/subgraph.hpp"
#include "snippets/pass/tokenization.hpp"
namespace ov {
namespace snippets {
@ -15,13 +15,20 @@ namespace pass {
class CommonOptimizations : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("CommonOptimizations", "0");
CommonOptimizations();
CommonOptimizations(const SnippetsTokenization::Config& config = {});
// Returns True if parallelism work amount can be increased using SplitDimensionM optimization
static bool CanOptimizeParallelWA(const std::shared_ptr<const ov::Node>& node, size_t minimal_concurrency);
private:
// Move up Constants which aren't scalars from body to Subgraph and replace them with Parameters inside body
void ExtractConstants(const std::shared_ptr<op::Subgraph>& subgraph);
// Move up unsupported Transposes on Parameter outputs from body
void ExtractUnsupportedTransposes(const std::shared_ptr<op::Subgraph>& subgraph);
// Insert Reshape nodes after and before Parameters and Results in Subgraphs with MatMul inside
// to split dimension M for MatMuls to increase work amount for parallelism
// Note: works only with 3D MHA patterns
void SplitDimensionM(const std::shared_ptr<op::Subgraph>& subgraph, size_t minimal_concurrency);
};
} // namespace pass

View File

@ -61,8 +61,12 @@ public:
* @ingroup snippets
*/
struct Config {
Config(bool enable_transpose = true) : mha_token_enable_transpose(enable_transpose) {}
Config(size_t minimal_concurrency = 1, bool split_m_dimension = true, bool enable_transpose = true)
: minimal_concurrency(minimal_concurrency), split_m_dimension(split_m_dimension), mha_token_enable_transpose(enable_transpose) {}
size_t minimal_concurrency = 1;
// True if "SplitDimensionM" optimization is enabled. Otherwise, it's disabled.
bool split_m_dimension = true;
// False if all Transposes aren't tokenized in MHA Tokenization.
// Otherwise, they may be fused into Subgraph if possible
// TODO [106921]: Remove please when the ticket 106921 is implemented

View File

@ -19,6 +19,251 @@ namespace ov {
namespace snippets {
namespace pass {
bool CommonOptimizations::CanOptimizeParallelWA(const std::shared_ptr<const ov::Node>& node, size_t minimal_concurrency) {
if (!ov::is_type<ov::op::v0::MatMul>(node))
return false;
// It's needed only for 3D MHA patterns
const auto mm_shape = node->get_shape();
if (mm_shape.size() != 3)
return false;
const auto current_parallel_work_amount =
std::accumulate(mm_shape.rbegin() + 2, mm_shape.rend(), size_t(1), std::multiplies<size_t>());
const auto dim_M = *(mm_shape.rbegin() + 1);
return (current_parallel_work_amount < minimal_concurrency) &&
(current_parallel_work_amount * dim_M >= minimal_concurrency);
}
void CommonOptimizations::SplitDimensionM(const std::shared_ptr<ov::snippets::op::Subgraph>& subgraph, size_t minimal_concurrency) {
// To increase parallelism work in 3D cases for MHA pattern,
// we split 1st dimension (starting from 0th) into 2 new dimensions to get 4D Shapes where
// - 0th and 1st dimensions are used in parallel scheduling,
// - 2nd and 3rd dimensions are used in kernel
// Note: 3D Patterns don't contain Transpose inside so the reshaping is valid
// It's needed only for MHA patterns. Need to add support for common patterns
if (!subgraph->has_domain_sensitive_ops())
return;
const auto& body = subgraph->body_ptr();
const auto& parameters = body->get_parameters();
// [107806]: If count of Parameters isn't equal to Subgraph inputs (it's possible case in general),
// we cannot garantee correct extraction since we don't have correct connections between body I/O and Subgraph I/O.
OPENVINO_ASSERT(parameters.size() == subgraph->input_values().size(),
"Failed to extract unsupported transposes: the count of Parameters isn't equal to Subgraph inputs");
// Need to find MatMul0 and check output shape
const auto& ops = body->get_ordered_ops();
const auto mm_it = std::find_if(ops.begin(), ops.end(),
[](const std::shared_ptr<ov::Node>& node){ return ov::is_type<ov::op::v0::MatMul>(node); });
if (mm_it == ops.end())
return;
const auto matmul0 = ov::as_type_ptr<ov::op::v0::MatMul>(*mm_it);
if (!matmul0 || !CanOptimizeParallelWA(matmul0, minimal_concurrency))
return;
auto get_dim_M = [](const ov::Shape& shape) {
return *(shape.rbegin() + 1);
};
const auto mm_shape = matmul0->get_shape();
const auto m_dim = get_dim_M(mm_shape); // M
const auto n_dim = mm_shape.back(); // N
// [113745] Heurestic is equal to double block size.
// When this optimization will be moved into Subgraph and blocking param will be implemented as dependents of shapes,
// need to implement common way (for all backends) to calculate optimal value of M dimension
const auto optimal_m_dim = 32 * 2;
const auto optimal_parallelism_work_amount = minimal_concurrency;
if (m_dim <= optimal_m_dim)
return;
const auto batch_dim =
std::accumulate(mm_shape.rbegin() + 2, mm_shape.rend(), size_t(1), std::multiplies<size_t>()); // B (batch)
size_t batch_m_dim = 1;
size_t new_m_dim = m_dim;
// Need to find optimized dimension splitting: [b1..bk, m, n] -> [b1..bk, batch_m_dim, new_m_dim, n]
// The work amount for parallelism should be divided by max thread count in ideal case
// that all threads have the same full work amount (avoid of thread downtime)
// If it's impossible, it should be more than max thread count
// [115284]: Find solution for finding of optimal splitting in these cases
// For example, there are 16 threads and shape [6, 512, 32]
// LCM(6, 16) = 48 <- ideal work amount for parallelism
// new_shape [6, 48 / 6, 512 / (48 / 6), 32 ] => [6, 8, 64, 32]
// Each thread has parallelism_work_amount = 6 * 8 / nthrs = 3
auto get_lcm = [](size_t a, size_t b) {
std::function<size_t(size_t, size_t)> get_gcd;
get_gcd = [&get_gcd](size_t a, size_t b) {
if (b == 0)
return a;
return get_gcd(b, a % b);
};
return a / get_gcd(a, b) * b;
};
const auto lcm = get_lcm(batch_dim, optimal_parallelism_work_amount); // LCM(b, nthrs)
const auto batch_dim_multiplier = lcm / batch_dim; // LCM(b, nthrs) / b
const auto needed_new_dim = m_dim / batch_dim_multiplier; // m / (LCM(b, nthrs) / b) - needed factors of dimension m
auto is_optimized = [&](size_t batch_m_dim, size_t new_m_dim) {
return batch_m_dim != 1 && new_m_dim >= optimal_m_dim;
};
if (batch_dim_multiplier * needed_new_dim == m_dim) {
batch_m_dim = batch_dim_multiplier;
new_m_dim = needed_new_dim;
}
if (!is_optimized(batch_m_dim, new_m_dim)) {
auto get_factors = [](size_t dim) -> std::vector<size_t> {
std::vector<size_t> factors;
size_t div = 2;
while (div <= dim) {
const auto res = dim / div;
if (res * div == dim) {
factors.push_back(div);
dim = res;
} else {
div++;
}
}
return factors;
};
const auto m_factors = get_factors(m_dim);
// If m_dim is Prime number
if (m_factors.size() == 2)
return;
batch_m_dim = 1;
new_m_dim = m_dim;
size_t idx = 0;
// [115284] The current solution is not enough optimized. For more details please go to the ticket
while (batch_m_dim * batch_dim < optimal_parallelism_work_amount && idx < m_factors.size()) {
auto tmp_batch_m_dim = batch_m_dim * m_factors[idx];
// There should be enough work for kernel execution
if (m_dim / tmp_batch_m_dim * n_dim < optimal_m_dim)
break;
batch_m_dim = tmp_batch_m_dim;
}
new_m_dim = m_dim / batch_m_dim;
}
OPENVINO_ASSERT(batch_m_dim * new_m_dim == m_dim, "Incorrect dimension M splitting!");
// nothing to split
if (!is_optimized(batch_m_dim, new_m_dim))
return;
/***** Reshape insertion *****/
// There are two Parameter variants:
// - Parameter on branches for Second input of MatMul - the shape should be only unsqueezed (add just 1)
// - Other Parameters (on First input of MatMuls and between) - the shape should be splitted on M dimension
bool updated = false;
std::set<std::shared_ptr<ov::op::v0::Parameter>> reshaped_params;
auto insert_reshape = [&](const std::shared_ptr<ov::op::v0::Parameter>& param, const ov::Shape& new_shape) {
const auto index = std::distance(parameters.begin(), std::find(parameters.begin(), parameters.end(), param));
const auto shape_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{new_shape.size()}, new_shape);
const auto reshape = std::make_shared<ov::op::v1::Reshape>(subgraph->input_value(index), shape_const, false);
subgraph->input(index).replace_source_output(reshape);
param->set_partial_shape(new_shape);
reshaped_params.insert(param);
updated = true;
};
auto get_updated_shape = [&](const ov::Shape& shape, bool split_m_dim) {
const auto current_m_dim = get_dim_M(shape);
OPENVINO_ASSERT(!split_m_dim || current_m_dim == 1 || current_m_dim == m_dim, "Incorrect shape for splitting!");
ov::Shape new_shape = shape;
if ((split_m_dim && current_m_dim == 1) || !split_m_dim) {
new_shape.insert((new_shape.rbegin() + 2).base(), 1);
} else {
new_shape.insert((new_shape.rbegin() + 2).base(), batch_m_dim);
*(new_shape.rbegin() + 1) = new_m_dim;
}
OPENVINO_ASSERT(ov::shape_size(new_shape) == ov::shape_size(shape), "Incorrect shape splitting!");
return new_shape;
};
auto reshape_parameter = [&](const std::shared_ptr<ov::Node>& node, bool split_m_dim = true) {
const auto param = ov::as_type_ptr<ov::op::v0::Parameter>(node);
if (!param || reshaped_params.count(param) > 0)
return;
insert_reshape(param, get_updated_shape(param->get_partial_shape().get_shape(), split_m_dim));
};
auto update_matmul_second_branch = [&](const std::shared_ptr<ov::Node>& node) {
auto parent = node->get_input_node_shared_ptr(1);
while (!ov::is_type<ov::op::v0::Parameter>(parent)) {
if (parent->get_input_size() > 1) {
for (const auto& input_source : parent->input_values()) {
reshape_parameter(input_source.get_node_shared_ptr(), false);
}
}
// [107731]: It's covered my MHA tokenization
parent = parent->get_input_node_shared_ptr(0);
}
reshape_parameter(parent, false);
};
// Firstly, Unsqueeze parameters on second branches of MatMuls
for (const auto& op : ops) {
if (ov::is_type<ov::op::v0::MatMul>(op)) {
update_matmul_second_branch(op);
}
}
// Secondly, Update All M dimensions for remaining parameters
for (const auto& param : parameters) {
if (reshaped_params.count(param) == 0)
reshape_parameter(param, true);
}
// Return the previous shape on outputs
for (size_t i = 0; i < subgraph->get_output_size() && updated; ++i) {
const auto output_shape = subgraph->get_output_shape(i);
if (is_scalar(output_shape))
continue;
const auto& target_inputs = subgraph->get_output_target_inputs(i);
const auto shape_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{output_shape.size()}, output_shape);
const auto reshape = std::make_shared<ov::op::v1::Reshape>(subgraph->output(i), shape_const, false);
// Save output name
const auto original_output = body->get_results()[i]->get_input_node_shared_ptr(0);
const auto original_name = original_output->get_friendly_name();
reshape->set_friendly_name(original_name);
original_output->set_friendly_name(original_name + "_original");
for (const auto& input : target_inputs) {
input.replace_source_output(reshape);
// Result input tensor name was changed, the name has to be restored
if (ov::is_type<ov::op::v0::Result>(input.get_node())) {
input.get_tensor_ptr()->add_names(subgraph->output(i).get_tensor_ptr()->get_names());
}
}
subgraph->output(i).get_tensor_ptr()->set_names({});
updated = true;
}
subgraph->set_friendly_name(subgraph->get_friendly_name() + "_original");
// Need to update inner Shapes and Softmax Axis
if (updated) {
for (const auto &op : ops) {
if (const auto softmax_v8 = ov::as_type_ptr<ov::op::v8::Softmax>(op)) {
softmax_v8->set_axis(-1);
} else if (const auto softmax_v1 = ov::as_type_ptr<ov::op::v1::Softmax>(op)) {
softmax_v1->set_axis(softmax_v1->get_output_partial_shape(0).size()); // since new_shape.size() = old_shape.size() + 1
} else if (const auto broadcast = ov::as_type_ptr<ov::op::v1::Broadcast>(op)) {
// Broadcast is tokenized only between MatMuls -> Split M dimension
const auto shape_const = ov::as_type_ptr<ov::op::v0::Constant>(broadcast->input_value(1).get_node_shared_ptr());
OPENVINO_ASSERT(shape_const, "SplitDimensionM expects Broadcast with Constant output shape");
const auto new_shape = get_updated_shape(shape_const->cast_vector<size_t>(), true);
broadcast->set_argument(1, std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{new_shape.size()}, new_shape));
}
}
subgraph->validate_and_infer_types();
}
}
void CommonOptimizations::ExtractConstants(const std::shared_ptr<ov::snippets::op::Subgraph>& subgraph) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ExtractConstants");
@ -98,9 +343,9 @@ void CommonOptimizations::ExtractUnsupportedTransposes(const std::shared_ptr<op:
}
}
CommonOptimizations::CommonOptimizations() {
CommonOptimizations::CommonOptimizations(const SnippetsTokenization::Config& config) {
MATCHER_SCOPE(CommonOptimizations);
ov::graph_rewrite_callback callback = [this](ov::pass::pattern::Matcher& m) {
ov::graph_rewrite_callback callback = [&](ov::pass::pattern::Matcher& m) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::CommonOptimizations");
auto subgraph = ov::as_type_ptr<ov::snippets::op::Subgraph>(m.get_match_root());
@ -130,6 +375,8 @@ CommonOptimizations::CommonOptimizations() {
// Extract unsupported Transposes from body
if (subgraph->has_domain_sensitive_ops()) {
ExtractUnsupportedTransposes(subgraph);
if (config.split_m_dimension)
SplitDimensionM(subgraph, config.minimal_concurrency);
}
return true;
};

View File

@ -80,7 +80,7 @@ bool SnippetsTokenization::run_on_model(const std::shared_ptr<ov::Model>& m) {
manager.register_pass<EnumerateNodes>();
manager.register_pass<TokenizeMHASnippets>(m_config);
manager.register_pass<TokenizeSnippets>();
manager.register_pass<CommonOptimizations>();
manager.register_pass<CommonOptimizations>(m_config);
manager.run_passes(m);
// Returning value is false because pass::Manager always apply Validation pass if function was changed.

View File

@ -6,6 +6,8 @@
#include <common_test_utils/ngraph_test_utils.hpp>
#include "snippets/pass/tokenization.hpp"
namespace ov {
namespace test {
namespace snippets {
@ -13,6 +15,9 @@ namespace snippets {
class TokenizeMHASnippetsTests : public TransformationTestsF {
public:
virtual void run();
protected:
ov::snippets::pass::SnippetsTokenization::Config config;
};
} // namespace snippets

View File

@ -17,7 +17,7 @@ void TokenizeMHASnippetsTests::run() {
ASSERT_TRUE(function);
manager.register_pass<ov::snippets::pass::EnumerateNodes>();
manager.register_pass<ov::snippets::pass::TokenizeMHASnippets>();
manager.register_pass<ov::snippets::pass::CommonOptimizations>();
manager.register_pass<ov::snippets::pass::CommonOptimizations>(config);
disable_rt_info_check();
}
@ -67,7 +67,25 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Transpose_fusion) {
run();
}
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_SplitM) {
const auto& f = MHAWOTransposeSplitMFunction(std::vector<PartialShape>{{10, 9216, 128}, {10, 128, 9216}, {10, 9216, 128}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{10, 9, 1024, 128}, {10, 1, 128, 9216}, {10, 1, 9216, 128}, {10, 9216, 128}});
function = f.getOriginal();
function_ref = f.getReference();
config.minimal_concurrency = 18;
run();
}
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHASelect_SplitM) {
const auto& f = MHASelectSplitMFunction(std::vector<PartialShape>{{8, 512, 18}, {8, 18, 64}, {1, 512, 64}, {1, 1, 64}, {8, 64, 512}},
std::vector<Shape>{{8, 2, 256, 18}, {8, 1, 18, 64}, {1, 2, 256, 64}, {1, 1, 1, 64},
{8, 1, 64, 512}, {8, 512, 512}});
function = f.getOriginal();
function_ref = f.getReference();
config.minimal_concurrency = 16;
run();
}
} // namespace snippets
} // namespace test

View File

@ -627,13 +627,17 @@ void Transformations::MainSnippets(void) {
!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) // snippets are implemented only for relevant platforms (avx2+ extensions)
return;
ov::snippets::pass::SnippetsTokenization::Config tokenization_config;
// At the moment Snippets supports Transposes in MHA pattern only in FP32 case since
// - ConvertSaturation[BF16->FP32] will be inserted after Parameters and before Transposes in canonicalization stage
// - ConvertSaturation[FP32->BF16] will be inserted after Transposes and before Brgemm in precision propagation stage
// Because of that Transposes won't be fused into Brgemm
// TODO [111813]: Need to update this pipeline to avoid Converts between Transposes and Brgemm on inputs
ov::snippets::pass::SnippetsTokenization::Config tokenization_config;
tokenization_config.mha_token_enable_transpose = (inferencePrecision == ov::element::f32);
tokenization_config.minimal_concurrency = parallel_get_num_threads();
// The optimization "SplitDimensionM" depends on target machine (thread count).
// To avoid uncontrolled behavior in tests, we disabled the optimization when there is Config::SnippetsMode::IgnoreCallback
tokenization_config.split_m_dimension = snippetsMode != Config::SnippetsMode::IgnoreCallback;
ngraph::pass::Manager snippetsManager;
snippetsManager.set_per_pass_validation(false);
@ -687,8 +691,9 @@ void Transformations::MainSnippets(void) {
// TODO: The heuristic will be removed after parallelism support on JIT level
const auto needed_num_of_threads = 12lu;
const auto is_unsupported_parallel_work_amount =
parallel_get_num_threads() / 2 > parallel_work_amount &&
static_cast<size_t>(parallel_work_amount) < needed_num_of_threads;
parallel_get_num_threads() / 2 > parallel_work_amount &&
static_cast<size_t>(parallel_work_amount) < needed_num_of_threads &&
!ov::snippets::pass::CommonOptimizations::CanOptimizeParallelWA(n, tokenization_config.minimal_concurrency);
return is_unsupported_parallel_work_amount;
},
snippets::pass::TokenizeMHASnippets);

View File

@ -45,8 +45,8 @@ class MHAFunction : public SnippetsFunctionBase {
public:
explicit MHAFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions, bool with_mul = true)
: SnippetsFunctionBase(inputShapes), with_mul(with_mul), precisions(precisions) {
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
NGRAPH_CHECK(precisions.size() == 4, "Got invalid number of input precisions");
OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes");
OPENVINO_ASSERT(precisions.size() == 4, "Got invalid number of input precisions");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
@ -75,8 +75,8 @@ class MHAMatMul0TransposeFunction : public SnippetsFunctionBase {
public:
explicit MHAMatMul0TransposeFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
NGRAPH_CHECK(precisions.size() == 4, "Got invalid number of input precisions");
OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes");
OPENVINO_ASSERT(precisions.size() == 4, "Got invalid number of input precisions");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
@ -104,8 +104,8 @@ class MHASelectFunction : public SnippetsFunctionBase {
public:
explicit MHASelectFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
NGRAPH_CHECK(input_shapes.size() == 6, "Got invalid number of input shapes");
NGRAPH_CHECK(precisions.size() == 6, "Got invalid number of input precisions");
OPENVINO_ASSERT(input_shapes.size() == 6, "Got invalid number of input shapes");
OPENVINO_ASSERT(precisions.size() == 6, "Got invalid number of input precisions");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
@ -113,6 +113,22 @@ protected:
std::vector<ov::element::Type> precisions;
};
// Only for tokenization tests since boolean type->u8
// Without Transposes
class MHASelectSplitMFunction : public SnippetsFunctionBase {
public:
explicit MHASelectSplitMFunction(const std::vector<PartialShape>& inputShapes, const std::vector<Shape>& reshapes)
: SnippetsFunctionBase(inputShapes), reshapes(reshapes) {
OPENVINO_ASSERT(input_shapes.size() == 5, "Got invalid number of input shapes");
OPENVINO_ASSERT(reshapes.size() == 6, "Got invalid number of input precisions");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
std::shared_ptr<ov::Model> initReference() const override;
std::vector<Shape> reshapes;
};
/* Graph:
* Constant
* \ /
@ -129,7 +145,7 @@ protected:
class MHAWOTransposeOnInputsFunction : public SnippetsFunctionBase {
public:
explicit MHAWOTransposeOnInputsFunction(const std::vector<PartialShape>& inputShapes) : SnippetsFunctionBase(inputShapes) {
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
OPENVINO_ASSERT(input_shapes.size() == 3, "Got invalid number of input shapes");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
@ -148,8 +164,8 @@ class MHAWOTransposeFunction : public SnippetsFunctionBase {
public:
explicit MHAWOTransposeFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
NGRAPH_CHECK(precisions.size() == 3, "Got invalid number of input precisions");
OPENVINO_ASSERT(input_shapes.size() == 3, "Got invalid number of input shapes");
OPENVINO_ASSERT(precisions.size() == 3, "Got invalid number of input precisions");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
@ -157,6 +173,19 @@ protected:
std::vector<ov::element::Type> precisions;
};
class MHAWOTransposeSplitMFunction : public MHAWOTransposeFunction {
public:
explicit MHAWOTransposeSplitMFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions,
const std::vector<Shape>& reshapes)
: MHAWOTransposeFunction(inputShapes, precisions), reshapes(reshapes) {
OPENVINO_ASSERT(reshapes.size() == 4, "Got invalid number of Reshape shapes");
}
protected:
std::shared_ptr<ov::Model> initReference() const override;
std::vector<ov::Shape> reshapes;
};
/* Graph:
* Transpose0[0,2,1,3] Transpose1[0,2,3,1]
* \ /
@ -176,7 +205,7 @@ class MHAFQAfterMatMulFunction : public SnippetsFunctionBase {
public:
explicit MHAFQAfterMatMulFunction(const std::vector<PartialShape>& inputShapes)
: SnippetsFunctionBase(inputShapes) {
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
@ -203,7 +232,7 @@ class MHAINT8MatMulFunction : public SnippetsFunctionBase {
public:
explicit MHAINT8MatMulFunction(const std::vector<PartialShape>& inputShapes)
: SnippetsFunctionBase(inputShapes) {
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
@ -232,7 +261,7 @@ class MHAFQFunction : public SnippetsFunctionBase {
public:
explicit MHAFQFunction(const std::vector<PartialShape>& inputShapes)
: SnippetsFunctionBase(inputShapes) {
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
@ -261,7 +290,7 @@ class MHAINT8MatMulTypeRelaxedFunction : public SnippetsFunctionBase {
public:
explicit MHAINT8MatMulTypeRelaxedFunction(const std::vector<PartialShape>& inputShapes)
: SnippetsFunctionBase(inputShapes) {
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
@ -283,7 +312,7 @@ protected:
class MHAMulAddFunction : public SnippetsFunctionBase {
public:
explicit MHAMulAddFunction(const std::vector<PartialShape>& inputShapes) : SnippetsFunctionBase(inputShapes) {
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
OPENVINO_ASSERT(input_shapes.size() == 3, "Got invalid number of input shapes");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
@ -304,7 +333,7 @@ public:
explicit MHATransposedInputFunction(const std::vector<PartialShape>& inputShapes, bool transposed_b = false,
std::vector<int64_t> order = {})
: SnippetsFunctionBase(inputShapes), m_transposed_b(transposed_b), m_order(order) {
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
OPENVINO_ASSERT(input_shapes.size() == 3, "Got invalid number of input shapes");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;

View File

@ -324,6 +324,105 @@ std::shared_ptr<ov::Model> MHASelectFunction::initOriginal() const {
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
std::shared_ptr<ov::Model> MHASelectSplitMFunction::initOriginal() const {
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
auto selectParam = std::make_shared<ngraph::opset1::Parameter>(ov::element::u8, input_shapes[3]);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[4]);
ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, selectParam, transpose2Param};
// Value is equal to '1' - to avoid situation e^(-1000) / (sum(e^(-1000)) = 0/0 = NAN
auto selectConst = ngraph::builder::makeConstant(precision, ov::Shape{1}, std::vector<float>{1});
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0Param, transpose1Param);
const auto add = std::make_shared<ngraph::opset3::Add>(matMul0, addParam);
std::shared_ptr<ov::Node> selectCond = selectParam;
if (add->get_output_partial_shape(0) != selectParam->get_output_partial_shape(0)) {
const auto broadcast_shape =
ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{add->get_output_shape(0).size()}, add->get_output_shape(0));
selectCond = std::make_shared<ngraph::opset1::Broadcast>(selectCond, broadcast_shape);
}
const auto select = std::make_shared<op::TypeRelaxed<ngraph::opset1::Select>>(
std::vector<element::Type>{ element::boolean, element::f32, element::f32 },
std::vector<element::Type>{ element::f32 },
ov::op::TemporaryReplaceOutputType(selectCond, element::boolean).get(),
ov::op::TemporaryReplaceOutputType(selectConst, element::f32).get(),
ov::op::TemporaryReplaceOutputType(add, element::f32).get());
const auto interm_shape = select->get_shape();
std::vector<int64_t> reshape0ConstData = {-1, static_cast<int64_t>(interm_shape.back())};
auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{reshape0ConstData.size()}, reshape0ConstData);
std::vector<int64_t> reshape1ConstData;
for (const auto& dim : interm_shape)
reshape1ConstData.push_back(static_cast<int64_t>(dim));
auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{reshape1ConstData.size()}, reshape1ConstData);
const auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(select, reshape0Const, true);
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(reshape0, 1);
const auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softMax, reshape1Const, true);
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(reshape1, transpose2Param);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(matMul1)};
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
std::shared_ptr<ov::Model> MHASelectSplitMFunction::initReference() const {
auto param0 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto param1 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
auto selectParam = std::make_shared<ngraph::opset1::Parameter>(ov::element::u8, input_shapes[3]);
auto param2 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[4]);
ngraph::ParameterVector ngraphParam = {param0, param1, addParam, selectParam, param2};
auto make_reshape = [](const std::shared_ptr<ov::Node>& node, const ov::Shape& new_shape) {
auto shape_const = ngraph::builder::makeConstant(ngraph::element::i32, {new_shape.size()}, new_shape);
return std::make_shared<ov::op::v1::Reshape>(node, shape_const, true);
};
auto reshape0 = make_reshape(param0, reshapes[0]);
auto reshape1 = make_reshape(param1, reshapes[1]);
auto reshapeAdd = make_reshape(addParam, reshapes[2]);
auto reshapeSelect = make_reshape(selectParam, reshapes[3]);
auto reshape2 = make_reshape(param2, reshapes[4]);
auto data0 = std::make_shared<ngraph::opset1::Parameter>(reshape0->get_element_type(), reshape0->get_shape());
auto data1 = std::make_shared<ngraph::opset1::Parameter>(reshape1->get_element_type(), reshape1->get_shape());
auto dataAdd = std::make_shared<ngraph::opset1::Parameter>(reshapeAdd->get_element_type(), reshapeAdd->get_shape());
auto dataSelect = std::make_shared<ngraph::opset1::Parameter>(reshapeSelect->get_element_type(), reshapeSelect->get_shape());
auto data2 = std::make_shared<ngraph::opset1::Parameter>(reshape2->get_element_type(), reshape2->get_shape());
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(data0, data1);
const auto add = std::make_shared<ngraph::opset3::Add>(matMul0, dataAdd);
// Value is equal to '1' - to avoid situation e^(-1000) / (sum(e^(-1000)) = 0/0 = NAN
auto selectConst = ngraph::builder::makeConstant(precision, ov::Shape{1}, std::vector<float>{1});
std::shared_ptr<ov::Node> selectCond = dataSelect;
if (add->get_output_partial_shape(0) != dataSelect->get_output_partial_shape(0)) {
const auto broadcast_shape =
ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{add->get_output_shape(0).size()}, add->get_output_shape(0));
selectCond = std::make_shared<ngraph::opset1::Broadcast>(selectCond, broadcast_shape);
}
const auto select = std::make_shared<op::TypeRelaxed<ngraph::opset1::Select>>(
std::vector<element::Type>{ element::boolean, element::f32, element::f32 },
std::vector<element::Type>{ element::f32 },
ov::op::TemporaryReplaceOutputType(selectCond, element::boolean).get(),
ov::op::TemporaryReplaceOutputType(selectConst, element::f32).get(),
ov::op::TemporaryReplaceOutputType(add, element::f32).get());
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(select, add->get_shape().size() - 1);
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softMax, data2);
const auto subgraph =
std::make_shared<ov::snippets::op::Subgraph>(
ov::NodeVector{reshape0, reshape1, reshapeAdd, reshapeSelect, reshape2},
std::make_shared<ov::Model>(ov::OutputVector{matMul1}, ov::ParameterVector{data0, data1, dataAdd, dataSelect, data2}));
auto reshape3 = make_reshape(subgraph, reshapes[5]);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(reshape3)};
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
std::shared_ptr<ov::Model> MHAWOTransposeOnInputsFunction::initOriginal() const {
auto param0 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto param1 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
@ -361,6 +460,37 @@ std::shared_ptr<ov::Model> MHAWOTransposeFunction::initOriginal() const {
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
std::shared_ptr<ov::Model> MHAWOTransposeSplitMFunction::initReference() const {
auto param0 = std::make_shared<ngraph::opset1::Parameter>(precisions[0], input_shapes[0]);
auto param1 = std::make_shared<ngraph::opset1::Parameter>(precisions[1], input_shapes[1]);
auto param2 = std::make_shared<ngraph::opset1::Parameter>(precisions[2], input_shapes[2]);
ngraph::ParameterVector ngraphParam = {param0, param1, param2};
auto make_reshape = [](const std::shared_ptr<ov::Node>& node, const ov::Shape& new_shape) {
auto shape_const = ngraph::builder::makeConstant(ngraph::element::i32, {new_shape.size()}, new_shape);
return std::make_shared<ov::op::v1::Reshape>(node, shape_const, true);
};
auto reshape0 = make_reshape(param0, reshapes[0]);
auto reshape1 = make_reshape(param1, reshapes[1]);
auto reshape2 = make_reshape(param2, reshapes[2]);
auto data0 = std::make_shared<ngraph::opset1::Parameter>(precisions[0], reshape0->get_shape());
auto data1 = std::make_shared<ngraph::opset1::Parameter>(precisions[1], reshape1->get_shape());
auto data2 = std::make_shared<ngraph::opset1::Parameter>(precisions[2], reshape2->get_shape());
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(data0, data1);
const auto softmax = std::make_shared<ngraph::opset8::Softmax>(matMul0, -1);
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, data2);
const auto subgraph = std::make_shared<ov::snippets::op::Subgraph>(ov::NodeVector{reshape0, reshape1, reshape2},
std::make_shared<ov::Model>(ov::OutputVector{matMul1},
ov::ParameterVector{data0, data1, data2}));
auto reshape3 = make_reshape(subgraph, reshapes[3]);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(reshape3)};
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
std::shared_ptr<ov::Model> MHAFQAfterMatMulFunction::initOriginal() const {
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);