[Snippets] Added SplitDimensionM optimization (#18160)
This commit is contained in:
parent
b7935bb869
commit
67c88f4434
@ -5,8 +5,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
|
||||
#include "snippets/op/subgraph.hpp"
|
||||
#include "snippets/pass/tokenization.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace snippets {
|
||||
@ -15,13 +15,20 @@ namespace pass {
|
||||
class CommonOptimizations : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("CommonOptimizations", "0");
|
||||
CommonOptimizations();
|
||||
CommonOptimizations(const SnippetsTokenization::Config& config = {});
|
||||
|
||||
// Returns True if parallelism work amount can be increased using SplitDimensionM optimization
|
||||
static bool CanOptimizeParallelWA(const std::shared_ptr<const ov::Node>& node, size_t minimal_concurrency);
|
||||
|
||||
private:
|
||||
// Move up Constants which aren't scalars from body to Subgraph and replace them with Parameters inside body
|
||||
void ExtractConstants(const std::shared_ptr<op::Subgraph>& subgraph);
|
||||
// Move up unsupported Transposes on Parameter outputs from body
|
||||
void ExtractUnsupportedTransposes(const std::shared_ptr<op::Subgraph>& subgraph);
|
||||
// Insert Reshape nodes after and before Parameters and Results in Subgraphs with MatMul inside
|
||||
// to split dimension M for MatMuls to increase work amount for parallelism
|
||||
// Note: works only with 3D MHA patterns
|
||||
void SplitDimensionM(const std::shared_ptr<op::Subgraph>& subgraph, size_t minimal_concurrency);
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
|
@ -61,8 +61,12 @@ public:
|
||||
* @ingroup snippets
|
||||
*/
|
||||
struct Config {
|
||||
Config(bool enable_transpose = true) : mha_token_enable_transpose(enable_transpose) {}
|
||||
Config(size_t minimal_concurrency = 1, bool split_m_dimension = true, bool enable_transpose = true)
|
||||
: minimal_concurrency(minimal_concurrency), split_m_dimension(split_m_dimension), mha_token_enable_transpose(enable_transpose) {}
|
||||
|
||||
size_t minimal_concurrency = 1;
|
||||
// True if "SplitDimensionM" optimization is enabled. Otherwise, it's disabled.
|
||||
bool split_m_dimension = true;
|
||||
// False if all Transposes aren't tokenized in MHA Tokenization.
|
||||
// Otherwise, they may be fused into Subgraph if possible
|
||||
// TODO [106921]: Remove please when the ticket 106921 is implemented
|
||||
|
@ -19,6 +19,251 @@ namespace ov {
|
||||
namespace snippets {
|
||||
namespace pass {
|
||||
|
||||
bool CommonOptimizations::CanOptimizeParallelWA(const std::shared_ptr<const ov::Node>& node, size_t minimal_concurrency) {
|
||||
if (!ov::is_type<ov::op::v0::MatMul>(node))
|
||||
return false;
|
||||
// It's needed only for 3D MHA patterns
|
||||
const auto mm_shape = node->get_shape();
|
||||
if (mm_shape.size() != 3)
|
||||
return false;
|
||||
const auto current_parallel_work_amount =
|
||||
std::accumulate(mm_shape.rbegin() + 2, mm_shape.rend(), size_t(1), std::multiplies<size_t>());
|
||||
const auto dim_M = *(mm_shape.rbegin() + 1);
|
||||
return (current_parallel_work_amount < minimal_concurrency) &&
|
||||
(current_parallel_work_amount * dim_M >= minimal_concurrency);
|
||||
}
|
||||
|
||||
void CommonOptimizations::SplitDimensionM(const std::shared_ptr<ov::snippets::op::Subgraph>& subgraph, size_t minimal_concurrency) {
|
||||
// To increase parallelism work in 3D cases for MHA pattern,
|
||||
// we split 1st dimension (starting from 0th) into 2 new dimensions to get 4D Shapes where
|
||||
// - 0th and 1st dimensions are used in parallel scheduling,
|
||||
// - 2nd and 3rd dimensions are used in kernel
|
||||
// Note: 3D Patterns don't contain Transpose inside so the reshaping is valid
|
||||
|
||||
// It's needed only for MHA patterns. Need to add support for common patterns
|
||||
if (!subgraph->has_domain_sensitive_ops())
|
||||
return;
|
||||
|
||||
const auto& body = subgraph->body_ptr();
|
||||
const auto& parameters = body->get_parameters();
|
||||
// [107806]: If count of Parameters isn't equal to Subgraph inputs (it's possible case in general),
|
||||
// we cannot garantee correct extraction since we don't have correct connections between body I/O and Subgraph I/O.
|
||||
OPENVINO_ASSERT(parameters.size() == subgraph->input_values().size(),
|
||||
"Failed to extract unsupported transposes: the count of Parameters isn't equal to Subgraph inputs");
|
||||
|
||||
// Need to find MatMul0 and check output shape
|
||||
const auto& ops = body->get_ordered_ops();
|
||||
const auto mm_it = std::find_if(ops.begin(), ops.end(),
|
||||
[](const std::shared_ptr<ov::Node>& node){ return ov::is_type<ov::op::v0::MatMul>(node); });
|
||||
if (mm_it == ops.end())
|
||||
return;
|
||||
|
||||
const auto matmul0 = ov::as_type_ptr<ov::op::v0::MatMul>(*mm_it);
|
||||
if (!matmul0 || !CanOptimizeParallelWA(matmul0, minimal_concurrency))
|
||||
return;
|
||||
|
||||
auto get_dim_M = [](const ov::Shape& shape) {
|
||||
return *(shape.rbegin() + 1);
|
||||
};
|
||||
|
||||
const auto mm_shape = matmul0->get_shape();
|
||||
const auto m_dim = get_dim_M(mm_shape); // M
|
||||
const auto n_dim = mm_shape.back(); // N
|
||||
// [113745] Heurestic is equal to double block size.
|
||||
// When this optimization will be moved into Subgraph and blocking param will be implemented as dependents of shapes,
|
||||
// need to implement common way (for all backends) to calculate optimal value of M dimension
|
||||
const auto optimal_m_dim = 32 * 2;
|
||||
const auto optimal_parallelism_work_amount = minimal_concurrency;
|
||||
if (m_dim <= optimal_m_dim)
|
||||
return;
|
||||
|
||||
const auto batch_dim =
|
||||
std::accumulate(mm_shape.rbegin() + 2, mm_shape.rend(), size_t(1), std::multiplies<size_t>()); // B (batch)
|
||||
size_t batch_m_dim = 1;
|
||||
size_t new_m_dim = m_dim;
|
||||
|
||||
// Need to find optimized dimension splitting: [b1..bk, m, n] -> [b1..bk, batch_m_dim, new_m_dim, n]
|
||||
// The work amount for parallelism should be divided by max thread count in ideal case
|
||||
// that all threads have the same full work amount (avoid of thread downtime)
|
||||
// If it's impossible, it should be more than max thread count
|
||||
// [115284]: Find solution for finding of optimal splitting in these cases
|
||||
// For example, there are 16 threads and shape [6, 512, 32]
|
||||
// LCM(6, 16) = 48 <- ideal work amount for parallelism
|
||||
// new_shape [6, 48 / 6, 512 / (48 / 6), 32 ] => [6, 8, 64, 32]
|
||||
// Each thread has parallelism_work_amount = 6 * 8 / nthrs = 3
|
||||
auto get_lcm = [](size_t a, size_t b) {
|
||||
std::function<size_t(size_t, size_t)> get_gcd;
|
||||
get_gcd = [&get_gcd](size_t a, size_t b) {
|
||||
if (b == 0)
|
||||
return a;
|
||||
return get_gcd(b, a % b);
|
||||
};
|
||||
return a / get_gcd(a, b) * b;
|
||||
};
|
||||
const auto lcm = get_lcm(batch_dim, optimal_parallelism_work_amount); // LCM(b, nthrs)
|
||||
const auto batch_dim_multiplier = lcm / batch_dim; // LCM(b, nthrs) / b
|
||||
const auto needed_new_dim = m_dim / batch_dim_multiplier; // m / (LCM(b, nthrs) / b) - needed factors of dimension m
|
||||
|
||||
auto is_optimized = [&](size_t batch_m_dim, size_t new_m_dim) {
|
||||
return batch_m_dim != 1 && new_m_dim >= optimal_m_dim;
|
||||
};
|
||||
|
||||
if (batch_dim_multiplier * needed_new_dim == m_dim) {
|
||||
batch_m_dim = batch_dim_multiplier;
|
||||
new_m_dim = needed_new_dim;
|
||||
}
|
||||
if (!is_optimized(batch_m_dim, new_m_dim)) {
|
||||
auto get_factors = [](size_t dim) -> std::vector<size_t> {
|
||||
std::vector<size_t> factors;
|
||||
size_t div = 2;
|
||||
while (div <= dim) {
|
||||
const auto res = dim / div;
|
||||
if (res * div == dim) {
|
||||
factors.push_back(div);
|
||||
dim = res;
|
||||
} else {
|
||||
div++;
|
||||
}
|
||||
}
|
||||
return factors;
|
||||
};
|
||||
const auto m_factors = get_factors(m_dim);
|
||||
// If m_dim is Prime number
|
||||
if (m_factors.size() == 2)
|
||||
return;
|
||||
|
||||
batch_m_dim = 1;
|
||||
new_m_dim = m_dim;
|
||||
size_t idx = 0;
|
||||
// [115284] The current solution is not enough optimized. For more details please go to the ticket
|
||||
while (batch_m_dim * batch_dim < optimal_parallelism_work_amount && idx < m_factors.size()) {
|
||||
auto tmp_batch_m_dim = batch_m_dim * m_factors[idx];
|
||||
// There should be enough work for kernel execution
|
||||
if (m_dim / tmp_batch_m_dim * n_dim < optimal_m_dim)
|
||||
break;
|
||||
batch_m_dim = tmp_batch_m_dim;
|
||||
}
|
||||
new_m_dim = m_dim / batch_m_dim;
|
||||
}
|
||||
|
||||
OPENVINO_ASSERT(batch_m_dim * new_m_dim == m_dim, "Incorrect dimension M splitting!");
|
||||
// nothing to split
|
||||
if (!is_optimized(batch_m_dim, new_m_dim))
|
||||
return;
|
||||
|
||||
/***** Reshape insertion *****/
|
||||
|
||||
// There are two Parameter variants:
|
||||
// - Parameter on branches for Second input of MatMul - the shape should be only unsqueezed (add just 1)
|
||||
// - Other Parameters (on First input of MatMuls and between) - the shape should be splitted on M dimension
|
||||
|
||||
bool updated = false;
|
||||
std::set<std::shared_ptr<ov::op::v0::Parameter>> reshaped_params;
|
||||
|
||||
auto insert_reshape = [&](const std::shared_ptr<ov::op::v0::Parameter>& param, const ov::Shape& new_shape) {
|
||||
const auto index = std::distance(parameters.begin(), std::find(parameters.begin(), parameters.end(), param));
|
||||
const auto shape_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{new_shape.size()}, new_shape);
|
||||
const auto reshape = std::make_shared<ov::op::v1::Reshape>(subgraph->input_value(index), shape_const, false);
|
||||
subgraph->input(index).replace_source_output(reshape);
|
||||
param->set_partial_shape(new_shape);
|
||||
reshaped_params.insert(param);
|
||||
updated = true;
|
||||
};
|
||||
|
||||
auto get_updated_shape = [&](const ov::Shape& shape, bool split_m_dim) {
|
||||
const auto current_m_dim = get_dim_M(shape);
|
||||
OPENVINO_ASSERT(!split_m_dim || current_m_dim == 1 || current_m_dim == m_dim, "Incorrect shape for splitting!");
|
||||
ov::Shape new_shape = shape;
|
||||
if ((split_m_dim && current_m_dim == 1) || !split_m_dim) {
|
||||
new_shape.insert((new_shape.rbegin() + 2).base(), 1);
|
||||
} else {
|
||||
new_shape.insert((new_shape.rbegin() + 2).base(), batch_m_dim);
|
||||
*(new_shape.rbegin() + 1) = new_m_dim;
|
||||
}
|
||||
OPENVINO_ASSERT(ov::shape_size(new_shape) == ov::shape_size(shape), "Incorrect shape splitting!");
|
||||
return new_shape;
|
||||
};
|
||||
|
||||
auto reshape_parameter = [&](const std::shared_ptr<ov::Node>& node, bool split_m_dim = true) {
|
||||
const auto param = ov::as_type_ptr<ov::op::v0::Parameter>(node);
|
||||
if (!param || reshaped_params.count(param) > 0)
|
||||
return;
|
||||
insert_reshape(param, get_updated_shape(param->get_partial_shape().get_shape(), split_m_dim));
|
||||
};
|
||||
|
||||
auto update_matmul_second_branch = [&](const std::shared_ptr<ov::Node>& node) {
|
||||
auto parent = node->get_input_node_shared_ptr(1);
|
||||
while (!ov::is_type<ov::op::v0::Parameter>(parent)) {
|
||||
if (parent->get_input_size() > 1) {
|
||||
for (const auto& input_source : parent->input_values()) {
|
||||
reshape_parameter(input_source.get_node_shared_ptr(), false);
|
||||
}
|
||||
}
|
||||
|
||||
// [107731]: It's covered my MHA tokenization
|
||||
parent = parent->get_input_node_shared_ptr(0);
|
||||
}
|
||||
reshape_parameter(parent, false);
|
||||
};
|
||||
|
||||
// Firstly, Unsqueeze parameters on second branches of MatMuls
|
||||
for (const auto& op : ops) {
|
||||
if (ov::is_type<ov::op::v0::MatMul>(op)) {
|
||||
update_matmul_second_branch(op);
|
||||
}
|
||||
}
|
||||
|
||||
// Secondly, Update All M dimensions for remaining parameters
|
||||
for (const auto& param : parameters) {
|
||||
if (reshaped_params.count(param) == 0)
|
||||
reshape_parameter(param, true);
|
||||
}
|
||||
|
||||
// Return the previous shape on outputs
|
||||
for (size_t i = 0; i < subgraph->get_output_size() && updated; ++i) {
|
||||
const auto output_shape = subgraph->get_output_shape(i);
|
||||
if (is_scalar(output_shape))
|
||||
continue;
|
||||
|
||||
const auto& target_inputs = subgraph->get_output_target_inputs(i);
|
||||
const auto shape_const = std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{output_shape.size()}, output_shape);
|
||||
const auto reshape = std::make_shared<ov::op::v1::Reshape>(subgraph->output(i), shape_const, false);
|
||||
// Save output name
|
||||
const auto original_output = body->get_results()[i]->get_input_node_shared_ptr(0);
|
||||
const auto original_name = original_output->get_friendly_name();
|
||||
reshape->set_friendly_name(original_name);
|
||||
original_output->set_friendly_name(original_name + "_original");
|
||||
|
||||
for (const auto& input : target_inputs) {
|
||||
input.replace_source_output(reshape);
|
||||
// Result input tensor name was changed, the name has to be restored
|
||||
if (ov::is_type<ov::op::v0::Result>(input.get_node())) {
|
||||
input.get_tensor_ptr()->add_names(subgraph->output(i).get_tensor_ptr()->get_names());
|
||||
}
|
||||
}
|
||||
subgraph->output(i).get_tensor_ptr()->set_names({});
|
||||
updated = true;
|
||||
}
|
||||
subgraph->set_friendly_name(subgraph->get_friendly_name() + "_original");
|
||||
|
||||
// Need to update inner Shapes and Softmax Axis
|
||||
if (updated) {
|
||||
for (const auto &op : ops) {
|
||||
if (const auto softmax_v8 = ov::as_type_ptr<ov::op::v8::Softmax>(op)) {
|
||||
softmax_v8->set_axis(-1);
|
||||
} else if (const auto softmax_v1 = ov::as_type_ptr<ov::op::v1::Softmax>(op)) {
|
||||
softmax_v1->set_axis(softmax_v1->get_output_partial_shape(0).size()); // since new_shape.size() = old_shape.size() + 1
|
||||
} else if (const auto broadcast = ov::as_type_ptr<ov::op::v1::Broadcast>(op)) {
|
||||
// Broadcast is tokenized only between MatMuls -> Split M dimension
|
||||
const auto shape_const = ov::as_type_ptr<ov::op::v0::Constant>(broadcast->input_value(1).get_node_shared_ptr());
|
||||
OPENVINO_ASSERT(shape_const, "SplitDimensionM expects Broadcast with Constant output shape");
|
||||
const auto new_shape = get_updated_shape(shape_const->cast_vector<size_t>(), true);
|
||||
broadcast->set_argument(1, std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{new_shape.size()}, new_shape));
|
||||
}
|
||||
}
|
||||
subgraph->validate_and_infer_types();
|
||||
}
|
||||
}
|
||||
|
||||
void CommonOptimizations::ExtractConstants(const std::shared_ptr<ov::snippets::op::Subgraph>& subgraph) {
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ExtractConstants");
|
||||
@ -98,9 +343,9 @@ void CommonOptimizations::ExtractUnsupportedTransposes(const std::shared_ptr<op:
|
||||
}
|
||||
}
|
||||
|
||||
CommonOptimizations::CommonOptimizations() {
|
||||
CommonOptimizations::CommonOptimizations(const SnippetsTokenization::Config& config) {
|
||||
MATCHER_SCOPE(CommonOptimizations);
|
||||
ov::graph_rewrite_callback callback = [this](ov::pass::pattern::Matcher& m) {
|
||||
ov::graph_rewrite_callback callback = [&](ov::pass::pattern::Matcher& m) {
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::CommonOptimizations");
|
||||
|
||||
auto subgraph = ov::as_type_ptr<ov::snippets::op::Subgraph>(m.get_match_root());
|
||||
@ -130,6 +375,8 @@ CommonOptimizations::CommonOptimizations() {
|
||||
// Extract unsupported Transposes from body
|
||||
if (subgraph->has_domain_sensitive_ops()) {
|
||||
ExtractUnsupportedTransposes(subgraph);
|
||||
if (config.split_m_dimension)
|
||||
SplitDimensionM(subgraph, config.minimal_concurrency);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
@ -80,7 +80,7 @@ bool SnippetsTokenization::run_on_model(const std::shared_ptr<ov::Model>& m) {
|
||||
manager.register_pass<EnumerateNodes>();
|
||||
manager.register_pass<TokenizeMHASnippets>(m_config);
|
||||
manager.register_pass<TokenizeSnippets>();
|
||||
manager.register_pass<CommonOptimizations>();
|
||||
manager.register_pass<CommonOptimizations>(m_config);
|
||||
manager.run_passes(m);
|
||||
|
||||
// Returning value is false because pass::Manager always apply Validation pass if function was changed.
|
||||
|
@ -6,6 +6,8 @@
|
||||
|
||||
#include <common_test_utils/ngraph_test_utils.hpp>
|
||||
|
||||
#include "snippets/pass/tokenization.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
namespace snippets {
|
||||
@ -13,6 +15,9 @@ namespace snippets {
|
||||
class TokenizeMHASnippetsTests : public TransformationTestsF {
|
||||
public:
|
||||
virtual void run();
|
||||
|
||||
protected:
|
||||
ov::snippets::pass::SnippetsTokenization::Config config;
|
||||
};
|
||||
|
||||
} // namespace snippets
|
||||
|
@ -17,7 +17,7 @@ void TokenizeMHASnippetsTests::run() {
|
||||
ASSERT_TRUE(function);
|
||||
manager.register_pass<ov::snippets::pass::EnumerateNodes>();
|
||||
manager.register_pass<ov::snippets::pass::TokenizeMHASnippets>();
|
||||
manager.register_pass<ov::snippets::pass::CommonOptimizations>();
|
||||
manager.register_pass<ov::snippets::pass::CommonOptimizations>(config);
|
||||
disable_rt_info_check();
|
||||
}
|
||||
|
||||
@ -67,7 +67,25 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Transpose_fusion) {
|
||||
run();
|
||||
}
|
||||
|
||||
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_SplitM) {
|
||||
const auto& f = MHAWOTransposeSplitMFunction(std::vector<PartialShape>{{10, 9216, 128}, {10, 128, 9216}, {10, 9216, 128}},
|
||||
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32}),
|
||||
std::vector<Shape>{{10, 9, 1024, 128}, {10, 1, 128, 9216}, {10, 1, 9216, 128}, {10, 9216, 128}});
|
||||
function = f.getOriginal();
|
||||
function_ref = f.getReference();
|
||||
config.minimal_concurrency = 18;
|
||||
run();
|
||||
}
|
||||
|
||||
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHASelect_SplitM) {
|
||||
const auto& f = MHASelectSplitMFunction(std::vector<PartialShape>{{8, 512, 18}, {8, 18, 64}, {1, 512, 64}, {1, 1, 64}, {8, 64, 512}},
|
||||
std::vector<Shape>{{8, 2, 256, 18}, {8, 1, 18, 64}, {1, 2, 256, 64}, {1, 1, 1, 64},
|
||||
{8, 1, 64, 512}, {8, 512, 512}});
|
||||
function = f.getOriginal();
|
||||
function_ref = f.getReference();
|
||||
config.minimal_concurrency = 16;
|
||||
run();
|
||||
}
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
|
@ -627,13 +627,17 @@ void Transformations::MainSnippets(void) {
|
||||
!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) // snippets are implemented only for relevant platforms (avx2+ extensions)
|
||||
return;
|
||||
|
||||
ov::snippets::pass::SnippetsTokenization::Config tokenization_config;
|
||||
// At the moment Snippets supports Transposes in MHA pattern only in FP32 case since
|
||||
// - ConvertSaturation[BF16->FP32] will be inserted after Parameters and before Transposes in canonicalization stage
|
||||
// - ConvertSaturation[FP32->BF16] will be inserted after Transposes and before Brgemm in precision propagation stage
|
||||
// Because of that Transposes won't be fused into Brgemm
|
||||
// TODO [111813]: Need to update this pipeline to avoid Converts between Transposes and Brgemm on inputs
|
||||
ov::snippets::pass::SnippetsTokenization::Config tokenization_config;
|
||||
tokenization_config.mha_token_enable_transpose = (inferencePrecision == ov::element::f32);
|
||||
tokenization_config.minimal_concurrency = parallel_get_num_threads();
|
||||
// The optimization "SplitDimensionM" depends on target machine (thread count).
|
||||
// To avoid uncontrolled behavior in tests, we disabled the optimization when there is Config::SnippetsMode::IgnoreCallback
|
||||
tokenization_config.split_m_dimension = snippetsMode != Config::SnippetsMode::IgnoreCallback;
|
||||
|
||||
ngraph::pass::Manager snippetsManager;
|
||||
snippetsManager.set_per_pass_validation(false);
|
||||
@ -687,8 +691,9 @@ void Transformations::MainSnippets(void) {
|
||||
// TODO: The heuristic will be removed after parallelism support on JIT level
|
||||
const auto needed_num_of_threads = 12lu;
|
||||
const auto is_unsupported_parallel_work_amount =
|
||||
parallel_get_num_threads() / 2 > parallel_work_amount &&
|
||||
static_cast<size_t>(parallel_work_amount) < needed_num_of_threads;
|
||||
parallel_get_num_threads() / 2 > parallel_work_amount &&
|
||||
static_cast<size_t>(parallel_work_amount) < needed_num_of_threads &&
|
||||
!ov::snippets::pass::CommonOptimizations::CanOptimizeParallelWA(n, tokenization_config.minimal_concurrency);
|
||||
return is_unsupported_parallel_work_amount;
|
||||
},
|
||||
snippets::pass::TokenizeMHASnippets);
|
||||
|
@ -45,8 +45,8 @@ class MHAFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHAFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions, bool with_mul = true)
|
||||
: SnippetsFunctionBase(inputShapes), with_mul(with_mul), precisions(precisions) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
|
||||
NGRAPH_CHECK(precisions.size() == 4, "Got invalid number of input precisions");
|
||||
OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes");
|
||||
OPENVINO_ASSERT(precisions.size() == 4, "Got invalid number of input precisions");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
@ -75,8 +75,8 @@ class MHAMatMul0TransposeFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHAMatMul0TransposeFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
|
||||
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
|
||||
NGRAPH_CHECK(precisions.size() == 4, "Got invalid number of input precisions");
|
||||
OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes");
|
||||
OPENVINO_ASSERT(precisions.size() == 4, "Got invalid number of input precisions");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
@ -104,8 +104,8 @@ class MHASelectFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHASelectFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
|
||||
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 6, "Got invalid number of input shapes");
|
||||
NGRAPH_CHECK(precisions.size() == 6, "Got invalid number of input precisions");
|
||||
OPENVINO_ASSERT(input_shapes.size() == 6, "Got invalid number of input shapes");
|
||||
OPENVINO_ASSERT(precisions.size() == 6, "Got invalid number of input precisions");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
@ -113,6 +113,22 @@ protected:
|
||||
std::vector<ov::element::Type> precisions;
|
||||
};
|
||||
|
||||
// Only for tokenization tests since boolean type->u8
|
||||
// Without Transposes
|
||||
class MHASelectSplitMFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHASelectSplitMFunction(const std::vector<PartialShape>& inputShapes, const std::vector<Shape>& reshapes)
|
||||
: SnippetsFunctionBase(inputShapes), reshapes(reshapes) {
|
||||
OPENVINO_ASSERT(input_shapes.size() == 5, "Got invalid number of input shapes");
|
||||
OPENVINO_ASSERT(reshapes.size() == 6, "Got invalid number of input precisions");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
std::shared_ptr<ov::Model> initReference() const override;
|
||||
|
||||
std::vector<Shape> reshapes;
|
||||
};
|
||||
|
||||
/* Graph:
|
||||
* Constant
|
||||
* \ /
|
||||
@ -129,7 +145,7 @@ protected:
|
||||
class MHAWOTransposeOnInputsFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHAWOTransposeOnInputsFunction(const std::vector<PartialShape>& inputShapes) : SnippetsFunctionBase(inputShapes) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
|
||||
OPENVINO_ASSERT(input_shapes.size() == 3, "Got invalid number of input shapes");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
@ -148,8 +164,8 @@ class MHAWOTransposeFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHAWOTransposeFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
|
||||
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
|
||||
NGRAPH_CHECK(precisions.size() == 3, "Got invalid number of input precisions");
|
||||
OPENVINO_ASSERT(input_shapes.size() == 3, "Got invalid number of input shapes");
|
||||
OPENVINO_ASSERT(precisions.size() == 3, "Got invalid number of input precisions");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
@ -157,6 +173,19 @@ protected:
|
||||
std::vector<ov::element::Type> precisions;
|
||||
};
|
||||
|
||||
class MHAWOTransposeSplitMFunction : public MHAWOTransposeFunction {
|
||||
public:
|
||||
explicit MHAWOTransposeSplitMFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions,
|
||||
const std::vector<Shape>& reshapes)
|
||||
: MHAWOTransposeFunction(inputShapes, precisions), reshapes(reshapes) {
|
||||
OPENVINO_ASSERT(reshapes.size() == 4, "Got invalid number of Reshape shapes");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initReference() const override;
|
||||
|
||||
std::vector<ov::Shape> reshapes;
|
||||
};
|
||||
|
||||
/* Graph:
|
||||
* Transpose0[0,2,1,3] Transpose1[0,2,3,1]
|
||||
* \ /
|
||||
@ -176,7 +205,7 @@ class MHAFQAfterMatMulFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHAFQAfterMatMulFunction(const std::vector<PartialShape>& inputShapes)
|
||||
: SnippetsFunctionBase(inputShapes) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
|
||||
OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
@ -203,7 +232,7 @@ class MHAINT8MatMulFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHAINT8MatMulFunction(const std::vector<PartialShape>& inputShapes)
|
||||
: SnippetsFunctionBase(inputShapes) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
|
||||
OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
@ -232,7 +261,7 @@ class MHAFQFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHAFQFunction(const std::vector<PartialShape>& inputShapes)
|
||||
: SnippetsFunctionBase(inputShapes) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
|
||||
OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
@ -261,7 +290,7 @@ class MHAINT8MatMulTypeRelaxedFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHAINT8MatMulTypeRelaxedFunction(const std::vector<PartialShape>& inputShapes)
|
||||
: SnippetsFunctionBase(inputShapes) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
|
||||
OPENVINO_ASSERT(input_shapes.size() == 4, "Got invalid number of input shapes");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
@ -283,7 +312,7 @@ protected:
|
||||
class MHAMulAddFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHAMulAddFunction(const std::vector<PartialShape>& inputShapes) : SnippetsFunctionBase(inputShapes) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
|
||||
OPENVINO_ASSERT(input_shapes.size() == 3, "Got invalid number of input shapes");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
@ -304,7 +333,7 @@ public:
|
||||
explicit MHATransposedInputFunction(const std::vector<PartialShape>& inputShapes, bool transposed_b = false,
|
||||
std::vector<int64_t> order = {})
|
||||
: SnippetsFunctionBase(inputShapes), m_transposed_b(transposed_b), m_order(order) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
|
||||
OPENVINO_ASSERT(input_shapes.size() == 3, "Got invalid number of input shapes");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
|
@ -324,6 +324,105 @@ std::shared_ptr<ov::Model> MHASelectFunction::initOriginal() const {
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> MHASelectSplitMFunction::initOriginal() const {
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
auto selectParam = std::make_shared<ngraph::opset1::Parameter>(ov::element::u8, input_shapes[3]);
|
||||
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[4]);
|
||||
ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, selectParam, transpose2Param};
|
||||
|
||||
// Value is equal to '1' - to avoid situation e^(-1000) / (sum(e^(-1000)) = 0/0 = NAN
|
||||
auto selectConst = ngraph::builder::makeConstant(precision, ov::Shape{1}, std::vector<float>{1});
|
||||
|
||||
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0Param, transpose1Param);
|
||||
const auto add = std::make_shared<ngraph::opset3::Add>(matMul0, addParam);
|
||||
std::shared_ptr<ov::Node> selectCond = selectParam;
|
||||
if (add->get_output_partial_shape(0) != selectParam->get_output_partial_shape(0)) {
|
||||
const auto broadcast_shape =
|
||||
ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{add->get_output_shape(0).size()}, add->get_output_shape(0));
|
||||
selectCond = std::make_shared<ngraph::opset1::Broadcast>(selectCond, broadcast_shape);
|
||||
}
|
||||
const auto select = std::make_shared<op::TypeRelaxed<ngraph::opset1::Select>>(
|
||||
std::vector<element::Type>{ element::boolean, element::f32, element::f32 },
|
||||
std::vector<element::Type>{ element::f32 },
|
||||
ov::op::TemporaryReplaceOutputType(selectCond, element::boolean).get(),
|
||||
ov::op::TemporaryReplaceOutputType(selectConst, element::f32).get(),
|
||||
ov::op::TemporaryReplaceOutputType(add, element::f32).get());
|
||||
|
||||
const auto interm_shape = select->get_shape();
|
||||
std::vector<int64_t> reshape0ConstData = {-1, static_cast<int64_t>(interm_shape.back())};
|
||||
auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{reshape0ConstData.size()}, reshape0ConstData);
|
||||
|
||||
std::vector<int64_t> reshape1ConstData;
|
||||
for (const auto& dim : interm_shape)
|
||||
reshape1ConstData.push_back(static_cast<int64_t>(dim));
|
||||
auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{reshape1ConstData.size()}, reshape1ConstData);
|
||||
|
||||
const auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(select, reshape0Const, true);
|
||||
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(reshape0, 1);
|
||||
const auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softMax, reshape1Const, true);
|
||||
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(reshape1, transpose2Param);
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(matMul1)};
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> MHASelectSplitMFunction::initReference() const {
|
||||
auto param0 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto param1 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
auto selectParam = std::make_shared<ngraph::opset1::Parameter>(ov::element::u8, input_shapes[3]);
|
||||
auto param2 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[4]);
|
||||
ngraph::ParameterVector ngraphParam = {param0, param1, addParam, selectParam, param2};
|
||||
|
||||
auto make_reshape = [](const std::shared_ptr<ov::Node>& node, const ov::Shape& new_shape) {
|
||||
auto shape_const = ngraph::builder::makeConstant(ngraph::element::i32, {new_shape.size()}, new_shape);
|
||||
return std::make_shared<ov::op::v1::Reshape>(node, shape_const, true);
|
||||
};
|
||||
|
||||
auto reshape0 = make_reshape(param0, reshapes[0]);
|
||||
auto reshape1 = make_reshape(param1, reshapes[1]);
|
||||
auto reshapeAdd = make_reshape(addParam, reshapes[2]);
|
||||
auto reshapeSelect = make_reshape(selectParam, reshapes[3]);
|
||||
auto reshape2 = make_reshape(param2, reshapes[4]);
|
||||
|
||||
auto data0 = std::make_shared<ngraph::opset1::Parameter>(reshape0->get_element_type(), reshape0->get_shape());
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(reshape1->get_element_type(), reshape1->get_shape());
|
||||
auto dataAdd = std::make_shared<ngraph::opset1::Parameter>(reshapeAdd->get_element_type(), reshapeAdd->get_shape());
|
||||
auto dataSelect = std::make_shared<ngraph::opset1::Parameter>(reshapeSelect->get_element_type(), reshapeSelect->get_shape());
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(reshape2->get_element_type(), reshape2->get_shape());
|
||||
|
||||
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(data0, data1);
|
||||
const auto add = std::make_shared<ngraph::opset3::Add>(matMul0, dataAdd);
|
||||
|
||||
// Value is equal to '1' - to avoid situation e^(-1000) / (sum(e^(-1000)) = 0/0 = NAN
|
||||
auto selectConst = ngraph::builder::makeConstant(precision, ov::Shape{1}, std::vector<float>{1});
|
||||
std::shared_ptr<ov::Node> selectCond = dataSelect;
|
||||
if (add->get_output_partial_shape(0) != dataSelect->get_output_partial_shape(0)) {
|
||||
const auto broadcast_shape =
|
||||
ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{add->get_output_shape(0).size()}, add->get_output_shape(0));
|
||||
selectCond = std::make_shared<ngraph::opset1::Broadcast>(selectCond, broadcast_shape);
|
||||
}
|
||||
const auto select = std::make_shared<op::TypeRelaxed<ngraph::opset1::Select>>(
|
||||
std::vector<element::Type>{ element::boolean, element::f32, element::f32 },
|
||||
std::vector<element::Type>{ element::f32 },
|
||||
ov::op::TemporaryReplaceOutputType(selectCond, element::boolean).get(),
|
||||
ov::op::TemporaryReplaceOutputType(selectConst, element::f32).get(),
|
||||
ov::op::TemporaryReplaceOutputType(add, element::f32).get());
|
||||
|
||||
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(select, add->get_shape().size() - 1);
|
||||
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softMax, data2);
|
||||
|
||||
const auto subgraph =
|
||||
std::make_shared<ov::snippets::op::Subgraph>(
|
||||
ov::NodeVector{reshape0, reshape1, reshapeAdd, reshapeSelect, reshape2},
|
||||
std::make_shared<ov::Model>(ov::OutputVector{matMul1}, ov::ParameterVector{data0, data1, dataAdd, dataSelect, data2}));
|
||||
auto reshape3 = make_reshape(subgraph, reshapes[5]);
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(reshape3)};
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> MHAWOTransposeOnInputsFunction::initOriginal() const {
|
||||
auto param0 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto param1 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
@ -361,6 +460,37 @@ std::shared_ptr<ov::Model> MHAWOTransposeFunction::initOriginal() const {
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> MHAWOTransposeSplitMFunction::initReference() const {
|
||||
auto param0 = std::make_shared<ngraph::opset1::Parameter>(precisions[0], input_shapes[0]);
|
||||
auto param1 = std::make_shared<ngraph::opset1::Parameter>(precisions[1], input_shapes[1]);
|
||||
auto param2 = std::make_shared<ngraph::opset1::Parameter>(precisions[2], input_shapes[2]);
|
||||
ngraph::ParameterVector ngraphParam = {param0, param1, param2};
|
||||
|
||||
auto make_reshape = [](const std::shared_ptr<ov::Node>& node, const ov::Shape& new_shape) {
|
||||
auto shape_const = ngraph::builder::makeConstant(ngraph::element::i32, {new_shape.size()}, new_shape);
|
||||
return std::make_shared<ov::op::v1::Reshape>(node, shape_const, true);
|
||||
};
|
||||
|
||||
auto reshape0 = make_reshape(param0, reshapes[0]);
|
||||
auto reshape1 = make_reshape(param1, reshapes[1]);
|
||||
auto reshape2 = make_reshape(param2, reshapes[2]);
|
||||
|
||||
auto data0 = std::make_shared<ngraph::opset1::Parameter>(precisions[0], reshape0->get_shape());
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(precisions[1], reshape1->get_shape());
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(precisions[2], reshape2->get_shape());
|
||||
|
||||
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(data0, data1);
|
||||
const auto softmax = std::make_shared<ngraph::opset8::Softmax>(matMul0, -1);
|
||||
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, data2);
|
||||
|
||||
const auto subgraph = std::make_shared<ov::snippets::op::Subgraph>(ov::NodeVector{reshape0, reshape1, reshape2},
|
||||
std::make_shared<ov::Model>(ov::OutputVector{matMul1},
|
||||
ov::ParameterVector{data0, data1, data2}));
|
||||
auto reshape3 = make_reshape(subgraph, reshapes[3]);
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(reshape3)};
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> MHAFQAfterMatMulFunction::initOriginal() const {
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
|
Loading…
Reference in New Issue
Block a user