[Snippets] Add support of MHA Tokenization for different precisions (#15647)
This commit is contained in:
parent
bdfa970c7a
commit
eb3e6a65eb
@ -28,6 +28,8 @@ public:
|
||||
size_t get_offset_b() const { return get_input_offset(1); }
|
||||
size_t get_offset_c() const { return get_output_offset(0); }
|
||||
|
||||
static ov::element::Type get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1);
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
|
@ -29,7 +29,7 @@ class Buffer : public ov::op::Op {
|
||||
public:
|
||||
OPENVINO_OP("Buffer", "SnippetsOpset");
|
||||
Buffer() = default;
|
||||
Buffer(const ov::Shape& shape, size_t id = 0);
|
||||
Buffer(const ov::Shape& shape, ov::element::Type element_type = ov::element::u8, size_t id = 0);
|
||||
Buffer(const ov::Output<ov::Node>& arg, const ov::Shape& shape, size_t id = 0);
|
||||
Buffer(const ov::Output<ov::Node>& arg, int32_t allocation_rank = -1, size_t id = 0);
|
||||
|
||||
@ -48,9 +48,10 @@ public:
|
||||
int64_t get_offset() const { return m_offset; }
|
||||
void set_id(size_t id) { m_id = id; }
|
||||
void set_offset(int64_t offset) { m_offset = offset; }
|
||||
|
||||
size_t get_byte_size() const;
|
||||
|
||||
void set_element_type(ov::element::Type element_type);
|
||||
|
||||
bool is_intermediate_memory() const { return m_type == Type::IntermediateMemory; }
|
||||
bool is_new_memory() const { return m_type == Type::NewMemory; }
|
||||
|
||||
@ -59,6 +60,7 @@ private:
|
||||
ov::Shape m_shape = {};
|
||||
int64_t m_offset = 0;
|
||||
size_t m_id = 0; // Default ID - 0. All Buffers are from the same set
|
||||
ov::element::Type m_element_type = ov::element::u8; // u8 - default 1 byte
|
||||
};
|
||||
|
||||
} // namespace op
|
||||
|
@ -136,7 +136,6 @@ public:
|
||||
// should have explicit Constants even if they're non-scalar (Reshape, Transpose, Broadcast)
|
||||
// This check returns True if Constant op which is input of this op should be inside Subgraph body
|
||||
static auto constant_input_should_be_inside_body(const std::shared_ptr<ov::Node>& node) -> bool;
|
||||
|
||||
static bool check_broadcast(const std::shared_ptr<const ov::Node>& node) noexcept;
|
||||
// Return estimated unique buffer count (upper bound). It's needed for tokenization
|
||||
static auto get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t;
|
||||
|
@ -6,6 +6,7 @@
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
#include "snippets/pass/tokenization.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace snippets {
|
||||
@ -14,13 +15,34 @@ namespace pass {
|
||||
/**
|
||||
* @interface TokenizeMHASnippets
|
||||
* @brief The pass tokenizes MHA-pattern into Subgraph
|
||||
* TODO: Write pattern
|
||||
* Pattern: Transpose1
|
||||
* |
|
||||
* Transpose0 [Eltwise, Select]
|
||||
* \ /
|
||||
* MatMul0
|
||||
* |
|
||||
* [Eltwise, Select, Reshape]
|
||||
* |
|
||||
* Softmax
|
||||
* |
|
||||
* [Eltwise, Select, Reshape] Transpose2
|
||||
* \ /
|
||||
* MatMul1
|
||||
* |
|
||||
* [Eltwise, Select, Transpose3]
|
||||
* Notes:
|
||||
* - Transposes can be missed
|
||||
* - Transpose0, Transpose2 and Transpose3 may have only [0,2,1,3] order
|
||||
* - Transpose1 may have only [0,2,3,1] order
|
||||
* - [...] means any count of different nodes from list. But:
|
||||
* * Reshapes can be only explicitly around Softmax (Reshape -> Softmax -> Reshape)
|
||||
* * After MatMul1 may be only Transpose3 or any count of Eltwise, Select ops.
|
||||
* @ingroup snippets
|
||||
*/
|
||||
class TokenizeMHASnippets: public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("TokenizeMHASnippets", "0");
|
||||
TokenizeMHASnippets();
|
||||
TokenizeMHASnippets(const SnippetsTokenization::Config& config = {});
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
|
@ -7,8 +7,7 @@
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
|
||||
#include "snippets/pass/mha_tokenization.hpp"
|
||||
#include "snippets/pass/collapse_subgraph.hpp"
|
||||
#include "snippets/op/subgraph.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace snippets {
|
||||
@ -19,8 +18,16 @@ namespace pass {
|
||||
SkippedByPlugin - indicate that snippets can't include this node in subgraph. Can be set by Plugin via SetSnippetsNodeType(...).
|
||||
*/
|
||||
enum class SnippetsNodeType : int64_t {NotSet, SkippedByPlugin};
|
||||
/*
|
||||
NotSet - default value returned if the subgraph wasn't marked and snippets can include nodes in this subgraph
|
||||
Completed - indicate that snippets can't include any nodes in this subgraph.
|
||||
It's used in separate tokenization pass, for example, tokenization by matcher (MHA Tokenization).
|
||||
*/
|
||||
enum class SnippetsSubgraphType : int64_t {NotSet, Completed};
|
||||
void SetSnippetsNodeType(const std::shared_ptr<Node>&, SnippetsNodeType);
|
||||
void SetSnippetsSubgraphType(const std::shared_ptr<op::Subgraph>&, SnippetsSubgraphType);
|
||||
SnippetsNodeType GetSnippetsNodeType(const std::shared_ptr<const Node>&);
|
||||
SnippetsSubgraphType GetSnippetsSubgraphType(const std::shared_ptr<const op::Subgraph>&);
|
||||
void SetTopologicalOrder(const std::shared_ptr<Node>&, int64_t);
|
||||
int64_t GetTopologicalOrder(const std::shared_ptr<const Node>&);
|
||||
|
||||
@ -48,8 +55,26 @@ public:
|
||||
*/
|
||||
class SnippetsTokenization : public ov::pass::ModelPass {
|
||||
public:
|
||||
/**
|
||||
* @interface Config
|
||||
* @brief Allow to adjust tokenization passes
|
||||
* @ingroup snippets
|
||||
*/
|
||||
struct Config {
|
||||
Config(bool enable_transpose = true) : mha_token_enable_transpose(enable_transpose) {}
|
||||
|
||||
// False if all Transposes aren't tokenized in MHA Tokenization.
|
||||
// Otherwise, they may be fused into Subgraph if possible
|
||||
// TODO [106921]: Remove please when the ticket 106921 is implemented
|
||||
bool mha_token_enable_transpose = true;
|
||||
};
|
||||
|
||||
OPENVINO_RTTI("SnippetsTokenization", "0");
|
||||
SnippetsTokenization(const Config& config) : m_config(config) {}
|
||||
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
|
||||
|
||||
private:
|
||||
Config m_config{};
|
||||
};
|
||||
|
||||
|
||||
|
@ -124,8 +124,8 @@ bool AllocateBuffers::run(LinearIR& linear_ir) {
|
||||
|
||||
const auto current_allocated_memory_size = m_buffer_scratchpad_size - offset;
|
||||
if (buffer_size > current_allocated_memory_size) {
|
||||
m_buffer_scratchpad_size += (buffer_size - current_allocated_memory_size);
|
||||
// Note: we don't update offset because we just add memory to needed size
|
||||
allocate(buffer, expr, buffer_size);
|
||||
continue;
|
||||
}
|
||||
propagate_offset(linear_ir, *expr_it, offset);
|
||||
allocated_buffers.insert(expr);
|
||||
|
@ -100,9 +100,9 @@ bool AssignRegisters::run(LinearIR& linear_ir) {
|
||||
// Otherwise WIN build fails with "IS_MANUALLY_ALLOCATED_REG cannot be implicitly captured because no default capture mode has been specified"
|
||||
// the same problem with all the other lambdas in this file
|
||||
auto enumerate_out_tensors = [=] (const ExpressionPtr& expr,
|
||||
decltype(regs_vec)& reg_map,
|
||||
const std::map<tensor, Reg>& manually_assigned_regs,
|
||||
size_t& counter) {
|
||||
decltype(regs_vec)& reg_map,
|
||||
const std::map<tensor, Reg>& manually_assigned_regs,
|
||||
size_t& counter) {
|
||||
for (const auto& out_tensor : expr->get_output_port_connectors()) {
|
||||
// Note that some ops might have identical input&output tensors (Result and Tile* for ex.)
|
||||
// so we have to check that the tensor has not been enumerated already
|
||||
|
@ -62,24 +62,31 @@ std::shared_ptr<Node> Brgemm::clone_with_new_inputs(const OutputVector& new_args
|
||||
lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout());
|
||||
}
|
||||
|
||||
ov::element::Type Brgemm::get_output_type() const {
|
||||
const auto element_type_a = get_input_element_type(0);
|
||||
const auto element_type_b = get_input_element_type(1);
|
||||
const bool is_f32 = utils::everyone_is(element::f32, element_type_a, element_type_b);
|
||||
const bool is_int8 = utils::one_of(element_type_a, element::i8, element::u8) && element_type_b == element::i8;
|
||||
const bool is_bf16 = utils::everyone_is(element::bf16, element_type_a, element_type_b);
|
||||
ov::element::Type Brgemm::get_output_type(const ov::element::Type& in_type0, const ov::element::Type& in_type1) {
|
||||
const bool is_f32 = utils::everyone_is(element::f32, in_type0, in_type1);
|
||||
const bool is_int8 = utils::one_of(in_type0, element::i8, element::u8) && in_type1 == element::i8;
|
||||
const bool is_bf16 = utils::everyone_is(element::bf16, in_type0, in_type1);
|
||||
if (is_f32 || is_bf16) {
|
||||
return element::f32;
|
||||
return element::f32;
|
||||
} else if (is_int8) {
|
||||
return element::i32;
|
||||
} else {
|
||||
OPENVINO_THROW("BrgemmCPU node has incompatible input element types: " +
|
||||
element_type_a.get_type_name() +
|
||||
" and " +
|
||||
element_type_b.get_type_name());
|
||||
return element::undefined;
|
||||
}
|
||||
}
|
||||
|
||||
ov::element::Type Brgemm::get_output_type() const {
|
||||
auto output_type = get_output_type(get_input_element_type(0), get_input_element_type(1));
|
||||
if (output_type == element::undefined) {
|
||||
OPENVINO_THROW("BrgemmCPU node has incompatible input element types: " +
|
||||
get_input_element_type(0).get_type_name() +
|
||||
" and " +
|
||||
get_input_element_type(1).get_type_name());
|
||||
}
|
||||
|
||||
return output_type;
|
||||
}
|
||||
|
||||
std::vector<ov::PartialShape> Brgemm::get_planar_input_shapes(const std::vector<ov::Input<ov::Node>>& inputs) const {
|
||||
OPENVINO_ASSERT(inputs.size() == 2, "Brgemm::get_planar_input_shapes() expects 2 inputs");
|
||||
return { utils::get_port_planar_shape(inputs[0]), utils::get_port_planar_shape(inputs[1]) };
|
||||
|
@ -14,8 +14,8 @@ namespace snippets {
|
||||
namespace op {
|
||||
|
||||
|
||||
Buffer::Buffer(const ov::Shape& shape, size_t id)
|
||||
: Op(), m_type(Type::NewMemory), m_shape(shape), m_offset(0), m_id(id) {
|
||||
Buffer::Buffer(const ov::Shape& shape, ov::element::Type element_type, size_t id)
|
||||
: Op(), m_type(Type::NewMemory), m_shape(shape), m_offset(0), m_id(id), m_element_type(std::move(element_type)) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
@ -40,26 +40,25 @@ bool Buffer::visit_attributes(AttributeVisitor& visitor) {
|
||||
visitor.on_attribute("allocation_shape", m_shape);
|
||||
visitor.on_attribute("offset", m_offset);
|
||||
visitor.on_attribute("id", m_id);
|
||||
visitor.on_attribute("element_type", m_element_type);
|
||||
return true;
|
||||
}
|
||||
|
||||
void Buffer::validate_and_infer_types() {
|
||||
INTERNAL_OP_SCOPE(Buffer_validate_and_infer_types);
|
||||
ov::element::Type output_type;
|
||||
ov::Shape output_shape;
|
||||
if (m_type == Type::NewMemory) {
|
||||
OPENVINO_ASSERT(get_input_size() == 0, "Buffer with new allocated memory must to not have arguments!");
|
||||
output_shape = m_shape;
|
||||
output_type = ov::element::u8; // 1Byte
|
||||
} else if (m_type == Type::IntermediateMemory) {
|
||||
const auto& input_shape = get_input_partial_shape(0);
|
||||
OPENVINO_ASSERT(input_shape.is_static(), "Buffer supports only static input shape");
|
||||
output_type = get_input_element_type(0);
|
||||
m_element_type = get_input_element_type(0);
|
||||
output_shape = input_shape.get_shape();
|
||||
} else {
|
||||
OPENVINO_THROW("Buffer supports only the following types: NewMemory and IntermediateMemory");
|
||||
}
|
||||
set_output_type(0, output_type, output_shape);
|
||||
set_output_type(0, m_element_type, output_shape);
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> Buffer::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
@ -67,7 +66,7 @@ std::shared_ptr<Node> Buffer::clone_with_new_inputs(const OutputVector& new_args
|
||||
check_new_args_count(this, new_args);
|
||||
std::shared_ptr<op::Buffer> new_buffer = nullptr;
|
||||
if (m_type == Type::NewMemory) {
|
||||
new_buffer = std::make_shared<Buffer>(m_shape, m_id);
|
||||
new_buffer = std::make_shared<Buffer>(m_shape, m_element_type, m_id);
|
||||
} else if (m_type == Type::IntermediateMemory) {
|
||||
new_buffer = std::make_shared<Buffer>(new_args.at(0), m_shape, m_id);
|
||||
} else {
|
||||
@ -82,6 +81,13 @@ size_t Buffer::get_byte_size() const {
|
||||
return ov::shape_size(shape) * get_element_type().size();
|
||||
}
|
||||
|
||||
void Buffer::set_element_type(ov::element::Type element_type) {
|
||||
OPENVINO_ASSERT(is_new_memory(), "Only Buffer with NewMemory can change his output precision!");
|
||||
m_element_type = std::move(element_type);
|
||||
// Apply the change
|
||||
validate_and_infer_types();
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace snippets
|
||||
} // namespace ov
|
||||
|
@ -14,7 +14,6 @@
|
||||
#include "snippets/pass/convert_constants.hpp"
|
||||
#include "snippets/pass/convert_power_to_powerstatic.hpp"
|
||||
#include "snippets/pass/transpose_decomposition.hpp"
|
||||
#include "snippets/pass/transform_convert.hpp"
|
||||
#include "snippets/pass/matmul_to_brgemm.hpp"
|
||||
#include "snippets/pass/fuse_transpose_brgemm.hpp"
|
||||
#include "snippets/pass/set_softmax_ports.hpp"
|
||||
@ -75,12 +74,11 @@ auto snippets::op::Subgraph::is_domain_sensitive_op(const std::shared_ptr<ov::No
|
||||
}
|
||||
|
||||
void snippets::op::Subgraph::init_config() {
|
||||
auto update = [](bool& flag, bool status) { flag = flag || status; };
|
||||
const auto ops = body_ptr()->get_ops();
|
||||
for (const auto& op : ops) {
|
||||
config.m_is_quantized = config.m_is_quantized ||
|
||||
ov::is_type<ov::op::v0::FakeQuantize>(op);
|
||||
config.m_has_domain_sensitive_ops = config.m_has_domain_sensitive_ops ||
|
||||
is_domain_sensitive_op(op);
|
||||
update(config.m_is_quantized, ov::is_type<ov::op::v0::FakeQuantize>(op));
|
||||
update(config.m_has_domain_sensitive_ops, is_domain_sensitive_op(op));
|
||||
}
|
||||
}
|
||||
|
||||
@ -93,6 +91,13 @@ auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& op
|
||||
// and where will be Loops - we can just predict.
|
||||
// Note: The ops that create Buffers: MatMul, Transpose and Softmax (always FP32)
|
||||
std::vector<size_t> used_precision_size;
|
||||
|
||||
auto push_prc_size = [&used_precision_size](size_t precision_size) {
|
||||
if (used_precision_size.empty() || used_precision_size.back() != precision_size) {
|
||||
used_precision_size.push_back(precision_size);
|
||||
}
|
||||
};
|
||||
|
||||
for (const auto& op : ops) {
|
||||
if (const auto transpose = ov::as_type_ptr<ov::op::v1::Transpose>(op)) {
|
||||
// At the moment Transposes are supported only on Results and Parameters but
|
||||
@ -106,34 +111,23 @@ auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& op
|
||||
}) ||
|
||||
!ov::is_type<ov::op::v0::Parameter>(transpose->get_input_node_shared_ptr(0));
|
||||
if (are_prev_or_next_ops) {
|
||||
const auto prc_size = transpose->get_element_type().size();
|
||||
if (used_precision_size.empty() || used_precision_size.back() != prc_size) {
|
||||
used_precision_size.push_back(prc_size);
|
||||
}
|
||||
push_prc_size(transpose->get_element_type().size());
|
||||
}
|
||||
} else if (ov::is_type<ov::op::v1::Softmax>(op) || ov::is_type<ov::op::v8::Softmax>(op)) {
|
||||
// Softmax always uses 2 FP32 Buffers
|
||||
const auto prc_size = ov::element::f32.size();
|
||||
if (used_precision_size.empty() || used_precision_size.back() != prc_size) {
|
||||
used_precision_size.push_back(prc_size);
|
||||
}
|
||||
// Softmax always uses 2 FP32 Buffers after decomposition.
|
||||
// They are inplace and the same so we can push precision size only once
|
||||
push_prc_size(ov::element::f32.size());
|
||||
} else if (const auto matmul = ov::as_type_ptr<ov::op::v0::MatMul>(op)) {
|
||||
// First input check is enough because MatMul requires the same prc size on inputs
|
||||
if (!ov::is_type<ov::op::v0::Parameter>(matmul->get_input_node_shared_ptr(0)) ||
|
||||
!ov::is_type<ov::op::v0::Parameter>(matmul->get_input_node_shared_ptr(1))) {
|
||||
const auto prc_size = matmul->get_input_element_type(0).size();
|
||||
if (used_precision_size.empty() || used_precision_size.back() != prc_size) {
|
||||
used_precision_size.push_back(prc_size);
|
||||
}
|
||||
push_prc_size(matmul->get_input_element_type(0).size());
|
||||
}
|
||||
|
||||
const auto consumers = matmul->get_output_target_inputs(0);
|
||||
if (std::none_of(consumers.begin(), consumers.end(),
|
||||
[](const ov::Input<ov::Node>& in) { return ov::is_type<ov::op::v0::Result>(in.get_node()); })) {
|
||||
const auto prc_size = matmul->get_element_type().size();
|
||||
if (used_precision_size.empty() || used_precision_size.back() != prc_size) {
|
||||
used_precision_size.push_back(prc_size);
|
||||
}
|
||||
push_prc_size(matmul->get_element_type().size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -63,11 +63,21 @@ auto is_supported_op(const std::shared_ptr<const Node> &n) -> bool {
|
||||
const auto& transpose = as_type_ptr<const opset1::Transpose>(n);
|
||||
const auto& out_shape = n->get_output_partial_shape(0);
|
||||
if (transpose && out_shape.is_static()) {
|
||||
const auto parent = transpose->get_input_node_shared_ptr(0);
|
||||
const auto child = transpose->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
|
||||
auto is_brgemm_case = ov::is_type<opset1::MatMul>(parent) || ov::is_type<opset1::MatMul>(child);
|
||||
// Check for Transpose parent is MatMul inside Subgraph
|
||||
if (const auto subgraph = ov::as_type_ptr<op::Subgraph>(parent)) {
|
||||
const auto body = subgraph->body_ptr();
|
||||
const auto subgraph_output = body->get_results()[transpose->input_value(0).get_index()]->get_input_node_shared_ptr(0);
|
||||
is_brgemm_case = is_brgemm_case || ov::is_type<opset1::MatMul>(subgraph_output);
|
||||
}
|
||||
|
||||
const auto& order = as_type_ptr<const opset1::Constant>(n->get_input_node_shared_ptr(1));
|
||||
if (order) {
|
||||
const auto order_value = order->cast_vector<int>();
|
||||
return TransposeDecomposition::supported_cases.count(order_value) != 0 ||
|
||||
FuseTransposeBrgemm::supported_cases.count(order_value) != 0;
|
||||
return (TransposeDecomposition::supported_cases.count(order_value) != 0) ||
|
||||
(is_brgemm_case && FuseTransposeBrgemm::supported_cases.count(order_value) != 0);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
@ -337,7 +347,7 @@ TokenizeSnippets::TokenizeSnippets() {
|
||||
|
||||
for (const auto& input_node : ov::as_node_vector(input_values)) {
|
||||
if (auto subgraph = ov::as_type_ptr<op::Subgraph>(input_node)) {
|
||||
if (!clones.count(input_node)) {
|
||||
if (!clones.count(input_node) && GetSnippetsSubgraphType(subgraph) != SnippetsSubgraphType::Completed) {
|
||||
auto f = subgraph->body().clone();
|
||||
f->set_friendly_name(subgraph->body_ptr()->get_friendly_name());
|
||||
clones[input_node] = f;
|
||||
@ -524,15 +534,18 @@ TokenizeSnippets::TokenizeSnippets() {
|
||||
ResultVector body_results;
|
||||
std::vector<std::set<Input<Node>>> subgraph_result_inputs;
|
||||
|
||||
ov::NodeVector new_body_ops;
|
||||
ov::NodeVector ops_for_buffer_count;
|
||||
for (auto subgraph : input_subgraphs) {
|
||||
// we should summurize additional needed data count (non-scalar Constants and Buffers) from all input subgraphs
|
||||
// because we will collapse them with our node and we should get total count
|
||||
const auto subgraph_ptr = ov::as_type_ptr<ov::snippets::op::Subgraph>(subgraph);
|
||||
hidden_data_count += subgraph_ptr->get_virtual_port_count();
|
||||
// Buffers can be existed only in Subgraphs with domain sensetive ops which
|
||||
// requires intermediate memory for data repacking
|
||||
// To avoid load time regressions, we verify only these Subgraph with domain sensetive ops
|
||||
if (subgraph_ptr->has_domain_sensitive_ops()) {
|
||||
const auto ops = subgraph_ptr->body_ptr()->get_ordered_ops();
|
||||
new_body_ops.insert(new_body_ops.end(), ops.begin(), ops.end());
|
||||
ops_for_buffer_count.insert(ops_for_buffer_count.end(), ops.begin(), ops.end());
|
||||
}
|
||||
|
||||
for (auto output : subgraph->outputs()) {
|
||||
@ -566,7 +579,7 @@ TokenizeSnippets::TokenizeSnippets() {
|
||||
}
|
||||
|
||||
if (op::Subgraph::is_domain_sensitive_op(node)) {
|
||||
new_body_ops.push_back(node);
|
||||
ops_for_buffer_count.push_back(node);
|
||||
}
|
||||
|
||||
for (auto output : node->outputs()) {
|
||||
@ -582,7 +595,7 @@ TokenizeSnippets::TokenizeSnippets() {
|
||||
// At the moment, CPU Plugin has limitation for GPR registers: there are only 12 available registers.
|
||||
// This limitation will be resolved once generator supports gprs spills [75622].
|
||||
// TODO [75567]: move this plugin-specific constraint to the plugin callback
|
||||
const auto unique_buffer_count = op::Subgraph::get_estimated_buffer_count(new_body_ops);
|
||||
const auto unique_buffer_count = op::Subgraph::get_estimated_buffer_count(ops_for_buffer_count);
|
||||
if (body_parameters.size() + body_results.size() + hidden_data_count + unique_buffer_count > 12) {
|
||||
const std::string message_reset = "new subgraph is created. Impossible to schedule subgraph with " +
|
||||
std::to_string(body_parameters.size()) + " inputs, " + std::to_string(body_results.size()) + " outputs and " +
|
||||
|
@ -6,9 +6,10 @@
|
||||
|
||||
|
||||
#include "snippets/itt.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
#include "snippets/pass/tokenization.hpp"
|
||||
#include "snippets/pass/collapse_subgraph.hpp"
|
||||
#include "snippets/op/subgraph.hpp"
|
||||
#include "snippets/op/brgemm.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
@ -17,17 +18,14 @@
|
||||
|
||||
namespace {
|
||||
auto is_supported_tensor(const ov::descriptor::Tensor& t) -> bool {
|
||||
// TODO: Add support of all supported by common tokenization element types
|
||||
// return ov::snippets::pass::TokenizeSnippets::supported_element_types.count(input.get_element_type()) != 0;
|
||||
return t.get_element_type() == ngraph::element::f32 &&
|
||||
t.get_partial_shape().is_static() && ov::snippets::utils::one_of(t.get_shape().size(), 3lu, 4lu);
|
||||
return t.get_partial_shape().is_static() && ov::snippets::utils::one_of(t.get_shape().size(), 3lu, 4lu);
|
||||
}
|
||||
|
||||
// TODO: Add support of FQ, Reshape?
|
||||
auto is_supported_intermediate_op(const std::shared_ptr<ov::Node>& node) -> bool {
|
||||
const auto is_intermediate_op = [](const std::shared_ptr<ov::Node>& node) {
|
||||
return ov::is_type<ov::op::util::UnaryElementwiseArithmetic>(node) ||
|
||||
ov::is_type<ov::op::util::BinaryElementwiseArithmetic>(node) ||
|
||||
ov::is_type<ov::op::v0::FakeQuantize>(node) ||
|
||||
ov::is_type<ov::op::v1::Select>(node);
|
||||
};
|
||||
return is_intermediate_op(node) && ov::snippets::pass::TokenizeSnippets::AppropriateForSubgraph(node);
|
||||
@ -40,9 +38,12 @@ auto is_valid_transpose(const std::shared_ptr<ov::opset1::Transpose>& node, std:
|
||||
return false;
|
||||
return transpose_pattern->cast_vector<int64_t>() == expected_order;
|
||||
};
|
||||
auto is_supported_transpose_tensor = [](const ov::descriptor::Tensor& t) {
|
||||
return is_supported_tensor(t) && ov::snippets::pass::TokenizeSnippets::supported_element_types.count(t.get_element_type()) != 0;
|
||||
};
|
||||
|
||||
return node && node->get_output_target_inputs(0).size() == 1 && node->get_shape().size() == 4 &&
|
||||
valid_transpose_order(node->get_input_node_shared_ptr(1)) && is_supported_tensor(node->get_input_tensor(0));
|
||||
valid_transpose_order(node->get_input_node_shared_ptr(1)) && is_supported_transpose_tensor(node->get_input_tensor(0));
|
||||
}
|
||||
|
||||
auto tokenize_broadcast(const std::shared_ptr<ov::Node>& interm_op, ov::NodeVector& ordered_ops) -> void {
|
||||
@ -98,14 +99,15 @@ auto tokenize_reshape_around_softmax(std::shared_ptr<ov::Node>& interm_op,
|
||||
ov::NodeVector& ordered_ops) -> bool {
|
||||
reshape = ov::as_type_ptr<ov::opset1::Reshape>(interm_op);
|
||||
if (reshape) {
|
||||
const auto shape = reshape->get_input_shape(0);
|
||||
if (shape.back() != reshape->get_output_shape(0).back() || reshape->get_output_target_inputs(0).size() != 1)
|
||||
const auto in_shape = reshape->get_input_shape(0);
|
||||
const auto out_shape = reshape->get_output_shape(0);
|
||||
if (in_shape.back() != out_shape.back() || reshape->get_output_target_inputs(0).size() != 1)
|
||||
return false;
|
||||
ordered_ops.push_back(reshape);
|
||||
interm_op = reshape->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
|
||||
}
|
||||
return true;
|
||||
};
|
||||
}
|
||||
|
||||
auto get_potential_body_params(const std::shared_ptr<ov::Node>& op) -> size_t {
|
||||
size_t count = 0;
|
||||
@ -124,43 +126,50 @@ auto get_potential_body_params(const std::shared_ptr<ov::Node>& op) -> size_t {
|
||||
|
||||
auto update_intermediate_supported_ops(std::shared_ptr<ov::Node>& interm_op, ov::NodeVector& ordered_ops,
|
||||
size_t& hidden_virtual_ports_count, size_t& potential_body_params_count) -> bool {
|
||||
// TODO: Add Reshape, FQ support
|
||||
while (is_supported_intermediate_op(interm_op)) {
|
||||
// All supported intermediate ops have only one output port
|
||||
// To verify output element type is enough because all supported intermediate ops have the same output element type as input type
|
||||
if (interm_op->get_output_target_inputs(0).size() != 1 || !is_supported_tensor(interm_op->get_output_tensor(0)))
|
||||
if (interm_op->get_output_target_inputs(0).size() != 1)
|
||||
return false;
|
||||
|
||||
// Check for supported Broadcast op
|
||||
// Check for supported ops on branches: Broadcast/Elementwise (for example, dequantize ops)
|
||||
if (interm_op->get_input_size() > 1) {
|
||||
tokenize_broadcast(interm_op, ordered_ops);
|
||||
}
|
||||
|
||||
auto is_supported_branch_op = [&ordered_ops](const std::shared_ptr<ov::Node>& op) {
|
||||
return is_supported_intermediate_op(op) &&
|
||||
ov::snippets::pass::GetSnippetsNodeType(op) != ov::snippets::pass::SnippetsNodeType::SkippedByPlugin &&
|
||||
std::find(ordered_ops.begin(), ordered_ops.end(), op) == ordered_ops.end();
|
||||
};
|
||||
// To avoid unsupported number of non-scalar Constants in the future after FakeQuantize decomposition (plugin specific limitation)
|
||||
// we should calculate potential number of non-scalar Constants for FakeQuantize that will be moved up from body.
|
||||
if (const auto fq_node = ov::as_type_ptr<ov::op::v0::FakeQuantize>(interm_op)) {
|
||||
hidden_virtual_ports_count += ov::snippets::utils::get_non_scalar_constant_count_for_fq(fq_node);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < interm_op->get_input_size(); ++i) {
|
||||
const size_t shift = ordered_ops.size();
|
||||
auto parent = interm_op->get_input_node_shared_ptr(i);
|
||||
while (is_supported_branch_op(parent)) {
|
||||
// All supported ops have only one output port
|
||||
if (parent->get_output_target_inputs(0).size() != 1)
|
||||
break;
|
||||
auto is_supported_branch_op = [&ordered_ops](const std::shared_ptr<ov::Node>& op) {
|
||||
return is_supported_intermediate_op(op) &&
|
||||
ov::snippets::pass::GetSnippetsNodeType(op) != ov::snippets::pass::SnippetsNodeType::SkippedByPlugin &&
|
||||
std::find(ordered_ops.begin(), ordered_ops.end(), op) == ordered_ops.end();
|
||||
};
|
||||
|
||||
// Add node only if there are scalar constants on inputs because of plugin-specific limitation
|
||||
bool are_weights_scalar = true;
|
||||
const auto parent_count = parent->get_input_size();
|
||||
for (size_t i = 1; i < parent_count; ++i) {
|
||||
are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1;
|
||||
for (size_t i = 0; i < interm_op->get_input_size(); ++i) {
|
||||
const size_t shift = ordered_ops.size();
|
||||
auto parent = interm_op->get_input_node_shared_ptr(i);
|
||||
while (is_supported_branch_op(parent)) {
|
||||
// All supported ops have only one output port
|
||||
if (parent->get_output_target_inputs(0).size() != 1)
|
||||
break;
|
||||
|
||||
// Add node only if there are scalar constants on inputs because of plugin-specific limitation
|
||||
bool are_weights_scalar = true;
|
||||
const auto parent_count = parent->get_input_size();
|
||||
for (size_t i = 1; i < parent_count; ++i) {
|
||||
are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1;
|
||||
}
|
||||
if (!are_weights_scalar)
|
||||
break;
|
||||
|
||||
ordered_ops.insert(ordered_ops.begin() + shift, parent);
|
||||
// TODO [107731]: We think that sequence of ops goes through input port 0
|
||||
// But can be Select here? If it can be, parent shouldn't be on input port 0. Need another way?
|
||||
if (parent->get_input_size() > 0)
|
||||
parent = parent->get_input_node_shared_ptr(0);
|
||||
}
|
||||
|
||||
ordered_ops.insert(ordered_ops.begin() + shift, parent);
|
||||
// We think that sequence of ops goes through input port 0
|
||||
// But can be Select here? If it can be, parent shouldn't be on input port 0. Need another way?
|
||||
parent = parent->get_input_node_shared_ptr(0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -173,7 +182,7 @@ auto update_intermediate_supported_ops(std::shared_ptr<ov::Node>& interm_op, ov:
|
||||
};
|
||||
} // namespace
|
||||
|
||||
ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
|
||||
ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsTokenization::Config& config) {
|
||||
MATCHER_SCOPE(TokenizeMHASnippets);
|
||||
|
||||
auto m_matmul0 = std::make_shared<ov::opset1::MatMul>(ov::pass::pattern::any_input(ov::pass::pattern::has_static_shape()),
|
||||
@ -184,14 +193,13 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::TokenizeMHASnippets")
|
||||
auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
||||
// Queries + Key + Values = 3 standard inputs of MHA
|
||||
size_t potential_body_params_count = 3;
|
||||
// After some transformations, a different number of Constants for some operations may be created
|
||||
// than the actual number of Constants during tokenization.
|
||||
// To avoid unsupported number of non-scalar Constants in the future (plugin specific limitation)
|
||||
// we should calculate potential number of non-scalar Constants that will be moved up from body.
|
||||
// TODO: Need update this variable when FQ will be supported
|
||||
size_t hidden_virtual_ports_count = 0;
|
||||
// Queries + Key + Values = 3 standard inputs of MHA
|
||||
size_t potential_body_params_count = 3;
|
||||
// The count of potential unique Buffers - it's hidden virtual ports as well
|
||||
// We should go through Subgraph and calculate potential non-inplace Buffers count.
|
||||
// Example:
|
||||
@ -231,10 +239,20 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
|
||||
!is_supported_tensor(matmul0->get_input_tensor(0)) || !is_supported_tensor(matmul0->get_input_tensor(1)))
|
||||
return false;
|
||||
|
||||
if (transformation_callback(matmul0)) {
|
||||
const auto matmul0_prc = op::Brgemm::get_output_type(matmul0->get_input_element_type(0),
|
||||
matmul0->get_input_element_type(1));
|
||||
if (matmul0_prc == element::undefined) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Between MatMul0 and Softmax will be the one Loop because of LoopFusing optimization.
|
||||
// The Loop will have one Buffer with the same shape both on input and output.
|
||||
// Need to check for precision to get if we need one more register for Buffer
|
||||
if (matmul0_prc.size() != ov::element::f32.size()) {
|
||||
if (buffer_count < 2)
|
||||
buffer_count++;
|
||||
}
|
||||
|
||||
ordered_ops.push_back(matmul0);
|
||||
|
||||
auto interm_op = matmul0->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
|
||||
@ -276,10 +294,28 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
|
||||
return false;
|
||||
|
||||
const auto matmul1 = ov::as_type_ptr<ov::opset1::MatMul>(interm_op);
|
||||
if (!matmul1 || matmul1->get_output_target_inputs(0).size() != 1 || matmul1->get_transpose_a() || matmul1->get_transpose_b() ||
|
||||
!is_supported_tensor(matmul1->get_input_tensor(0)) || !is_supported_tensor(matmul1->get_input_tensor(1)))
|
||||
if (!matmul1 || matmul1->get_output_target_inputs(0).size() != 1 ||
|
||||
matmul1->get_transpose_a() || matmul1->get_transpose_b())
|
||||
return false;
|
||||
|
||||
const auto matmul1_out_type = op::Brgemm::get_output_type(matmul1->get_input_element_type(0),
|
||||
matmul1->get_input_element_type(1));
|
||||
if (matmul1_out_type == element::undefined ||
|
||||
!is_supported_tensor(matmul1->get_input_tensor(0)) ||
|
||||
!is_supported_tensor(matmul1->get_input_tensor(1)))
|
||||
return false;
|
||||
|
||||
if (transformation_callback(matmul0)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Between Softmax and MatMul1 will be the one Loop because of LoopFusing optimization.
|
||||
// The Loop will have one Buffer with the same shape both on input and output.
|
||||
// Need to check for precision to get if we need one more register for Buffer
|
||||
if (matmul1->get_input_element_type(0).size() != ov::element::f32.size()) {
|
||||
buffer_count++;
|
||||
}
|
||||
|
||||
/***********************/
|
||||
|
||||
/***** Transposes *****/
|
||||
@ -287,29 +323,51 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
|
||||
* We can add them into Subgraph body
|
||||
*/
|
||||
|
||||
auto tokenize_transpose = [config](const std::shared_ptr<ov::Node>& node) -> std::shared_ptr<ov::opset1::Transpose> {
|
||||
return config.mha_token_enable_transpose ? ov::as_type_ptr<ov::opset1::Transpose>(node)
|
||||
: nullptr;
|
||||
};
|
||||
|
||||
// First input branch of MatMul0 should be executed before second input branch of MatMul0,
|
||||
// so firstly we insert Transpose1 on the beginning of ordered_ops and then Transpose1
|
||||
bool are_weights_scalar = true;
|
||||
// We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order (or without this Transpose1)
|
||||
// only if these ops have scalar shapes on other inputs.
|
||||
// There is transformation ExplicitTransposeMatMulInputs that set supported order and transposed_b(false).
|
||||
// We can allow to call this pass only if ops have scalar shapes to avoid shape mismatching
|
||||
const auto is_transposed_b_0 = matmul0->get_transpose_b();
|
||||
auto parent = matmul0->get_input_node_shared_ptr(1);
|
||||
while (is_supported_intermediate_op(parent)) {
|
||||
// All supported ops have only one output port
|
||||
// To verify output element type is enough because all supported ops have the same output element type as input type
|
||||
if (parent->get_output_target_inputs(0).size() != 1 || !is_supported_tensor(parent->get_output_tensor(0)))
|
||||
if (parent->get_output_target_inputs(0).size() != 1)
|
||||
break;
|
||||
|
||||
const auto parent_count = parent->inputs().size();
|
||||
for (size_t i = 1; i < parent_count; ++i) {
|
||||
are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1;
|
||||
// Only if MatMul0 has transposed_b, we have to tokenize scalar ops
|
||||
// to move explicit Transpose from MatMul0 input_1 to Parameter of Subgraph body
|
||||
if (is_transposed_b_0) {
|
||||
const auto parent_count = parent->get_input_size();
|
||||
bool are_weights_scalar = true;
|
||||
for (size_t i = 1; i < parent_count; ++i) {
|
||||
are_weights_scalar = are_weights_scalar && ov::shape_size(parent->get_input_shape(i)) == 1;
|
||||
}
|
||||
if (!are_weights_scalar) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// To avoid unsupported number of non-scalar Constants in the future after FakeQuantize decomposition (plugin specific limitation)
|
||||
// we should calculate potential number of non-scalar Constants for FakeQuantize that will be moved up from body.
|
||||
if (const auto fq_node = ov::as_type_ptr<ov::op::v0::FakeQuantize>(parent)) {
|
||||
hidden_virtual_ports_count += ov::snippets::utils::get_non_scalar_constant_count_for_fq(fq_node);
|
||||
}
|
||||
potential_body_params_count += get_potential_body_params(parent);
|
||||
ordered_ops.insert(ordered_ops.begin(), parent);
|
||||
// We think that sequence of ops goes through input port 0
|
||||
// But can be Select here? If it can be, parent shouldn't be on input port 0. Need another way?
|
||||
// TODO [107731] To go always through 0-th port - is it safe?
|
||||
parent = parent->get_input_node_shared_ptr(0);
|
||||
}
|
||||
|
||||
auto transpose1 = ov::as_type_ptr<ov::opset1::Transpose>(parent);
|
||||
if (matmul0->get_transpose_b()) {
|
||||
const auto transpose1 = tokenize_transpose(parent);
|
||||
if (is_transposed_b_0) {
|
||||
if (is_valid_transpose(transpose1, {0, 2, 1, 3})) {
|
||||
// We can support several ops between MatMul0 with transposed_b and Transpose1 with 0213 order
|
||||
// only if these ops have scalar shapes on other inputs.
|
||||
@ -329,31 +387,63 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add Reshape Support for all Transposes
|
||||
// Add 3D support for all Transposes
|
||||
const auto transpose0 = ov::as_type_ptr<ov::opset1::Transpose>(matmul0->get_input_node_shared_ptr(0));
|
||||
if (transpose1) {
|
||||
// Between Transpose1 and MatMul0 will be the one Loop because of LoopFusing optimization.
|
||||
// The Loop will have one Buffer with the same shape both on input and output.
|
||||
// Need to check for precision to get if we need one more register for Buffer
|
||||
if (matmul0->get_input_element_type(1).size() != transpose1->get_output_element_type(0).size()) {
|
||||
buffer_count++;
|
||||
}
|
||||
}
|
||||
|
||||
const auto transpose0 = tokenize_transpose(matmul0->get_input_node_shared_ptr(0));
|
||||
if (is_valid_transpose(transpose0, {0, 2, 1, 3})) {
|
||||
ordered_ops.insert(ordered_ops.begin(), transpose0);
|
||||
} else if (matmul0->get_transpose_b()) {
|
||||
} else if (matmul0->get_transpose_a()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto transpose2 = ov::as_type_ptr<ov::opset1::Transpose>(matmul1->get_input_node_shared_ptr(1));
|
||||
const auto transpose2 = tokenize_transpose(matmul1->get_input_node_shared_ptr(1));
|
||||
if (is_valid_transpose(transpose2, {0, 2, 1, 3})) {
|
||||
ordered_ops.push_back(transpose2);
|
||||
}
|
||||
ordered_ops.push_back(matmul1);
|
||||
|
||||
bool are_ops_after_matmul1 = false;
|
||||
auto child = matmul1->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
|
||||
// TODO: Add support Eltwises between MatMul1 and Transpose
|
||||
// status = update_intermediate_supported_ops(child, ordered_ops);
|
||||
// if (!status) {
|
||||
// ordered_ops.push_back(child);
|
||||
// }
|
||||
while (is_supported_intermediate_op(child)) {
|
||||
are_ops_after_matmul1 = true;
|
||||
// All supported ops have only one output port
|
||||
if (child->get_output_target_inputs(0).size() != 1)
|
||||
break;
|
||||
|
||||
auto transpose3 = ov::as_type_ptr<ov::opset1::Transpose>(child);
|
||||
if (is_valid_transpose(transpose3, {0, 2, 1, 3})) {
|
||||
ordered_ops.push_back(transpose3);
|
||||
// To avoid unsupported number of non-scalar Constants in the future after FakeQuantize decomposition (plugin specific limitation)
|
||||
// we should calculate potential number of non-scalar Constants for FakeQuantize that will be moved up from body.
|
||||
if (const auto fq_node = ov::as_type_ptr<ov::op::v0::FakeQuantize>(child)) {
|
||||
hidden_virtual_ports_count += ov::snippets::utils::get_non_scalar_constant_count_for_fq(fq_node);
|
||||
}
|
||||
potential_body_params_count += get_potential_body_params(child);
|
||||
|
||||
// TODO [75567]: move this plugin-specific constraint to the plugin callback
|
||||
// We cannot collapse op to Subgraph if count of potential Parameter and Result count is higher 12
|
||||
if (potential_body_params_count + child->get_output_target_inputs(0).size() + hidden_virtual_ports_count + buffer_count > 12) {
|
||||
break;
|
||||
}
|
||||
|
||||
ordered_ops.push_back(child);
|
||||
child = child->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
|
||||
}
|
||||
|
||||
// At the moment Snippets don't support nodes between MatMul1 and Transpose3 due to Loop and strided calculations limitations
|
||||
// MatMul1
|
||||
// <Supported ops>
|
||||
// Transpose3
|
||||
if (!are_ops_after_matmul1) {
|
||||
auto transpose3 = tokenize_transpose(child);
|
||||
if (is_valid_transpose(transpose3, {0, 2, 1, 3}) &&
|
||||
transpose3->get_input_element_type(0) == matmul1_out_type) { // To avoid Convert between MatMul1 and Transpose3
|
||||
ordered_ops.push_back(transpose3);
|
||||
}
|
||||
}
|
||||
|
||||
/**********************/
|
||||
@ -362,7 +452,7 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
|
||||
|
||||
/* ====== Subgraph creation ======= */
|
||||
|
||||
// TODO: move this plugin-specific constraint to the plugin callback
|
||||
// TODO [75567]: move this plugin-specific constraint to the plugin callback
|
||||
const auto last_node = ordered_ops.back();
|
||||
if (potential_body_params_count + last_node->get_output_size() + hidden_virtual_ports_count + buffer_count > 12) {
|
||||
return false;
|
||||
@ -378,7 +468,9 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
|
||||
const auto input = node->input(i);
|
||||
const auto parent = input.get_source_output().get_node_shared_ptr();
|
||||
const auto constant = ov::as_type_ptr<ov::op::v0::Constant>(parent);
|
||||
if (constant && (ov::shape_size(input.get_shape()) == 1 || op::Subgraph::constant_input_should_be_inside_body(node))) {
|
||||
if (constant && (ov::shape_size(input.get_shape()) == 1 ||
|
||||
ov::is_type<ov::op::v0::FakeQuantize>(node) ||
|
||||
op::Subgraph::constant_input_should_be_inside_body(node))) {
|
||||
// If Constant has one consumer - target node, we add Constant to body_inputs
|
||||
// If Constant has several consumers, we should check that all these consumers are inside Subgraph body
|
||||
// and if all of them are inside body, we can explicitly add Constant to the body_inputs, otherwise we should
|
||||
@ -454,6 +546,9 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets() {
|
||||
subgraph->get_rt_info()["originalLayersNames"] = fused_names;
|
||||
subgraph->set_virtual_port_count(hidden_virtual_ports_count);
|
||||
|
||||
// mark the Subgraph as Completed to not allow Snippets to include any nodes into the MHA Subgraph in common Tokenization
|
||||
SetSnippetsSubgraphType(subgraph, SnippetsSubgraphType::Completed);
|
||||
|
||||
return true;
|
||||
|
||||
/* ================================ */
|
||||
|
@ -7,6 +7,8 @@
|
||||
#include "snippets/pass/tokenization.hpp"
|
||||
#include "snippets/pass/common_optimizations.hpp"
|
||||
#include "openvino/pass/manager.hpp"
|
||||
#include "snippets/pass/mha_tokenization.hpp"
|
||||
#include "snippets/pass/collapse_subgraph.hpp"
|
||||
|
||||
|
||||
namespace ov {
|
||||
@ -18,6 +20,13 @@ void SetSnippetsNodeType(const std::shared_ptr<Node> &node, SnippetsNodeType nod
|
||||
rt["SnippetsNodeType"] = nodeType;
|
||||
}
|
||||
|
||||
void SetSnippetsSubgraphType(const std::shared_ptr<op::Subgraph> &node, SnippetsSubgraphType nodeType) {
|
||||
if (node) {
|
||||
auto &rt = node->get_rt_info();
|
||||
rt["SnippetsSubgraphType"] = nodeType;
|
||||
}
|
||||
}
|
||||
|
||||
SnippetsNodeType GetSnippetsNodeType(const std::shared_ptr<const Node> &node) {
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::GetSnippetsNodeType")
|
||||
auto& rt = node->get_rt_info();
|
||||
@ -27,6 +36,17 @@ SnippetsNodeType GetSnippetsNodeType(const std::shared_ptr<const Node> &node) {
|
||||
return rinfo->second.as<SnippetsNodeType>();
|
||||
}
|
||||
|
||||
SnippetsSubgraphType GetSnippetsSubgraphType(const std::shared_ptr<const op::Subgraph> &node) {
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::GetSnippetsSubgraphType")
|
||||
if (!node)
|
||||
return SnippetsSubgraphType::NotSet;
|
||||
auto &rt = node->get_rt_info();
|
||||
const auto rinfo = rt.find("SnippetsSubgraphType");
|
||||
if (rinfo == rt.end())
|
||||
return SnippetsSubgraphType::NotSet;
|
||||
return rinfo->second.as<SnippetsSubgraphType>();
|
||||
}
|
||||
|
||||
void SetTopologicalOrder(const std::shared_ptr<Node> &node, int64_t order) {
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::SetTopologicalOrder")
|
||||
auto& rt = node->get_rt_info();
|
||||
@ -58,7 +78,7 @@ bool SnippetsTokenization::run_on_model(const std::shared_ptr<ov::Model>& m) {
|
||||
manager.set_per_pass_validation(false);
|
||||
|
||||
manager.register_pass<EnumerateNodes>();
|
||||
manager.register_pass<TokenizeMHASnippets>();
|
||||
manager.register_pass<TokenizeMHASnippets>(m_config);
|
||||
manager.register_pass<TokenizeSnippets>();
|
||||
manager.register_pass<CommonOptimizations>();
|
||||
manager.run_passes(m);
|
||||
|
@ -5,6 +5,7 @@
|
||||
#include <common_test_utils/ngraph_test_utils.hpp>
|
||||
#include "lowering_utils.hpp"
|
||||
#include "snippets/pass/tokenization.hpp"
|
||||
#include "snippets/pass/collapse_subgraph.hpp"
|
||||
|
||||
|
||||
namespace ov {
|
||||
|
@ -8,6 +8,7 @@
|
||||
#include <subgraph_fq.hpp>
|
||||
#include <subgraph_converts.hpp>
|
||||
#include "snippets/pass/tokenization.hpp"
|
||||
#include "snippets/pass/collapse_subgraph.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include <pass/mha_tokenization.hpp>
|
||||
#include <subgraph_mha.hpp>
|
||||
#include "snippets/pass/tokenization.hpp"
|
||||
#include "snippets/pass/mha_tokenization.hpp"
|
||||
#include "snippets/pass/explicit_transpose_matmul_inputs.hpp"
|
||||
|
||||
namespace ov {
|
||||
@ -20,14 +21,23 @@ void TokenizeMHASnippetsTests::run() {
|
||||
}
|
||||
|
||||
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA) {
|
||||
const auto& f = MHAFunction(std::vector<PartialShape>{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}});
|
||||
const auto &f = MHAFunction(std::vector<PartialShape>{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}},
|
||||
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}));
|
||||
function = f.getOriginal();
|
||||
function_ref = f.getReference();
|
||||
run();
|
||||
}
|
||||
|
||||
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_with_MatMul0_Transpose) {
|
||||
const auto& f = MHAMatMul0TransposeFunction(std::vector<PartialShape>{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}});
|
||||
const auto &f = MHAMatMul0TransposeFunction(std::vector<PartialShape>{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}},
|
||||
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}));
|
||||
function = f.getOriginal();
|
||||
function_ref = f.getReference();
|
||||
run();
|
||||
}
|
||||
|
||||
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_with_int_Matmuls) {
|
||||
const auto &f = MHAINT8MatMulTypeRelaxedFunction(std::vector<PartialShape>{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}});
|
||||
function = f.getOriginal();
|
||||
function_ref = f.getReference();
|
||||
run();
|
||||
|
@ -172,7 +172,7 @@ KernelEmitter::KernelEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl:
|
||||
general_exprs.emplace_back(expr);
|
||||
}
|
||||
}
|
||||
num_unique_buffer = unique_buffers.size();
|
||||
num_unique_buffers = unique_buffers.size();
|
||||
|
||||
// Note that we can't use reg_indexes_idx or reg_const_params_idx to store data pointers because these two
|
||||
// regs are used to calculate offsets for the data pointers
|
||||
@ -198,15 +198,16 @@ void KernelEmitter::validate_arguments(const std::vector<size_t> &in,
|
||||
IE_THROW() << "KernelEmitter got invalid number of inputs. Expected 0, got " << in.size();
|
||||
if (!out.empty())
|
||||
IE_THROW() << "KernelEmitter got invalid number of outputs. Expected 0, got " << out.size();
|
||||
const auto num_params = num_inputs + num_outputs + num_unique_buffer;
|
||||
const auto num_params = num_inputs + num_outputs + num_unique_buffers;
|
||||
// The number of used gpr may be >= num_params since LoopBegin+LoopEnd could also use gpr to store work_amount
|
||||
if (data_ptr_regs_idx.size() != num_params)
|
||||
IE_THROW() << "KernelEmitter: number of inputs and outputs is inconsisnent with the number of allocated registers"
|
||||
IE_THROW() << "KernelEmitter: number of inputs and outputs is inconsistent with the number of allocated registers "
|
||||
<< num_params << " data_ptr_regs_idx.size() = " << data_ptr_regs_idx.size();
|
||||
}
|
||||
|
||||
void KernelEmitter::init_data_pointers(size_t num_inputs, size_t num_params, size_t num_buffer,
|
||||
const Reg64& reg_indexes, const Reg64& reg_const_params, const std::vector<Reg64>& data_ptr_regs) const {
|
||||
void KernelEmitter::init_data_pointers(const Xbyak::Reg64& reg_indexes, const Xbyak::Reg64& reg_const_params,
|
||||
const std::vector<Xbyak::Reg64>& data_ptr_regs) const {
|
||||
const auto num_params = num_inputs + num_outputs;
|
||||
// Note that we don't need offset for the last dim, since it's handled directly by Tile emitter
|
||||
const size_t offset_rank = jcp.master_shape.size() - 1;
|
||||
std::vector<std::vector<size_t>> data_offsets(num_params, std::vector<size_t>{});
|
||||
@ -267,7 +268,9 @@ void KernelEmitter::init_data_pointers(size_t num_inputs, size_t num_params, siz
|
||||
// Vector "data_ptr_regs" is sorted by abstract regs.
|
||||
// It means that the vector contains the physical registers in order [src, .., src, dst, .., dst, buffer]
|
||||
// So we can initialize buffer register firstly as last value of vector "data_ptr_regs"
|
||||
for (size_t i = 0; i < num_buffer; ++i) {
|
||||
// NOTE: Snippets Buffer Scratchpad has the common data pointer for all Buffers (even with different ID).
|
||||
// The accessing memory is covered by correct offsets in each Buffer and the corresponding MemoryAccess ops
|
||||
for (size_t i = 0; i < num_unique_buffers; ++i) {
|
||||
h->mov(data_ptr_regs[num_params + i], h->ptr[reg_const_params + GET_OFF(buffer_scratchpad_ptr)]);
|
||||
}
|
||||
size_t i = 0;
|
||||
@ -299,7 +302,7 @@ void KernelEmitter::emit_impl(const std::vector<size_t>& in,
|
||||
std::vector<Reg64> data_ptr_regs;
|
||||
transform_idxs_to_regs(data_ptr_regs_idx, data_ptr_regs);
|
||||
|
||||
init_data_pointers(num_inputs, num_inputs + num_outputs, num_unique_buffer, reg_indexes, reg_const_params, data_ptr_regs);
|
||||
init_data_pointers(reg_indexes, reg_const_params, data_ptr_regs);
|
||||
for (const auto& expression : body) {
|
||||
const auto& emitter = expression->get_emitter();
|
||||
std::vector<size_t> in_regs, out_regs;
|
||||
|
@ -87,13 +87,13 @@ private:
|
||||
const std::vector<size_t> &out) const override;
|
||||
void emit_impl(const std::vector<size_t>& in,
|
||||
const std::vector<size_t>& out) const override;
|
||||
void init_data_pointers(size_t, size_t, size_t, const Xbyak::Reg64&, const Xbyak::Reg64&, const std::vector<Xbyak::Reg64>&) const;
|
||||
void init_data_pointers(const Xbyak::Reg64&, const Xbyak::Reg64&, const std::vector<Xbyak::Reg64>&) const;
|
||||
|
||||
jit_snippets_compile_args jcp;
|
||||
std::vector<size_t> gp_regs_pool;
|
||||
size_t num_inputs;
|
||||
size_t num_outputs;
|
||||
size_t num_unique_buffer;
|
||||
size_t num_unique_buffers;
|
||||
// Vector of indices (lenght = input tensor rank) per every input and output that describes in which order
|
||||
// corresponding tensor dimensions are accessed (default: consecutive dense, e.g. 0,1,2,3 for 4D tensor).
|
||||
// Needed to calc i/o offsets.
|
||||
|
@ -427,7 +427,7 @@ void MarkSubgraphOpAsSkipped(const std::shared_ptr<Node> &node) {
|
||||
bool isSuitableConvert(const std::shared_ptr<const Node>& node) {
|
||||
if (!ov::is_type<ov::op::v0::Convert>(node))
|
||||
return false;
|
||||
auto hasResult = [](const std::shared_ptr<const Node>& node){
|
||||
auto hasResult = [](const std::shared_ptr<const Node>& node) {
|
||||
auto consumers = node->output(0).get_target_inputs();
|
||||
bool findResult = false;
|
||||
if (consumers.size() == 1) {
|
||||
@ -449,13 +449,19 @@ bool isSuitableConvert(const std::shared_ptr<const Node>& node) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto is_skipped_op(const std::shared_ptr<ov::Node>& op) -> bool {
|
||||
return ov::is_type<ov::op::v0::Constant>(op) ||
|
||||
ov::is_type<ov::op::v0::Parameter>(op) ||
|
||||
ov::is_type<ov::op::v0::Result>(op);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr<ov::Model> &m) {
|
||||
RUN_ON_MODEL_SCOPE(SnippetsMarkSkipped);
|
||||
int channelAxis = DEFAULT_AXIS;
|
||||
for (auto &node : m->get_ordered_ops()) {
|
||||
if (ov::is_type<ov::op::v0::Constant>(node) || ov::is_type<ov::op::v0::Result>(node))
|
||||
if (is_skipped_op(node))
|
||||
continue;
|
||||
if (isSuitableConvolutionParent(node)) {
|
||||
// Initiate fusing chain
|
||||
|
@ -108,6 +108,8 @@
|
||||
|
||||
// Snippets
|
||||
#include "snippets/pass/tokenization.hpp"
|
||||
#include "snippets/pass/mha_tokenization.hpp"
|
||||
#include "snippets/pass/collapse_subgraph.hpp"
|
||||
#include "snippets/pass/common_optimizations.hpp"
|
||||
|
||||
// Misc
|
||||
@ -616,22 +618,58 @@ void Transformations::MainSnippets(void) {
|
||||
!dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) // snippets are implemented only for relevant platforms (avx2+ extensions)
|
||||
return;
|
||||
|
||||
// At the moment Snippets supports Transposes in MHA pattern only in FP32 case since
|
||||
// - ConvertSaturation[BF16->FP32] will be inserted after Parameters and before Transposes in canonicalization stage
|
||||
// - ConvertSaturation[FP32->BF16] will be inserted after Transposes and before Brgemm in precision propagation stage
|
||||
// Because of that Transposes won't be fused into Brgemm
|
||||
// TODO [111813]: Need to update this pipeline to avoid Converts between Transposes and Brgemm on inputs
|
||||
ov::snippets::pass::SnippetsTokenization::Config tokenization_config;
|
||||
tokenization_config.mha_token_enable_transpose = !enableBF16;
|
||||
|
||||
ngraph::pass::Manager snippetsManager;
|
||||
snippetsManager.set_per_pass_validation(false);
|
||||
if (snippetsMode != Config::SnippetsMode::IgnoreCallback)
|
||||
CPU_REGISTER_PASS_X64(snippetsManager, SnippetsMarkSkipped, enableBF16);
|
||||
CPU_REGISTER_PASS_X64(snippetsManager, snippets::pass::SnippetsTokenization);
|
||||
CPU_REGISTER_PASS_X64(snippetsManager, snippets::pass::SnippetsTokenization, tokenization_config);
|
||||
|
||||
// Tokenize MHA in quantized model or with BF16 only in tests.
|
||||
// TODO [106921]: Please enable the tokenization when the ticket 106921 with blocking support for BRGEMM will be implemented
|
||||
const bool onlyFloatSupported = snippetsMode != Config::SnippetsMode::IgnoreCallback;
|
||||
const bool isMHASupported =
|
||||
!enableBF16 && // TODO: Need to add BF16 support for MHA in Snippets
|
||||
IMPLICATION(enableBF16, !onlyFloatSupported) &&
|
||||
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core); // MHA has BRGEMM that is supported only on AVX512 platforms
|
||||
if (!isMHASupported) {
|
||||
CPU_DISABLE_PASS_X64(snippetsManager, snippets::pass::TokenizeMHASnippets);
|
||||
}
|
||||
|
||||
#if defined(OPENVINO_ARCH_X86_64)
|
||||
auto is_supported_matmul = [onlyFloatSupported](const std::shared_ptr<const ov::Node>& n) {
|
||||
const auto matmul = ov::as_type_ptr<const ov::op::v0::MatMul>(n);
|
||||
if (!matmul)
|
||||
return false;
|
||||
if (matmul->get_input_element_type(1) == ov::element::i8)
|
||||
return !onlyFloatSupported && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_vnni);
|
||||
if (matmul->get_input_element_type(0) == ov::element::bf16 &&
|
||||
matmul->get_input_element_type(1) == ov::element::bf16)
|
||||
return !onlyFloatSupported && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16);
|
||||
return true;
|
||||
};
|
||||
#endif // OPENVINO_ARCH_X86_64
|
||||
|
||||
if (snippetsMode != Config::SnippetsMode::IgnoreCallback) {
|
||||
CPU_SET_CALLBACK_X64(snippetsManager,
|
||||
[](const std::shared_ptr<const ov::Node>& n) -> bool {
|
||||
const auto pshape = n->get_output_partial_shape(0);
|
||||
[&](const std::shared_ptr<const ov::Node>& n) -> bool {
|
||||
// Tranformation callback is called on MatMul0
|
||||
if (!is_supported_matmul(n))
|
||||
return true;
|
||||
// Search for MatMul1
|
||||
auto child = n->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
|
||||
while (!ov::is_type<const ov::op::v0::MatMul>(child)) {
|
||||
child = child->get_output_target_inputs(0).begin()->get_node()->shared_from_this();
|
||||
}
|
||||
if (!is_supported_matmul(child))
|
||||
return true;
|
||||
const auto pshape = child->get_input_partial_shape(0);
|
||||
const auto shape = pshape.get_shape();
|
||||
const auto parallel_work_amount =
|
||||
std::accumulate(shape.rbegin() + 2, shape.rend(), 1, std::multiplies<size_t>());
|
||||
@ -662,18 +700,18 @@ void Transformations::MainSnippets(void) {
|
||||
// todo: general tokenization flow is not currently supported for these operations.
|
||||
// they can be tokenized only as a part of complex patterns
|
||||
const bool is_disabled_tokenization = (ov::is_type<const ov::op::v1::Softmax>(n) ||
|
||||
ov::is_type<const ov::op::v8::Softmax>(n) ||
|
||||
ov::is_type<const ov::op::v0::MatMul>(n) ||
|
||||
ov::is_type<const ov::op::v1::Transpose>(n) ||
|
||||
ov::is_type<const ov::op::v1::Broadcast>(n) ||
|
||||
ov::is_type<const ov::op::v3::Broadcast>(n));
|
||||
ov::is_type<const ov::op::v8::Softmax>(n) ||
|
||||
ov::is_type<const ov::op::v0::MatMul>(n) ||
|
||||
ov::is_type<const ov::op::v1::Transpose>(n) ||
|
||||
ov::is_type<const ov::op::v1::Broadcast>(n) ||
|
||||
ov::is_type<const ov::op::v3::Broadcast>(n));
|
||||
const auto& inputs = n->inputs();
|
||||
// todo: clarify whether we can evaluate snippets on const paths
|
||||
const bool has_only_const_inputs = std::all_of(inputs.begin(), inputs.end(),
|
||||
[](const ov::Input<const ov::Node>& in) {
|
||||
return ov::is_type<ov::op::v0::Constant>(
|
||||
in.get_source_output().get_node_shared_ptr());
|
||||
});
|
||||
[](const ov::Input<const ov::Node>& in) {
|
||||
return ov::is_type<ov::op::v0::Constant>(
|
||||
in.get_source_output().get_node_shared_ptr());
|
||||
});
|
||||
// todo: clarify whether we can evaluate snippets on inputs with larger ranks
|
||||
auto rank_is_too_large = [](const ov::descriptor::Tensor& t) {
|
||||
// callback is called has_supported_in_out(), so it's safe to assume that the shapes are static
|
||||
|
@ -246,6 +246,9 @@ std::vector<std::string> disabledTestPatterns() {
|
||||
if (!InferenceEngine::with_cpu_x86_avx512_core_vnni() && !InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) {
|
||||
// MatMul in Snippets uses BRGEMM that supports i8 only on platforms with VNNI or AMX instructions
|
||||
retVector.emplace_back(R"(.*Snippets.*MatMulFQ.*)");
|
||||
retVector.emplace_back(R"(.*Snippets.*MatMul.*Quantized.*)");
|
||||
retVector.emplace_back(R"(.*Snippets.*MHAFQ.*)");
|
||||
retVector.emplace_back(R"(.*Snippets.*MHAINT8.*)");
|
||||
}
|
||||
if (!InferenceEngine::with_cpu_x86_avx512_core_amx_int8())
|
||||
//TODO: Issue 92895
|
||||
@ -254,6 +257,7 @@ std::vector<std::string> disabledTestPatterns() {
|
||||
if (!InferenceEngine::with_cpu_x86_avx512_core_amx_bf16() && !InferenceEngine::with_cpu_x86_bfloat16()) {
|
||||
// ignored for not supported bf16 platforms
|
||||
retVector.emplace_back(R"(.*smoke_Snippets_EnforcePrecision_bf16.*)");
|
||||
retVector.emplace_back(R"(.*smoke_Snippets_MHAWOTransposeEnforceBF16.*)");
|
||||
}
|
||||
|
||||
return retVector;
|
||||
|
@ -19,23 +19,7 @@ std::vector<std::vector<ov::PartialShape>> input_shapes{
|
||||
{{1, 1, 37, 23}, {1, 2, 23, 33}},
|
||||
{{1, 16, 384, 64}, {1, 16, 64, 384}}
|
||||
};
|
||||
static inline std::vector<std::vector<element::Type>> precisions(bool only_fp32 = true) {
|
||||
std::vector<std::vector<element::Type>> prc = {
|
||||
{element::f32, element::f32},
|
||||
};
|
||||
if (!only_fp32) {
|
||||
// In Snippets MatMul INT8 is supported only on VNNI/AMX platforms
|
||||
if (InferenceEngine::with_cpu_x86_avx512_core_vnni() || InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) {
|
||||
prc.emplace_back(std::vector<element::Type>{element::i8, element::i8});
|
||||
prc.emplace_back(std::vector<element::Type>{element::u8, element::i8});
|
||||
}
|
||||
// In Snippets MatMul BF16 is supported only on bf16/AMX platforms
|
||||
if (InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16()) {
|
||||
prc.emplace_back(std::vector<element::Type>{element::bf16, element::bf16});
|
||||
}
|
||||
}
|
||||
return prc;
|
||||
}
|
||||
|
||||
static inline std::vector<std::vector<element::Type>> quantized_precisions() {
|
||||
std::vector<std::vector<element::Type>> prc = {};
|
||||
// In Snippets MatMul INT8 is supported only on VNNI/AMX platforms
|
||||
@ -46,6 +30,21 @@ static inline std::vector<std::vector<element::Type>> quantized_precisions() {
|
||||
return prc;
|
||||
}
|
||||
|
||||
static inline std::vector<std::vector<element::Type>> precisions(bool only_fp32 = true) {
|
||||
std::vector<std::vector<element::Type>> prc = {
|
||||
{element::f32, element::f32},
|
||||
};
|
||||
if (!only_fp32) {
|
||||
auto quant = quantized_precisions();
|
||||
std::copy(quant.begin(), quant.end(), std::back_inserter(prc));
|
||||
// In Snippets MatMul BF16 is supported only on bf16/AMX platforms
|
||||
if (InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16()) {
|
||||
prc.emplace_back(std::vector<element::Type>{element::bf16, element::bf16});
|
||||
}
|
||||
}
|
||||
return prc;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, MatMul,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(input_shapes),
|
||||
|
@ -4,7 +4,9 @@
|
||||
|
||||
#include "snippets/mha.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
#include "test_utils/cpu_test_utils.hpp"
|
||||
#include "ie_plugin_config.hpp"
|
||||
#include "ie_system_conf.h"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
@ -15,22 +17,52 @@ namespace {
|
||||
|
||||
const std::vector<std::vector<ov::PartialShape>> inputShapes = {
|
||||
{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}},
|
||||
{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}},
|
||||
{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 1, 1}, {1, 128, 16, 64}},
|
||||
{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 1, 1, 128}, {1, 128, 16, 64}},
|
||||
{{2, 68, 6, 92}, {2, 68, 6, 92}, {1, 1, 68, 68}, {2, 68, 6, 92}},
|
||||
{{1, 58, 16, 34}, {1, 58, 16, 34}, {1, 1, 1, 58}, {1, 58, 16, 34}},
|
||||
};
|
||||
|
||||
static inline bool is_bf16_supported() {
|
||||
return InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16();
|
||||
}
|
||||
|
||||
static inline std::vector<std::vector<element::Type>> precision_f32(size_t count) {
|
||||
std::vector<std::vector<element::Type>> prc;
|
||||
prc.emplace_back(std::vector<element::Type>(count, element::f32));
|
||||
return prc;
|
||||
}
|
||||
|
||||
static inline std::vector<std::vector<element::Type>> precision_bf16(size_t count) {
|
||||
std::vector<std::vector<element::Type>> prc;
|
||||
if (is_bf16_supported())
|
||||
prc.emplace_back(std::vector<element::Type>(count, element::bf16));
|
||||
return prc;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHA,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapes),
|
||||
::testing::ValuesIn({false, true}),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(1),
|
||||
::testing::Values(1),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(std::map<std::string, std::string>{})),
|
||||
MHA::getTestCaseName);
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapes),
|
||||
::testing::ValuesIn(precision_f32(4)),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::ValuesIn({false, true}),
|
||||
::testing::Values(1),
|
||||
::testing::Values(1),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHA,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapes),
|
||||
::testing::ValuesIn(precision_bf16(4)),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::ValuesIn({false, true}),
|
||||
::testing::Values(7), // MHA + 5 Converts + 1 Transpose on output
|
||||
::testing::Values(6), // MHA + 5 Converts on inputs and output
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
const std::vector<std::vector<ov::PartialShape>> inputShapeSelect = {
|
||||
// without broadcast
|
||||
@ -44,64 +76,142 @@ const std::vector<std::vector<ov::PartialShape>> inputShapeSelect = {
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHA, MHASelect,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapeSelect),
|
||||
::testing::Values(false), // Need to support True for graph builder in tests
|
||||
::testing::ValuesIn(precision_f32(6)),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(false), // Need to support True for graph builder in tests
|
||||
::testing::Values(2), // Less + MHA
|
||||
::testing::Values(2),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(std::map<std::string, std::string>{})),
|
||||
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
const std::vector<std::vector<ov::PartialShape>> inputShapesWOTranspose_4D = {
|
||||
{{1, 12, 197, 64}, {1, 12, 64, 197}, {1, 12, 197, 64}},
|
||||
{{1, 12, 12, 64}, {1, 12, 64, 48}, {1, 12, 48, 64}}
|
||||
};
|
||||
const std::vector<std::vector<ov::PartialShape>> inputShapesWOTranspose_3D = {
|
||||
{{12, 197, 64}, {12, 64, 197}, {12, 197, 64}},
|
||||
{{12, 128, 100}, {12, 100, 128}, {12, 128, 100}}
|
||||
};
|
||||
|
||||
static std::vector<std::vector<ov::PartialShape>> inputShapesWOTranspose(bool supports_3d = false) {
|
||||
std::vector<std::vector<ov::PartialShape>> shapes = {
|
||||
{{1, 12, 197, 64}, {1, 12, 64, 197}, {1, 12, 197, 64}},
|
||||
{{1, 12, 12, 64}, {1, 12, 64, 48}, {1, 12, 48, 64}}
|
||||
};
|
||||
if (supports_3d) {
|
||||
std::vector<std::vector<ov::PartialShape>> shapes_3d = {
|
||||
{{12, 197, 64}, {12, 64, 197}, {12, 197, 64}},
|
||||
{{12, 128, 100}, {12, 100, 128}, {12, 128, 100}}
|
||||
};
|
||||
shapes.insert(shapes.end(), shapes_3d.begin(), shapes_3d.end());
|
||||
}
|
||||
return shapes;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeOnInputs, MHAWOTransposeOnInputs,
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeOnInputs_4D, MHAWOTransposeOnInputs,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesWOTranspose()),
|
||||
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
|
||||
::testing::ValuesIn(inputShapesWOTranspose_4D),
|
||||
::testing::Values(std::vector<ov::element::Type>{}),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(true), // Need to support False for graph builder in tests
|
||||
::testing::Values(1),
|
||||
::testing::Values(1),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(std::map<std::string, std::string>{})),
|
||||
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
const std::map<std::string, std::string> cpuBF16PluginConfig = { { InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16,
|
||||
InferenceEngine::PluginConfigParams::YES } };
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHABF16, MHAWOTranspose,
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTranspose_4D, MHAWOTranspose,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesWOTranspose(true)),
|
||||
::testing::ValuesIn(inputShapesWOTranspose_4D),
|
||||
::testing::ValuesIn(precision_f32(3)),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
|
||||
::testing::Values(1),
|
||||
::testing::Values(1),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTranspose_3D, MHAWOTranspose,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesWOTranspose_3D),
|
||||
::testing::ValuesIn(precision_f32(3)),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
|
||||
::testing::Values(1),
|
||||
::testing::Values(1),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeBF16_4D, MHAWOTranspose,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesWOTranspose_4D),
|
||||
::testing::ValuesIn(precision_bf16(3)),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
|
||||
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
|
||||
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeBF16_3D, MHAWOTranspose,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesWOTranspose_3D),
|
||||
::testing::ValuesIn(precision_bf16(3)),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
|
||||
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
|
||||
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeEnforceBF16_4D, MHAWOTranspose,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesWOTranspose_4D),
|
||||
::testing::ValuesIn(precision_f32(3)),
|
||||
::testing::Values(ov::element::bf16),
|
||||
::testing::Values(3),
|
||||
::testing::Values(0), // CPU plugin doesn't support MHA pattern via Snippets on bf16
|
||||
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
|
||||
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
|
||||
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(cpuBF16PluginConfig)),
|
||||
::testing::Values(CPUTestUtils::cpuBF16PluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTranspose, MHAWOTranspose,
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAWOTransposeEnforceBF16_3D, MHAWOTranspose,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapesWOTranspose(true)),
|
||||
::testing::ValuesIn(inputShapesWOTranspose_3D),
|
||||
::testing::ValuesIn(precision_f32(3)),
|
||||
::testing::Values(ov::element::bf16),
|
||||
::testing::ValuesIn({true}), // Need to support False for graph builder in tests
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(1),
|
||||
::testing::Values(1),
|
||||
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
|
||||
::testing::Values(5), // MHA + 4 extra Converts on inputs and output
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(std::map<std::string, std::string>{})),
|
||||
::testing::Values(CPUTestUtils::cpuBF16PluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAINT8MatMul, MHAINT8MatMul,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(std::vector<std::vector<ov::PartialShape>>(inputShapes.begin(), inputShapes.begin() + 2)),
|
||||
::testing::Values(std::vector<element::Type>{}),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(false), // The graph doesn't contain Multiply
|
||||
::testing::Values(6), // FQx3 on inputs + MHA + Transpose on output + Deq Mul
|
||||
::testing::Values(5), // FQx3 on inputs + MHA + Deq Mul
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFQAfterMatMul, MHAFQAfterMatMul,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inputShapes),
|
||||
::testing::Values(std::vector<element::Type>{}),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(false), // The graph doesn't contain Multiply
|
||||
::testing::Values(3), // MHA + Transpose on output + Deq Mul
|
||||
::testing::Values(2), // MHA + Deq Mul
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MHAFQ, MHAFQ,
|
||||
::testing::Combine(
|
||||
::testing::Values(std::vector<ov::PartialShape>{{1, 64, 12, 64}, {1, 64, 12, 64}, {1, 1, 1, 64}, {1, 64, 12, 64}}),
|
||||
::testing::Values(std::vector<element::Type>{}),
|
||||
::testing::Values(ov::element::f32),
|
||||
::testing::Values(false), // The graph doesn't contain Multiply
|
||||
::testing::Values(7), // Transposex2 + Subgraphsx5
|
||||
::testing::Values(5), // MHA + Deq Mul on output + Deqs on inputs + 2 xFQ on inputs
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig)),
|
||||
MHA::getTestCaseName);
|
||||
|
||||
|
||||
|
@ -575,7 +575,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant, MHAQuantTest,
|
||||
::testing::ValuesIn(inputPrecisionsQuant),
|
||||
::testing::ValuesIn(matMulIn0PrecisionsQuant),
|
||||
::testing::ValuesIn(patternTypesQuant),
|
||||
::testing::Values("MHA"), // Snippets don't support Quantized MHA pattern yet
|
||||
::testing::Values("MHA"),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
MHAQuantTest::getTestCaseName);
|
||||
|
||||
|
@ -12,8 +12,9 @@ namespace snippets {
|
||||
|
||||
typedef std::tuple<
|
||||
std::vector<ov::PartialShape>, // Input shapes
|
||||
bool, // With Multiply
|
||||
std::vector<ov::element::Type>, // Input Element types
|
||||
ov::element::Type, // Inference precision
|
||||
bool, // With Multiply
|
||||
size_t, // Expected num nodes
|
||||
size_t, // Expected num subgraphs
|
||||
std::string, // Target Device
|
||||
@ -32,6 +33,7 @@ protected:
|
||||
virtual void init_subgraph();
|
||||
|
||||
bool m_with_mul = false;
|
||||
std::vector<ov::element::Type> m_input_types;
|
||||
};
|
||||
|
||||
class MHASelect : public MHA {
|
||||
@ -46,6 +48,22 @@ protected:
|
||||
};
|
||||
|
||||
class MHAWOTranspose : public MHA {
|
||||
protected:
|
||||
void init_subgraph() override;
|
||||
};
|
||||
|
||||
class MHAINT8MatMul : public MHA {
|
||||
protected:
|
||||
void init_subgraph() override;
|
||||
};
|
||||
|
||||
class MHAFQAfterMatMul : public MHA {
|
||||
protected:
|
||||
void init_subgraph() override;
|
||||
};
|
||||
|
||||
class MHAFQ : public MHA {
|
||||
protected:
|
||||
void init_subgraph() override;
|
||||
};
|
||||
|
||||
|
@ -15,16 +15,19 @@ namespace snippets {
|
||||
|
||||
std::string MHA::getTestCaseName(testing::TestParamInfo<ov::test::snippets::MHAParams> obj) {
|
||||
std::vector<ov::PartialShape> inputShapes;
|
||||
bool withMul;
|
||||
std::vector<ov::element::Type> elem_types;
|
||||
ov::element::Type prc;
|
||||
bool withMul;
|
||||
std::string targetDevice;
|
||||
size_t num_nodes, num_subgraphs;
|
||||
std::map<std::string, std::string> additionalConfig;
|
||||
std::tie(inputShapes, withMul, prc, num_nodes, num_subgraphs, targetDevice, additionalConfig) = obj.param;
|
||||
std::tie(inputShapes, elem_types, prc, withMul, num_nodes, num_subgraphs, targetDevice, additionalConfig) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
for (size_t i = 0; i < inputShapes.size(); ++i)
|
||||
result << "IS[" << i << "]=" << CommonTestUtils::partialShape2str({inputShapes[i]}) << "_";
|
||||
for (size_t i = 0; i < elem_types.size(); i++)
|
||||
result << "T[" << i <<"]=" << elem_types[i] << "_";
|
||||
result << "Mul=" << withMul << "_";
|
||||
result << "PRC=" << prc << "_";
|
||||
result << "#N=" << num_nodes << "_";
|
||||
@ -45,13 +48,13 @@ void MHA::SetUp() {
|
||||
std::vector<ov::PartialShape> inputShapes;
|
||||
ov::element::Type prc;
|
||||
std::map<std::string, std::string> additionalConfig;
|
||||
std::tie(inputShapes, m_with_mul, prc, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam();
|
||||
std::tie(inputShapes, m_input_types, prc, m_with_mul, ref_num_nodes, ref_num_subgraphs, targetDevice, additionalConfig) = this->GetParam();
|
||||
init_input_shapes(static_partial_shapes_to_test_representation(inputShapes));
|
||||
|
||||
init_subgraph();
|
||||
|
||||
configuration.insert(additionalConfig.begin(), additionalConfig.end());
|
||||
if (additionalConfig.empty() && !configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
|
||||
if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) {
|
||||
configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE,
|
||||
InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK});
|
||||
}
|
||||
@ -59,7 +62,7 @@ void MHA::SetUp() {
|
||||
setInferenceType(prc);
|
||||
inType = outType = prc;
|
||||
if (prc == ov::element::bf16)
|
||||
abs_threshold = 0.3;
|
||||
rel_threshold = 0.05f;
|
||||
}
|
||||
|
||||
void MHA::generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) {
|
||||
@ -68,13 +71,13 @@ void MHA::generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticSha
|
||||
for (int i = 0; i < model_inputs.size(); ++i) {
|
||||
const auto& model_input = model_inputs[i];
|
||||
ov::Tensor tensor;
|
||||
tensor = ov::test::utils::create_and_fill_tensor_normal_distribution(model_input.get_element_type(), targetInputStaticShapes[i], 1.0f, 0.5f);
|
||||
tensor = ov::test::utils::create_and_fill_tensor(model_input.get_element_type(), model_input.get_shape(), 2, -1, 256);
|
||||
inputs.insert({model_input.get_node_shared_ptr(), tensor});
|
||||
}
|
||||
}
|
||||
|
||||
void MHA::init_subgraph() {
|
||||
auto f = ov::test::snippets::MHAFunction(inputDynamicShapes, m_with_mul);
|
||||
auto f = ov::test::snippets::MHAFunction(inputDynamicShapes, m_input_types, m_with_mul);
|
||||
function = f.getOriginal();
|
||||
}
|
||||
|
||||
@ -90,14 +93,14 @@ void MHASelect::generate_inputs(const std::vector<ngraph::Shape>& targetInputSta
|
||||
tensor = ov::test::utils::create_and_fill_tensor(model_input.get_element_type(), model_input.get_shape(), 5 + seed, -2, 10, seed);
|
||||
seed++;
|
||||
} else {
|
||||
tensor = ov::test::utils::create_and_fill_tensor_normal_distribution(model_input.get_element_type(), model_input.get_shape(), 1.0f, 0.5f);
|
||||
tensor = ov::test::utils::create_and_fill_tensor(model_input.get_element_type(), model_input.get_shape(), 2, -1, 256);
|
||||
}
|
||||
inputs.insert({node_input, tensor});
|
||||
}
|
||||
}
|
||||
|
||||
void MHASelect::init_subgraph() {
|
||||
auto f = ov::test::snippets::MHASelectFunction(inputDynamicShapes);
|
||||
auto f = ov::test::snippets::MHASelectFunction(inputDynamicShapes, m_input_types);
|
||||
function = f.getOriginal();
|
||||
}
|
||||
|
||||
@ -107,7 +110,22 @@ void MHAWOTransposeOnInputs::init_subgraph() {
|
||||
}
|
||||
|
||||
void MHAWOTranspose::init_subgraph() {
|
||||
auto f = ov::test::snippets::MHAWOTransposeFunction(inputDynamicShapes);
|
||||
auto f = ov::test::snippets::MHAWOTransposeFunction(inputDynamicShapes, m_input_types);
|
||||
function = f.getOriginal();
|
||||
}
|
||||
|
||||
void MHAINT8MatMul::init_subgraph() {
|
||||
auto f = ov::test::snippets::MHAINT8MatMulFunction(inputDynamicShapes);
|
||||
function = f.getOriginal();
|
||||
}
|
||||
|
||||
void MHAFQAfterMatMul::init_subgraph() {
|
||||
auto f = ov::test::snippets::MHAFQAfterMatMulFunction(inputDynamicShapes);
|
||||
function = f.getOriginal();
|
||||
}
|
||||
|
||||
void MHAFQ::init_subgraph() {
|
||||
auto f = ov::test::snippets::MHAFQFunction(inputDynamicShapes);
|
||||
function = f.getOriginal();
|
||||
}
|
||||
|
||||
@ -134,6 +152,20 @@ TEST_P(MHAWOTranspose, CompareWithRefImpl) {
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(MHAINT8MatMul, CompareWithRefImpl) {
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(MHAFQAfterMatMul, CompareWithRefImpl) {
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
TEST_P(MHAFQ, CompareWithRefImpl) {
|
||||
run();
|
||||
validateNumSubgraphs();
|
||||
}
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
|
@ -26,9 +26,9 @@ public:
|
||||
explicit MatMulFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
|
||||
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 2, "Got invalid number of input shapes");
|
||||
verify_precisions(precisions);
|
||||
validate_precisions(precisions);
|
||||
}
|
||||
static void verify_precisions(const std::vector<ov::element::Type>& precisions) {
|
||||
static void validate_precisions(const std::vector<ov::element::Type>& precisions) {
|
||||
NGRAPH_CHECK(precisions.size() == 2, "Got invalid number of input element types");
|
||||
const bool is_f32 = ov::snippets::utils::everyone_is(element::f32, precisions[0], precisions[1]);
|
||||
const bool is_int8 = ov::snippets::utils::one_of(precisions[0], element::i8, element::u8) && precisions[1] == element::i8;
|
||||
@ -62,7 +62,7 @@ public:
|
||||
explicit MatMulBiasFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
|
||||
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
|
||||
MatMulFunction::verify_precisions(precisions);
|
||||
MatMulFunction::validate_precisions(precisions);
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
@ -70,7 +70,6 @@ protected:
|
||||
std::vector<ov::element::Type> precisions;
|
||||
};
|
||||
|
||||
|
||||
// Quantized MatMul
|
||||
// FQ[I8]
|
||||
// Add
|
||||
@ -79,7 +78,7 @@ public:
|
||||
explicit MatMulBiasQuantizedFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
|
||||
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
|
||||
MatMulFunction::verify_precisions(precisions);
|
||||
MatMulFunction::validate_precisions(precisions);
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
@ -97,7 +96,7 @@ public:
|
||||
explicit MatMulsQuantizedFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
|
||||
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
|
||||
MatMulFunction::verify_precisions(precisions);
|
||||
MatMulFunction::validate_precisions(precisions);
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
@ -121,7 +120,7 @@ public:
|
||||
NGRAPH_CHECK(input_shapes[0].rank().get_length() == 4 && input_shapes[1].rank().get_length() == 4,
|
||||
"Only rank 4 input shapes are supported by this test");
|
||||
NGRAPH_CHECK(transpose_position >=0 && transpose_position <= 2, "Got invalid transpose position");
|
||||
MatMulFunction::verify_precisions(precisions);
|
||||
MatMulFunction::validate_precisions(precisions);
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
@ -166,7 +165,7 @@ public:
|
||||
explicit MatMulsQuantizedSoftmaxFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
|
||||
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
|
||||
MatMulFunction::verify_precisions(precisions);
|
||||
MatMulFunction::validate_precisions(precisions);
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
|
@ -43,15 +43,17 @@ namespace snippets {
|
||||
*/
|
||||
class MHAFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHAFunction(const std::vector<PartialShape>& inputShapes, bool with_mul = true)
|
||||
: SnippetsFunctionBase(inputShapes), with_mul(with_mul) {
|
||||
explicit MHAFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions, bool with_mul = true)
|
||||
: SnippetsFunctionBase(inputShapes), with_mul(with_mul), precisions(precisions) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
|
||||
NGRAPH_CHECK(precisions.size() == 4, "Got invalid number of input precisions");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
std::shared_ptr<ov::Model> initReference() const override;
|
||||
|
||||
bool with_mul = true;
|
||||
std::vector<ov::element::Type> precisions;
|
||||
};
|
||||
|
||||
/* Graph:
|
||||
@ -71,13 +73,16 @@ protected:
|
||||
*/
|
||||
class MHAMatMul0TransposeFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHAMatMul0TransposeFunction(const std::vector<PartialShape>& inputShapes)
|
||||
: SnippetsFunctionBase(inputShapes) {
|
||||
explicit MHAMatMul0TransposeFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
|
||||
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
|
||||
NGRAPH_CHECK(precisions.size() == 4, "Got invalid number of input precisions");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
std::shared_ptr<ov::Model> initReference() const override;
|
||||
|
||||
std::vector<ov::element::Type> precisions;
|
||||
};
|
||||
|
||||
/* Graph:
|
||||
@ -97,11 +102,15 @@ protected:
|
||||
*/
|
||||
class MHASelectFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHASelectFunction(const std::vector<PartialShape>& inputShapes) : SnippetsFunctionBase(inputShapes) {
|
||||
explicit MHASelectFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
|
||||
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 6, "Got invalid number of input shapes");
|
||||
NGRAPH_CHECK(precisions.size() == 6, "Got invalid number of input precisions");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
|
||||
std::vector<ov::element::Type> precisions;
|
||||
};
|
||||
|
||||
/* Graph:
|
||||
@ -137,13 +146,128 @@ protected:
|
||||
*/
|
||||
class MHAWOTransposeFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHAWOTransposeFunction(const std::vector<PartialShape>& inputShapes) : SnippetsFunctionBase(inputShapes) {
|
||||
explicit MHAWOTransposeFunction(const std::vector<PartialShape>& inputShapes, const std::vector<ov::element::Type>& precisions)
|
||||
: SnippetsFunctionBase(inputShapes), precisions(precisions) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes");
|
||||
NGRAPH_CHECK(precisions.size() == 3, "Got invalid number of input precisions");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
|
||||
std::vector<ov::element::Type> precisions;
|
||||
};
|
||||
|
||||
/* Graph:
|
||||
* Transpose0[0,2,1,3] Transpose1[0,2,3,1]
|
||||
* \ /
|
||||
* MatMul0
|
||||
* FakeQuantize i8
|
||||
* \ /
|
||||
* Add
|
||||
* Reshape0
|
||||
* Softmax
|
||||
* Reshape1 Transpose2[0,2,1,3]
|
||||
* \ /
|
||||
* MatMul1
|
||||
* FakeQuantize i8
|
||||
* Transpose3[0,2,1,3]
|
||||
*/
|
||||
class MHAFQAfterMatMulFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHAFQAfterMatMulFunction(const std::vector<PartialShape>& inputShapes)
|
||||
: SnippetsFunctionBase(inputShapes) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
};
|
||||
|
||||
/* Graph:
|
||||
* FakeQuantize i8 FakeQuantize i8
|
||||
* Transpose0[0,2,1,3] Transpose1[0,2,3,1]
|
||||
* \ /
|
||||
* MatMul0
|
||||
* FakeQuantize i8
|
||||
* \ /
|
||||
* Add
|
||||
* Reshape0
|
||||
* Softmax
|
||||
* Reshape1 FakeQuantize i8
|
||||
* FakeQuantize u8 Transpose2[0,2,1,3]
|
||||
* \ /
|
||||
* MatMul1
|
||||
* FakeQuantize i8
|
||||
* Transpose3[0,2,1,3]
|
||||
*/
|
||||
class MHAINT8MatMulFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHAINT8MatMulFunction(const std::vector<PartialShape>& inputShapes)
|
||||
: SnippetsFunctionBase(inputShapes) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
};
|
||||
|
||||
/* Graph:
|
||||
* Constant
|
||||
* FakeQuantize u8 FakeQuantize u8 Convert
|
||||
* Transpose0[0,2,1,3] Transpose1[0,2,3,1] Multiply
|
||||
* \ \ /
|
||||
* \ Multiply
|
||||
* \ FakeQuantize f32
|
||||
* \ /
|
||||
* MatMul0
|
||||
* FakeQuantize f32 FakeQuantize u8
|
||||
* \ /
|
||||
* Add
|
||||
* Softmax Transpose2[0,2,1,3]
|
||||
* \ /
|
||||
* MatMul1
|
||||
* FakeQuantize u8
|
||||
* Transpose3[0,2,1,3]
|
||||
* Note: Check a lot of different FQ (the both quantized and floating) - buffers with different size and precision
|
||||
*/
|
||||
class MHAFQFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHAFQFunction(const std::vector<PartialShape>& inputShapes)
|
||||
: SnippetsFunctionBase(inputShapes) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
};
|
||||
|
||||
// Only for tokenization! The graph is after LPT: contains TypeRelaxed ops
|
||||
/* Graph:
|
||||
* FakeQuantize i8 FakeQuantize i8
|
||||
* Transpose0[0,2,1,3] Transpose1[0,2,3,1]
|
||||
* \ /
|
||||
* MatMul0
|
||||
* FakeQuantize i8
|
||||
* \ /
|
||||
* Add
|
||||
* Mul (DeQuantize)
|
||||
* Reshape0
|
||||
* Softmax
|
||||
* Reshape1 FakeQuantize i8
|
||||
* FakeQuantize u8 Transpose2[0,2,1,3]
|
||||
* \ /
|
||||
* MatMul1
|
||||
* FakeQuantize i8
|
||||
* Transpose3[0,2,1,3]
|
||||
*/
|
||||
class MHAINT8MatMulTypeRelaxedFunction : public SnippetsFunctionBase {
|
||||
public:
|
||||
explicit MHAINT8MatMulTypeRelaxedFunction(const std::vector<PartialShape>& inputShapes)
|
||||
: SnippetsFunctionBase(inputShapes) {
|
||||
NGRAPH_CHECK(input_shapes.size() == 4, "Got invalid number of input shapes");
|
||||
}
|
||||
protected:
|
||||
std::shared_ptr<ov::Model> initOriginal() const override;
|
||||
std::shared_ptr<ov::Model> initReference() const override;
|
||||
};
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
|
@ -230,4 +230,4 @@ std::shared_ptr<ov::Model> MatMulsQuantizedSoftmaxFunction::initOriginal() const
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
} // namespace ov
|
||||
|
@ -7,16 +7,18 @@
|
||||
#include "common_test_utils/data_utils.hpp"
|
||||
#include <snippets/op/subgraph.hpp>
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "ov_ops/type_relaxed.hpp"
|
||||
#include "lpt_ngraph_functions/common/builders.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace test {
|
||||
namespace snippets {
|
||||
|
||||
std::shared_ptr<ov::Model> MHAFunction::initOriginal() const {
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precisions[0], input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precisions[1], input_shapes[1]);
|
||||
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precisions[2], input_shapes[2]);
|
||||
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precisions[3], input_shapes[3]);
|
||||
ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param};
|
||||
|
||||
std::vector<ov::Shape> constantShapes;
|
||||
@ -51,7 +53,7 @@ std::shared_ptr<ov::Model> MHAFunction::initOriginal() const {
|
||||
std::shared_ptr<ov::Node> matmul_parent1 = transpose1;
|
||||
if (with_mul) {
|
||||
std::vector<float> mulConstData(ngraph::shape_size(constantShapes[2]));
|
||||
auto mulConst = ngraph::builder::makeConstant(precision, constantShapes[2], mulConstData, true);
|
||||
auto mulConst = ngraph::builder::makeConstant(precisions[1], constantShapes[2], mulConstData, true);
|
||||
matmul_parent1 = std::make_shared<ngraph::opset3::Multiply>(transpose1, mulConst);
|
||||
}
|
||||
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, matmul_parent1, transA, transB);
|
||||
@ -67,17 +69,17 @@ std::shared_ptr<ov::Model> MHAFunction::initOriginal() const {
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
std::shared_ptr<ov::Model> MHAFunction::initReference() const {
|
||||
auto data0 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
auto data3 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
|
||||
auto data0 = std::make_shared<ngraph::opset1::Parameter>(precisions[0], input_shapes[0]);
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(precisions[1], input_shapes[1]);
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(precisions[2], input_shapes[2]);
|
||||
auto data3 = std::make_shared<ngraph::opset1::Parameter>(precisions[3], input_shapes[3]);
|
||||
ngraph::ParameterVector ngraphParams = {data0, data1, data2, data3};
|
||||
NodeVector subgraph_inputs = {data0, data1, data2, data3};
|
||||
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precisions[0], input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precisions[1], input_shapes[1]);
|
||||
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precisions[2], input_shapes[2]);
|
||||
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precisions[3], input_shapes[3]);
|
||||
|
||||
std::vector<ov::Shape> constantShapes;
|
||||
constantShapes.push_back(ov::Shape({input_shapes[0].get_shape().size()}));
|
||||
@ -113,8 +115,8 @@ std::shared_ptr<ov::Model> MHAFunction::initReference() const {
|
||||
std::shared_ptr<ov::Node> matmul_parent1 = transpose1;
|
||||
if (with_mul) {
|
||||
std::vector<float> mulConstData(ngraph::shape_size(constantShapes[2]));
|
||||
auto mulConst = ngraph::builder::makeConstant(precision, constantShapes[2], mulConstData, true);
|
||||
auto mulParam = std::make_shared<ngraph::opset1::Parameter>(precision, mulConst->get_shape());
|
||||
auto mulConst = ngraph::builder::makeConstant(precisions[1], constantShapes[2], mulConstData, true);
|
||||
auto mulParam = std::make_shared<ngraph::opset1::Parameter>(precisions[1], mulConst->get_shape());
|
||||
matmul_parent1 = std::make_shared<ngraph::opset3::Multiply>(transpose1, mulParam);
|
||||
subgraph_params = {transpose0Param, transpose1Param, mulParam, addParam, transpose2Param};
|
||||
subgraph_inputs = {data0, data1, mulConst, data2, data3};
|
||||
@ -135,10 +137,10 @@ std::shared_ptr<ov::Model> MHAFunction::initReference() const {
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> MHAMatMul0TransposeFunction::initOriginal() const {
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precisions[0], input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precisions[1], input_shapes[1]);
|
||||
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precisions[2], input_shapes[2]);
|
||||
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precisions[3], input_shapes[3]);
|
||||
ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param};
|
||||
|
||||
std::vector<ov::Shape> constantShapes;
|
||||
@ -157,7 +159,7 @@ std::shared_ptr<ov::Model> MHAMatMul0TransposeFunction::initOriginal() const {
|
||||
auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, constantShapes[6], order);
|
||||
|
||||
std::vector<float> mulConstData(1);
|
||||
auto mulConst = ngraph::builder::makeConstant(precision, ov::Shape{1}, mulConstData, true);
|
||||
auto mulConst = ngraph::builder::makeConstant(precisions[1], ov::Shape{1}, mulConstData, true);
|
||||
|
||||
std::vector<int64_t> reshape0ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0] *
|
||||
input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]),
|
||||
@ -188,16 +190,16 @@ std::shared_ptr<ov::Model> MHAMatMul0TransposeFunction::initOriginal() const {
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
std::shared_ptr<ov::Model> MHAMatMul0TransposeFunction::initReference() const {
|
||||
auto data0 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
auto data3 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
|
||||
auto data0 = std::make_shared<ngraph::opset1::Parameter>(precisions[0], input_shapes[0]);
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(precisions[1], input_shapes[1]);
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(precisions[2], input_shapes[2]);
|
||||
auto data3 = std::make_shared<ngraph::opset1::Parameter>(precisions[3], input_shapes[3]);
|
||||
ngraph::ParameterVector ngraphParams = {data0, data1, data2, data3};
|
||||
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precisions[0], input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precisions[1], input_shapes[1]);
|
||||
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precisions[2], input_shapes[2]);
|
||||
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precisions[3], input_shapes[3]);
|
||||
|
||||
std::vector<ov::Shape> constantShapes;
|
||||
constantShapes.push_back(ov::Shape({input_shapes[0].get_shape().size()}));
|
||||
@ -214,7 +216,7 @@ std::shared_ptr<ov::Model> MHAMatMul0TransposeFunction::initReference() const {
|
||||
auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, constantShapes[6], std::vector<int64_t>{0, 2, 1, 3});
|
||||
|
||||
std::vector<float> mulConstData(1);
|
||||
auto mulConst = ngraph::builder::makeConstant(precision, ov::Shape{1}, mulConstData, true);
|
||||
auto mulConst = ngraph::builder::makeConstant(precisions[1], ov::Shape{1}, mulConstData, true);
|
||||
ngraph::ParameterVector subgraphParams = {transpose0Param, transpose1Param, addParam, transpose2Param};
|
||||
|
||||
std::vector<int64_t> reshape0ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0] *
|
||||
@ -250,12 +252,12 @@ std::shared_ptr<ov::Model> MHAMatMul0TransposeFunction::initReference() const {
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> MHASelectFunction::initOriginal() const {
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
auto less0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
|
||||
auto less1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[4]);
|
||||
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[5]);
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precisions[0], input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precisions[1], input_shapes[1]);
|
||||
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precisions[2], input_shapes[2]);
|
||||
auto less0Param = std::make_shared<ngraph::opset1::Parameter>(precisions[3], input_shapes[3]);
|
||||
auto less1Param = std::make_shared<ngraph::opset1::Parameter>(precisions[4], input_shapes[4]);
|
||||
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precisions[5], input_shapes[5]);
|
||||
ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, less0Param, less1Param, transpose2Param};
|
||||
|
||||
std::vector<ov::Shape> constantShapes;
|
||||
@ -288,7 +290,7 @@ std::shared_ptr<ov::Model> MHASelectFunction::initOriginal() const {
|
||||
static_cast<int64_t>(input_shapes[0].get_shape()[1])};
|
||||
auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, constantShapes[4], reshape1ConstData);
|
||||
// Value is equal to '1' - to avoid situation e^(-1000) / (sum(e^(-1000)) = 0/0 = NAN
|
||||
auto selectConst = ngraph::builder::makeConstant(precision, ov::Shape{1}, std::vector<float>{1});
|
||||
auto selectConst = ngraph::builder::makeConstant(precisions[2], ov::Shape{1}, std::vector<float>{1});
|
||||
|
||||
float transA = false;
|
||||
float transB = false;
|
||||
@ -344,9 +346,9 @@ std::shared_ptr<ov::Model> MHAWOTransposeOnInputsFunction::initOriginal() const
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Model> MHAWOTransposeFunction::initOriginal() const {
|
||||
auto param0 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto param1 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
auto param2 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
auto param0 = std::make_shared<ngraph::opset1::Parameter>(precisions[0], input_shapes[0]);
|
||||
auto param1 = std::make_shared<ngraph::opset1::Parameter>(precisions[1], input_shapes[1]);
|
||||
auto param2 = std::make_shared<ngraph::opset1::Parameter>(precisions[2], input_shapes[2]);
|
||||
ngraph::ParameterVector ngraphParam = {param0, param1, param2};
|
||||
|
||||
float transA = false;
|
||||
@ -359,6 +361,302 @@ std::shared_ptr<ov::Model> MHAWOTransposeFunction::initOriginal() const {
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
|
||||
|
||||
std::shared_ptr<ov::Model> MHAFQAfterMatMulFunction::initOriginal() const {
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
|
||||
ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param};
|
||||
|
||||
const auto shape_rank = input_shapes[0].get_shape().size();
|
||||
auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
|
||||
auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 3, 1});
|
||||
auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
|
||||
auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
|
||||
|
||||
std::vector<int64_t> reshape0ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0] *
|
||||
input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]),
|
||||
-1};
|
||||
auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape0ConstData.size()}, reshape0ConstData);
|
||||
|
||||
std::vector<int64_t> reshape1ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0]),
|
||||
static_cast<int64_t>(input_shapes[0].get_shape()[2]),
|
||||
static_cast<int64_t>(input_shapes[0].get_shape()[1]),
|
||||
static_cast<int64_t>(input_shapes[0].get_shape()[1])};
|
||||
auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape1ConstData.size()}, reshape1ConstData);
|
||||
|
||||
float transA = false;
|
||||
float transB = false;
|
||||
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(transpose0Param, transpose0Const);
|
||||
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(transpose1Param, transpose1Const);
|
||||
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, transpose1, transA, transB);
|
||||
auto fq0 = ngraph::builder::makeFakeQuantize(matMul0, ov::element::f32, 256, {1},
|
||||
{-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294});
|
||||
const auto add = std::make_shared<ngraph::opset3::Add>(fq0, addParam);
|
||||
const auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(add, reshape0Const, true);
|
||||
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(reshape0, 1);
|
||||
const auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softMax, reshape1Const, true);
|
||||
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(transpose2Param, transpose2Const);
|
||||
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(reshape1, transpose2, transA, transB);
|
||||
auto fq1 = ngraph::builder::makeFakeQuantize(matMul1, ov::element::f32, 256, {1},
|
||||
{-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294});
|
||||
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(fq1, transpose3Const);
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
std::shared_ptr<ov::Model> MHAINT8MatMulFunction::initOriginal() const {
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
|
||||
ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param};
|
||||
|
||||
const auto shape_rank = input_shapes[0].get_shape().size();
|
||||
auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
|
||||
auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 3, 1});
|
||||
auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
|
||||
auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
|
||||
|
||||
std::vector<int64_t> reshape0ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0] *
|
||||
input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]),
|
||||
-1};
|
||||
auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape0ConstData.size()}, reshape0ConstData);
|
||||
|
||||
std::vector<int64_t> reshape1ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0]),
|
||||
static_cast<int64_t>(input_shapes[0].get_shape()[2]),
|
||||
static_cast<int64_t>(input_shapes[0].get_shape()[1]),
|
||||
static_cast<int64_t>(input_shapes[0].get_shape()[1])};
|
||||
auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape1ConstData.size()}, reshape1ConstData);
|
||||
|
||||
auto fq0 = ngraph::builder::makeFakeQuantize(transpose0Param, ov::element::f32, 256, {1},
|
||||
{-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294});
|
||||
auto fq1 = ngraph::builder::makeFakeQuantize(transpose1Param, ov::element::f32, 256, {1},
|
||||
{-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294});
|
||||
auto fq2 = ngraph::builder::makeFakeQuantize(transpose2Param, ov::element::f32, 256, {1},
|
||||
{-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294});
|
||||
float transA = false;
|
||||
float transB = false;
|
||||
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(fq0, transpose0Const);
|
||||
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(fq1, transpose1Const);
|
||||
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, transpose1, transA, transB);
|
||||
auto fq3 = ngraph::builder::makeFakeQuantize(matMul0, ov::element::f32, 256, {1},
|
||||
{-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294});
|
||||
const auto add = std::make_shared<ngraph::opset3::Add>(fq3, addParam);
|
||||
const auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(add, reshape0Const, true);
|
||||
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(reshape0, 1);
|
||||
const auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softMax, reshape1Const, true);
|
||||
auto fq4 = ngraph::builder::makeFakeQuantize(reshape1, ov::element::f32, 256, {1},
|
||||
{0}, {0.820726}, {0}, {0.820726});
|
||||
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(fq2, transpose2Const);
|
||||
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(fq4, transpose2, transA, transB);
|
||||
auto fq5 = ngraph::builder::makeFakeQuantize(matMul1, ov::element::f32, 256, {1},
|
||||
{-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294});
|
||||
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(fq5, transpose3Const);
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
std::shared_ptr<ov::Model> MHAFQFunction::initOriginal() const {
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
|
||||
ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param};
|
||||
|
||||
const auto shape_rank = input_shapes[0].get_shape().size();
|
||||
auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
|
||||
auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 3, 1});
|
||||
auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
|
||||
auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
|
||||
|
||||
const auto fq0 = ngraph::builder::makeFakeQuantize(transpose0Param, ov::element::f32, 256, {1},
|
||||
{-5.217694}, {6.661877}, {-5.217694}, {6.661877});
|
||||
const auto fq1 = ngraph::builder::makeFakeQuantize(transpose1Param, ov::element::f32, 256, {1},
|
||||
{-6.40245}, {6.45286}, {-6.40245}, {6.45286});
|
||||
const auto fq_add = ngraph::builder::makeFakeQuantize(addParam, ov::element::f32, 256, {1},
|
||||
{-1000}, {0}, {-1000}, {0});
|
||||
|
||||
float transA = false;
|
||||
float transB = false;
|
||||
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(fq0, transpose0Const);
|
||||
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(fq1, transpose1Const);
|
||||
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(transpose2Param, transpose2Const);
|
||||
const auto mul_const = ngraph::builder::makeConstant(ov::element::i8, ov::Shape{1}, std::vector<int8_t>{127});
|
||||
const auto convert = std::make_shared<ngraph::opset1::Convert>(mul_const, ov::element::f32);
|
||||
const auto mul_deq_const = ngraph::builder::makeConstant(ov::element::f32, ov::Shape{1}, std::vector<float>{0.00098425});
|
||||
const auto mul_deq = std::make_shared<ngraph::opset1::Multiply>(convert, mul_deq_const);
|
||||
const auto mul = std::make_shared<ngraph::opset1::Multiply>(transpose1, mul_deq);
|
||||
auto fq1_1 = ngraph::builder::makeFakeQuantize(mul, ov::element::f32, 256, {1},
|
||||
{-0.8003067}, {0.8066083}, {-0.8003067}, {0.8066083});
|
||||
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, fq1_1, transA, transB);
|
||||
auto fq2 = ngraph::builder::makeFakeQuantize(matMul0, ov::element::f32, 256, {1},
|
||||
{-14.50351}, {17.65645}, {-14.50351}, {17.65645});
|
||||
const auto add = std::make_shared<ngraph::opset1::Add>(fq2, fq_add);
|
||||
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(add, 3);
|
||||
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softMax, transpose2, transA, transB);
|
||||
auto fq3 = ngraph::builder::makeFakeQuantize(matMul1, ov::element::f32, 256, {1},
|
||||
{-1.895786}, {2.0028071}, {-1.895786}, {2.0028071});
|
||||
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(fq3, transpose3Const);
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
std::shared_ptr<ov::Model> MHAINT8MatMulTypeRelaxedFunction::initOriginal() const {
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
|
||||
ngraph::ParameterVector ngraphParam = {transpose0Param, transpose1Param, addParam, transpose2Param};
|
||||
|
||||
const auto shape_rank = input_shapes[0].get_shape().size();
|
||||
auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
|
||||
auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 3, 1});
|
||||
auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
|
||||
auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
|
||||
|
||||
std::vector<int64_t> reshape0ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0] *
|
||||
input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]),
|
||||
-1};
|
||||
auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape0ConstData.size()}, reshape0ConstData);
|
||||
|
||||
std::vector<int64_t> reshape1ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0]),
|
||||
static_cast<int64_t>(input_shapes[0].get_shape()[2]),
|
||||
static_cast<int64_t>(input_shapes[0].get_shape()[1]),
|
||||
static_cast<int64_t>(input_shapes[0].get_shape()[1])};
|
||||
auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape1ConstData.size()}, reshape1ConstData);
|
||||
|
||||
const auto fq_signed_params = ngraph::builder::subgraph::FakeQuantizeOnData(256, {1}, {-36912.66015625}, {36624.28125}, {-128}, {127}, ov::element::i8);
|
||||
const auto fq0 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(transpose0Param, ov::element::i8, fq_signed_params);
|
||||
const auto fq1 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(transpose1Param, ov::element::i8, fq_signed_params);
|
||||
const auto fq2 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(transpose2Param, ov::element::i8, fq_signed_params);
|
||||
|
||||
float transA = false;
|
||||
float transB = false;
|
||||
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(fq0, transpose0Const);
|
||||
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(fq1, transpose1Const);
|
||||
const auto matMul0 = std::make_shared<op::TypeRelaxed<op::v0::MatMul>>(
|
||||
std::vector<element::Type>{ element::f32, element::f32 },
|
||||
std::vector<element::Type>{ element::f32 },
|
||||
ov::op::TemporaryReplaceOutputType(transpose0, element::f32).get(),
|
||||
ov::op::TemporaryReplaceOutputType(transpose1, element::f32).get(), transA, transB);
|
||||
|
||||
const auto fq3 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul0, ov::element::i8, fq_signed_params);
|
||||
const auto add = std::make_shared<op::TypeRelaxed<ngraph::opset3::Add>>(
|
||||
std::vector<element::Type>{ element::f32, element::f32 },
|
||||
std::vector<element::Type>{ element::f32 },
|
||||
ov::op::TemporaryReplaceOutputType(fq3, element::f32).get(),
|
||||
ov::op::TemporaryReplaceOutputType(addParam, element::f32).get());
|
||||
const auto deq = std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1}, std::vector<float>{0.1122});
|
||||
const auto deq_mul = std::make_shared<op::TypeRelaxed<ngraph::opset3::Multiply>>(
|
||||
std::vector<element::Type>{ element::f32, element::f32 },
|
||||
std::vector<element::Type>{ element::f32 },
|
||||
ov::op::TemporaryReplaceOutputType(add, element::f32).get(),
|
||||
ov::op::TemporaryReplaceOutputType(deq, element::f32).get());
|
||||
|
||||
const auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(add, reshape0Const, true);
|
||||
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(reshape0, 1);
|
||||
const auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softMax, reshape1Const, true);
|
||||
|
||||
const auto fq_unsigned_params = ngraph::builder::subgraph::FakeQuantizeOnData(256, {1}, {0}, {0.245}, {0}, {255}, ov::element::u8);
|
||||
const auto fq4 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(reshape1, ov::element::u8, fq_unsigned_params);
|
||||
|
||||
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(fq2, transpose2Const);
|
||||
const auto matMul1 = std::make_shared<op::TypeRelaxed<op::v0::MatMul>>(
|
||||
std::vector<element::Type>{ element::f32, element::f32 },
|
||||
std::vector<element::Type>{ element::f32 },
|
||||
ov::op::TemporaryReplaceOutputType(fq4, element::f32).get(),
|
||||
ov::op::TemporaryReplaceOutputType(transpose2, element::f32).get(), transA, transB);
|
||||
const auto fq5 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul1, ov::element::i8, fq_signed_params);
|
||||
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(fq5, transpose3Const);
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
|
||||
return std::make_shared<ov::Model>(results, ngraphParam, "mha");
|
||||
}
|
||||
std::shared_ptr<ov::Model> MHAINT8MatMulTypeRelaxedFunction::initReference() const {
|
||||
auto data0 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto data1 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
auto data2 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
auto data3 = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
|
||||
ngraph::ParameterVector ngraphParams = {data0, data1, data2, data3};
|
||||
|
||||
const auto fq_signed_params = ngraph::builder::subgraph::FakeQuantizeOnData(256, {1}, {-36912.66015625}, {36624.28125}, {-128}, {127}, ov::element::i8);
|
||||
const auto fq0 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(data0, ov::element::i8, fq_signed_params);
|
||||
const auto fq1 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(data1, ov::element::i8, fq_signed_params);
|
||||
const auto fq2 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(data3, ov::element::i8, fq_signed_params);
|
||||
NodeVector subgraph_inputs = {fq0, fq1, data2, fq2};
|
||||
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[0]);
|
||||
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[1]);
|
||||
auto addParam = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[2]);
|
||||
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(precision, input_shapes[3]);
|
||||
ov::ParameterVector subgraph_params = {transpose0Param, transpose1Param, addParam, transpose2Param};
|
||||
|
||||
const auto shape_rank = input_shapes[0].get_shape().size();
|
||||
auto transpose0Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
|
||||
auto transpose1Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 3, 1});
|
||||
auto transpose2Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
|
||||
auto transpose3Const = ngraph::builder::makeConstant(ngraph::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
|
||||
|
||||
std::vector<int64_t> reshape0ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0] *
|
||||
input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]),
|
||||
-1};
|
||||
auto reshape0Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape0ConstData.size()}, reshape0ConstData);
|
||||
|
||||
std::vector<int64_t> reshape1ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0]),
|
||||
static_cast<int64_t>(input_shapes[0].get_shape()[2]),
|
||||
static_cast<int64_t>(input_shapes[0].get_shape()[1]),
|
||||
static_cast<int64_t>(input_shapes[0].get_shape()[1])};
|
||||
auto reshape1Const = ngraph::builder::makeConstant(ngraph::element::i64, {reshape1ConstData.size()}, reshape1ConstData);
|
||||
|
||||
float transA = false;
|
||||
float transB = false;
|
||||
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(transpose0Param, transpose0Const);
|
||||
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(transpose1Param, transpose1Const);
|
||||
const auto matMul0 = std::make_shared<op::TypeRelaxed<op::v0::MatMul>>(
|
||||
std::vector<element::Type>{ element::f32, element::f32 },
|
||||
std::vector<element::Type>{ element::f32 },
|
||||
ov::op::TemporaryReplaceOutputType(transpose0, element::f32).get(),
|
||||
ov::op::TemporaryReplaceOutputType(transpose1, element::f32).get(), transA, transB);
|
||||
|
||||
const auto fq3 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul0, ov::element::i8, fq_signed_params);
|
||||
const auto add = std::make_shared<op::TypeRelaxed<ngraph::opset3::Add>>(
|
||||
std::vector<element::Type>{ element::f32, element::f32 },
|
||||
std::vector<element::Type>{ element::f32 },
|
||||
ov::op::TemporaryReplaceOutputType(fq3, element::f32).get(),
|
||||
ov::op::TemporaryReplaceOutputType(addParam, element::f32).get());
|
||||
const auto deq = std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1}, std::vector<float>{0.1122});
|
||||
const auto deq_mul = std::make_shared<op::TypeRelaxed<ngraph::opset3::Multiply>>(
|
||||
std::vector<element::Type>{ element::f32, element::f32 },
|
||||
std::vector<element::Type>{ element::f32 },
|
||||
ov::op::TemporaryReplaceOutputType(add, element::f32).get(),
|
||||
ov::op::TemporaryReplaceOutputType(deq, element::f32).get());
|
||||
|
||||
const auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(add, reshape0Const, true);
|
||||
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(reshape0, 1);
|
||||
const auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softMax, reshape1Const, true);
|
||||
|
||||
const auto fq_unsigned_params = ngraph::builder::subgraph::FakeQuantizeOnData(256, {1}, {0}, {0.245}, {0}, {255}, ov::element::u8);
|
||||
const auto fq4 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(reshape1, ov::element::u8, fq_unsigned_params);
|
||||
|
||||
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(transpose2Param, transpose2Const);
|
||||
const auto matMul1 = std::make_shared<op::TypeRelaxed<op::v0::MatMul>>(
|
||||
std::vector<element::Type>{ element::f32, element::f32 },
|
||||
std::vector<element::Type>{ element::f32 },
|
||||
ov::op::TemporaryReplaceOutputType(fq4, element::f32).get(),
|
||||
ov::op::TemporaryReplaceOutputType(transpose2, element::f32).get(), transA, transB);
|
||||
const auto fq5 = ngraph::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul1, ov::element::i8, fq_signed_params);
|
||||
|
||||
auto subgraph = std::make_shared<ov::snippets::op::Subgraph>(subgraph_inputs,
|
||||
std::make_shared<ov::Model>(NodeVector{fq5}, subgraph_params));
|
||||
// TODO: At the moment Snippets don't support explicitly Transpose.
|
||||
// So we cannot collapse Transpose into Subgraph if there are ops between MatMul2 and Transpose3
|
||||
auto transpose3 = std::make_shared<ov::op::v1::Transpose>(subgraph, transpose3Const);
|
||||
|
||||
return std::make_shared<ov::Model>(NodeVector{transpose3}, ngraphParams);
|
||||
}
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace test
|
||||
} // namespace ov
|
||||
|
Loading…
Reference in New Issue
Block a user