[Snippets] Add support of MHA Tokenization for different precisions (#15647)

This commit is contained in:
Alexandra Sidorova 2023-06-08 12:05:14 +04:00 committed by GitHub
parent bdfa970c7a
commit eb3e6a65eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 1111 additions and 283 deletions

View File

@ -28,6 +28,8 @@ public:
size_t get_offset_b() const { return get_input_offset(1); }
size_t get_offset_c() const { return get_output_offset(0); }
static ov::element::Type get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1);
void validate_and_infer_types() override;
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;

View File

@ -29,7 +29,7 @@ class Buffer : public ov::op::Op {
public:
OPENVINO_OP("Buffer", "SnippetsOpset");
Buffer() = default;
Buffer(const ov::Shape& shape, size_t id = 0);
Buffer(const ov::Shape& shape, ov::element::Type element_type = ov::element::u8, size_t id = 0);
Buffer(const ov::Output<ov::Node>& arg, const ov::Shape& shape, size_t id = 0);
Buffer(const ov::Output<ov::Node>& arg, int32_t allocation_rank = -1, size_t id = 0);
@ -48,9 +48,10 @@ public:
int64_t get_offset() const { return m_offset; }
void set_id(size_t id) { m_id = id; }
void set_offset(int64_t offset) { m_offset = offset; }
size_t get_byte_size() const;
void set_element_type(ov::element::Type element_type);
bool is_intermediate_memory() const { return m_type == Type::IntermediateMemory; }
bool is_new_memory() const { return m_type == Type::NewMemory; }
@ -59,6 +60,7 @@ private:
ov::Shape m_shape = {};
int64_t m_offset = 0;
size_t m_id = 0; // Default ID - 0. All Buffers are from the same set
ov::element::Type m_element_type = ov::element::u8; // u8 - default 1 byte
};
} // namespace op

View File

@ -136,7 +136,6 @@ public:
// should have explicit Constants even if they're non-scalar (Reshape, Transpose, Broadcast)
// This check returns True if Constant op which is input of this op should be inside Subgraph body
static auto constant_input_should_be_inside_body(const std::shared_ptr<ov::Node>& node) -> bool;
static bool check_broadcast(const std::shared_ptr<const ov::Node>& node) noexcept;
// Return estimated unique buffer count (upper bound). It's needed for tokenization
static auto get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t;

View File

