[Snippets] Added support of MatMuls with transposed inputs (#17819)
This commit is contained in:
parent
9754117a61
commit
67dc220d38
@ -5,7 +5,8 @@
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
|
||||
#include "snippets/op/subgraph.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace snippets {
|
||||
@ -15,6 +16,12 @@ class CommonOptimizations : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("CommonOptimizations", "0");
|
||||
CommonOptimizations();
|
||||
|
||||
private:
|
||||
// Move up Constants which aren't scalars from body to Subgraph and replace them with Parameters inside body
|
||||
void ExtractConstants(const std::shared_ptr<op::Subgraph>& subgraph);
|
||||
// Move up unsupported Transposes on Parameter outputs from body
|
||||
void ExtractUnsupportedTransposes(const std::shared_ptr<op::Subgraph>& subgraph);
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
|
@ -5,7 +5,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace snippets {
|
||||
@ -13,18 +12,25 @@ namespace pass {
|
||||
|
||||
/**
|
||||
* @interface ExplicitTransposeMatMulInputs
|
||||
* @brief At the moment Snippets supports Transpose only with order {0, 2, 3, 1},
|
||||
* so if there is pattern in graph:
|
||||
* in0 Transpose{0, 2, 1, 3}
|
||||
* \ /
|
||||
* MatMul[false, true]
|
||||
* We can set false in MatMul parameter `transposed_b` and
|
||||
* change Transpose order to {0, 2, 3, 1} which is supported by Snippets
|
||||
* @brief The pass extracts explicit Transpose node from MatMul with transposed_<a|b> and moves it to Parameter.
|
||||
* If there is another Transpose, the pass fuses extracted Transpose and existing Transpose.
|
||||
* For example, At the moment Snippets supports Transpose only with order {0, 2, 3, 1}, so if there is pattern in graph:
|
||||
* in0 Transpose{0, 2, 1, 3}
|
||||
* \ /
|
||||
* MatMul[false, true]
|
||||
* We can set `false` in MatMul parameter `transposed_b` and change Transpose order to {0, 2, 3, 1} which is supported by Snippets
|
||||
* @ingroup snippets
|
||||
*/
|
||||
class ExplicitTransposeMatMulInputs: public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("ExplicitTransposeMatMulInputs", "0");
|
||||
ExplicitTransposeMatMulInputs();
|
||||
|
||||
// Return `True` if all inputs (except 0-th input) have scalar shape. Otherwise returns `False`
|
||||
static bool are_weights_scalar(const std::shared_ptr<ov::Node>& node);
|
||||
|
||||
private:
|
||||
static void extract(const ov::Input<ov::Node>& input);
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
|
@ -4,27 +4,24 @@
|
||||
|
||||
#include "snippets/pass/common_optimizations.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include "openvino/opsets/opset1.hpp"
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
|
||||
#include "transformations/utils/utils.hpp"
|
||||
#include "snippets/pass/fq_decomposition.hpp"
|
||||
#include "snippets/pass/softmax_reshape_elimination.hpp"
|
||||
#include "snippets/pass/explicit_transpose_matmul_inputs.hpp"
|
||||
#include "snippets/pass/transpose_decomposition.hpp"
|
||||
#include "snippets/pass/fuse_transpose_brgemm.hpp"
|
||||
#include "snippets/op/subgraph.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
#include "snippets/itt.hpp"
|
||||
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace snippets {
|
||||
namespace pass {
|
||||
|
||||
|
||||
// Move up Constants which aren't scalars from body to Subgraph and replace them with Parameters inside body
|
||||
void ConvertConstantsToParameters(const std::shared_ptr<ov::snippets::op::Subgraph>& subgraph) {
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ConvertConstantsToParameters");
|
||||
void CommonOptimizations::ExtractConstants(const std::shared_ptr<ov::snippets::op::Subgraph>& subgraph) {
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ExtractConstants");
|
||||
auto body = subgraph->body_ptr();
|
||||
|
||||
ParameterVector new_parameters;
|
||||
@ -55,6 +52,52 @@ void ConvertConstantsToParameters(const std::shared_ptr<ov::snippets::op::Subgra
|
||||
}
|
||||
}
|
||||
|
||||
void CommonOptimizations::ExtractUnsupportedTransposes(const std::shared_ptr<op::Subgraph>& subgraph) {
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::ExtractUnsupportedTransposes");
|
||||
const auto& body = subgraph->body_ptr();
|
||||
const auto parameters = body->get_parameters();
|
||||
// [107806]: If count of Parameters isn't equal to Subgraph inputs,
|
||||
// we cannot guarantee correct extraction since we don't have correct connections between body I/O and Subgraph I/O.
|
||||
OPENVINO_ASSERT(parameters.size() == subgraph->input_values().size(),
|
||||
"Failed to extract unsupported transposes: the count of Parameters isn't equal to Subgraph inputs");
|
||||
|
||||
bool updated = false;
|
||||
for (size_t i = 0; i < parameters.size(); ++i) {
|
||||
const auto& parameter = parameters[i];
|
||||
const auto& consumers = parameter->get_output_target_inputs(0);
|
||||
if (consumers.size() != 1)
|
||||
continue;
|
||||
|
||||
const auto transpose = ov::as_type_ptr<opset1::Transpose>(consumers.begin()->get_node()->shared_from_this());
|
||||
if (!transpose)
|
||||
continue;
|
||||
|
||||
const auto& order = ov::as_type_ptr<opset1::Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
if (!order)
|
||||
continue;
|
||||
|
||||
const auto order_value = order->cast_vector<int>();
|
||||
const auto transpose_child = *(transpose->get_output_target_inputs(0).begin());
|
||||
const auto is_brgemm_case = ov::is_type<opset1::MatMul>(transpose_child.get_node()->shared_from_this());
|
||||
// If Transpose is supported (can be decomposed or fused into Brgemm), skip
|
||||
if ((is_brgemm_case && FuseTransposeBrgemm::supported_cases.count(order_value) != 0) ||
|
||||
(TransposeDecomposition::supported_cases.count(order_value) != 0))
|
||||
continue;
|
||||
|
||||
// If the transpose isn't supported - we have to extract it from Subgraph
|
||||
transpose->set_argument(0, subgraph->input_value(i));
|
||||
subgraph->set_argument(i, transpose);
|
||||
transpose_child.replace_source_output(parameter);
|
||||
// Update shape
|
||||
parameter->set_partial_shape(transpose->get_output_partial_shape(0));
|
||||
updated = true;
|
||||
}
|
||||
|
||||
if (updated) {
|
||||
subgraph->validate_and_infer_types();
|
||||
}
|
||||
}
|
||||
|
||||
CommonOptimizations::CommonOptimizations() {
|
||||
MATCHER_SCOPE(CommonOptimizations);
|
||||
ov::graph_rewrite_callback callback = [this](ov::pass::pattern::Matcher& m) {
|
||||
@ -65,10 +108,10 @@ CommonOptimizations::CommonOptimizations() {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto body = subgraph->body_ptr();
|
||||
const auto& body = subgraph->body_ptr();
|
||||
const auto is_quantized = subgraph->is_quantized();
|
||||
|
||||
// Firsly we should transform all original Converts inside body to ConvertTruncation to save original behavior.
|
||||
// Firstly, we should transform all original Converts inside body to ConvertTruncation to save original behavior.
|
||||
// Then if Subgraph contains FakeQuantize we enable specific transformation for quantized subgraphs.
|
||||
ov::pass::Manager manager;
|
||||
manager.register_pass<ov::snippets::pass::TransformConvertToConvertTruncation>();
|
||||
@ -80,15 +123,18 @@ CommonOptimizations::CommonOptimizations() {
|
||||
manager.run_passes(body);
|
||||
|
||||
// At the moment only non-scalar Constants of FakeQuantize can be inside Subgraph
|
||||
// so we can enable ConvertConstantsToParameters pass for quantized models
|
||||
// so we can enable ExtractConstants pass for quantized models
|
||||
if (is_quantized) {
|
||||
ConvertConstantsToParameters(subgraph);
|
||||
ExtractConstants(subgraph);
|
||||
}
|
||||
// Extract unsupported Transposes from body
|
||||
if (subgraph->has_domain_sensitive_ops()) {
|
||||
ExtractUnsupportedTransposes(subgraph);
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(ov::pass::pattern::wrap_type<ov::snippets::op::Subgraph>(),
|
||||
matcher_name);
|
||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(ov::pass::pattern::wrap_type<ov::snippets::op::Subgraph>(), matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
|
@ -2,79 +2,101 @@
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "snippets/pass/explicit_transpose_matmul_inputs.hpp"
|
||||
|
||||
#include "snippets/op/subgraph.hpp"
|
||||
#include "snippets/itt.hpp"
|
||||
|
||||
#include "snippets/pass/explicit_transpose_matmul_inputs.hpp"
|
||||
#include "snippets/pass/transpose_decomposition.hpp"
|
||||
#include "snippets/op/subgraph.hpp"
|
||||
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
|
||||
bool ov::snippets::pass::ExplicitTransposeMatMulInputs::are_weights_scalar(const std::shared_ptr<ov::Node>& node) {
|
||||
const auto inputs = node->inputs();
|
||||
return std::all_of(inputs.begin() + 1, inputs.end(),
|
||||
[](const ov::Input<ov::Node>& in) {
|
||||
return in.get_partial_shape().is_static() && ov::shape_size(in.get_shape()) == 1;
|
||||
});
|
||||
}
|
||||
|
||||
void ov::snippets::pass::ExplicitTransposeMatMulInputs::extract(const ov::Input<ov::Node>& input) {
|
||||
auto parent = input.get_source_output().get_node_shared_ptr();
|
||||
auto transpose = ov::as_type_ptr<ov::op::v1::Transpose>(parent);
|
||||
while (!transpose && !ov::is_type<ov::op::v0::Parameter>(parent)) {
|
||||
// We can set supported order and transposed_<a|b>=false only if ops have scalar shapes to avoid shape mismatching
|
||||
if (!are_weights_scalar(parent))
|
||||
break;
|
||||
|
||||
parent = parent->get_input_node_shared_ptr(0);
|
||||
transpose = ov::as_type_ptr<ov::op::v1::Transpose>(parent);
|
||||
}
|
||||
|
||||
// If there isn't another Transpose, need to create new Transpose
|
||||
if (transpose) {
|
||||
const auto transpose_pattern = ov::as_type_ptr<ov::op::v0::Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
OPENVINO_ASSERT(transpose_pattern,
|
||||
"ExplicitTransposeMatMulInputs expects existing Transpose with Constant order");
|
||||
|
||||
auto transposed_order = transpose_pattern->cast_vector<int32_t>();
|
||||
OPENVINO_ASSERT(transposed_order.size() > 2, "Incorrect Transpose order for ExplicitTransposeMatMulInputs");
|
||||
std::swap(*transposed_order.rbegin(), *(transposed_order.rbegin() + 1));
|
||||
|
||||
auto new_transpose_order = std::make_shared<ov::op::v0::Constant>(transpose_pattern->get_element_type(),
|
||||
ov::Shape{transposed_order.size()},
|
||||
transposed_order);
|
||||
new_transpose_order->set_friendly_name(transpose_pattern->get_friendly_name());
|
||||
ov::copy_runtime_info(transpose_pattern, new_transpose_order);
|
||||
transpose->set_argument(1, new_transpose_order);
|
||||
return;
|
||||
}
|
||||
|
||||
// Create new Transpose before Parameter
|
||||
OPENVINO_ASSERT(ov::is_type<opset1::Parameter>(parent),
|
||||
"ExplicitTransposeMatMulInputs expects Parameter in cases when there isn't existing Transpose on input");
|
||||
const auto& consumers = parent->get_output_target_inputs(0);
|
||||
OPENVINO_ASSERT(consumers.size() == 1,
|
||||
"ExplicitTransposeMatMulInputs expects Parameter with one consumer in cases when there isn't existing Transpose on input");
|
||||
// Extract Transpose from MatMul
|
||||
OPENVINO_ASSERT(input.get_partial_shape().is_static(), "ExplicitTransposeMatMulInputs supports only static shapes");
|
||||
const auto rank = input.get_shape().size();
|
||||
std::vector<size_t> transpose_order(rank, 0);
|
||||
std::iota(transpose_order.begin(), transpose_order.end(), 0);
|
||||
std::swap(transpose_order[rank - 1], transpose_order[rank - 2]);
|
||||
|
||||
const auto constant_order = std::make_shared<opset1::Constant>(ov::element::i32, ov::Shape{rank}, transpose_order);
|
||||
const auto new_transpose = std::make_shared<opset1::Transpose>(parent, constant_order); // parent is Parameter
|
||||
const auto consumer_input = *(consumers.begin());
|
||||
consumer_input.replace_source_output(new_transpose);
|
||||
}
|
||||
|
||||
ov::snippets::pass::ExplicitTransposeMatMulInputs::ExplicitTransposeMatMulInputs() {
|
||||
MATCHER_SCOPE(ExplicitTransposeMatMulInputs);
|
||||
|
||||
auto m_matmul0 = std::make_shared<ov::opset1::MatMul>(
|
||||
auto m_matmul0 = std::make_shared<ov::op::v0::MatMul>(
|
||||
ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape()),
|
||||
ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape()));
|
||||
|
||||
register_matcher(std::make_shared<ov::pass::pattern::Matcher>(m_matmul0, matcher_name),
|
||||
[=](ov::pass::pattern::Matcher &m) {
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::ExplicitTransposeMatMulInputs")
|
||||
auto root = m.get_match_root();
|
||||
bool rewritten = false;
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::ExplicitTransposeMatMulInputs")
|
||||
auto root = m.get_match_root();
|
||||
bool rewritten = false;
|
||||
|
||||
auto matmul0 = ov::as_type_ptr<ov::opset1::MatMul>(root);
|
||||
if (!matmul0)
|
||||
return false;
|
||||
auto matmul = ov::as_type_ptr<ov::op::v0::MatMul>(root);
|
||||
if (!matmul)
|
||||
return false;
|
||||
|
||||
for (size_t i = 0; i < matmul0->get_input_size(); i++) {
|
||||
if (i == 0 && !matmul0->get_transpose_a())
|
||||
continue;
|
||||
if (i == 1 && !matmul0->get_transpose_b())
|
||||
continue;
|
||||
|
||||
auto parent1 = matmul0->get_input_node_shared_ptr(i);
|
||||
auto transpose1 = ov::as_type_ptr<ov::opset1::Transpose>(parent1);
|
||||
while (!transpose1 && !ov::is_type<ov::opset1::Parameter>(parent1)) {
|
||||
// We can set supported order and transposed_b(false) only if ops have scalar shapes to avoid shape mismatching
|
||||
const auto parent_count = parent1->inputs().size();
|
||||
bool are_weights_scalar = true;
|
||||
for (size_t j = 1; j < parent_count; ++j) {
|
||||
are_weights_scalar = are_weights_scalar && ov::shape_size(parent1->get_input_shape(j)) == 1;
|
||||
}
|
||||
if (!are_weights_scalar)
|
||||
break;
|
||||
|
||||
parent1 = parent1->get_input_node_shared_ptr(0);
|
||||
transpose1 = ov::as_type_ptr<ov::opset1::Transpose>(parent1);
|
||||
if (matmul->get_transpose_a()) {
|
||||
extract(matmul->input(0));
|
||||
matmul->set_transpose_a(false);
|
||||
rewritten |= true;
|
||||
}
|
||||
if (!transpose1)
|
||||
continue;
|
||||
|
||||
const auto transpose_pattern = ov::as_type_ptr<ov::opset1::Constant>(transpose1->get_input_node_shared_ptr(1));
|
||||
if (!transpose_pattern)
|
||||
continue;
|
||||
|
||||
auto transposed_order = transpose_pattern->cast_vector<int32_t>();
|
||||
std::swap(*transposed_order.rbegin(), *(transposed_order.rbegin() + 1));
|
||||
if (pass::TransposeDecomposition::supported_cases.count(transposed_order) == 0)
|
||||
continue;
|
||||
|
||||
auto new_transpose_order = std::make_shared<ov::opset1::Constant>(transpose_pattern->get_element_type(),
|
||||
ov::Shape{4},
|
||||
transposed_order);
|
||||
new_transpose_order->set_friendly_name(transpose_pattern->get_friendly_name());
|
||||
ov::copy_runtime_info(transpose_pattern, new_transpose_order);
|
||||
transpose1->set_argument(1, new_transpose_order);
|
||||
if (i == 0) {
|
||||
matmul0->set_transpose_a(false);
|
||||
} else {
|
||||
matmul0->set_transpose_b(false);
|
||||
if (matmul->get_transpose_b()) {
|
||||
extract(matmul->input(1));
|
||||
matmul->set_transpose_b(false);
|
||||
rewritten |= true;
|
||||
}
|
||||
rewritten |= true;
|
||||
}
|
||||
|
||||
return rewritten;
|
||||
});
|
||||
return rewritten;
|
||||
});
|
||||
}
|
||||
|
@ -7,6 +7,7 @@
|
||||
|
||||
#include "snippets/itt.hpp"
|
||||
#include "snippets/pass/collapse_subgraph.hpp"
|
||||
#include "snippets/pass/explicit_transpose_matmul_inputs.hpp"
|
||||
#include "snippets/op/subgraph.hpp"
|
||||
#include "snippets/op/brgemm.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
@ -156,12 +157,7 @@ auto update_intermediate_supported_ops(std::shared_ptr<ov::Node>& interm_op, ov:
|
||||
break;
|
||||
|
||||
// Add node only if there are scalar constants on inputs because of plugin-specific limitation
|
||||
bool are_weights_scalar = true;
|
||||
const auto parent_count = parent->get_input_size();
|
||||
for (size_t i = 1; i < parent_count; ++i) {
|
||||
are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1;
|
||||
}
|
||||
if (!are_weights_scalar)
|
||||
if (!ov::snippets::pass::ExplicitTransposeMatMulInputs::are_weights_scalar(parent))
|
||||
break;
|
||||
|
||||
ordered_ops.insert(ordered_ops.begin() + shift, parent);
|
||||
@ -321,22 +317,27 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
|
||||
/***** Transposes *****/
|
||||
/* There may be Transpose and Reshape ops on inputs and outputs of MHA-pattern skeleton
|
||||
* We can add them into Subgraph body
|
||||
* Transpose0 Transpose1
|
||||
* \ /
|
||||
* MatMul0
|
||||
* |
|
||||
* [...] Transpose2
|
||||
* \ /
|
||||
* MatMul1
|
||||
* |
|
||||
* Transpose3
|
||||
*/
|
||||
|
||||
auto tokenize_transpose = [config](const std::shared_ptr<ov::Node>& node) -> std::shared_ptr<ov::opset1::Transpose> {
|
||||
return config.mha_token_enable_transpose ? ov::as_type_ptr<ov::opset1::Transpose>(node)
|
||||
: nullptr;
|
||||
};
|
||||
|
||||
// First input branch of MatMul0 should be executed before second input branch of MatMul0,
|
||||
// so firstly we insert Transpose1 on the beginning of ordered_ops and then Transpose1
|
||||
bool are_weights_scalar = true;
|
||||
// so firstly we insert Transpose1 on the beginning of ordered_ops and then Transpose0
|
||||
// Note: If MatMul0 has transposed_b, we should tokenize only scalars ops from 1st branch
|
||||
// to move extracted Transpose from MatMul input to body Parameter
|
||||
auto parent = matmul0->get_input_node_shared_ptr(1);
|
||||
// We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order (or without this Transpose1)
|
||||
// only if these ops have scalar shapes on other inputs.
|
||||
// There is transformation ExplicitTransposeMatMulInputs that set supported order and transposed_b(false).
|
||||
// We can allow to call this pass only if ops have scalar shapes to avoid shape mismatching
|
||||
const auto is_transposed_b_0 = matmul0->get_transpose_b();
|
||||
auto parent = matmul0->get_input_node_shared_ptr(1);
|
||||
while (is_supported_intermediate_op(parent)) {
|
||||
// All supported ops have only one output port
|
||||
if (parent->get_output_target_inputs(0).size() != 1)
|
||||
@ -344,15 +345,8 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
|
||||
|
||||
// Only if MatMul0 has transposed_b, we have to tokenize scalar ops
|
||||
// to move explicit Transpose from MatMul0 input_1 to Parameter of Subgraph body
|
||||
if (is_transposed_b_0) {
|
||||
const auto parent_count = parent->get_input_size();
|
||||
bool are_weights_scalar = true;
|
||||
for (size_t i = 1; i < parent_count; ++i) {
|
||||
are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1;
|
||||
}
|
||||
if (!are_weights_scalar) {
|
||||
break;
|
||||
}
|
||||
if (is_transposed_b_0 && !ov::snippets::pass::ExplicitTransposeMatMulInputs::are_weights_scalar(parent)) {
|
||||
break;
|
||||
}
|
||||
|
||||
// To avoid unsupported number of non-scalar Constants in the future after FakeQuantize decomposition (plugin specific limitation)
|
||||
@ -360,53 +354,45 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
|
||||
if (const auto fq_node = ov::as_type_ptr<ov::op::v0::FakeQuantize>(parent)) {
|
||||
hidden_virtual_ports_count += ov::snippets::utils::get_non_scalar_constant_count_for_fq(fq_node);
|
||||
}
|
||||
|
||||
potential_body_params_count += get_potential_body_params(parent);
|
||||
ordered_ops.insert(ordered_ops.begin(), parent);
|
||||
// TODO [107731] To go always through 0-th port - is it safe?
|
||||
// [107731] To go always through 0-th port - is it safe?
|
||||
parent = parent->get_input_node_shared_ptr(0);
|
||||
}
|
||||
|
||||
const auto transpose1 = tokenize_transpose(parent);
|
||||
if (is_transposed_b_0) {
|
||||
if (is_valid_transpose(transpose1, {0, 2, 1, 3})) {
|
||||
// We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order
|
||||
// only if these ops have scalar shapes on other inputs.
|
||||
// There is transformation ExplicitTransposeMatMulInputs that set supported order and transposed_b(false).
|
||||
// We can allow to call this pass only if ops have scalar shapes to avoid shape mismatching
|
||||
if (are_weights_scalar) {
|
||||
ordered_ops.insert(ordered_ops.begin(), transpose1);
|
||||
} else {
|
||||
return false;
|
||||
auto tokenize_transpose = [&](const std::shared_ptr<ov::opset1::Transpose>& transpose,
|
||||
bool is_input_transposed, std::vector<int64_t> order,
|
||||
const ov::NodeVector::const_iterator& pos) {
|
||||
// If Transpose has valid order for the Transpose fusing (ExplicitTransposeMatMulInputs pass call), tokenize him.
|
||||
// Otherwise, skip the Transpose.
|
||||
if (!is_input_transposed) {
|
||||
if (is_valid_transpose(transpose, order)) {
|
||||
ordered_ops.insert(pos, transpose);
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
if (is_valid_transpose(transpose1, {0, 2, 3, 1})) {
|
||||
ordered_ops.insert(ordered_ops.begin(), transpose1);
|
||||
auto transposed_order = order;
|
||||
const auto rank = transposed_order.size();
|
||||
if (rank < 2)
|
||||
return;
|
||||
std::swap(transposed_order[rank - 1], transposed_order[rank - 2]);
|
||||
if (is_valid_transpose(transpose, transposed_order)) {
|
||||
ordered_ops.insert(pos, transpose);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (transpose1) {
|
||||
// Between Transpose1 and MatMul0 will be the one Loop because of LoopFusing optimization.
|
||||
// The Loop will have one Buffer with the same shape both on input and output.
|
||||
// Need to check for precision to get if we need one more register for Buffer
|
||||
if (matmul0->get_input_element_type(1).size() != transpose1->get_output_element_type(0).size()) {
|
||||
buffer_count++;
|
||||
}
|
||||
}
|
||||
auto get_transpose = [config](const std::shared_ptr<ov::Node>& node) -> std::shared_ptr<ov::opset1::Transpose> {
|
||||
return config.mha_token_enable_transpose ? ov::as_type_ptr<ov::opset1::Transpose>(node)
|
||||
: nullptr;
|
||||
};
|
||||
|
||||
const auto transpose0 = tokenize_transpose(matmul0->get_input_node_shared_ptr(0));
|
||||
if (is_valid_transpose(transpose0, {0, 2, 1, 3})) {
|
||||
ordered_ops.insert(ordered_ops.begin(), transpose0);
|
||||
} else if (matmul0->get_transpose_a()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto transpose2 = tokenize_transpose(matmul1->get_input_node_shared_ptr(1));
|
||||
if (is_valid_transpose(transpose2, {0, 2, 1, 3})) {
|
||||
ordered_ops.push_back(transpose2);
|
||||
}
|
||||
const auto transpose1 = get_transpose(parent);
|
||||
const auto transpose0 = get_transpose(matmul0->get_input_node_shared_ptr(0));
|
||||
const auto transpose2 = get_transpose(matmul1->get_input_node_shared_ptr(1));
|
||||
tokenize_transpose(transpose1, is_transposed_b_0, {0, 2, 3, 1}, ordered_ops.begin());
|
||||
tokenize_transpose(transpose0, matmul0->get_transpose_a(), {0, 2, 1, 3}, ordered_ops.begin());
|
||||
tokenize_transpose(transpose2, matmul1->get_transpose_b(), {0, 2, 1, 3}, ordered_ops.end());
|
||||
ordered_ops.push_back(matmul1);
|
||||
|
||||
bool are_ops_after_matmul1 = false;
|
||||
@ -439,7 +425,7 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
|
||||
// <Supported ops>
|
||||
// Transpose3
|
||||
if (!are_ops_after_matmul1) {
|
||||
auto transpose3 = tokenize_transpose(child);
|
||||
auto transpose3 = get_transpose(child);
|
||||
if (is_valid_transpose(transpose3, {0, 2, 1, 3}) &&
|
||||
transpose3->get_input_element_type(0) == matmul1_out_type) { // To avoid Convert between MatMul1 and Transpose3
|
||||
ordered_ops.push_back(transpose3);
|
||||
|
@ -7,7 +7,7 @@
|
||||
#include <subgraph_mha.hpp>
|
||||
#include "snippets/pass/tokenization.hpp"
|
||||
#include "snippets/pass/mha_tokenization.hpp"
|
||||
#include "snippets/pass/explicit_transpose_matmul_inputs.hpp"
|
||||
#include "snippets/pass/common_optimizations.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
@ -15,9 +15,10 @@ namespace snippets {
|
||||
|
||||
void TokenizeMHASnippetsTests::run() {
|
||||
ASSERT_TRUE(function);
|
||||
std::string name;
|
||||
manager.register_pass<ov::snippets::pass::EnumerateNodes>();
|
||||
manager.register_pass<ov::snippets::pass::TokenizeMHASnippets>();
|
||||
manager.register_pass<ov::snippets::pass::CommonOptimizations>();
|
||||
disable_rt_info_check();
|
||||
}
|
||||
|
||||
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA) {
|
||||
@ -43,6 +44,31 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_with_int_Matmuls) {
|
||||
run();
|
||||
}
|
||||
|
||||
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Transpose_extraction) {
|
||||
const auto& f = MHATransposedInputFunction(std::vector<PartialShape>{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 128, 12, 64}}, true);
|
||||
function = f.getOriginal();
|
||||
function_ref = f.getReference();
|
||||
run();
|
||||
}
|
||||
|
||||
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Transpose_extraction_and_unsupported_existing_transpose) {
|
||||
const auto& f = MHATransposedInputFunction(std::vector<PartialShape>{{1, 128, 12, 64}, {1, 12, 64, 128}, {1, 128, 12, 64}}, true,
|
||||
std::vector<int64_t>{0, 3, 1, 2});
|
||||
function = f.getOriginal();
|
||||
function_ref = f.getReference();
|
||||
run();
|
||||
}
|
||||
|
||||
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Transpose_fusion) {
|
||||
const auto& f = MHATransposedInputFunction(std::vector<PartialShape>{{1, 128, 12, 64}, {1, 64, 128, 12}, {1, 128, 12, 64}}, false,
|
||||
std::vector<int64_t>{0, 2, 1, 3});
|
||||
function = f.getOriginal();
|
||||
function_ref = f.getReference();
|
||||
run();
|
||||
}
|
||||
|
||||
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
} // namespace ov
|
@ -227,6 +227,21 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFQ, MHAFQ,
|
||||
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
const std::vector<std::vector<ov::PartialShape>> inputShapesTransposedB = {
|
||||
{{1, 12, 12, 64}, {1, 12, 48, 64}, {1, 12, 48, 64}}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHATransposedB, MHATransposedB,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesTransposedB),
|
||||
::testing::Values(std::vector<element::Type>{}),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
|
||||
::testing::Values(2),
|
||||
::testing::Values(1),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(std::map<std::string, std::string>{})),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
} // namespace snippets
|
||||
|
@ -5,6 +5,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "shared_test_classes/base/snippets_test_utils.hpp"
|
||||
#include "ngraph_helpers/snippets_ngraph_functions/include/snippets_helpers.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
@ -30,7 +31,7 @@ protected:
|
||||
void SetUp() override;
|
||||
|
||||
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override;
|
||||
virtual void init_subgraph();
|
||||
virtual std::shared_ptr<SnippetsFunctionBase> get_subgraph();
|
||||
|
||||
bool m_with_mul = false;
|
||||
std::vector<ov::element::Type> m_input_types;
|
||||
@ -39,39 +40,42 @@ protected:
|
||||
class MHASelect : public MHA {
|
||||
protected:
|
||||
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override;
|
||||
void init_subgraph() override;
|
||||
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
|
||||
};
|
||||
|
||||
class MHAWOTransposeOnInputs : public MHA {
|
||||
protected:
|
||||
void init_subgraph() override;
|
||||
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
|
||||
};
|
||||
|
||||
class MHAWOTranspose : public MHA {
|
||||
protected:
|
||||
void init_subgraph() override;
|
||||
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
|
||||
};
|
||||
|
||||
class MHAMulAdd : public MHA {
|
||||
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
|
||||
};
|
||||
|
||||
class MHATransposedB : public MHA {
|
||||
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
|
||||
};
|
||||
|
||||
class MHAINT8MatMul : public MHA {
|
||||
protected:
|
||||
void init_subgraph() override;
|
||||
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
|
||||
};
|
||||
|
||||
class MHAFQAfterMatMul : public MHA {
|
||||
protected:
|
||||
void init_subgraph() override;
|
||||
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
|
||||
};
|
||||
|
||||
class MHAFQ : public MHA {
|
||||
protected:
|
||||
void init_subgraph() override;
|
||||
std::shared_ptr<SnippetsFunctionBase> get_subgraph() override;
|
||||
};
|
||||
|
||||
class MHAMulAdd : public MHA {
|
||||
void init_subgraph() override;
|
||||
};
|
||||
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
|
@ -51,7 +51,8 @@ void MHA::SetUp() {
|
||||
std::tie(inputShapes, m_input_types, prc, m_with_mul, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam();
|
||||
init_input_shapes(static_partial_shapes_to_test_representation(inputShapes));
|
||||
|
||||
init_subgraph();
|
||||
const auto subgraph_model = get_subgraph();
|
||||
function = subgraph_model->getOriginal();
|
||||
|
||||
configuration.insert(additionalConfig.begin(), additionalConfig.end());
|
||||
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
|
||||
@ -76,9 +77,8 @@ void MHA::generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticSha
|
||||
}
|
||||
}
|
||||
|
||||
void MHA::init_subgraph() {
|
||||
auto f = ov::test::snippets::MHAFunction(inputDynamicShapes, m_input_types, m_with_mul);
|
||||
function = f.getOriginal();
|
||||
std::shared_ptr<SnippetsFunctionBase> MHA::get_subgraph() {
|
||||
return std::make_shared<ov::test::snippets::MHAFunction>(inputDynamicShapes, m_input_types, m_with_mul);
|
||||
}
|
||||
|
||||
void MHASelect::generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) {
|
||||
@ -99,39 +99,36 @@ void MHASelect::generate_inputs(const std::vector<ngraph::Shape>& targetInputSta
|
||||
}
|
||||
}
|
||||
|
||||
void MHASelect::init_subgraph() {
|
||||
auto f = ov::test::snippets::MHASelectFunction(inputDynamicShapes, m_input_types);
|
||||
function = f.getOriginal();
|
||||
std::shared_ptr<SnippetsFunctionBase> MHASelect::get_subgraph() {
|
||||
return std::make_shared<ov::test::snippets::MHASelectFunction>(inputDynamicShapes, m_input_types);
|
||||
}
|
||||
|
||||
void MHAWOTransposeOnInputs::init_subgraph() {
|
||||
auto f = ov::test::snippets::MHAWOTransposeOnInputsFunction(inputDynamicShapes);
|
||||
function = f.getOriginal();
|
||||
std::shared_ptr<SnippetsFunctionBase> MHAWOTransposeOnInputs::get_subgraph() {
|
||||
return std::make_shared<ov::test::snippets::MHAWOTransposeOnInputsFunction>(inputDynamicShapes);
|
||||
}
|
||||
|
||||
void MHAWOTranspose::init_subgraph() {
|
||||
auto f = ov::test::snippets::MHAWOTransposeFunction(inputDynamicShapes, m_input_types);
|
||||
function = f.getOriginal();
|
||||
std::shared_ptr<SnippetsFunctionBase> MHAWOTranspose::get_subgraph() {
|
||||
return std::make_shared<ov::test::snippets::MHAWOTransposeFunction>(inputDynamicShapes, m_input_types);
|
||||
}
|
||||
|
||||
void MHAINT8MatMul::init_subgraph() {
|
||||
auto f = ov::test::snippets::MHAINT8MatMulFunction(inputDynamicShapes);
|
||||
function = f.getOriginal();
|
||||
std::shared_ptr<SnippetsFunctionBase> MHAINT8MatMul::get_subgraph() {
|
||||
return std::make_shared<ov::test::snippets::MHAINT8MatMulFunction>(inputDynamicShapes);
|
||||
}
|
||||
|
||||
void MHAFQAfterMatMul::init_subgraph() {
|
||||
auto f = ov::test::snippets::MHAFQAfterMatMulFunction(inputDynamicShapes);
|
||||
function = f.getOriginal();
|
||||
std::shared_ptr<SnippetsFunctionBase> MHAFQAfterMatMul::get_subgraph() {
|
||||
return std::make_shared<ov::test::snippets::MHAFQAfterMatMulFunction>(inputDynamicShapes);
|
||||
}
|
||||
|
||||
void MHAFQ::init_subgraph() {
|
||||
auto f = ov::test::snippets::MHAFQFunction(inputDynamicShapes);
|
||||
function = f.getOriginal();
|
||||
std::shared_ptr<SnippetsFunctionBase> MHAFQ::get_subgraph() {
|
||||
return std::make_shared<ov::test::snippets::MHAFQFunction>(inputDynamicShapes);
|
||||
}
|
||||
|
||||
void MHAMulAdd::init_subgraph() {
|
||||
auto f = ov::test::snippets::MHAMulAddFunction(inputDynamicShapes);
|
||||
function = f.getOriginal();
|
||||
std::shared_ptr<SnippetsFunctionBase> MHAMulAdd::get_subgraph() {
|
||||
return std::make_shared<ov::test::snippets::MHAMulAddFunction>(inputDynamicShapes);
|
||||
}
|
||||
|
||||
std::shared_ptr<SnippetsFunctionBase> MHATransposedB::get_subgraph() {
|
||||
return std::make_shared<ov::test::snippets::MHATransposedInputFunction>(inputDynamicShapes, true);
|
||||
}
|
||||
|
||||
TEST_P(MHA, CompareWithRefImpl) {
|
||||
@ -153,26 +150,37 @@ TEST_P(MHAWOTransposeOnInputs, CompareWithRefImpl) {
|
||||
}
|
||||
|
||||
TEST_P(MHAWOTranspose, CompareWithRefImpl) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(MHAMulAdd, CompareWithRefImpl) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(MHATransposedB, CompareWithRefImpl) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(MHAINT8MatMul, CompareWithRefImpl) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(MHAFQAfterMatMul, CompareWithRefImpl) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(MHAFQ, CompareWithRefImpl) {
|
||||
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
@ -289,6 +289,31 @@ protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
};
|
||||
|
||||
/* Graph:
|
||||
* Transpose/Parameter
|
||||
* \ /
|
||||
* MatMul0 [transposed_b = true/false]
|
||||
* |
|
||||
* Softmax
|
||||
* \ /
|
||||
* MatMul1
|
||||
* |
|
||||
*/
|
||||
class MHATransposedInputFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHATransposedInputFunction(const std::vector<PartialShape>& inputShapes, bool transposed_b = false,
|
||||
std::vector<int64_t> order = {})
|
||||
: SnippetsFunctionBase(inputShapes), m_transposed_b(transposed_b), m_order(order) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
std::shared_ptr<ov::Model> initReference() const override;
|
||||
|
||||
bool m_transposed_b = false;
|
||||
std::vector<int64_t> m_order = {};
|
||||
};
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
|
@ -685,6 +685,72 @@ std::shared_ptr<ov::Model> MHAMulAddFunction::initOriginal() const {
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> MHATransposedInputFunction::initOriginal() const {
|
||||
const auto param0 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
const auto param1 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
const auto param2 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
ngraph::ParameterVector ngraphParam = {param0, param1, param2};
|
||||
|
||||
std::shared_ptr<ov::Node> matmul0_in1 = param1;
|
||||
if (!m_order.empty()) {
|
||||
const auto transposeConst = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{m_order.size()}, m_order);
|
||||
matmul0_in1 = std::make_shared<ov::op::v1::Transpose>(param1, transposeConst);
|
||||
}
|
||||
|
||||
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(param0, matmul0_in1, false, m_transposed_b);
|
||||
const auto softmax = std::make_shared<ngraph::opset8::Softmax>(matMul0, -1);
|
||||
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, param2);
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(matMul1)};
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> MHATransposedInputFunction::initReference() const {
|
||||
const auto data0 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
const auto data1 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
const auto data2 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
ngraph::ParameterVector ngraphParam = {data0, data1, data2};
|
||||
|
||||
bool is_supported = ((m_transposed_b && m_order == std::vector<int64_t>{0, 2, 1, 3}) ||
|
||||
(!m_transposed_b && m_order == std::vector<int64_t>{0, 2, 3, 1}));
|
||||
|
||||
std::shared_ptr<ov::Node> in1 = data1;
|
||||
if (!m_order.empty() && !is_supported) {
|
||||
const auto transposeConst = ngraph::builder::makeConstant(ngraph::element::i64, ov::Shape{m_order.size()}, m_order);
|
||||
in1 = std::make_shared<ov::op::v1::Transpose>(in1, transposeConst);
|
||||
}
|
||||
if (m_transposed_b) {
|
||||
if (m_order != std::vector<int64_t>{0, 2, 1, 3}) {
|
||||
const auto rank = input_shapes[1].size();
|
||||
std::vector<int32_t> transpose_order(rank, 0);
|
||||
std::iota(transpose_order.begin(), transpose_order.end(), 0);
|
||||
std::swap(transpose_order[rank - 1], transpose_order[rank - 2]);
|
||||
const auto transposeConst = ngraph::builder::makeConstant(ngraph::element::i32, ov::Shape{transpose_order.size()}, transpose_order);
|
||||
in1 = std::make_shared<ov::op::v1::Transpose>(in1, transposeConst);
|
||||
}
|
||||
}
|
||||
|
||||
const auto param0 = std::make_shared<ngraph::opset1::Parameter>(precision, data0->get_shape());
|
||||
const auto param1 = std::make_shared<ngraph::opset1::Parameter>(precision, in1->get_shape());
|
||||
const auto param2 = std::make_shared<ngraph::opset1::Parameter>(precision, data2->get_shape());
|
||||
|
||||
std::shared_ptr<ov::Node> matmul0_in1 = param1;
|
||||
if (!m_order.empty() && is_supported) {
|
||||
const auto transposeConst = ngraph::builder::makeConstant(ngraph::element::i32, ov::Shape{m_order.size()}, m_order);
|
||||
matmul0_in1 = std::make_shared<ov::op::v1::Transpose>(param1, transposeConst);
|
||||
}
|
||||
|
||||
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(param0, matmul0_in1);
|
||||
const auto softmax = std::make_shared<ngraph::opset8::Softmax>(matMul0, -1);
|
||||
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, param2);
|
||||
|
||||
auto subgraph = std::make_shared<ov::snippets::op::Subgraph>(ov::NodeVector{data0, in1, data2},
|
||||
std::make_shared<ov::Model>(NodeVector{matMul1}, ov::ParameterVector{param0, param1, param2}));
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(subgraph)};
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
|
Loading…
Reference in New Issue
Block a user