[Snippets] Added support of MatMuls with transposed inputs (#17819)

This commit is contained in:
Alexandra Sidorova 2023-06-14 11:55:24 +04:00 committed by GitHub
parent 9754117a61
commit 67dc220d38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 394 additions and 183 deletions

View File

@ -5,7 +5,8 @@
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "snippets/op/subgraph.hpp"
namespace ov {
namespace snippets {
@ -15,6 +16,12 @@ class CommonOptimizations : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("CommonOptimizations", "0");
CommonOptimizations();
private:
// Move up Constants which aren't scalars from body to Subgraph and replace them with Parameters inside body
void ExtractConstants(const std::shared_ptr<op::Subgraph>& subgraph);
// Move up unsupported Transposes on Parameter outputs from body
void ExtractUnsupportedTransposes(const std::shared_ptr<op::Subgraph>& subgraph);
};
} // namespace pass

View File

@ -5,7 +5,6 @@
#pragma once
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/matcher.hpp"
namespace ov {
namespace snippets {
@ -13,18 +12,25 @@ namespace pass {
/**
* @interface ExplicitTransposeMatMulInputs
* @brief At the moment Snippets supports Transpose only with order {0, 2, 3, 1},
* so if there is pattern in graph:
* in0 Transpose{0, 2, 1, 3}
* \ /
* MatMul[false, true]
* We can set false in MatMul parameter `transposed_b` and
* change Transpose order to {0, 2, 3, 1} which is supported by Snippets
* @brief The pass extracts explicit Transpose node from MatMul with transposed_<a|b> and moves it to Parameter.
* If there is another Transpose, the pass fuses extracted Transpose and existing Transpose.
* For example, At the moment Snippets supports Transpose only with order {0, 2, 3, 1}, so if there is pattern in graph:
* in0 Transpose{0, 2, 1, 3}
* \ /
* MatMul[false, true]
* We can set `false` in MatMul parameter `transposed_b` and change Transpose order to {0, 2, 3, 1} which is supported by Snippets
* @ingroup snippets
*/
class ExplicitTransposeMatMulInputs: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("ExplicitTransposeMatMulInputs", "0");
ExplicitTransposeMatMulInputs();
// Return `True` if all inputs (except 0-th input) have scalar shape. Otherwise returns `False`
static bool are_weights_scalar(const std::shared_ptr<ov::Node>& node);
private:
static void extract(const ov::Input<ov::Node>& input);
};
} // namespace pass

View File

@ -4,27 +4,24 @@
#include "snippets/pass/common_optimizations.hpp"
#include <memory>
#include "openvino/opsets/opset1.hpp"
#include <ngraph/pass/constant_folding.hpp>
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
#include "snippets/pass/fq_decomposition.hpp"
#include "snippets/pass/softmax_reshape_elimination.hpp"
#include "snippets/pass/explicit_transpose_matmul_inputs.hpp"
#include "snippets/pass/transpose_decomposition.hpp"
#include "snippets/pass/fuse_transpose_brgemm.hpp"
#include "snippets/op/subgraph.hpp"
#include "snippets/utils.hpp"
#include "snippets/itt.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"
namespace ov {
namespace snippets {
namespace pass {
// Move up Constants which aren't scalars from body to Subgraph and replace them with Parameters inside body
void ConvertConstantsToParameters(const std::shared_ptr<ov::snippets::op::Subgraph>& subgraph) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ConvertConstantsToParameters");
void CommonOptimizations::ExtractConstants(const std::shared_ptr<ov::snippets::op::Subgraph>& subgraph) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ExtractConstants");
auto body = subgraph->body_ptr();
ParameterVector new_parameters;
@ -55,6 +52,52 @@ void ConvertConstantsToParameters(const std::shared_ptr<ov::snippets::op::Subgra
}
}
void CommonOptimizations::ExtractUnsupportedTransposes(const std::shared_ptr<op::Subgraph>& subgraph) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ExtractUnsupportedTransposes");
const auto& body = subgraph->body_ptr();
const auto parameters = body->get_parameters();
// [107806]: If count of Parameters isn't equal to Subgraph inputs,
// we cannot guarantee correct extraction since we don't have correct connections between body I/O and Subgraph I/O.
OPENVINO_ASSERT(parameters.size() == subgraph->input_values().size(),
"Failed to extract unsupported transposes: the count of Parameters isn't equal to Subgraph inputs");
bool updated = false;
for (size_t i = 0; i < parameters.size(); ++i) {
const auto& parameter = parameters[i];
const auto& consumers = parameter->get_output_target_inputs(0);
if (consumers.size() != 1)
continue;
const auto transpose = ov::as_type_ptr<opset1::Transpose>(consumers.begin()->get_node()->shared_from_this());
if (!transpose)
continue;
const auto& order = ov::as_type_ptr<opset1::Constant>(transpose->get_input_node_shared_ptr(1));
if (!order)
continue;
const auto order_value = order->cast_vector<int>();
const auto transpose_child = *(transpose->get_output_target_inputs(0).begin());
const auto is_brgemm_case = ov::is_type<opset1::MatMul>(transpose_child.get_node()->shared_from_this());
// If Transpose is supported (can be decomposed or fused into Brgemm), skip
if ((is_brgemm_case && FuseTransposeBrgemm::supported_cases.count(order_value) != 0) ||
(TransposeDecomposition::supported_cases.count(order_value) != 0))
continue;
// If the transpose isn't supported - we have to extract it from Subgraph
transpose->set_argument(0, subgraph->input_value(i));
subgraph->set_argument(i, transpose);
transpose_child.replace_source_output(parameter);
// Update shape
parameter->set_partial_shape(transpose->get_output_partial_shape(0));
updated = true;
}
if (updated) {
subgraph->validate_and_infer_types();
}
}
CommonOptimizations::CommonOptimizations() {
MATCHER_SCOPE(CommonOptimizations);
ov::graph_rewrite_callback callback = [this](ov::pass::pattern::Matcher& m) {
@ -65,10 +108,10 @@ CommonOptimizations::CommonOptimizations() {
return false;
}
auto body = subgraph->body_ptr();
const auto& body = subgraph->body_ptr();
const auto is_quantized = subgraph->is_quantized();
// Firsly we should transform all original Converts inside body to ConvertTruncation to save original behavior.
// Firstly, we should transform all original Converts inside body to ConvertTruncation to save original behavior.
// Then if Subgraph contains FakeQuantize we enable specific transformation for quantized subgraphs.
ov::pass::Manager manager;
manager.register_pass<ov::snippets::pass::TransformConvertToConvertTruncation>();
@ -80,15 +123,18 @@ CommonOptimizations::CommonOptimizations() {
manager.run_passes(body);
// At the moment only non-scalar Constants of FakeQuantize can be inside Subgraph
// so we can enable ConvertConstantsToParameters pass for quantized models
// so we can enable ExtractConstants pass for quantized models
if (is_quantized) {
ConvertConstantsToParameters(subgraph);
ExtractConstants(subgraph);
}
// Extract unsupported Transposes from body
if (subgraph->has_domain_sensitive_ops()) {
ExtractUnsupportedTransposes(subgraph);
}
return true;
};
auto m = std::make_shared<ov::pass::pattern::Matcher>(ov::pass::pattern::wrap_type<ov::snippets::op::Subgraph>(),
matcher_name);
auto m = std::make_shared<ov::pass::pattern::Matcher>(ov::pass::pattern::wrap_type<ov::snippets::op::Subgraph>(), matcher_name);
this->register_matcher(m, callback);
}

View File

@ -2,79 +2,101 @@
// SPDX-License-Identifier: Apache-2.0
//
#include "snippets/pass/explicit_transpose_matmul_inputs.hpp"
#include "snippets/op/subgraph.hpp"
#include "snippets/itt.hpp"
#include "snippets/pass/explicit_transpose_matmul_inputs.hpp"
#include "snippets/pass/transpose_decomposition.hpp"
#include "snippets/op/subgraph.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "openvino/core/rt_info.hpp"
bool ov::snippets::pass::ExplicitTransposeMatMulInputs::are_weights_scalar(const std::shared_ptr<ov::Node>& node) {
const auto inputs = node->inputs();
return std::all_of(inputs.begin() + 1, inputs.end(),
[](const ov::Input<ov::Node>& in) {
return in.get_partial_shape().is_static() && ov::shape_size(in.get_shape()) == 1;
});
}
void ov::snippets::pass::ExplicitTransposeMatMulInputs::extract(const ov::Input<ov::Node>& input) {
auto parent = input.get_source_output().get_node_shared_ptr();
auto transpose = ov::as_type_ptr<ov::op::v1::Transpose>(parent);
while (!transpose && !ov::is_type<ov::op::v0::Parameter>(parent)) {
// We can set supported order and transposed_<a|b>=false only if ops have scalar shapes to avoid shape mismatching
if (!are_weights_scalar(parent))
break;
parent = parent->get_input_node_shared_ptr(0);
transpose = ov::as_type_ptr<ov::op::v1::Transpose>(parent);
}
// If there isn't another Transpose, need to create new Transpose
if (transpose) {
const auto transpose_pattern = ov::as_type_ptr<ov::op::v0::Constant>(transpose->get_input_node_shared_ptr(1));
OPENVINO_ASSERT(transpose_pattern,
"ExplicitTransposeMatMulInputs expects existing Transpose with Constant order");
auto transposed_order = transpose_pattern->cast_vector<int32_t>();
OPENVINO_ASSERT(transposed_order.size() > 2, "Incorrect Transpose order for ExplicitTransposeMatMulInputs");
std::swap(*transposed_order.rbegin(), *(transposed_order.rbegin() + 1));
auto new_transpose_order = std::make_shared<ov::op::v0::Constant>(transpose_pattern->get_element_type(),
ov::Shape{transposed_order.size()},
transposed_order);
new_transpose_order->set_friendly_name(transpose_pattern->get_friendly_name());
ov::copy_runtime_info(transpose_pattern, new_transpose_order);
transpose->set_argument(1, new_transpose_order);
return;
}
// Create new Transpose before Parameter
OPENVINO_ASSERT(ov::is_type<opset1::Parameter>(parent),
"ExplicitTransposeMatMulInputs expects Parameter in cases when there isn't existing Transpose on input");
const auto& consumers = parent->get_output_target_inputs(0);
OPENVINO_ASSERT(consumers.size() == 1,
"ExplicitTransposeMatMulInputs expects Parameter with one consumer in cases when there isn't existing Transpose on input");
// Extract Transpose from MatMul
OPENVINO_ASSERT(input.get_partial_shape().is_static(), "ExplicitTransposeMatMulInputs supports only static shapes");
const auto rank = input.get_shape().size();
std::vector<size_t> transpose_order(rank, 0);
std::iota(transpose_order.begin(), transpose_order.end(), 0);
std::swap(transpose_order[rank - 1], transpose_order[rank - 2]);
const auto constant_order = std::make_shared<opset1::Constant>(ov::element::i32, ov::Shape{rank}, transpose_order);
const auto new_transpose = std::make_shared<opset1::Transpose>(parent, constant_order); // parent is Parameter
const auto consumer_input = *(consumers.begin());
consumer_input.replace_source_output(new_transpose);
}
ov::snippets::pass::ExplicitTransposeMatMulInputs::ExplicitTransposeMatMulInputs() {
MATCHER_SCOPE(ExplicitTransposeMatMulInputs);
auto m_matmul0 = std::make_shared<ov::opset1::MatMul>(
auto m_matmul0 = std::make_shared<ov::op::v0::MatMul>(
ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape()),
ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape()));
register_matcher(std::make_shared<ov::pass::pattern::Matcher>(m_matmul0, matcher_name),
[=](ov::pass::pattern::Matcher &m) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::ExplicitTransposeMatMulInputs")
auto root = m.get_match_root();
bool rewritten = false;
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::ExplicitTransposeMatMulInputs")
auto root = m.get_match_root();
bool rewritten = false;
auto matmul0 = ov::as_type_ptr<ov::opset1::MatMul>(root);
if (!matmul0)
return false;
auto matmul = ov::as_type_ptr<ov::op::v0::MatMul>(root);
if (!matmul)
return false;
for (size_t i = 0; i < matmul0->get_input_size(); i++) {
if (i == 0 && !matmul0->get_transpose_a())
continue;
if (i == 1 && !matmul0->get_transpose_b())
continue;
auto parent1 = matmul0->get_input_node_shared_ptr(i);
auto transpose1 = ov::as_type_ptr<ov::opset1::Transpose>(parent1);
while (!transpose1 && !ov::is_type<ov::opset1::Parameter>(parent1)) {
// We can set supported order and transposed_b(false) only if ops have scalar shapes to avoid shape mismatching
const auto parent_count = parent1->inputs().size();
bool are_weights_scalar = true;
for (size_t j = 1; j < parent_count; ++j) {
are_weights_scalar = are_weights_scalar && ov::shape_size(parent1->get_input_shape(j)) == 1;
}
if (!are_weights_scalar)
break;
parent1 = parent1->get_input_node_shared_ptr(0);
transpose1 = ov::as_type_ptr<ov::opset1::Transpose>(parent1);
if (matmul->get_transpose_a()) {
extract(matmul->input(0));
matmul->set_transpose_a(false);
rewritten |= true;
}
if (!transpose1)
continue;
const auto transpose_pattern = ov::as_type_ptr<ov::opset1::Constant>(transpose1->get_input_node_shared_ptr(1));
if (!transpose_pattern)
continue;
auto transposed_order = transpose_pattern->cast_vector<int32_t>();
std::swap(*transposed_order.rbegin(), *(transposed_order.rbegin() + 1));
if (pass::TransposeDecomposition::supported_cases.count(transposed_order) == 0)
continue;
auto new_transpose_order = std::make_shared<ov::opset1::Constant>(transpose_pattern->get_element_type(),
ov::Shape{4},
transposed_order);
new_transpose_order->set_friendly_name(transpose_pattern->get_friendly_name());
ov::copy_runtime_info(transpose_pattern, new_transpose_order);
transpose1->set_argument(1, new_transpose_order);
if (i == 0) {
matmul0->set_transpose_a(false);
} else {
matmul0->set_transpose_b(false);
if (matmul->get_transpose_b()) {
extract(matmul->input(1));
matmul->set_transpose_b(false);
rewritten |= true;
}
rewritten |= true;
}
return rewritten;
});
return rewritten;
});
}

View File

@ -7,6 +7,7 @@
#include "snippets/itt.hpp"
#include "snippets/pass/collapse_subgraph.hpp"
#include "snippets/pass/explicit_transpose_matmul_inputs.hpp"
#include "snippets/op/subgraph.hpp"
#include "snippets/op/brgemm.hpp"
#include "snippets/utils.hpp"
@ -156,12 +157,7 @@ auto update_intermediate_supported_ops(std::shared_ptr<ov::Node>& interm_op, ov:
break;
// Add node only if there are scalar constants on inputs because of plugin-specific limitation
bool are_weights_scalar = true;
const auto parent_count = parent->get_input_size();
for (size_t i = 1; i < parent_count; ++i) {
are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1;
}
if (!are_weights_scalar)
if (!ov::snippets::pass::ExplicitTransposeMatMulInputs::are_weights_scalar(parent))
break;
ordered_ops.insert(ordered_ops.begin() + shift, parent);
@ -321,22 +317,27 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
/***** Transposes *****/
/* There may be Transpose and Reshape ops on inputs and outputs of MHA-pattern skeleton
* We can add them into Subgraph body
* Transpose0 Transpose1
* \ /
* MatMul0
* |
* [...] Transpose2
* \ /
* MatMul1
* |
* Transpose3
*/
auto tokenize_transpose = [config](const std::shared_ptr<ov::Node>& node) -> std::shared_ptr<ov::opset1::Transpose> {
return config.mha_token_enable_transpose ? ov::as_type_ptr<ov::opset1::Transpose>(node)
: nullptr;
};
// First input branch of MatMul0 should be executed before second input branch of MatMul0,
// so firstly we insert Transpose1 on the beginning of ordered_ops and then Transpose1
bool are_weights_scalar = true;
// so firstly we insert Transpose1 on the beginning of ordered_ops and then Transpose0
// Note: If MatMul0 has transposed_b, we should tokenize only scalars ops from 1st branch
// to move extracted Transpose from MatMul input to body Parameter
auto parent = matmul0->get_input_node_shared_ptr(1);
// We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order (or without this Transpose1)
// only if these ops have scalar shapes on other inputs.
// There is transformation ExplicitTransposeMatMulInputs that set supported order and transposed_b(false).
// We can allow to call this pass only if ops have scalar shapes to avoid shape mismatching
const auto is_transposed_b_0 = matmul0->get_transpose_b();
auto parent = matmul0->get_input_node_shared_ptr(1);
while (is_supported_intermediate_op(parent)) {
// All supported ops have only one output port
if (parent->get_output_target_inputs(0).size() != 1)
@ -344,15 +345,8 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
// Only if MatMul0 has transposed_b, we have to tokenize scalar ops
// to move explicit Transpose from MatMul0 input_1 to Parameter of Subgraph body
if (is_transposed_b_0) {
const auto parent_count = parent->get_input_size();
bool are_weights_scalar = true;
for (size_t i = 1; i < parent_count; ++i) {
are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1;
}
if (!are_weights_scalar) {
break;
}
if (is_transposed_b_0 && !ov::snippets::pass::ExplicitTransposeMatMulInputs::are_weights_scalar(parent)) {
break;
}
// To avoid unsupported number of non-scalar Constants in the future after FakeQuantize decomposition (plugin specific limitation)
@ -360,53 +354,45 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
if (const auto fq_node = ov::as_type_ptr<ov::op::v0::FakeQuantize>(parent)) {
hidden_virtual_ports_count += ov::snippets::utils::get_non_scalar_constant_count_for_fq(fq_node);
}
potential_body_params_count += get_potential_body_params(parent);
ordered_ops.insert(ordered_ops.begin(), parent);
// TODO [107731] To go always through 0-th port - is it safe?
// [107731] To go always through 0-th port - is it safe?
parent = parent->get_input_node_shared_ptr(0);
}
const auto transpose1 = tokenize_transpose(parent);
if (is_transposed_b_0) {
if (is_valid_transpose(transpose1, {0, 2, 1, 3})) {
// We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order
// only if these ops have scalar shapes on other inputs.
// There is transformation ExplicitTransposeMatMulInputs that set supported order and transposed_b(false).
// We can allow to call this pass only if ops have scalar shapes to avoid shape mismatching
if (are_weights_scalar) {
ordered_ops.insert(ordered_ops.begin(), transpose1);
} else {
return false;
auto tokenize_transpose = [&](const std::shared_ptr<ov::opset1::Transpose>& transpose,
bool is_input_transposed, std::vector<int64_t> order,
const ov::NodeVector::const_iterator& pos) {
// If Transpose has valid order for the Transpose fusing (ExplicitTransposeMatMulInputs pass call), tokenize him.
// Otherwise, skip the Transpose.
if (!is_input_transposed) {
if (is_valid_transpose(transpose, order)) {
ordered_ops.insert(pos, transpose);
}
} else {
return false;
return;
}
} else {
if (is_valid_transpose(transpose1, {0, 2, 3, 1})) {
ordered_ops.insert(ordered_ops.begin(), transpose1);
auto transposed_order = order;
const auto rank = transposed_order.size();
if (rank < 2)
return;
std::swap(transposed_order[rank - 1], transposed_order[rank - 2]);
if (is_valid_transpose(transpose, transposed_order)) {
ordered_ops.insert(pos, transpose);
}
}
};
if (transpose1) {
// Between Transpose1 and MatMul0 will be the one Loop because of LoopFusing optimization.
// The Loop will have one Buffer with the same shape both on input and output.
// Need to check for precision to get if we need one more register for Buffer
if (matmul0->get_input_element_type(1).size() != transpose1->get_output_element_type(0).size()) {
buffer_count++;
}
}
auto get_transpose = [config](const std::shared_ptr<ov::Node>& node) -> std::shared_ptr<ov::opset1::Transpose> {
return config.mha_token_enable_transpose ? ov::as_type_ptr<ov::opset1::Transpose>(node)
: nullptr;
};
const auto transpose0 = tokenize_transpose(matmul0->get_input_node_shared_ptr(0));
if (is_valid_transpose(transpose0, {0, 2, 1, 3})) {
ordered_ops.insert(ordered_ops.begin(), transpose0);
} else if (matmul0->get_transpose_a()) {
return false;
}
const auto transpose2 = tokenize_transpose(matmul1->get_input_node_shared_ptr(1));
if (is_valid_transpose(transpose2, {0, 2, 1, 3})) {
ordered_ops.push_back(transpose2);
}
const auto transpose1 = get_transpose(parent);
const auto transpose0 = get_transpose(matmul0->get_input_node_shared_ptr(0));
const auto transpose2 = get_transpose(matmul1->get_input_node_shared_ptr(1));
tokenize_transpose(transpose1, is_transposed_b_0, {0, 2, 3, 1}, ordered_ops.begin());
tokenize_transpose(transpose0, matmul0->get_transpose_a(), {0, 2, 1, 3}, ordered_ops.begin());
tokenize_transpose(transpose2, matmul1->get_transpose_b(), {0, 2, 1, 3}, ordered_ops.end());
ordered_ops.push_back(matmul1);
bool are_ops_after_matmul1 = false;
@ -439,7 +425,7 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
// <Supported ops>
// Transpose3
if (!are_ops_after_matmul1) {
auto transpose3 = tokenize_transpose(child);
auto transpose3 = get_transpose(child);
if (is_valid_transpose(transpose3, {0, 2, 1, 3}) &&
transpose3->get_input_element_type(0) == matmul1_out_type) { // To avoid Convert between MatMul1 and Transpose3
ordered_ops.push_back(transpose3);

View File

@ -7,7 +7,7 @@
#include <subgraph_mha.hpp>
#include "snippets/pass/tokenization.hpp"
#include "snippets/pass/mha_tokenization.hpp"
#include "snippets/pass/explicit_transpose_matmul_inputs.hpp"
#include "snippets/pass/common_optimizations.hpp"
namespace ov {
namespace test {
@ -15,9 +15,10 @@ namespace snippets {
void TokenizeMHASnippetsTests::run() {
ASSERT_TRUE(function);
std::string name;
manager.register_pass<ov::snippets::pass::EnumerateNodes>();
manager.register_pass<ov::snippets::pass::TokenizeMHASnippets>();
manager.register_pass<ov::snippets::pass::CommonOptimizations>();
disable_rt_info_check();
}
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA) {
@ -43,6 +44,31 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_with_int_Matmuls) {
run();
}
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Transpose_extraction) {
const auto& f = MHATransposedInputFunction(std::vector<PartialShape>{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 128, 12, 64}}, true);
function = f.getOriginal();
function_ref = f.getReference();
run();
}
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Transpose_extraction_and_unsupported_existing_transpose) {
const auto& f = MHATransposedInputFunction(std::vector<PartialShape>{{1, 128, 12, 64}, {1, 12, 64, 128}, {1, 128, 12, 64}}, true,
std::vector<int64_t>{0, 3, 1, 2});
function = f.getOriginal();
function_ref = f.getReference();
run();
}
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Transpose_fusion) {
const auto& f = MHATransposedInputFunction(std::vector<PartialShape>{{1, 128, 12, 64}, {1, 64, 128, 12}, {1, 128, 12, 64}}, false,
std::vector<int64_t>{0, 2, 1, 3});
function = f.getOriginal();
function_ref = f.getReference();
run();
}
} // namespace snippets
} // namespace test
} // namespace ov

View File

@ -227,6 +227,21 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFQ, MHAFQ,
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
MHA::getTestCaseName);
const std::vector<std::vector<ov::PartialShape>> inputShapesTransposedB = {
{{1, 12, 12, 64}, {1, 12, 48, 64}, {1, 12, 48, 64}}
};
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHATransposedB, MHATransposedB,
::testing::Combine(
::testing::ValuesIn(inputShapesTransposedB),
::testing::Values(std::vector<element::Type>{}),
::testing::Values(ov::element::f32),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(2),
::testing::Values(1),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(std::map<std::string, std::string>{})),
MHA::getTestCaseName);
} // namespace
} // namespace snippets