@ -6,6 +6,7 @@
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "snippets/pass/tokenization.hpp"
namespace ov {
namespace snippets {
@ -14,13 +15,34 @@ namespace pass {
/**
* @interface TokenizeMHASnippets
* @brief The pass tokenizes MHA-pattern into Subgraph
* TODO: Write pattern
* Pattern: Transpose1
* |
* Transpose0 [Eltwise, Select]
* \ /
* MatMul0
* |
* [Eltwise, Select, Reshape]
* |
* Softmax
* |
* [Eltwise, Select, Reshape] Transpose2
* \ /
* MatMul1
* |
* [Eltwise, Select, Transpose3]
* Notes:
* - Transposes can be missed
* - Transpose0, Transpose2 and Transpose3 may have only [0,2,1,3] order
* - Transpose1 may have only [0,2,3,1] order
* - [...] means any count of different nodes from list. But:
* * Reshapes can be only explicitly around Softmax (Reshape -> Softmax -> Reshape)
* * After MatMul1 may be only Transpose3 or any count of Eltwise, Select ops.
* @ingroup snippets
*/
class TokenizeMHASnippets: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("TokenizeMHASnippets", "0");
TokenizeMHASnippets();
TokenizeMHASnippets(const SnippetsTokenization::Config& config = {});
};
} // namespace pass

View File

@ -7,8 +7,7 @@
#include "openvino/pass/graph_rewrite.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "snippets/pass/mha_tokenization.hpp"
#include "snippets/pass/collapse_subgraph.hpp"
#include "snippets/op/subgraph.hpp"
namespace ov {
namespace snippets {
@ -19,8 +18,16 @@ namespace pass {
SkippedByPlugin - indicate that snippets can't include this node in subgraph. Can be set by Plugin via SetSnippetsNodeType(...).
*/
enum class SnippetsNodeType : int64_t {NotSet, SkippedByPlugin};
/*
NotSet - default value returned if the subgraph wasn't marked and snippets can include nodes in this subgraph
Completed - indicate that snippets can't include any nodes in this subgraph.
It's used in separate tokenization pass, for example, tokenization by matcher (MHA Tokenization).
*/
enum class SnippetsSubgraphType : int64_t {NotSet, Completed};
void SetSnippetsNodeType(const std::shared_ptr<Node>&, SnippetsNodeType);
void SetSnippetsSubgraphType(const std::shared_ptr<op::Subgraph>&, SnippetsSubgraphType);
SnippetsNodeType GetSnippetsNodeType(const std::shared_ptr<const Node>&);
SnippetsSubgraphType GetSnippetsSubgraphType(const std::shared_ptr<const op::Subgraph>&);
void SetTopologicalOrder(const std::shared_ptr<Node>&, int64_t);
int64_t GetTopologicalOrder(const std::shared_ptr<const Node>&);
@ -48,8 +55,26 @@ public:
*/
class SnippetsTokenization : public ov::pass::ModelPass {
public:
/**
* @interface Config
* @brief Allow to adjust tokenization passes
* @ingroup snippets
*/
struct Config {
Config(bool enable_transpose = true) : mha_token_enable_transpose(enable_transpose) {}
// False if all Transposes aren't tokenized in MHA Tokenization.
// Otherwise, they may be fused into Subgraph if possible
// TODO [106921]: Remove please when the ticket 106921 is implemented
bool mha_token_enable_transpose = true;
};
OPENVINO_RTTI("SnippetsTokenization", "0");
SnippetsTokenization(const Config& config) : m_config(config) {}
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
private:
Config m_config{};
};

View File

@ -124,8 +124,8 @@ bool AllocateBuffers::run(LinearIR& linear_ir) {
const auto current_allocated_memory_size = m_buffer_scratchpad_size - offset;
if (buffer_size > current_allocated_memory_size) {
m_buffer_scratchpad_size += (buffer_size - current_allocated_memory_size);
// Note: we don't update offset because we just add memory to needed size
allocate(buffer, expr, buffer_size);
continue;
}
propagate_offset(linear_ir, *expr_it, offset);
allocated_buffers.insert(expr);

View File

@ -100,9 +100,9 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
// Otherwise WIN build fails with "IS_MANUALLY_ALLOCATED_REG cannot be implicitly captured because no default capture mode has been specified"
// the same problem with all the other lambdas in this file
auto enumerate_out_tensors = [=] (const ExpressionPtr& expr,
decltype(regs_vec)& reg_map,
const std::map<tensor, Reg>& manually_assigned_regs,
size_t& counter) {
decltype(regs_vec)& reg_map,
const std::map<tensor, Reg>& manually_assigned_regs,
size_t& counter) {
for (const auto& out_tensor : expr->get_output_port_connectors()) {
// Note that some ops might have identical input&output tensors (Result and Tile* for ex.)
// so we have to check that the tensor has not been enumerated already

View File

@ -62,24 +62,31 @@ std::shared_ptr<Node> Brgemm::clone_with_new_inputs(const OutputVector& new_args
lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout());
}
ov::element::Type Brgemm::get_output_type() const {
const auto element_type_a = get_input_element_type(0);
const auto element_type_b = get_input_element_type(1);
const bool is_f32 = utils::everyone_is(element::f32, element_type_a, element_type_b);
const bool is_int8 = utils::one_of(element_type_a, element::i8, element::u8) && element_type_b == element::i8;
const bool is_bf16 = utils::everyone_is(element::bf16, element_type_a, element_type_b);
ov::element::Type Brgemm::get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1) {
const bool is_f32 = utils::everyone_is(element::f32, in_type0, in_type1);
const bool is_int8 = utils::one_of(in_type0, element::i8, element::u8) && in_type1 == element::i8;
const bool is_bf16 = utils::everyone_is(element::bf16, in_type0, in_type1);
if (is_f32 || is_bf16) {
return element::f32;
return element::f32;
} else if (is_int8) {
return element::i32;
} else {
OPENVINO_THROW("BrgemmCPU node has incompatible input element types: " +
element_type_a.get_type_name() +
" and " +
element_type_b.get_type_name());
return element::undefined;
}
}
ov::element::Type Brgemm::get_output_type() const {
auto output_type = get_output_type(get_input_element_type(0), get_input_element_type(1));
if (output_type == element::undefined) {
OPENVINO_THROW("BrgemmCPU node has incompatible input element types: " +
get_input_element_type(0).get_type_name() +
" and " +
get_input_element_type(1).get_type_name());
}
return output_type;
}
std::vector<ov::PartialShape> Brgemm::get_planar_input_shapes(const std::vector<ov::Input<ov::Node>>& inputs) const {
OPENVINO_ASSERT(inputs.size() == 2, "Brgemm::get_planar_input_shapes() expects 2 inputs");
return { utils::get_port_planar_shape(inputs[0]), utils::get_port_planar_shape(inputs[1]) };

View File

@ -14,8 +14,8 @@ namespace snippets {
namespace op {
Buffer::Buffer(const ov::Shape& shape, size_t id)
: Op(), m_type(Type::NewMemory), m_shape(shape), m_offset(0), m_id(id) {
Buffer::Buffer(const ov::Shape& shape, ov::element::Type element_type, size_t id)
: Op(), m_type(Type::NewMemory), m_shape(shape), m_offset(0), m_id(id), m_element_type(std::move(element_type)) {
constructor_validate_and_infer_types();
}
@ -40,26 +40,25 @@ bool Buffer::visit_attributes(AttributeVisitor& visitor) {
visitor.on_attribute("allocation_shape", m_shape);
visitor.on_attribute("offset", m_offset);
visitor.on_attribute("id", m_id);
visitor.on_attribute("element_type", m_element_type);
return true;
}
void Buffer::validate_and_infer_types() {
INTERNAL_OP_SCOPE(Buffer_validate_and_infer_types);
ov::element::Type output_type;
ov::Shape output_shape;
if (m_type == Type::NewMemory) {
OPENVINO_ASSERT(get_input_size() == 0, "Buffer with new allocated memory must to not have arguments!");
output_shape = m_shape;
output_type = ov::element::u8; // 1Byte
} else if (m_type == Type::IntermediateMemory) {
const auto& input_shape = get_input_partial_shape(0);
OPENVINO_ASSERT(input_shape.is_static(), "Buffer supports only static input shape");
output_type = get_input_element_type(0);
m_element_type = get_input_element_type(0);
output_shape = input_shape.get_shape();
} else {
OPENVINO_THROW("Buffer supports only the following types: NewMemory and IntermediateMemory");
}
set_output_type(0, output_type, output_shape);
set_output_type(0, m_element_type, output_shape);
}
std::shared_ptr<Node> Buffer::clone_with_new_inputs(const OutputVector& new_args) const {
@ -67,7 +66,7 @@ std::shared_ptr<Node> Buffer::clone_with_new_inputs(const OutputVector& new_args
check_new_args_count(this, new_args);
std::shared_ptr<op::Buffer> new_buffer = nullptr;
if (m_type == Type::NewMemory) {
new_buffer = std::make_shared<Buffer>(m_shape, m_id);
new_buffer = std::make_shared<Buffer>(m_shape, m_element_type, m_id);
} else if (m_type == Type::IntermediateMemory) {
new_buffer = std::make_shared<Buffer>(new_args.at(0), m_shape, m_id);
} else {
@ -82,6 +81,13 @@ size_t Buffer::get_byte_size() const {
return ov::shape_size(shape) * get_element_type().size();
}
void Buffer::set_element_type(ov::element::Type element_type) {
OPENVINO_ASSERT(is_new_memory(), "Only Buffer with NewMemory can change his output precision!");
m_element_type = std::move(element_type);
// Apply the change
validate_and_infer_types();
}
} // namespace op
} // namespace snippets
} // namespace ov

View File

@ -14,7 +14,6 @@
#include "snippets/pass/convert_constants.hpp"
#include "snippets/pass/convert_power_to_powerstatic.hpp"
#include "snippets/pass/transpose_decomposition.hpp"
#include "snippets/pass/transform_convert.hpp"
#include "snippets/pass/matmul_to_brgemm.hpp"
#include "snippets/pass/fuse_transpose_brgemm.hpp"
#include "snippets/pass/set_softmax_ports.hpp"
@ -75,12 +74,11 @@ auto snippets::op::Subgraph::is_domain_sensitive_op(const std::shared_ptr<ov::No
}
void snippets::op::Subgraph::init_config() {
auto update = [](bool& flag, bool status) { flag = flag || status; };
const auto ops = body_ptr()->get_ops();
for (const auto& op : ops) {
config.m_is_quantized = config.m_is_quantized ||
ov::is_type<ov::op::v0::FakeQuantize>(op);
config.m_has_domain_sensitive_ops = config.m_has_domain_sensitive_ops ||
is_domain_sensitive_op(op);
update(config.m_is_quantized, ov::is_type<ov::op::v0::FakeQuantize>(op));
update(config.m_has_domain_sensitive_ops, is_domain_sensitive_op(op));
}
}
@ -93,6 +91,13 @@ auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& op
// and where will be Loops - we can just predict.
// Note: The ops that create Buffers: MatMul, Transpose and Softmax (always FP32)
std::vector<size_t> used_precision_size;
auto push_prc_size = [&used_precision_size](size_t precision_size) {
if (used_precision_size.empty() || used_precision_size.back() != precision_size) {
used_precision_size.push_back(precision_size);
}
};
for (const auto& op : ops) {
if (const auto transpose = ov::as_type_ptr<ov::op::v1::Transpose>(op)) {
// At the moment Transposes are supported only on Results and Parameters but
@ -106,34 +111,23 @@ auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& op
}) ||
!ov::is_type<ov::op::v0::Parameter>(transpose->get_input_node_shared_ptr(0));
if (are_prev_or_next_ops) {
const auto prc_size = transpose->get_element_type().size();
if (used_precision_size.empty() || used_precision_size.back() != prc_size) {
used_precision_size.push_back(prc_size);
}
push_prc_size(transpose->get_element_type().size());
}
} else if (ov::is_type<ov::op::v1::Softmax>(op) || ov::is_type<ov::op::v8::Softmax>(op)) {
// Softmax always uses 2 FP32 Buffers
const auto prc_size = ov::element::f32.size();
if (used_precision_size.empty() || used_precision_size.back() != prc_size) {
used_precision_size.push_back(prc_size);
}
// Softmax always uses 2 FP32 Buffers after decomposition.
// They are inplace and the same so we can push precision size only once
push_prc_size(ov::element::f32.size());
} else if (const auto matmul = ov::as_type_ptr<ov::op::v0::MatMul>(op)) {
// First input check is enough because MatMul requires the same prc size on inputs
if (!ov::is_type<ov::op::v0::Parameter>(matmul->get_input_node_shared_ptr(0)) ||
!ov::is_type<ov::op::v0::Parameter>(matmul->get_input_node_shared_ptr(1))) {
const auto prc_size = matmul->get_input_element_type(0).size();
if (used_precision_size.empty() || used_precision_size.back() != prc_size) {
used_precision_size.push_back(prc_size);
}
push_prc_size(matmul->get_input_element_type(0).size());
}
const auto consumers = matmul->get_output_target_inputs(0);
if (std::none_of(consumers.begin(), consumers.end(),
[](const ov::Input<ov::Node>& in) { return ov::is_type<ov::op::v0::Result>(in.get_node()); })) {
const auto prc_size = matmul->get_element_type().size();
if (used_precision_size.empty() || used_precision_size.back() != prc_size) {
used_precision_size.push_back(prc_size);
}
push_prc_size(matmul->get_element_type().size());
}
}
}

View File

@ -63,11 +63,21 @@ auto is_supported_op(const std::shared_ptr<const Node> &n) -> bool {
const auto& transpose = as_type_ptr<const opset1::Transpose>(n);
const auto& out_shape = n->get_output_partial_shape(0);
if (transpose && out_shape.is_static()) {
const auto parent = transpose->get_input_node_shared_ptr(0);
const auto child = transpose->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
auto is_brgemm_case = ov::is_type<opset1::MatMul>(parent) || ov::is_type<opset1::MatMul>(child);
// Check for Transpose parent is MatMul inside Subgraph
if (const auto subgraph = ov::as_type_ptr<op::Subgraph>(parent)) {
const auto body = subgraph->body_ptr();
const auto subgraph_output = body->get_results()[transpose->input_value(0).get_index()]->get_input_node_shared_ptr(0);
is_brgemm_case = is_brgemm_case || ov::is_type<opset1::MatMul>(subgraph_output);
}
const auto& order = as_type_ptr<const opset1::Constant>(n->get_input_node_shared_ptr(1));
if (order) {
const auto order_value = order->cast_vector<int>();
return TransposeDecomposition::supported_cases.count(order_value) != 0 ||
FuseTransposeBrgemm::supported_cases.count(order_value) != 0;
return (TransposeDecomposition::supported_cases.count(order_value) != 0) ||
(is_brgemm_case && FuseTransposeBrgemm::supported_cases.count(order_value) != 0);
}
}
return false;
@ -337,7 +347,7 @@ TokenizeSnippets::TokenizeSnippets() {
for (const auto& input_node : ov::as_node_vector(input_values)) {
if (auto subgraph = ov::as_type_ptr<op::Subgraph>(input_node)) {
if (!clones.count(input_node)) {
if (!clones.count(input_node) && GetSnippetsSubgraphType(subgraph) != SnippetsSubgraphType::Completed) {
auto f = subgraph->body().clone();
f->set_friendly_name(subgraph->body_ptr()->get_friendly_name());
clones[input_node] = f;
@ -524,15 +534,18 @@ TokenizeSnippets::TokenizeSnippets() {
ResultVector body_results;
std::vector<std::set<Input<Node>>> subgraph_result_inputs;
ov::NodeVector new_body_ops;
ov::NodeVector ops_for_buffer_count;
for (auto subgraph : input_subgraphs) {
// we should summurize additional needed data count (non-scalar Constants and Buffers) from all input subgraphs
// because we will collapse them with our node and we should get total count
const auto subgraph_ptr = ov::as_type_ptr<ov::snippets::op::Subgraph>(subgraph);
hidden_data_count += subgraph_ptr->get_virtual_port_count();
// Buffers can be existed only in Subgraphs with domain sensetive ops which
// requires intermediate memory for data repacking
// To avoid load time regressions, we verify only these Subgraph with domain sensetive ops
if (subgraph_ptr->has_domain_sensitive_ops()) {
const auto ops = subgraph_ptr->body_ptr()->get_ordered_ops();
new_body_ops.insert(new_body_ops.end(), ops.begin(), ops.end());
ops_for_buffer_count.insert(ops_for_buffer_count.end(), ops.begin(), ops.end());
}
for (auto output : subgraph->outputs()) {
@ -566,7 +579,7 @@ TokenizeSnippets::TokenizeSnippets() {
}
if (op::Subgraph::is_domain_sensitive_op(node)) {
new_body_ops.push_back(node);
ops_for_buffer_count.push_back(node);
}
for (auto output : node->outputs()) {
@ -582,7 +595,7 @@ TokenizeSnippets::TokenizeSnippets() {
// At the moment, CPU Plugin has limitation for GPR registers: there are only 12 available registers.
// This limitation will be resolved once generator supports gprs spills [75622].
// TODO [75567]: move this plugin-specific constraint to the plugin callback
const auto unique_buffer_count = op::Subgraph::get_estimated_buffer_count(new_body_ops);
const auto unique_buffer_count = op::Subgraph::get_estimated_buffer_count(ops_for_buffer_count);
if (body_parameters.size() + body_results.size() + hidden_data_count + unique_buffer_count > 12) {
const std::string message_reset = "new subgraph is created. Impossible to schedule subgraph with " +
std::to_string(body_parameters.size()) + " inputs, " + std::to_string(body_results.size()) + " outputs and " +

View File

@ -6,9 +6,10 @@
#include "snippets/itt.hpp"
#include "snippets/utils.hpp"
#include "snippets/pass/tokenization.hpp"
#include "snippets/pass/collapse_subgraph.hpp"
#include "snippets/op/subgraph.hpp"
#include "snippets/op/brgemm.hpp"
#include "snippets/utils.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
@ -17,17 +18,14 @@
namespace {
auto is_supported_tensor(const ov::descriptor::Tensor& t) -> bool {
// TODO: Add support of all supported by common tokenization element types
// return ov::snippets::pass::TokenizeSnippets::supported_element_types.count(input.get_element_type()) != 0;
return t.get_element_type() == ngraph::element::f32 &&
t.get_partial_shape().is_static() && ov::snippets::utils::one_of(t.get_shape().size(), 3lu, 4lu);
return t.get_partial_shape().is_static() && ov::snippets::utils::one_of(t.get_shape().size(), 3lu, 4lu);
}
// TODO: Add support of FQ, Reshape?
auto is_supported_intermediate_op(const std::shared_ptr<ov::Node>& node) -> bool {
const auto is_intermediate_op = [](const std::shared_ptr<ov::Node>& node) {
return ov::is_type<ov::op::util::UnaryElementwiseArithmetic>(node) ||
ov::is_type<ov::op::util::BinaryElementwiseArithmetic>(node) ||
ov::is_type<ov::op::v0::FakeQuantize>(node) ||
ov::is_type<ov::op::v1::Select>(node);
};
return is_intermediate_op(node) && ov::snippets::pass::TokenizeSnippets::AppropriateForSubgraph(node);
@ -40,9 +38,12 @@ auto is_valid_transpose(const std::shared_ptr<ov::opset1::Transpose>& node, std:
return false;
return transpose_pattern->cast_vector<int64_t>() == expected_order;
};
auto is_supported_transpose_tensor = [](const ov::descriptor::Tensor& t) {
return is_supported_tensor(t) && ov::snippets::pass::TokenizeSnippets::supported_element_types.count(t.get_element_type()) != 0;
};
return node && node->get_output_target_inputs(0).size() == 1 && node->get_shape().size() == 4 &&
valid_transpose_order(node->get_input_node_shared_ptr(1)) && is_supported_tensor(node->get_input_tensor(0));
valid_transpose_order(node->get_input_node_shared_ptr(1)) && is_supported_transpose_tensor(node->get_input_tensor(0));
}
auto tokenize_broadcast(const std::shared_ptr<ov::Node>& interm_op, ov::NodeVector& ordered_ops) -> void {
@ -98,14 +99,15 @@ auto tokenize_reshape_around_softmax(std::shared_ptr<ov::Node>& interm_op,
ov::NodeVector& ordered_ops) -> bool {
reshape = ov::as_type_ptr<ov::opset1::Reshape>(interm_op);
if (reshape) {
const auto shape = reshape->get_input_shape(0);
if (shape.back() != reshape->get_output_shape(0).back() || reshape->get_output_target_inputs(0).size() != 1)
const auto in_shape = reshape->get_input_shape(0);
const auto out_shape = reshape->get_output_shape(0);
if (in_shape.back() != out_shape.back() || reshape->get_output_target_inputs(0).size() != 1)
return false;
ordered_ops.push_back(reshape);
interm_op = reshape->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
}
return true;
};
}
auto get_potential_body_params(const std::shared_ptr<ov::Node>& op) -> size_t {
size_t count = 0;
@ -124,43 +126,50 @@ auto get_potential_body_params(const std::shared_ptr<ov::Node>& op) -> size_t {
auto update_intermediate_supported_ops(std::shared_ptr<ov::Node>& interm_op, ov::NodeVector& ordered_ops,
size_t& hidden_virtual_ports_count, size_t& potential_body_params_count) -> bool {
// TODO: Add Reshape, FQ support
while (is_supported_intermediate_op(interm_op)) {
// All supported intermediate ops have only one output port
// To verify output element type is enough because all supported intermediate ops have the same output element type as input type
if (interm_op->get_output_target_inputs(0).size() != 1 || !is_supported_tensor(interm_op->get_output_tensor(0)))
if (interm_op->get_output_target_inputs(0).size() != 1)
return false;
// Check for supported Broadcast op
// Check for supported ops on branches: Broadcast/Elementwise (for example, dequantize ops)
if (interm_op->get_input_size() > 1) {
tokenize_broadcast(interm_op, ordered_ops);
}
auto is_supported_branch_op = [&ordered_ops](const std::shared_ptr<ov::Node>& op) {
return is_supported_intermediate_op(op) &&
ov::snippets::pass::GetSnippetsNodeType(op) != ov::snippets::pass::SnippetsNodeType::SkippedByPlugin &&
std::find(ordered_ops.begin(), ordered_ops.end(), op) == ordered_ops.end();
};
// To avoid unsupported number of non-scalar Constants in the future after FakeQuantize decomposition (plugin specific limitation)
// we should calculate potential number of non-scalar Constants for FakeQuantize that will be moved up from body.
if (const auto fq_node = ov::as_type_ptr<ov::op::v0::FakeQuantize>(interm_op)) {
hidden_virtual_ports_count += ov::snippets::utils::get_non_scalar_constant_count_for_fq(fq_node);
}
for (size_t i = 0; i < interm_op->get_input_size(); ++i) {
const size_t shift = ordered_ops.size();
auto parent = interm_op->get_input_node_shared_ptr(i);
while (is_supported_branch_op(parent)) {
// All supported ops have only one output port
if (parent->get_output_target_inputs(0).size() != 1)
break;
auto is_supported_branch_op = [&ordered_ops](const std::shared_ptr<ov::Node>& op) {
return is_supported_intermediate_op(op) &&
ov::snippets::pass::GetSnippetsNodeType(op) != ov::snippets::pass::SnippetsNodeType::SkippedByPlugin &&
std::find(ordered_ops.begin(), ordered_ops.end(), op) == ordered_ops.end();
};
// Add node only if there are scalar constants on inputs because of plugin-specific limitation
bool are_weights_scalar = true;
const auto parent_count = parent->get_input_size();
for (size_t i = 1; i < parent_count; ++i) {
are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1;
for (size_t i = 0; i < interm_op->get_input_size(); ++i) {
const size_t shift = ordered_ops.size();
auto parent = interm_op->get_input_node_shared_ptr(i);
while (is_supported_branch_op(parent)) {
// All supported ops have only one output port
if (parent->get_output_target_inputs(0).size() != 1)
break;
// Add node only if there are scalar constants on inputs because of plugin-specific limitation
bool are_weights_scalar = true;
const auto parent_count = parent->get_input_size();
for (size_t i = 1; i < parent_count; ++i) {
are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1;
}
if (!are_weights_scalar)
break;
ordered_ops.insert(ordered_ops.begin() + shift, parent);
// TODO [107731]: We think that sequence of ops goes through input port 0
// But can be Select here? If it can be, parent shouldn't be on input port 0. Need another way?
if (parent->get_input_size() > 0)
parent = parent->get_input_node_shared_ptr(0);
}
ordered_ops.insert(ordered_ops.begin() + shift, parent);
// We think that sequence of ops goes through input port 0
// But can be Select here? If it can be, parent shouldn't be on input port 0. Need another way?
parent = parent->get_input_node_shared_ptr(0);
}
}
@ -173,7 +182,7 @@ auto update_intermediate_supported_ops(std::shared_ptr<ov::Node>& interm_op, ov:
};
} // namespace
ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsTokenization::Config& config) {
MATCHER_SCOPE(TokenizeMHASnippets);
auto m_matmul0 = std::make_shared<ov::opset1::MatMul>(ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape()),
@ -184,14 +193,13 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::TokenizeMHASnippets")
auto& pattern_to_output = m.get_pattern_value_map();
// Queries + Key + Values = 3 standard inputs of MHA
size_t potential_body_params_count = 3;
// After some transformations, a different number of Constants for some operations may be created
// than the actual number of Constants during tokenization.
// To avoid unsupported number of non-scalar Constants in the future (plugin specific limitation)
// we should calculate potential number of non-scalar Constants that will be moved up from body.
// TODO: Need update this variable when FQ will be supported
size_t hidden_virtual_ports_count = 0;
// Queries + Key + Values = 3 standard inputs of MHA
size_t potential_body_params_count = 3;
// The count of potential unique Buffers - it's hidden virtual ports as well
// We should go through Subgraph and calculate potential non-inplace Buffers count.
// Example:
@ -231,10 +239,20 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
!is_supported_tensor(matmul0->get_input_tensor(0)) || !is_supported_tensor(matmul0->get_input_tensor(1)))
return false;
if (transformation_callback(matmul0)) {
const auto matmul0_prc = op::Brgemm::get_output_type(matmul0->get_input_element_type(0),
matmul0->get_input_element_type(1));
if (matmul0_prc == element::undefined) {
return false;
}
// Between MatMul0 and Softmax will be the one Loop because of LoopFusing optimization.
// The Loop will have one Buffer with the same shape both on input and output.
// Need to check for precision to get if we need one more register for Buffer
if (matmul0_prc.size() != ov::element::f32.size()) {
if (buffer_count < 2)
buffer_count++;
}
ordered_ops.push_back(matmul0);
auto interm_op = matmul0->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
@ -276,10 +294,28 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
return false;
const auto matmul1 = ov::as_type_ptr<ov::opset1::MatMul>(interm_op);
if (!matmul1 || matmul1->get_output_target_inputs(0).size() != 1 || matmul1->get_transpose_a() || matmul1->get_transpose_b() ||
!is_supported_tensor(matmul1->get_input_tensor(0)) || !is_supported_tensor(matmul1->get_input_tensor(1)))
if (!matmul1 || matmul1->get_output_target_inputs(0).size() != 1 ||
matmul1->get_transpose_a() || matmul1->get_transpose_b())
return false;
const auto matmul1_out_type = op::Brgemm::get_output_type(matmul1->get_input_element_type(0),
matmul1->get_input_element_type(1));
if (matmul1_out_type == element::undefined ||
!is_supported_tensor(matmul1->get_input_tensor(0)) ||
!is_supported_tensor(matmul1->get_input_tensor(1)))
return false;
if (transformation_callback(matmul0)) {
return false;
}
// Between Softmax and MatMul1 will be the one Loop because of LoopFusing optimization.
// The Loop will have one Buffer with the same shape both on input and output.
// Need to check for precision to get if we need one more register for Buffer
if (matmul1->get_input_element_type(0).size() != ov::element::f32.size()) {
buffer_count++;
}
/***********************/
/***** Transposes *****/
@ -287,29 +323,51 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
* We can add them into Subgraph body
*/
auto tokenize_transpose = [config](const std::shared_ptr<ov::Node>& node) -> std::shared_ptr<ov::opset1::Transpose> {
return config.mha_token_enable_transpose ? ov::as_type_ptr<ov::opset1::Transpose>(node)
: nullptr;
};
// First input branch of MatMul0 should be executed before second input branch of MatMul0,
// so firstly we insert Transpose1 on the beginning of ordered_ops and then Transpose1
bool are_weights_scalar = true;
// We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order (or without this Transpose1)
// only if these ops have scalar shapes on other inputs.
// There is transformation ExplicitTransposeMatMulInputs that set supported order and transposed_b(false).
// We can allow to call this pass only if ops have scalar shapes to avoid shape mismatching
const auto is_transposed_b_0 = matmul0->get_transpose_b();
auto parent = matmul0->get_input_node_shared_ptr(1);
while (is_supported_intermediate_op(parent)) {
// All supported ops have only one output port
// To verify output element type is enough because all supported ops have the same output element type as input type
if (parent->get_output_target_inputs(0).size() != 1 || !is_supported_tensor(parent->get_output_tensor(0)))
if (parent->get_output_target_inputs(0).size() != 1)
break;
const auto parent_count = parent->inputs().size();
for (size_t i = 1; i < parent_count; ++i) {
are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1;
// Only if MatMul0 has transposed_b, we have to tokenize scalar ops
// to move explicit Transpose from MatMul0 input_1 to Parameter of Subgraph body
if (is_transposed_b_0) {
const auto parent_count = parent->get_input_size();
bool are_weights_scalar = true;
for (size_t i = 1; i < parent_count; ++i) {
are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1;
}
if (!are_weights_scalar) {
break;
}
}
// To avoid unsupported number of non-scalar Constants in the future after FakeQuantize decomposition (plugin specific limitation)
// we should calculate potential number of non-scalar Constants for FakeQuantize that will be moved up from body.
if (const auto fq_node = ov::as_type_ptr<ov::op::v0::FakeQuantize>(parent)) {
hidden_virtual_ports_count += ov::snippets::utils::get_non_scalar_constant_count_for_fq(fq_node);
}
potential_body_params_count += get_potential_body_params(parent);
ordered_ops.insert(ordered_ops.begin(), parent);
// We think that sequence of ops goes through input port 0
// But can be Select here? If it can be, parent shouldn't be on input port 0. Need another way?
// TODO [107731] To go always through 0-th port - is it safe?
parent = parent->get_input_node_shared_ptr(0);
}
auto transpose1 = ov::as_type_ptr<ov::opset1::Transpose>(parent);
if (matmul0->get_transpose_b()) {
const auto transpose1 = tokenize_transpose(parent);
if (is_transposed_b_0) {
if (is_valid_transpose(transpose1, {0, 2, 1, 3})) {
// We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order
// only if these ops have scalar shapes on other inputs.
@ -329,31 +387,63 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
}
}
// TODO: Add Reshape Support for all Transposes
// Add 3D support for all Transposes
const auto transpose0 = ov::as_type_ptr<ov::opset1::Transpose>(matmul0->get_input_node_shared_ptr(0));
if (transpose1) {
// Between Transpose1 and MatMul0 will be the one Loop because of LoopFusing optimization.
// The Loop will have one Buffer with the same shape both on input and output.
// Need to check for precision to get if we need one more register for Buffer
if (matmul0->get_input_element_type(1).size() != transpose1->get_output_element_type(0).size()) {
buffer_count++;
}
}
const auto transpose0 = tokenize_transpose(matmul0->get_input_node_shared_ptr(0));
if (is_valid_transpose(transpose0, {0, 2, 1, 3})) {
ordered_ops.insert(ordered_ops.begin(), transpose0);
} else if (matmul0->get_transpose_b()) {
} else if (matmul0->get_transpose_a()) {
return false;
}
const auto transpose2 = ov::as_type_ptr<ov::opset1::Transpose>(matmul1->get_input_node_shared_ptr(1));
const auto transpose2 = tokenize_transpose(matmul1->get_input_node_shared_ptr(1));
if (is_valid_transpose(transpose2, {0, 2, 1, 3})) {
ordered_ops.push_back(transpose2);
}
ordered_ops.push_back(matmul1);
bool are_ops_after_matmul1 = false;
auto child = matmul1->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
// TODO: Add support Eltwises between MatMul1 and Transpose
// status = update_intermediate_supported_ops(child, ordered_ops);
// if (!status) {
// ordered_ops.push_back(child);
// }
while (is_supported_intermediate_op(child)) {
are_ops_after_matmul1 = true;
// All supported ops have only one output port
if (child->get_output_target_inputs(0).size() != 1)
break;
auto transpose3 = ov::as_type_ptr<ov::opset1::Transpose>(child);
if (is_valid_transpose(transpose3, {0, 2, 1, 3})) {
ordered_ops.push_back(transpose3);
// To avoid unsupported number of non-scalar Constants in the future after FakeQuantize decomposition (plugin specific limitation)
// we should calculate potential number of non-scalar Constants for FakeQuantize that will be moved up from body.
if (const auto fq_node = ov::as_type_ptr<ov::op::v0::FakeQuantize>(child)) {
hidden_virtual_ports_count += ov::snippets::utils::get_non_scalar_constant_count_for_fq(fq_node);
}
potential_body_params_count += get_potential_body_params(child);
// TODO [75567]: move this plugin-specific constraint to the plugin callback
// We cannot collapse op to Subgraph if count of potential Parameter and Result count is higher 12
if (potential_body_params_count + child->get_output_target_inputs(0).size() + hidden_virtual_ports_count + buffer_count > 12) {
break;
}
ordered_ops.push_back(child);
child = child->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
}
// At the moment Snippets don't support nodes between MatMul1 and Transpose3 due to Loop and strided calculations limitations
// MatMul1
// <Supported ops>
// Transpose3
if (!are_ops_after_matmul1) {
auto transpose3 = tokenize_transpose(child);
if (is_valid_transpose(transpose3, {0, 2, 1, 3}) &&
transpose3->get_input_element_type(0) == matmul1_out_type) { // To avoid Convert between MatMul1 and Transpose3
ordered_ops.push_back(transpose3);
}
}
/**********************/
@ -362,7 +452,7 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
/* ====== Subgraph creation ======= */
// TODO: move this plugin-specific constraint to the plugin callback
// TODO [75567]: move this plugin-specific constraint to the plugin callback
const auto last_node = ordered_ops.back();
if (potential_body_params_count + last_node->get_output_size() + hidden_virtual_ports_count + buffer_count > 12) {
return false;
@ -378,7 +468,9 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
const auto input = node->input(i);
const auto parent = input.get_source_output().get_node_shared_ptr();
const auto constant = ov::as_type_ptr<ov::op::v0::Constant>(parent);
if (constant && (ov::shape_size(input.get_shape()) == 1 || op::Subgraph::constant_input_should_be_inside_body(node))) {
if (constant && (ov::shape_size(input.get_shape()) == 1 ||
ov::is_type<ov::op::v0::FakeQuantize>(node) ||
op::Subgraph::constant_input_should_be_inside_body(node))) {
// If Constant has one consumer - target node, we add Constant to body_inputs
// If Constant has several consumers, we should check that all these consumers are inside Subgraph body
// and if all of them are inside body, we can explicitly add Constant to the body_inputs, otherwise we should
@ -454,6 +546,9 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
subgraph->get_rt_info()["originalLayersNames"] = fused_names;
subgraph->set_virtual_port_count(hidden_virtual_ports_count);
// mark the Subgraph as Completed to not allow Snippets to include any nodes into the MHA Subgraph in common Tokenization
SetSnippetsSubgraphType(subgraph, SnippetsSubgraphType::Completed);
return true;
/* ================================ */

View File

@ -7,6 +7,8 @@
#include "snippets/pass/tokenization.hpp"
#include "snippets/pass/common_optimizations.hpp"
#include "openvino/pass/manager.hpp"
#include "snippets/pass/mha_tokenization.hpp"
#include "snippets/pass/collapse_subgraph.hpp"
namespace ov {
@ -18,6 +20,13 @@ void SetSnippetsNodeType(const std::shared_ptr<Node> &node, SnippetsNodeType nod
rt["SnippetsNodeType"] = nodeType;
}
void SetSnippetsSubgraphType(const std::shared_ptr<op::Subgraph> &node, SnippetsSubgraphType nodeType) {
if (node) {
auto &rt = node->get_rt_info();
rt["SnippetsSubgraphType"] = nodeType;
}
}
SnippetsNodeType GetSnippetsNodeType(const std::shared_ptr<const Node> &node) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::GetSnippetsNodeType")
auto& rt = node->get_rt_info();
@ -27,6 +36,17 @@ SnippetsNodeType GetSnippetsNodeType(const std::shared_ptr<const Node> &node) {
return rinfo->second.as<SnippetsNodeType>();
}
SnippetsSubgraphType GetSnippetsSubgraphType(const std::shared_ptr<const op::Subgraph> &node) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::GetSnippetsSubgraphType")
if (!node)
return SnippetsSubgraphType::NotSet;
auto &rt = node->get_rt_info();
const auto rinfo = rt.find("SnippetsSubgraphType");
if (rinfo == rt.end())
return SnippetsSubgraphType::NotSet;
return rinfo->second.as<SnippetsSubgraphType>();
}
void SetTopologicalOrder(const std::shared_ptr<Node> &node, int64_t order) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::SetTopologicalOrder")
auto& rt = node->get_rt_info();
@ -58,7 +78,7 @@ bool SnippetsTokenization::run_on_model(const std::shared_ptr<ov::Model>& m) {
manager.set_per_pass_validation(false);
manager.register_pass<EnumerateNodes>();
manager.register_pass<TokenizeMHASnippets>();
manager.register_pass<TokenizeMHASnippets>(m_config);
manager.register_pass<TokenizeSnippets>();
manager.register_pass<CommonOptimizations>();
manager.run_passes(m);

View File

@ -5,6 +5,7 @@
#include <common_test_utils/ngraph_test_utils.hpp>
#include "lowering_utils.hpp"
#include "snippets/pass/tokenization.hpp"
#include "snippets/pass/collapse_subgraph.hpp"
namespace ov {

View File

@ -8,6 +8,7 @@
#include <subgraph_fq.hpp>
#include <subgraph_converts.hpp>
#include "snippets/pass/tokenization.hpp"
#include "snippets/pass/collapse_subgraph.hpp"
namespace ov {
namespace test {

View File

@ -6,6 +6,7 @@
#include <pass/mha_tokenization.hpp>
#include <subgraph_mha.hpp>
#include "snippets/pass/tokenization.hpp"
#include "snippets/pass/mha_tokenization.hpp"
#include "snippets/pass/explicit_transpose_matmul_inputs.hpp"
namespace ov {
@ -20,14 +21,23 @@ void TokenizeMHASnippetsTests::run() {
}
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA) {
const auto& f = MHAFunction(std::vector<PartialShape>{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}});
const auto &f = MHAFunction(std::vector<PartialShape>{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}));
function = f.getOriginal();
function_ref = f.getReference();
run();
}
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_with_MatMul0_Transpose) {
const auto& f = MHAMatMul0TransposeFunction(std::vector<PartialShape>{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}});
const auto &f = MHAMatMul0TransposeFunction(std::vector<PartialShape>{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}));
function = f.getOriginal();
function_ref = f.getReference();
run();
}
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_with_int_Matmuls) {
const auto &f = MHAINT8MatMulTypeRelaxedFunction(std::vector<PartialShape>{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}});
function = f.getOriginal();
function_ref = f.getReference();
run();

View File

@ -172,7 +172,7 @@ KernelEmitter::KernelEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl:
general_exprs.emplace_back(expr);
}
}
num_unique_buffer = unique_buffers.size();
num_unique_buffers = unique_buffers.size();
// Note that we can't use reg_indexes_idx or reg_const_params_idx to store data pointers because these two
// regs are used to calculate offsets for the data pointers
@ -198,15 +198,16 @@ void KernelEmitter::validate_arguments(const std::vector<size_t> &in,
IE_THROW() << "KernelEmitter got invalid number of inputs. Expected 0, got " << in.size();
if (!out.empty())
IE_THROW() << "KernelEmitter got invalid number of outputs. Expected 0, got " << out.size();
const auto num_params = num_inputs + num_outputs + num_unique_buffer;
const auto num_params = num_inputs + num_outputs + num_unique_buffers;
// The number of used gpr may be >= num_params since LoopBegin+LoopEnd could also use gpr to store work_amount
if (data_ptr_regs_idx.size() != num_params)
IE_THROW() << "KernelEmitter: number of inputs and outputs is inconsisnent with the number of allocated registers"
IE_THROW() << "KernelEmitter: number of inputs and outputs is inconsistent with the number of allocated registers "
<< num_params << " data_ptr_regs_idx.size() = " << data_ptr_regs_idx.size();
}
void KernelEmitter::init_data_pointers(size_t num_inputs, size_t num_params, size_t num_buffer,
const Reg64& reg_indexes, const Reg64& reg_const_params, const std::vector<Reg64>& data_ptr_regs) const {
void KernelEmitter::init_data_pointers(const Xbyak::Reg64& reg_indexes, const Xbyak::Reg64& reg_const_params,
const std::vector<Xbyak::Reg64>& data_ptr_regs) const {
const auto num_params = num_inputs + num_outputs;
// Note that we don't need offset for the last dim, since it's handled directly by Tile emitter
const size_t offset_rank = jcp.master_shape.size() - 1;
std::vector<std::vector<size_t>> data_offsets(num_params, std::vector<size_t>{});
@ -267,7 +268,9 @@ void KernelEmitter::init_data_pointers(size_t num_inputs, size_t num_params, siz
// Vector "data_ptr_regs" is sorted by abstract regs.
// It means that the vector contains the physical registers in order [src, .., src, dst, .., dst, buffer]
// So we can initialize buffer register firstly as last value of vector "data_ptr_regs"
for (size_t i = 0; i < num_buffer; ++i) {
// NOTE: Snippets Buffer Scratchpad has the common data pointer for all Buffers (even with different ID).
// The accessing memory is covered by correct offsets in each Buffer and the corresponding MemoryAccess ops
for (size_t i = 0; i < num_unique_buffers; ++i) {
h->mov(data_ptr_regs[num_params + i], h->ptr[reg_const_params + GET_OFF(buffer_scratchpad_ptr)]);
}
size_t i = 0;
@ -299,7 +302,7 @@ void KernelEmitter::emit_impl(const std::vector<size_t>& in,
std::vector<Reg64> data_ptr_regs;
transform_idxs_to_regs(data_ptr_regs_idx, data_ptr_regs);
init_data_pointers(num_inputs, num_inputs + num_outputs, num_unique_buffer, reg_indexes, reg_const_params, data_ptr_regs);
init_data_pointers(reg_indexes, reg_const_params, data_ptr_regs);
for (const auto& expression : body) {
const auto& emitter = expression->get_emitter();
std::vector<size_t> in_regs, out_regs;

View File

@ -87,13 +87,13 @@ private:
const std::vector<size_t> &out) const override;
void emit_impl(const std::vector<size_t>& in,
const std::vector<size_t>& out) const override;
void init_data_pointers(size_t, size_t, size_t, const Xbyak::Reg64&, const Xbyak::Reg64&, const std::vector<Xbyak::Reg64>&) const;
void init_data_pointers(const Xbyak::Reg64&, const Xbyak::Reg64&, const std::vector<Xbyak::Reg64>&) const;
jit_snippets_compile_args jcp;
std::vector<size_t> gp_regs_pool;
size_t num_inputs;
size_t num_outputs;
size_t num_unique_buffer;
size_t num_unique_buffers;
// Vector of indices (lenght = input tensor rank) per every input and output that describes in which order
// corresponding tensor dimensions are accessed (default: consecutive dense, e.g. 0,1,2,3 for 4D tensor).
// Needed to calc i/o offsets.

View File

@ -427,7 +427,7 @@ void MarkSubgraphOpAsSkipped(const std::shared_ptr<Node> &node) {
bool isSuitableConvert(const std::shared_ptr<const Node>& node) {
if (!ov::is_type<ov::op::v0::Convert>(node))
return false;
auto hasResult = [](const std::shared_ptr<const Node>& node){
auto hasResult = [](const std::shared_ptr<const Node>& node) {
auto consumers = node->output(0).get_target_inputs();
bool findResult = false;
if (consumers.size() == 1) {
@ -449,13 +449,19 @@ bool isSuitableConvert(const std::shared_ptr<const Node>& node) {
return false;
}
}
auto is_skipped_op(const std::shared_ptr<ov::Node>& op) -> bool {
return ov::is_type<ov::op::v0::Constant>(op) ||
ov::is_type<ov::op::v0::Parameter>(op) ||
ov::is_type<ov::op::v0::Result>(op);
}
} // namespace
bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr<ov::Model> &m) {
RUN_ON_MODEL_SCOPE(SnippetsMarkSkipped);
int channelAxis = DEFAULT_AXIS;
for (auto &node : m->get_ordered_ops()) {
if (ov::is_type<ov::op::v0::Constant>(node) || ov::is_type<ov::op::v0::Result>(node))
if (is_skipped_op(node))
continue;
if (isSuitableConvolutionParent(node)) {
// Initiate fusing chain

View File

@ -108,6 +108,8 @@
// Snippets
#include "snippets/pass/tokenization.hpp"
#include "snippets/pass/mha_tokenization.hpp"
#include "snippets/pass/collapse_subgraph.hpp"
#include "snippets/pass/common_optimizations.hpp"
// Misc
@ -616,22 +618,58 @@ void Transformations::MainSnippets(void) {
!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) // snippets are implemented only for relevant platforms (avx2+ extensions)
return;
// At the moment Snippets supports Transposes in MHA pattern only in FP32 case since
// - ConvertSaturation[BF16->FP32] will be inserted after Parameters and before Transposes in canonicalization stage
// - ConvertSaturation[FP32->BF16] will be inserted after Transposes and before Brgemm in precision propagation stage
// Because of that Transposes won't be fused into Brgemm
// TODO [111813]: Need to update this pipeline to avoid Converts between Transposes and Brgemm on inputs
ov::snippets::pass::SnippetsTokenization::Config tokenization_config;
tokenization_config.mha_token_enable_transpose = !enableBF16;
ngraph::pass::Manager snippetsManager;
snippetsManager.set_per_pass_validation(false);
if (snippetsMode != Config::SnippetsMode::IgnoreCallback)
CPU_REGISTER_PASS_X64(snippetsManager, SnippetsMarkSkipped, enableBF16);
CPU_REGISTER_PASS_X64(snippetsManager, snippets::pass::SnippetsTokenization);
CPU_REGISTER_PASS_X64(snippetsManager, snippets::pass::SnippetsTokenization, tokenization_config);
// Tokenize MHA in quantized model or with BF16 only in tests.
// TODO [106921]: Please enable the tokenization when the ticket 106921 with blocking support for BRGEMM will be implemented
const bool onlyFloatSupported = snippetsMode != Config::SnippetsMode::IgnoreCallback;
const bool isMHASupported =
!enableBF16 && // TODO: Need to add BF16 support for MHA in Snippets
IMPLICATION(enableBF16, !onlyFloatSupported) &&
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core); // MHA has BRGEMM that is supported only on AVX512 platforms
if (!isMHASupported) {
CPU_DISABLE_PASS_X64(snippetsManager, snippets::pass::TokenizeMHASnippets);
}
#if defined(OPENVINO_ARCH_X86_64)
auto is_supported_matmul = [onlyFloatSupported](const std::shared_ptr<const ov::Node>& n) {
const auto matmul = ov::as_type_ptr<const ov::op::v0::MatMul>(n);
if (!matmul)
return false;
if (matmul->get_input_element_type(1) == ov::element::i8)
return !onlyFloatSupported && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_vnni);
if (matmul->get_input_element_type(0) == ov::element::bf16 &&
matmul->get_input_element_type(1) == ov::element::bf16)
return !onlyFloatSupported && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16);
return true;
};
#endif // OPENVINO_ARCH_X86_64
if (snippetsMode != Config::SnippetsMode::IgnoreCallback) {
CPU_SET_CALLBACK_X64(snippetsManager,
[](const std::shared_ptr<const ov::Node>& n) -> bool {
const auto pshape = n->get_output_partial_shape(0);
[&](const std::shared_ptr<const ov::Node>& n) -> bool {
// Tranformation callback is called on MatMul0
if (!is_supported_matmul(n))
return true;
// Search for MatMul1
auto child = n->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
while (!ov::is_type<const ov::op::v0::MatMul>(child)) {
child = child->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
}
if (!is_supported_matmul(child))
return true;
const auto pshape = child->get_input_partial_shape(0);
const auto shape = pshape.get_shape();
const auto parallel_work_amount =
std::accumulate(shape.rbegin() + 2, shape.rend(), 1, std::multiplies<size_t>());
@ -662,18 +700,18 @@ void Transformations::MainSnippets(void) {
// todo: general tokenization flow is not currently supported for these operations.
// they can be tokenized only as a part of complex patterns
const bool is_disabled_tokenization = (ov::is_type<const ov::op::v1::Softmax>(n) ||
ov::is_type<const ov::op::v8::Softmax>(n) ||
ov::is_type<const ov::op::v0::MatMul>(n) ||
ov::is_type<const ov::op::v1::Transpose>(n) ||
ov::is_type<const ov::op::v1::Broadcast>(n) ||
ov::is_type<const ov::op::v3::Broadcast>(n));
ov::is_type<const ov::op::v8::Softmax>(n) ||
ov::is_type<const ov::op::v0::MatMul>(n) ||
ov::is_type<const ov::op::v1::Transpose>(n) ||
ov::is_type<const ov::op::v1::Broadcast>(n) ||
ov::is_type<const ov::op::v3::Broadcast>(n));
const auto& inputs = n->inputs();
// todo: clarify whether we can evaluate snippets on const paths
const bool has_only_const_inputs = std::all_of(inputs.begin(), inputs.end(),
[](const ov::Input<const ov::Node>& in) {
return ov::is_type<ov::op::v0::Constant>(
in.get_source_output().get_node_shared_ptr());
});
[](const ov::Input<const ov::Node>& in) {
return ov::is_type<ov::op::v0::Constant>(
in.get_source_output().get_node_shared_ptr());
});
// todo: clarify whether we can evaluate snippets on inputs with larger ranks
auto rank_is_too_large = [](const ov::descriptor::Tensor& t) {
// callback is called has_supported_in_out(), so it's safe to assume that the shapes are static

View File

@ -246,6 +246,9 @@ std::vector<std::string> disabledTestPatterns() {
if (!InferenceEngine::with_cpu_x86_avx512_core_vnni() && !InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) {
// MatMul in Snippets uses BRGEMM that supports i8 only on platforms with VNNI or AMX instructions
retVector.emplace_back(R"(.*Snippets.*MatMulFQ.*)");
retVector.emplace_back(R"(.*Snippets.*MatMul.*Quantized.*)");
retVector.emplace_back(R"(.*Snippets.*MHAFQ.*)");
retVector.emplace_back(R"(.*Snippets.*MHAINT8.*)");
}
if (!InferenceEngine::with_cpu_x86_avx512_core_amx_int8())
//TODO: Issue 92895
@ -254,6 +257,7 @@ std::vector<std::string> disabledTestPatterns() {
if (!InferenceEngine::with_cpu_x86_avx512_core_amx_bf16() && !InferenceEngine::with_cpu_x86_bfloat16()) {
// ignored for not supported bf16 platforms
retVector.emplace_back(R"(.*smoke_Snippets_EnforcePrecision_bf16.*)");
retVector.emplace_back(R"(.*smoke_Snippets_MHAWOTransposeEnforceBF16.*)");
}
return retVector;

View File

@ -19,23 +19,7 @@ std::vector<std::vector<ov::PartialShape>> input_shapes{
{{1, 1, 37, 23}, {1, 2, 23, 33}},
{{1, 16, 384, 64}, {1, 16, 64, 384}}
};
static inline std::vector<std::vector<element::Type>> precisions(bool only_fp32 = true) {
std::vector<std::vector<element::Type>> prc = {
{element::f32, element::f32},
};
if (!only_fp32) {
// In Snippets MatMul INT8 is supported only on VNNI/AMX platforms
if (InferenceEngine::with_cpu_x86_avx512_core_vnni() || InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) {
prc.emplace_back(std::vector<element::Type>{element::i8, element::i8});
prc.emplace_back(std::vector<element::Type>{element::u8, element::i8});
}
// In Snippets MatMul BF16 is supported only on bf16/AMX platforms
if (InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16()) {
prc.emplace_back(std::vector<element::Type>{element::bf16, element::bf16});
}
}
return prc;
}
static inline std::vector<std::vector<element::Type>> quantized_precisions() {
std::vector<std::vector<element::Type>> prc = {};
// In Snippets MatMul INT8 is supported only on VNNI/AMX platforms
@ -46,6 +30,21 @@ static inline std::vector<std::vector<element::Type>> quantized_precisions() {
return prc;
}
static inline std::vector<std::vector<element::Type>> precisions(bool only_fp32 = true) {
std::vector<std::vector<element::Type>> prc = {
{element::f32, element::f32},
};
if (!only_fp32) {
auto quant = quantized_precisions();
std::copy(quant.begin(), quant.end(), std::back_inserter(prc));
// In Snippets MatMul BF16 is supported only on bf16/AMX platforms
if (InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16()) {
prc.emplace_back(std::vector<element::Type>{element::bf16, element::bf16});
}
}
return prc;
}
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, MatMul,
::testing::Combine(
::testing::ValuesIn(input_shapes),

View File

@ -4,7 +4,9 @@
#include "snippets/mha.hpp"
#include "common_test_utils/test_constants.hpp"
#include "test_utils/cpu_test_utils.hpp"
#include "ie_plugin_config.hpp"
#include "ie_system_conf.h"
namespace ov {
namespace test {
@ -15,22 +17,52 @@ namespace {
const std::vector<std::vector<ov::PartialShape>> inputShapes = {
{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}},
{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}},
{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 1, 1}, {1, 128, 16, 64}},
{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}},
{{2, 68, 6, 92}, {2, 68, 6, 92}, {1, 1, 68, 68}, {2, 68, 6, 92}},
{{1, 58, 16, 34}, {1, 58, 16, 34}, {1, 1, 1, 58}, {1, 58, 16, 34}},
};
static inline bool is_bf16_supported() {
return InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16();
}
static inline std::vector<std::vector<element::Type>> precision_f32(size_t count) {
std::vector<std::vector<element::Type>> prc;
prc.emplace_back(std::vector<element::Type>(count, element::f32));
return prc;
}
static inline std::vector<std::vector<element::Type>> precision_bf16(size_t count) {
std::vector<std::vector<element::Type>> prc;
if (is_bf16_supported())
prc.emplace_back(std::vector<element::Type>(count, element::bf16));
return prc;
}
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHA,
::testing::Combine(
::testing::ValuesIn(inputShapes),
::testing::ValuesIn({false, true}),
::testing::Values(ov::element::f32),
::testing::Values(1),
::testing::Values(1),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(std::map<std::string, std::string>{})),
MHA::getTestCaseName);
::testing::Combine(
::testing::ValuesIn(inputShapes),
::testing::ValuesIn(precision_f32(4)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({false, true}),
::testing::Values(1),
::testing::Values(1),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
MHA::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHA,
::testing::Combine(
::testing::ValuesIn(inputShapes),
::testing::ValuesIn(precision_bf16(4)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({false, true}),
::testing::Values(7), // MHA + 5 Converts + 1 Transpose on output
::testing::Values(6), // MHA + 5 Converts on inputs and output
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
MHA::getTestCaseName);
const std::vector<std::vector<ov::PartialShape>> inputShapeSelect = {
// without broadcast
@ -44,64 +76,142 @@ const std::vector<std::vector<ov::PartialShape>> inputShapeSelect = {
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHASelect,
::testing::Combine(
::testing::ValuesIn(inputShapeSelect),
::testing::Values(false), // Need to support True for graph builder in tests
::testing::ValuesIn(precision_f32(6)),
::testing::Values(ov::element::f32),
::testing::Values(false), // Need to support True for graph builder in tests
::testing::Values(2), // Less + MHA
::testing::Values(2),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(std::map<std::string, std::string>{})),
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
MHA::getTestCaseName);
const std::vector<std::vector<ov::PartialShape>> inputShapesWOTranspose_4D = {
{{1, 12, 197, 64}, {1, 12, 64, 197}, {1, 12, 197, 64}},
{{1, 12, 12, 64}, {1, 12, 64, 48}, {1, 12, 48, 64}}
};
const std::vector<std::vector<ov::PartialShape>> inputShapesWOTranspose_3D = {
{{12, 197, 64}, {12, 64, 197}, {12, 197, 64}},
{{12, 128, 100}, {12, 100, 128}, {12, 128, 100}}
};
static std::vector<std::vector<ov::PartialShape>> inputShapesWOTranspose(bool supports_3d = false) {
std::vector<std::vector<ov::PartialShape>> shapes = {
{{1, 12, 197, 64}, {1, 12, 64, 197}, {1, 12, 197, 64}},
{{1, 12, 12, 64}, {1, 12, 64, 48}, {1, 12, 48, 64}}
};
if (supports_3d) {
std::vector<std::vector<ov::PartialShape>> shapes_3d = {
{{12, 197, 64}, {12, 64, 197}, {12, 197, 64}},
{{12, 128, 100}, {12, 100, 128}, {12, 128, 100}}
};
shapes.insert(shapes.end(), shapes_3d.begin(), shapes_3d.end());
}
return shapes;
}
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeOnInputs, MHAWOTransposeOnInputs,
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeOnInputs_4D, MHAWOTransposeOnInputs,
::testing::Combine(
::testing::ValuesIn(inputShapesWOTranspose()),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::ValuesIn(inputShapesWOTranspose_4D),
::testing::Values(std::vector<ov::element::Type>{}),
::testing::Values(ov::element::f32),
::testing::Values(true), // Need to support False for graph builder in tests
::testing::Values(1),
::testing::Values(1),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(std::map<std::string, std::string>{})),
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
MHA::getTestCaseName);
const std::map<std::string, std::string> cpuBF16PluginConfig = { { InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16,
InferenceEngine::PluginConfigParams::YES } };
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHAWOTranspose,
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTranspose_4D, MHAWOTranspose,
::testing::Combine(
::testing::ValuesIn(inputShapesWOTranspose(true)),
::testing::ValuesIn(inputShapesWOTranspose_4D),
::testing::ValuesIn(precision_f32(3)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(1),
::testing::Values(1),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
MHA::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTranspose_3D, MHAWOTranspose,
::testing::Combine(
::testing::ValuesIn(inputShapesWOTranspose_3D),
::testing::ValuesIn(precision_f32(3)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(1),
::testing::Values(1),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
MHA::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeBF16_4D, MHAWOTranspose,
::testing::Combine(
::testing::ValuesIn(inputShapesWOTranspose_4D),
::testing::ValuesIn(precision_bf16(3)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
MHA::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeBF16_3D, MHAWOTranspose,
::testing::Combine(
::testing::ValuesIn(inputShapesWOTranspose_3D),
::testing::ValuesIn(precision_bf16(3)),
::testing::Values(ov::element::f32),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
MHA::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeEnforceBF16_4D, MHAWOTranspose,
::testing::Combine(
::testing::ValuesIn(inputShapesWOTranspose_4D),
::testing::ValuesIn(precision_f32(3)),
::testing::Values(ov::element::bf16),
::testing::Values(3),
::testing::Values(0), // CPU plugin doesn't support MHA pattern via Snippets on bf16
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(cpuBF16PluginConfig)),
::testing::Values(CPUTestUtils::cpuBF16PluginConfig)),
MHA::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTranspose, MHAWOTranspose,
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeEnforceBF16_3D, MHAWOTranspose,
::testing::Combine(
::testing::ValuesIn(inputShapesWOTranspose(true)),
::testing::ValuesIn(inputShapesWOTranspose_3D),
::testing::ValuesIn(precision_f32(3)),
::testing::Values(ov::element::bf16),
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
::testing::Values(ov::element::f32),
::testing::Values(1),
::testing::Values(1),
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(std::map<std::string, std::string>{})),
::testing::Values(CPUTestUtils::cpuBF16PluginConfig)),
MHA::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAINT8MatMul, MHAINT8MatMul,
::testing::Combine(
::testing::ValuesIn(std::vector<std::vector<ov::PartialShape>>(inputShapes.begin(), inputShapes.begin() + 2)),
::testing::Values(std::vector<element::Type>{}),
::testing::Values(ov::element::f32),
::testing::Values(false), // The graph doesn't contain Multiply
::testing::Values(6), // FQx3 on inputs + MHA + Transpose on output + Deq Mul
::testing::Values(5), // FQx3 on inputs + MHA + Deq Mul
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
MHA::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFQAfterMatMul, MHAFQAfterMatMul,
::testing::Combine(
::testing::ValuesIn(inputShapes),
::testing::Values(std::vector<element::Type>{}),
::testing::Values(ov::element::f32),
::testing::Values(false), // The graph doesn't contain Multiply
::testing::Values(3), // MHA + Transpose on output + Deq Mul
::testing::Values(2), // MHA + Deq Mul
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
MHA::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFQ, MHAFQ,
::testing::Combine(
::testing::Values(std::vector<ov::PartialShape>{{1, 64, 12, 64}, {1, 64, 12, 64}, {1, 1, 1, 64}, {1, 64, 12, 64}}),
::testing::Values(std::vector<element::Type>{}),
::testing::Values(ov::element::f32),
::testing::Values(false), // The graph doesn't contain Multiply
::testing::Values(7), // Transposex2 + Subgraphsx5
::testing::Values(5), // MHA + Deq Mul on output + Deqs on inputs + 2 xFQ on inputs
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
MHA::getTestCaseName);

View File

@ -575,7 +575,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant, MHAQuantTest,
::testing::ValuesIn(inputPrecisionsQuant),
::testing::ValuesIn(matMulIn0PrecisionsQuant),
::testing::ValuesIn(patternTypesQuant),
::testing::Values("MHA"), // Snippets don't support Quantized MHA pattern yet
::testing::Values("MHA"),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
MHAQuantTest::getTestCaseName);

View File

@ -12,8 +12,9 @@ namespace snippets {
typedef std::tuple<
std::vector<ov::PartialShape>, // Input shapes
bool, // With Multiply
std::vector<ov::element::Type>, // Input Element types
ov::element::Type, // Inference precision
bool, // With Multiply
size_t, // Expected num nodes
size_t, // Expected num subgraphs
std::string, // Target Device
@ -32,6 +33,7 @@ protected:
virtual void init_subgraph();
bool m_with_mul = false;
std::vector<ov::element::Type> m_input_types;
};
class MHASelect : public MHA {
@ -46,6 +48,22 @@ protected:
};
class MHAWOTranspose : public MHA {
protected:
void init_subgraph() override;
};
class MHAINT8MatMul : public MHA {
protected:
void init_subgraph() override;
};
class MHAFQAfterMatMul : public MHA {
protected:
void init_subgraph() override;
};
class MHAFQ : public MHA {
protected:
void init_subgraph() override;
};

View File

@ -15,16 +15,19 @@ namespace snippets {
std::string MHA::getTestCaseName(testing::TestParamInfo<ov::test::snippets::MHAParams> obj) {
std::vector<ov::PartialShape> inputShapes;
bool withMul;
std::vector<ov::element::Type> elem_types;
ov::element::Type prc;
bool withMul;
std::string targetDevice;
size_t num_nodes, num_subgraphs;
std::map<std::string, std::string> additionalConfig;
std::tie(inputShapes, withMul, prc, num_nodes, num_subgraphs, targetDevice, additionalConfig) = obj.param;
std::tie(inputShapes, elem_types, prc, withMul, num_nodes, num_subgraphs, targetDevice, additionalConfig) = obj.param;
std::ostringstream result;
for (size_t i = 0; i < inputShapes.size(); ++i)
result << "IS[" << i << "]=" << CommonTestUtils::partialShape2str({inputShapes[i]}) << "_";
for (size_t i = 0; i < elem_types.size(); i++)
result << "T[" << i <<"]=" << elem_types[i] << "_";
result << "Mul=" << withMul << "_";
result << "PRC=" << prc << "_";
result << "#N=" << num_nodes << "_";
@ -45,13 +48,13 @@ void MHA::SetUp() {
std::vector<ov::PartialShape> inputShapes;
ov::element::Type prc;
std::map<std::string, std::string> additionalConfig;
std::tie(inputShapes, m_with_mul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam();
std::tie(inputShapes, m_input_types, prc, m_with_mul, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam();
init_input_shapes(static_partial_shapes_to_test_representation(inputShapes));
init_subgraph();
configuration.insert(additionalConfig.begin(), additionalConfig.end());
if (additionalConfig.empty() && !configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE,
InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK});
}
@ -59,7 +62,7 @@ void MHA::SetUp() {
setInferenceType(prc);
inType = outType = prc;
if (prc == ov::element::bf16)
abs_threshold = 0.3;
rel_threshold = 0.05f;
}
void MHA::generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) {
@ -68,13 +71,13 @@ void MHA::generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticSha
for (int i = 0; i < model_inputs.size(); ++i) {
const auto& model_input = model_inputs[i];
ov::Tensor tensor;
tensor = ov::test::utils::create_and_fill_tensor_normal_distribution(model_input.get_element_type(), targetInputStaticShapes[i], 1.0f, 0.5f);
tensor = ov::test::utils::create_and_fill_tensor(model_input.get_element_type(), model_input.get_shape(), 2, -1, 256);
inputs.insert({model_input.get_node_shared_ptr(), tensor});
}
}
void MHA::init_subgraph() {
auto f = ov::test::snippets::MHAFunction(inputDynamicShapes, m_with_mul);
auto f = ov::test::snippets::MHAFunction(inputDynamicShapes, m_input_types, m_with_mul);
function = f.getOriginal();
}
@ -90,14 +93,14 @@ void MHASelect::generate_inputs(const std::vector<ngraph::Shape>& targetInputSta
tensor = ov::test::utils::create_and_fill_tensor(model_input.get_element_type(), model_input.get_shape(), 5 + seed, -2, 10, seed);
seed++;
} else {
tensor = ov::test::utils::create_and_fill_tensor_normal_distribution(model_input.get_element_type(), model_input.get_shape(), 1.0f, 0.5f);
tensor = ov::test::utils::create_and_fill_tensor(model_input.get_element_type(), model_input.get_shape(), 2, -1, 256);
}
inputs.insert({node_input, tensor});
}
}
void MHASelect::init_subgraph() {
auto f = ov::test::snippets::MHASelectFunction(inputDynamicShapes);
auto f = ov::test::snippets::MHASelectFunction(inputDynamicShapes, m_input_types);
function = f.getOriginal();
}
@ -107,7 +110,22 @@ void MHAWOTransposeOnInputs::init_subgraph() {
}
void MHAWOTranspose::init_subgraph() {
auto f = ov::test::snippets::MHAWOTransposeFunction(inputDynamicShapes);
auto f = ov::test::snippets::MHAWOTransposeFunction(inputDynamicShapes, m_input_types);
function = f.getOriginal();
}
void MHAINT8MatMul::init_subgraph() {
auto f = ov::test::snippets::MHAINT8MatMulFunction(inputDynamicShapes);
function = f.getOriginal();
}
void MHAFQAfterMatMul::init_subgraph() {
auto f = ov::test::snippets::MHAFQAfterMatMulFunction(inputDynamicShapes);
function = f.getOriginal();
}
void MHAFQ::init_subgraph() {
auto f = ov::test::snippets::MHAFQFunction(inputDynamicShapes);
function = f.getOriginal();
}
@ -134,6 +152,20 @@ TEST_P(MHAWOTranspose, CompareWithRefImpl) {
validateNumSubgraphs();
}
TEST_P(MHAINT8MatMul, CompareWithRefImpl) {
run();
validateNumSubgraphs();
}
TEST_P(MHAFQAfterMatMul, CompareWithRefImpl) {
run();
validateNumSubgraphs();
}
TEST_P(MHAFQ, CompareWithRefImpl) {
run();
validateNumSubgraphs();
}
} // namespace snippets
} // namespace test

View File

@ -26,9 +26,9 @@ public:
explicit MatMulFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
NGRAPH_CHECK(input_shapes.size() == 2, "Got invalid number of input shapes");
verify_precisions(precisions);
validate_precisions(precisions);
}
static void verify_precisions(const std::vector<ov::element::Type>& precisions) {
static void validate_precisions(const std::vector<ov::element::Type>& precisions) {
NGRAPH_CHECK(precisions.size() == 2, "Got invalid number of input element types");
const bool is_f32 = ov::snippets::utils::everyone_is(element::f32, precisions[0], precisions[1]);
const bool is_int8 = ov::snippets::utils::one_of(precisions[0], element::i8, element::u8) && precisions[1] == element::i8;
@ -62,7 +62,7 @@ public:
explicit MatMulBiasFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
MatMulFunction::verify_precisions(precisions);
MatMulFunction::validate_precisions(precisions);
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
@ -70,7 +70,6 @@ protected:
std::vector<ov::element::Type> precisions;
};
// Quantized MatMul
// FQ[I8]
// Add
@ -79,7 +78,7 @@ public:
explicit MatMulBiasQuantizedFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
MatMulFunction::verify_precisions(precisions);
MatMulFunction::validate_precisions(precisions);
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
@ -97,7 +96,7 @@ public:
explicit MatMulsQuantizedFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
MatMulFunction::verify_precisions(precisions);
MatMulFunction::validate_precisions(precisions);
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
@ -121,7 +120,7 @@ public:
NGRAPH_CHECK(input_shapes[0].rank().get_length() == 4 && input_shapes[1].rank().get_length() == 4,
"Only rank 4 input shapes are supported by this test");
NGRAPH_CHECK(transpose_position >=0 && transpose_position <= 2, "Got invalid transpose position");
MatMulFunction::verify_precisions(precisions);
MatMulFunction::validate_precisions(precisions);
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
@ -166,7 +165,7 @@ public:
explicit MatMulsQuantizedSoftmaxFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
MatMulFunction::verify_precisions(precisions);
MatMulFunction::validate_precisions(precisions);
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;

View File

@ -43,15 +43,17 @@ namespace snippets {
*/
class MHAFunction : public SnippetsFunctionBase {
public:
explicit MHAFunction(const std::vector<PartialShape>& inputShapes, bool with_mul = true)
: SnippetsFunctionBase(inputShapes), with_mul(with_mul) {
explicit MHAFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions, bool with_mul = true)
: SnippetsFunctionBase(inputShapes), with_mul(with_mul), precisions(precisions) {
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
NGRAPH_CHECK(precisions.size() == 4, "Got invalid number of input precisions");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
std::shared_ptr<ov::Model> initReference() const override;
bool with_mul = true;
std::vector<ov::element::Type> precisions;
};
/* Graph:
@ -71,13 +73,16 @@ protected:
*/
class MHAMatMul0TransposeFunction : public SnippetsFunctionBase {
public:
explicit MHAMatMul0TransposeFunction(const std::vector<PartialShape>& inputShapes)
: SnippetsFunctionBase(inputShapes) {
explicit MHAMatMul0TransposeFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
NGRAPH_CHECK(precisions.size() == 4, "Got invalid number of input precisions");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
std::shared_ptr<ov::Model> initReference() const override;
std::vector<ov::element::Type> precisions;
};
/* Graph:
@ -97,11 +102,15 @@ protected:
*/
class MHASelectFunction : public SnippetsFunctionBase {
public:
explicit MHASelectFunction(const std::vector<PartialShape>& inputShapes) : SnippetsFunctionBase(inputShapes) {
explicit MHASelectFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
NGRAPH_CHECK(input_shapes.size() == 6, "Got invalid number of input shapes");
NGRAPH_CHECK(precisions.size() == 6, "Got invalid number of input precisions");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
std::vector<ov::element::Type> precisions;
};
/* Graph:
@ -137,13 +146,128 @@ protected:
*/
class MHAWOTransposeFunction : public SnippetsFunctionBase {
public:
explicit MHAWOTransposeFunction(const std::vector<PartialShape>& inputShapes) : SnippetsFunctionBase(inputShapes) {
explicit MHAWOTransposeFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
NGRAPH_CHECK(precisions.size() == 3, "Got invalid number of input precisions");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
std::vector<ov::element::Type> precisions;
};
/* Graph:
* Transpose0[0,2,1,3] Transpose1[0,2,3,1]
* \ /
* MatMul0
* FakeQuantize i8
* \ /
* Add
* Reshape0
* Softmax
* Reshape1 Transpose2[0,2,1,3]
* \ /
* MatMul1
* FakeQuantize i8
* Transpose3[0,2,1,3]
*/
class MHAFQAfterMatMulFunction : public SnippetsFunctionBase {
public:
explicit MHAFQAfterMatMulFunction(const std::vector<PartialShape>& inputShapes)
: SnippetsFunctionBase(inputShapes) {
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
};
/* Graph:
* FakeQuantize i8 FakeQuantize i8
* Transpose0[0,2,1,3] Transpose1[0,2,3,1]
* \ /
* MatMul0
* FakeQuantize i8
* \ /
* Add
* Reshape0
* Softmax
* Reshape1 FakeQuantize i8
* FakeQuantize u8 Transpose2[0,2,1,3]
* \ /
* MatMul1
* FakeQuantize i8
* Transpose3[0,2,1,3]
*/
class MHAINT8MatMulFunction : public SnippetsFunctionBase {
public:
explicit MHAINT8MatMulFunction(const std::vector<PartialShape>& inputShapes)
: SnippetsFunctionBase(inputShapes) {
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
};
/* Graph:
* Constant
* FakeQuantize u8 FakeQuantize u8 Convert
* Transpose0[0,2,1,3] Transpose1[0,2,3,1] Multiply
* \ \ /
* \ Multiply
* \ FakeQuantize f32
* \ /
* MatMul0
* FakeQuantize f32 FakeQuantize u8
* \ /
* Add
* Softmax Transpose2[0,2,1,3]
* \ /
* MatMul1
* FakeQuantize u8
* Transpose3[0,2,1,3]
* Note: Check a lot of different FQ (the both quantized and floating) - buffers with different size and precision
*/
class MHAFQFunction : public SnippetsFunctionBase {
public:
explicit MHAFQFunction(const std::vector<PartialShape>& inputShapes)
: SnippetsFunctionBase(inputShapes) {
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
};
// Only for tokenization! The graph is after LPT: contains TypeRelaxed ops
/* Graph:
* FakeQuantize i8 FakeQuantize i8
* Transpose0[0,2,1,3] Transpose1[0,2,3,1]
* \ /
* MatMul0
* FakeQuantize i8
* \ /
* Add
* Mul (DeQuantize)
* Reshape0
* Softmax
* Reshape1 FakeQuantize i8
* FakeQuantize u8 Transpose2[0,2,1,3]
* \ /
* MatMul1
* FakeQuantize i8
* Transpose3[0,2,1,3]
*/
class MHAINT8MatMulTypeRelaxedFunction : public SnippetsFunctionBase {
public:
explicit MHAINT8MatMulTypeRelaxedFunction(const std::vector<PartialShape>& inputShapes)
: SnippetsFunctionBase(inputShapes) {
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
}
protected:
std::shared_ptr<ov::Model> initOriginal() const override;
std::shared_ptr<ov::Model> initReference() const override;
};
} // namespace snippets
} // namespace test
} // namespace ov

View File

@ -230,4 +230,4 @@ std::shared_ptr<ov::Model> MatMulsQuantizedSoftmaxFunction::initOriginal() const
} // namespace snippets
} // namespace test
} // namespace ov
} // namespace ov

View File

@ -7,16 +7,18 @@
#include "common_test_utils/data_utils.hpp"
#include <snippets/op/subgraph.hpp>
#include "ngraph_functions/builders.hpp"
#include "ov_ops/type_relaxed.hpp"
#include "lpt_ngraph_functions/common/builders.hpp"
namespace ov {
namespace test {
namespace snippets {
std::shared_ptr<ov::Model> MHAFunction::initOriginal() const {
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precisions[0], input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precisions[1], input_shapes[1]);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precisions[2], input_shapes[2]);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precisions[3], input_shapes[3]);
ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param};
std::vector<ov::Shape> constantShapes;
@ -51,7 +53,7 @@ std::shared_ptr<ov::Model> MHAFunction::initOriginal() const {
std::shared_ptr<ov::Node> matmul_parent1 = transpose1;
if (with_mul) {
std::vector<float> mulConstData(ngraph::shape_size(constantShapes[2]));
auto mulConst = ngraph::builder::makeConstant(precision, constantShapes[2], mulConstData, true);
auto mulConst = ngraph::builder::makeConstant(precisions[1], constantShapes[2], mulConstData, true);
matmul_parent1 = std::make_shared<ngraph::opset3::Multiply>(transpose1, mulConst);
}
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, matmul_parent1, transA, transB);
@ -67,17 +69,17 @@ std::shared_ptr<ov::Model> MHAFunction::initOriginal() const {
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
std::shared_ptr<ov::Model> MHAFunction::initReference() const {
auto data0 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto data1 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
auto data2 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
auto data3 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
auto data0 = std::make_shared<ngraph::opset1::Parameter>(precisions[0], input_shapes[0]);
auto data1 = std::make_shared<ngraph::opset1::Parameter>(precisions[1], input_shapes[1]);
auto data2 = std::make_shared<ngraph::opset1::Parameter>(precisions[2], input_shapes[2]);
auto data3 = std::make_shared<ngraph::opset1::Parameter>(precisions[3], input_shapes[3]);
ngraph::ParameterVector ngraphParams = {data0, data1, data2, data3};
NodeVector subgraph_inputs = {data0, data1, data2, data3};
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precisions[0], input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precisions[1], input_shapes[1]);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precisions[2], input_shapes[2]);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precisions[3], input_shapes[3]);
std::vector<ov::Shape> constantShapes;
constantShapes.push_back(ov::Shape({input_shapes[0].get_shape().size()}));
@ -113,8 +115,8 @@ std::shared_ptr<ov::Model> MHAFunction::initReference() const {
std::shared_ptr<ov::Node> matmul_parent1 = transpose1;
if (with_mul) {
std::vector<float> mulConstData(ngraph::shape_size(constantShapes[2]));
auto mulConst = ngraph::builder::makeConstant(precision, constantShapes[2], mulConstData, true);
auto mulParam = std::make_shared<ngraph::opset1::Parameter>(precision, mulConst->get_shape());
auto mulConst = ngraph::builder::makeConstant(precisions[1], constantShapes[2], mulConstData, true);
auto mulParam = std::make_shared<ngraph::opset1::Parameter>(precisions[1], mulConst->get_shape());
matmul_parent1 = std::make_shared<ngraph::opset3::Multiply>(transpose1, mulParam);
subgraph_params = {transpose0Param, transpose1Param, mulParam, addParam, transpose2Param};
subgraph_inputs = {data0, data1, mulConst, data2, data3};
@ -135,10 +137,10 @@ std::shared_ptr<ov::Model> MHAFunction::initReference() const {
}
std::shared_ptr<ov::Model> MHAMatMul0TransposeFunction::initOriginal() const {
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precisions[0], input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precisions[1], input_shapes[1]);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precisions[2], input_shapes[2]);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precisions[3], input_shapes[3]);
ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param};
std::vector<ov::Shape> constantShapes;
@ -157,7 +159,7 @@ std::shared_ptr<ov::Model> MHAMatMul0TransposeFunction::initOriginal() const {
auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, constantShapes[6], order);
std::vector<float> mulConstData(1);
auto mulConst = ngraph::builder::makeConstant(precision, ov::Shape{1}, mulConstData, true);
auto mulConst = ngraph::builder::makeConstant(precisions[1], ov::Shape{1}, mulConstData, true);
std::vector<int64_t> reshape0ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0] *
input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]),
@ -188,16 +190,16 @@ std::shared_ptr<ov::Model> MHAMatMul0TransposeFunction::initOriginal() const {
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
std::shared_ptr<ov::Model> MHAMatMul0TransposeFunction::initReference() const {
auto data0 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto data1 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
auto data2 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
auto data3 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
auto data0 = std::make_shared<ngraph::opset1::Parameter>(precisions[0], input_shapes[0]);
auto data1 = std::make_shared<ngraph::opset1::Parameter>(precisions[1], input_shapes[1]);
auto data2 = std::make_shared<ngraph::opset1::Parameter>(precisions[2], input_shapes[2]);
auto data3 = std::make_shared<ngraph::opset1::Parameter>(precisions[3], input_shapes[3]);
ngraph::ParameterVector ngraphParams = {data0, data1, data2, data3};
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precisions[0], input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precisions[1], input_shapes[1]);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precisions[2], input_shapes[2]);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precisions[3], input_shapes[3]);
std::vector<ov::Shape> constantShapes;
constantShapes.push_back(ov::Shape({input_shapes[0].get_shape().size()}));
@ -214,7 +216,7 @@ std::shared_ptr<ov::Model> MHAMatMul0TransposeFunction::initReference() const {
auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, constantShapes[6], std::vector<int64_t>{0, 2, 1, 3});
std::vector<float> mulConstData(1);
auto mulConst = ngraph::builder::makeConstant(precision, ov::Shape{1}, mulConstData, true);
auto mulConst = ngraph::builder::makeConstant(precisions[1], ov::Shape{1}, mulConstData, true);
ngraph::ParameterVector subgraphParams = {transpose0Param, transpose1Param, addParam, transpose2Param};
std::vector<int64_t> reshape0ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0] *
@ -250,12 +252,12 @@ std::shared_ptr<ov::Model> MHAMatMul0TransposeFunction::initReference() const {
}
std::shared_ptr<ov::Model> MHASelectFunction::initOriginal() const {
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
auto less0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
auto less1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[4]);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[5]);
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precisions[0], input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precisions[1], input_shapes[1]);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precisions[2], input_shapes[2]);
auto less0Param = std::make_shared<ngraph::opset1::Parameter>(precisions[3], input_shapes[3]);
auto less1Param = std::make_shared<ngraph::opset1::Parameter>(precisions[4], input_shapes[4]);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precisions[5], input_shapes[5]);
ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, less0Param, less1Param, transpose2Param};
std::vector<ov::Shape> constantShapes;
@ -288,7 +290,7 @@ std::shared_ptr<ov::Model> MHASelectFunction::initOriginal() const {
static_cast<int64_t>(input_shapes[0].get_shape()[1])};
auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, constantShapes[4], reshape1ConstData);
// Value is equal to '1' - to avoid situation e^(-1000) / (sum(e^(-1000)) = 0/0 = NAN
auto selectConst = ngraph::builder::makeConstant(precision, ov::Shape{1}, std::vector<float>{1});
auto selectConst = ngraph::builder::makeConstant(precisions[2], ov::Shape{1}, std::vector<float>{1});
float transA = false;
float transB = false;
@ -344,9 +346,9 @@ std::shared_ptr<ov::Model> MHAWOTransposeOnInputsFunction::initOriginal() const
}
std::shared_ptr<ov::Model> MHAWOTransposeFunction::initOriginal() const {
auto param0 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto param1 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
auto param2 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
auto param0 = std::make_shared<ngraph::opset1::Parameter>(precisions[0], input_shapes[0]);
auto param1 = std::make_shared<ngraph::opset1::Parameter>(precisions[1], input_shapes[1]);
auto param2 = std::make_shared<ngraph::opset1::Parameter>(precisions[2], input_shapes[2]);
ngraph::ParameterVector ngraphParam = {param0, param1, param2};
float transA = false;
@ -359,6 +361,302 @@ std::shared_ptr<ov::Model> MHAWOTransposeFunction::initOriginal() const {
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
std::shared_ptr<ov::Model> MHAFQAfterMatMulFunction::initOriginal() const {
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param};
const auto shape_rank = input_shapes[0].get_shape().size();
auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 3, 1});
auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
std::vector<int64_t> reshape0ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0] *
input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]),
-1};
auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape0ConstData.size()}, reshape0ConstData);
std::vector<int64_t> reshape1ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0]),
static_cast<int64_t>(input_shapes[0].get_shape()[2]),
static_cast<int64_t>(input_shapes[0].get_shape()[1]),
static_cast<int64_t>(input_shapes[0].get_shape()[1])};
auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape1ConstData.size()}, reshape1ConstData);
float transA = false;
float transB = false;
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(transpose0Param, transpose0Const);
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(transpose1Param, transpose1Const);
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, transpose1, transA, transB);
auto fq0 = ngraph::builder::makeFakeQuantize(matMul0, ov::element::f32, 256, {1},
{-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294});
const auto add = std::make_shared<ngraph::opset3::Add>(fq0, addParam);
const auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(add, reshape0Const, true);
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(reshape0, 1);
const auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softMax, reshape1Const, true);
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(transpose2Param, transpose2Const);
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(reshape1, transpose2, transA, transB);
auto fq1 = ngraph::builder::makeFakeQuantize(matMul1, ov::element::f32, 256, {1},
{-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294});
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(fq1, transpose3Const);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
std::shared_ptr<ov::Model> MHAINT8MatMulFunction::initOriginal() const {
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param};
const auto shape_rank = input_shapes[0].get_shape().size();
auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 3, 1});
auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
std::vector<int64_t> reshape0ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0] *
input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]),
-1};
auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape0ConstData.size()}, reshape0ConstData);
std::vector<int64_t> reshape1ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0]),
static_cast<int64_t>(input_shapes[0].get_shape()[2]),
static_cast<int64_t>(input_shapes[0].get_shape()[1]),
static_cast<int64_t>(input_shapes[0].get_shape()[1])};
auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape1ConstData.size()}, reshape1ConstData);
auto fq0 = ngraph::builder::makeFakeQuantize(transpose0Param, ov::element::f32, 256, {1},
{-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294});
auto fq1 = ngraph::builder::makeFakeQuantize(transpose1Param, ov::element::f32, 256, {1},
{-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294});
auto fq2 = ngraph::builder::makeFakeQuantize(transpose2Param, ov::element::f32, 256, {1},
{-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294});
float transA = false;
float transB = false;
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(fq0, transpose0Const);
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(fq1, transpose1Const);
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, transpose1, transA, transB);
auto fq3 = ngraph::builder::makeFakeQuantize(matMul0, ov::element::f32, 256, {1},
{-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294});
const auto add = std::make_shared<ngraph::opset3::Add>(fq3, addParam);
const auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(add, reshape0Const, true);
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(reshape0, 1);
const auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softMax, reshape1Const, true);
auto fq4 = ngraph::builder::makeFakeQuantize(reshape1, ov::element::f32, 256, {1},
{0}, {0.820726}, {0}, {0.820726});
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(fq2, transpose2Const);
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(fq4, transpose2, transA, transB);
auto fq5 = ngraph::builder::makeFakeQuantize(matMul1, ov::element::f32, 256, {1},
{-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294});
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(fq5, transpose3Const);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
std::shared_ptr<ov::Model> MHAFQFunction::initOriginal() const {
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param};
const auto shape_rank = input_shapes[0].get_shape().size();
auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 3, 1});
auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
const auto fq0 = ngraph::builder::makeFakeQuantize(transpose0Param, ov::element::f32, 256, {1},
{-5.217694}, {6.661877}, {-5.217694}, {6.661877});
const auto fq1 = ngraph::builder::makeFakeQuantize(transpose1Param, ov::element::f32, 256, {1},
{-6.40245}, {6.45286}, {-6.40245}, {6.45286});
const auto fq_add = ngraph::builder::makeFakeQuantize(addParam, ov::element::f32, 256, {1},
{-1000}, {0}, {-1000}, {0});
float transA = false;
float transB = false;
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(fq0, transpose0Const);
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(fq1, transpose1Const);
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(transpose2Param, transpose2Const);
const auto mul_const = ngraph::builder::makeConstant(ov::element::i8, ov::Shape{1}, std::vector<int8_t>{127});
const auto convert = std::make_shared<ngraph::opset1::Convert>(mul_const, ov::element::f32);
const auto mul_deq_const = ngraph::builder::makeConstant(ov::element::f32, ov::Shape{1}, std::vector<float>{0.00098425});
const auto mul_deq = std::make_shared<ngraph::opset1::Multiply>(convert, mul_deq_const);
const auto mul = std::make_shared<ngraph::opset1::Multiply>(transpose1, mul_deq);
auto fq1_1 = ngraph::builder::makeFakeQuantize(mul, ov::element::f32, 256, {1},
{-0.8003067}, {0.8066083}, {-0.8003067}, {0.8066083});
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, fq1_1, transA, transB);
auto fq2 = ngraph::builder::makeFakeQuantize(matMul0, ov::element::f32, 256, {1},
{-14.50351}, {17.65645}, {-14.50351}, {17.65645});
const auto add = std::make_shared<ngraph::opset1::Add>(fq2, fq_add);
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(add, 3);
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softMax, transpose2, transA, transB);
auto fq3 = ngraph::builder::makeFakeQuantize(matMul1, ov::element::f32, 256, {1},
{-1.895786}, {2.0028071}, {-1.895786}, {2.0028071});
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(fq3, transpose3Const);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
std::shared_ptr<ov::Model> MHAINT8MatMulTypeRelaxedFunction::initOriginal() const {
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param};
const auto shape_rank = input_shapes[0].get_shape().size();
auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 3, 1});
auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
std::vector<int64_t> reshape0ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0] *
input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]),
-1};
auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape0ConstData.size()}, reshape0ConstData);
std::vector<int64_t> reshape1ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0]),
static_cast<int64_t>(input_shapes[0].get_shape()[2]),
static_cast<int64_t>(input_shapes[0].get_shape()[1]),
static_cast<int64_t>(input_shapes[0].get_shape()[1])};
auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape1ConstData.size()}, reshape1ConstData);
const auto fq_signed_params = ngraph::builder::subgraph::FakeQuantizeOnData(256, {1}, {-36912.66015625}, {36624.28125}, {-128}, {127}, ov::element::i8);
const auto fq0 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(transpose0Param, ov::element::i8, fq_signed_params);
const auto fq1 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(transpose1Param, ov::element::i8, fq_signed_params);
const auto fq2 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(transpose2Param, ov::element::i8, fq_signed_params);
float transA = false;
float transB = false;
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(fq0, transpose0Const);
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(fq1, transpose1Const);
const auto matMul0 = std::make_shared<op::TypeRelaxed<op::v0::MatMul>>(
std::vector<element::Type>{ element::f32, element::f32 },
std::vector<element::Type>{ element::f32 },
ov::op::TemporaryReplaceOutputType(transpose0, element::f32).get(),
ov::op::TemporaryReplaceOutputType(transpose1, element::f32).get(), transA, transB);
const auto fq3 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul0, ov::element::i8, fq_signed_params);
const auto add = std::make_shared<op::TypeRelaxed<ngraph::opset3::Add>>(
std::vector<element::Type>{ element::f32, element::f32 },
std::vector<element::Type>{ element::f32 },
ov::op::TemporaryReplaceOutputType(fq3, element::f32).get(),
ov::op::TemporaryReplaceOutputType(addParam, element::f32).get());
const auto deq = std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1}, std::vector<float>{0.1122});
const auto deq_mul = std::make_shared<op::TypeRelaxed<ngraph::opset3::Multiply>>(
std::vector<element::Type>{ element::f32, element::f32 },
std::vector<element::Type>{ element::f32 },
ov::op::TemporaryReplaceOutputType(add, element::f32).get(),
ov::op::TemporaryReplaceOutputType(deq, element::f32).get());
const auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(add, reshape0Const, true);
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(reshape0, 1);
const auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softMax, reshape1Const, true);
const auto fq_unsigned_params = ngraph::builder::subgraph::FakeQuantizeOnData(256, {1}, {0}, {0.245}, {0}, {255}, ov::element::u8);
const auto fq4 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(reshape1, ov::element::u8, fq_unsigned_params);
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(fq2, transpose2Const);
const auto matMul1 = std::make_shared<op::TypeRelaxed<op::v0::MatMul>>(
std::vector<element::Type>{ element::f32, element::f32 },
std::vector<element::Type>{ element::f32 },
ov::op::TemporaryReplaceOutputType(fq4, element::f32).get(),
ov::op::TemporaryReplaceOutputType(transpose2, element::f32).get(), transA, transB);
const auto fq5 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul1, ov::element::i8, fq_signed_params);
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(fq5, transpose3Const);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
}
std::shared_ptr<ov::Model> MHAINT8MatMulTypeRelaxedFunction::initReference() const {
auto data0 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto data1 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
auto data2 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
auto data3 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
ngraph::ParameterVector ngraphParams = {data0, data1, data2, data3};
const auto fq_signed_params = ngraph::builder::subgraph::FakeQuantizeOnData(256, {1}, {-36912.66015625}, {36624.28125}, {-128}, {127}, ov::element::i8);
const auto fq0 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(data0, ov::element::i8, fq_signed_params);
const auto fq1 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(data1, ov::element::i8, fq_signed_params);
const auto fq2 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(data3, ov::element::i8, fq_signed_params);
NodeVector subgraph_inputs = {fq0, fq1, data2, fq2};
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
ov::ParameterVector subgraph_params = {transpose0Param, transpose1Param, addParam, transpose2Param};
const auto shape_rank = input_shapes[0].get_shape().size();
auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 3, 1});
auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
std::vector<int64_t> reshape0ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0] *
input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]),
-1};
auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape0ConstData.size()}, reshape0ConstData);
std::vector<int64_t> reshape1ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0]),
static_cast<int64_t>(input_shapes[0].get_shape()[2]),
static_cast<int64_t>(input_shapes[0].get_shape()[1]),
static_cast<int64_t>(input_shapes[0].get_shape()[1])};
auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape1ConstData.size()}, reshape1ConstData);
float transA = false;
float transB = false;
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(transpose0Param, transpose0Const);
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(transpose1Param, transpose1Const);
const auto matMul0 = std::make_shared<op::TypeRelaxed<op::v0::MatMul>>(
std::vector<element::Type>{ element::f32, element::f32 },
std::vector<element::Type>{ element::f32 },
ov::op::TemporaryReplaceOutputType(transpose0, element::f32).get(),
ov::op::TemporaryReplaceOutputType(transpose1, element::f32).get(), transA, transB);
const auto fq3 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul0, ov::element::i8, fq_signed_params);
const auto add = std::make_shared<op::TypeRelaxed<ngraph::opset3::Add>>(
std::vector<element::Type>{ element::f32, element::f32 },
std::vector<element::Type>{ element::f32 },
ov::op::TemporaryReplaceOutputType(fq3, element::f32).get(),
ov::op::TemporaryReplaceOutputType(addParam, element::f32).get());
const auto deq = std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1}, std::vector<float>{0.1122});
const auto deq_mul = std::make_shared<op::TypeRelaxed<ngraph::opset3::Multiply>>(
std::vector<element::Type>{ element::f32, element::f32 },
std::vector<element::Type>{ element::f32 },
ov::op::TemporaryReplaceOutputType(add, element::f32).get(),
ov::op::TemporaryReplaceOutputType(deq, element::f32).get());
const auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(add, reshape0Const, true);
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(reshape0, 1);
const auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softMax, reshape1Const, true);
const auto fq_unsigned_params = ngraph::builder::subgraph::FakeQuantizeOnData(256, {1}, {0}, {0.245}, {0}, {255}, ov::element::u8);
const auto fq4 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(reshape1, ov::element::u8, fq_unsigned_params);
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(transpose2Param, transpose2Const);
const auto matMul1 = std::make_shared<op::TypeRelaxed<op::v0::MatMul>>(
std::vector<element::Type>{ element::f32, element::f32 },
std::vector<element::Type>{ element::f32 },
ov::op::TemporaryReplaceOutputType(fq4, element::f32).get(),
ov::op::TemporaryReplaceOutputType(transpose2, element::f32).get(), transA, transB);
const auto fq5 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul1, ov::element::i8, fq_signed_params);
auto subgraph = std::make_shared<ov::snippets::op::Subgraph>(subgraph_inputs,
std::make_shared<ov::Model>(NodeVector{fq5}, subgraph_params));
// TODO: At the moment Snippets don't support explicitly Transpose.
// So we cannot collapse Transpose into Subgraph if there are ops between MatMul2 and Transpose3
auto transpose3 = std::make_shared<ov::op::v1::Transpose>(subgraph, transpose3Const);
return std::make_shared<ov::Model>(NodeVector{transpose3}, ngraphParams);
}
} // namespace snippets
} // namespace test
} // namespace ov