View File

@ -5,6 +5,7 @@
#pragma once
#include "shared_test_classes/base/snippets_test_utils.hpp"
#include "ngraph_helpers/snippets_ngraph_functions/include/snippets_helpers.hpp"
namespace ov {
namespace test {
@ -30,7 +31,7 @@ protected:
void SetUp() override;
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override;
virtual void init_subgraph();
virtual std::shared_ptr<SnippetsFunctionBase> get_subgraph();
bool m_with_mul = false;
std::vector<ov::element::Type> m_input_types;
@ -39,39 +40,42 @@ protected:
class MHASelect : public MHA {
protected:
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override;
void init_subgraph() override;
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
};
class MHAWOTransposeOnInputs : public MHA {
protected:
void init_subgraph() override;
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
};
class MHAWOTranspose : public MHA {
protected:
void init_subgraph() override;
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
};
class MHAMulAdd : public MHA {
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
};
class MHATransposedB : public MHA {
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
};
class MHAINT8MatMul : public MHA {
protected:
void init_subgraph() override;
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
};
class MHAFQAfterMatMul : public MHA {
protected:
void init_subgraph() override;
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
};
class MHAFQ : public MHA {
protected:
void init_subgraph() override;
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
};
class MHAMulAdd : public MHA {
void init_subgraph() override;
};
} // namespace snippets
} // namespace test
} // namespace ov

View File

@ -51,7 +51,8 @@ void MHA::SetUp() {
std::tie(inputShapes, m_input_types, prc, m_with_mul, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam();
init_input_shapes(static_partial_shapes_to_test_representation(inputShapes));
init_subgraph();
const auto subgraph_model = get_subgraph();
function = subgraph_model->getOriginal();
configuration.insert(additionalConfig.begin(), additionalConfig.end());
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
@ -76,9 +77,8 @@ void MHA::generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticSha
}
}
void MHA::init_subgraph() {
auto f = ov::test::snippets::MHAFunction(inputDynamicShapes, m_input_types, m_with_mul);
function = f.getOriginal();
std::shared_ptr<SnippetsFunctionBase> MHA::get_subgraph() {
return std::make_shared<ov::test::snippets::MHAFunction>(inputDynamicShapes, m_input_types, m_with_mul);
}
void MHASelect::generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) {
@ -99,39 +99,36 @@ void MHASelect::generate_inputs(const std::vector<ngraph::Shape>& targetInputSta
}
}
void MHASelect::init_subgraph() {
auto f = ov::test::snippets::MHASelectFunction(inputDynamicShapes, m_input_types);
function = f.getOriginal();
std::shared_ptr<SnippetsFunctionBase> MHASelect::get_subgraph() {
return std::make_shared<ov::test::snippets::MHASelectFunction>(inputDynamicShapes, m_input_types);
}
void MHAWOTransposeOnInputs::init_subgraph() {
auto f = ov::test::snippets::MHAWOTransposeOnInputsFunction(inputDynamicShapes);
function = f.getOriginal();
std::shared_ptr<SnippetsFunctionBase> MHAWOTransposeOnInputs::get_subgraph() {
return std::make_shared<ov::test::snippets::MHAWOTransposeOnInputsFunction>(inputDynamicShapes);
}
void MHAWOTranspose::init_subgraph() {
auto f = ov::test::snippets::MHAWOTransposeFunction(inputDynamicShapes, m_input_types);
function = f.getOriginal();
std::shared_ptr<SnippetsFunctionBase> MHAWOTranspose::get_subgraph() {
return std::make_shared<ov::test::snippets::MHAWOTransposeFunction>(inputDynamicShapes, m_input_types);
}
void MHAINT8MatMul::init_subgraph() {
auto f = ov::test::snippets::MHAINT8MatMulFunction(inputDynamicShapes);
function = f.getOriginal();
std::shared_ptr<SnippetsFunctionBase> MHAINT8MatMul::get_subgraph() {
return std::make_shared<ov::test::snippets::MHAINT8MatMulFunction>(inputDynamicShapes);
}
void MHAFQAfterMatMul::init_subgraph() {
auto f = ov::test::snippets::MHAFQAfterMatMulFunction(inputDynamicShapes);
function = f.getOriginal();
std::shared_ptr<SnippetsFunctionBase> MHAFQAfterMatMul::get_subgraph() {
return std::make_shared<ov::test::snippets::MHAFQAfterMatMulFunction>(inputDynamicShapes);
}
void MHAFQ::init_subgraph() {
auto f = ov::test::snippets::MHAFQFunction(inputDynamicShapes);
function = f.getOriginal();
std::shared_ptr<SnippetsFunctionBase> MHAFQ::get_subgraph() {
return std::make_shared<ov::test::snippets::MHAFQFunction>(inputDynamicShapes);
}
void MHAMulAdd::init_subgraph() {
auto f = ov::test::snippets::MHAMulAddFunction(inputDynamicShapes);
function = f.getOriginal();
std::shared_ptr<SnippetsFunctionBase> MHAMulAdd::get_subgraph() {
return std::make_shared<ov::test::snippets::MHAMulAddFunction>(inputDynamicShapes);
}
std::shared_ptr<SnippetsFunctionBase> MHATransposedB::get_subgraph() {
return std::make_shared<ov::test::snippets::MHATransposedInputFunction>(inputDynamicShapes, true);
}
TEST_P(MHA, CompareWithRefImpl) {
@ -153,26 +150,37 @@ TEST_P(MHAWOTransposeOnInputs, CompareWithRefImpl) {
}
TEST_P(MHAWOTranspose, CompareWithRefImpl) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
validateNumSubgraphs();
}
TEST_P(MHAMulAdd, CompareWithRefImpl) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
validateNumSubgraphs();
}
TEST_P(MHATransposedB, CompareWithRefImpl) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
validateNumSubgraphs();
}
TEST_P(MHAINT8MatMul, CompareWithRefImpl) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
validateNumSubgraphs();
}
TEST_P(MHAFQAfterMatMul, CompareWithRefImpl) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
validateNumSubgraphs();
}
TEST_P(MHAFQ, CompareWithRefImpl) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
validateNumSubgraphs();
}

View File

@ -289,6 +289,31 @@ protected:
std::shared_ptr<ov::Model> initOriginal() const override;
};
/* Graph:
* Transpose/Parameter
* \ /
* MatMul0 [transposed_b = true/false]
* |
* Softmax
* \ /
* MatMul1
* |
*/
class MHATransposedInputFunction : public SnippetsFunctionBase {
public:
explicit MHATransposedInputFunction(const std::vector<PartialShape>& inputShapes, bool transposed_b = false,
std::vector<int64_t> order = {})
: SnippetsFunctionBase(inputShapes), m_transposed_b(transposed_b), m_order(order) {
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
std::shared_ptr<ov::Model> initReference() const override;
bool m_transposed_b = false;
std::vector<int64_t> m_order = {};
};
} // namespace snippets
} // namespace test
} // namespace ov

View File

@ -685,6 +685,72 @@ std::shared_ptr<ov::Model> MHAMulAddFunction::initOriginal() const {
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
std::shared_ptr<ov::Model> MHATransposedInputFunction::initOriginal() const {
const auto param0 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
const auto param1 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
const auto param2 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
ngraph::ParameterVector ngraphParam = {param0, param1, param2};
std::shared_ptr<ov::Node> matmul0_in1 = param1;
if (!m_order.empty()) {
const auto transposeConst = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{m_order.size()}, m_order);
matmul0_in1 = std::make_shared<ov::op::v1::Transpose>(param1, transposeConst);
}
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(param0, matmul0_in1, false, m_transposed_b);
const auto softmax = std::make_shared<ngraph::opset8::Softmax>(matMul0, -1);
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, param2);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(matMul1)};
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
std::shared_ptr<ov::Model> MHATransposedInputFunction::initReference() const {
const auto data0 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
const auto data1 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
const auto data2 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
ngraph::ParameterVector ngraphParam = {data0, data1, data2};
bool is_supported = ((m_transposed_b && m_order == std::vector<int64_t>{0, 2, 1, 3}) ||
(!m_transposed_b && m_order == std::vector<int64_t>{0, 2, 3, 1}));
std::shared_ptr<ov::Node> in1 = data1;
if (!m_order.empty() && !is_supported) {
const auto transposeConst = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{m_order.size()}, m_order);
in1 = std::make_shared<ov::op::v1::Transpose>(in1, transposeConst);
}
if (m_transposed_b) {
if (m_order != std::vector<int64_t>{0, 2, 1, 3}) {
const auto rank = input_shapes[1].size();
std::vector<int32_t> transpose_order(rank, 0);
std::iota(transpose_order.begin(), transpose_order.end(), 0);
std::swap(transpose_order[rank - 1], transpose_order[rank - 2]);
const auto transposeConst = ngraph::builder::makeConstant(ngraph::element::i32, ov::Shape{transpose_order.size()}, transpose_order);
in1 = std::make_shared<ov::op::v1::Transpose>(in1, transposeConst);
}
}
const auto param0 = std::make_shared<ngraph::opset1::Parameter>(precision, data0->get_shape());
const auto param1 = std::make_shared<ngraph::opset1::Parameter>(precision, in1->get_shape());
const auto param2 = std::make_shared<ngraph::opset1::Parameter>(precision, data2->get_shape());
std::shared_ptr<ov::Node> matmul0_in1 = param1;
if (!m_order.empty() && is_supported) {
const auto transposeConst = ngraph::builder::makeConstant(ngraph::element::i32, ov::Shape{m_order.size()}, m_order);
matmul0_in1 = std::make_shared<ov::op::v1::Transpose>(param1, transposeConst);
}
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(param0, matmul0_in1);
const auto softmax = std::make_shared<ngraph::opset8::Softmax>(matMul0, -1);
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, param2);
auto subgraph = std::make_shared<ov::snippets::op::Subgraph>(ov::NodeVector{data0, in1, data2},
std::make_shared<ov::Model>(NodeVector{matMul1}, ov::ParameterVector{param0, param1, param2}));
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(subgraph)};
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
} // namespace snippets
} // namespace test
} // namespace ov